From dd496b7b98e07a91efebfbb8455b1fee28636624 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 6 Apr 2024 00:30:23 +0530 Subject: [PATCH] Webui chat with x added --- src/chain/chat-with-x.ts | 154 +++++++++++ .../Option/Knowledge/SelectedKnwledge.tsx | 4 +- .../Option/Playground/PlaygroundForm.tsx | 9 +- src/db/vector.ts | 4 +- src/hooks/useMessageOption.tsx | 245 +++++++++++++++++- src/libs/PageAssistVectorStore.ts | 11 +- src/loader/pdf-url.ts | 2 +- 7 files changed, 401 insertions(+), 28 deletions(-) create mode 100644 src/chain/chat-with-x.ts diff --git a/src/chain/chat-with-x.ts b/src/chain/chat-with-x.ts new file mode 100644 index 0000000..7c378a2 --- /dev/null +++ b/src/chain/chat-with-x.ts @@ -0,0 +1,154 @@ +import { BaseLanguageModel } from "@langchain/core/language_models/base" +import { Document } from "@langchain/core/documents" +import { + ChatPromptTemplate, + MessagesPlaceholder, + PromptTemplate +} from "@langchain/core/prompts" +import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages" +import { StringOutputParser } from "@langchain/core/output_parsers" +import { + Runnable, + RunnableBranch, + RunnableLambda, + RunnableMap, + RunnableSequence +} from "@langchain/core/runnables" +type RetrievalChainInput = { + chat_history: string + question: string +} + +const formatChatHistoryAsString = (history: BaseMessage[]) => { + return history + .map((message) => `${message._getType()}: ${message.content}`) + .join("\n") +} + +export const formatDocs = (docs: Document[]) => { + return docs + .map((doc, i) => `${doc.pageContent}`) + .join("\n") +} + +const serializeHistory = (input: any) => { + const chatHistory = input.chat_history || [] + const convertedChatHistory = [] + for (const message of chatHistory) { + if (message.human !== undefined) { + convertedChatHistory.push(new HumanMessage({ content: message.human })) + } + if (message["ai"] !== undefined) { + convertedChatHistory.push(new AIMessage({ content: message.ai })) + } + } + return convertedChatHistory +} + +const createRetrieverChain = ( + llm: BaseLanguageModel, + retriever: Runnable, + question_template: string +) => { + const CONDENSE_QUESTION_PROMPT = + PromptTemplate.fromTemplate(question_template) + const condenseQuestionChain = RunnableSequence.from([ + CONDENSE_QUESTION_PROMPT, + llm, + new StringOutputParser() + ]).withConfig({ + runName: "CondenseQuestion" + }) + const hasHistoryCheckFn = RunnableLambda.from( + (input: RetrievalChainInput) => input.chat_history.length > 0 + ).withConfig({ runName: "HasChatHistoryCheck" }) + const conversationChain = condenseQuestionChain.pipe(retriever).withConfig({ + runName: "RetrievalChainWithHistory" + }) + const basicRetrievalChain = RunnableLambda.from( + (input: RetrievalChainInput) => input.question + ) + .withConfig({ + runName: "Itemgetter:question" + }) + .pipe(retriever) + .withConfig({ runName: "RetrievalChainWithNoHistory" }) + + return RunnableBranch.from([ + [hasHistoryCheckFn, conversationChain], + basicRetrievalChain + ]).withConfig({ + runName: "FindDocs" + }) +} + +export const createChatWithXChain = ({ + llm, + question_template, + question_llm, + retriever, + response_template +}: { + llm: BaseLanguageModel + question_llm: BaseLanguageModel + retriever: Runnable + question_template: string + response_template: string +}) => { + const retrieverChain = createRetrieverChain( + question_llm, + retriever, + question_template + ) + const context = RunnableMap.from({ + context: RunnableSequence.from([ + ({ question, chat_history }) => { + return { + question: question, + chat_history: formatChatHistoryAsString(chat_history) + } + }, + retrieverChain, + RunnableLambda.from(formatDocs).withConfig({ + runName: "FormatDocumentChunks" + }) + ]), + question: RunnableLambda.from( + (input: RetrievalChainInput) => input.question + ).withConfig({ + runName: "Itemgetter:question" + }), + chat_history: RunnableLambda.from( + (input: RetrievalChainInput) => input.chat_history + ).withConfig({ + runName: "Itemgetter:chat_history" + }) + }).withConfig({ tags: ["RetrieveDocs"] }) + const prompt = ChatPromptTemplate.fromMessages([ + ["system", response_template], + new MessagesPlaceholder("chat_history"), + ["human", "{question}"] + ]) + + const responseSynthesizerChain = RunnableSequence.from([ + prompt, + llm, + new StringOutputParser() + ]).withConfig({ + tags: ["GenerateResponse"] + }) + return RunnableSequence.from([ + { + question: RunnableLambda.from( + (input: RetrievalChainInput) => input.question + ).withConfig({ + runName: "Itemgetter:question" + }), + chat_history: RunnableLambda.from(serializeHistory).withConfig({ + runName: "SerializeHistory", + }) + }, + context, + responseSynthesizerChain + ]) +} diff --git a/src/components/Option/Knowledge/SelectedKnwledge.tsx b/src/components/Option/Knowledge/SelectedKnwledge.tsx index 26ab68f..7cdc0ca 100644 --- a/src/components/Option/Knowledge/SelectedKnwledge.tsx +++ b/src/components/Option/Knowledge/SelectedKnwledge.tsx @@ -12,7 +12,7 @@ export const SelectedKnowledge = ({ knowledge, onClose }: Props) => {
-

