From 1d9d704c76a9f0f7fd4763cd5706afc62c4b85b9 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 4 Jan 2025 20:16:23 +0530 Subject: [PATCH 1/3] feat: Add Mistral provider support feat: Update manifest version to 1.4.1 --- src/components/Common/ProviderIcon.tsx | 3 +++ src/components/Icons/Mistral.tsx | 32 ++++++++++++++++++++++++++ src/utils/oai-api-providers.ts | 5 ++++ wxt.config.ts | 2 +- 4 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 src/components/Icons/Mistral.tsx diff --git a/src/components/Common/ProviderIcon.tsx b/src/components/Common/ProviderIcon.tsx index 1db9d6d..adb9574 100644 --- a/src/components/Common/ProviderIcon.tsx +++ b/src/components/Common/ProviderIcon.tsx @@ -8,6 +8,7 @@ import { TogtherMonoIcon } from "../Icons/Togther" import { OpenRouterIcon } from "../Icons/OpenRouter" import { LLamaFile } from "../Icons/Llamafile" import { GeminiIcon } from "../Icons/GeminiIcon" +import { MistarlIcon } from "../Icons/Mistral" export const ProviderIcons = ({ provider, @@ -37,6 +38,8 @@ export const ProviderIcons = ({ return case "gemini": return + case "mistral": + return default: return } diff --git a/src/components/Icons/Mistral.tsx b/src/components/Icons/Mistral.tsx new file mode 100644 index 0000000..14784f6 --- /dev/null +++ b/src/components/Icons/Mistral.tsx @@ -0,0 +1,32 @@ +import React from "react" + +export const MistarlIcon = React.forwardRef< + SVGSVGElement, + React.SVGProps +>((props, ref) => { + return ( + + + + + + + + + + + ) +}) diff --git a/src/utils/oai-api-providers.ts b/src/utils/oai-api-providers.ts index 06a59e7..f0d3239 100644 --- a/src/utils/oai-api-providers.ts +++ b/src/utils/oai-api-providers.ts @@ -48,5 +48,10 @@ export const OAI_API_PROVIDERS = [ label: "Google AI", value: "gemini", baseUrl: "https://generativelanguage.googleapis.com/v1beta/openai" + }, + { + label: "Mistral", + value: "mistral", + baseUrl: "https://api.mistral.ai/v1" } ] \ No newline at end of file diff --git a/wxt.config.ts b/wxt.config.ts index 0aa73c2..975795b 100644 --- a/wxt.config.ts +++ b/wxt.config.ts @@ -51,7 +51,7 @@ export default defineConfig({ outDir: "build", manifest: { - version: "1.4.0", + version: "1.4.1", name: process.env.TARGET === "firefox" ? "Page Assist - A Web UI for Local AI Models" From 0af69a3be87a8c4820ba84783383321d3ac30916 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 4 Jan 2025 23:24:23 +0530 Subject: [PATCH 2/3] feat: Add text splitting configuration options --- src/assets/locale/ar/settings.json | 11 +++- src/assets/locale/da/settings.json | 8 +++ src/assets/locale/de/settings.json | 8 +++ src/assets/locale/en/settings.json | 10 +++- src/assets/locale/es/settings.json | 8 +++ src/assets/locale/fa/settings.json | 8 +++ src/assets/locale/fr/settings.json | 8 +++ src/assets/locale/it/settings.json | 8 +++ src/assets/locale/ja-JP/settings.json | 8 +++ src/assets/locale/ko/settings.json | 8 +++ src/assets/locale/ml/settings.json | 8 +++ src/assets/locale/no/settings.json | 8 +++ src/assets/locale/pt-BR/settings.json | 8 +++ src/assets/locale/ru/settings.json | 8 +++ src/assets/locale/sv/settings.json | 8 +++ src/assets/locale/uk/settings.json | 8 +++ src/assets/locale/zh/settings.json | 8 +++ src/components/Option/Settings/rag.tsx | 78 +++++++++++++++++++++++--- src/libs/process-knowledge.ts | 27 ++++----- src/services/ollama.ts | 37 +++++++++++- src/utils/memory-embeddings.ts | 13 +---- src/utils/text-splitter.ts | 37 ++++++++++++ src/web/search-engines/brave-api.ts | 11 +--- src/web/search-engines/brave.ts | 12 +--- src/web/search-engines/duckduckgo.ts | 11 +--- src/web/search-engines/google.ts | 14 ++--- src/web/search-engines/searxng.ts | 12 +--- src/web/search-engines/sogou.ts | 11 +--- src/web/website/index.ts | 13 ++--- 29 files changed, 315 insertions(+), 102 deletions(-) create mode 100644 src/utils/text-splitter.ts diff --git a/src/assets/locale/ar/settings.json b/src/assets/locale/ar/settings.json index 6de45fd..23e964d 100644 --- a/src/assets/locale/ar/settings.json +++ b/src/assets/locale/ar/settings.json @@ -334,6 +334,14 @@ "label": "عدد المستندات المسترجعة", "placeholder": "أدخل عدد المستندات المسترجعة", "required": "الرجاء إدخال عدد المستندات المسترجعة" + }, + "splittingSeparator": { + "label": "الفاصل", + "placeholder": "أدخل الفاصل (مثال: \\n\\n)", + "required": "الرجاء إدخال الفاصل" + }, + "splittingStrategy": { + "label": "مقسم النص" } }, "prompt": { @@ -355,4 +363,5 @@ }, "chromeAiSettings": { "title": "إعدادات Chrome AI" - }} + } +} diff --git a/src/assets/locale/da/settings.json b/src/assets/locale/da/settings.json index 1712a45..7efe9eb 100644 --- a/src/assets/locale/da/settings.json +++ b/src/assets/locale/da/settings.json @@ -331,6 +331,14 @@ "label": "Antal Hentede Dokumenter", "placeholder": "Indtast Number of Retrieved Documents", "required": "Venligst indtast the number of retrieved documents" + }, + "splittingSeparator": { + "label": "Separator", + "placeholder": "Indtast Separator (f.eks. \\n\\n)", + "required": "Indtast venligst en separator" + }, + "splittingStrategy": { + "label": "Tekst Splitter" } }, "prompt": { diff --git a/src/assets/locale/de/settings.json b/src/assets/locale/de/settings.json index 581cc28..9d5822b 100644 --- a/src/assets/locale/de/settings.json +++ b/src/assets/locale/de/settings.json @@ -331,6 +331,14 @@ "label": "Anzahl der abgerufenen Dokumente", "placeholder": "Anzahl der abgerufenen Dokumente eingeben", "required": "Bitte geben Sie die Anzahl der abgerufenen Dokumente ein" + }, + "splittingSeparator": { + "label": "Separator", + "placeholder": "Separator eingeben (z.B. \\n\\n)", + "required": "Bitte geben Sie einen Separator ein" + }, + "splittingStrategy": { + "label": "Text-Splitter" } }, "prompt": { diff --git a/src/assets/locale/en/settings.json b/src/assets/locale/en/settings.json index 3f6af27..2d7cdcd 100644 --- a/src/assets/locale/en/settings.json +++ b/src/assets/locale/en/settings.json @@ -72,7 +72,7 @@ } }, "braveApi": { - "label": "Brave API Key", + "label": "Brave API Key", "placeholder": "Enter your Brave API key" }, "googleDomain": { @@ -337,6 +337,14 @@ "label": "Number of Retrieved Documents", "placeholder": "Enter Number of Retrieved Documents", "required": "Please enter the number of retrieved documents" + }, + "splittingSeparator": { + "label": "Separator", + "placeholder": "Enter Separator (e.g., \\n\\n)", + "required": "Please enter a separator" + }, + "splittingStrategy": { + "label": "Text Splitter" } }, "prompt": { diff --git a/src/assets/locale/es/settings.json b/src/assets/locale/es/settings.json index 3809ff1..41a24cb 100644 --- a/src/assets/locale/es/settings.json +++ b/src/assets/locale/es/settings.json @@ -331,6 +331,14 @@ "label": "Número de Documentos Recuperados", "placeholder": "Ingrese el Número de Documentos Recuperados", "required": "Por favor, ingrese el número de documentos recuperados" + }, + "splittingSeparator": { + "label": "Separador", + "placeholder": "Ingrese el separador (ej., \\n\\n)", + "required": "Por favor, ingrese un separador" + }, + "splittingStrategy": { + "label": "Divisor de Texto" } }, "prompt": { diff --git a/src/assets/locale/fa/settings.json b/src/assets/locale/fa/settings.json index bb037ff..48aaf5d 100644 --- a/src/assets/locale/fa/settings.json +++ b/src/assets/locale/fa/settings.json @@ -327,6 +327,14 @@ "label": "تعداد اسناد بازیابی شده", "placeholder": "تعداد اسناد بازیابی شده را وارد کنید", "required": "لطفاً تعداد اسناد بازیابی شده را وارد کنید" + }, + "splittingSeparator": { + "label": "جداکننده", + "placeholder": "جداکننده را وارد کنید (مثلاً \\n\\n)", + "required": "لطفاً یک جداکننده وارد کنید" + }, + "splittingStrategy": { + "label": "تقسیم‌کننده متن" } }, "prompt": { diff --git a/src/assets/locale/fr/settings.json b/src/assets/locale/fr/settings.json index 93ba685..dbd7fad 100644 --- a/src/assets/locale/fr/settings.json +++ b/src/assets/locale/fr/settings.json @@ -331,6 +331,14 @@ "label": "Nombre de documents récupérés", "placeholder": "Entrez le nombre de documents récupérés", "required": "Veuillez saisir le nombre de documents récupérés" + }, + "splittingSeparator": { + "label": "Séparateur", + "placeholder": "Entrez le séparateur (par exemple, \\n\\n)", + "required": "Veuillez saisir un séparateur" + }, + "splittingStrategy": { + "label": "Diviseur de texte" } }, "prompt": { diff --git a/src/assets/locale/it/settings.json b/src/assets/locale/it/settings.json index d7d7007..3b61c5c 100644 --- a/src/assets/locale/it/settings.json +++ b/src/assets/locale/it/settings.json @@ -331,6 +331,14 @@ "label": "Numero di Documenti Recuperati", "placeholder": "Inserisci il Numero di Documenti Recuperati", "required": "Inserisci il numero di documenti recuperati" + }, + "splittingSeparator": { + "label": "Separatore", + "placeholder": "Inserisci il Separatore (es. \\n\\n)", + "required": "Inserisci un separatore" + }, + "splittingStrategy": { + "label": "Divisore di Testo" } }, "prompt": { diff --git a/src/assets/locale/ja-JP/settings.json b/src/assets/locale/ja-JP/settings.json index 62363a5..239ffa4 100644 --- a/src/assets/locale/ja-JP/settings.json +++ b/src/assets/locale/ja-JP/settings.json @@ -334,6 +334,14 @@ "label": "取得ドキュメント数", "placeholder": "取得ドキュメント数を入力", "required": "取得ドキュメント数を入力してください" + }, + "splittingSeparator": { + "label": "セパレーター", + "placeholder": "セパレーターを入力(例:\\n\\n)", + "required": "セパレーターを入力してください" + }, + "splittingStrategy": { + "label": "テキスト分割方式" } }, "prompt": { diff --git a/src/assets/locale/ko/settings.json b/src/assets/locale/ko/settings.json index f4f019a..9728e98 100644 --- a/src/assets/locale/ko/settings.json +++ b/src/assets/locale/ko/settings.json @@ -334,6 +334,14 @@ "label": "검색 문서 수", "placeholder": "검색 문서 수 입력", "required": "검색 문서 수를 입력해주세요" + }, + "splittingSeparator": { + "label": "구분자", + "placeholder": "구분자 입력 (예: \\n\\n)", + "required": "구분자를 입력해주세요" + }, + "splittingStrategy": { + "label": "텍스트 분할기" } }, "prompt": { diff --git a/src/assets/locale/ml/settings.json b/src/assets/locale/ml/settings.json index e18ef5e..fab6e6f 100644 --- a/src/assets/locale/ml/settings.json +++ b/src/assets/locale/ml/settings.json @@ -334,6 +334,14 @@ "label": "വീണ്ടെടുത്ത രേഖകളുടെ എണ്ണം", "placeholder": "വീണ്ടെടുത്ത രേഖകളുടെ എണ്ണം നൽകുക", "required": "ദയവായി വീണ്ടെടുത്ത രേഖകളുടെ എണ്ണം നൽകുക" + }, + "splittingSeparator": { + "label": "വിഭജന ചിഹ്നം", + "placeholder": "വിഭജന ചിഹ്നം നൽകുക (ഉദാ: \\n\\n)", + "required": "ദയവായി ഒരു വിഭജന ചിഹ്നം നൽകുക" + }, + "splittingStrategy": { + "label": "ടെക്സ്റ്റ് സ്പ്ലിറ്റർ" } }, "prompt": { diff --git a/src/assets/locale/no/settings.json b/src/assets/locale/no/settings.json index cd96712..40dece1 100644 --- a/src/assets/locale/no/settings.json +++ b/src/assets/locale/no/settings.json @@ -331,6 +331,14 @@ "label": "Antall hentede dokumenter", "placeholder": "Skriv inn antall hentede dokumenter", "required": "Vennligst skriv inn antall hentede dokumenter" + }, + "splittingSeparator": { + "label": "Separator", + "placeholder": "Skriv inn separator (f.eks. \\n\\n)", + "required": "Vennligst skriv inn en separator" + }, + "splittingStrategy": { + "label": "Tekstdeler" } }, "prompt": { diff --git a/src/assets/locale/pt-BR/settings.json b/src/assets/locale/pt-BR/settings.json index 6dbd407..e296201 100644 --- a/src/assets/locale/pt-BR/settings.json +++ b/src/assets/locale/pt-BR/settings.json @@ -331,6 +331,14 @@ "label": "Número de Documentos Recuperados", "placeholder": "Digite o Número de Documentos Recuperados", "required": "Por favor, insira o número de documentos recuperados" + }, + "splittingSeparator": { + "label": "Separador", + "placeholder": "Digite o Separador (ex: \\n\\n)", + "required": "Por favor, insira um separador" + }, + "splittingStrategy": { + "label": "Divisor de Texto" } }, "prompt": { diff --git a/src/assets/locale/ru/settings.json b/src/assets/locale/ru/settings.json index c71c037..1f985ae 100644 --- a/src/assets/locale/ru/settings.json +++ b/src/assets/locale/ru/settings.json @@ -333,6 +333,14 @@ "label": "Количество извлеченных документов", "placeholder": "Введите количество извлеченных документов", "required": "Пожалуйста, введите количество извлеченных документов" + }, + "splittingSeparator": { + "label": "Разделитель", + "placeholder": "Введите разделитель (например, \\n\\n)", + "required": "Пожалуйста, введите разделитель" + }, + "splittingStrategy": { + "label": "Разделитель текста" } }, "prompt": { diff --git a/src/assets/locale/sv/settings.json b/src/assets/locale/sv/settings.json index 712e9a3..aa049c6 100644 --- a/src/assets/locale/sv/settings.json +++ b/src/assets/locale/sv/settings.json @@ -331,6 +331,14 @@ "label": "Antal hämtade dokument", "placeholder": "Ange antal hämtade dokument", "required": "Vänligen ange antal hämtade dokument" + }, + "splittingSeparator": { + "label": "Separator", + "placeholder": "Ange separator (t.ex. \\n\\n)", + "required": "Vänligen ange en separator" + }, + "splittingStrategy": { + "label": "Textdelare" } }, "prompt": { diff --git a/src/assets/locale/uk/settings.json b/src/assets/locale/uk/settings.json index 7462317..34752c9 100644 --- a/src/assets/locale/uk/settings.json +++ b/src/assets/locale/uk/settings.json @@ -331,6 +331,14 @@ "label": "Кількість отриманих документів", "placeholder": "Ввести кількість отриманих документів", "required": "Будь ласка, введіть кількість документів" + }, + "splittingSeparator": { + "label": "Роздільник", + "placeholder": "Введіть роздільник (напр., \\n\\n)", + "required": "Будь ласка, введіть роздільник" + }, + "splittingStrategy": { + "label": "Розділювач тексту" } }, "prompt": { diff --git a/src/assets/locale/zh/settings.json b/src/assets/locale/zh/settings.json index 2557c44..1a66351 100644 --- a/src/assets/locale/zh/settings.json +++ b/src/assets/locale/zh/settings.json @@ -336,6 +336,14 @@ "label": "检索文档数量", "placeholder": "输入检索文档数量", "required": "请输入检索文档数量" + }, + "splittingSeparator": { + "label": "分隔符", + "placeholder": "输入分隔符(例如:\\n\\n)", + "required": "请输入分隔符" + }, + "splittingStrategy": { + "label": "文本分割器" } }, "prompt": { diff --git a/src/components/Option/Settings/rag.tsx b/src/components/Option/Settings/rag.tsx index 8cc6bbf..534b9be 100644 --- a/src/components/Option/Settings/rag.tsx +++ b/src/components/Option/Settings/rag.tsx @@ -1,10 +1,12 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" -import { Form, InputNumber, Select, Skeleton } from "antd" +import { Form, Input, InputNumber, Select, Skeleton } from "antd" import { SaveButton } from "~/components/Common/SaveButton" import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, + defaultSplittingStrategy, + defaultSsplttingSeparator, getEmbeddingModels, saveForRag } from "~/services/ollama" @@ -16,7 +18,8 @@ import { ProviderIcons } from "@/components/Common/ProviderIcon" export const RagSettings = () => { const { t } = useTranslation("settings") - + const [form] = Form.useForm() + const splittingStrategy = Form.useWatch("splittingStrategy", form) const queryClient = useQueryClient() const { data: ollamaInfo, status } = useQuery({ @@ -28,14 +31,18 @@ export const RagSettings = () => { chunkSize, defaultEM, totalFilePerKB, - noOfRetrievedDocs + noOfRetrievedDocs, + splittingStrategy, + splittingSeparator ] = await Promise.all([ getEmbeddingModels({ returnEmpty: true }), defaultEmbeddingChunkOverlap(), defaultEmbeddingChunkSize(), defaultEmbeddingModelForRag(), getTotalFilePerKB(), - getNoOfRetrievedDocs() + getNoOfRetrievedDocs(), + defaultSplittingStrategy(), + defaultSsplttingSeparator() ]) return { models: allModels, @@ -43,7 +50,9 @@ export const RagSettings = () => { chunkSize, defaultEM, totalFilePerKB, - noOfRetrievedDocs + noOfRetrievedDocs, + splittingStrategy, + splittingSeparator } } }) @@ -55,13 +64,17 @@ export const RagSettings = () => { overlap: number totalFilePerKB: number noOfRetrievedDocs: number + strategy: string + separator: string }) => { await saveForRag( data.model, data.chunkSize, data.overlap, data.totalFilePerKB, - data.noOfRetrievedDocs + data.noOfRetrievedDocs, + data.strategy, + data.separator ) return true }, @@ -85,6 +98,7 @@ export const RagSettings = () => {
{ saveRAG({ @@ -92,7 +106,9 @@ export const RagSettings = () => { chunkSize: data.chunkSize, overlap: data.chunkOverlap, totalFilePerKB: data.totalFilePerKB, - noOfRetrievedDocs: data.noOfRetrievedDocs + noOfRetrievedDocs: data.noOfRetrievedDocs, + separator: data.splittingSeparator, + strategy: data.splittingStrategy }) }} initialValues={{ @@ -100,7 +116,9 @@ export const RagSettings = () => { chunkOverlap: ollamaInfo?.chunkOverlap, defaultEM: ollamaInfo?.defaultEM, totalFilePerKB: ollamaInfo?.totalFilePerKB, - noOfRetrievedDocs: ollamaInfo?.noOfRetrievedDocs + noOfRetrievedDocs: ollamaInfo?.noOfRetrievedDocs, + splittingStrategy: ollamaInfo?.splittingStrategy, + splittingSeparator: ollamaInfo?.splittingSeparator }}> { /> + + + + )} + => { console.log(`Processing knowledge with id: ${id}`) @@ -32,12 +27,8 @@ export const processKnowledge = async (msg: any, id: string): Promise => { baseUrl: cleanUrl(ollamaUrl), model: knowledge.embedding_model }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + + const textSplitter = await getPageAssistTextSplitter() for (const doc of knowledge.source) { if (doc.type === "pdf" || doc.type === "application/pdf") { @@ -65,13 +56,15 @@ export const processKnowledge = async (msg: any, id: string): Promise => { knownledge_id: knowledge.id, file_id: doc.source_id }) - } else if (doc.type === "docx" || doc.type === "application/vnd.openxmlformats-officedocument.wordprocessingml.document") { + } else if ( + doc.type === "docx" || + doc.type === + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ) { try { const loader = new PageAssistDocxLoader({ fileName: doc.filename, - buffer: await toArrayBufferFromBase64( - doc.content - ) + buffer: await toArrayBufferFromBase64(doc.content) }) let docs = await loader.load() diff --git a/src/services/ollama.ts b/src/services/ollama.ts index e19d704..f5e3f71 100644 --- a/src/services/ollama.ts +++ b/src/services/ollama.ts @@ -8,6 +8,9 @@ import { ollamaFormatAllCustomModels } from "@/db/models" const storage = new Storage() +const storage2 = new Storage({ + area: "local" +}) const DEFAULT_OLLAMA_URL = "http://127.0.0.1:11434" const DEFAULT_ASK_FOR_MODEL_SELECTION_EVERY_TIME = true @@ -310,6 +313,22 @@ export const defaultEmbeddingChunkSize = async () => { return parseInt(embeddingChunkSize) } +export const defaultSplittingStrategy = async () => { + const splittingStrategy = await storage.get("defaultSplittingStrategy") + if (!splittingStrategy || splittingStrategy.length === 0) { + return "RecursiveCharacterTextSplitter" + } + return splittingStrategy +} + +export const defaultSsplttingSeparator = async () => { + const splittingSeparator = await storage.get("defaultSplittingSeparator") + if (!splittingSeparator || splittingSeparator.length === 0) { + return "\\n\\n" + } + return splittingSeparator +} + export const defaultEmbeddingChunkOverlap = async () => { const embeddingChunkOverlap = await storage.get( "defaultEmbeddingChunkOverlap" @@ -320,6 +339,14 @@ export const defaultEmbeddingChunkOverlap = async () => { return parseInt(embeddingChunkOverlap) } +export const setDefaultSplittingStrategy = async (strategy: string) => { + await storage.set("defaultSplittingStrategy", strategy) +} + +export const setDefaultSplittingSeparator = async (separator: string) => { + await storage.set("defaultSplittingSeparator", separator) +} + export const setDefaultEmbeddingModelForRag = async (model: string) => { await storage.set("defaultEmbeddingModel", model) } @@ -337,7 +364,9 @@ export const saveForRag = async ( chunkSize: number, overlap: number, totalFilePerKB: number, - noOfRetrievedDocs?: number + noOfRetrievedDocs?: number, + strategy?: string, + separator?: string ) => { await setDefaultEmbeddingModelForRag(model) await setDefaultEmbeddingChunkSize(chunkSize) @@ -346,6 +375,12 @@ export const saveForRag = async ( if (noOfRetrievedDocs) { await setNoOfRetrievedDocs(noOfRetrievedDocs) } + if (strategy) { + await setDefaultSplittingStrategy(strategy) + } + if (separator) { + await setDefaultSplittingSeparator(separator) + } } export const getWebSearchPrompt = async () => { diff --git a/src/utils/memory-embeddings.ts b/src/utils/memory-embeddings.ts index 9cb1b82..1128cf4 100644 --- a/src/utils/memory-embeddings.ts +++ b/src/utils/memory-embeddings.ts @@ -1,12 +1,8 @@ import { PageAssistHtmlLoader } from "~/loader/html" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" -import { - defaultEmbeddingChunkOverlap, - defaultEmbeddingChunkSize -} from "@/services/ollama" import { PageAssistPDFLoader } from "@/loader/pdf" import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore" +import { getPageAssistTextSplitter } from "./text-splitter" export const getLoader = ({ html, @@ -54,12 +50,7 @@ export const memoryEmbedding = async ({ setIsEmbedding(true) const loader = getLoader({ html, pdf, type, url }) const docs = await loader.load() - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + const textSplitter = await getPageAssistTextSplitter() const chunks = await textSplitter.splitDocuments(docs) diff --git a/src/utils/text-splitter.ts b/src/utils/text-splitter.ts new file mode 100644 index 0000000..67a0d7d --- /dev/null +++ b/src/utils/text-splitter.ts @@ -0,0 +1,37 @@ +import { + RecursiveCharacterTextSplitter, + CharacterTextSplitter +} from "langchain/text_splitter" + +import { + defaultEmbeddingChunkOverlap, + defaultEmbeddingChunkSize, + defaultSsplttingSeparator, + defaultSplittingStrategy +} from "@/services/ollama" + +export const getPageAssistTextSplitter = async () => { + const chunkSize = await defaultEmbeddingChunkSize() + const chunkOverlap = await defaultEmbeddingChunkOverlap() + const splittingStrategy = await defaultSplittingStrategy() + + switch (splittingStrategy) { + case "CharacterTextSplitter": + console.log("Using CharacterTextSplitter") + const splittingSeparator = await defaultSsplttingSeparator() + const processedSeparator = splittingSeparator + .replace(/\\n/g, "\n") + .replace(/\\t/g, "\t") + .replace(/\\r/g, "\r") + return new CharacterTextSplitter({ + chunkSize, + chunkOverlap, + separator: processedSeparator + }) + default: + return new RecursiveCharacterTextSplitter({ + chunkSize, + chunkOverlap + }) + } +} diff --git a/src/web/search-engines/brave-api.ts b/src/web/search-engines/brave-api.ts index 5e13312..b95c37c 100644 --- a/src/web/search-engines/brave-api.ts +++ b/src/web/search-engines/brave-api.ts @@ -2,15 +2,13 @@ import { cleanUrl } from "~/libs/clean-url" import { getIsSimpleInternetSearch, totalSearchResults, getBraveApiKey } from "@/services/search" import { pageAssistEmbeddingModel } from "@/models/embedding" import type { Document } from "@langchain/core/documents" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { MemoryVectorStore } from "langchain/vectorstores/memory" import { PageAssistHtmlLoader } from "~/loader/html" import { - defaultEmbeddingChunkOverlap, - defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "~/services/ollama" +import { getPageAssistTextSplitter } from "@/utils/text-splitter" interface BraveAPIResult { title: string @@ -70,12 +68,7 @@ export const braveAPISearch = async (query: string) => { baseUrl: cleanUrl(ollamaUrl) }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + const textSplitter = await getPageAssistTextSplitter() const chunks = await textSplitter.splitDocuments(docs) const store = new MemoryVectorStore(ollamaEmbedding) diff --git a/src/web/search-engines/brave.ts b/src/web/search-engines/brave.ts index b795d8b..71b7670 100644 --- a/src/web/search-engines/brave.ts +++ b/src/web/search-engines/brave.ts @@ -3,8 +3,6 @@ import { urlRewriteRuntime } from "@/libs/runtime" import { PageAssistHtmlLoader } from "@/loader/html" import { pageAssistEmbeddingModel } from "@/models/embedding" import { - defaultEmbeddingChunkOverlap, - defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama" @@ -12,10 +10,10 @@ import { getIsSimpleInternetSearch, totalSearchResults } from "@/services/search" +import { getPageAssistTextSplitter } from "@/utils/text-splitter" import type { Document } from "@langchain/core/documents" import * as cheerio from "cheerio" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { MemoryVectorStore } from "langchain/vectorstores/memory" export const localBraveSearch = async (query: string) => { @@ -87,12 +85,8 @@ export const webBraveSearch = async (query: string) => { baseUrl: cleanUrl(ollamaUrl) }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + + const textSplitter = await getPageAssistTextSplitter(); const chunks = await textSplitter.splitDocuments(docs) diff --git a/src/web/search-engines/duckduckgo.ts b/src/web/search-engines/duckduckgo.ts index e368500..9552b9d 100644 --- a/src/web/search-engines/duckduckgo.ts +++ b/src/web/search-engines/duckduckgo.ts @@ -3,8 +3,6 @@ import { urlRewriteRuntime } from "@/libs/runtime" import { PageAssistHtmlLoader } from "@/loader/html" import { pageAssistEmbeddingModel } from "@/models/embedding" import { - defaultEmbeddingChunkOverlap, - defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama" @@ -12,9 +10,9 @@ import { getIsSimpleInternetSearch, totalSearchResults } from "@/services/search" +import { getPageAssistTextSplitter } from "@/utils/text-splitter" import type { Document } from "@langchain/core/documents" import * as cheerio from "cheerio" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { MemoryVectorStore } from "langchain/vectorstores/memory" export const localDuckDuckGoSearch = async (query: string) => { @@ -90,12 +88,7 @@ export const webDuckDuckGoSearch = async (query: string) => { baseUrl: cleanUrl(ollamaUrl) }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + const textSplitter = await getPageAssistTextSplitter() const chunks = await textSplitter.splitDocuments(docs) diff --git a/src/web/search-engines/google.ts b/src/web/search-engines/google.ts index 8c0a92d..94dd3c4 100644 --- a/src/web/search-engines/google.ts +++ b/src/web/search-engines/google.ts @@ -4,15 +4,13 @@ import { getIsSimpleInternetSearch, totalSearchResults } from "@/services/search" +import { getPageAssistTextSplitter } from "@/utils/text-splitter" import type { Document } from "@langchain/core/documents" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { MemoryVectorStore } from "langchain/vectorstores/memory" import { cleanUrl } from "~/libs/clean-url" import { urlRewriteRuntime } from "~/libs/runtime" import { PageAssistHtmlLoader } from "~/loader/html" import { - defaultEmbeddingChunkOverlap, - defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "~/services/ollama" @@ -91,13 +89,9 @@ export const webGoogleSearch = async (query: string) => { baseUrl: cleanUrl(ollamaUrl) }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) - + + const textSplitter = await getPageAssistTextSplitter() + const chunks = await textSplitter.splitDocuments(docs) const store = new MemoryVectorStore(ollamaEmbedding) diff --git a/src/web/search-engines/searxng.ts b/src/web/search-engines/searxng.ts index 0dc2e64..d3277bf 100644 --- a/src/web/search-engines/searxng.ts +++ b/src/web/search-engines/searxng.ts @@ -3,15 +3,13 @@ import { cleanUrl } from "~/libs/clean-url" import { getSearxngURL, isSearxngJSONMode, getIsSimpleInternetSearch, totalSearchResults } from "@/services/search" import { pageAssistEmbeddingModel } from "@/models/embedding" import type { Document } from "@langchain/core/documents" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { MemoryVectorStore } from "langchain/vectorstores/memory" import { PageAssistHtmlLoader } from "~/loader/html" import { - defaultEmbeddingChunkOverlap, - defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "~/services/ollama" +import { getPageAssistTextSplitter } from "@/utils/text-splitter" interface SearxNGJSONResult { title: string @@ -73,13 +71,9 @@ export const searxngSearch = async (query: string) => { baseUrl: cleanUrl(ollamaUrl) }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + const textSplitter = await getPageAssistTextSplitter(); + const chunks = await textSplitter.splitDocuments(docs) const store = new MemoryVectorStore(ollamaEmbedding) await store.addDocuments(chunks) diff --git a/src/web/search-engines/sogou.ts b/src/web/search-engines/sogou.ts index d1a6090..7bc0126 100644 --- a/src/web/search-engines/sogou.ts +++ b/src/web/search-engines/sogou.ts @@ -3,8 +3,6 @@ import { urlRewriteRuntime } from "@/libs/runtime" import { PageAssistHtmlLoader } from "@/loader/html" import { pageAssistEmbeddingModel } from "@/models/embedding" import { - defaultEmbeddingChunkOverlap, - defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama" @@ -12,9 +10,9 @@ import { getIsSimpleInternetSearch, totalSearchResults } from "@/services/search" +import { getPageAssistTextSplitter } from "@/utils/text-splitter" import type { Document } from "@langchain/core/documents" import * as cheerio from "cheerio" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { MemoryVectorStore } from "langchain/vectorstores/memory" const getCorrectTargeUrl = async (url: string) => { if (!url) return "" @@ -104,12 +102,7 @@ export const webSogouSearch = async (query: string) => { baseUrl: cleanUrl(ollamaUrl) }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + const textSplitter = await getPageAssistTextSplitter() const chunks = await textSplitter.splitDocuments(docs) diff --git a/src/web/website/index.ts b/src/web/website/index.ts index 817fdb7..d817160 100644 --- a/src/web/website/index.ts +++ b/src/web/website/index.ts @@ -1,8 +1,9 @@ import { cleanUrl } from "@/libs/clean-url" import { PageAssistHtmlLoader } from "@/loader/html" import { pageAssistEmbeddingModel } from "@/models/embedding" -import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama" -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" +import { defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama" +import { getPageAssistTextSplitter } from "@/utils/text-splitter" + import { MemoryVectorStore } from "langchain/vectorstores/memory" export const processSingleWebsite = async (url: string, query: string) => { @@ -20,12 +21,8 @@ export const processSingleWebsite = async (url: string, query: string) => { baseUrl: cleanUrl(ollamaUrl) }) - const chunkSize = await defaultEmbeddingChunkSize() - const chunkOverlap = await defaultEmbeddingChunkOverlap() - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize, - chunkOverlap - }) + + const textSplitter = await getPageAssistTextSplitter() const chunks = await textSplitter.splitDocuments(docs) From 9674b842ef98ec07fc6578f9aad175d034c20fea Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sun, 5 Jan 2025 15:11:43 +0530 Subject: [PATCH 3/3] feat: Add Ollama model settings for tfsZ, numKeep, numThread, and useMlock --- src/assets/locale/en/common.json | 15 ++++ .../Settings/CurrentChatModelSettings.tsx | 42 +++++++++- .../Option/Settings/model-settings.tsx | 33 ++++++++ src/hooks/useMessage.tsx | 76 +++++++++++++++++-- src/hooks/useMessageOption.tsx | 56 ++++++++++++-- src/models/ChatOllama.ts | 5 +- src/models/index.ts | 16 +++- src/models/utils/ollama.ts | 1 + src/services/model-settings.ts | 2 + src/store/model.tsx | 4 + 10 files changed, 232 insertions(+), 18 deletions(-) diff --git a/src/assets/locale/en/common.json b/src/assets/locale/en/common.json index b00e03e..6f14588 100644 --- a/src/assets/locale/en/common.json +++ b/src/assets/locale/en/common.json @@ -90,6 +90,21 @@ "useMMap": { "label": "useMmap" }, + "tfsZ": { + "label": "TFS-Z", + "placeholder": "e.g. 1.0, 1.1" + }, + "numKeep": { + "label": "Num Keep", + "placeholder": "e.g. 256, 512" + }, + "numThread": { + "label": "Num Thread", + "placeholder": "e.g. 8, 16" + }, + "useMlock": { + "label": "useMlock" + }, "minP": { "label": "Min P", "placeholder": "e.g. 0.05" diff --git a/src/components/Common/Settings/CurrentChatModelSettings.tsx b/src/components/Common/Settings/CurrentChatModelSettings.tsx index 57a6926..7474f68 100644 --- a/src/components/Common/Settings/CurrentChatModelSettings.tsx +++ b/src/components/Common/Settings/CurrentChatModelSettings.tsx @@ -13,9 +13,8 @@ import { Modal, Skeleton, Switch, - Button } from "antd" -import React, { useState, useCallback } from "react" +import React, { useCallback } from "react" import { useTranslation } from "react-i18next" import { SaveButton } from "../SaveButton" @@ -79,7 +78,11 @@ export const CurrentChatModelSettings = ({ useMMap: cUserSettings.useMMap ?? data.useMMap, minP: cUserSettings.minP ?? data.minP, repeatLastN: cUserSettings.repeatLastN ?? data.repeatLastN, - repeatPenalty: cUserSettings.repeatPenalty ?? data.repeatPenalty + repeatPenalty: cUserSettings.repeatPenalty ?? data.repeatPenalty, + useMlock: cUserSettings.useMlock ?? data.useMlock, + tfsZ: cUserSettings.tfsZ ?? data.tfsZ, + numKeep: cUserSettings.numKeep ?? data.numKeep, + numThread: cUserSettings.numThread ?? data.numThread }) return data }, @@ -230,11 +233,44 @@ export const CurrentChatModelSettings = ({ )} /> + + + + + + + + + + + + ) } diff --git a/src/components/Option/Settings/model-settings.tsx b/src/components/Option/Settings/model-settings.tsx index 43753eb..7ed8eb9 100644 --- a/src/components/Option/Settings/model-settings.tsx +++ b/src/components/Option/Settings/model-settings.tsx @@ -150,11 +150,44 @@ export const ModelSettings = () => { )} /> + + + + + + + + + + + + ) } diff --git a/src/hooks/useMessage.tsx b/src/hooks/useMessage.tsx index 1cdca63..a32fedc 100644 --- a/src/hooks/useMessage.tsx +++ b/src/hooks/useMessage.tsx @@ -150,7 +150,15 @@ export const useMessage = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] @@ -293,7 +301,18 @@ export const useMessage = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: + currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? + userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? + userDefaultModelSettings?.useMlock }) const response = await questionOllama.invoke(promptForQuestion) query = response.content.toString() @@ -514,7 +533,15 @@ export const useMessage = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] @@ -758,7 +785,15 @@ export const useMessage = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] @@ -997,7 +1032,15 @@ export const useMessage = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] @@ -1087,7 +1130,18 @@ export const useMessage = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: + currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? + userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? + userDefaultModelSettings?.useMlock }) const response = await questionOllama.invoke(promptForQuestion) query = response.content.toString() @@ -1286,7 +1340,15 @@ export const useMessage = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] diff --git a/src/hooks/useMessageOption.tsx b/src/hooks/useMessageOption.tsx index 56100be..5fcd4d1 100644 --- a/src/hooks/useMessageOption.tsx +++ b/src/hooks/useMessageOption.tsx @@ -141,7 +141,15 @@ export const useMessageOption = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] @@ -231,7 +239,18 @@ export const useMessageOption = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: + currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? + userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? + userDefaultModelSettings?.useMlock }) const response = await questionOllama.invoke(promptForQuestion) query = response.content.toString() @@ -464,7 +483,15 @@ export const useMessageOption = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] @@ -719,7 +746,15 @@ export const useMessageOption = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? userDefaultModelSettings?.useMlock }) let newMessage: Message[] = [] @@ -825,7 +860,18 @@ export const useMessageOption = () => { userDefaultModelSettings?.repeatLastN, repeatPenalty: currentChatModelSettings?.repeatPenalty ?? - userDefaultModelSettings?.repeatPenalty + userDefaultModelSettings?.repeatPenalty, + tfsZ: + currentChatModelSettings?.tfsZ ?? userDefaultModelSettings?.tfsZ, + numKeep: + currentChatModelSettings?.numKeep ?? + userDefaultModelSettings?.numKeep, + numThread: + currentChatModelSettings?.numThread ?? + userDefaultModelSettings?.numThread, + useMlock: + currentChatModelSettings?.useMlock ?? + userDefaultModelSettings?.useMlock }) const response = await questionOllama.invoke(promptForQuestion) query = response.content.toString() diff --git a/src/models/ChatOllama.ts b/src/models/ChatOllama.ts index e046f68..18d18eb 100644 --- a/src/models/ChatOllama.ts +++ b/src/models/ChatOllama.ts @@ -103,6 +103,8 @@ export class ChatOllama useMMap?: boolean; + useMlock?: boolean; + vocabOnly?: boolean; seed?: number; @@ -148,6 +150,7 @@ export class ChatOllama this.typicalP = fields.typicalP; this.useMLock = fields.useMLock; this.useMMap = fields.useMMap; + this.useMlock = fields.useMlock; this.vocabOnly = fields.vocabOnly; this.format = fields.format; this.seed = fields.seed; @@ -210,7 +213,7 @@ export class ChatOllama top_p: this.topP, min_p: this.minP, typical_p: this.typicalP, - use_mlock: this.useMLock, + use_mlock: this.useMlock, use_mmap: this.useMMap, vocab_only: this.vocabOnly, seed: this.seed, diff --git a/src/models/index.ts b/src/models/index.ts index b50baf2..752f114 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -20,7 +20,11 @@ export const pageAssistModel = async ({ useMMap, minP, repeatLastN, - repeatPenalty + repeatPenalty, + tfsZ, + numKeep, + numThread, + useMlock, }: { model: string baseUrl: string @@ -36,6 +40,10 @@ export const pageAssistModel = async ({ minP?: number repeatPenalty?: number repeatLastN?: number + tfsZ?: number, + numKeep?: number, + numThread?: number, + useMlock?: boolean, }) => { if (model === "chrome::gemini-nano::page-assist") { return new ChatChromeAI({ @@ -80,7 +88,7 @@ export const pageAssistModel = async ({ } }) as any } - + console.log('useMlock', useMlock) return new ChatOllama({ baseUrl, keepAlive, @@ -96,5 +104,9 @@ export const pageAssistModel = async ({ minP: minP, repeatPenalty: repeatPenalty, repeatLastN: repeatLastN, + tfsZ, + numKeep, + numThread, + useMlock }) } diff --git a/src/models/utils/ollama.ts b/src/models/utils/ollama.ts index a311ca9..8a48a9c 100644 --- a/src/models/utils/ollama.ts +++ b/src/models/utils/ollama.ts @@ -40,6 +40,7 @@ export interface OllamaInput { useMLock?: boolean useMMap?: boolean vocabOnly?: boolean + useMlock?: boolean seed?: number format?: StringWithAutocomplete<"json"> } diff --git a/src/services/model-settings.ts b/src/services/model-settings.ts index 730d619..82afca7 100644 --- a/src/services/model-settings.ts +++ b/src/services/model-settings.ts @@ -33,6 +33,7 @@ type ModelSettings = { useMMap?: boolean vocabOnly?: boolean minP?: number + useMlock?: boolean } const keys = [ @@ -65,6 +66,7 @@ const keys = [ "useMMap", "vocabOnly", "minP", + "useMlock" ] export const getAllModelSettings = async () => { diff --git a/src/store/model.tsx b/src/store/model.tsx index 9ef8546..23aab9c 100644 --- a/src/store/model.tsx +++ b/src/store/model.tsx @@ -66,6 +66,8 @@ type CurrentChatModelSettings = { reset: () => void systemPrompt?: string setSystemPrompt: (systemPrompt: string) => void + useMlock?: boolean + setUseMlock: (useMlock: boolean) => void setMinP: (minP: number) => void } @@ -108,6 +110,7 @@ export const useStoreChatModelSettings = create( systemPrompt: undefined, setMinP: (minP: number) => set({ minP }), setSystemPrompt: (systemPrompt: string) => set({ systemPrompt }), + setUseMlock: (useMlock: boolean) => set({ useMlock }), reset: () => set({ f16KV: undefined, @@ -141,6 +144,7 @@ export const useStoreChatModelSettings = create( seed: undefined, systemPrompt: undefined, minP: undefined, + useMlock: undefined, }) }) )