From 192e3893bb26de84617bd2794fcdfe9e48167c34 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sun, 29 Sep 2024 23:59:15 +0530 Subject: [PATCH] feat: support custom models for messages This commit introduces support for custom models in the message history generation process. Previously, the history would format messages using LangChain's standard message structure, which is not compatible with custom models. This change allows for correct history formatting regardless of the selected model type, enhancing compatibility and user experience. --- src/hooks/useMessage.tsx | 44 ++++++++------ src/hooks/useMessageOption.tsx | 34 ++++++----- src/utils/generate-history.ts | 107 ++++++++++++++++++--------------- src/utils/human-message.tsx | 43 +++++++++++++ 4 files changed, 148 insertions(+), 80 deletions(-) create mode 100644 src/utils/human-message.tsx diff --git a/src/hooks/useMessage.tsx b/src/hooks/useMessage.tsx index 4dcac4e..3b6b4e5 100644 --- a/src/hooks/useMessage.tsx +++ b/src/hooks/useMessage.tsx @@ -9,7 +9,7 @@ import { } from "~/services/ollama" import { useStoreMessageOption, type Message } from "~/store/option" import { useStoreMessage } from "~/store" -import { HumanMessage, SystemMessage } from "@langchain/core/messages" +import { SystemMessage } from "@langchain/core/messages" import { getDataFromCurrentTab } from "~/libs/get-html" import { MemoryVectorStore } from "langchain/vectorstores/memory" import { memoryEmbedding } from "@/utils/memory-embeddings" @@ -33,6 +33,7 @@ import { getAllDefaultModelSettings } from "@/services/model-settings" import { getSystemPromptForWeb } from "@/web/web" import { pageAssistModel } from "@/models" import { getPrompt } from "@/services/application" +import { humanMessageFormatter } from "@/utils/human-message" export const useMessage = () => { const { @@ -313,7 +314,7 @@ export const useMessage = () => { ] } - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: systemPrompt @@ -321,10 +322,11 @@ export const useMessage = () => { .replace("{question}", query), type: "text" } - ] + ], + model: selectedModel }) - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) const chunks = await ollama.stream( [...applicationChatHistory, humanMessage], @@ -500,16 +502,17 @@ export const useMessage = () => { const prompt = await systemPromptForNonRag() const selectedPrompt = await getPromptById(selectedSystemPrompt) - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -519,11 +522,12 @@ export const useMessage = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt && !selectedPrompt) { applicationChatHistory.unshift( @@ -760,16 +764,17 @@ export const useMessage = () => { // message = message.trim().replaceAll("\n", " ") - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -779,11 +784,12 @@ export const useMessage = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt) { applicationChatHistory.unshift( @@ -966,16 +972,17 @@ export const useMessage = () => { try { const prompt = await getPrompt(messageType) - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: prompt.replace("{text}", message), type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: prompt.replace("{text}", message), @@ -985,7 +992,8 @@ export const useMessage = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } diff --git a/src/hooks/useMessageOption.tsx b/src/hooks/useMessageOption.tsx index 4e633f8..bceb5a6 100644 --- a/src/hooks/useMessageOption.tsx +++ b/src/hooks/useMessageOption.tsx @@ -33,6 +33,7 @@ import { useStoreChatModelSettings } from "@/store/model" import { getAllDefaultModelSettings } from "@/services/model-settings" import { pageAssistModel } from "@/models" import { getNoOfRetrievedDocs } from "@/services/app" +import { humanMessageFormatter } from "@/utils/human-message" export const useMessageOption = () => { const { @@ -68,7 +69,7 @@ export const useMessageOption = () => { } = useStoreMessageOption() const currentChatModelSettings = useStoreChatModelSettings() const [selectedModel, setSelectedModel] = useStorage("selectedModel") - const [ speechToTextLanguage, setSpeechToTextLanguage ] = useStorage( + const [speechToTextLanguage, setSpeechToTextLanguage] = useStorage( "speechToTextLanguage", "en-US" ) @@ -207,16 +208,17 @@ export const useMessageOption = () => { // message = message.trim().replaceAll("\n", " ") - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -226,11 +228,12 @@ export const useMessageOption = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt) { applicationChatHistory.unshift( @@ -412,16 +415,17 @@ export const useMessageOption = () => { const prompt = await systemPromptForNonRagOption() const selectedPrompt = await getPromptById(selectedSystemPrompt) - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -431,11 +435,12 @@ export const useMessageOption = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt && !selectedPrompt) { applicationChatHistory.unshift( @@ -695,7 +700,7 @@ export const useMessageOption = () => { }) // message = message.trim().replaceAll("\n", " ") - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: systemPrompt @@ -703,10 +708,11 @@ export const useMessageOption = () => { .replace("{question}", message), type: "text" } - ] + ], + model: selectedModel }) - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) const chunks = await ollama.stream( [...applicationChatHistory, humanMessage], diff --git a/src/utils/generate-history.ts b/src/utils/generate-history.ts index dd6f446..cb4b466 100644 --- a/src/utils/generate-history.ts +++ b/src/utils/generate-history.ts @@ -1,55 +1,66 @@ +import { isCustomModel } from "@/db/models" import { - HumanMessage, - AIMessage, - type MessageContent, + HumanMessage, + AIMessage, + type MessageContent } from "@langchain/core/messages" export const generateHistory = ( - messages: { - role: "user" | "assistant" | "system" - content: string - image?: string - }[] + messages: { + role: "user" | "assistant" | "system" + content: string + image?: string + }[], + model: string ) => { - let history = [] - for (const message of messages) { - if (message.role === "user") { - let content: MessageContent = [ - { - type: "text", - text: message.content - } - ] - - if (message.image) { - content = [ - { - type: "image_url", - image_url: message.image - }, - { - type: "text", - text: message.content - } - ] + let history = [] + const isCustom = isCustomModel(model) + for (const message of messages) { + if (message.role === "user") { + let content: MessageContent = isCustom + ? message.content + : [ + { + type: "text", + text: message.content } - history.push( - new HumanMessage({ - content: content - }) - ) - } else if (message.role === "assistant") { - history.push( - new AIMessage({ - content: [ - { - type: "text", - text: message.content - } - ] - }) - ) - } + ] + + if (message.image) { + content = [ + { + type: "image_url", + image_url: !isCustom + ? message.image + : { + url: message.image + } + }, + { + type: "text", + text: message.content + } + ] + } + history.push( + new HumanMessage({ + content: content + }) + ) + } else if (message.role === "assistant") { + history.push( + new AIMessage({ + content: isCustom + ? message.content + : [ + { + type: "text", + text: message.content + } + ] + }) + ) } - return history -} \ No newline at end of file + } + return history +} diff --git a/src/utils/human-message.tsx b/src/utils/human-message.tsx new file mode 100644 index 0000000..6712339 --- /dev/null +++ b/src/utils/human-message.tsx @@ -0,0 +1,43 @@ +import { isCustomModel } from "@/db/models" +import { HumanMessage, type MessageContent } from "@langchain/core/messages" + + +type HumanMessageType = { + content: MessageContent, + model: string +} + +export const humanMessageFormatter = ({ content, model }: HumanMessageType) => { + + const isCustom = isCustomModel(model) + + if(isCustom) { + if(typeof content !== 'string') { + if(content.length > 1) { + // this means that we need to reformat the image_url + const newContent: MessageContent = [ + { + type: "text", + //@ts-ignore + text: content[0].text + }, + { + type: "image_url", + image_url: { + //@ts-ignore + url: content[1].image_url + } + } + ] + + return new HumanMessage({ + content: newContent + }) + } + } + } + + return new HumanMessage({ + content, + }) +} \ No newline at end of file