feat: implement CustomAIMessageChunk class and enhance reasoning handling in useMessageOption hook
This commit is contained in:
		
							parent
							
								
									8381f1c996
								
							
						
					
					
						commit
						926f4e1a4a
					
				| @ -206,7 +206,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => { | |||||||
| 
 | 
 | ||||||
|   return ( |   return ( | ||||||
|     <div className="flex w-full flex-col items-center p-2 pt-1  pb-4"> |     <div className="flex w-full flex-col items-center p-2 pt-1  pb-4"> | ||||||
|       <form className="relative z-10 flex w-full flex-col items-center justify-center gap-2 text-base"> |       <div className="relative z-10 flex w-full flex-col items-center justify-center gap-2 text-base"> | ||||||
|         <div className="relative flex w-full flex-row justify-center gap-2 lg:w-4/5"> |         <div className="relative flex w-full flex-row justify-center gap-2 lg:w-4/5"> | ||||||
|           <div |           <div | ||||||
|             className={` bg-neutral-50  dark:bg-[#262626] relative w-full max-w-[48rem] p-1 backdrop-blur-lg duration-100 border border-gray-300 rounded-xl  dark:border-gray-600
 |             className={` bg-neutral-50  dark:bg-[#262626] relative w-full max-w-[48rem] p-1 backdrop-blur-lg duration-100 border border-gray-300 rounded-xl  dark:border-gray-600
 | ||||||
| @ -449,7 +449,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => { | |||||||
|             </div> |             </div> | ||||||
|           </div> |           </div> | ||||||
|         </div> |         </div> | ||||||
|       </form> |       </div> | ||||||
|     </div> |     </div> | ||||||
|   ) |   ) | ||||||
| } | } | ||||||
|  | |||||||
| @ -36,7 +36,12 @@ import { humanMessageFormatter } from "@/utils/human-message" | |||||||
| import { pageAssistEmbeddingModel } from "@/models/embedding" | import { pageAssistEmbeddingModel } from "@/models/embedding" | ||||||
| import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore" | import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore" | ||||||
| import { getScreenshotFromCurrentTab } from "@/libs/get-screenshot" | import { getScreenshotFromCurrentTab } from "@/libs/get-screenshot" | ||||||
| import { isReasoningEnded, isReasoningStarted, removeReasoning } from "@/libs/reasoning" | import { | ||||||
|  |   isReasoningEnded, | ||||||
|  |   isReasoningStarted, | ||||||
|  |   mergeReasoningContent, | ||||||
|  |   removeReasoning | ||||||
|  | } from "@/libs/reasoning" | ||||||
| 
 | 
 | ||||||
| export const useMessage = () => { | export const useMessage = () => { | ||||||
|   const { |   const { | ||||||
| @ -413,7 +418,24 @@ export const useMessage = () => { | |||||||
|       let reasoningStartTime: Date | null = null |       let reasoningStartTime: Date | null = null | ||||||
|       let reasoningEndTime: Date | null = null |       let reasoningEndTime: Date | null = null | ||||||
|       let timetaken = 0 |       let timetaken = 0 | ||||||
|  |       let apiReasoning = false | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         if (count === 0) { |         if (count === 0) { | ||||||
| @ -680,7 +702,24 @@ export const useMessage = () => { | |||||||
|       let reasoningStartTime: Date | undefined = undefined |       let reasoningStartTime: Date | undefined = undefined | ||||||
|       let reasoningEndTime: Date | undefined = undefined |       let reasoningEndTime: Date | undefined = undefined | ||||||
|       let timetaken = 0 |       let timetaken = 0 | ||||||
|  |       let apiReasoning = false | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         if (count === 0) { |         if (count === 0) { | ||||||
| @ -950,8 +989,25 @@ export const useMessage = () => { | |||||||
|       let reasoningStartTime: Date | null = null |       let reasoningStartTime: Date | null = null | ||||||
|       let reasoningEndTime: Date | null = null |       let reasoningEndTime: Date | null = null | ||||||
|       let timetaken = 0 |       let timetaken = 0 | ||||||
|  |       let apiReasoning = false | ||||||
| 
 | 
 | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         if (count === 0) { |         if (count === 0) { | ||||||
| @ -1279,7 +1335,24 @@ export const useMessage = () => { | |||||||
|       let timetaken = 0 |       let timetaken = 0 | ||||||
|       let reasoningStartTime: Date | undefined = undefined |       let reasoningStartTime: Date | undefined = undefined | ||||||
|       let reasoningEndTime: Date | undefined = undefined |       let reasoningEndTime: Date | undefined = undefined | ||||||
|  |       let apiReasoning = false | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         if (count === 0) { |         if (count === 0) { | ||||||
| @ -1527,7 +1600,24 @@ export const useMessage = () => { | |||||||
|       let reasoningStartTime: Date | null = null |       let reasoningStartTime: Date | null = null | ||||||
|       let reasoningEndTime: Date | null = null |       let reasoningEndTime: Date | null = null | ||||||
|       let timetaken = 0 |       let timetaken = 0 | ||||||
|  |       let apiReasoning = false | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         if (count === 0) { |         if (count === 0) { | ||||||
|  | |||||||
| @ -332,7 +332,24 @@ export const useMessageOption = () => { | |||||||
|       let count = 0 |       let count = 0 | ||||||
|       let reasoningStartTime: Date | undefined = undefined |       let reasoningStartTime: Date | undefined = undefined | ||||||
|       let reasoningEndTime: Date | undefined = undefined |       let reasoningEndTime: Date | undefined = undefined | ||||||
|  |       let apiReasoning = false | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         if (count === 0) { |         if (count === 0) { | ||||||
| @ -649,19 +666,27 @@ export const useMessageOption = () => { | |||||||
|       let count = 0 |       let count = 0 | ||||||
|       let reasoningStartTime: Date | null = null |       let reasoningStartTime: Date | null = null | ||||||
|       let reasoningEndTime: Date | null = null |       let reasoningEndTime: Date | null = null | ||||||
|  |       let apiReasoning: boolean = false | ||||||
| 
 | 
 | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         // console.log(chunk)
 |  | ||||||
|         // if (chunk?.reasoning_content) {
 |  | ||||||
|         //   const reasoningContent = mergeReasoningContent(
 |  | ||||||
|         //     fullText,
 |  | ||||||
|         //     chunk?.reasoning_content || ""
 |  | ||||||
|         //   )
 |  | ||||||
|         //   contentToSave += reasoningContent
 |  | ||||||
|         //   fullText += reasoningContent
 |  | ||||||
|         // }
 |  | ||||||
| 
 | 
 | ||||||
|         if (isReasoningStarted(fullText) && !reasoningStartTime) { |         if (isReasoningStarted(fullText) && !reasoningStartTime) { | ||||||
|           reasoningStartTime = new Date() |           reasoningStartTime = new Date() | ||||||
| @ -992,8 +1017,25 @@ export const useMessageOption = () => { | |||||||
|       let count = 0 |       let count = 0 | ||||||
|       let reasoningStartTime: Date | undefined = undefined |       let reasoningStartTime: Date | undefined = undefined | ||||||
|       let reasoningEndTime: Date | undefined = undefined |       let reasoningEndTime: Date | undefined = undefined | ||||||
|  |       let apiReasoning = false | ||||||
| 
 | 
 | ||||||
|       for await (const chunk of chunks) { |       for await (const chunk of chunks) { | ||||||
|  |         if (chunk?.additional_kwargs?.reasoning_content) { | ||||||
|  |           const reasoningContent = mergeReasoningContent( | ||||||
|  |             fullText, | ||||||
|  |             chunk?.additional_kwargs?.reasoning_content || "" | ||||||
|  |           ) | ||||||
|  |           contentToSave = reasoningContent | ||||||
|  |           fullText = reasoningContent | ||||||
|  |           apiReasoning = true | ||||||
|  |         } else { | ||||||
|  |           if (apiReasoning) { | ||||||
|  |             fullText += "</think>" | ||||||
|  |             contentToSave += "</think>" | ||||||
|  |             apiReasoning = false | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         contentToSave += chunk?.content |         contentToSave += chunk?.content | ||||||
|         fullText += chunk?.content |         fullText += chunk?.content | ||||||
|         if (count === 0) { |         if (count === 0) { | ||||||
|  | |||||||
| @ -25,7 +25,7 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => { | |||||||
|     clearTimeout(timeoutId) |     clearTimeout(timeoutId) | ||||||
| 
 | 
 | ||||||
|     // if Google API fails to return models, try another approach
 |     // if Google API fails to return models, try another approach
 | ||||||
|     if (res.status === 401 && res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') { |     if (res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') { | ||||||
|       const urlGoogle = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}` |       const urlGoogle = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}` | ||||||
|       const resGoogle = await fetch(urlGoogle, { |       const resGoogle = await fetch(urlGoogle, { | ||||||
|         signal: controller.signal |         signal: controller.signal | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| const tags = ["think", "reason", "reasoning", "thought"] | const tags = ["think", "reason", "reasoning", "thought"] | ||||||
| export function parseReasoning( | export function parseReasoning(text: string): { | ||||||
|     text: string |  | ||||||
| ): { |  | ||||||
|   type: "reasoning" | "text" |   type: "reasoning" | "text" | ||||||
|   content: string |   content: string | ||||||
|   reasoning_running?: boolean |   reasoning_running?: boolean | ||||||
| @ -90,20 +88,13 @@ export function removeReasoning(text: string): string { | |||||||
|   ) |   ) | ||||||
|   return text.replace(tagPattern, "").trim() |   return text.replace(tagPattern, "").trim() | ||||||
| } | } | ||||||
| export function mergeReasoningContent(originalText: string, reasoning: string): string { | export function mergeReasoningContent( | ||||||
|     const defaultReasoningTag = "think" |   originalText: string, | ||||||
|     const tagPattern = new RegExp(`<(${tags.join("|")})>(.*?)</(${tags.join("|")})>`, "is") |   reasoning: string | ||||||
|     const hasReasoningTag = tagPattern.test(originalText) | ): string { | ||||||
|  |   const reasoningTag = "<think>" | ||||||
| 
 | 
 | ||||||
|     if (hasReasoningTag) { |   originalText = originalText.replace(reasoningTag, "") | ||||||
|         const match = originalText.match(tagPattern) |  | ||||||
|         if (match) { |  | ||||||
|             const [fullMatch, tag, existingContent] = match |  | ||||||
|             const remainingText = originalText.replace(fullMatch, '').trim() |  | ||||||
|             const newContent = `${existingContent.trim()}${reasoning}` |  | ||||||
|             return `<${tag}>${newContent}</${tag}> ${remainingText}` |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     return `<${defaultReasoningTag}>${reasoning}</${defaultReasoningTag}> ${originalText.trim()}`.trim() |   return `${reasoningTag}${originalText + reasoning}`.trim() | ||||||
| } | } | ||||||
							
								
								
									
										78
									
								
								src/models/CustomAIMessageChunk.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								src/models/CustomAIMessageChunk.ts
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,78 @@ | |||||||
|  | interface BaseMessageFields { | ||||||
|  |     content: string; | ||||||
|  |     name?: string; | ||||||
|  |     additional_kwargs?: { | ||||||
|  |         [key: string]: unknown; | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | export class CustomAIMessageChunk { | ||||||
|  |     /** The text of the message. */ | ||||||
|  |     content: string; | ||||||
|  | 
 | ||||||
|  |     /** The name of the message sender in a multi-user chat. */ | ||||||
|  |     name?: string; | ||||||
|  | 
 | ||||||
|  |     /** Additional keyword arguments */ | ||||||
|  |     additional_kwargs: NonNullable<BaseMessageFields["additional_kwargs"]>; | ||||||
|  | 
 | ||||||
|  |     constructor(fields: BaseMessageFields) { | ||||||
|  |         // Make sure the default value for additional_kwargs is passed into super() for serialization
 | ||||||
|  |         if (!fields.additional_kwargs) { | ||||||
|  |             // eslint-disable-next-line no-param-reassign
 | ||||||
|  |             fields.additional_kwargs = {}; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         this.name = fields.name; | ||||||
|  |         this.content = fields.content; | ||||||
|  |         this.additional_kwargs = fields.additional_kwargs; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static _mergeAdditionalKwargs( | ||||||
|  |         left: NonNullable<BaseMessageFields["additional_kwargs"]>, | ||||||
|  |         right: NonNullable<BaseMessageFields["additional_kwargs"]> | ||||||
|  |     ): NonNullable<BaseMessageFields["additional_kwargs"]> { | ||||||
|  |         const merged = { ...left }; | ||||||
|  |         for (const [key, value] of Object.entries(right)) { | ||||||
|  |             if (merged[key] === undefined) { | ||||||
|  |                 merged[key] = value; | ||||||
|  |             }else if (typeof merged[key] === "string") { | ||||||
|  |                 merged[key] = (merged[key] as string) + value; | ||||||
|  |             } else if ( | ||||||
|  |                 !Array.isArray(merged[key]) && | ||||||
|  |                 typeof merged[key] === "object" | ||||||
|  |             ) { | ||||||
|  |                 merged[key] = this._mergeAdditionalKwargs( | ||||||
|  |                     merged[key] as NonNullable<BaseMessageFields["additional_kwargs"]>, | ||||||
|  |                     value as NonNullable<BaseMessageFields["additional_kwargs"]> | ||||||
|  |                 ); | ||||||
|  |             } else { | ||||||
|  |                 throw new Error( | ||||||
|  |                     `additional_kwargs[${key}] already exists in this message chunk.` | ||||||
|  |                 ); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         return merged; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     concat(chunk: CustomAIMessageChunk) { | ||||||
|  |         return new CustomAIMessageChunk({ | ||||||
|  |             content: this.content + chunk.content, | ||||||
|  |             additional_kwargs: CustomAIMessageChunk._mergeAdditionalKwargs( | ||||||
|  |                 this.additional_kwargs, | ||||||
|  |                 chunk.additional_kwargs | ||||||
|  |             ), | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | function isAiMessageChunkFields(value: unknown): value is BaseMessageFields { | ||||||
|  |     if (typeof value !== "object" || value == null) return false; | ||||||
|  |     return "content" in value && typeof value["content"] === "string"; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | function isAiMessageChunkFieldsList( | ||||||
|  |     value: unknown[] | ||||||
|  | ): value is BaseMessageFields[] { | ||||||
|  |     return value.length > 0 && value.every((x) => isAiMessageChunkFields(x)); | ||||||
|  | } | ||||||
							
								
								
									
										914
									
								
								src/models/CustomChatOpenAI.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										914
									
								
								src/models/CustomChatOpenAI.ts
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,914 @@ | |||||||
|  | import { type ClientOptions, OpenAI as OpenAIClient } from "openai" | ||||||
|  | import { | ||||||
|  |     AIMessage, | ||||||
|  |     AIMessageChunk, | ||||||
|  |     BaseMessage, | ||||||
|  |     ChatMessage, | ||||||
|  |     ChatMessageChunk, | ||||||
|  |     FunctionMessageChunk, | ||||||
|  |     HumanMessageChunk, | ||||||
|  |     SystemMessageChunk, | ||||||
|  |     ToolMessageChunk, | ||||||
|  | } from "@langchain/core/messages" | ||||||
|  | import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs" | ||||||
|  | import { getEnvironmentVariable } from "@langchain/core/utils/env" | ||||||
|  | import { BaseChatModel, BaseChatModelParams } from "@langchain/core/language_models/chat_models" | ||||||
|  | import { convertToOpenAITool } from "@langchain/core/utils/function_calling" | ||||||
|  | import { | ||||||
|  |     RunnablePassthrough, | ||||||
|  |     RunnableSequence | ||||||
|  | } from "@langchain/core/runnables" | ||||||
|  | import { | ||||||
|  |     JsonOutputParser, | ||||||
|  |     StructuredOutputParser | ||||||
|  | } from "@langchain/core/output_parsers" | ||||||
|  | import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools" | ||||||
|  | import { wrapOpenAIClientError } from "./utils/openai.js" | ||||||
|  | import { | ||||||
|  |     ChatOpenAICallOptions, | ||||||
|  |     getEndpoint, | ||||||
|  |     OpenAIChatInput, | ||||||
|  |     OpenAICoreRequestOptions | ||||||
|  | } from "@langchain/openai" | ||||||
|  | import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager" | ||||||
|  | import { TokenUsage } from "@langchain/core/language_models/base" | ||||||
|  | import { LegacyOpenAIInput } from "./types.js" | ||||||
|  | import { CustomAIMessageChunk } from "./CustomAIMessageChunk.js" | ||||||
|  | 
 | ||||||
|  | type OpenAIRoleEnum = "system" | "assistant" | "user" | "function" | "tool" | ||||||
|  | 
 | ||||||
|  | function extractGenericMessageCustomRole(message: ChatMessage) { | ||||||
|  |     if ( | ||||||
|  |         message.role !== "system" && | ||||||
|  |         message.role !== "assistant" && | ||||||
|  |         message.role !== "user" && | ||||||
|  |         message.role !== "function" && | ||||||
|  |         message.role !== "tool" | ||||||
|  |     ) { | ||||||
|  |         console.warn(`Unknown message role: ${message.role}`) | ||||||
|  |     } | ||||||
|  |     return message.role | ||||||
|  | } | ||||||
|  | export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum { | ||||||
|  |     const type = message._getType() | ||||||
|  |     switch (type) { | ||||||
|  |         case "system": | ||||||
|  |             return "system" | ||||||
|  |         case "ai": | ||||||
|  |             return "assistant" | ||||||
|  |         case "human": | ||||||
|  |             return "user" | ||||||
|  |         case "function": | ||||||
|  |             return "function" | ||||||
|  |         case "tool": | ||||||
|  |             return "tool" | ||||||
|  |         case "generic": { | ||||||
|  |             if (!ChatMessage.isInstance(message)) | ||||||
|  |                 throw new Error("Invalid generic chat message") | ||||||
|  |             return extractGenericMessageCustomRole(message) as OpenAIRoleEnum | ||||||
|  |         } | ||||||
|  |         default: | ||||||
|  |             return type | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | function openAIResponseToChatMessage( | ||||||
|  |     message: OpenAIClient.Chat.Completions.ChatCompletionMessage | ||||||
|  | ) { | ||||||
|  |     switch (message.role) { | ||||||
|  |         case "assistant": | ||||||
|  |             return new AIMessage(message.content || "", { | ||||||
|  |                 // function_call: message.function_call,
 | ||||||
|  |                 // tool_calls: message.tool_calls
 | ||||||
|  |                 // reasoning_content: message?.reasoning_content || null
 | ||||||
|  |             }) | ||||||
|  |         default: | ||||||
|  |             return new ChatMessage(message.content || "", message.role ?? "unknown") | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | function _convertDeltaToMessageChunk( | ||||||
|  |     // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  |     delta: Record<string, any>, | ||||||
|  |     defaultRole?: OpenAIRoleEnum | ||||||
|  | ) { | ||||||
|  |     const role = delta.role ?? defaultRole | ||||||
|  |     const content = delta.content ?? "" | ||||||
|  |     const reasoning_content: string | null = delta.reasoning_content ?? null | ||||||
|  |     let additional_kwargs | ||||||
|  |     if (delta.function_call) { | ||||||
|  |         additional_kwargs = { | ||||||
|  |             function_call: delta.function_call | ||||||
|  |         } | ||||||
|  |     } else if (delta.tool_calls) { | ||||||
|  |         additional_kwargs = { | ||||||
|  |             tool_calls: delta.tool_calls | ||||||
|  |         } | ||||||
|  |     } else { | ||||||
|  |         additional_kwargs = {} | ||||||
|  |     } | ||||||
|  |     if (role === "user") { | ||||||
|  |         return new HumanMessageChunk({ content }) | ||||||
|  |     } else if (role === "assistant") { | ||||||
|  |         return new CustomAIMessageChunk({ | ||||||
|  |             content, | ||||||
|  |             additional_kwargs: { | ||||||
|  |                 ...additional_kwargs, | ||||||
|  |                 reasoning_content | ||||||
|  |             } | ||||||
|  |         }) as any | ||||||
|  |     } else if (role === "system") { | ||||||
|  |         return new SystemMessageChunk({ content }) | ||||||
|  |     } else if (role === "function") { | ||||||
|  |         return new FunctionMessageChunk({ | ||||||
|  |             content, | ||||||
|  |             additional_kwargs, | ||||||
|  |             name: delta.name | ||||||
|  |         }) | ||||||
|  |     } else if (role === "tool") { | ||||||
|  |         return new ToolMessageChunk({ | ||||||
|  |             content, | ||||||
|  |             additional_kwargs, | ||||||
|  |             tool_call_id: delta.tool_call_id | ||||||
|  |         }) | ||||||
|  |     } else { | ||||||
|  |         return new ChatMessageChunk({ content, role }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | function convertMessagesToOpenAIParams(messages: any[]) { | ||||||
|  |     // TODO: Function messages do not support array content, fix cast
 | ||||||
|  |     return messages.map((message) => { | ||||||
|  |         // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  |         const completionParam: { role: string; content: string; name?: string } = { | ||||||
|  |             role: messageToOpenAIRole(message), | ||||||
|  |             content: message.content | ||||||
|  |         } | ||||||
|  |         if (message.name != null) { | ||||||
|  |             completionParam.name = message.name | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return completionParam | ||||||
|  |     }) | ||||||
|  | } | ||||||
|  | export class CustomChatOpenAI< | ||||||
|  |     CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions | ||||||
|  | > | ||||||
|  |     extends BaseChatModel<CallOptions> | ||||||
|  |     implements OpenAIChatInput { | ||||||
|  |     temperature = 1 | ||||||
|  | 
 | ||||||
|  |     topP = 1 | ||||||
|  | 
 | ||||||
|  |     frequencyPenalty = 0 | ||||||
|  | 
 | ||||||
|  |     presencePenalty = 0 | ||||||
|  | 
 | ||||||
|  |     n = 1 | ||||||
|  | 
 | ||||||
|  |     logitBias?: Record<string, number> | ||||||
|  | 
 | ||||||
|  |     modelName = "gpt-3.5-turbo" | ||||||
|  | 
 | ||||||
|  |     model = "gpt-3.5-turbo" | ||||||
|  | 
 | ||||||
|  |     modelKwargs?: OpenAIChatInput["modelKwargs"] | ||||||
|  | 
 | ||||||
|  |     stop?: string[] | ||||||
|  | 
 | ||||||
|  |     stopSequences?: string[] | ||||||
|  | 
 | ||||||
|  |     user?: string | ||||||
|  | 
 | ||||||
|  |     timeout?: number | ||||||
|  | 
 | ||||||
|  |     streaming = false | ||||||
|  | 
 | ||||||
|  |     streamUsage = true | ||||||
|  | 
 | ||||||
|  |     maxTokens?: number | ||||||
|  | 
 | ||||||
|  |     logprobs?: boolean | ||||||
|  | 
 | ||||||
|  |     topLogprobs?: number | ||||||
|  | 
 | ||||||
|  |     openAIApiKey?: string | ||||||
|  | 
 | ||||||
|  |     apiKey?: string | ||||||
|  | 
 | ||||||
|  |     azureOpenAIApiVersion?: string | ||||||
|  | 
 | ||||||
|  |     azureOpenAIApiKey?: string | ||||||
|  | 
 | ||||||
|  |     azureADTokenProvider?: () => Promise<string> | ||||||
|  | 
 | ||||||
|  |     azureOpenAIApiInstanceName?: string | ||||||
|  | 
 | ||||||
|  |     azureOpenAIApiDeploymentName?: string | ||||||
|  | 
 | ||||||
|  |     azureOpenAIBasePath?: string | ||||||
|  | 
 | ||||||
|  |     organization?: string | ||||||
|  | 
 | ||||||
|  |     protected client: OpenAIClient | ||||||
|  | 
 | ||||||
|  |     protected clientConfig: ClientOptions | ||||||
|  |     static lc_name() { | ||||||
|  |         return "ChatOpenAI" | ||||||
|  |     } | ||||||
|  |     get callKeys() { | ||||||
|  |         return [ | ||||||
|  |             ...super.callKeys, | ||||||
|  |             "options", | ||||||
|  |             "function_call", | ||||||
|  |             "functions", | ||||||
|  |             "tools", | ||||||
|  |             "tool_choice", | ||||||
|  |             "promptIndex", | ||||||
|  |             "response_format", | ||||||
|  |             "seed" | ||||||
|  |         ] | ||||||
|  |     } | ||||||
|  |     get lc_secrets() { | ||||||
|  |         return { | ||||||
|  |             openAIApiKey: "OPENAI_API_KEY", | ||||||
|  |             azureOpenAIApiKey: "AZURE_OPENAI_API_KEY", | ||||||
|  |             organization: "OPENAI_ORGANIZATION" | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     get lc_aliases() { | ||||||
|  |         return { | ||||||
|  |             modelName: "model", | ||||||
|  |             openAIApiKey: "openai_api_key", | ||||||
|  |             azureOpenAIApiVersion: "azure_openai_api_version", | ||||||
|  |             azureOpenAIApiKey: "azure_openai_api_key", | ||||||
|  |             azureOpenAIApiInstanceName: "azure_openai_api_instance_name", | ||||||
|  |             azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name" | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     constructor( | ||||||
|  |         fields?: Partial<OpenAIChatInput> & | ||||||
|  |             BaseChatModelParams & { | ||||||
|  |                 configuration?: ClientOptions & LegacyOpenAIInput; | ||||||
|  |             }, | ||||||
|  |         /** @deprecated */ | ||||||
|  |         configuration?: ClientOptions & LegacyOpenAIInput | ||||||
|  |     ) { | ||||||
|  |         super(fields ?? {}) | ||||||
|  |         Object.defineProperty(this, "lc_serializable", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: true | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "temperature", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: 1 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "topP", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: 1 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "frequencyPenalty", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "presencePenalty", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "n", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: 1 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "logitBias", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "modelName", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: "gpt-3.5-turbo" | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "modelKwargs", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "stop", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "user", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "timeout", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "streaming", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: false | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "maxTokens", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "logprobs", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "topLogprobs", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "openAIApiKey", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "azureOpenAIApiVersion", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "azureOpenAIApiKey", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "azureOpenAIApiInstanceName", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "azureOpenAIApiDeploymentName", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "azureOpenAIBasePath", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "organization", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "client", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         Object.defineProperty(this, "clientConfig", { | ||||||
|  |             enumerable: true, | ||||||
|  |             configurable: true, | ||||||
|  |             writable: true, | ||||||
|  |             value: void 0 | ||||||
|  |         }) | ||||||
|  |         this.openAIApiKey = | ||||||
|  |             fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY") | ||||||
|  | 
 | ||||||
|  |         this.modelName = fields?.modelName ?? this.modelName | ||||||
|  |         this.modelKwargs = fields?.modelKwargs ?? {} | ||||||
|  |         this.timeout = fields?.timeout | ||||||
|  |         this.temperature = fields?.temperature ?? this.temperature | ||||||
|  |         this.topP = fields?.topP ?? this.topP | ||||||
|  |         this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty | ||||||
|  |         this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty | ||||||
|  |         this.maxTokens = fields?.maxTokens | ||||||
|  |         this.logprobs = fields?.logprobs | ||||||
|  |         this.topLogprobs = fields?.topLogprobs | ||||||
|  |         this.n = fields?.n ?? this.n | ||||||
|  |         this.logitBias = fields?.logitBias | ||||||
|  |         this.stop = fields?.stop | ||||||
|  |         this.user = fields?.user | ||||||
|  |         this.streaming = fields?.streaming ?? false | ||||||
|  |         this.clientConfig = { | ||||||
|  |             apiKey: this.openAIApiKey, | ||||||
|  |             organization: this.organization, | ||||||
|  |             baseURL: configuration?.basePath ?? fields?.configuration?.basePath, | ||||||
|  |             dangerouslyAllowBrowser: true, | ||||||
|  |             defaultHeaders: | ||||||
|  |                 configuration?.baseOptions?.headers ?? | ||||||
|  |                 fields?.configuration?.baseOptions?.headers, | ||||||
|  |             defaultQuery: | ||||||
|  |                 configuration?.baseOptions?.params ?? | ||||||
|  |                 fields?.configuration?.baseOptions?.params, | ||||||
|  |             ...configuration, | ||||||
|  |             ...fields?.configuration | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     /** | ||||||
|  |      * Get the parameters used to invoke the model | ||||||
|  |      */ | ||||||
|  |     invocationParams(options) { | ||||||
|  |         function isStructuredToolArray(tools) { | ||||||
|  |             return ( | ||||||
|  |                 tools !== undefined && | ||||||
|  |                 tools.every((tool) => Array.isArray(tool.lc_namespace)) | ||||||
|  |             ) | ||||||
|  |         } | ||||||
|  |         const params = { | ||||||
|  |             model: this.modelName, | ||||||
|  |             temperature: this.temperature, | ||||||
|  |             top_p: this.topP, | ||||||
|  |             frequency_penalty: this.frequencyPenalty, | ||||||
|  |             presence_penalty: this.presencePenalty, | ||||||
|  |             max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens, | ||||||
|  |             logprobs: this.logprobs, | ||||||
|  |             top_logprobs: this.topLogprobs, | ||||||
|  |             n: this.n, | ||||||
|  |             logit_bias: this.logitBias, | ||||||
|  |             stop: options?.stop ?? this.stop, | ||||||
|  |             user: this.user, | ||||||
|  |             stream: this.streaming, | ||||||
|  |             functions: options?.functions, | ||||||
|  |             function_call: options?.function_call, | ||||||
|  |             tools: isStructuredToolArray(options?.tools) | ||||||
|  |                 ? options?.tools.map(convertToOpenAITool) | ||||||
|  |                 : options?.tools, | ||||||
|  |             tool_choice: options?.tool_choice, | ||||||
|  |             response_format: options?.response_format, | ||||||
|  |             seed: options?.seed, | ||||||
|  |             ...this.modelKwargs | ||||||
|  |         } | ||||||
|  |         return params | ||||||
|  |     } | ||||||
|  |     /** @ignore */ | ||||||
|  |     _identifyingParams() { | ||||||
|  |         return { | ||||||
|  |             model_name: this.modelName, | ||||||
|  |             //@ts-ignore
 | ||||||
|  |             ...this?.invocationParams(), | ||||||
|  |             ...this.clientConfig | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     async *_streamResponseChunks( | ||||||
|  |         messages: BaseMessage[], | ||||||
|  |         options: this["ParsedCallOptions"], | ||||||
|  |         runManager?: CallbackManagerForLLMRun | ||||||
|  |     ): AsyncGenerator<ChatGenerationChunk> { | ||||||
|  |         const messagesMapped = convertMessagesToOpenAIParams(messages) | ||||||
|  |         const params = { | ||||||
|  |             ...this.invocationParams(options), | ||||||
|  |             messages: messagesMapped, | ||||||
|  |             stream: true | ||||||
|  |         } | ||||||
|  |         let defaultRole | ||||||
|  |         //@ts-ignore
 | ||||||
|  |         const streamIterable = await this.completionWithRetry(params, options) | ||||||
|  |         for await (const data of streamIterable) { | ||||||
|  |             const choice = data?.choices[0] | ||||||
|  |             if (!choice) { | ||||||
|  |                 continue | ||||||
|  |             } | ||||||
|  |             const { delta } = choice | ||||||
|  |             if (!delta) { | ||||||
|  |                 continue | ||||||
|  |             } | ||||||
|  |             const chunk = _convertDeltaToMessageChunk(delta, defaultRole) | ||||||
|  |             defaultRole = delta.role ?? defaultRole | ||||||
|  |             const newTokenIndices = { | ||||||
|  |                 //@ts-ignore
 | ||||||
|  |                 prompt: options?.promptIndex ?? 0, | ||||||
|  |                 completion: choice.index ?? 0 | ||||||
|  |             } | ||||||
|  |             if (typeof chunk.content !== "string") { | ||||||
|  |                 console.log( | ||||||
|  |                     "[WARNING]: Received non-string content from OpenAI. This is currently not supported." | ||||||
|  |                 ) | ||||||
|  |                 continue | ||||||
|  |             } | ||||||
|  |             // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  |             const generationInfo = { ...newTokenIndices } as any | ||||||
|  |             if (choice.finish_reason !== undefined) { | ||||||
|  |                 generationInfo.finish_reason = choice.finish_reason | ||||||
|  |             } | ||||||
|  |             if (this.logprobs) { | ||||||
|  |                 generationInfo.logprobs = choice.logprobs | ||||||
|  |             } | ||||||
|  |             const generationChunk = new ChatGenerationChunk({ | ||||||
|  |                 message: chunk, | ||||||
|  |                 text: chunk.content, | ||||||
|  |                 generationInfo | ||||||
|  |             }) | ||||||
|  |             yield generationChunk | ||||||
|  |             // eslint-disable-next-line no-void
 | ||||||
|  |             void runManager?.handleLLMNewToken( | ||||||
|  |                 generationChunk.text ?? "", | ||||||
|  |                 newTokenIndices, | ||||||
|  |                 undefined, | ||||||
|  |                 undefined, | ||||||
|  |                 undefined, | ||||||
|  |                 { chunk: generationChunk } | ||||||
|  |             ) | ||||||
|  |         } | ||||||
|  |         if (options.signal?.aborted) { | ||||||
|  |             throw new Error("AbortError") | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     /** | ||||||
|  |      * Get the identifying parameters for the model | ||||||
|  |      * | ||||||
|  |      */ | ||||||
|  |     identifyingParams() { | ||||||
|  |         return this._identifyingParams() | ||||||
|  |     } | ||||||
|  |     /** @ignore */ | ||||||
|  |     async _generate( | ||||||
|  |         messages: BaseMessage[], | ||||||
|  |         options: this["ParsedCallOptions"], | ||||||
|  |         runManager?: CallbackManagerForLLMRun | ||||||
|  |     ): Promise<ChatResult> { | ||||||
|  |         const tokenUsage: TokenUsage = {} | ||||||
|  |         const params = this.invocationParams(options) | ||||||
|  |         const messagesMapped: any[] = convertMessagesToOpenAIParams(messages) | ||||||
|  |         if (params.stream) { | ||||||
|  |             const stream = this._streamResponseChunks(messages, options, runManager) | ||||||
|  |             const finalChunks: Record<number, ChatGenerationChunk> = {} | ||||||
|  |             for await (const chunk of stream) { | ||||||
|  |                 //@ts-ignore
 | ||||||
|  |                 chunk.message.response_metadata = { | ||||||
|  |                     ...chunk.generationInfo, | ||||||
|  |                     //@ts-ignore
 | ||||||
|  |                     ...chunk.message.response_metadata | ||||||
|  |                 } | ||||||
|  |                 const index = chunk.generationInfo?.completion ?? 0 | ||||||
|  |                 if (finalChunks[index] === undefined) { | ||||||
|  |                     finalChunks[index] = chunk | ||||||
|  |                 } else { | ||||||
|  |                     finalChunks[index] = finalChunks[index].concat(chunk) | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             const generations = Object.entries(finalChunks) | ||||||
|  |                 .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10)) | ||||||
|  |                 .map(([_, value]) => value) | ||||||
|  |             const { functions, function_call } = this.invocationParams(options) | ||||||
|  |             // OpenAI does not support token usage report under stream mode,
 | ||||||
|  |             // fallback to estimation.
 | ||||||
|  |             const promptTokenUsage = await this.getEstimatedTokenCountFromPrompt( | ||||||
|  |                 messages, | ||||||
|  |                 functions, | ||||||
|  |                 function_call | ||||||
|  |             ) | ||||||
|  |             const completionTokenUsage = | ||||||
|  |                 await this.getNumTokensFromGenerations(generations) | ||||||
|  |             tokenUsage.promptTokens = promptTokenUsage | ||||||
|  |             tokenUsage.completionTokens = completionTokenUsage | ||||||
|  |             tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage | ||||||
|  |             return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } } | ||||||
|  |         } else { | ||||||
|  |             const data = await this.completionWithRetry( | ||||||
|  |                 { | ||||||
|  |                     ...params, | ||||||
|  |                     //@ts-ignore
 | ||||||
|  |                     stream: false, | ||||||
|  |                     messages: messagesMapped | ||||||
|  |                 }, | ||||||
|  |                 { | ||||||
|  |                     signal: options?.signal, | ||||||
|  |                     //@ts-ignore
 | ||||||
|  |                     ...options?.options | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |             const { | ||||||
|  |                 completion_tokens: completionTokens, | ||||||
|  |                 prompt_tokens: promptTokens, | ||||||
|  |                 total_tokens: totalTokens | ||||||
|  |                 //@ts-ignore
 | ||||||
|  |             } = data?.usage ?? {} | ||||||
|  |             if (completionTokens) { | ||||||
|  |                 tokenUsage.completionTokens = | ||||||
|  |                     (tokenUsage.completionTokens ?? 0) + completionTokens | ||||||
|  |             } | ||||||
|  |             if (promptTokens) { | ||||||
|  |                 tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens | ||||||
|  |             } | ||||||
|  |             if (totalTokens) { | ||||||
|  |                 tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens | ||||||
|  |             } | ||||||
|  |             const generations = [] | ||||||
|  |             //@ts-ignore
 | ||||||
|  |             for (const part of data?.choices ?? []) { | ||||||
|  |                 const text = part.message?.content ?? "" | ||||||
|  |                 const generation = { | ||||||
|  |                     text, | ||||||
|  |                     message: openAIResponseToChatMessage( | ||||||
|  |                         part.message ?? { role: "assistant" } | ||||||
|  |                     ) | ||||||
|  |                 } | ||||||
|  |                 //@ts-ignore
 | ||||||
|  |                 generation.generationInfo = { | ||||||
|  |                     ...(part.finish_reason ? { finish_reason: part.finish_reason } : {}), | ||||||
|  |                     ...(part.logprobs ? { logprobs: part.logprobs } : {}) | ||||||
|  |                 } | ||||||
|  |                 generations.push(generation) | ||||||
|  |             } | ||||||
|  |             return { | ||||||
|  |                 generations, | ||||||
|  |                 llmOutput: { tokenUsage } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     /** | ||||||
|  |      * Estimate the number of tokens a prompt will use. | ||||||
|  |      * Modified from: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts
 | ||||||
|  |      */ | ||||||
|  |     async getEstimatedTokenCountFromPrompt(messages, functions, function_call) { | ||||||
|  |         let tokens = (await this.getNumTokensFromMessages(messages)).totalCount | ||||||
|  |         if (functions && messages.find((m) => m._getType() === "system")) { | ||||||
|  |             tokens -= 4 | ||||||
|  |         } | ||||||
|  |         if (function_call === "none") { | ||||||
|  |             tokens += 1 | ||||||
|  |         } else if (typeof function_call === "object") { | ||||||
|  |             tokens += (await this.getNumTokens(function_call.name)) + 4 | ||||||
|  |         } | ||||||
|  |         return tokens | ||||||
|  |     } | ||||||
|  |     /** | ||||||
|  |      * Estimate the number of tokens an array of generations have used. | ||||||
|  |      */ | ||||||
|  |     async getNumTokensFromGenerations(generations) { | ||||||
|  |         const generationUsages = await Promise.all( | ||||||
|  |             generations.map(async (generation) => { | ||||||
|  |                 if (generation.message.additional_kwargs?.function_call) { | ||||||
|  |                     return (await this.getNumTokensFromMessages([generation.message])) | ||||||
|  |                         .countPerMessage[0] | ||||||
|  |                 } else { | ||||||
|  |                     return await this.getNumTokens(generation.message.content) | ||||||
|  |                 } | ||||||
|  |             }) | ||||||
|  |         ) | ||||||
|  |         return generationUsages.reduce((a, b) => a + b, 0) | ||||||
|  |     } | ||||||
|  |     async getNumTokensFromMessages(messages) { | ||||||
|  |         let totalCount = 0 | ||||||
|  |         let tokensPerMessage = 0 | ||||||
|  |         let tokensPerName = 0 | ||||||
|  |         // From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
 | ||||||
|  |         if (this.modelName === "gpt-3.5-turbo-0301") { | ||||||
|  |             tokensPerMessage = 4 | ||||||
|  |             tokensPerName = -1 | ||||||
|  |         } else { | ||||||
|  |             tokensPerMessage = 3 | ||||||
|  |             tokensPerName = 1 | ||||||
|  |         } | ||||||
|  |         const countPerMessage = await Promise.all( | ||||||
|  |             messages.map(async (message) => { | ||||||
|  |                 const textCount = await this.getNumTokens(message.content) | ||||||
|  |                 const roleCount = await this.getNumTokens(messageToOpenAIRole(message)) | ||||||
|  |                 const nameCount = | ||||||
|  |                     message.name !== undefined | ||||||
|  |                         ? tokensPerName + (await this.getNumTokens(message.name)) | ||||||
|  |                         : 0 | ||||||
|  |                 let count = textCount + tokensPerMessage + roleCount + nameCount | ||||||
|  |                 // From: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts messageTokenEstimate
 | ||||||
|  |                 const openAIMessage = message | ||||||
|  |                 if (openAIMessage._getType() === "function") { | ||||||
|  |                     count -= 2 | ||||||
|  |                 } | ||||||
|  |                 if (openAIMessage.additional_kwargs?.function_call) { | ||||||
|  |                     count += 3 | ||||||
|  |                 } | ||||||
|  |                 if (openAIMessage?.additional_kwargs.function_call?.name) { | ||||||
|  |                     count += await this.getNumTokens( | ||||||
|  |                         openAIMessage.additional_kwargs.function_call?.name | ||||||
|  |                     ) | ||||||
|  |                 } | ||||||
|  |                 if (openAIMessage.additional_kwargs.function_call?.arguments) { | ||||||
|  |                     try { | ||||||
|  |                         count += await this.getNumTokens( | ||||||
|  |                             // Remove newlines and spaces
 | ||||||
|  |                             JSON.stringify( | ||||||
|  |                                 JSON.parse( | ||||||
|  |                                     openAIMessage.additional_kwargs.function_call?.arguments | ||||||
|  |                                 ) | ||||||
|  |                             ) | ||||||
|  |                         ) | ||||||
|  |                     } catch (error) { | ||||||
|  |                         console.error( | ||||||
|  |                             "Error parsing function arguments", | ||||||
|  |                             error, | ||||||
|  |                             JSON.stringify(openAIMessage.additional_kwargs.function_call) | ||||||
|  |                         ) | ||||||
|  |                         count += await this.getNumTokens( | ||||||
|  |                             openAIMessage.additional_kwargs.function_call?.arguments | ||||||
|  |                         ) | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 totalCount += count | ||||||
|  |                 return count | ||||||
|  |             }) | ||||||
|  |         ) | ||||||
|  |         totalCount += 3 // every reply is primed with <|start|>assistant<|message|>
 | ||||||
|  |         return { totalCount, countPerMessage } | ||||||
|  |     } | ||||||
|  |     async completionWithRetry( | ||||||
|  |         request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming, | ||||||
|  |         options?: OpenAICoreRequestOptions | ||||||
|  |     ) { | ||||||
|  |         const requestOptions = this._getClientOptions(options) | ||||||
|  |         return this.caller.call(async () => { | ||||||
|  |             try { | ||||||
|  |                 const res = await this.client.chat.completions.create( | ||||||
|  |                     request, | ||||||
|  |                     requestOptions | ||||||
|  |                 ) | ||||||
|  |                 return res | ||||||
|  |             } catch (e) { | ||||||
|  |                 const error = wrapOpenAIClientError(e) | ||||||
|  |                 throw error | ||||||
|  |             } | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  |     _getClientOptions(options) { | ||||||
|  |         if (!this.client) { | ||||||
|  |             const openAIEndpointConfig = { | ||||||
|  |                 azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, | ||||||
|  |                 azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, | ||||||
|  |                 azureOpenAIApiKey: this.azureOpenAIApiKey, | ||||||
|  |                 azureOpenAIBasePath: this.azureOpenAIBasePath, | ||||||
|  |                 baseURL: this.clientConfig.baseURL | ||||||
|  |             } | ||||||
|  |             const endpoint = getEndpoint(openAIEndpointConfig) | ||||||
|  |             const params = { | ||||||
|  |                 ...this.clientConfig, | ||||||
|  |                 baseURL: endpoint, | ||||||
|  |                 timeout: this.timeout, | ||||||
|  |                 maxRetries: 0 | ||||||
|  |             } | ||||||
|  |             if (!params.baseURL) { | ||||||
|  |                 delete params.baseURL | ||||||
|  |             } | ||||||
|  |             this.client = new OpenAIClient(params) | ||||||
|  |         } | ||||||
|  |         const requestOptions = { | ||||||
|  |             ...this.clientConfig, | ||||||
|  |             ...options | ||||||
|  |         } | ||||||
|  |         if (this.azureOpenAIApiKey) { | ||||||
|  |             requestOptions.headers = { | ||||||
|  |                 "api-key": this.azureOpenAIApiKey, | ||||||
|  |                 ...requestOptions.headers | ||||||
|  |             } | ||||||
|  |             requestOptions.query = { | ||||||
|  |                 "api-version": this.azureOpenAIApiVersion, | ||||||
|  |                 ...requestOptions.query | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         return requestOptions | ||||||
|  |     } | ||||||
|  |     _llmType() { | ||||||
|  |         return "openai" | ||||||
|  |     } | ||||||
|  |     /** @ignore */ | ||||||
|  |     _combineLLMOutput(...llmOutputs) { | ||||||
|  |         return llmOutputs.reduce( | ||||||
|  |             (acc, llmOutput) => { | ||||||
|  |                 if (llmOutput && llmOutput.tokenUsage) { | ||||||
|  |                     acc.tokenUsage.completionTokens += | ||||||
|  |                         llmOutput.tokenUsage.completionTokens ?? 0 | ||||||
|  |                     acc.tokenUsage.promptTokens += llmOutput.tokenUsage.promptTokens ?? 0 | ||||||
|  |                     acc.tokenUsage.totalTokens += llmOutput.tokenUsage.totalTokens ?? 0 | ||||||
|  |                 } | ||||||
|  |                 return acc | ||||||
|  |             }, | ||||||
|  |             { | ||||||
|  |                 tokenUsage: { | ||||||
|  |                     completionTokens: 0, | ||||||
|  |                     promptTokens: 0, | ||||||
|  |                     totalTokens: 0 | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  |     withStructuredOutput(outputSchema, config) { | ||||||
|  |         // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  |         let schema | ||||||
|  |         let name | ||||||
|  |         let method | ||||||
|  |         let includeRaw | ||||||
|  |         if (isStructuredOutputMethodParams(outputSchema)) { | ||||||
|  |             schema = outputSchema.schema | ||||||
|  |             name = outputSchema.name | ||||||
|  |             method = outputSchema.method | ||||||
|  |             includeRaw = outputSchema.includeRaw | ||||||
|  |         } else { | ||||||
|  |             schema = outputSchema | ||||||
|  |             name = config?.name | ||||||
|  |             method = config?.method | ||||||
|  |             includeRaw = config?.includeRaw | ||||||
|  |         } | ||||||
|  |         let llm | ||||||
|  |         let outputParser | ||||||
|  |         if (method === "jsonMode") { | ||||||
|  |             llm = this.bind({ | ||||||
|  |             }) | ||||||
|  |             if (isZodSchema(schema)) { | ||||||
|  |                 outputParser = StructuredOutputParser.fromZodSchema(schema) | ||||||
|  |             } else { | ||||||
|  |                 outputParser = new JsonOutputParser() | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             let functionName = name ?? "extract" | ||||||
|  |             // Is function calling
 | ||||||
|  | 
 | ||||||
|  |             let openAIFunctionDefinition | ||||||
|  |             if ( | ||||||
|  |                 typeof schema.name === "string" && | ||||||
|  |                 typeof schema.parameters === "object" && | ||||||
|  |                 schema.parameters != null | ||||||
|  |             ) { | ||||||
|  |                 openAIFunctionDefinition = schema | ||||||
|  |                 functionName = schema.name | ||||||
|  |             } else { | ||||||
|  |                 openAIFunctionDefinition = { | ||||||
|  |                     name: schema.title ?? functionName, | ||||||
|  |                     description: schema.description ?? "", | ||||||
|  |                     parameters: schema | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             llm = this.bind({ | ||||||
|  | 
 | ||||||
|  |             }) | ||||||
|  |             outputParser = new JsonOutputKeyToolsParser({ | ||||||
|  |                 returnSingle: true, | ||||||
|  |                 keyName: functionName | ||||||
|  |             }) | ||||||
|  |         } | ||||||
|  |         if (!includeRaw) { | ||||||
|  |             return llm.pipe(outputParser) | ||||||
|  |         } | ||||||
|  |         const parserAssign = RunnablePassthrough.assign({ | ||||||
|  |             // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  |             parsed: (input, config) => outputParser.invoke(input.raw, config) | ||||||
|  |         }) | ||||||
|  |         const parserNone = RunnablePassthrough.assign({ | ||||||
|  |             parsed: () => null | ||||||
|  |         }) | ||||||
|  |         const parsedWithFallback = parserAssign.withFallbacks({ | ||||||
|  |             fallbacks: [parserNone] | ||||||
|  |         }) | ||||||
|  |         return RunnableSequence.from([ | ||||||
|  |             { | ||||||
|  |                 raw: llm | ||||||
|  |             }, | ||||||
|  |             parsedWithFallback | ||||||
|  |         ] as any) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | function isZodSchema( | ||||||
|  |     // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  |     input | ||||||
|  | ) { | ||||||
|  |     // Check for a characteristic method of Zod schemas
 | ||||||
|  |     return typeof input?.parse === "function" | ||||||
|  | } | ||||||
|  | function isStructuredOutputMethodParams( | ||||||
|  |     x | ||||||
|  |     // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  | ) { | ||||||
|  |     return ( | ||||||
|  |         x !== undefined && | ||||||
|  |         // eslint-disable-next-line @typescript-eslint/no-explicit-any
 | ||||||
|  |         typeof x.schema === "object" | ||||||
|  |     ) | ||||||
|  | } | ||||||
| @ -2,9 +2,9 @@ import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models" | |||||||
| import { ChatChromeAI } from "./ChatChromeAi" | import { ChatChromeAI } from "./ChatChromeAi" | ||||||
| import { ChatOllama } from "./ChatOllama" | import { ChatOllama } from "./ChatOllama" | ||||||
| import { getOpenAIConfigById } from "@/db/openai" | import { getOpenAIConfigById } from "@/db/openai" | ||||||
| import { ChatOpenAI } from "@langchain/openai" |  | ||||||
| import { urlRewriteRuntime } from "@/libs/runtime" | import { urlRewriteRuntime } from "@/libs/runtime" | ||||||
| import { ChatGoogleAI } from "./ChatGoogleAI" | import { ChatGoogleAI } from "./ChatGoogleAI" | ||||||
|  | import { CustomChatOpenAI } from "./CustomChatOpenAI" | ||||||
| 
 | 
 | ||||||
| export const pageAssistModel = async ({ | export const pageAssistModel = async ({ | ||||||
|   model, |   model, | ||||||
| @ -76,7 +76,7 @@ export const pageAssistModel = async ({ | |||||||
|       }) as any |       }) as any | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     return new ChatOpenAI({ |     return new CustomChatOpenAI({ | ||||||
|       modelName: modelInfo.model_id, |       modelName: modelInfo.model_id, | ||||||
|       openAIApiKey: providerInfo.apiKey || "temp", |       openAIApiKey: providerInfo.apiKey || "temp", | ||||||
|       temperature, |       temperature, | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user