feat: implement CustomAIMessageChunk class and enhance reasoning handling in useMessageOption hook
This commit is contained in:
parent
8381f1c996
commit
926f4e1a4a
@ -206,7 +206,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex w-full flex-col items-center p-2 pt-1 pb-4">
|
<div className="flex w-full flex-col items-center p-2 pt-1 pb-4">
|
||||||
<form className="relative z-10 flex w-full flex-col items-center justify-center gap-2 text-base">
|
<div className="relative z-10 flex w-full flex-col items-center justify-center gap-2 text-base">
|
||||||
<div className="relative flex w-full flex-row justify-center gap-2 lg:w-4/5">
|
<div className="relative flex w-full flex-row justify-center gap-2 lg:w-4/5">
|
||||||
<div
|
<div
|
||||||
className={` bg-neutral-50 dark:bg-[#262626] relative w-full max-w-[48rem] p-1 backdrop-blur-lg duration-100 border border-gray-300 rounded-xl dark:border-gray-600
|
className={` bg-neutral-50 dark:bg-[#262626] relative w-full max-w-[48rem] p-1 backdrop-blur-lg duration-100 border border-gray-300 rounded-xl dark:border-gray-600
|
||||||
@ -449,7 +449,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,12 @@ import { humanMessageFormatter } from "@/utils/human-message"
|
|||||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||||
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
|
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
|
||||||
import { getScreenshotFromCurrentTab } from "@/libs/get-screenshot"
|
import { getScreenshotFromCurrentTab } from "@/libs/get-screenshot"
|
||||||
import { isReasoningEnded, isReasoningStarted, removeReasoning } from "@/libs/reasoning"
|
import {
|
||||||
|
isReasoningEnded,
|
||||||
|
isReasoningStarted,
|
||||||
|
mergeReasoningContent,
|
||||||
|
removeReasoning
|
||||||
|
} from "@/libs/reasoning"
|
||||||
|
|
||||||
export const useMessage = () => {
|
export const useMessage = () => {
|
||||||
const {
|
const {
|
||||||
@ -413,7 +418,24 @@ export const useMessage = () => {
|
|||||||
let reasoningStartTime: Date | null = null
|
let reasoningStartTime: Date | null = null
|
||||||
let reasoningEndTime: Date | null = null
|
let reasoningEndTime: Date | null = null
|
||||||
let timetaken = 0
|
let timetaken = 0
|
||||||
|
let apiReasoning = false
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
@ -680,7 +702,24 @@ export const useMessage = () => {
|
|||||||
let reasoningStartTime: Date | undefined = undefined
|
let reasoningStartTime: Date | undefined = undefined
|
||||||
let reasoningEndTime: Date | undefined = undefined
|
let reasoningEndTime: Date | undefined = undefined
|
||||||
let timetaken = 0
|
let timetaken = 0
|
||||||
|
let apiReasoning = false
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
@ -950,8 +989,25 @@ export const useMessage = () => {
|
|||||||
let reasoningStartTime: Date | null = null
|
let reasoningStartTime: Date | null = null
|
||||||
let reasoningEndTime: Date | null = null
|
let reasoningEndTime: Date | null = null
|
||||||
let timetaken = 0
|
let timetaken = 0
|
||||||
|
let apiReasoning = false
|
||||||
|
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
@ -1279,7 +1335,24 @@ export const useMessage = () => {
|
|||||||
let timetaken = 0
|
let timetaken = 0
|
||||||
let reasoningStartTime: Date | undefined = undefined
|
let reasoningStartTime: Date | undefined = undefined
|
||||||
let reasoningEndTime: Date | undefined = undefined
|
let reasoningEndTime: Date | undefined = undefined
|
||||||
|
let apiReasoning = false
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
@ -1527,7 +1600,24 @@ export const useMessage = () => {
|
|||||||
let reasoningStartTime: Date | null = null
|
let reasoningStartTime: Date | null = null
|
||||||
let reasoningEndTime: Date | null = null
|
let reasoningEndTime: Date | null = null
|
||||||
let timetaken = 0
|
let timetaken = 0
|
||||||
|
let apiReasoning = false
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
|
@ -332,7 +332,24 @@ export const useMessageOption = () => {
|
|||||||
let count = 0
|
let count = 0
|
||||||
let reasoningStartTime: Date | undefined = undefined
|
let reasoningStartTime: Date | undefined = undefined
|
||||||
let reasoningEndTime: Date | undefined = undefined
|
let reasoningEndTime: Date | undefined = undefined
|
||||||
|
let apiReasoning = false
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
@ -649,19 +666,27 @@ export const useMessageOption = () => {
|
|||||||
let count = 0
|
let count = 0
|
||||||
let reasoningStartTime: Date | null = null
|
let reasoningStartTime: Date | null = null
|
||||||
let reasoningEndTime: Date | null = null
|
let reasoningEndTime: Date | null = null
|
||||||
|
let apiReasoning: boolean = false
|
||||||
|
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
// console.log(chunk)
|
|
||||||
// if (chunk?.reasoning_content) {
|
|
||||||
// const reasoningContent = mergeReasoningContent(
|
|
||||||
// fullText,
|
|
||||||
// chunk?.reasoning_content || ""
|
|
||||||
// )
|
|
||||||
// contentToSave += reasoningContent
|
|
||||||
// fullText += reasoningContent
|
|
||||||
// }
|
|
||||||
|
|
||||||
if (isReasoningStarted(fullText) && !reasoningStartTime) {
|
if (isReasoningStarted(fullText) && !reasoningStartTime) {
|
||||||
reasoningStartTime = new Date()
|
reasoningStartTime = new Date()
|
||||||
@ -992,8 +1017,25 @@ export const useMessageOption = () => {
|
|||||||
let count = 0
|
let count = 0
|
||||||
let reasoningStartTime: Date | undefined = undefined
|
let reasoningStartTime: Date | undefined = undefined
|
||||||
let reasoningEndTime: Date | undefined = undefined
|
let reasoningEndTime: Date | undefined = undefined
|
||||||
|
let apiReasoning = false
|
||||||
|
|
||||||
for await (const chunk of chunks) {
|
for await (const chunk of chunks) {
|
||||||
|
if (chunk?.additional_kwargs?.reasoning_content) {
|
||||||
|
const reasoningContent = mergeReasoningContent(
|
||||||
|
fullText,
|
||||||
|
chunk?.additional_kwargs?.reasoning_content || ""
|
||||||
|
)
|
||||||
|
contentToSave = reasoningContent
|
||||||
|
fullText = reasoningContent
|
||||||
|
apiReasoning = true
|
||||||
|
} else {
|
||||||
|
if (apiReasoning) {
|
||||||
|
fullText += "</think>"
|
||||||
|
contentToSave += "</think>"
|
||||||
|
apiReasoning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentToSave += chunk?.content
|
contentToSave += chunk?.content
|
||||||
fullText += chunk?.content
|
fullText += chunk?.content
|
||||||
if (count === 0) {
|
if (count === 0) {
|
||||||
|
@ -25,7 +25,7 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => {
|
|||||||
clearTimeout(timeoutId)
|
clearTimeout(timeoutId)
|
||||||
|
|
||||||
// if Google API fails to return models, try another approach
|
// if Google API fails to return models, try another approach
|
||||||
if (res.status === 401 && res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') {
|
if (res.url == 'https://generativelanguage.googleapis.com/v1beta/openai/models') {
|
||||||
const urlGoogle = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`
|
const urlGoogle = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`
|
||||||
const resGoogle = await fetch(urlGoogle, {
|
const resGoogle = await fetch(urlGoogle, {
|
||||||
signal: controller.signal
|
signal: controller.signal
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
const tags = ["think", "reason", "reasoning", "thought"]
|
const tags = ["think", "reason", "reasoning", "thought"]
|
||||||
export function parseReasoning(
|
export function parseReasoning(text: string): {
|
||||||
text: string
|
|
||||||
): {
|
|
||||||
type: "reasoning" | "text"
|
type: "reasoning" | "text"
|
||||||
content: string
|
content: string
|
||||||
reasoning_running?: boolean
|
reasoning_running?: boolean
|
||||||
@ -90,20 +88,13 @@ export function removeReasoning(text: string): string {
|
|||||||
)
|
)
|
||||||
return text.replace(tagPattern, "").trim()
|
return text.replace(tagPattern, "").trim()
|
||||||
}
|
}
|
||||||
export function mergeReasoningContent(originalText: string, reasoning: string): string {
|
export function mergeReasoningContent(
|
||||||
const defaultReasoningTag = "think"
|
originalText: string,
|
||||||
const tagPattern = new RegExp(`<(${tags.join("|")})>(.*?)</(${tags.join("|")})>`, "is")
|
reasoning: string
|
||||||
const hasReasoningTag = tagPattern.test(originalText)
|
): string {
|
||||||
|
const reasoningTag = "<think>"
|
||||||
|
|
||||||
if (hasReasoningTag) {
|
originalText = originalText.replace(reasoningTag, "")
|
||||||
const match = originalText.match(tagPattern)
|
|
||||||
if (match) {
|
|
||||||
const [fullMatch, tag, existingContent] = match
|
|
||||||
const remainingText = originalText.replace(fullMatch, '').trim()
|
|
||||||
const newContent = `${existingContent.trim()}${reasoning}`
|
|
||||||
return `<${tag}>${newContent}</${tag}> ${remainingText}`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return `<${defaultReasoningTag}>${reasoning}</${defaultReasoningTag}> ${originalText.trim()}`.trim()
|
return `${reasoningTag}${originalText + reasoning}`.trim()
|
||||||
}
|
}
|
78
src/models/CustomAIMessageChunk.ts
Normal file
78
src/models/CustomAIMessageChunk.ts
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
interface BaseMessageFields {
|
||||||
|
content: string;
|
||||||
|
name?: string;
|
||||||
|
additional_kwargs?: {
|
||||||
|
[key: string]: unknown;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export class CustomAIMessageChunk {
|
||||||
|
/** The text of the message. */
|
||||||
|
content: string;
|
||||||
|
|
||||||
|
/** The name of the message sender in a multi-user chat. */
|
||||||
|
name?: string;
|
||||||
|
|
||||||
|
/** Additional keyword arguments */
|
||||||
|
additional_kwargs: NonNullable<BaseMessageFields["additional_kwargs"]>;
|
||||||
|
|
||||||
|
constructor(fields: BaseMessageFields) {
|
||||||
|
// Make sure the default value for additional_kwargs is passed into super() for serialization
|
||||||
|
if (!fields.additional_kwargs) {
|
||||||
|
// eslint-disable-next-line no-param-reassign
|
||||||
|
fields.additional_kwargs = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
this.name = fields.name;
|
||||||
|
this.content = fields.content;
|
||||||
|
this.additional_kwargs = fields.additional_kwargs;
|
||||||
|
}
|
||||||
|
|
||||||
|
static _mergeAdditionalKwargs(
|
||||||
|
left: NonNullable<BaseMessageFields["additional_kwargs"]>,
|
||||||
|
right: NonNullable<BaseMessageFields["additional_kwargs"]>
|
||||||
|
): NonNullable<BaseMessageFields["additional_kwargs"]> {
|
||||||
|
const merged = { ...left };
|
||||||
|
for (const [key, value] of Object.entries(right)) {
|
||||||
|
if (merged[key] === undefined) {
|
||||||
|
merged[key] = value;
|
||||||
|
}else if (typeof merged[key] === "string") {
|
||||||
|
merged[key] = (merged[key] as string) + value;
|
||||||
|
} else if (
|
||||||
|
!Array.isArray(merged[key]) &&
|
||||||
|
typeof merged[key] === "object"
|
||||||
|
) {
|
||||||
|
merged[key] = this._mergeAdditionalKwargs(
|
||||||
|
merged[key] as NonNullable<BaseMessageFields["additional_kwargs"]>,
|
||||||
|
value as NonNullable<BaseMessageFields["additional_kwargs"]>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
`additional_kwargs[${key}] already exists in this message chunk.`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return merged;
|
||||||
|
}
|
||||||
|
|
||||||
|
concat(chunk: CustomAIMessageChunk) {
|
||||||
|
return new CustomAIMessageChunk({
|
||||||
|
content: this.content + chunk.content,
|
||||||
|
additional_kwargs: CustomAIMessageChunk._mergeAdditionalKwargs(
|
||||||
|
this.additional_kwargs,
|
||||||
|
chunk.additional_kwargs
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function isAiMessageChunkFields(value: unknown): value is BaseMessageFields {
|
||||||
|
if (typeof value !== "object" || value == null) return false;
|
||||||
|
return "content" in value && typeof value["content"] === "string";
|
||||||
|
}
|
||||||
|
|
||||||
|
function isAiMessageChunkFieldsList(
|
||||||
|
value: unknown[]
|
||||||
|
): value is BaseMessageFields[] {
|
||||||
|
return value.length > 0 && value.every((x) => isAiMessageChunkFields(x));
|
||||||
|
}
|
914
src/models/CustomChatOpenAI.ts
Normal file
914
src/models/CustomChatOpenAI.ts
Normal file
@ -0,0 +1,914 @@
|
|||||||
|
import { type ClientOptions, OpenAI as OpenAIClient } from "openai"
|
||||||
|
import {
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessageChunk,
|
||||||
|
ToolMessageChunk,
|
||||||
|
} from "@langchain/core/messages"
|
||||||
|
import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs"
|
||||||
|
import { getEnvironmentVariable } from "@langchain/core/utils/env"
|
||||||
|
import { BaseChatModel, BaseChatModelParams } from "@langchain/core/language_models/chat_models"
|
||||||
|
import { convertToOpenAITool } from "@langchain/core/utils/function_calling"
|
||||||
|
import {
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnableSequence
|
||||||
|
} from "@langchain/core/runnables"
|
||||||
|
import {
|
||||||
|
JsonOutputParser,
|
||||||
|
StructuredOutputParser
|
||||||
|
} from "@langchain/core/output_parsers"
|
||||||
|
import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools"
|
||||||
|
import { wrapOpenAIClientError } from "./utils/openai.js"
|
||||||
|
import {
|
||||||
|
ChatOpenAICallOptions,
|
||||||
|
getEndpoint,
|
||||||
|
OpenAIChatInput,
|
||||||
|
OpenAICoreRequestOptions
|
||||||
|
} from "@langchain/openai"
|
||||||
|
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"
|
||||||
|
import { TokenUsage } from "@langchain/core/language_models/base"
|
||||||
|
import { LegacyOpenAIInput } from "./types.js"
|
||||||
|
import { CustomAIMessageChunk } from "./CustomAIMessageChunk.js"
|
||||||
|
|
||||||
|
type OpenAIRoleEnum = "system" | "assistant" | "user" | "function" | "tool"
|
||||||
|
|
||||||
|
function extractGenericMessageCustomRole(message: ChatMessage) {
|
||||||
|
if (
|
||||||
|
message.role !== "system" &&
|
||||||
|
message.role !== "assistant" &&
|
||||||
|
message.role !== "user" &&
|
||||||
|
message.role !== "function" &&
|
||||||
|
message.role !== "tool"
|
||||||
|
) {
|
||||||
|
console.warn(`Unknown message role: ${message.role}`)
|
||||||
|
}
|
||||||
|
return message.role
|
||||||
|
}
|
||||||
|
export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum {
|
||||||
|
const type = message._getType()
|
||||||
|
switch (type) {
|
||||||
|
case "system":
|
||||||
|
return "system"
|
||||||
|
case "ai":
|
||||||
|
return "assistant"
|
||||||
|
case "human":
|
||||||
|
return "user"
|
||||||
|
case "function":
|
||||||
|
return "function"
|
||||||
|
case "tool":
|
||||||
|
return "tool"
|
||||||
|
case "generic": {
|
||||||
|
if (!ChatMessage.isInstance(message))
|
||||||
|
throw new Error("Invalid generic chat message")
|
||||||
|
return extractGenericMessageCustomRole(message) as OpenAIRoleEnum
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function openAIResponseToChatMessage(
|
||||||
|
message: OpenAIClient.Chat.Completions.ChatCompletionMessage
|
||||||
|
) {
|
||||||
|
switch (message.role) {
|
||||||
|
case "assistant":
|
||||||
|
return new AIMessage(message.content || "", {
|
||||||
|
// function_call: message.function_call,
|
||||||
|
// tool_calls: message.tool_calls
|
||||||
|
// reasoning_content: message?.reasoning_content || null
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
return new ChatMessage(message.content || "", message.role ?? "unknown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function _convertDeltaToMessageChunk(
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
delta: Record<string, any>,
|
||||||
|
defaultRole?: OpenAIRoleEnum
|
||||||
|
) {
|
||||||
|
const role = delta.role ?? defaultRole
|
||||||
|
const content = delta.content ?? ""
|
||||||
|
const reasoning_content: string | null = delta.reasoning_content ?? null
|
||||||
|
let additional_kwargs
|
||||||
|
if (delta.function_call) {
|
||||||
|
additional_kwargs = {
|
||||||
|
function_call: delta.function_call
|
||||||
|
}
|
||||||
|
} else if (delta.tool_calls) {
|
||||||
|
additional_kwargs = {
|
||||||
|
tool_calls: delta.tool_calls
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
additional_kwargs = {}
|
||||||
|
}
|
||||||
|
if (role === "user") {
|
||||||
|
return new HumanMessageChunk({ content })
|
||||||
|
} else if (role === "assistant") {
|
||||||
|
return new CustomAIMessageChunk({
|
||||||
|
content,
|
||||||
|
additional_kwargs: {
|
||||||
|
...additional_kwargs,
|
||||||
|
reasoning_content
|
||||||
|
}
|
||||||
|
}) as any
|
||||||
|
} else if (role === "system") {
|
||||||
|
return new SystemMessageChunk({ content })
|
||||||
|
} else if (role === "function") {
|
||||||
|
return new FunctionMessageChunk({
|
||||||
|
content,
|
||||||
|
additional_kwargs,
|
||||||
|
name: delta.name
|
||||||
|
})
|
||||||
|
} else if (role === "tool") {
|
||||||
|
return new ToolMessageChunk({
|
||||||
|
content,
|
||||||
|
additional_kwargs,
|
||||||
|
tool_call_id: delta.tool_call_id
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
return new ChatMessageChunk({ content, role })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function convertMessagesToOpenAIParams(messages: any[]) {
|
||||||
|
// TODO: Function messages do not support array content, fix cast
|
||||||
|
return messages.map((message) => {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const completionParam: { role: string; content: string; name?: string } = {
|
||||||
|
role: messageToOpenAIRole(message),
|
||||||
|
content: message.content
|
||||||
|
}
|
||||||
|
if (message.name != null) {
|
||||||
|
completionParam.name = message.name
|
||||||
|
}
|
||||||
|
|
||||||
|
return completionParam
|
||||||
|
})
|
||||||
|
}
|
||||||
|
export class CustomChatOpenAI<
|
||||||
|
CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions
|
||||||
|
>
|
||||||
|
extends BaseChatModel<CallOptions>
|
||||||
|
implements OpenAIChatInput {
|
||||||
|
temperature = 1
|
||||||
|
|
||||||
|
topP = 1
|
||||||
|
|
||||||
|
frequencyPenalty = 0
|
||||||
|
|
||||||
|
presencePenalty = 0
|
||||||
|
|
||||||
|
n = 1
|
||||||
|
|
||||||
|
logitBias?: Record<string, number>
|
||||||
|
|
||||||
|
modelName = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
model = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
modelKwargs?: OpenAIChatInput["modelKwargs"]
|
||||||
|
|
||||||
|
stop?: string[]
|
||||||
|
|
||||||
|
stopSequences?: string[]
|
||||||
|
|
||||||
|
user?: string
|
||||||
|
|
||||||
|
timeout?: number
|
||||||
|
|
||||||
|
streaming = false
|
||||||
|
|
||||||
|
streamUsage = true
|
||||||
|
|
||||||
|
maxTokens?: number
|
||||||
|
|
||||||
|
logprobs?: boolean
|
||||||
|
|
||||||
|
topLogprobs?: number
|
||||||
|
|
||||||
|
openAIApiKey?: string
|
||||||
|
|
||||||
|
apiKey?: string
|
||||||
|
|
||||||
|
azureOpenAIApiVersion?: string
|
||||||
|
|
||||||
|
azureOpenAIApiKey?: string
|
||||||
|
|
||||||
|
azureADTokenProvider?: () => Promise<string>
|
||||||
|
|
||||||
|
azureOpenAIApiInstanceName?: string
|
||||||
|
|
||||||
|
azureOpenAIApiDeploymentName?: string
|
||||||
|
|
||||||
|
azureOpenAIBasePath?: string
|
||||||
|
|
||||||
|
organization?: string
|
||||||
|
|
||||||
|
protected client: OpenAIClient
|
||||||
|
|
||||||
|
protected clientConfig: ClientOptions
|
||||||
|
static lc_name() {
|
||||||
|
return "ChatOpenAI"
|
||||||
|
}
|
||||||
|
get callKeys() {
|
||||||
|
return [
|
||||||
|
...super.callKeys,
|
||||||
|
"options",
|
||||||
|
"function_call",
|
||||||
|
"functions",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"promptIndex",
|
||||||
|
"response_format",
|
||||||
|
"seed"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
get lc_secrets() {
|
||||||
|
return {
|
||||||
|
openAIApiKey: "OPENAI_API_KEY",
|
||||||
|
azureOpenAIApiKey: "AZURE_OPENAI_API_KEY",
|
||||||
|
organization: "OPENAI_ORGANIZATION"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
get lc_aliases() {
|
||||||
|
return {
|
||||||
|
modelName: "model",
|
||||||
|
openAIApiKey: "openai_api_key",
|
||||||
|
azureOpenAIApiVersion: "azure_openai_api_version",
|
||||||
|
azureOpenAIApiKey: "azure_openai_api_key",
|
||||||
|
azureOpenAIApiInstanceName: "azure_openai_api_instance_name",
|
||||||
|
azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
constructor(
|
||||||
|
fields?: Partial<OpenAIChatInput> &
|
||||||
|
BaseChatModelParams & {
|
||||||
|
configuration?: ClientOptions & LegacyOpenAIInput;
|
||||||
|
},
|
||||||
|
/** @deprecated */
|
||||||
|
configuration?: ClientOptions & LegacyOpenAIInput
|
||||||
|
) {
|
||||||
|
super(fields ?? {})
|
||||||
|
Object.defineProperty(this, "lc_serializable", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: true
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "temperature", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: 1
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "topP", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: 1
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "frequencyPenalty", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "presencePenalty", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "n", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: 1
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "logitBias", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "modelName", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: "gpt-3.5-turbo"
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "modelKwargs", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "stop", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "user", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "timeout", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "streaming", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: false
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "maxTokens", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "logprobs", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "topLogprobs", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "openAIApiKey", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "azureOpenAIApiVersion", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "azureOpenAIApiKey", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "azureOpenAIApiInstanceName", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "azureOpenAIApiDeploymentName", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "azureOpenAIBasePath", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "organization", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "client", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
Object.defineProperty(this, "clientConfig", {
|
||||||
|
enumerable: true,
|
||||||
|
configurable: true,
|
||||||
|
writable: true,
|
||||||
|
value: void 0
|
||||||
|
})
|
||||||
|
this.openAIApiKey =
|
||||||
|
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
this.modelName = fields?.modelName ?? this.modelName
|
||||||
|
this.modelKwargs = fields?.modelKwargs ?? {}
|
||||||
|
this.timeout = fields?.timeout
|
||||||
|
this.temperature = fields?.temperature ?? this.temperature
|
||||||
|
this.topP = fields?.topP ?? this.topP
|
||||||
|
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty
|
||||||
|
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty
|
||||||
|
this.maxTokens = fields?.maxTokens
|
||||||
|
this.logprobs = fields?.logprobs
|
||||||
|
this.topLogprobs = fields?.topLogprobs
|
||||||
|
this.n = fields?.n ?? this.n
|
||||||
|
this.logitBias = fields?.logitBias
|
||||||
|
this.stop = fields?.stop
|
||||||
|
this.user = fields?.user
|
||||||
|
this.streaming = fields?.streaming ?? false
|
||||||
|
this.clientConfig = {
|
||||||
|
apiKey: this.openAIApiKey,
|
||||||
|
organization: this.organization,
|
||||||
|
baseURL: configuration?.basePath ?? fields?.configuration?.basePath,
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
defaultHeaders:
|
||||||
|
configuration?.baseOptions?.headers ??
|
||||||
|
fields?.configuration?.baseOptions?.headers,
|
||||||
|
defaultQuery:
|
||||||
|
configuration?.baseOptions?.params ??
|
||||||
|
fields?.configuration?.baseOptions?.params,
|
||||||
|
...configuration,
|
||||||
|
...fields?.configuration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Get the parameters used to invoke the model
|
||||||
|
*/
|
||||||
|
invocationParams(options) {
|
||||||
|
function isStructuredToolArray(tools) {
|
||||||
|
return (
|
||||||
|
tools !== undefined &&
|
||||||
|
tools.every((tool) => Array.isArray(tool.lc_namespace))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
const params = {
|
||||||
|
model: this.modelName,
|
||||||
|
temperature: this.temperature,
|
||||||
|
top_p: this.topP,
|
||||||
|
frequency_penalty: this.frequencyPenalty,
|
||||||
|
presence_penalty: this.presencePenalty,
|
||||||
|
max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens,
|
||||||
|
logprobs: this.logprobs,
|
||||||
|
top_logprobs: this.topLogprobs,
|
||||||
|
n: this.n,
|
||||||
|
logit_bias: this.logitBias,
|
||||||
|
stop: options?.stop ?? this.stop,
|
||||||
|
user: this.user,
|
||||||
|
stream: this.streaming,
|
||||||
|
functions: options?.functions,
|
||||||
|
function_call: options?.function_call,
|
||||||
|
tools: isStructuredToolArray(options?.tools)
|
||||||
|
? options?.tools.map(convertToOpenAITool)
|
||||||
|
: options?.tools,
|
||||||
|
tool_choice: options?.tool_choice,
|
||||||
|
response_format: options?.response_format,
|
||||||
|
seed: options?.seed,
|
||||||
|
...this.modelKwargs
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
/** @ignore */
|
||||||
|
_identifyingParams() {
|
||||||
|
return {
|
||||||
|
model_name: this.modelName,
|
||||||
|
//@ts-ignore
|
||||||
|
...this?.invocationParams(),
|
||||||
|
...this.clientConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async *_streamResponseChunks(
|
||||||
|
messages: BaseMessage[],
|
||||||
|
options: this["ParsedCallOptions"],
|
||||||
|
runManager?: CallbackManagerForLLMRun
|
||||||
|
): AsyncGenerator<ChatGenerationChunk> {
|
||||||
|
const messagesMapped = convertMessagesToOpenAIParams(messages)
|
||||||
|
const params = {
|
||||||
|
...this.invocationParams(options),
|
||||||
|
messages: messagesMapped,
|
||||||
|
stream: true
|
||||||
|
}
|
||||||
|
let defaultRole
|
||||||
|
//@ts-ignore
|
||||||
|
const streamIterable = await this.completionWithRetry(params, options)
|
||||||
|
for await (const data of streamIterable) {
|
||||||
|
const choice = data?.choices[0]
|
||||||
|
if (!choice) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const { delta } = choice
|
||||||
|
if (!delta) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const chunk = _convertDeltaToMessageChunk(delta, defaultRole)
|
||||||
|
defaultRole = delta.role ?? defaultRole
|
||||||
|
const newTokenIndices = {
|
||||||
|
//@ts-ignore
|
||||||
|
prompt: options?.promptIndex ?? 0,
|
||||||
|
completion: choice.index ?? 0
|
||||||
|
}
|
||||||
|
if (typeof chunk.content !== "string") {
|
||||||
|
console.log(
|
||||||
|
"[WARNING]: Received non-string content from OpenAI. This is currently not supported."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const generationInfo = { ...newTokenIndices } as any
|
||||||
|
if (choice.finish_reason !== undefined) {
|
||||||
|
generationInfo.finish_reason = choice.finish_reason
|
||||||
|
}
|
||||||
|
if (this.logprobs) {
|
||||||
|
generationInfo.logprobs = choice.logprobs
|
||||||
|
}
|
||||||
|
const generationChunk = new ChatGenerationChunk({
|
||||||
|
message: chunk,
|
||||||
|
text: chunk.content,
|
||||||
|
generationInfo
|
||||||
|
})
|
||||||
|
yield generationChunk
|
||||||
|
// eslint-disable-next-line no-void
|
||||||
|
void runManager?.handleLLMNewToken(
|
||||||
|
generationChunk.text ?? "",
|
||||||
|
newTokenIndices,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
{ chunk: generationChunk }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (options.signal?.aborted) {
|
||||||
|
throw new Error("AbortError")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Get the identifying parameters for the model
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
identifyingParams() {
|
||||||
|
return this._identifyingParams()
|
||||||
|
}
|
||||||
|
/** @ignore */
|
||||||
|
async _generate(
|
||||||
|
messages: BaseMessage[],
|
||||||
|
options: this["ParsedCallOptions"],
|
||||||
|
runManager?: CallbackManagerForLLMRun
|
||||||
|
): Promise<ChatResult> {
|
||||||
|
const tokenUsage: TokenUsage = {}
|
||||||
|
const params = this.invocationParams(options)
|
||||||
|
const messagesMapped: any[] = convertMessagesToOpenAIParams(messages)
|
||||||
|
if (params.stream) {
|
||||||
|
const stream = this._streamResponseChunks(messages, options, runManager)
|
||||||
|
const finalChunks: Record<number, ChatGenerationChunk> = {}
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
//@ts-ignore
|
||||||
|
chunk.message.response_metadata = {
|
||||||
|
...chunk.generationInfo,
|
||||||
|
//@ts-ignore
|
||||||
|
...chunk.message.response_metadata
|
||||||
|
}
|
||||||
|
const index = chunk.generationInfo?.completion ?? 0
|
||||||
|
if (finalChunks[index] === undefined) {
|
||||||
|
finalChunks[index] = chunk
|
||||||
|
} else {
|
||||||
|
finalChunks[index] = finalChunks[index].concat(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const generations = Object.entries(finalChunks)
|
||||||
|
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
|
||||||
|
.map(([_, value]) => value)
|
||||||
|
const { functions, function_call } = this.invocationParams(options)
|
||||||
|
// OpenAI does not support token usage report under stream mode,
|
||||||
|
// fallback to estimation.
|
||||||
|
const promptTokenUsage = await this.getEstimatedTokenCountFromPrompt(
|
||||||
|
messages,
|
||||||
|
functions,
|
||||||
|
function_call
|
||||||
|
)
|
||||||
|
const completionTokenUsage =
|
||||||
|
await this.getNumTokensFromGenerations(generations)
|
||||||
|
tokenUsage.promptTokens = promptTokenUsage
|
||||||
|
tokenUsage.completionTokens = completionTokenUsage
|
||||||
|
tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage
|
||||||
|
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
|
||||||
|
} else {
|
||||||
|
const data = await this.completionWithRetry(
|
||||||
|
{
|
||||||
|
...params,
|
||||||
|
//@ts-ignore
|
||||||
|
stream: false,
|
||||||
|
messages: messagesMapped
|
||||||
|
},
|
||||||
|
{
|
||||||
|
signal: options?.signal,
|
||||||
|
//@ts-ignore
|
||||||
|
...options?.options
|
||||||
|
}
|
||||||
|
)
|
||||||
|
const {
|
||||||
|
completion_tokens: completionTokens,
|
||||||
|
prompt_tokens: promptTokens,
|
||||||
|
total_tokens: totalTokens
|
||||||
|
//@ts-ignore
|
||||||
|
} = data?.usage ?? {}
|
||||||
|
if (completionTokens) {
|
||||||
|
tokenUsage.completionTokens =
|
||||||
|
(tokenUsage.completionTokens ?? 0) + completionTokens
|
||||||
|
}
|
||||||
|
if (promptTokens) {
|
||||||
|
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens
|
||||||
|
}
|
||||||
|
if (totalTokens) {
|
||||||
|
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens
|
||||||
|
}
|
||||||
|
const generations = []
|
||||||
|
//@ts-ignore
|
||||||
|
for (const part of data?.choices ?? []) {
|
||||||
|
const text = part.message?.content ?? ""
|
||||||
|
const generation = {
|
||||||
|
text,
|
||||||
|
message: openAIResponseToChatMessage(
|
||||||
|
part.message ?? { role: "assistant" }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
//@ts-ignore
|
||||||
|
generation.generationInfo = {
|
||||||
|
...(part.finish_reason ? { finish_reason: part.finish_reason } : {}),
|
||||||
|
...(part.logprobs ? { logprobs: part.logprobs } : {})
|
||||||
|
}
|
||||||
|
generations.push(generation)
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
generations,
|
||||||
|
llmOutput: { tokenUsage }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Estimate the number of tokens a prompt will use.
|
||||||
|
* Modified from: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts
|
||||||
|
*/
|
||||||
|
async getEstimatedTokenCountFromPrompt(messages, functions, function_call) {
|
||||||
|
let tokens = (await this.getNumTokensFromMessages(messages)).totalCount
|
||||||
|
if (functions && messages.find((m) => m._getType() === "system")) {
|
||||||
|
tokens -= 4
|
||||||
|
}
|
||||||
|
if (function_call === "none") {
|
||||||
|
tokens += 1
|
||||||
|
} else if (typeof function_call === "object") {
|
||||||
|
tokens += (await this.getNumTokens(function_call.name)) + 4
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Estimate the number of tokens an array of generations have used.
|
||||||
|
*/
|
||||||
|
async getNumTokensFromGenerations(generations) {
|
||||||
|
const generationUsages = await Promise.all(
|
||||||
|
generations.map(async (generation) => {
|
||||||
|
if (generation.message.additional_kwargs?.function_call) {
|
||||||
|
return (await this.getNumTokensFromMessages([generation.message]))
|
||||||
|
.countPerMessage[0]
|
||||||
|
} else {
|
||||||
|
return await this.getNumTokens(generation.message.content)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
return generationUsages.reduce((a, b) => a + b, 0)
|
||||||
|
}
|
||||||
|
async getNumTokensFromMessages(messages) {
|
||||||
|
let totalCount = 0
|
||||||
|
let tokensPerMessage = 0
|
||||||
|
let tokensPerName = 0
|
||||||
|
// From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||||
|
if (this.modelName === "gpt-3.5-turbo-0301") {
|
||||||
|
tokensPerMessage = 4
|
||||||
|
tokensPerName = -1
|
||||||
|
} else {
|
||||||
|
tokensPerMessage = 3
|
||||||
|
tokensPerName = 1
|
||||||
|
}
|
||||||
|
const countPerMessage = await Promise.all(
|
||||||
|
messages.map(async (message) => {
|
||||||
|
const textCount = await this.getNumTokens(message.content)
|
||||||
|
const roleCount = await this.getNumTokens(messageToOpenAIRole(message))
|
||||||
|
const nameCount =
|
||||||
|
message.name !== undefined
|
||||||
|
? tokensPerName + (await this.getNumTokens(message.name))
|
||||||
|
: 0
|
||||||
|
let count = textCount + tokensPerMessage + roleCount + nameCount
|
||||||
|
// From: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts messageTokenEstimate
|
||||||
|
const openAIMessage = message
|
||||||
|
if (openAIMessage._getType() === "function") {
|
||||||
|
count -= 2
|
||||||
|
}
|
||||||
|
if (openAIMessage.additional_kwargs?.function_call) {
|
||||||
|
count += 3
|
||||||
|
}
|
||||||
|
if (openAIMessage?.additional_kwargs.function_call?.name) {
|
||||||
|
count += await this.getNumTokens(
|
||||||
|
openAIMessage.additional_kwargs.function_call?.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (openAIMessage.additional_kwargs.function_call?.arguments) {
|
||||||
|
try {
|
||||||
|
count += await this.getNumTokens(
|
||||||
|
// Remove newlines and spaces
|
||||||
|
JSON.stringify(
|
||||||
|
JSON.parse(
|
||||||
|
openAIMessage.additional_kwargs.function_call?.arguments
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
"Error parsing function arguments",
|
||||||
|
error,
|
||||||
|
JSON.stringify(openAIMessage.additional_kwargs.function_call)
|
||||||
|
)
|
||||||
|
count += await this.getNumTokens(
|
||||||
|
openAIMessage.additional_kwargs.function_call?.arguments
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
totalCount += count
|
||||||
|
return count
|
||||||
|
})
|
||||||
|
)
|
||||||
|
totalCount += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return { totalCount, countPerMessage }
|
||||||
|
}
|
||||||
|
async completionWithRetry(
|
||||||
|
request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming,
|
||||||
|
options?: OpenAICoreRequestOptions
|
||||||
|
) {
|
||||||
|
const requestOptions = this._getClientOptions(options)
|
||||||
|
return this.caller.call(async () => {
|
||||||
|
try {
|
||||||
|
const res = await this.client.chat.completions.create(
|
||||||
|
request,
|
||||||
|
requestOptions
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
} catch (e) {
|
||||||
|
const error = wrapOpenAIClientError(e)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_getClientOptions(options) {
|
||||||
|
if (!this.client) {
|
||||||
|
const openAIEndpointConfig = {
|
||||||
|
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
|
||||||
|
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
|
||||||
|
azureOpenAIApiKey: this.azureOpenAIApiKey,
|
||||||
|
azureOpenAIBasePath: this.azureOpenAIBasePath,
|
||||||
|
baseURL: this.clientConfig.baseURL
|
||||||
|
}
|
||||||
|
const endpoint = getEndpoint(openAIEndpointConfig)
|
||||||
|
const params = {
|
||||||
|
...this.clientConfig,
|
||||||
|
baseURL: endpoint,
|
||||||
|
timeout: this.timeout,
|
||||||
|
maxRetries: 0
|
||||||
|
}
|
||||||
|
if (!params.baseURL) {
|
||||||
|
delete params.baseURL
|
||||||
|
}
|
||||||
|
this.client = new OpenAIClient(params)
|
||||||
|
}
|
||||||
|
const requestOptions = {
|
||||||
|
...this.clientConfig,
|
||||||
|
...options
|
||||||
|
}
|
||||||
|
if (this.azureOpenAIApiKey) {
|
||||||
|
requestOptions.headers = {
|
||||||
|
"api-key": this.azureOpenAIApiKey,
|
||||||
|
...requestOptions.headers
|
||||||
|
}
|
||||||
|
requestOptions.query = {
|
||||||
|
"api-version": this.azureOpenAIApiVersion,
|
||||||
|
...requestOptions.query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return requestOptions
|
||||||
|
}
|
||||||
|
_llmType() {
|
||||||
|
return "openai"
|
||||||
|
}
|
||||||
|
/** @ignore */
|
||||||
|
_combineLLMOutput(...llmOutputs) {
|
||||||
|
return llmOutputs.reduce(
|
||||||
|
(acc, llmOutput) => {
|
||||||
|
if (llmOutput && llmOutput.tokenUsage) {
|
||||||
|
acc.tokenUsage.completionTokens +=
|
||||||
|
llmOutput.tokenUsage.completionTokens ?? 0
|
||||||
|
acc.tokenUsage.promptTokens += llmOutput.tokenUsage.promptTokens ?? 0
|
||||||
|
acc.tokenUsage.totalTokens += llmOutput.tokenUsage.totalTokens ?? 0
|
||||||
|
}
|
||||||
|
return acc
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tokenUsage: {
|
||||||
|
completionTokens: 0,
|
||||||
|
promptTokens: 0,
|
||||||
|
totalTokens: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
withStructuredOutput(outputSchema, config) {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
let schema
|
||||||
|
let name
|
||||||
|
let method
|
||||||
|
let includeRaw
|
||||||
|
if (isStructuredOutputMethodParams(outputSchema)) {
|
||||||
|
schema = outputSchema.schema
|
||||||
|
name = outputSchema.name
|
||||||
|
method = outputSchema.method
|
||||||
|
includeRaw = outputSchema.includeRaw
|
||||||
|
} else {
|
||||||
|
schema = outputSchema
|
||||||
|
name = config?.name
|
||||||
|
method = config?.method
|
||||||
|
includeRaw = config?.includeRaw
|
||||||
|
}
|
||||||
|
let llm
|
||||||
|
let outputParser
|
||||||
|
if (method === "jsonMode") {
|
||||||
|
llm = this.bind({
|
||||||
|
})
|
||||||
|
if (isZodSchema(schema)) {
|
||||||
|
outputParser = StructuredOutputParser.fromZodSchema(schema)
|
||||||
|
} else {
|
||||||
|
outputParser = new JsonOutputParser()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let functionName = name ?? "extract"
|
||||||
|
// Is function calling
|
||||||
|
|
||||||
|
let openAIFunctionDefinition
|
||||||
|
if (
|
||||||
|
typeof schema.name === "string" &&
|
||||||
|
typeof schema.parameters === "object" &&
|
||||||
|
schema.parameters != null
|
||||||
|
) {
|
||||||
|
openAIFunctionDefinition = schema
|
||||||
|
functionName = schema.name
|
||||||
|
} else {
|
||||||
|
openAIFunctionDefinition = {
|
||||||
|
name: schema.title ?? functionName,
|
||||||
|
description: schema.description ?? "",
|
||||||
|
parameters: schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llm = this.bind({
|
||||||
|
|
||||||
|
})
|
||||||
|
outputParser = new JsonOutputKeyToolsParser({
|
||||||
|
returnSingle: true,
|
||||||
|
keyName: functionName
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (!includeRaw) {
|
||||||
|
return llm.pipe(outputParser)
|
||||||
|
}
|
||||||
|
const parserAssign = RunnablePassthrough.assign({
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
parsed: (input, config) => outputParser.invoke(input.raw, config)
|
||||||
|
})
|
||||||
|
const parserNone = RunnablePassthrough.assign({
|
||||||
|
parsed: () => null
|
||||||
|
})
|
||||||
|
const parsedWithFallback = parserAssign.withFallbacks({
|
||||||
|
fallbacks: [parserNone]
|
||||||
|
})
|
||||||
|
return RunnableSequence.from([
|
||||||
|
{
|
||||||
|
raw: llm
|
||||||
|
},
|
||||||
|
parsedWithFallback
|
||||||
|
] as any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function isZodSchema(
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
input
|
||||||
|
) {
|
||||||
|
// Check for a characteristic method of Zod schemas
|
||||||
|
return typeof input?.parse === "function"
|
||||||
|
}
|
||||||
|
function isStructuredOutputMethodParams(
|
||||||
|
x
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
) {
|
||||||
|
return (
|
||||||
|
x !== undefined &&
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
typeof x.schema === "object"
|
||||||
|
)
|
||||||
|
}
|
@ -2,9 +2,9 @@ import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models"
|
|||||||
import { ChatChromeAI } from "./ChatChromeAi"
|
import { ChatChromeAI } from "./ChatChromeAi"
|
||||||
import { ChatOllama } from "./ChatOllama"
|
import { ChatOllama } from "./ChatOllama"
|
||||||
import { getOpenAIConfigById } from "@/db/openai"
|
import { getOpenAIConfigById } from "@/db/openai"
|
||||||
import { ChatOpenAI } from "@langchain/openai"
|
|
||||||
import { urlRewriteRuntime } from "@/libs/runtime"
|
import { urlRewriteRuntime } from "@/libs/runtime"
|
||||||
import { ChatGoogleAI } from "./ChatGoogleAI"
|
import { ChatGoogleAI } from "./ChatGoogleAI"
|
||||||
|
import { CustomChatOpenAI } from "./CustomChatOpenAI"
|
||||||
|
|
||||||
export const pageAssistModel = async ({
|
export const pageAssistModel = async ({
|
||||||
model,
|
model,
|
||||||
@ -76,7 +76,7 @@ export const pageAssistModel = async ({
|
|||||||
}) as any
|
}) as any
|
||||||
}
|
}
|
||||||
|
|
||||||
return new ChatOpenAI({
|
return new CustomChatOpenAI({
|
||||||
modelName: modelInfo.model_id,
|
modelName: modelInfo.model_id,
|
||||||
openAIApiKey: providerInfo.apiKey || "temp",
|
openAIApiKey: providerInfo.apiKey || "temp",
|
||||||
temperature,
|
temperature,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user