diff --git a/src/adapters/openai/chat.ts b/src/adapters/openai/chat.ts index f57498e8..f6d5566d 100644 --- a/src/adapters/openai/chat.ts +++ b/src/adapters/openai/chat.ts @@ -29,7 +29,7 @@ import { shallowCopy } from "@/serializer/utils.js"; import { ChatLLM, ChatLLMGenerateEvents, ChatLLMOutput } from "@/llms/chat.js"; import { BaseMessage, RoleType } from "@/llms/primitives/message.js"; import { Emitter } from "@/emitter/emitter.js"; -import { ClientOptions, OpenAI, AzureOpenAI } from "openai"; +import { ClientOptions, OpenAI, AzureOpenAI, AzureClientOptions } from "openai"; import { GetRunContext } from "@/context.js"; import { promptTokensEstimate } from "openai-chat-tokens"; import { Serializer } from "@/serializer/serializer.js"; @@ -98,6 +98,7 @@ export class OpenAIChatLLMOutput extends ChatLLMOutput { interface Input { modelId?: ChatModel; client?: OpenAI | AzureOpenAI; + clientOptions?: ClientOptions | AzureClientOptions; parameters?: Partial; executionOptions?: ExecutionOptions; cache?: LLMCache; @@ -118,7 +119,15 @@ export class OpenAIChatLLM extends ChatLLM { public readonly client: OpenAI | AzureOpenAI; public readonly parameters: Partial; - constructor({ client, modelId, parameters, executionOptions = {}, cache, azure }: Input = {}) { + constructor({ + client, + modelId, + parameters, + executionOptions = {}, + clientOptions = {}, + cache, + azure, + }: Input = {}) { super(modelId || "gpt-4o-mini", executionOptions, cache); if (client) { this.client = client; @@ -128,9 +137,10 @@ export class OpenAIChatLLM extends ChatLLM { endpoint: process.env.AZURE_OPENAI_API_ENDPOINT, apiVersion: process.env.AZURE_OPENAI_API_VERSION, deployment: process.env.AZURE_OPENAI_API_DEPLOYMENT, + ...clientOptions, }); } else { - this.client = new OpenAI(); + this.client = new OpenAI(clientOptions); } this.parameters = parameters ?? { temperature: 0 }; }