feat: support custom models for messages

This commit introduces support for custom models in the message history generation process. Previously, the history would format messages using LangChain's standard message structure, which is not compatible with custom models. This change allows for correct history formatting regardless of the selected model type, enhancing compatibility and user experience.
This commit is contained in:
n4ze3m 2024-09-29 23:59:15 +05:30
parent c8620637f8
commit 192e3893bb
4 changed files with 148 additions and 80 deletions

View File

@ -9,7 +9,7 @@ import {
} from "~/services/ollama" } from "~/services/ollama"
import { useStoreMessageOption, type Message } from "~/store/option" import { useStoreMessageOption, type Message } from "~/store/option"
import { useStoreMessage } from "~/store" import { useStoreMessage } from "~/store"
import { HumanMessage, SystemMessage } from "@langchain/core/messages" import { SystemMessage } from "@langchain/core/messages"
import { getDataFromCurrentTab } from "~/libs/get-html" import { getDataFromCurrentTab } from "~/libs/get-html"
import { MemoryVectorStore } from "langchain/vectorstores/memory" import { MemoryVectorStore } from "langchain/vectorstores/memory"
import { memoryEmbedding } from "@/utils/memory-embeddings" import { memoryEmbedding } from "@/utils/memory-embeddings"
@ -33,6 +33,7 @@ import { getAllDefaultModelSettings } from "@/services/model-settings"
import { getSystemPromptForWeb } from "@/web/web" import { getSystemPromptForWeb } from "@/web/web"
import { pageAssistModel } from "@/models" import { pageAssistModel } from "@/models"
import { getPrompt } from "@/services/application" import { getPrompt } from "@/services/application"
import { humanMessageFormatter } from "@/utils/human-message"
export const useMessage = () => { export const useMessage = () => {
const { const {
@ -313,7 +314,7 @@ export const useMessage = () => {
] ]
} }
let humanMessage = new HumanMessage({ let humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: systemPrompt text: systemPrompt
@ -321,10 +322,11 @@ export const useMessage = () => {
.replace("{question}", query), .replace("{question}", query),
type: "text" type: "text"
} }
] ],
model: selectedModel
}) })
const applicationChatHistory = generateHistory(history) const applicationChatHistory = generateHistory(history, selectedModel)
const chunks = await ollama.stream( const chunks = await ollama.stream(
[...applicationChatHistory, humanMessage], [...applicationChatHistory, humanMessage],
@ -500,16 +502,17 @@ export const useMessage = () => {
const prompt = await systemPromptForNonRag() const prompt = await systemPromptForNonRag()
const selectedPrompt = await getPromptById(selectedSystemPrompt) const selectedPrompt = await getPromptById(selectedSystemPrompt)
let humanMessage = new HumanMessage({ let humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
type: "text" type: "text"
} }
] ],
model: selectedModel
}) })
if (image.length > 0) { if (image.length > 0) {
humanMessage = new HumanMessage({ humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
@ -519,11 +522,12 @@ export const useMessage = () => {
image_url: image, image_url: image,
type: "image_url" type: "image_url"
} }
] ],
model: selectedModel
}) })
} }
const applicationChatHistory = generateHistory(history) const applicationChatHistory = generateHistory(history, selectedModel)
if (prompt && !selectedPrompt) { if (prompt && !selectedPrompt) {
applicationChatHistory.unshift( applicationChatHistory.unshift(
@ -760,16 +764,17 @@ export const useMessage = () => {
// message = message.trim().replaceAll("\n", " ") // message = message.trim().replaceAll("\n", " ")
let humanMessage = new HumanMessage({ let humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
type: "text" type: "text"
} }
] ],
model: selectedModel
}) })
if (image.length > 0) { if (image.length > 0) {
humanMessage = new HumanMessage({ humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
@ -779,11 +784,12 @@ export const useMessage = () => {
image_url: image, image_url: image,
type: "image_url" type: "image_url"
} }
] ],
model: selectedModel
}) })
} }
const applicationChatHistory = generateHistory(history) const applicationChatHistory = generateHistory(history, selectedModel)
if (prompt) { if (prompt) {
applicationChatHistory.unshift( applicationChatHistory.unshift(
@ -966,16 +972,17 @@ export const useMessage = () => {
try { try {
const prompt = await getPrompt(messageType) const prompt = await getPrompt(messageType)
let humanMessage = new HumanMessage({ let humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: prompt.replace("{text}", message), text: prompt.replace("{text}", message),
type: "text" type: "text"
} }
] ],
model: selectedModel
}) })
if (image.length > 0) { if (image.length > 0) {
humanMessage = new HumanMessage({ humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: prompt.replace("{text}", message), text: prompt.replace("{text}", message),
@ -985,7 +992,8 @@ export const useMessage = () => {
image_url: image, image_url: image,
type: "image_url" type: "image_url"
} }
] ],
model: selectedModel
}) })
} }

View File

