diff --git a/src/hooks/useMessage.tsx b/src/hooks/useMessage.tsx index 4dcac4e..3b6b4e5 100644 --- a/src/hooks/useMessage.tsx +++ b/src/hooks/useMessage.tsx @@ -9,7 +9,7 @@ import { } from "~/services/ollama" import { useStoreMessageOption, type Message } from "~/store/option" import { useStoreMessage } from "~/store" -import { HumanMessage, SystemMessage } from "@langchain/core/messages" +import { SystemMessage } from "@langchain/core/messages" import { getDataFromCurrentTab } from "~/libs/get-html" import { MemoryVectorStore } from "langchain/vectorstores/memory" import { memoryEmbedding } from "@/utils/memory-embeddings" @@ -33,6 +33,7 @@ import { getAllDefaultModelSettings } from "@/services/model-settings" import { getSystemPromptForWeb } from "@/web/web" import { pageAssistModel } from "@/models" import { getPrompt } from "@/services/application" +import { humanMessageFormatter } from "@/utils/human-message" export const useMessage = () => { const { @@ -313,7 +314,7 @@ export const useMessage = () => { ] } - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: systemPrompt @@ -321,10 +322,11 @@ export const useMessage = () => { .replace("{question}", query), type: "text" } - ] + ], + model: selectedModel }) - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) const chunks = await ollama.stream( [...applicationChatHistory, humanMessage], @@ -500,16 +502,17 @@ export const useMessage = () => { const prompt = await systemPromptForNonRag() const selectedPrompt = await getPromptById(selectedSystemPrompt) - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -519,11 +522,12 @@ export const useMessage = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt && !selectedPrompt) { applicationChatHistory.unshift( @@ -760,16 +764,17 @@ export const useMessage = () => { // message = message.trim().replaceAll("\n", " ") - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -779,11 +784,12 @@ export const useMessage = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt) { applicationChatHistory.unshift( @@ -966,16 +972,17 @@ export const useMessage = () => { try { const prompt = await getPrompt(messageType) - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: prompt.replace("{text}", message), type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: prompt.replace("{text}", message), @@ -985,7 +992,8 @@ export const useMessage = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } diff --git a/src/hooks/useMessageOption.tsx b/src/hooks/useMessageOption.tsx index 4e633f8..bceb5a6 100644 --- a/src/hooks/useMessageOption.tsx +++ b/src/hooks/useMessageOption.tsx @@ -33,6 +33,7 @@ import { useStoreChatModelSettings } from "@/store/model" import { getAllDefaultModelSettings } from "@/services/model-settings" import { pageAssistModel } from "@/models" import { getNoOfRetrievedDocs } from "@/services/app" +import { humanMessageFormatter } from "@/utils/human-message" export const useMessageOption = () => { const { @@ -68,7 +69,7 @@ export const useMessageOption = () => { } = useStoreMessageOption() const currentChatModelSettings = useStoreChatModelSettings() const [selectedModel, setSelectedModel] = useStorage("selectedModel") - const [ speechToTextLanguage, setSpeechToTextLanguage ] = useStorage( + const [speechToTextLanguage, setSpeechToTextLanguage] = useStorage( "speechToTextLanguage", "en-US" ) @@ -207,16 +208,17 @@ export const useMessageOption = () => { // message = message.trim().replaceAll("\n", " ") - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -226,11 +228,12 @@ export const useMessageOption = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt) { applicationChatHistory.unshift( @@ -412,16 +415,17 @@ export const useMessageOption = () => { const prompt = await systemPromptForNonRagOption() const selectedPrompt = await getPromptById(selectedSystemPrompt) - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: message, type: "text" } - ] + ], + model: selectedModel }) if (image.length > 0) { - humanMessage = new HumanMessage({ + humanMessage = humanMessageFormatter({ content: [ { text: message, @@ -431,11 +435,12 @@ export const useMessageOption = () => { image_url: image, type: "image_url" } - ] + ], + model: selectedModel }) } - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) if (prompt && !selectedPrompt) { applicationChatHistory.unshift( @@ -695,7 +700,7 @@ export const useMessageOption = () => { }) // message = message.trim().replaceAll("\n", " ") - let humanMessage = new HumanMessage({ + let humanMessage = humanMessageFormatter({ content: [ { text: systemPrompt @@ -703,10 +708,11 @@ export const useMessageOption = () => { .replace("{question}", message), type: "text" } - ] + ], + model: selectedModel }) - const applicationChatHistory = generateHistory(history) + const applicationChatHistory = generateHistory(history, selectedModel) const chunks = await ollama.stream( [...applicationChatHistory, humanMessage], diff --git a/src/utils/generate-history.ts b/src/utils/generate-history.ts index dd6f446..cb4b466 100644 --- a/src/utils/generate-history.ts +++ b/src/utils/generate-history.ts @@ -1,55 +1,66 @@ +import { isCustomModel } from "@/db/models" import { - HumanMessage, - AIMessage, - type MessageContent, + HumanMessage, + AIMessage, + type MessageContent } from "@langchain/core/messages" export const generateHistory = ( - messages: { - role: "user" | "assistant" | "system" - content: string - image?: string - }[] + messages: { + role: "user" | "assistant" | "system" + content: string + image?: string + }[], + model: string ) => { - let history = [] - for (const message of messages) { - if (message.role === "user") { - let content: MessageContent = [ - { - type: "text", - text: message.content - } - ] - - if (message.image) { - content = [ - { - type: "image_url", - image_url: message.image - }, - { - type: "text", - text: message.content - } - ] + let history = [] + const isCustom = isCustomModel(model) + for (const message of messages) { + if (message.role === "user") { + let content: MessageContent = isCustom + ? message.content + : [ + { + type: "text", + text: message.content } - history.push( - new HumanMessage({ - content: content - }) - ) - } else if (message.role === "assistant") { - history.push( - new AIMessage({ - content: [ - { - type: "text", - text: message.content - } - ] - }) - ) - } + ] + + if (message.image) { + content = [ + { + type: "image_url", + image_url: !isCustom + ? message.image + : { + url: message.image + } + }, + { + type: "text", + text: message.content + } + ] + } + history.push( + new HumanMessage({ + content: content + }) + ) + } else if (message.role === "assistant") { + history.push( + new AIMessage({ + content: isCustom + ? message.content + : [ + { + type: "text", + text: message.content + } + ] + }) + ) } - return history -} \ No newline at end of file + } + return history +} diff --git a/src/utils/human-message.tsx b/src/utils/human-message.tsx new file mode 100644 index 0000000..6712339 --- /dev/null +++ b/src/utils/human-message.tsx @@ -0,0 +1,43 @@ +import { isCustomModel } from "@/db/models" +import { HumanMessage, type MessageContent } from "@langchain/core/messages" + + +type HumanMessageType = { + content: MessageContent, + model: string +} + +export const humanMessageFormatter = ({ content, model }: HumanMessageType) => { + + const isCustom = isCustomModel(model) + + if(isCustom) { + if(typeof content !== 'string') { + if(content.length > 1) { + // this means that we need to reformat the image_url + const newContent: MessageContent = [ + { + type: "text", + //@ts-ignore + text: content[0].text + }, + { + type: "image_url", + image_url: { + //@ts-ignore + url: content[1].image_url + } + } + ] + + return new HumanMessage({ + content: newContent + }) + } + } + } + + return new HumanMessage({ + content, + }) +} \ No newline at end of file