diff --git a/src/assets/locale/en/settings.json b/src/assets/locale/en/settings.json index 677fe2b..e7e5fb3 100644 --- a/src/assets/locale/en/settings.json +++ b/src/assets/locale/en/settings.json @@ -51,6 +51,23 @@ "success": "Import Success", "error": "Import Error" } + }, + "tts": { + "heading": "Text-to-Speech Settings", + "ttsEnabled": { + "label": "Enable Text-to-Speech" + }, + "ttsProvider": { + "label": "Text-to-Speech Provider", + "placeholder": "Select a provider" + }, + "ttsVoice": { + "label": "Text-to-Speech Voice", + "placeholder": "Select a voice" + }, + "ssmlEnabled": { + "label": "Enable SSML (Speech Synthesis Markup Language)" + } } }, "manageModels": { diff --git a/src/assets/locale/ja-JP/settings.json b/src/assets/locale/ja-JP/settings.json index ee3e0b6..dc4d508 100644 --- a/src/assets/locale/ja-JP/settings.json +++ b/src/assets/locale/ja-JP/settings.json @@ -54,6 +54,23 @@ "success": "インポート成功", "error": "インポートエラー" } + }, + "tts": { + "heading": "テキスト読み上げ設定", + "ttsEnabled": { + "label": "テキスト読み上げを有効にする" + }, + "ttsProvider": { + "label": "テキスト読み上げプロバイダー", + "placeholder": "プロバイダーを選択" + }, + "ttsVoice": { + "label": "テキスト読み上げの音声", + "placeholder": "音声を選択" + }, + "ssmlEnabled": { + "label": "SSML (Speech Synthesis Markup Language) を有効にする" + } } }, "manageModels": { diff --git a/src/assets/locale/ml/settings.json b/src/assets/locale/ml/settings.json index 3f77730..1ed5902 100644 --- a/src/assets/locale/ml/settings.json +++ b/src/assets/locale/ml/settings.json @@ -54,6 +54,23 @@ "success": "ഇമ്പോർട്ട് വിജയകരമായി", "error": "ഇമ്പോർട്ട് പരാജയപ്പെട്ടു" } + }, + "tts": { + "heading": "ടെക്സ്റ്റ്-ടു-സ്പീച്ച് ക്രമീകരണങ്ങൾ", + "ttsEnabled": { + "label": "ടെക്സ്റ്റ്-ടു-സ്പീച്ച് പ്രവർത്തനക്ഷമമാക്കുക" + }, + "ttsProvider": { + "label": "ടെക്സ്റ്റ്-ടു-സ്പീച്ച് പ്രോവൈഡർ", + "placeholder": "ഒരു പ്രോവൈഡർ തിരഞ്ഞെടുക്കുക" + }, + "ttsVoice": { + "label": "ടെക്സ്റ്റ്-ടു-സ്പീച്ച് വോയ്സ്", + "placeholder": "ഒരു വോയ്സ് തിരഞ്ഞെടുക്കുക" + }, + "ssmlEnabled": { + "label": "SSML (സ്പീച്ച് സിന്തസിസ് മാർക്കപ്പ് ലാംഗ്വേജ്) പ്രവർത്തനക്ഷമമാക്കുക" + } } }, "manageModels": { diff --git a/src/assets/locale/zh/settings.json b/src/assets/locale/zh/settings.json index 4a7cdda..b83adea 100644 --- a/src/assets/locale/zh/settings.json +++ b/src/assets/locale/zh/settings.json @@ -54,6 +54,23 @@ "success": "导入成功", "error": "导入错误" } + }, + "tts": { + "heading": "文本转语音设置", + "ttsEnabled": { + "label": "启用文本转语音" + }, + "ttsProvider": { + "label": "文本转语音提供商", + "placeholder": "选择一个提供商" + }, + "ttsVoice": { + "label": "文本转语音语音", + "placeholder": "选择一种语音" + }, + "ssmlEnabled": { + "label": "启用SSML(语音合成标记语言)" + } } }, "manageModels": { diff --git a/src/components/Common/PageAssistProvider.tsx b/src/components/Common/PageAssistProvider.tsx index fb75f83..cd5cac8 100644 --- a/src/components/Common/PageAssistProvider.tsx +++ b/src/components/Common/PageAssistProvider.tsx @@ -11,6 +11,8 @@ export const PageAssistProvider = ({ const [controller, setController] = React.useState( null ) + const [embeddingController, setEmbeddingController] = + React.useState(null) return ( {children} diff --git a/src/components/Sidepanel/Chat/form.tsx b/src/components/Sidepanel/Chat/form.tsx index 70f5c34..b8b3fc9 100644 --- a/src/components/Sidepanel/Chat/form.tsx +++ b/src/components/Sidepanel/Chat/form.tsx @@ -8,7 +8,7 @@ import { Checkbox, Dropdown, Image, Tooltip } from "antd" import { useSpeechRecognition } from "~/hooks/useSpeechRecognition" import { useWebUI } from "~/store/webui" import { defaultEmbeddingModelForRag } from "~/services/ollama" -import { ImageIcon, MicIcon, X } from "lucide-react" +import { ImageIcon, MicIcon, StopCircleIcon, X } from "lucide-react" import { useTranslation } from "react-i18next" type Props = { @@ -56,8 +56,13 @@ export const SidepanelForm = ({ dropedFile }: Props) => { useDynamicTextareaSize(textareaRef, form.values.message, 120) - const { onSubmit, selectedModel, chatMode, speechToTextLanguage } = - useMessage() + const { + onSubmit, + selectedModel, + chatMode, + speechToTextLanguage, + stopStreamingRequest + } = useMessage() const { isListening, start, stop, transcript } = useSpeechRecognition() React.useEffect(() => { @@ -217,59 +222,70 @@ export const SidepanelForm = ({ dropedFile }: Props) => { - - - - } - menu={{ - items: [ - { - key: 1, - label: ( - - setSendWhenEnter(e.target.checked) - }> - {t("sendWhenEnter")} - - ) - } - ] - }}> -
- {sendWhenEnter ? ( + {!isSending ? ( + - - + className="w-5 h-5"> + - ) : null} - {t("common:submit")} -
-
+ } + menu={{ + items: [ + { + key: 1, + label: ( + + setSendWhenEnter(e.target.checked) + }> + {t("sendWhenEnter")} + + ) + } + ] + }}> +
+ {sendWhenEnter ? ( + + + + + ) : null} + {t("common:submit")} +
+ + ) : ( + + + + )} diff --git a/src/context/index.tsx b/src/context/index.tsx index 7e91bae..574dc59 100644 --- a/src/context/index.tsx +++ b/src/context/index.tsx @@ -7,6 +7,9 @@ interface PageAssistContext { controller: AbortController | null setController: Dispatch> + + embeddingController: AbortController | null + setEmbeddingController: Dispatch> } export const PageAssistContext = createContext({ @@ -14,7 +17,10 @@ export const PageAssistContext = createContext({ setMessages: () => {}, controller: null, - setController: () => {} + setController: () => {}, + + embeddingController: null, + setEmbeddingController: () => {} }) export const usePageAssist = () => { diff --git a/src/hooks/useMessage.tsx b/src/hooks/useMessage.tsx index c501185..89f5b27 100644 --- a/src/hooks/useMessage.tsx +++ b/src/hooks/useMessage.tsx @@ -6,24 +6,36 @@ import { promptForRag, systemPromptForNonRag } from "~/services/ollama" -import { useStoreMessage, type Message } from "~/store" +import { type Message } from "~/store/option" +import { useStoreMessage } from "~/store" import { ChatOllama } from "@langchain/community/chat_models/ollama" import { HumanMessage, SystemMessage } from "@langchain/core/messages" import { getDataFromCurrentTab } from "~/libs/get-html" import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama" -import { - createChatWithWebsiteChain, - groupMessagesByConversation -} from "~/chain/chat-with-website" import { MemoryVectorStore } from "langchain/vectorstores/memory" import { memoryEmbedding } from "@/utils/memory-embeddings" +import { ChatHistory } from "@/store/option" +import { generateID } from "@/db" +import { saveMessageOnError, saveMessageOnSuccess } from "./chat-helper" +import { notification } from "antd" +import { useTranslation } from "react-i18next" +import { usePageAssist } from "@/context" +import { formatDocs } from "@/chain/chat-with-x" +import { OllamaEmbeddingsPageAssist } from "@/models/OllamaEmbedding" export const useMessage = () => { const { - history, + controller: abortController, + setController: setAbortController, messages, - setHistory, setMessages, + embeddingController, + setEmbeddingController + } = usePageAssist() + const { t } = useTranslation("option") + const { + history, + setHistory, setStreaming, streaming, setIsFirstMessage, @@ -45,8 +57,6 @@ export const useMessage = () => { setCurrentURL } = useStoreMessage() - const abortControllerRef = React.useRef(null) - const [keepTrackOfEmbedding, setKeepTrackOfEmbedding] = React.useState<{ [key: string]: MemoryVectorStore }>({}) @@ -62,57 +72,87 @@ export const useMessage = () => { setStreaming(false) } - const chatWithWebsiteMode = async (message: string) => { - try { - let isAlreadyExistEmbedding: MemoryVectorStore - let embedURL: string, embedHTML: string, embedType: string - let embedPDF: { content: string; page: number }[] = [] + const chatWithWebsiteMode = async ( + message: string, + image: string, + isRegenerate: boolean, + messages: Message[], + history: ChatHistory, + signal: AbortSignal, + embeddingSignal: AbortSignal + ) => { + const url = await getOllamaURL() - if (messages.length === 0) { - const { content: html, url, type, pdf } = await getDataFromCurrentTab() - embedHTML = html - embedURL = url - embedType = type - embedPDF = pdf - setCurrentURL(url) - isAlreadyExistEmbedding = keepTrackOfEmbedding[currentURL] - } else { - isAlreadyExistEmbedding = keepTrackOfEmbedding[currentURL] - embedURL = currentURL - } - let newMessage: Message[] = [ + const ollama = new ChatOllama({ + model: selectedModel!, + baseUrl: cleanUrl(url) + }) + + let newMessage: Message[] = [] + let generateMessageId = generateID() + + if (!isRegenerate) { + newMessage = [ ...messages, { isBot: false, name: "You", message, - sources: [] + sources: [], + images: [] }, { isBot: true, name: selectedModel, message: "▋", - sources: [] + sources: [], + id: generateMessageId } ] + } else { + newMessage = [ + ...messages, + { + isBot: true, + name: selectedModel, + message: "▋", + sources: [], + id: generateMessageId + } + ] + } + setMessages(newMessage) + let fullText = "" + let contentToSave = "" + let isAlreadyExistEmbedding: MemoryVectorStore + let embedURL: string, embedHTML: string, embedType: string + let embedPDF: { content: string; page: number }[] = [] - const appendingIndex = newMessage.length - 1 - setMessages(newMessage) - const ollamaUrl = await getOllamaURL() - const embeddingModle = await defaultEmbeddingModelForRag() + if (messages.length === 0) { + const { content: html, url, type, pdf } = await getDataFromCurrentTab() + embedHTML = html + embedURL = url + embedType = type + embedPDF = pdf + setCurrentURL(url) + isAlreadyExistEmbedding = keepTrackOfEmbedding[currentURL] + } else { + isAlreadyExistEmbedding = keepTrackOfEmbedding[currentURL] + embedURL = currentURL + } - const ollamaEmbedding = new OllamaEmbeddings({ - model: embeddingModle || selectedModel, - baseUrl: cleanUrl(ollamaUrl) - }) + setMessages(newMessage) + const ollamaUrl = await getOllamaURL() + const embeddingModle = await defaultEmbeddingModelForRag() - const ollamaChat = new ChatOllama({ - model: selectedModel, - baseUrl: cleanUrl(ollamaUrl) - }) - - let vectorstore: MemoryVectorStore + const ollamaEmbedding = new OllamaEmbeddingsPageAssist({ + model: embeddingModle || selectedModel, + baseUrl: cleanUrl(ollamaUrl), + signal: embeddingSignal + }) + let vectorstore: MemoryVectorStore + try { if (isAlreadyExistEmbedding) { vectorstore = isAlreadyExistEmbedding } else { @@ -127,109 +167,206 @@ export const useMessage = () => { url: embedURL }) } - + let query = message const { ragPrompt: systemPrompt, ragQuestionPrompt: questionPrompt } = await promptForRag() - - const sanitizedQuestion = message.trim().replaceAll("\n", " ") - - const chain = createChatWithWebsiteChain({ - llm: ollamaChat, - question_llm: ollamaChat, - question_template: questionPrompt, - response_template: systemPrompt, - retriever: vectorstore.asRetriever() - }) - - const chunks = await chain.stream({ - question: sanitizedQuestion, - chat_history: groupMessagesByConversation(history) - }) - let count = 0 - for await (const chunk of chunks) { - if (count === 0) { - setIsProcessing(true) - newMessage[appendingIndex].message = chunk + "▋" - setMessages(newMessage) - } else { - newMessage[appendingIndex].message = - newMessage[appendingIndex].message.slice(0, -1) + chunk + "▋" - setMessages(newMessage) - } - - count++ + if (newMessage.length > 2) { + const lastTenMessages = newMessage.slice(-10) + lastTenMessages.pop() + const chat_history = lastTenMessages + .map((message) => { + return `${message.isBot ? "Assistant: " : "Human: "}${message.message}` + }) + .join("\n") + const promptForQuestion = questionPrompt + .replaceAll("{chat_history}", chat_history) + .replaceAll("{question}", message) + const questionOllama = new ChatOllama({ + model: selectedModel!, + baseUrl: cleanUrl(url) + }) + const response = await questionOllama.invoke(promptForQuestion) + query = response.content.toString() } - newMessage[appendingIndex].message = newMessage[ - appendingIndex - ].message.slice(0, -1) + const docs = await vectorstore.similaritySearch(query, 4) + const context = formatDocs(docs) + const source = docs.map((doc) => { + return { + ...doc, + name: doc?.metadata?.source || "untitled", + type: doc?.metadata?.type || "unknown", + mode: "chat", + url: "" + } + }) + message = message.trim().replaceAll("\n", " ") + + let humanMessage = new HumanMessage({ + content: [ + { + text: systemPrompt + .replace("{context}", context) + .replace("{question}", message), + type: "text" + } + ] + }) + + const applicationChatHistory = generateHistory(history) + + const chunks = await ollama.stream( + [...applicationChatHistory, humanMessage], + { + signal: signal + } + ) + let count = 0 + for await (const chunk of chunks) { + contentToSave += chunk.content + fullText += chunk.content + if (count === 0) { + setIsProcessing(true) + } + setMessages((prev) => { + return prev.map((message) => { + if (message.id === generateMessageId) { + return { + ...message, + message: fullText.slice(0, -1) + "▋" + } + } + return message + }) + }) + count++ + } + // update the message with the full text + setMessages((prev) => { + return prev.map((message) => { + if (message.id === generateMessageId) { + return { + ...message, + message: fullText, + sources: source + } + } + return message + }) + }) setHistory([ ...history, { role: "user", - content: message + content: message, + image }, { role: "assistant", - content: newMessage[appendingIndex].message + content: fullText } ]) - setIsProcessing(false) - } catch (e) { + await saveMessageOnSuccess({ + historyId, + setHistoryId, + isRegenerate, + selectedModel: selectedModel, + message, + image, + fullText, + source + }) + setIsProcessing(false) setStreaming(false) + } catch (e) { + const errorSave = await saveMessageOnError({ + e, + botMessage: fullText, + history, + historyId, + image, + selectedModel, + setHistory, + setHistoryId, + userMessage: message, + isRegenerating: isRegenerate + }) - setMessages([ - ...messages, - { - isBot: true, - name: selectedModel, - message: `Error in chat with website mode. Check out the following logs: - -~~~ -${e?.message} - ~~~ - `, - sources: [] - } - ]) + if (!errorSave) { + notification.error({ + message: t("error"), + description: e?.message || t("somethingWentWrong") + }) + } + setIsProcessing(false) + setStreaming(false) + setIsProcessing(false) + setStreaming(false) + setIsEmbedding(false) + } finally { + setAbortController(null) + setEmbeddingController(null) } } - const normalChatMode = async (message: string, image: string) => { + const normalChatMode = async ( + message: string, + image: string, + isRegenerate: boolean, + messages: Message[], + history: ChatHistory, + signal: AbortSignal + ) => { const url = await getOllamaURL() if (image.length > 0) { image = `data:image/jpeg;base64,${image.split(",")[1]}` } - abortControllerRef.current = new AbortController() const ollama = new ChatOllama({ - model: selectedModel, + model: selectedModel!, baseUrl: cleanUrl(url) }) - let newMessage: Message[] = [ - ...messages, - { - isBot: false, - name: "You", - message, - sources: [], - images: [image] - }, - { - isBot: true, - name: selectedModel, - message: "▋", - sources: [] - } - ] + let newMessage: Message[] = [] + let generateMessageId = generateID() - const appendingIndex = newMessage.length - 1 + if (!isRegenerate) { + newMessage = [ + ...messages, + { + isBot: false, + name: "You", + message, + sources: [], + images: [image] + }, + { + isBot: true, + name: selectedModel, + message: "▋", + sources: [], + id: generateMessageId + } + ] + } else { + newMessage = [ + ...messages, + { + isBot: true, + name: selectedModel, + message: "▋", + sources: [], + id: generateMessageId + } + ] + } setMessages(newMessage) + let fullText = "" + let contentToSave = "" try { const prompt = await systemPromptForNonRag() @@ -277,29 +414,41 @@ ${e?.message} const chunks = await ollama.stream( [...applicationChatHistory, humanMessage], { - signal: abortControllerRef.current.signal + signal: signal } ) let count = 0 for await (const chunk of chunks) { + contentToSave += chunk.content + fullText += chunk.content if (count === 0) { setIsProcessing(true) - newMessage[appendingIndex].message = chunk.content + "▋" - setMessages(newMessage) - } else { - newMessage[appendingIndex].message = - newMessage[appendingIndex].message.slice(0, -1) + - chunk.content + - "▋" - setMessages(newMessage) } - + setMessages((prev) => { + return prev.map((message) => { + if (message.id === generateMessageId) { + return { + ...message, + message: fullText.slice(0, -1) + "▋" + } + } + return message + }) + }) count++ } - newMessage[appendingIndex].message = newMessage[ - appendingIndex - ].message.slice(0, -1) + setMessages((prev) => { + return prev.map((message) => { + if (message.id === generateMessageId) { + return { + ...message, + message: fullText.slice(0, -1) + } + } + return message + }) + }) setHistory([ ...history, @@ -310,28 +459,49 @@ ${e?.message} }, { role: "assistant", - content: newMessage[appendingIndex].message + content: fullText } ]) - setIsProcessing(false) - } catch (e) { + await saveMessageOnSuccess({ + historyId, + setHistoryId, + isRegenerate, + selectedModel: selectedModel, + message, + image, + fullText, + source: [] + }) + setIsProcessing(false) setStreaming(false) + setIsProcessing(false) + setStreaming(false) + } catch (e) { + const errorSave = await saveMessageOnError({ + e, + botMessage: fullText, + history, + historyId, + image, + selectedModel, + setHistory, + setHistoryId, + userMessage: message, + isRegenerating: isRegenerate + }) - setMessages([ - ...messages, - { - isBot: true, - name: selectedModel, - message: `Something went wrong. Check out the following logs: - \`\`\` - ${e?.message} - \`\`\` - `, - sources: [] - } - ]) + if (!errorSave) { + notification.error({ + message: t("error"), + description: e?.message || t("somethingWentWrong") + }) + } + setIsProcessing(false) + setStreaming(false) + } finally { + setAbortController(null) } } @@ -342,20 +512,40 @@ ${e?.message} message: string image: string }) => { + const newController = new AbortController() + let signal = newController.signal + setAbortController(newController) + if (chatMode === "normal") { - await normalChatMode(message, image) + await normalChatMode(message, image, false, messages, history, signal) } else { - await chatWithWebsiteMode(message) + const newEmbeddingController = new AbortController() + let embeddingSignal = newEmbeddingController.signal + setEmbeddingController(newEmbeddingController) + await chatWithWebsiteMode( + message, + image, + false, + messages, + history, + signal, + embeddingSignal + ) } } const stopStreamingRequest = () => { - if (abortControllerRef.current) { - abortControllerRef.current.abort() - abortControllerRef.current = null + if (isEmbedding) { + if (embeddingController) { + embeddingController.abort() + setEmbeddingController(null) + } + } + if (abortController) { + abortController.abort() + setAbortController(null) } } - return { messages, setMessages, diff --git a/src/models/OllamaEmbedding.ts b/src/models/OllamaEmbedding.ts new file mode 100644 index 0000000..2d57ef2 --- /dev/null +++ b/src/models/OllamaEmbedding.ts @@ -0,0 +1,255 @@ +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings" +import type { StringWithAutocomplete } from "@langchain/core/utils/types" + +export interface OllamaInput { + embeddingOnly?: boolean + f16KV?: boolean + frequencyPenalty?: number + headers?: Record + keepAlive?: string + logitsAll?: boolean + lowVram?: boolean + mainGpu?: number + model?: string + baseUrl?: string + mirostat?: number + mirostatEta?: number + mirostatTau?: number + numBatch?: number + numCtx?: number + numGpu?: number + numGqa?: number + numKeep?: number + numPredict?: number + numThread?: number + penalizeNewline?: boolean + presencePenalty?: number + repeatLastN?: number + repeatPenalty?: number + ropeFrequencyBase?: number + ropeFrequencyScale?: number + temperature?: number + stop?: string[] + tfsZ?: number + topK?: number + topP?: number + typicalP?: number + useMLock?: boolean + useMMap?: boolean + vocabOnly?: boolean + format?: StringWithAutocomplete<"json"> +} +export interface OllamaRequestParams { + model: string + format?: StringWithAutocomplete<"json"> + images?: string[] + options: { + embedding_only?: boolean + f16_kv?: boolean + frequency_penalty?: number + logits_all?: boolean + low_vram?: boolean + main_gpu?: number + mirostat?: number + mirostat_eta?: number + mirostat_tau?: number + num_batch?: number + num_ctx?: number + num_gpu?: number + num_gqa?: number + num_keep?: number + num_thread?: number + num_predict?: number + penalize_newline?: boolean + presence_penalty?: number + repeat_last_n?: number + repeat_penalty?: number + rope_frequency_base?: number + rope_frequency_scale?: number + temperature?: number + stop?: string[] + tfs_z?: number + top_k?: number + top_p?: number + typical_p?: number + use_mlock?: boolean + use_mmap?: boolean + vocab_only?: boolean + } +} + +type CamelCasedRequestOptions = Omit< + OllamaInput, + "baseUrl" | "model" | "format" | "headers" +> + +/** + * Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and + * defines additional parameters specific to the OllamaEmbeddings class. + */ +interface OllamaEmbeddingsParams extends EmbeddingsParams { + /** The Ollama model to use, e.g: "llama2:13b" */ + model?: string + + /** Base URL of the Ollama server, defaults to "http://localhost:11434" */ + baseUrl?: string + + /** Extra headers to include in the Ollama API request */ + headers?: Record + + /** Defaults to "5m" */ + keepAlive?: string + + /** Advanced Ollama API request parameters in camelCase, see + * https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + * for details of the available parameters. + */ + requestOptions?: CamelCasedRequestOptions + + signal?: AbortSignal +} + +export class OllamaEmbeddingsPageAssist extends Embeddings { + model = "llama2" + + baseUrl = "http://localhost:11434" + + headers?: Record + + keepAlive = "5m" + + requestOptions?: OllamaRequestParams["options"] + + signal?: AbortSignal + + constructor(params?: OllamaEmbeddingsParams) { + super({ maxConcurrency: 1, ...params }) + + if (params?.model) { + this.model = params.model + } + + if (params?.baseUrl) { + this.baseUrl = params.baseUrl + } + + if (params?.headers) { + this.headers = params.headers + } + + if (params?.keepAlive) { + this.keepAlive = params.keepAlive + } + + if (params?.requestOptions) { + this.requestOptions = this._convertOptions(params.requestOptions) + } + + if (params?.signal) { + this.signal = params.signal + } + } + + /** convert camelCased Ollama request options like "useMMap" to + * the snake_cased equivalent which the ollama API actually uses. + * Used only for consistency with the llms/Ollama and chatModels/Ollama classes + */ + _convertOptions(requestOptions: CamelCasedRequestOptions) { + const snakeCasedOptions: Record = {} + const mapping: Record = { + embeddingOnly: "embedding_only", + f16KV: "f16_kv", + frequencyPenalty: "frequency_penalty", + keepAlive: "keep_alive", + logitsAll: "logits_all", + lowVram: "low_vram", + mainGpu: "main_gpu", + mirostat: "mirostat", + mirostatEta: "mirostat_eta", + mirostatTau: "mirostat_tau", + numBatch: "num_batch", + numCtx: "num_ctx", + numGpu: "num_gpu", + numGqa: "num_gqa", + numKeep: "num_keep", + numPredict: "num_predict", + numThread: "num_thread", + penalizeNewline: "penalize_newline", + presencePenalty: "presence_penalty", + repeatLastN: "repeat_last_n", + repeatPenalty: "repeat_penalty", + ropeFrequencyBase: "rope_frequency_base", + ropeFrequencyScale: "rope_frequency_scale", + temperature: "temperature", + stop: "stop", + tfsZ: "tfs_z", + topK: "top_k", + topP: "top_p", + typicalP: "typical_p", + useMLock: "use_mlock", + useMMap: "use_mmap", + vocabOnly: "vocab_only" + } + + for (const [key, value] of Object.entries(requestOptions)) { + const snakeCasedOption = mapping[key as keyof CamelCasedRequestOptions] + if (snakeCasedOption) { + snakeCasedOptions[snakeCasedOption] = value + } + } + return snakeCasedOptions + } + + async _request(prompt: string): Promise { + const { model, baseUrl, keepAlive, requestOptions } = this + + let formattedBaseUrl = baseUrl + if (formattedBaseUrl.startsWith("http://localhost:")) { + // Node 18 has issues with resolving "localhost" + // See https://github.com/node-fetch/node-fetch/issues/1624 + formattedBaseUrl = formattedBaseUrl.replace( + "http://localhost:", + "http://127.0.0.1:" + ) + } + + const response = await fetch(`${formattedBaseUrl}/api/embeddings`, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...this.headers + }, + body: JSON.stringify({ + prompt, + model, + keep_alive: keepAlive, + options: requestOptions + }), + signal: this.signal + }) + if (!response.ok) { + throw new Error( + `Request to Ollama server failed: ${response.status} ${response.statusText}` + ) + } + + const json = await response.json() + return json.embedding + } + + async _embed(texts: string[]): Promise { + const embeddings: number[][] = await Promise.all( + texts.map((text) => this.caller.call(() => this._request(text))) + ) + + return embeddings + } + + async embedDocuments(documents: string[]) { + return this._embed(documents) + } + + async embedQuery(document: string) { + return (await this.embedDocuments([document]))[0] + } +} diff --git a/src/public/_locales/en/messages.json b/src/public/_locales/en/messages.json index 77ece77..6e3f66b 100644 --- a/src/public/_locales/en/messages.json +++ b/src/public/_locales/en/messages.json @@ -6,6 +6,6 @@ "message": "Use your locally running AI models to assist you in your web browsing." }, "openSidePanelToChat": { - "message": "Open Side Panel to Chat" + "message": "Open Copilot to Chat" } } \ No newline at end of file diff --git a/src/public/_locales/ja/messages.json b/src/public/_locales/ja/messages.json index 569ec1a..ced8c60 100644 --- a/src/public/_locales/ja/messages.json +++ b/src/public/_locales/ja/messages.json @@ -1,11 +1,11 @@ { - "extName": { - "message": "Page Assist - ローカルAIモデル用のWeb UI" - }, - "extDescription": { - "message": "ローカルで実行中のAIモデルを使って、Webブラウジングをアシストします。" - }, - "openSidePanelToChat": { - "message": "サイドパネルを開いてチャット" - } - } \ No newline at end of file + "extName": { + "message": "Page Assist - ローカルAIモデル用のWeb UI" + }, + "extDescription": { + "message": "ローカルで実行中のAIモデルを使って、Webブラウジングをアシストします。" + }, + "openSidePanelToChat": { + "message": "チャットするためにCopilotを開く" + } +} \ No newline at end of file diff --git a/src/public/_locales/zh_CN/messages.json b/src/public/_locales/zh_CN/messages.json index 9b4fefe..49dadd1 100644 --- a/src/public/_locales/zh_CN/messages.json +++ b/src/public/_locales/zh_CN/messages.json @@ -6,6 +6,6 @@ "message": "使用本地运行的 AI 模型来辅助您的网络浏览。" }, "openSidePanelToChat": { - "message": "打开侧边栏进行聊天" + "message": "打开Copilot进行聊天" } -} \ No newline at end of file +} \ No newline at end of file diff --git a/src/utils/memory-embeddings.ts b/src/utils/memory-embeddings.ts index e3572d1..f58f4b9 100644 --- a/src/utils/memory-embeddings.ts +++ b/src/utils/memory-embeddings.ts @@ -2,62 +2,75 @@ import { PageAssistHtmlLoader } from "~/loader/html" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { MemoryVectorStore } from "langchain/vectorstores/memory" import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama" -import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize } from "@/services/ollama" +import { + defaultEmbeddingChunkOverlap, + defaultEmbeddingChunkSize +} from "@/services/ollama" import { PageAssistPDFLoader } from "@/loader/pdf" - -export const getLoader = ({ html, pdf, type, url }: { - url: string, - html: string, - type: string, - pdf: { content: string, page: number }[] +export const getLoader = ({ + html, + pdf, + type, + url +}: { + url: string + html: string + type: string + pdf: { content: string; page: number }[] }) => { - if (type === "pdf") { - return new PageAssistPDFLoader({ - pdf, - url - }) - } else { - return new PageAssistHtmlLoader({ - html, - url - }) - } + if (type === "pdf") { + return new PageAssistPDFLoader({ + pdf, + url + }) + } else { + return new PageAssistHtmlLoader({ + html, + url + }) + } } -export const memoryEmbedding = async ( - { html, - keepTrackOfEmbedding, ollamaEmbedding, pdf, setIsEmbedding, setKeepTrackOfEmbedding, type, url }: { - url: string, - html: string, - type: string, - pdf: { content: string, page: number }[], - keepTrackOfEmbedding: Record, - ollamaEmbedding: OllamaEmbeddings, - setIsEmbedding: (value: boolean) => void, - setKeepTrackOfEmbedding: (value: Record) => void - } -) => { - setIsEmbedding(true) +export const memoryEmbedding = async ({ + html, + keepTrackOfEmbedding, + ollamaEmbedding, + pdf, + setIsEmbedding, + setKeepTrackOfEmbedding, + type, + url +}: { + url: string + html: string + type: string + pdf: { content: string; page: number }[] + keepTrackOfEmbedding: Record + ollamaEmbedding: OllamaEmbeddings + setIsEmbedding: (value: boolean) => void + setKeepTrackOfEmbedding: (value: Record) => void +}) => { + setIsEmbedding(true) - const loader = getLoader({ html, pdf, type, url }) - const docs = await loader.load() - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + const loader = getLoader({ html, pdf, type, url }) + const docs = await loader.load() + const chunkSize = await defaultEmbeddingChunkSize() + const chunkOverlap = await defaultEmbeddingChunkOverlap() + const textSplitter = new RecursiveCharacterTextSplitter({ + chunkSize, + chunkOverlap + }) - const chunks = await textSplitter.splitDocuments(docs) + const chunks = await textSplitter.splitDocuments(docs) - const store = new MemoryVectorStore(ollamaEmbedding) + const store = new MemoryVectorStore(ollamaEmbedding) - await store.addDocuments(chunks) - setKeepTrackOfEmbedding({ - ...keepTrackOfEmbedding, - [url]: store - }) - setIsEmbedding(false) - return store -} \ No newline at end of file + await store.addDocuments(chunks) + setKeepTrackOfEmbedding({ + ...keepTrackOfEmbedding, + [url]: store + }) + setIsEmbedding(false) + return store +}