@ -33,6 +33,7 @@ import { useStoreChatModelSettings } from "@/store/model"
import { getAllDefaultModelSettings } from "@/services/model-settings" import { getAllDefaultModelSettings } from "@/services/model-settings"
import { pageAssistModel } from "@/models" import { pageAssistModel } from "@/models"
import { getNoOfRetrievedDocs } from "@/services/app" import { getNoOfRetrievedDocs } from "@/services/app"
import { humanMessageFormatter } from "@/utils/human-message"
export const useMessageOption = () => { export const useMessageOption = () => {
const { const {
@ -68,7 +69,7 @@ export const useMessageOption = () => {
} = useStoreMessageOption() } = useStoreMessageOption()
const currentChatModelSettings = useStoreChatModelSettings() const currentChatModelSettings = useStoreChatModelSettings()
const [selectedModel, setSelectedModel] = useStorage("selectedModel") const [selectedModel, setSelectedModel] = useStorage("selectedModel")
const [ speechToTextLanguage, setSpeechToTextLanguage ] = useStorage( const [speechToTextLanguage, setSpeechToTextLanguage] = useStorage(
"speechToTextLanguage", "speechToTextLanguage",
"en-US" "en-US"
) )
@ -207,16 +208,17 @@ export const useMessageOption = () => {
// message = message.trim().replaceAll("\n", " ") // message = message.trim().replaceAll("\n", " ")
let humanMessage = new HumanMessage({ let humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
type: "text" type: "text"
} }
] ],
model: selectedModel
}) })
if (image.length > 0) { if (image.length > 0) {
humanMessage = new HumanMessage({ humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
@ -226,11 +228,12 @@ export const useMessageOption = () => {
image_url: image, image_url: image,
type: "image_url" type: "image_url"
} }
] ],
model: selectedModel
}) })
} }
const applicationChatHistory = generateHistory(history) const applicationChatHistory = generateHistory(history, selectedModel)
if (prompt) { if (prompt) {
applicationChatHistory.unshift( applicationChatHistory.unshift(
@ -412,16 +415,17 @@ export const useMessageOption = () => {
const prompt = await systemPromptForNonRagOption() const prompt = await systemPromptForNonRagOption()
const selectedPrompt = await getPromptById(selectedSystemPrompt) const selectedPrompt = await getPromptById(selectedSystemPrompt)
let humanMessage = new HumanMessage({ let humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
type: "text" type: "text"
} }
] ],
model: selectedModel
}) })
if (image.length > 0) { if (image.length > 0) {
humanMessage = new HumanMessage({ humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: message, text: message,
@ -431,11 +435,12 @@ export const useMessageOption = () => {
image_url: image, image_url: image,
type: "image_url" type: "image_url"
} }
] ],
model: selectedModel
}) })
} }
const applicationChatHistory = generateHistory(history) const applicationChatHistory = generateHistory(history, selectedModel)
if (prompt && !selectedPrompt) { if (prompt && !selectedPrompt) {
applicationChatHistory.unshift( applicationChatHistory.unshift(
@ -695,7 +700,7 @@ export const useMessageOption = () => {
}) })
// message = message.trim().replaceAll("\n", " ") // message = message.trim().replaceAll("\n", " ")
let humanMessage = new HumanMessage({ let humanMessage = humanMessageFormatter({
content: [ content: [
{ {
text: systemPrompt text: systemPrompt
@ -703,10 +708,11 @@ export const useMessageOption = () => {
.replace("{question}", message), .replace("{question}", message),
type: "text" type: "text"
} }
] ],
model: selectedModel
}) })
const applicationChatHistory = generateHistory(history) const applicationChatHistory = generateHistory(history, selectedModel)
const chunks = await ollama.stream( const chunks = await ollama.stream(
[...applicationChatHistory, humanMessage], [...applicationChatHistory, humanMessage],

View File

@ -1,55 +1,66 @@
import { isCustomModel } from "@/db/models"
import { import {
HumanMessage, HumanMessage,
AIMessage, AIMessage,
type MessageContent, type MessageContent
} from "@langchain/core/messages" } from "@langchain/core/messages"
export const generateHistory = ( export const generateHistory = (
messages: { messages: {
role: "user" | "assistant" | "system" role: "user" | "assistant" | "system"
content: string content: string
image?: string image?: string
}[] }[],
model: string
) => { ) => {
let history = [] let history = []
for (const message of messages) { const isCustom = isCustomModel(model)
if (message.role === "user") { for (const message of messages) {
let content: MessageContent = [ if (message.role === "user") {
{ let content: MessageContent = isCustom
type: "text", ? message.content
text: message.content : [
} {
] type: "text",
text: message.content
if (message.image) {
content = [
{
type: "image_url",
image_url: message.image
},
{
type: "text",
text: message.content
}
]
} }
history.push( ]
new HumanMessage({
content: content if (message.image) {
}) content = [
) {
} else if (message.role === "assistant") { type: "image_url",
history.push( image_url: !isCustom
new AIMessage({ ? message.image
content: [ : {
{ url: message.image
type: "text", }
text: message.content },
} {
] type: "text",
}) text: message.content
) }
} ]
}
history.push(
new HumanMessage({
content: content
})
)
} else if (message.role === "assistant") {
history.push(
new AIMessage({
content: isCustom
? message.content
: [
{
type: "text",
text: message.content
}
]
})
)
} }
return history }
} return history
}

View File

@ -0,0 +1,43 @@
import { isCustomModel } from "@/db/models"
import { HumanMessage, type MessageContent } from "@langchain/core/messages"
type HumanMessageType = {
content: MessageContent,
model: string
}
export const humanMessageFormatter = ({ content, model }: HumanMessageType) => {
const isCustom = isCustomModel(model)
if(isCustom) {
if(typeof content !== 'string') {
if(content.length > 1) {
// this means that we need to reformat the image_url
const newContent: MessageContent = [
{
type: "text",
//@ts-ignore
text: content[0].text
},
{
type: "image_url",
image_url: {
//@ts-ignore
url: content[1].image_url
}
}
]
return new HumanMessage({
content: newContent
})
}
}
}
return new HumanMessage({
content,
})
}