diff --git a/src/components/Option/Playground/PlaygroundForm.tsx b/src/components/Option/Playground/PlaygroundForm.tsx index 245fd67..5e65524 100644 --- a/src/components/Option/Playground/PlaygroundForm.tsx +++ b/src/components/Option/Playground/PlaygroundForm.tsx @@ -206,7 +206,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => { return (
-
+
{
-
+
) } diff --git a/src/hooks/useMessage.tsx b/src/hooks/useMessage.tsx index 4559295..0049539 100644 --- a/src/hooks/useMessage.tsx +++ b/src/hooks/useMessage.tsx @@ -36,7 +36,12 @@ import { humanMessageFormatter } from "@/utils/human-message" import { pageAssistEmbeddingModel } from "@/models/embedding" import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore" import { getScreenshotFromCurrentTab } from "@/libs/get-screenshot" -import { isReasoningEnded, isReasoningStarted, removeReasoning } from "@/libs/reasoning" +import { + isReasoningEnded, + isReasoningStarted, + mergeReasoningContent, + removeReasoning +} from "@/libs/reasoning" export const useMessage = () => { const { @@ -413,7 +418,24 @@ export const useMessage = () => { let reasoningStartTime: Date | null = null let reasoningEndTime: Date | null = null let timetaken = 0 + let apiReasoning = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content if (count === 0) { @@ -618,7 +640,7 @@ export const useMessage = () => { const applicationChatHistory = [] const data = await getScreenshotFromCurrentTab() - + const visionImage = data?.screenshot || "" if (visionImage === "") { @@ -680,7 +702,24 @@ export const useMessage = () => { let reasoningStartTime: Date | undefined = undefined let reasoningEndTime: Date | undefined = undefined let timetaken = 0 + let apiReasoning = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content if (count === 0) { @@ -950,8 +989,25 @@ export const useMessage = () => { let reasoningStartTime: Date | null = null let reasoningEndTime: Date | null = null let timetaken = 0 + let apiReasoning = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content if (count === 0) { @@ -1279,7 +1335,24 @@ export const useMessage = () => { let timetaken = 0 let reasoningStartTime: Date | undefined = undefined let reasoningEndTime: Date | undefined = undefined + let apiReasoning = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content if (count === 0) { @@ -1527,7 +1600,24 @@ export const useMessage = () => { let reasoningStartTime: Date | null = null let reasoningEndTime: Date | null = null let timetaken = 0 + let apiReasoning = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content if (count === 0) { diff --git a/src/hooks/useMessageOption.tsx b/src/hooks/useMessageOption.tsx index 24d930d..8639452 100644 --- a/src/hooks/useMessageOption.tsx +++ b/src/hooks/useMessageOption.tsx @@ -332,7 +332,24 @@ export const useMessageOption = () => { let count = 0 let reasoningStartTime: Date | undefined = undefined let reasoningEndTime: Date | undefined = undefined + let apiReasoning = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content if (count === 0) { @@ -649,19 +666,27 @@ export const useMessageOption = () => { let count = 0 let reasoningStartTime: Date | null = null let reasoningEndTime: Date | null = null + let apiReasoning: boolean = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content - // console.log(chunk) - // if (chunk?.reasoning_content) { - // const reasoningContent = mergeReasoningContent( - // fullText, - // chunk?.reasoning_content || "" - // ) - // contentToSave += reasoningContent - // fullText += reasoningContent - // } if (isReasoningStarted(fullText) && !reasoningStartTime) { reasoningStartTime = new Date() @@ -992,8 +1017,25 @@ export const useMessageOption = () => { let count = 0 let reasoningStartTime: Date | undefined = undefined let reasoningEndTime: Date | undefined = undefined + let apiReasoning = false for await (const chunk of chunks) { + if (chunk?.additional_kwargs?.reasoning_content) { + const reasoningContent = mergeReasoningContent( + fullText, + chunk?.additional_kwargs?.reasoning_content || "" + ) + contentToSave = reasoningContent + fullText = reasoningContent + apiReasoning = true + } else { + if (apiReasoning) { + fullText += "" + contentToSave += "" + apiReasoning = false + } + } + contentToSave += chunk?.content fullText += chunk?.content if (count === 0) { diff --git a/src/libs/openai.ts b/src/libs/openai.ts index 63905c8..c9f105a 100644 --- a/src/libs/openai.ts +++ b/src/libs/openai.ts @@ -25,7 +25,7 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => { clearTimeout(timeoutId) // if Google API fails to return models, try another approach - if (res.status === 401 && res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') { + if (res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') { const urlGoogle = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}` const resGoogle = await fetch(urlGoogle, { signal: controller.signal diff --git a/src/libs/reasoning.ts b/src/libs/reasoning.ts index a44dc4c..d6c1b23 100644 --- a/src/libs/reasoning.ts +++ b/src/libs/reasoning.ts @@ -1,109 +1,100 @@ const tags = ["think", "reason", "reasoning", "thought"] -export function parseReasoning( - text: string -): { - type: "reasoning" | "text" - content: string - reasoning_running?: boolean +export function parseReasoning(text: string): { + type: "reasoning" | "text" + content: string + reasoning_running?: boolean }[] { - try { - const result: { - type: "reasoning" | "text" - content: string - reasoning_running?: boolean - }[] = [] - const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i") - const closeTagPattern = new RegExp(``, "i") + try { + const result: { + type: "reasoning" | "text" + content: string + reasoning_running?: boolean + }[] = [] + const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i") + const closeTagPattern = new RegExp(``, "i") - let currentIndex = 0 - let isReasoning = false + let currentIndex = 0 + let isReasoning = false - while (currentIndex < text.length) { - const openTagMatch = text.slice(currentIndex).match(tagPattern) - const closeTagMatch = text.slice(currentIndex).match(closeTagPattern) + while (currentIndex < text.length) { + const openTagMatch = text.slice(currentIndex).match(tagPattern) + const closeTagMatch = text.slice(currentIndex).match(closeTagPattern) - if (!isReasoning && openTagMatch) { - const beforeText = text.slice( - currentIndex, - currentIndex + openTagMatch.index - ) - if (beforeText.trim()) { - result.push({ type: "text", content: beforeText.trim() }) - } - - isReasoning = true - currentIndex += openTagMatch.index! + openTagMatch[0].length - continue - } - - if (isReasoning && closeTagMatch) { - const reasoningContent = text.slice( - currentIndex, - currentIndex + closeTagMatch.index - ) - if (reasoningContent.trim()) { - result.push({ type: "reasoning", content: reasoningContent.trim() }) - } - - isReasoning = false - currentIndex += closeTagMatch.index! + closeTagMatch[0].length - continue - } - - if (currentIndex < text.length) { - const remainingText = text.slice(currentIndex) - result.push({ - type: isReasoning ? "reasoning" : "text", - content: remainingText.trim(), - reasoning_running: isReasoning - }) - break - } + if (!isReasoning && openTagMatch) { + const beforeText = text.slice( + currentIndex, + currentIndex + openTagMatch.index + ) + if (beforeText.trim()) { + result.push({ type: "text", content: beforeText.trim() }) } - return result - } catch (e) { - console.error(`Error parsing reasoning: ${e}`) - return [ - { - type: "text", - content: text - } - ] + isReasoning = true + currentIndex += openTagMatch.index! + openTagMatch[0].length + continue + } + + if (isReasoning && closeTagMatch) { + const reasoningContent = text.slice( + currentIndex, + currentIndex + closeTagMatch.index + ) + if (reasoningContent.trim()) { + result.push({ type: "reasoning", content: reasoningContent.trim() }) + } + + isReasoning = false + currentIndex += closeTagMatch.index! + closeTagMatch[0].length + continue + } + + if (currentIndex < text.length) { + const remainingText = text.slice(currentIndex) + result.push({ + type: isReasoning ? "reasoning" : "text", + content: remainingText.trim(), + reasoning_running: isReasoning + }) + break + } } + + return result + } catch (e) { + console.error(`Error parsing reasoning: ${e}`) + return [ + { + type: "text", + content: text + } + ] + } } export function isReasoningStarted(text: string): boolean { - const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i") - return tagPattern.test(text) + const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i") + return tagPattern.test(text) } export function isReasoningEnded(text: string): boolean { - const closeTagPattern = new RegExp(``, "i") - return closeTagPattern.test(text) + const closeTagPattern = new RegExp(``, "i") + return closeTagPattern.test(text) } export function removeReasoning(text: string): string { - const tagPattern = new RegExp( - `<(${tags.join("|")})>.*?`, - "gis" - ) - return text.replace(tagPattern, "").trim() + const tagPattern = new RegExp( + `<(${tags.join("|")})>.*?`, + "gis" + ) + return text.replace(tagPattern, "").trim() } -export function mergeReasoningContent(originalText: string, reasoning: string): string { - const defaultReasoningTag = "think" - const tagPattern = new RegExp(`<(${tags.join("|")})>(.*?)`, "is") - const hasReasoningTag = tagPattern.test(originalText) +export function mergeReasoningContent( + originalText: string, + reasoning: string +): string { + const reasoningTag = "" - if (hasReasoningTag) { - const match = originalText.match(tagPattern) - if (match) { - const [fullMatch, tag, existingContent] = match - const remainingText = originalText.replace(fullMatch, '').trim() - const newContent = `${existingContent.trim()}${reasoning}` - return `<${tag}>${newContent} ${remainingText}` - } - } + originalText = originalText.replace(reasoningTag, "") - return `<${defaultReasoningTag}>${reasoning} ${originalText.trim()}`.trim() -} \ No newline at end of file + return `${reasoningTag}${originalText + reasoning}`.trim() +} diff --git a/src/models/CustomAIMessageChunk.ts b/src/models/CustomAIMessageChunk.ts new file mode 100644 index 0000000..72479ed --- /dev/null +++ b/src/models/CustomAIMessageChunk.ts @@ -0,0 +1,78 @@ +interface BaseMessageFields { + content: string; + name?: string; + additional_kwargs?: { + [key: string]: unknown; + }; +} + +export class CustomAIMessageChunk { + /** The text of the message. */ + content: string; + + /** The name of the message sender in a multi-user chat. */ + name?: string; + + /** Additional keyword arguments */ + additional_kwargs: NonNullable; + + constructor(fields: BaseMessageFields) { + // Make sure the default value for additional_kwargs is passed into super() for serialization + if (!fields.additional_kwargs) { + // eslint-disable-next-line no-param-reassign + fields.additional_kwargs = {}; + } + + this.name = fields.name; + this.content = fields.content; + this.additional_kwargs = fields.additional_kwargs; + } + + static _mergeAdditionalKwargs( + left: NonNullable, + right: NonNullable + ): NonNullable { + const merged = { ...left }; + for (const [key, value] of Object.entries(right)) { + if (merged[key] === undefined) { + merged[key] = value; + }else if (typeof merged[key] === "string") { + merged[key] = (merged[key] as string) + value; + } else if ( + !Array.isArray(merged[key]) && + typeof merged[key] === "object" + ) { + merged[key] = this._mergeAdditionalKwargs( + merged[key] as NonNullable, + value as NonNullable + ); + } else { + throw new Error( + `additional_kwargs[${key}] already exists in this message chunk.` + ); + } + } + return merged; + } + + concat(chunk: CustomAIMessageChunk) { + return new CustomAIMessageChunk({ + content: this.content + chunk.content, + additional_kwargs: CustomAIMessageChunk._mergeAdditionalKwargs( + this.additional_kwargs, + chunk.additional_kwargs + ), + }); + } +} + +function isAiMessageChunkFields(value: unknown): value is BaseMessageFields { + if (typeof value !== "object" || value == null) return false; + return "content" in value && typeof value["content"] === "string"; +} + +function isAiMessageChunkFieldsList( + value: unknown[] +): value is BaseMessageFields[] { + return value.length > 0 && value.every((x) => isAiMessageChunkFields(x)); +} diff --git a/src/models/CustomChatOpenAI.ts b/src/models/CustomChatOpenAI.ts new file mode 100644 index 0000000..ec74e10 --- /dev/null +++ b/src/models/CustomChatOpenAI.ts @@ -0,0 +1,914 @@ +import { type ClientOptions, OpenAI as OpenAIClient } from "openai" +import { + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessageChunk, + SystemMessageChunk, + ToolMessageChunk, +} from "@langchain/core/messages" +import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs" +import { getEnvironmentVariable } from "@langchain/core/utils/env" +import { BaseChatModel, BaseChatModelParams } from "@langchain/core/language_models/chat_models" +import { convertToOpenAITool } from "@langchain/core/utils/function_calling" +import { + RunnablePassthrough, + RunnableSequence +} from "@langchain/core/runnables" +import { + JsonOutputParser, + StructuredOutputParser +} from "@langchain/core/output_parsers" +import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools" +import { wrapOpenAIClientError } from "./utils/openai.js" +import { + ChatOpenAICallOptions, + getEndpoint, + OpenAIChatInput, + OpenAICoreRequestOptions +} from "@langchain/openai" +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager" +import { TokenUsage } from "@langchain/core/language_models/base" +import { LegacyOpenAIInput } from "./types.js" +import { CustomAIMessageChunk } from "./CustomAIMessageChunk.js" + +type OpenAIRoleEnum = "system" | "assistant" | "user" | "function" | "tool" + +function extractGenericMessageCustomRole(message: ChatMessage) { + if ( + message.role !== "system" && + message.role !== "assistant" && + message.role !== "user" && + message.role !== "function" && + message.role !== "tool" + ) { + console.warn(`Unknown message role: ${message.role}`) + } + return message.role +} +export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum { + const type = message._getType() + switch (type) { + case "system": + return "system" + case "ai": + return "assistant" + case "human": + return "user" + case "function": + return "function" + case "tool": + return "tool" + case "generic": { + if (!ChatMessage.isInstance(message)) + throw new Error("Invalid generic chat message") + return extractGenericMessageCustomRole(message) as OpenAIRoleEnum + } + default: + return type + } +} +function openAIResponseToChatMessage( + message: OpenAIClient.Chat.Completions.ChatCompletionMessage +) { + switch (message.role) { + case "assistant": + return new AIMessage(message.content || "", { + // function_call: message.function_call, + // tool_calls: message.tool_calls + // reasoning_content: message?.reasoning_content || null + }) + default: + return new ChatMessage(message.content || "", message.role ?? "unknown") + } +} +function _convertDeltaToMessageChunk( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + delta: Record, + defaultRole?: OpenAIRoleEnum +) { + const role = delta.role ?? defaultRole + const content = delta.content ?? "" + const reasoning_content: string | null = delta.reasoning_content ?? null + let additional_kwargs + if (delta.function_call) { + additional_kwargs = { + function_call: delta.function_call + } + } else if (delta.tool_calls) { + additional_kwargs = { + tool_calls: delta.tool_calls + } + } else { + additional_kwargs = {} + } + if (role === "user") { + return new HumanMessageChunk({ content }) + } else if (role === "assistant") { + return new CustomAIMessageChunk({ + content, + additional_kwargs: { + ...additional_kwargs, + reasoning_content + } + }) as any + } else if (role === "system") { + return new SystemMessageChunk({ content }) + } else if (role === "function") { + return new FunctionMessageChunk({ + content, + additional_kwargs, + name: delta.name + }) + } else if (role === "tool") { + return new ToolMessageChunk({ + content, + additional_kwargs, + tool_call_id: delta.tool_call_id + }) + } else { + return new ChatMessageChunk({ content, role }) + } +} +function convertMessagesToOpenAIParams(messages: any[]) { + // TODO: Function messages do not support array content, fix cast + return messages.map((message) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const completionParam: { role: string; content: string; name?: string } = { + role: messageToOpenAIRole(message), + content: message.content + } + if (message.name != null) { + completionParam.name = message.name + } + + return completionParam + }) +} +export class CustomChatOpenAI< + CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions +> + extends BaseChatModel + implements OpenAIChatInput { + temperature = 1 + + topP = 1 + + frequencyPenalty = 0 + + presencePenalty = 0 + + n = 1 + + logitBias?: Record + + modelName = "gpt-3.5-turbo" + + model = "gpt-3.5-turbo" + + modelKwargs?: OpenAIChatInput["modelKwargs"] + + stop?: string[] + + stopSequences?: string[] + + user?: string + + timeout?: number + + streaming = false + + streamUsage = true + + maxTokens?: number + + logprobs?: boolean + + topLogprobs?: number + + openAIApiKey?: string + + apiKey?: string + + azureOpenAIApiVersion?: string + + azureOpenAIApiKey?: string + + azureADTokenProvider?: () => Promise + + azureOpenAIApiInstanceName?: string + + azureOpenAIApiDeploymentName?: string + + azureOpenAIBasePath?: string + + organization?: string + + protected client: OpenAIClient + + protected clientConfig: ClientOptions + static lc_name() { + return "ChatOpenAI" + } + get callKeys() { + return [ + ...super.callKeys, + "options", + "function_call", + "functions", + "tools", + "tool_choice", + "promptIndex", + "response_format", + "seed" + ] + } + get lc_secrets() { + return { + openAIApiKey: "OPENAI_API_KEY", + azureOpenAIApiKey: "AZURE_OPENAI_API_KEY", + organization: "OPENAI_ORGANIZATION" + } + } + get lc_aliases() { + return { + modelName: "model", + openAIApiKey: "openai_api_key", + azureOpenAIApiVersion: "azure_openai_api_version", + azureOpenAIApiKey: "azure_openai_api_key", + azureOpenAIApiInstanceName: "azure_openai_api_instance_name", + azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name" + } + } + constructor( + fields?: Partial & + BaseChatModelParams & { + configuration?: ClientOptions & LegacyOpenAIInput; + }, + /** @deprecated */ + configuration?: ClientOptions & LegacyOpenAIInput + ) { + super(fields ?? {}) + Object.defineProperty(this, "lc_serializable", { + enumerable: true, + configurable: true, + writable: true, + value: true + }) + Object.defineProperty(this, "temperature", { + enumerable: true, + configurable: true, + writable: true, + value: 1 + }) + Object.defineProperty(this, "topP", { + enumerable: true, + configurable: true, + writable: true, + value: 1 + }) + Object.defineProperty(this, "frequencyPenalty", { + enumerable: true, + configurable: true, + writable: true, + value: 0 + }) + Object.defineProperty(this, "presencePenalty", { + enumerable: true, + configurable: true, + writable: true, + value: 0 + }) + Object.defineProperty(this, "n", { + enumerable: true, + configurable: true, + writable: true, + value: 1 + }) + Object.defineProperty(this, "logitBias", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "modelName", { + enumerable: true, + configurable: true, + writable: true, + value: "gpt-3.5-turbo" + }) + Object.defineProperty(this, "modelKwargs", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "stop", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "user", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "timeout", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "streaming", { + enumerable: true, + configurable: true, + writable: true, + value: false + }) + Object.defineProperty(this, "maxTokens", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "logprobs", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "topLogprobs", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "openAIApiKey", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "azureOpenAIApiVersion", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "azureOpenAIApiKey", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "azureOpenAIApiInstanceName", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "azureOpenAIApiDeploymentName", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "azureOpenAIBasePath", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "organization", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "client", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + Object.defineProperty(this, "clientConfig", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }) + this.openAIApiKey = + fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY") + + this.modelName = fields?.modelName ?? this.modelName + this.modelKwargs = fields?.modelKwargs ?? {} + this.timeout = fields?.timeout + this.temperature = fields?.temperature ?? this.temperature + this.topP = fields?.topP ?? this.topP + this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty + this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty + this.maxTokens = fields?.maxTokens + this.logprobs = fields?.logprobs + this.topLogprobs = fields?.topLogprobs + this.n = fields?.n ?? this.n + this.logitBias = fields?.logitBias + this.stop = fields?.stop + this.user = fields?.user + this.streaming = fields?.streaming ?? false + this.clientConfig = { + apiKey: this.openAIApiKey, + organization: this.organization, + baseURL: configuration?.basePath ?? fields?.configuration?.basePath, + dangerouslyAllowBrowser: true, + defaultHeaders: + configuration?.baseOptions?.headers ?? + fields?.configuration?.baseOptions?.headers, + defaultQuery: + configuration?.baseOptions?.params ?? + fields?.configuration?.baseOptions?.params, + ...configuration, + ...fields?.configuration + } + } + /** + * Get the parameters used to invoke the model + */ + invocationParams(options) { + function isStructuredToolArray(tools) { + return ( + tools !== undefined && + tools.every((tool) => Array.isArray(tool.lc_namespace)) + ) + } + const params = { + model: this.modelName, + temperature: this.temperature, + top_p: this.topP, + frequency_penalty: this.frequencyPenalty, + presence_penalty: this.presencePenalty, + max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens, + logprobs: this.logprobs, + top_logprobs: this.topLogprobs, + n: this.n, + logit_bias: this.logitBias, + stop: options?.stop ?? this.stop, + user: this.user, + stream: this.streaming, + functions: options?.functions, + function_call: options?.function_call, + tools: isStructuredToolArray(options?.tools) + ? options?.tools.map(convertToOpenAITool) + : options?.tools, + tool_choice: options?.tool_choice, + response_format: options?.response_format, + seed: options?.seed, + ...this.modelKwargs + } + return params + } + /** @ignore */ + _identifyingParams() { + return { + model_name: this.modelName, + //@ts-ignore + ...this?.invocationParams(), + ...this.clientConfig + } + } + async *_streamResponseChunks( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const messagesMapped = convertMessagesToOpenAIParams(messages) + const params = { + ...this.invocationParams(options), + messages: messagesMapped, + stream: true + } + let defaultRole + //@ts-ignore + const streamIterable = await this.completionWithRetry(params, options) + for await (const data of streamIterable) { + const choice = data?.choices[0] + if (!choice) { + continue + } + const { delta } = choice + if (!delta) { + continue + } + const chunk = _convertDeltaToMessageChunk(delta, defaultRole) + defaultRole = delta.role ?? defaultRole + const newTokenIndices = { + //@ts-ignore + prompt: options?.promptIndex ?? 0, + completion: choice.index ?? 0 + } + if (typeof chunk.content !== "string") { + console.log( + "[WARNING]: Received non-string content from OpenAI. This is currently not supported." + ) + continue + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const generationInfo = { ...newTokenIndices } as any + if (choice.finish_reason !== undefined) { + generationInfo.finish_reason = choice.finish_reason + } + if (this.logprobs) { + generationInfo.logprobs = choice.logprobs + } + const generationChunk = new ChatGenerationChunk({ + message: chunk, + text: chunk.content, + generationInfo + }) + yield generationChunk + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken( + generationChunk.text ?? "", + newTokenIndices, + undefined, + undefined, + undefined, + { chunk: generationChunk } + ) + } + if (options.signal?.aborted) { + throw new Error("AbortError") + } + } + /** + * Get the identifying parameters for the model + * + */ + identifyingParams() { + return this._identifyingParams() + } + /** @ignore */ + async _generate( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const tokenUsage: TokenUsage = {} + const params = this.invocationParams(options) + const messagesMapped: any[] = convertMessagesToOpenAIParams(messages) + if (params.stream) { + const stream = this._streamResponseChunks(messages, options, runManager) + const finalChunks: Record = {} + for await (const chunk of stream) { + //@ts-ignore + chunk.message.response_metadata = { + ...chunk.generationInfo, + //@ts-ignore + ...chunk.message.response_metadata + } + const index = chunk.generationInfo?.completion ?? 0 + if (finalChunks[index] === undefined) { + finalChunks[index] = chunk + } else { + finalChunks[index] = finalChunks[index].concat(chunk) + } + } + const generations = Object.entries(finalChunks) + .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10)) + .map(([_, value]) => value) + const { functions, function_call } = this.invocationParams(options) + // OpenAI does not support token usage report under stream mode, + // fallback to estimation. + const promptTokenUsage = await this.getEstimatedTokenCountFromPrompt( + messages, + functions, + function_call + ) + const completionTokenUsage = + await this.getNumTokensFromGenerations(generations) + tokenUsage.promptTokens = promptTokenUsage + tokenUsage.completionTokens = completionTokenUsage + tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage + return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } } + } else { + const data = await this.completionWithRetry( + { + ...params, + //@ts-ignore + stream: false, + messages: messagesMapped + }, + { + signal: options?.signal, + //@ts-ignore + ...options?.options + } + ) + const { + completion_tokens: completionTokens, + prompt_tokens: promptTokens, + total_tokens: totalTokens + //@ts-ignore + } = data?.usage ?? {} + if (completionTokens) { + tokenUsage.completionTokens = + (tokenUsage.completionTokens ?? 0) + completionTokens + } + if (promptTokens) { + tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens + } + if (totalTokens) { + tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens + } + const generations = [] + //@ts-ignore + for (const part of data?.choices ?? []) { + const text = part.message?.content ?? "" + const generation = { + text, + message: openAIResponseToChatMessage( + part.message ?? { role: "assistant" } + ) + } + //@ts-ignore + generation.generationInfo = { + ...(part.finish_reason ? { finish_reason: part.finish_reason } : {}), + ...(part.logprobs ? { logprobs: part.logprobs } : {}) + } + generations.push(generation) + } + return { + generations, + llmOutput: { tokenUsage } + } + } + } + /** + * Estimate the number of tokens a prompt will use. + * Modified from: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts + */ + async getEstimatedTokenCountFromPrompt(messages, functions, function_call) { + let tokens = (await this.getNumTokensFromMessages(messages)).totalCount + if (functions && messages.find((m) => m._getType() === "system")) { + tokens -= 4 + } + if (function_call === "none") { + tokens += 1 + } else if (typeof function_call === "object") { + tokens += (await this.getNumTokens(function_call.name)) + 4 + } + return tokens + } + /** + * Estimate the number of tokens an array of generations have used. + */ + async getNumTokensFromGenerations(generations) { + const generationUsages = await Promise.all( + generations.map(async (generation) => { + if (generation.message.additional_kwargs?.function_call) { + return (await this.getNumTokensFromMessages([generation.message])) + .countPerMessage[0] + } else { + return await this.getNumTokens(generation.message.content) + } + }) + ) + return generationUsages.reduce((a, b) => a + b, 0) + } + async getNumTokensFromMessages(messages) { + let totalCount = 0 + let tokensPerMessage = 0 + let tokensPerName = 0 + // From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb + if (this.modelName === "gpt-3.5-turbo-0301") { + tokensPerMessage = 4 + tokensPerName = -1 + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + const countPerMessage = await Promise.all( + messages.map(async (message) => { + const textCount = await this.getNumTokens(message.content) + const roleCount = await this.getNumTokens(messageToOpenAIRole(message)) + const nameCount = + message.name !== undefined + ? tokensPerName + (await this.getNumTokens(message.name)) + : 0 + let count = textCount + tokensPerMessage + roleCount + nameCount + // From: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts messageTokenEstimate + const openAIMessage = message + if (openAIMessage._getType() === "function") { + count -= 2 + } + if (openAIMessage.additional_kwargs?.function_call) { + count += 3 + } + if (openAIMessage?.additional_kwargs.function_call?.name) { + count += await this.getNumTokens( + openAIMessage.additional_kwargs.function_call?.name + ) + } + if (openAIMessage.additional_kwargs.function_call?.arguments) { + try { + count += await this.getNumTokens( + // Remove newlines and spaces + JSON.stringify( + JSON.parse( + openAIMessage.additional_kwargs.function_call?.arguments + ) + ) + ) + } catch (error) { + console.error( + "Error parsing function arguments", + error, + JSON.stringify(openAIMessage.additional_kwargs.function_call) + ) + count += await this.getNumTokens( + openAIMessage.additional_kwargs.function_call?.arguments + ) + } + } + totalCount += count + return count + }) + ) + totalCount += 3 // every reply is primed with <|start|>assistant<|message|> + return { totalCount, countPerMessage } + } + async completionWithRetry( + request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming, + options?: OpenAICoreRequestOptions + ) { + const requestOptions = this._getClientOptions(options) + return this.caller.call(async () => { + try { + const res = await this.client.chat.completions.create( + request, + requestOptions + ) + return res + } catch (e) { + const error = wrapOpenAIClientError(e) + throw error + } + }) + } + _getClientOptions(options) { + if (!this.client) { + const openAIEndpointConfig = { + azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, + azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, + azureOpenAIApiKey: this.azureOpenAIApiKey, + azureOpenAIBasePath: this.azureOpenAIBasePath, + baseURL: this.clientConfig.baseURL + } + const endpoint = getEndpoint(openAIEndpointConfig) + const params = { + ...this.clientConfig, + baseURL: endpoint, + timeout: this.timeout, + maxRetries: 0 + } + if (!params.baseURL) { + delete params.baseURL + } + this.client = new OpenAIClient(params) + } + const requestOptions = { + ...this.clientConfig, + ...options + } + if (this.azureOpenAIApiKey) { + requestOptions.headers = { + "api-key": this.azureOpenAIApiKey, + ...requestOptions.headers + } + requestOptions.query = { + "api-version": this.azureOpenAIApiVersion, + ...requestOptions.query + } + } + return requestOptions + } + _llmType() { + return "openai" + } + /** @ignore */ + _combineLLMOutput(...llmOutputs) { + return llmOutputs.reduce( + (acc, llmOutput) => { + if (llmOutput && llmOutput.tokenUsage) { + acc.tokenUsage.completionTokens += + llmOutput.tokenUsage.completionTokens ?? 0 + acc.tokenUsage.promptTokens += llmOutput.tokenUsage.promptTokens ?? 0 + acc.tokenUsage.totalTokens += llmOutput.tokenUsage.totalTokens ?? 0 + } + return acc + }, + { + tokenUsage: { + completionTokens: 0, + promptTokens: 0, + totalTokens: 0 + } + } + ) + } + withStructuredOutput(outputSchema, config) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let schema + let name + let method + let includeRaw + if (isStructuredOutputMethodParams(outputSchema)) { + schema = outputSchema.schema + name = outputSchema.name + method = outputSchema.method + includeRaw = outputSchema.includeRaw + } else { + schema = outputSchema + name = config?.name + method = config?.method + includeRaw = config?.includeRaw + } + let llm + let outputParser + if (method === "jsonMode") { + llm = this.bind({ + }) + if (isZodSchema(schema)) { + outputParser = StructuredOutputParser.fromZodSchema(schema) + } else { + outputParser = new JsonOutputParser() + } + } else { + let functionName = name ?? "extract" + // Is function calling + + let openAIFunctionDefinition + if ( + typeof schema.name === "string" && + typeof schema.parameters === "object" && + schema.parameters != null + ) { + openAIFunctionDefinition = schema + functionName = schema.name + } else { + openAIFunctionDefinition = { + name: schema.title ?? functionName, + description: schema.description ?? "", + parameters: schema + } + } + llm = this.bind({ + + }) + outputParser = new JsonOutputKeyToolsParser({ + returnSingle: true, + keyName: functionName + }) + } + if (!includeRaw) { + return llm.pipe(outputParser) + } + const parserAssign = RunnablePassthrough.assign({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parsed: (input, config) => outputParser.invoke(input.raw, config) + }) + const parserNone = RunnablePassthrough.assign({ + parsed: () => null + }) + const parsedWithFallback = parserAssign.withFallbacks({ + fallbacks: [parserNone] + }) + return RunnableSequence.from([ + { + raw: llm + }, + parsedWithFallback + ] as any) + } +} +function isZodSchema( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + input +) { + // Check for a characteristic method of Zod schemas + return typeof input?.parse === "function" +} +function isStructuredOutputMethodParams( + x + // eslint-disable-next-line @typescript-eslint/no-explicit-any +) { + return ( + x !== undefined && + // eslint-disable-next-line @typescript-eslint/no-explicit-any + typeof x.schema === "object" + ) +} diff --git a/src/models/index.ts b/src/models/index.ts index d76dd15..b2b4d8f 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -2,9 +2,9 @@ import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models" import { ChatChromeAI } from "./ChatChromeAi" import { ChatOllama } from "./ChatOllama" import { getOpenAIConfigById } from "@/db/openai" -import { ChatOpenAI } from "@langchain/openai" import { urlRewriteRuntime } from "@/libs/runtime" import { ChatGoogleAI } from "./ChatGoogleAI" +import { CustomChatOpenAI } from "./CustomChatOpenAI" export const pageAssistModel = async ({ model, @@ -76,7 +76,7 @@ export const pageAssistModel = async ({ }) as any } - return new ChatOpenAI({ + return new CustomChatOpenAI({ modelName: modelInfo.model_id, openAIApiKey: providerInfo.apiKey || "temp", temperature,