feat: support custom models for messages
This commit introduces support for custom models in the message history generation process. Previously, the history would format messages using LangChain's standard message structure, which is not compatible with custom models. This change allows for correct history formatting regardless of the selected model type, enhancing compatibility and user experience.
This commit is contained in:
		
							parent
							
								
									c8620637f8
								
							
						
					
					
						commit
						192e3893bb
					
				| @ -9,7 +9,7 @@ import { | ||||
| } from "~/services/ollama" | ||||
| import { useStoreMessageOption, type Message } from "~/store/option" | ||||
| import { useStoreMessage } from "~/store" | ||||
| import { HumanMessage, SystemMessage } from "@langchain/core/messages" | ||||
| import { SystemMessage } from "@langchain/core/messages" | ||||
| import { getDataFromCurrentTab } from "~/libs/get-html" | ||||
| import { MemoryVectorStore } from "langchain/vectorstores/memory" | ||||
| import { memoryEmbedding } from "@/utils/memory-embeddings" | ||||
| @ -33,6 +33,7 @@ import { getAllDefaultModelSettings } from "@/services/model-settings" | ||||
| import { getSystemPromptForWeb } from "@/web/web" | ||||
| import { pageAssistModel } from "@/models" | ||||
| import { getPrompt } from "@/services/application" | ||||
| import { humanMessageFormatter } from "@/utils/human-message" | ||||
| 
 | ||||
