feat: implement CustomAIMessageChunk class and enhance reasoning handling in useMessageOption hook
This commit is contained in:
parent
8381f1c996
commit
926f4e1a4a
@ -206,7 +206,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => {
|
||||
|
||||
return (
|
||||
<div className="flex w-full flex-col items-center p-2 pt-1 pb-4">
|
||||
<form className="relative z-10 flex w-full flex-col items-center justify-center gap-2 text-base">
|
||||
<div className="relative z-10 flex w-full flex-col items-center justify-center gap-2 text-base">
|
||||
<div className="relative flex w-full flex-row justify-center gap-2 lg:w-4/5">
|
||||
<div
|
||||
className={` bg-neutral-50 dark:bg-[#262626] relative w-full max-w-[48rem] p-1 backdrop-blur-lg duration-100 border border-gray-300 rounded-xl dark:border-gray-600
|
||||
@ -449,7 +449,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
@ -36,7 +36,12 @@ import { humanMessageFormatter } from "@/utils/human-message"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
|
||||
import { getScreenshotFromCurrentTab } from "@/libs/get-screenshot"
|
||||
import { isReasoningEnded, isReasoningStarted, removeReasoning } from "@/libs/reasoning"
|
||||
import {
|
||||
isReasoningEnded,
|
||||
isReasoningStarted,
|
||||
mergeReasoningContent,
|
||||
removeReasoning
|
||||
} from "@/libs/reasoning"
|
||||
|
||||
export const useMessage = () => {
|
||||
const {
|
||||
@ -413,7 +418,24 @@ export const useMessage = () => {
|
||||
let reasoningStartTime: Date | null = null
|
||||
let reasoningEndTime: Date | null = null
|
||||
let timetaken = 0
|
||||
let apiReasoning = false
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
if (count === 0) {
|
||||
@ -618,7 +640,7 @@ export const useMessage = () => {
|
||||
const applicationChatHistory = []
|
||||
|
||||
const data = await getScreenshotFromCurrentTab()
|
||||
|
||||
|
||||
const visionImage = data?.screenshot || ""
|
||||
|
||||
if (visionImage === "") {
|
||||
@ -680,7 +702,24 @@ export const useMessage = () => {
|
||||
let reasoningStartTime: Date | undefined = undefined
|
||||
let reasoningEndTime: Date | undefined = undefined
|
||||
let timetaken = 0
|
||||
let apiReasoning = false
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
if (count === 0) {
|
||||
@ -950,8 +989,25 @@ export const useMessage = () => {
|
||||
let reasoningStartTime: Date | null = null
|
||||
let reasoningEndTime: Date | null = null
|
||||
let timetaken = 0
|
||||
let apiReasoning = false
|
||||
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
if (count === 0) {
|
||||
@ -1279,7 +1335,24 @@ export const useMessage = () => {
|
||||
let timetaken = 0
|
||||
let reasoningStartTime: Date | undefined = undefined
|
||||
let reasoningEndTime: Date | undefined = undefined
|
||||
let apiReasoning = false
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
if (count === 0) {
|
||||
@ -1527,7 +1600,24 @@ export const useMessage = () => {
|
||||
let reasoningStartTime: Date | null = null
|
||||
let reasoningEndTime: Date | null = null
|
||||
let timetaken = 0
|
||||
let apiReasoning = false
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
if (count === 0) {
|
||||
|
@ -332,7 +332,24 @@ export const useMessageOption = () => {
|
||||
let count = 0
|
||||
let reasoningStartTime: Date | undefined = undefined
|
||||
let reasoningEndTime: Date | undefined = undefined
|
||||
let apiReasoning = false
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
if (count === 0) {
|
||||
@ -649,19 +666,27 @@ export const useMessageOption = () => {
|
||||
let count = 0
|
||||
let reasoningStartTime: Date | null = null
|
||||
let reasoningEndTime: Date | null = null
|
||||
let apiReasoning: boolean = false
|
||||
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
// console.log(chunk)
|
||||
// if (chunk?.reasoning_content) {
|
||||
// const reasoningContent = mergeReasoningContent(
|
||||
// fullText,
|
||||
// chunk?.reasoning_content || ""
|
||||
// )
|
||||
// contentToSave += reasoningContent
|
||||
// fullText += reasoningContent
|
||||
// }
|
||||
|
||||
if (isReasoningStarted(fullText) && !reasoningStartTime) {
|
||||
reasoningStartTime = new Date()
|
||||
@ -992,8 +1017,25 @@ export const useMessageOption = () => {
|
||||
let count = 0
|
||||
let reasoningStartTime: Date | undefined = undefined
|
||||
let reasoningEndTime: Date | undefined = undefined
|
||||
let apiReasoning = false
|
||||
|
||||
for await (const chunk of chunks) {
|
||||
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||
const reasoningContent = mergeReasoningContent(
|
||||
fullText,
|
||||
chunk?.additional_kwargs?.reasoning_content || ""
|
||||
)
|
||||
contentToSave = reasoningContent
|
||||
fullText = reasoningContent
|
||||
apiReasoning = true
|
||||
} else {
|
||||
if (apiReasoning) {
|
||||
fullText += "</think>"
|
||||
contentToSave += "</think>"
|
||||
apiReasoning = false
|
||||
}
|
||||
}
|
||||
|
||||
contentToSave += chunk?.content
|
||||
fullText += chunk?.content
|
||||
if (count === 0) {
|
||||
|
@ -25,7 +25,7 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => {
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
// if Google API fails to return models, try another approach
|
||||
if (res.status === 401 && res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') {
|
||||
if (res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') {
|
||||
const urlGoogle = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`
|
||||
const resGoogle = await fetch(urlGoogle, {
|
||||
signal: controller.signal
|
||||
|
@ -1,109 +1,100 @@
|
||||
const tags = ["think", "reason", "reasoning", "thought"]
|
||||
export function parseReasoning(
|
||||
text: string
|
||||
): {
|
||||
type: "reasoning" | "text"
|
||||
content: string
|
||||
reasoning_running?: boolean
|
||||
export function parseReasoning(text: string): {
|
||||
type: "reasoning" | "text"
|
||||
content: string
|
||||
reasoning_running?: boolean
|
||||
}[] {
|
||||
try {
|
||||
const result: {
|
||||
type: "reasoning" | "text"
|
||||
content: string
|
||||
reasoning_running?: boolean
|
||||
}[] = []
|
||||
const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i")
|
||||
const closeTagPattern = new RegExp(`</(${tags.join("|")})>`, "i")
|
||||
try {
|
||||
const result: {
|
||||
type: "reasoning" | "text"
|
||||
content: string
|
||||
reasoning_running?: boolean
|
||||
}[] = []
|
||||
const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i")
|
||||
const closeTagPattern = new RegExp(`</(${tags.join("|")})>`, "i")
|
||||
|
||||
let currentIndex = 0
|
||||
let isReasoning = false
|
||||
let currentIndex = 0
|
||||
let isReasoning = false
|
||||
|
||||
while (currentIndex < text.length) {
|
||||
const openTagMatch = text.slice(currentIndex).match(tagPattern)
|
||||
const closeTagMatch = text.slice(currentIndex).match(closeTagPattern)
|
||||
while (currentIndex < text.length) {
|
||||
const openTagMatch = text.slice(currentIndex).match(tagPattern)
|
||||
const closeTagMatch = text.slice(currentIndex).match(closeTagPattern)
|
||||
|
||||
if (!isReasoning && openTagMatch) {
|
||||
const beforeText = text.slice(
|
||||
currentIndex,
|
||||
currentIndex + openTagMatch.index
|
||||
)
|
||||
if (beforeText.trim()) {
|
||||
result.push({ type: "text", content: beforeText.trim() })
|
||||
}
|
||||
|
||||
isReasoning = true
|
||||
currentIndex += openTagMatch.index! + openTagMatch[0].length
|
||||
continue
|
||||
}
|
||||
|
||||
if (isReasoning && closeTagMatch) {
|
||||
const reasoningContent = text.slice(
|
||||
currentIndex,
|
||||
currentIndex + closeTagMatch.index
|
||||
)
|
||||
if (reasoningContent.trim()) {
|
||||
result.push({ type: "reasoning", content: reasoningContent.trim() })
|
||||
}
|
||||
|
||||
isReasoning = false
|
||||
currentIndex += closeTagMatch.index! + closeTagMatch[0].length
|
||||
continue
|
||||
}
|
||||
|
||||
if (currentIndex < text.length) {
|
||||
const remainingText = text.slice(currentIndex)
|
||||
result.push({
|
||||
type: isReasoning ? "reasoning" : "text",
|
||||
content: remainingText.trim(),
|
||||
reasoning_running: isReasoning
|
||||
})
|
||||
break
|
||||
}
|
||||
if (!isReasoning && openTagMatch) {
|
||||
const beforeText = text.slice(
|
||||
currentIndex,
|
||||
currentIndex + openTagMatch.index
|
||||
)
|
||||
if (beforeText.trim()) {
|
||||
result.push({ type: "text", content: beforeText.trim() })
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (e) {
|
||||
console.error(`Error parsing reasoning: ${e}`)
|
||||
return [
|
||||
{
|
||||
type: "text",
|
||||
content: text
|
||||
}
|
||||
]
|
||||
isReasoning = true
|
||||
currentIndex += openTagMatch.index! + openTagMatch[0].length
|
||||
continue
|
||||
}
|
||||
|
||||
if (isReasoning && closeTagMatch) {
|
||||
const reasoningContent = text.slice(
|
||||
currentIndex,
|
||||
currentIndex + closeTagMatch.index
|
||||
)
|
||||
if (reasoningContent.trim()) {
|
||||
result.push({ type: "reasoning", content: reasoningContent.trim() })
|
||||
}
|
||||
|
||||
isReasoning = false
|
||||
currentIndex += closeTagMatch.index! + closeTagMatch[0].length
|
||||
continue
|
||||
}
|
||||
|
||||
if (currentIndex < text.length) {
|
||||
const remainingText = text.slice(currentIndex)
|
||||
result.push({
|
||||
type: isReasoning ? "reasoning" : "text",
|
||||
content: remainingText.trim(),
|
||||
reasoning_running: isReasoning
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (e) {
|
||||
console.error(`Error parsing reasoning: ${e}`)
|
||||
return [
|
||||
{
|
||||
type: "text",
|
||||
content: text
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
export function isReasoningStarted(text: string): boolean {
|
||||
const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i")
|
||||
return tagPattern.test(text)
|
||||
const tagPattern = new RegExp(`<(${tags.join("|")})>`, "i")
|
||||
return tagPattern.test(text)
|
||||
}
|
||||
|
||||
export function isReasoningEnded(text: string): boolean {
|
||||
const closeTagPattern = new RegExp(`</(${tags.join("|")})>`, "i")
|
||||
return closeTagPattern.test(text)
|
||||
const closeTagPattern = new RegExp(`</(${tags.join("|")})>`, "i")
|
||||
return closeTagPattern.test(text)
|
||||
}
|
||||
|
||||
export function removeReasoning(text: string): string {
|
||||
const tagPattern = new RegExp(
|
||||
`<(${tags.join("|")})>.*?</(${tags.join("|")})>`,
|
||||
"gis"
|
||||
)
|
||||
return text.replace(tagPattern, "").trim()
|
||||
const tagPattern = new RegExp(
|
||||
`<(${tags.join("|")})>.*?</(${tags.join("|")})>`,
|
||||
"gis"
|
||||
)
|
||||
return text.replace(tagPattern, "").trim()
|
||||
}
|
||||
export function mergeReasoningContent(originalText: string, reasoning: string): string {
|
||||
const defaultReasoningTag = "think"
|
||||
const tagPattern = new RegExp(`<(${tags.join("|")})>(.*?)</(${tags.join("|")})>`, "is")
|
||||
const hasReasoningTag = tagPattern.test(originalText)
|
||||
export function mergeReasoningContent(
|
||||
originalText: string,
|
||||
reasoning: string
|
||||
): string {
|
||||
const reasoningTag = "<think>"
|
||||
|
||||
if (hasReasoningTag) {
|
||||
const match = originalText.match(tagPattern)
|
||||
if (match) {
|
||||
const [fullMatch, tag, existingContent] = match
|
||||
const remainingText = originalText.replace(fullMatch, '').trim()
|
||||
const newContent = `${existingContent.trim()}${reasoning}`
|
||||
return `<${tag}>${newContent}</${tag}> ${remainingText}`
|
||||
}
|
||||
}
|
||||
originalText = originalText.replace(reasoningTag, "")
|
||||
|
||||
return `<${defaultReasoningTag}>${reasoning}</${defaultReasoningTag}> ${originalText.trim()}`.trim()
|
||||
}
|
||||
return `${reasoningTag}${originalText + reasoning}`.trim()
|
||||
}
|
||||
|
78
src/models/CustomAIMessageChunk.ts
Normal file
78
src/models/CustomAIMessageChunk.ts
Normal file
@ -0,0 +1,78 @@
|
||||
interface BaseMessageFields {
|
||||
content: string;
|
||||
name?: string;
|
||||
additional_kwargs?: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
}
|
||||
|
||||
export class CustomAIMessageChunk {
|
||||
/** The text of the message. */
|
||||
content: string;
|
||||
|
||||
/** The name of the message sender in a multi-user chat. */
|
||||
name?: string;
|
||||
|
||||
/** Additional keyword arguments */
|
||||
additional_kwargs: NonNullable<BaseMessageFields["additional_kwargs"]>;
|
||||
|
||||
constructor(fields: BaseMessageFields) {
|
||||
// Make sure the default value for additional_kwargs is passed into super() for serialization
|
||||
if (!fields.additional_kwargs) {
|
||||
// eslint-disable-next-line no-param-reassign
|
||||
fields.additional_kwargs = {};
|
||||
}
|
||||
|
||||
this.name = fields.name;
|
||||
this.content = fields.content;
|
||||
this.additional_kwargs = fields.additional_kwargs;
|
||||
}
|
||||
|
||||
static _mergeAdditionalKwargs(
|
||||
left: NonNullable<BaseMessageFields["additional_kwargs"]>,
|
||||
right: NonNullable<BaseMessageFields["additional_kwargs"]>
|
||||
): NonNullable<BaseMessageFields["additional_kwargs"]> {
|
||||
const merged = { ...left };
|
||||
for (const [key, value] of Object.entries(right)) {
|
||||
if (merged[key] === undefined) {
|
||||
merged[key] = value;
|
||||
}else if (typeof merged[key] === "string") {
|
||||
merged[key] = (merged[key] as string) + value;
|
||||
} else if (
|
||||
!Array.isArray(merged[key]) &&
|
||||
typeof merged[key] === "object"
|
||||
) {
|
||||
merged[key] = this._mergeAdditionalKwargs(
|
||||
merged[key] as NonNullable<BaseMessageFields["additional_kwargs"]>,
|
||||
value as NonNullable<BaseMessageFields["additional_kwargs"]>
|
||||
);
|
||||
} else {
|
||||
throw new Error(
|
||||
`additional_kwargs[${key}] already exists in this message chunk.`
|
||||
);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
concat(chunk: CustomAIMessageChunk) {
|
||||
return new CustomAIMessageChunk({
|
||||
content: this.content + chunk.content,
|
||||
additional_kwargs: CustomAIMessageChunk._mergeAdditionalKwargs(
|
||||
this.additional_kwargs,
|
||||
chunk.additional_kwargs
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function isAiMessageChunkFields(value: unknown): value is BaseMessageFields {
|
||||
if (typeof value !== "object" || value == null) return false;
|
||||
return "content" in value && typeof value["content"] === "string";
|
||||
}
|
||||
|
||||
function isAiMessageChunkFieldsList(
|
||||
value: unknown[]
|
||||
): value is BaseMessageFields[] {
|
||||
return value.length > 0 && value.every((x) => isAiMessageChunkFields(x));
|
||||
}
|
914
src/models/CustomChatOpenAI.ts
Normal file
914
src/models/CustomChatOpenAI.ts
Normal file
@ -0,0 +1,914 @@
|
||||
import { type ClientOptions, OpenAI as OpenAIClient } from "openai"
|
||||
import {
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
} from "@langchain/core/messages"
|
||||
import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs"
|
||||
import { getEnvironmentVariable } from "@langchain/core/utils/env"
|
||||
import { BaseChatModel, BaseChatModelParams } from "@langchain/core/language_models/chat_models"
|
||||
import { convertToOpenAITool } from "@langchain/core/utils/function_calling"
|
||||
import {
|
||||
RunnablePassthrough,
|
||||
RunnableSequence
|
||||
} from "@langchain/core/runnables"
|
||||
import {
|
||||
JsonOutputParser,
|
||||
StructuredOutputParser
|
||||
} from "@langchain/core/output_parsers"
|
||||
import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools"
|
||||
import { wrapOpenAIClientError } from "./utils/openai.js"
|
||||
import {
|
||||
ChatOpenAICallOptions,
|
||||
getEndpoint,
|
||||
OpenAIChatInput,
|
||||
OpenAICoreRequestOptions
|
||||
} from "@langchain/openai"
|
||||
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"
|
||||
import { TokenUsage } from "@langchain/core/language_models/base"
|
||||
import { LegacyOpenAIInput } from "./types.js"
|
||||
import { CustomAIMessageChunk } from "./CustomAIMessageChunk.js"
|
||||
|
||||
type OpenAIRoleEnum = "system" | "assistant" | "user" | "function" | "tool"
|
||||
|
||||
function extractGenericMessageCustomRole(message: ChatMessage) {
|
||||
if (
|
||||
message.role !== "system" &&
|
||||
message.role !== "assistant" &&
|
||||
message.role !== "user" &&
|
||||
message.role !== "function" &&
|
||||
message.role !== "tool"
|
||||
) {
|
||||
console.warn(`Unknown message role: ${message.role}`)
|
||||
}
|
||||
return message.role
|
||||
}
|
||||
export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum {
|
||||
const type = message._getType()
|
||||
switch (type) {
|
||||
case "system":
|
||||
return "system"
|
||||
case "ai":
|
||||
return "assistant"
|
||||
case "human":
|
||||
return "user"
|
||||
case "function":
|
||||
return "function"
|
||||
case "tool":
|
||||
return "tool"
|
||||
case "generic": {
|
||||
if (!ChatMessage.isInstance(message))
|
||||
throw new Error("Invalid generic chat message")
|
||||
return extractGenericMessageCustomRole(message) as OpenAIRoleEnum
|
||||
}
|
||||
default:
|
||||
return type
|
||||
}
|
||||
}
|
||||
function openAIResponseToChatMessage(
|
||||
message: OpenAIClient.Chat.Completions.ChatCompletionMessage
|
||||
) {
|
||||
switch (message.role) {
|
||||
case "assistant":
|
||||
return new AIMessage(message.content || "", {
|
||||
// function_call: message.function_call,
|
||||
// tool_calls: message.tool_calls
|
||||
// reasoning_content: message?.reasoning_content || null
|
||||
})
|
||||
default:
|
||||
return new ChatMessage(message.content || "", message.role ?? "unknown")
|
||||
}
|
||||
}
|
||||
function _convertDeltaToMessageChunk(
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
delta: Record<string, any>,
|
||||
defaultRole?: OpenAIRoleEnum
|
||||
) {
|
||||
const role = delta.role ?? defaultRole
|
||||
const content = delta.content ?? ""
|
||||
const reasoning_content: string | null = delta.reasoning_content ?? null
|
||||
let additional_kwargs
|
||||
if (delta.function_call) {
|
||||
additional_kwargs = {
|
||||
function_call: delta.function_call
|
||||
}
|
||||
} else if (delta.tool_calls) {
|
||||
additional_kwargs = {
|
||||
tool_calls: delta.tool_calls
|
||||
}
|
||||
} else {
|
||||
additional_kwargs = {}
|
||||
}
|
||||
if (role === "user") {
|
||||
return new HumanMessageChunk({ content })
|
||||
} else if (role === "assistant") {
|
||||
return new CustomAIMessageChunk({
|
||||
content,
|
||||
additional_kwargs: {
|
||||
...additional_kwargs,
|
||||
reasoning_content
|
||||
}
|
||||
}) as any
|
||||
} else if (role === "system") {
|
||||
return new SystemMessageChunk({ content })
|
||||
} else if (role === "function") {
|
||||
return new FunctionMessageChunk({
|
||||
content,
|
||||
additional_kwargs,
|
||||
name: delta.name
|
||||
})
|
||||
} else if (role === "tool") {
|
||||
return new ToolMessageChunk({
|
||||
content,
|
||||
additional_kwargs,
|
||||
tool_call_id: delta.tool_call_id
|
||||
})
|
||||
} else {
|
||||
return new ChatMessageChunk({ content, role })
|
||||
}
|
||||
}
|
||||
function convertMessagesToOpenAIParams(messages: any[]) {
|
||||
// TODO: Function messages do not support array content, fix cast
|
||||
return messages.map((message) => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const completionParam: { role: string; content: string; name?: string } = {
|
||||
role: messageToOpenAIRole(message),
|
||||
content: message.content
|
||||
}
|
||||
if (message.name != null) {
|
||||
completionParam.name = message.name
|
||||
}
|
||||
|
||||
return completionParam
|
||||
})
|
||||
}
|
||||
export class CustomChatOpenAI<
|
||||
CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions
|
||||
>
|
||||
extends BaseChatModel<CallOptions>
|
||||
implements OpenAIChatInput {
|
||||
temperature = 1
|
||||
|
||||
topP = 1
|
||||
|
||||
frequencyPenalty = 0
|
||||
|
||||
presencePenalty = 0
|
||||
|
||||
n = 1
|
||||
|
||||
logitBias?: Record<string, number>
|
||||
|
||||
modelName = "gpt-3.5-turbo"
|
||||
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
modelKwargs?: OpenAIChatInput["modelKwargs"]
|
||||
|
||||
stop?: string[]
|
||||
|
||||
stopSequences?: string[]
|
||||
|
||||
user?: string
|
||||
|
||||
timeout?: number
|
||||
|
||||
streaming = false
|
||||
|
||||
streamUsage = true
|
||||
|
||||
maxTokens?: number
|
||||
|
||||
logprobs?: boolean
|
||||
|
||||
topLogprobs?: number
|
||||
|
||||
openAIApiKey?: string
|
||||
|
||||
apiKey?: string
|
||||
|
||||
azureOpenAIApiVersion?: string
|
||||
|
||||
azureOpenAIApiKey?: string
|
||||
|
||||
azureADTokenProvider?: () => Promise<string>
|
||||
|
||||
azureOpenAIApiInstanceName?: string
|
||||
|
||||
azureOpenAIApiDeploymentName?: string
|
||||
|
||||
azureOpenAIBasePath?: string
|
||||
|
||||
organization?: string
|
||||
|
||||
protected client: OpenAIClient
|
||||
|
||||
protected clientConfig: ClientOptions
|
||||
static lc_name() {
|
||||
return "ChatOpenAI"
|
||||
}
|
||||
get callKeys() {
|
||||
return [
|
||||
...super.callKeys,
|
||||
"options",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"promptIndex",
|
||||
"response_format",
|
||||
"seed"
|
||||
]
|
||||
}
|
||||
get lc_secrets() {
|
||||
return {
|
||||
openAIApiKey: "OPENAI_API_KEY",
|
||||
azureOpenAIApiKey: "AZURE_OPENAI_API_KEY",
|
||||
organization: "OPENAI_ORGANIZATION"
|
||||
}
|
||||
}
|
||||
get lc_aliases() {
|
||||
return {
|
||||
modelName: "model",
|
||||
openAIApiKey: "openai_api_key",
|
||||
azureOpenAIApiVersion: "azure_openai_api_version",
|
||||
azureOpenAIApiKey: "azure_openai_api_key",
|
||||
azureOpenAIApiInstanceName: "azure_openai_api_instance_name",
|
||||
azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name"
|
||||
}
|
||||
}
|
||||
constructor(
|
||||
fields?: Partial<OpenAIChatInput> &
|
||||
BaseChatModelParams & {
|
||||
configuration?: ClientOptions & LegacyOpenAIInput;
|
||||
},
|
||||
/** @deprecated */
|
||||
configuration?: ClientOptions & LegacyOpenAIInput
|
||||
) {
|
||||
super(fields ?? {})
|
||||
Object.defineProperty(this, "lc_serializable", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: true
|
||||
})
|
||||
Object.defineProperty(this, "temperature", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: 1
|
||||
})
|
||||
Object.defineProperty(this, "topP", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: 1
|
||||
})
|
||||
Object.defineProperty(this, "frequencyPenalty", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: 0
|
||||
})
|
||||
Object.defineProperty(this, "presencePenalty", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: 0
|
||||
})
|
||||
Object.defineProperty(this, "n", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: 1
|
||||
})
|
||||
Object.defineProperty(this, "logitBias", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "modelName", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: "gpt-3.5-turbo"
|
||||
})
|
||||
Object.defineProperty(this, "modelKwargs", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "stop", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "user", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "timeout", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "streaming", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: false
|
||||
})
|
||||
Object.defineProperty(this, "maxTokens", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "logprobs", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "topLogprobs", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "openAIApiKey", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "azureOpenAIApiVersion", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "azureOpenAIApiKey", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "azureOpenAIApiInstanceName", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "azureOpenAIApiDeploymentName", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "azureOpenAIBasePath", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "organization", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "client", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
Object.defineProperty(this, "clientConfig", {
|
||||
enumerable: true,
|
||||
configurable: true,
|
||||
writable: true,
|
||||
value: void 0
|
||||
})
|
||||
this.openAIApiKey =
|
||||
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY")
|
||||
|
||||
this.modelName = fields?.modelName ?? this.modelName
|
||||
this.modelKwargs = fields?.modelKwargs ?? {}
|
||||
this.timeout = fields?.timeout
|
||||
this.temperature = fields?.temperature ?? this.temperature
|
||||
this.topP = fields?.topP ?? this.topP
|
||||
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty
|
||||
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty
|
||||
this.maxTokens = fields?.maxTokens
|
||||
this.logprobs = fields?.logprobs
|
||||
this.topLogprobs = fields?.topLogprobs
|
||||
this.n = fields?.n ?? this.n
|
||||
this.logitBias = fields?.logitBias
|
||||
this.stop = fields?.stop
|
||||
this.user = fields?.user
|
||||
this.streaming = fields?.streaming ?? false
|
||||
this.clientConfig = {
|
||||
apiKey: this.openAIApiKey,
|
||||
organization: this.organization,
|
||||
baseURL: configuration?.basePath ?? fields?.configuration?.basePath,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders:
|
||||
configuration?.baseOptions?.headers ??
|
||||
fields?.configuration?.baseOptions?.headers,
|
||||
defaultQuery:
|
||||
configuration?.baseOptions?.params ??
|
||||
fields?.configuration?.baseOptions?.params,
|
||||
...configuration,
|
||||
...fields?.configuration
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Get the parameters used to invoke the model
|
||||
*/
|
||||
invocationParams(options) {
|
||||
function isStructuredToolArray(tools) {
|
||||
return (
|
||||
tools !== undefined &&
|
||||
tools.every((tool) => Array.isArray(tool.lc_namespace))
|
||||
)
|
||||
}
|
||||
const params = {
|
||||
model: this.modelName,
|
||||
temperature: this.temperature,
|
||||
top_p: this.topP,
|
||||
frequency_penalty: this.frequencyPenalty,
|
||||
presence_penalty: this.presencePenalty,
|
||||
max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens,
|
||||
logprobs: this.logprobs,
|
||||
top_logprobs: this.topLogprobs,
|
||||
n: this.n,
|
||||
logit_bias: this.logitBias,
|
||||
stop: options?.stop ?? this.stop,
|
||||
user: this.user,
|
||||
stream: this.streaming,
|
||||
functions: options?.functions,
|
||||
function_call: options?.function_call,
|
||||
tools: isStructuredToolArray(options?.tools)
|
||||
? options?.tools.map(convertToOpenAITool)
|
||||
: options?.tools,
|
||||
tool_choice: options?.tool_choice,
|
||||
response_format: options?.response_format,
|
||||
seed: options?.seed,
|
||||
...this.modelKwargs
|
||||
}
|
||||
return params
|
||||
}
|
||||
/** @ignore */
|
||||
_identifyingParams() {
|
||||
return {
|
||||
model_name: this.modelName,
|
||||
//@ts-ignore
|
||||
...this?.invocationParams(),
|
||||
...this.clientConfig
|
||||
}
|
||||
}
|
||||
async *_streamResponseChunks(
|
||||
messages: BaseMessage[],
|
||||
options: this["ParsedCallOptions"],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): AsyncGenerator<ChatGenerationChunk> {
|
||||
const messagesMapped = convertMessagesToOpenAIParams(messages)
|
||||
const params = {
|
||||
...this.invocationParams(options),
|
||||
messages: messagesMapped,
|
||||
stream: true
|
||||
}
|
||||
let defaultRole
|
||||
//@ts-ignore
|
||||
const streamIterable = await this.completionWithRetry(params, options)
|
||||
for await (const data of streamIterable) {
|
||||
const choice = data?.choices[0]
|
||||
if (!choice) {
|
||||
continue
|
||||
}
|
||||
const { delta } = choice
|
||||
if (!delta) {
|
||||
continue
|
||||
}
|
||||
const chunk = _convertDeltaToMessageChunk(delta, defaultRole)
|
||||
defaultRole = delta.role ?? defaultRole
|
||||
const newTokenIndices = {
|
||||
//@ts-ignore
|
||||
prompt: options?.promptIndex ?? 0,
|
||||
completion: choice.index ?? 0
|
||||
}
|
||||
if (typeof chunk.content !== "string") {
|
||||
console.log(
|
||||
"[WARNING]: Received non-string content from OpenAI. This is currently not supported."
|
||||
)
|
||||
continue
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const generationInfo = { ...newTokenIndices } as any
|
||||
if (choice.finish_reason !== undefined) {
|
||||
generationInfo.finish_reason = choice.finish_reason
|
||||
}
|
||||
if (this.logprobs) {
|
||||
generationInfo.logprobs = choice.logprobs
|
||||
}
|
||||
const generationChunk = new ChatGenerationChunk({
|
||||
message: chunk,
|
||||
text: chunk.content,
|
||||
generationInfo
|
||||
})
|
||||
yield generationChunk
|
||||
// eslint-disable-next-line no-void
|
||||
void runManager?.handleLLMNewToken(
|
||||
generationChunk.text ?? "",
|
||||
newTokenIndices,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
{ chunk: generationChunk }
|
||||
)
|
||||
}
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("AbortError")
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Get the identifying parameters for the model
|
||||
*
|
||||
*/
|
||||
identifyingParams() {
|
||||
return this._identifyingParams()
|
||||
}
|
||||
/** @ignore */
|
||||
async _generate(
|
||||
messages: BaseMessage[],
|
||||
options: this["ParsedCallOptions"],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): Promise<ChatResult> {
|
||||
const tokenUsage: TokenUsage = {}
|
||||
const params = this.invocationParams(options)
|
||||
const messagesMapped: any[] = convertMessagesToOpenAIParams(messages)
|
||||
if (params.stream) {
|
||||
const stream = this._streamResponseChunks(messages, options, runManager)
|
||||
const finalChunks: Record<number, ChatGenerationChunk> = {}
|
||||
for await (const chunk of stream) {
|
||||
//@ts-ignore
|
||||
chunk.message.response_metadata = {
|
||||
...chunk.generationInfo,
|
||||
//@ts-ignore
|
||||
...chunk.message.response_metadata
|
||||
}
|
||||
const index = chunk.generationInfo?.completion ?? 0
|
||||
if (finalChunks[index] === undefined) {
|
||||
finalChunks[index] = chunk
|
||||
} else {
|
||||
finalChunks[index] = finalChunks[index].concat(chunk)
|
||||
}
|
||||
}
|
||||
const generations = Object.entries(finalChunks)
|
||||
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
|
||||
.map(([_, value]) => value)
|
||||
const { functions, function_call } = this.invocationParams(options)
|
||||
// OpenAI does not support token usage report under stream mode,
|
||||
// fallback to estimation.
|
||||
const promptTokenUsage = await this.getEstimatedTokenCountFromPrompt(
|
||||
messages,
|
||||
functions,
|
||||
function_call
|
||||
)
|
||||
const completionTokenUsage =
|
||||
await this.getNumTokensFromGenerations(generations)
|
||||
tokenUsage.promptTokens = promptTokenUsage
|
||||
tokenUsage.completionTokens = completionTokenUsage
|
||||
tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage
|
||||
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
|
||||
} else {
|
||||
const data = await this.completionWithRetry(
|
||||
{
|
||||
...params,
|
||||
//@ts-ignore
|
||||
stream: false,
|
||||
messages: messagesMapped
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
//@ts-ignore
|
||||
...options?.options
|
||||
}
|
||||
)
|
||||
const {
|
||||
completion_tokens: completionTokens,
|
||||
prompt_tokens: promptTokens,
|
||||
total_tokens: totalTokens
|
||||
//@ts-ignore
|
||||
} = data?.usage ?? {}
|
||||
if (completionTokens) {
|
||||
tokenUsage.completionTokens =
|
||||
(tokenUsage.completionTokens ?? 0) + completionTokens
|
||||
}
|
||||
if (promptTokens) {
|
||||
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens
|
||||
}
|
||||
if (totalTokens) {
|
||||
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens
|
||||
}
|
||||
const generations = []
|
||||
//@ts-ignore
|
||||
for (const part of data?.choices ?? []) {
|
||||
const text = part.message?.content ?? ""
|
||||
const generation = {
|
||||
text,
|
||||
message: openAIResponseToChatMessage(
|
||||
part.message ?? { role: "assistant" }
|
||||
)
|
||||
}
|
||||
//@ts-ignore
|
||||
generation.generationInfo = {
|
||||
...(part.finish_reason ? { finish_reason: part.finish_reason } : {}),
|
||||
...(part.logprobs ? { logprobs: part.logprobs } : {})
|
||||
}
|
||||
generations.push(generation)
|
||||
}
|
||||
return {
|
||||
generations,
|
||||
llmOutput: { tokenUsage }
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Estimate the number of tokens a prompt will use.
|
||||
* Modified from: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts
|
||||
*/
|
||||
async getEstimatedTokenCountFromPrompt(messages, functions, function_call) {
|
||||
let tokens = (await this.getNumTokensFromMessages(messages)).totalCount
|
||||
if (functions && messages.find((m) => m._getType() === "system")) {
|
||||
tokens -= 4
|
||||
}
|
||||
if (function_call === "none") {
|
||||
tokens += 1
|
||||
} else if (typeof function_call === "object") {
|
||||
tokens += (await this.getNumTokens(function_call.name)) + 4
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
/**
|
||||
* Estimate the number of tokens an array of generations have used.
|
||||
*/
|
||||
async getNumTokensFromGenerations(generations) {
|
||||
const generationUsages = await Promise.all(
|
||||
generations.map(async (generation) => {
|
||||
if (generation.message.additional_kwargs?.function_call) {
|
||||
return (await this.getNumTokensFromMessages([generation.message]))
|
||||
.countPerMessage[0]
|
||||
} else {
|
||||
return await this.getNumTokens(generation.message.content)
|
||||
}
|
||||
})
|
||||
)
|
||||
return generationUsages.reduce((a, b) => a + b, 0)
|
||||
}
|
||||
async getNumTokensFromMessages(messages) {
|
||||
let totalCount = 0
|
||||
let tokensPerMessage = 0
|
||||
let tokensPerName = 0
|
||||
// From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||
if (this.modelName === "gpt-3.5-turbo-0301") {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
const countPerMessage = await Promise.all(
|
||||
messages.map(async (message) => {
|
||||
const textCount = await this.getNumTokens(message.content)
|
||||
const roleCount = await this.getNumTokens(messageToOpenAIRole(message))
|
||||
const nameCount =
|
||||
message.name !== undefined
|
||||
? tokensPerName + (await this.getNumTokens(message.name))
|
||||
: 0
|
||||
let count = textCount + tokensPerMessage + roleCount + nameCount
|
||||
// From: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts messageTokenEstimate
|
||||
const openAIMessage = message
|
||||
if (openAIMessage._getType() === "function") {
|
||||
count -= 2
|
||||
}
|
||||
if (openAIMessage.additional_kwargs?.function_call) {
|
||||
count += 3
|
||||
}
|
||||
if (openAIMessage?.additional_kwargs.function_call?.name) {
|
||||
count += await this.getNumTokens(
|
||||
openAIMessage.additional_kwargs.function_call?.name
|
||||
)
|
||||
}
|
||||
if (openAIMessage.additional_kwargs.function_call?.arguments) {
|
||||
try {
|
||||
count += await this.getNumTokens(
|
||||
// Remove newlines and spaces
|
||||
JSON.stringify(
|
||||
JSON.parse(
|
||||
openAIMessage.additional_kwargs.function_call?.arguments
|
||||
)
|
||||
)
|
||||
)
|
||||
} catch (error) {
|
||||
console.error(
|
||||
"Error parsing function arguments",
|
||||
error,
|
||||
JSON.stringify(openAIMessage.additional_kwargs.function_call)
|
||||
)
|
||||
count += await this.getNumTokens(
|
||||
openAIMessage.additional_kwargs.function_call?.arguments
|
||||
)
|
||||
}
|
||||
}
|
||||
totalCount += count
|
||||
return count
|
||||
})
|
||||
)
|
||||
totalCount += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||
return { totalCount, countPerMessage }
|
||||
}
|
||||
async completionWithRetry(
|
||||
request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming,
|
||||
options?: OpenAICoreRequestOptions
|
||||
) {
|
||||
const requestOptions = this._getClientOptions(options)
|
||||
return this.caller.call(async () => {
|
||||
try {
|
||||
const res = await this.client.chat.completions.create(
|
||||
request,
|
||||
requestOptions
|
||||
)
|
||||
return res
|
||||
} catch (e) {
|
||||
const error = wrapOpenAIClientError(e)
|
||||
throw error
|
||||
}
|
||||
})
|
||||
}
|
||||
_getClientOptions(options) {
|
||||
if (!this.client) {
|
||||
const openAIEndpointConfig = {
|
||||
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
|
||||
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
|
||||
azureOpenAIApiKey: this.azureOpenAIApiKey,
|
||||
azureOpenAIBasePath: this.azureOpenAIBasePath,
|
||||
baseURL: this.clientConfig.baseURL
|
||||
}
|
||||
const endpoint = getEndpoint(openAIEndpointConfig)
|
||||
const params = {
|
||||
...this.clientConfig,
|
||||
baseURL: endpoint,
|
||||
timeout: this.timeout,
|
||||
maxRetries: 0
|
||||
}
|
||||
if (!params.baseURL) {
|
||||
delete params.baseURL
|
||||
}
|
||||
this.client = new OpenAIClient(params)
|
||||
}
|
||||
const requestOptions = {
|
||||
...this.clientConfig,
|
||||
...options
|
||||
}
|
||||
if (this.azureOpenAIApiKey) {
|
||||
requestOptions.headers = {
|
||||
"api-key": this.azureOpenAIApiKey,
|
||||
...requestOptions.headers
|
||||
}
|
||||
requestOptions.query = {
|
||||
"api-version": this.azureOpenAIApiVersion,
|
||||
...requestOptions.query
|
||||
}
|
||||
}
|
||||
return requestOptions
|
||||
}
|
||||
_llmType() {
|
||||
return "openai"
|
||||
}
|
||||
/** @ignore */
|
||||
_combineLLMOutput(...llmOutputs) {
|
||||
return llmOutputs.reduce(
|
||||
(acc, llmOutput) => {
|
||||
if (llmOutput && llmOutput.tokenUsage) {
|
||||
acc.tokenUsage.completionTokens +=
|
||||
llmOutput.tokenUsage.completionTokens ?? 0
|
||||
acc.tokenUsage.promptTokens += llmOutput.tokenUsage.promptTokens ?? 0
|
||||
acc.tokenUsage.totalTokens += llmOutput.tokenUsage.totalTokens ?? 0
|
||||
}
|
||||
return acc
|
||||
},
|
||||
{
|
||||
tokenUsage: {
|
||||
completionTokens: 0,
|
||||
promptTokens: 0,
|
||||
totalTokens: 0
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
withStructuredOutput(outputSchema, config) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let schema
|
||||
let name
|
||||
let method
|
||||
let includeRaw
|
||||
if (isStructuredOutputMethodParams(outputSchema)) {
|
||||
schema = outputSchema.schema
|
||||
name = outputSchema.name
|
||||
method = outputSchema.method
|
||||
includeRaw = outputSchema.includeRaw
|
||||
} else {
|
||||
schema = outputSchema
|
||||
name = config?.name
|
||||
method = config?.method
|
||||
includeRaw = config?.includeRaw
|
||||
}
|
||||
let llm
|
||||
let outputParser
|
||||
if (method === "jsonMode") {
|
||||
llm = this.bind({
|
||||
})
|
||||
if (isZodSchema(schema)) {
|
||||
outputParser = StructuredOutputParser.fromZodSchema(schema)
|
||||
} else {
|
||||
outputParser = new JsonOutputParser()
|
||||
}
|
||||
} else {
|
||||
let functionName = name ?? "extract"
|
||||
// Is function calling
|
||||
|
||||
let openAIFunctionDefinition
|
||||
if (
|
||||
typeof schema.name === "string" &&
|
||||
typeof schema.parameters === "object" &&
|
||||
schema.parameters != null
|
||||
) {
|
||||
openAIFunctionDefinition = schema
|
||||
functionName = schema.name
|
||||
} else {
|
||||
openAIFunctionDefinition = {
|
||||
name: schema.title ?? functionName,
|
||||
description: schema.description ?? "",
|
||||
parameters: schema
|
||||
}
|
||||
}
|
||||
llm = this.bind({
|
||||
|
||||
})
|
||||
outputParser = new JsonOutputKeyToolsParser({
|
||||
returnSingle: true,
|
||||
keyName: functionName
|
||||
})
|
||||
}
|
||||
if (!includeRaw) {
|
||||
return llm.pipe(outputParser)
|
||||
}
|
||||
const parserAssign = RunnablePassthrough.assign({
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
parsed: (input, config) => outputParser.invoke(input.raw, config)
|
||||
})
|
||||
const parserNone = RunnablePassthrough.assign({
|
||||
parsed: () => null
|
||||
})
|
||||
const parsedWithFallback = parserAssign.withFallbacks({
|
||||
fallbacks: [parserNone]
|
||||
})
|
||||
return RunnableSequence.from([
|
||||
{
|
||||
raw: llm
|
||||
},
|
||||
parsedWithFallback
|
||||
] as any)
|
||||
}
|
||||
}
|
||||
function isZodSchema(
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
input
|
||||
) {
|
||||
// Check for a characteristic method of Zod schemas
|
||||
return typeof input?.parse === "function"
|
||||
}
|
||||
function isStructuredOutputMethodParams(
|
||||
x
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
) {
|
||||
return (
|
||||
x !== undefined &&
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
typeof x.schema === "object"
|
||||
)
|
||||
}
|
@ -2,9 +2,9 @@ import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models"
|
||||
import { ChatChromeAI } from "./ChatChromeAi"
|
||||
import { ChatOllama } from "./ChatOllama"
|
||||
import { getOpenAIConfigById } from "@/db/openai"
|
||||
import { ChatOpenAI } from "@langchain/openai"
|
||||
import { urlRewriteRuntime } from "@/libs/runtime"
|
||||
import { ChatGoogleAI } from "./ChatGoogleAI"
|
||||
import { CustomChatOpenAI } from "./CustomChatOpenAI"
|
||||
|
||||
export const pageAssistModel = async ({
|
||||
model,
|
||||
@ -76,7 +76,7 @@ export const pageAssistModel = async ({
|
||||
}) as any
|
||||
}
|
||||
|
||||
return new ChatOpenAI({
|
||||
return new CustomChatOpenAI({
|
||||
modelName: modelInfo.model_id,
|
||||
openAIApiKey: providerInfo.apiKey || "temp",
|
||||
temperature,
|
||||
|
Loading…
x
Reference in New Issue
Block a user