+

{knowledge.title}

@@ -20,7 +20,7 @@ export const SelectedKnowledge = ({ knowledge, onClose }: Props) => { {knowledge.source.map((source, index) => (
+ className="inline-flex gap-2 text-xs border rounded-md p-1 dark:border-gray-600 dark:text-gray-100"> {source.filename}
diff --git a/src/components/Option/Playground/PlaygroundForm.tsx b/src/components/Option/Playground/PlaygroundForm.tsx index a387ad6..4c15e6d 100644 --- a/src/components/Option/Playground/PlaygroundForm.tsx +++ b/src/components/Option/Playground/PlaygroundForm.tsx @@ -158,14 +158,7 @@ export const PlaygroundForm = ({ dropedFile }: Props) => { } return (
- {selectedKnowledge && ( - { - setSelectedKnowledge(null) - }} - knowledge={selectedKnowledge} - /> - )} +
{ + console.log("Creating new vector", vector) + this.db.set({ [id]: { id, vectors: vector } }, () => { if (chrome.runtime.lastError) { reject(chrome.runtime.lastError) } else { diff --git a/src/hooks/useMessageOption.tsx b/src/hooks/useMessageOption.tsx index 7d67415..760739a 100644 --- a/src/hooks/useMessageOption.tsx +++ b/src/hooks/useMessageOption.tsx @@ -1,8 +1,10 @@ import React from "react" import { cleanUrl } from "~/libs/clean-url" import { + defaultEmbeddingModelForRag, geWebSearchFollowUpPrompt, getOllamaURL, + promptForRag, systemPromptForNonRagOption } from "~/services/ollama" import { type ChatHistory, type Message } from "~/store/option" @@ -23,13 +25,16 @@ import { generateHistory } from "@/utils/generate-history" import { useTranslation } from "react-i18next" import { saveMessageOnError, saveMessageOnSuccess } from "./chat-helper" import { usePageAssist } from "@/context" +import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama" +import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore" +import { formatDocs } from "@/chain/chat-with-x" export const useMessageOption = () => { const { controller: abortController, setController: setAbortController, messages, - setMessages, + setMessages } = usePageAssist() const { history, @@ -502,6 +507,213 @@ export const useMessageOption = () => { } } + const ragMode = async ( + message: string, + image: string, + isRegenerate: boolean, + messages: Message[], + history: ChatHistory, + signal: AbortSignal + ) => { + const url = await getOllamaURL() + + const ollama = new ChatOllama({ + model: selectedModel!, + baseUrl: cleanUrl(url) + }) + + let newMessage: Message[] = [] + let generateMessageId = generateID() + + if (!isRegenerate) { + newMessage = [ + ...messages, + { + isBot: false, + name: "You", + message, + sources: [], + images: [] + }, + { + isBot: true, + name: selectedModel, + message: "▋", + sources: [], + id: generateMessageId + } + ] + } else { + newMessage = [ + ...messages, + { + isBot: true, + name: selectedModel, + message: "▋", + sources: [], + id: generateMessageId + } + ] + } + setMessages(newMessage) + let fullText = "" + let contentToSave = "" + + const embeddingModle = await defaultEmbeddingModelForRag() + const ollamaUrl = await getOllamaURL() + const ollamaEmbedding = new OllamaEmbeddings({ + model: embeddingModle || selectedModel, + baseUrl: cleanUrl(ollamaUrl) + }) + + let vectorstore = await PageAssistVectorStore.fromExistingIndex( + ollamaEmbedding, + { + file_id: null, + knownledge_id: selectedKnowledge.id + } + ) + + try { + let query = message + const { ragPrompt: systemPrompt, ragQuestionPrompt: questionPrompt } = + await promptForRag() + if (newMessage.length > 2) { + const lastTenMessages = newMessage.slice(-10) + lastTenMessages.pop() + const chat_history = lastTenMessages + .map((message) => { + return `${message.isBot ? "Assistant: " : "Human: "}${message.message}` + }) + .join("\n") + const promptForQuestion = questionPrompt + .replaceAll("{chat_history}", chat_history) + .replaceAll("{question}", message) + const questionOllama = new ChatOllama({ + model: selectedModel!, + baseUrl: cleanUrl(url) + }) + const response = await questionOllama.invoke(promptForQuestion) + query = response.content.toString() + } + + const docs = await vectorstore.similaritySearch(query, 4) + const context = formatDocs(docs) + const source = docs.map((doc) => { + return { + name: doc?.metadata?.source || "untitled", + type: doc?.metadata?.type || "unknown", + url: "" + } + }) + message = message.trim().replaceAll("\n", " ") + + let humanMessage = new HumanMessage({ + content: [ + { + text: systemPrompt + .replace("{context}", context) + .replace("{question}", message), + type: "text" + } + ] + }) + + const applicationChatHistory = generateHistory(history) + + const chunks = await ollama.stream( + [...applicationChatHistory, humanMessage], + { + signal: signal + } + ) + let count = 0 + for await (const chunk of chunks) { + contentToSave += chunk.content + fullText += chunk.content + if (count === 0) { + setIsProcessing(true) + } + setMessages((prev) => { + return prev.map((message) => { + if (message.id === generateMessageId) { + return { + ...message, + message: fullText.slice(0, -1) + "▋" + } + } + return message + }) + }) + count++ + } + // update the message with the full text + setMessages((prev) => { + return prev.map((message) => { + if (message.id === generateMessageId) { + return { + ...message, + message: fullText, + sources: source + } + } + return message + }) + }) + + setHistory([ + ...history, + { + role: "user", + content: message, + image + }, + { + role: "assistant", + content: fullText + } + ]) + + await saveMessageOnSuccess({ + historyId, + setHistoryId, + isRegenerate, + selectedModel: selectedModel, + message, + image, + fullText, + source + }) + + setIsProcessing(false) + setStreaming(false) + } catch (e) { + const errorSave = await saveMessageOnError({ + e, + botMessage: fullText, + history, + historyId, + image, + selectedModel, + setHistory, + setHistoryId, + userMessage: message, + isRegenerating: isRegenerate + }) + + if (!errorSave) { + notification.error({ + message: t("error"), + description: e?.message || t("somethingWentWrong") + }) + } + setIsProcessing(false) + setStreaming(false) + } finally { + setAbortController(null) + } + } + const onSubmit = async ({ message, image, @@ -527,8 +739,8 @@ export const useMessageOption = () => { setAbortController(controller) signal = controller.signal } - if (webSearch) { - await searchChatMode( + if (selectedKnowledge) { + await ragMode( message, image, isRegenerate, @@ -537,14 +749,25 @@ export const useMessageOption = () => { signal ) } else { - await normalChatMode( - message, - image, - isRegenerate, - chatHistory || messages, - memory || history, - signal - ) + if (webSearch) { + await searchChatMode( + message, + image, + isRegenerate, + chatHistory || messages, + memory || history, + signal + ) + } else { + await normalChatMode( + message, + image, + isRegenerate, + chatHistory || messages, + memory || history, + signal + ) + } } } diff --git a/src/libs/PageAssistVectorStore.ts b/src/libs/PageAssistVectorStore.ts index f3ae7b4..ca0bbb2 100644 --- a/src/libs/PageAssistVectorStore.ts +++ b/src/libs/PageAssistVectorStore.ts @@ -3,6 +3,7 @@ import { VectorStore } from "@langchain/core/vectorstores" import type { EmbeddingsInterface } from "@langchain/core/embeddings" import { Document } from "@langchain/core/documents" import { getVector, insertVector } from "@/db/vector" +import { cp } from "fs" /** * Interface representing a vector in memory. It includes the content @@ -116,8 +117,10 @@ export class PageAssistVectorStore extends VectorStore { }) return filter(doc) } - const pgVector = await getVector(`vector:${this.knownledge_id}`) - const filteredMemoryVectors = pgVector.vectors.filter(filterFunction) + const data = await getVector(`vector:${this.knownledge_id}`) + const pgVector = [...data.vectors] + const filteredMemoryVectors = pgVector.filter(filterFunction) + console.log(filteredMemoryVectors) const searches = filteredMemoryVectors .map((vector, index) => ({ similarity: this.similarity(query, vector.embedding), @@ -125,7 +128,7 @@ export class PageAssistVectorStore extends VectorStore { })) .sort((a, b) => (a.similarity > b.similarity ? -1 : 0)) .slice(0, k) - + console.log(searches) const result: [Document, number][] = searches.map((search) => [ new Document({ metadata: filteredMemoryVectors[search.index].metadata, @@ -133,7 +136,7 @@ export class PageAssistVectorStore extends VectorStore { }), search.similarity ]) - + console.log(result) return result } diff --git a/src/loader/pdf-url.ts b/src/loader/pdf-url.ts index 3124085..73edf63 100644 --- a/src/loader/pdf-url.ts +++ b/src/loader/pdf-url.ts @@ -40,7 +40,7 @@ export class PageAssistPDFUrlLoader .trim() documents.push({ pageContent: text, - metadata: { source: this.name, page: i } + metadata: { source: this.name, page: i, type: "pdf" } }) }