| export const useMessage = () => { | ||||
|   const { | ||||
| @ -313,7 +314,7 @@ export const useMessage = () => { | ||||
|         ] | ||||
|       } | ||||
| 
 | ||||
|       let humanMessage = new HumanMessage({ | ||||
|       let humanMessage = humanMessageFormatter({ | ||||
|         content: [ | ||||
|           { | ||||
|             text: systemPrompt | ||||
| @ -321,10 +322,11 @@ export const useMessage = () => { | ||||
|               .replace("{question}", query), | ||||
|             type: "text" | ||||
|           } | ||||
|         ] | ||||
|         ], | ||||
|         model: selectedModel | ||||
|       }) | ||||
| 
 | ||||
|       const applicationChatHistory = generateHistory(history) | ||||
|       const applicationChatHistory = generateHistory(history, selectedModel) | ||||
| 
 | ||||
|       const chunks = await ollama.stream( | ||||
|         [...applicationChatHistory, humanMessage], | ||||
| @ -500,16 +502,17 @@ export const useMessage = () => { | ||||
|       const prompt = await systemPromptForNonRag() | ||||
|       const selectedPrompt = await getPromptById(selectedSystemPrompt) | ||||
| 
 | ||||
|       let humanMessage = new HumanMessage({ | ||||
|       let humanMessage = humanMessageFormatter({ | ||||
|         content: [ | ||||
|           { | ||||
|             text: message, | ||||
|             type: "text" | ||||
|           } | ||||
|         ] | ||||
|         ], | ||||
|         model: selectedModel | ||||
|       }) | ||||
|       if (image.length > 0) { | ||||
|         humanMessage = new HumanMessage({ | ||||
|         humanMessage = humanMessageFormatter({ | ||||
|           content: [ | ||||
|             { | ||||
|               text: message, | ||||
| @ -519,11 +522,12 @@ export const useMessage = () => { | ||||
|               image_url: image, | ||||
|               type: "image_url" | ||||
|             } | ||||
|           ] | ||||
|           ], | ||||
|           model: selectedModel | ||||
|         }) | ||||
|       } | ||||
| 
 | ||||
|       const applicationChatHistory = generateHistory(history) | ||||
|       const applicationChatHistory = generateHistory(history, selectedModel) | ||||
| 
 | ||||
|       if (prompt && !selectedPrompt) { | ||||
|         applicationChatHistory.unshift( | ||||
| @ -760,16 +764,17 @@ export const useMessage = () => { | ||||
| 
 | ||||
|       //  message = message.trim().replaceAll("\n", " ")
 | ||||
| 
 | ||||
|       let humanMessage = new HumanMessage({ | ||||
|       let humanMessage = humanMessageFormatter({ | ||||
|         content: [ | ||||
|           { | ||||
|             text: message, | ||||
|             type: "text" | ||||
|           } | ||||
|         ] | ||||
|         ], | ||||
|         model: selectedModel | ||||
|       }) | ||||
|       if (image.length > 0) { | ||||
|         humanMessage = new HumanMessage({ | ||||
|         humanMessage = humanMessageFormatter({ | ||||
|           content: [ | ||||
|             { | ||||
|               text: message, | ||||
| @ -779,11 +784,12 @@ export const useMessage = () => { | ||||
|               image_url: image, | ||||
|               type: "image_url" | ||||
|             } | ||||
|           ] | ||||
|           ], | ||||
|           model: selectedModel | ||||
|         }) | ||||
|       } | ||||
| 
 | ||||
|       const applicationChatHistory = generateHistory(history) | ||||
|       const applicationChatHistory = generateHistory(history, selectedModel) | ||||
| 
 | ||||
|       if (prompt) { | ||||
|         applicationChatHistory.unshift( | ||||
| @ -966,16 +972,17 @@ export const useMessage = () => { | ||||
| 
 | ||||
|     try { | ||||
|       const prompt = await getPrompt(messageType) | ||||
|       let humanMessage = new HumanMessage({ | ||||
|       let humanMessage = humanMessageFormatter({ | ||||
|         content: [ | ||||
|           { | ||||
|             text: prompt.replace("{text}", message), | ||||
|             type: "text" | ||||
|           } | ||||
|         ] | ||||
|         ], | ||||
|         model: selectedModel | ||||
|       }) | ||||
|       if (image.length > 0) { | ||||
|         humanMessage = new HumanMessage({ | ||||
|         humanMessage = humanMessageFormatter({ | ||||
|           content: [ | ||||
|             { | ||||
|               text: prompt.replace("{text}", message), | ||||
| @ -985,7 +992,8 @@ export const useMessage = () => { | ||||
|               image_url: image, | ||||
|               type: "image_url" | ||||
|             } | ||||
|           ] | ||||
|           ], | ||||
|           model: selectedModel | ||||
|         }) | ||||
|       } | ||||
| 
 | ||||
|  | ||||
| @ -33,6 +33,7 @@ import { useStoreChatModelSettings } from "@/store/model" | ||||
| import { getAllDefaultModelSettings } from "@/services/model-settings" | ||||
| import { pageAssistModel } from "@/models" | ||||
| import { getNoOfRetrievedDocs } from "@/services/app" | ||||
| import { humanMessageFormatter } from "@/utils/human-message" | ||||
| 
 | ||||
| export const useMessageOption = () => { | ||||
|   const { | ||||
| @ -68,7 +69,7 @@ export const useMessageOption = () => { | ||||
|   } = useStoreMessageOption() | ||||
|   const currentChatModelSettings = useStoreChatModelSettings() | ||||
|   const [selectedModel, setSelectedModel] = useStorage("selectedModel") | ||||
|   const [ speechToTextLanguage, setSpeechToTextLanguage ] = useStorage( | ||||
|   const [speechToTextLanguage, setSpeechToTextLanguage] = useStorage( | ||||
|     "speechToTextLanguage", | ||||
|     "en-US" | ||||
|   ) | ||||
| @ -207,16 +208,17 @@ export const useMessageOption = () => { | ||||
| 
 | ||||
|       //  message = message.trim().replaceAll("\n", " ")
 | ||||
| 
 | ||||
|       let humanMessage = new HumanMessage({ | ||||
|       let humanMessage = humanMessageFormatter({ | ||||
|         content: [ | ||||
|           { | ||||
|             text: message, | ||||
|             type: "text" | ||||
|           } | ||||
|         ] | ||||
|         ], | ||||
|         model: selectedModel | ||||
|       }) | ||||
|       if (image.length > 0) { | ||||
|         humanMessage = new HumanMessage({ | ||||
|         humanMessage = humanMessageFormatter({ | ||||
|           content: [ | ||||
|             { | ||||
|               text: message, | ||||
| @ -226,11 +228,12 @@ export const useMessageOption = () => { | ||||
|               image_url: image, | ||||
|               type: "image_url" | ||||
|             } | ||||
|           ] | ||||
|           ], | ||||
|           model: selectedModel | ||||
|         }) | ||||
|       } | ||||
| 
 | ||||
|       const applicationChatHistory = generateHistory(history) | ||||
|       const applicationChatHistory = generateHistory(history, selectedModel) | ||||
| 
 | ||||
|       if (prompt) { | ||||
|         applicationChatHistory.unshift( | ||||
| @ -412,16 +415,17 @@ export const useMessageOption = () => { | ||||
|       const prompt = await systemPromptForNonRagOption() | ||||
|       const selectedPrompt = await getPromptById(selectedSystemPrompt) | ||||
| 
 | ||||
|       let humanMessage = new HumanMessage({ | ||||
|       let humanMessage = humanMessageFormatter({ | ||||
|         content: [ | ||||
|           { | ||||
|             text: message, | ||||
|             type: "text" | ||||
|           } | ||||
|         ] | ||||
|         ], | ||||
|         model: selectedModel | ||||
|       }) | ||||
|       if (image.length > 0) { | ||||
|         humanMessage = new HumanMessage({ | ||||
|         humanMessage = humanMessageFormatter({ | ||||
|           content: [ | ||||
|             { | ||||
|               text: message, | ||||
| @ -431,11 +435,12 @@ export const useMessageOption = () => { | ||||
|               image_url: image, | ||||
|               type: "image_url" | ||||
|             } | ||||
|           ] | ||||
|           ], | ||||
|           model: selectedModel | ||||
|         }) | ||||
|       } | ||||
| 
 | ||||
|       const applicationChatHistory = generateHistory(history) | ||||
|       const applicationChatHistory = generateHistory(history, selectedModel) | ||||
| 
 | ||||
|       if (prompt && !selectedPrompt) { | ||||
|         applicationChatHistory.unshift( | ||||
| @ -695,7 +700,7 @@ export const useMessageOption = () => { | ||||
|       }) | ||||
|       //  message = message.trim().replaceAll("\n", " ")
 | ||||
| 
 | ||||
|       let humanMessage = new HumanMessage({ | ||||
|       let humanMessage = humanMessageFormatter({ | ||||
|         content: [ | ||||
|           { | ||||
|             text: systemPrompt | ||||
| @ -703,10 +708,11 @@ export const useMessageOption = () => { | ||||
|               .replace("{question}", message), | ||||
|             type: "text" | ||||
|           } | ||||
|         ] | ||||
|         ], | ||||
|         model: selectedModel | ||||
|       }) | ||||
| 
 | ||||
|       const applicationChatHistory = generateHistory(history) | ||||
|       const applicationChatHistory = generateHistory(history, selectedModel) | ||||
| 
 | ||||
|       const chunks = await ollama.stream( | ||||
|         [...applicationChatHistory, humanMessage], | ||||
|  | ||||
| @ -1,55 +1,66 @@ | ||||
| import { isCustomModel } from "@/db/models" | ||||
| import { | ||||
|     HumanMessage, | ||||
|     AIMessage, | ||||
|     type MessageContent, | ||||
|   HumanMessage, | ||||
|   AIMessage, | ||||
|   type MessageContent | ||||
| } from "@langchain/core/messages" | ||||
| 
 | ||||
| export const generateHistory = ( | ||||
|     messages: { | ||||
|         role: "user" | "assistant" | "system" | ||||
|         content: string | ||||
|         image?: string | ||||
|     }[] | ||||
|   messages: { | ||||
|     role: "user" | "assistant" | "system" | ||||
|     content: string | ||||
|     image?: string | ||||
|   }[], | ||||
|   model: string | ||||
| ) => { | ||||
|     let history = [] | ||||
|     for (const message of messages) { | ||||
|         if (message.role === "user") { | ||||
|             let content: MessageContent = [ | ||||
|                 { | ||||
|                     type: "text", | ||||
|                     text: message.content | ||||
|                 } | ||||
|             ] | ||||
| 
 | ||||
|             if (message.image) { | ||||
|                 content = [ | ||||
|                     { | ||||
|                         type: "image_url", | ||||
|                         image_url: message.image | ||||
|                     }, | ||||
|                     { | ||||
|                         type: "text", | ||||
|                         text: message.content | ||||
|                     } | ||||
|                 ] | ||||
|   let history = [] | ||||
|   const isCustom = isCustomModel(model) | ||||
|   for (const message of messages) { | ||||
|     if (message.role === "user") { | ||||
|       let content: MessageContent = isCustom | ||||
|         ? message.content | ||||
|         : [ | ||||
|             { | ||||
|               type: "text", | ||||
|               text: message.content | ||||
|             } | ||||
|             history.push( | ||||
|                 new HumanMessage({ | ||||
|                     content: content | ||||
|                 }) | ||||
|             ) | ||||
|         } else if (message.role === "assistant") { | ||||
|             history.push( | ||||
|                 new AIMessage({ | ||||
|                     content: [ | ||||
|                         { | ||||
|                             type: "text", | ||||
|                             text: message.content | ||||
|                         } | ||||
|                     ] | ||||
|                 }) | ||||
|             ) | ||||
|         } | ||||
|           ] | ||||
| 
 | ||||
|       if (message.image) { | ||||
|         content = [ | ||||
|           { | ||||
|             type: "image_url", | ||||
|             image_url: !isCustom | ||||
|               ? message.image | ||||
|               : { | ||||
|                   url: message.image | ||||
|                 } | ||||
|           }, | ||||
|           { | ||||
|             type: "text", | ||||
|             text: message.content | ||||
|           } | ||||
|         ] | ||||
|       } | ||||
|       history.push( | ||||
|         new HumanMessage({ | ||||
|           content: content | ||||
|         }) | ||||
|       ) | ||||
|     } else if (message.role === "assistant") { | ||||
|       history.push( | ||||
|         new AIMessage({ | ||||
|           content: isCustom | ||||
|             ? message.content | ||||
|             : [ | ||||
|                 { | ||||
|                   type: "text", | ||||
|                   text: message.content | ||||
|                 } | ||||
|               ] | ||||
|         }) | ||||
|       ) | ||||
|     } | ||||
|     return history | ||||
| } | ||||
|   } | ||||
|   return history | ||||
| } | ||||
|  | ||||
							
								
								
									
										43
									
								
								src/utils/human-message.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								src/utils/human-message.tsx
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | ||||
| import { isCustomModel } from "@/db/models" | ||||
| import { HumanMessage, type MessageContent } from "@langchain/core/messages" | ||||
| 
 | ||||
| 
 | ||||
| type HumanMessageType = { | ||||
|     content: MessageContent, | ||||
|     model: string | ||||
| } | ||||
| 
 | ||||
| export const humanMessageFormatter = ({ content, model }: HumanMessageType) => { | ||||
|      | ||||
|     const isCustom = isCustomModel(model) | ||||
| 
 | ||||
|     if(isCustom) { | ||||
|         if(typeof content !== 'string') { | ||||
|             if(content.length > 1) { | ||||
|                 // this means that we need to reformat the image_url 
 | ||||
|                 const newContent: MessageContent = [ | ||||
|                     { | ||||
|                         type: "text", | ||||
|                         //@ts-ignore
 | ||||
|                         text: content[0].text | ||||
|                     }, | ||||
|                     { | ||||
|                         type: "image_url", | ||||
|                         image_url: { | ||||
|                             //@ts-ignore
 | ||||
|                             url: content[1].image_url | ||||
|                         } | ||||
|                     } | ||||
|                 ] | ||||
| 
 | ||||
|                 return new HumanMessage({ | ||||
|                     content: newContent | ||||
|                 }) | ||||
|             } | ||||
|          } | ||||
|     } | ||||
|      | ||||
|     return new HumanMessage({ | ||||
|         content, | ||||
|     }) | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user