From b3a455382c35c464345d5d58bb9ac5feb86c4e57 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Thu, 23 May 2024 00:39:44 +0530 Subject: [PATCH] chore: Update version to 1.1.9 and add Model Settings to Ollama settings page --- src/assets/locale/en/common.json | 36 +- .../Common/CurrentChatModelSettings.tsx | 138 ++++++ src/components/Icons/ChatSettings.tsx | 29 ++ src/components/Layouts/Layout.tsx | 20 +- .../Option/Settings/model-settings.tsx | 123 ++++++ src/components/Option/Settings/ollama.tsx | 4 + .../Option/Settings/search-mode.tsx | 1 - src/hooks/useMessageOption.tsx | 92 +++- src/models/ChatOllama.ts | 406 ++++++++++++++++++ src/models/utils/ollama.ts | 201 +++++++++ src/services/model-settings.ts | 101 +++++ src/store/model.tsx | 136 ++++++ wxt.config.ts | 2 +- 13 files changed, 1271 insertions(+), 18 deletions(-) create mode 100644 src/components/Common/CurrentChatModelSettings.tsx create mode 100644 src/components/Icons/ChatSettings.tsx create mode 100644 src/components/Option/Settings/model-settings.tsx create mode 100644 src/models/ChatOllama.ts create mode 100644 src/models/utils/ollama.ts create mode 100644 src/services/model-settings.ts create mode 100644 src/store/model.tsx diff --git a/src/assets/locale/en/common.json b/src/assets/locale/en/common.json index 41df538..553c5f7 100644 --- a/src/assets/locale/en/common.json +++ b/src/assets/locale/en/common.json @@ -50,5 +50,39 @@ "noHistory": "No chat history", "chatWithCurrentPage": "Chat with current page", "beta": "Beta", - "tts": "Read aloud" + "tts": "Read aloud", + "modelSettings": { + "label": "Model Settings", + "currentChatModelSettings":"Current Chat Model Settings", + "description": "Set the model options globally for all chats", + "form": { + "keepAlive": { + "label": "Keep Alive", + "help": "controls how long the model will stay loaded into memory following the request (default: 5m)", + "placeholder": "Enter Keep Alive duration (e.g. 5m, 10m, 1h)" + }, + "temperature": { + "label": "Temperature", + "placeholder": "Enter Temperature value (e.g. 0.7, 1.0)" + }, + "numCtx": { + "label": "Number of Contexts", + "placeholder": "Enter Number of Contexts value (default: 2048)" + }, + "seed": { + "label": "Seed", + "placeholder": "Enter Seed value (e.g. 1234)", + "help": "Reproducibility of the model output" + }, + "topK": { + "label": "Top K", + "placeholder": "Enter Top K value (e.g. 40, 100)" + }, + "topP": { + "label": "Top P", + "placeholder": "Enter Top P value (e.g. 0.9, 0.95)" + } + }, + "advanced": "More Model Settings" + } } \ No newline at end of file diff --git a/src/components/Common/CurrentChatModelSettings.tsx b/src/components/Common/CurrentChatModelSettings.tsx new file mode 100644 index 0000000..d2b7cba --- /dev/null +++ b/src/components/Common/CurrentChatModelSettings.tsx @@ -0,0 +1,138 @@ +import { getAllModelSettings } from "@/services/model-settings" +import { useStoreChatModelSettings } from "@/store/model" +import { useQuery } from "@tanstack/react-query" +import { Collapse, Form, Input, InputNumber, Modal, Skeleton } from "antd" +import React from "react" +import { useTranslation } from "react-i18next" + +type Props = { + open: boolean + setOpen: (open: boolean) => void +} + +export const CurrentChatModelSettings = ({ open, setOpen }: Props) => { + const { t } = useTranslation("common") + const [form] = Form.useForm() + const cUserSettings = useStoreChatModelSettings() + const { isPending: isLoading } = useQuery({ + queryKey: ["fetchModelConfig2", open], + queryFn: async () => { + const data = await getAllModelSettings() + form.setFieldsValue({ + temperature: cUserSettings.temperature ?? data.temperature, + topK: cUserSettings.topK ?? data.topK, + topP: cUserSettings.topP ?? data.topP, + keepAlive: cUserSettings.keepAlive ?? data.keepAlive, + numCtx: cUserSettings.numCtx ?? data.numCtx, + seed: cUserSettings.seed + }) + return data + }, + enabled: open, + refetchOnMount: true + }) + return ( + setOpen(false)} + onCancel={() => setOpen(false)} + footer={null}> + {!isLoading ? ( +
{ + Object.entries(values).forEach(([key, value]) => { + cUserSettings.setX(key, value) + setOpen(false) + }) + }} + form={form} + layout="vertical"> + + + + + + + + + + + + + + + + + + + + + + + + ) + } + ]} + /> + + + + ) : ( + + )} +
+ ) +} diff --git a/src/components/Icons/ChatSettings.tsx b/src/components/Icons/ChatSettings.tsx new file mode 100644 index 0000000..e9cbf27 --- /dev/null +++ b/src/components/Icons/ChatSettings.tsx @@ -0,0 +1,29 @@ +import React from "react" + +export const ChatSettings = React.forwardRef< + SVGSVGElement, + React.SVGProps +>((props, ref) => { + return ( + + + + + + ) +}) diff --git a/src/components/Layouts/Layout.tsx b/src/components/Layouts/Layout.tsx index 2156710..676623d 100644 --- a/src/components/Layouts/Layout.tsx +++ b/src/components/Layouts/Layout.tsx @@ -7,6 +7,7 @@ import { useQuery } from "@tanstack/react-query" import { fetchChatModels, getAllModels } from "~/services/ollama" import { useMessageOption } from "~/hooks/useMessageOption" import { + BrainCog, ChevronLeft, CogIcon, ComputerIcon, @@ -24,6 +25,8 @@ import { SelectedKnowledge } from "../Option/Knowledge/SelectedKnwledge" import { useStorage } from "@plasmohq/storage/hook" import { ModelSelect } from "../Common/ModelSelect" import { PromptSelect } from "../Common/PromptSelect" +import { ChatSettings } from "../Icons/ChatSettings" +import { CurrentChatModelSettings } from "../Common/CurrentChatModelSettings" export default function OptionLayout({ children @@ -33,6 +36,7 @@ export default function OptionLayout({ const [sidebarOpen, setSidebarOpen] = useState(false) const { t } = useTranslation(["option", "common"]) const [shareModeEnabled] = useStorage("shareMode", false) + const [openModelSettings, setOpenModelSettings] = useState(false) const { selectedModel, @@ -108,9 +112,7 @@ export default function OptionLayout({ onClick={clearChat} className="inline-flex dark:bg-transparent bg-white items-center rounded-lg border dark:border-gray-700 bg-transparent px-3 py-2.5 text-xs lg:text-sm font-medium leading-4 text-gray-800 dark:text-white disabled:opacity-50 ease-in-out transition-colors duration-200 hover:bg-gray-100 dark:hover:bg-gray-800 dark:hover:text-white"> - - {t("newChat")} - + {t("newChat")} @@ -193,6 +195,13 @@ export default function OptionLayout({
+ + + {pathname === "/" && messages.length > 0 && !streaming && @@ -228,6 +237,11 @@ export default function OptionLayout({ open={sidebarOpen}> setSidebarOpen(false)} /> + +
) } diff --git a/src/components/Option/Settings/model-settings.tsx b/src/components/Option/Settings/model-settings.tsx new file mode 100644 index 0000000..a7ee8ef --- /dev/null +++ b/src/components/Option/Settings/model-settings.tsx @@ -0,0 +1,123 @@ +import { SaveButton } from "@/components/Common/SaveButton" +import { getAllModelSettings, setModelSetting } from "@/services/model-settings" +import { useQuery, useQueryClient } from "@tanstack/react-query" +import { Form, Skeleton, Input, Switch, InputNumber, Collapse } from "antd" +import React from "react" +import { useTranslation } from "react-i18next" +// keepAlive?: string +// temperature?: number +// topK?: number +// topP?: number + +export const ModelSettings = () => { + const { t } = useTranslation("common") + const [form] = Form.useForm() + const client = useQueryClient() + const { isPending: isLoading } = useQuery({ + queryKey: ["fetchModelConfig"], + queryFn: async () => { + const data = await getAllModelSettings() + form.setFieldsValue(data) + return data + } + }) + + return ( +
+
+

+ {t("modelSettings.label")} +

+

+ {t("modelSettings.description")} +

+
+
+ {!isLoading ? ( +
{ + Object.entries(values).forEach(([key, value]) => { + setModelSetting(key, value) + }) + client.invalidateQueries({ + queryKey: ["fetchModelConfig"] + }) + }} + form={form} + layout="vertical"> + + + + + + + + + + + + + + + + + + + + + ) + } + ]} + /> + +
+ +
+ + ) : ( + + )} +
+ ) +} diff --git a/src/components/Option/Settings/ollama.tsx b/src/components/Option/Settings/ollama.tsx index 969a845..688ae96 100644 --- a/src/components/Option/Settings/ollama.tsx +++ b/src/components/Option/Settings/ollama.tsx @@ -15,6 +15,7 @@ import { SettingPrompt } from "./prompt" import { Trans, useTranslation } from "react-i18next" import { useStorage } from "@plasmohq/storage/hook" import { AdvanceOllamaSettings } from "@/components/Common/AdvanceOllamaSettings" +import { ModelSettings } from "./model-settings" export const SettingsOllama = () => { const [ollamaURL, setOllamaURL] = useState("") @@ -219,6 +220,7 @@ export const SettingsOllama = () => {
+
@@ -229,6 +231,8 @@ export const SettingsOllama = () => {
+ + )} diff --git a/src/components/Option/Settings/search-mode.tsx b/src/components/Option/Settings/search-mode.tsx index cf09183..8973ec0 100644 --- a/src/components/Option/Settings/search-mode.tsx +++ b/src/components/Option/Settings/search-mode.tsx @@ -8,7 +8,6 @@ import { useTranslation } from "react-i18next" export const SearchModeSettings = () => { const { t } = useTranslation("settings") - const queryClient = useQueryClient() const form = useForm({ initialValues: { diff --git a/src/hooks/useMessageOption.tsx b/src/hooks/useMessageOption.tsx index 64100fa..40341ff 100644 --- a/src/hooks/useMessageOption.tsx +++ b/src/hooks/useMessageOption.tsx @@ -8,7 +8,6 @@ import { systemPromptForNonRagOption } from "~/services/ollama" import { type ChatHistory, type Message } from "~/store/option" -import { ChatOllama } from "@langchain/community/chat_models/ollama" import { HumanMessage, SystemMessage } from "@langchain/core/messages" import { useStoreMessageOption } from "~/store/option" import { @@ -29,8 +28,10 @@ import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama" import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore" import { formatDocs } from "@/chain/chat-with-x" import { useWebUI } from "@/store/webui" -import { isTTSEnabled } from "@/services/tts" import { useStorage } from "@plasmohq/storage/hook" +import { useStoreChatModelSettings } from "@/store/model" +import { getAllDefaultModelSettings } from "@/services/model-settings" +import { ChatOllama } from "@/models/ChatOllama" export const useMessageOption = () => { const { @@ -66,6 +67,7 @@ export const useMessageOption = () => { selectedKnowledge, setSelectedKnowledge } = useStoreMessageOption() + const currentChatModelSettings = useStoreChatModelSettings() const [selectedModel, setSelectedModel] = useStorage("selectedModel") const { ttsEnabled } = useWebUI() @@ -75,7 +77,6 @@ export const useMessageOption = () => { const navigate = useNavigate() const textareaRef = React.useRef(null) - const clearChat = () => { navigate("/") setMessages([]) @@ -85,6 +86,7 @@ export const useMessageOption = () => { setIsLoading(false) setIsProcessing(false) setStreaming(false) + currentChatModelSettings.reset() textareaRef?.current?.focus() } @@ -97,14 +99,25 @@ export const useMessageOption = () => { signal: AbortSignal ) => { const url = await getOllamaURL() - + const userDefaultModelSettings = await getAllDefaultModelSettings() if (image.length > 0) { image = `data:image/jpeg;base64,${image.split(",")[1]}` } const ollama = new ChatOllama({ model: selectedModel!, - baseUrl: cleanUrl(url) + baseUrl: cleanUrl(url), + keepAlive: + currentChatModelSettings?.keepAlive ?? + userDefaultModelSettings?.keepAlive, + temperature: + currentChatModelSettings?.temperature ?? + userDefaultModelSettings?.temperature, + topK: currentChatModelSettings?.topK ?? userDefaultModelSettings?.topK, + topP: currentChatModelSettings?.topP ?? userDefaultModelSettings?.topP, + numCtx: + currentChatModelSettings?.numCtx ?? userDefaultModelSettings?.numCtx, + seed: currentChatModelSettings?.seed }) let newMessage: Message[] = [] @@ -163,7 +176,21 @@ export const useMessageOption = () => { .replaceAll("{question}", message) const questionOllama = new ChatOllama({ model: selectedModel!, - baseUrl: cleanUrl(url) + baseUrl: cleanUrl(url), + keepAlive: + currentChatModelSettings?.keepAlive ?? + userDefaultModelSettings?.keepAlive, + temperature: + currentChatModelSettings?.temperature ?? + userDefaultModelSettings?.temperature, + topK: + currentChatModelSettings?.topK ?? userDefaultModelSettings?.topK, + topP: + currentChatModelSettings?.topP ?? userDefaultModelSettings?.topP, + numCtx: + currentChatModelSettings?.numCtx ?? + userDefaultModelSettings?.numCtx, + seed: currentChatModelSettings?.seed }) const response = await questionOllama.invoke(promptForQuestion) query = response.content.toString() @@ -172,7 +199,7 @@ export const useMessageOption = () => { const { prompt, source } = await getSystemPromptForWeb(query) setIsSearchingInternet(false) - // message = message.trim().replaceAll("\n", " ") + // message = message.trim().replaceAll("\n", " ") let humanMessage = new HumanMessage({ content: [ @@ -314,6 +341,7 @@ export const useMessageOption = () => { signal: AbortSignal ) => { const url = await getOllamaURL() + const userDefaultModelSettings = await getAllDefaultModelSettings() if (image.length > 0) { image = `data:image/jpeg;base64,${image.split(",")[1]}` @@ -321,7 +349,18 @@ export const useMessageOption = () => { const ollama = new ChatOllama({ model: selectedModel!, - baseUrl: cleanUrl(url) + baseUrl: cleanUrl(url), + keepAlive: + currentChatModelSettings?.keepAlive ?? + userDefaultModelSettings?.keepAlive, + temperature: + currentChatModelSettings?.temperature ?? + userDefaultModelSettings?.temperature, + topK: currentChatModelSettings?.topK ?? userDefaultModelSettings?.topK, + topP: currentChatModelSettings?.topP ?? userDefaultModelSettings?.topP, + numCtx: + currentChatModelSettings?.numCtx ?? userDefaultModelSettings?.numCtx, + seed: currentChatModelSettings?.seed }) let newMessage: Message[] = [] @@ -521,10 +560,22 @@ export const useMessageOption = () => { signal: AbortSignal ) => { const url = await getOllamaURL() + const userDefaultModelSettings = await getAllDefaultModelSettings() const ollama = new ChatOllama({ model: selectedModel!, - baseUrl: cleanUrl(url) + baseUrl: cleanUrl(url), + keepAlive: + currentChatModelSettings?.keepAlive ?? + userDefaultModelSettings?.keepAlive, + temperature: + currentChatModelSettings?.temperature ?? + userDefaultModelSettings?.temperature, + topK: currentChatModelSettings?.topK ?? userDefaultModelSettings?.topK, + topP: currentChatModelSettings?.topP ?? userDefaultModelSettings?.topP, + numCtx: + currentChatModelSettings?.numCtx ?? userDefaultModelSettings?.numCtx, + seed: currentChatModelSettings?.seed }) let newMessage: Message[] = [] @@ -568,7 +619,10 @@ export const useMessageOption = () => { const ollamaUrl = await getOllamaURL() const ollamaEmbedding = new OllamaEmbeddings({ model: embeddingModle || selectedModel, - baseUrl: cleanUrl(ollamaUrl) + baseUrl: cleanUrl(ollamaUrl), + keepAlive: + currentChatModelSettings?.keepAlive ?? + userDefaultModelSettings?.keepAlive }) let vectorstore = await PageAssistVectorStore.fromExistingIndex( @@ -596,7 +650,21 @@ export const useMessageOption = () => { .replaceAll("{question}", message) const questionOllama = new ChatOllama({ model: selectedModel!, - baseUrl: cleanUrl(url) + baseUrl: cleanUrl(url), + keepAlive: + currentChatModelSettings?.keepAlive ?? + userDefaultModelSettings?.keepAlive, + temperature: + currentChatModelSettings?.temperature ?? + userDefaultModelSettings?.temperature, + topK: + currentChatModelSettings?.topK ?? userDefaultModelSettings?.topK, + topP: + currentChatModelSettings?.topP ?? userDefaultModelSettings?.topP, + numCtx: + currentChatModelSettings?.numCtx ?? + userDefaultModelSettings?.numCtx, + seed: currentChatModelSettings?.seed }) const response = await questionOllama.invoke(promptForQuestion) query = response.content.toString() @@ -613,7 +681,7 @@ export const useMessageOption = () => { url: "" } }) - // message = message.trim().replaceAll("\n", " ") + // message = message.trim().replaceAll("\n", " ") let humanMessage = new HumanMessage({ content: [ diff --git a/src/models/ChatOllama.ts b/src/models/ChatOllama.ts new file mode 100644 index 0000000..69138bf --- /dev/null +++ b/src/models/ChatOllama.ts @@ -0,0 +1,406 @@ +import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { + SimpleChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + AIMessageChunk, + BaseMessage, + ChatMessage, +} from "@langchain/core/messages"; +import { ChatGenerationChunk } from "@langchain/core/outputs"; +import type { StringWithAutocomplete } from "@langchain/core/utils/types"; + +import { + createOllamaChatStream, + createOllamaGenerateStream, + type OllamaInput, + type OllamaMessage, +} from "./utils/ollama"; + +export interface ChatOllamaInput extends OllamaInput { } + +export interface ChatOllamaCallOptions extends BaseLanguageModelCallOptions { } + +export class ChatOllama + extends SimpleChatModel + implements ChatOllamaInput { + static lc_name() { + return "ChatOllama"; + } + + lc_serializable = true; + + model = "llama2"; + + baseUrl = "http://localhost:11434"; + + keepAlive = "5m"; + + embeddingOnly?: boolean; + + f16KV?: boolean; + + frequencyPenalty?: number; + + headers?: Record; + + logitsAll?: boolean; + + lowVram?: boolean; + + mainGpu?: number; + + mirostat?: number; + + mirostatEta?: number; + + mirostatTau?: number; + + numBatch?: number; + + numCtx?: number; + + numGpu?: number; + + numGqa?: number; + + numKeep?: number; + + numPredict?: number; + + numThread?: number; + + penalizeNewline?: boolean; + + presencePenalty?: number; + + repeatLastN?: number; + + repeatPenalty?: number; + + ropeFrequencyBase?: number; + + ropeFrequencyScale?: number; + + temperature?: number; + + stop?: string[]; + + tfsZ?: number; + + topK?: number; + + topP?: number; + + typicalP?: number; + + useMLock?: boolean; + + useMMap?: boolean; + + vocabOnly?: boolean; + + seed?: number; + + format?: StringWithAutocomplete<"json">; + + constructor(fields: OllamaInput & BaseChatModelParams) { + super(fields); + this.model = fields.model ?? this.model; + this.baseUrl = fields.baseUrl?.endsWith("/") + ? fields.baseUrl.slice(0, -1) + : fields.baseUrl ?? this.baseUrl; + this.keepAlive = fields.keepAlive ?? this.keepAlive; + this.embeddingOnly = fields.embeddingOnly; + this.f16KV = fields.f16KV; + this.frequencyPenalty = fields.frequencyPenalty; + this.headers = fields.headers; + this.logitsAll = fields.logitsAll; + this.lowVram = fields.lowVram; + this.mainGpu = fields.mainGpu; + this.mirostat = fields.mirostat; + this.mirostatEta = fields.mirostatEta; + this.mirostatTau = fields.mirostatTau; + this.numBatch = fields.numBatch; + this.numCtx = fields.numCtx; + this.numGpu = fields.numGpu; + this.numGqa = fields.numGqa; + this.numKeep = fields.numKeep; + this.numPredict = fields.numPredict; + this.numThread = fields.numThread; + this.penalizeNewline = fields.penalizeNewline; + this.presencePenalty = fields.presencePenalty; + this.repeatLastN = fields.repeatLastN; + this.repeatPenalty = fields.repeatPenalty; + this.ropeFrequencyBase = fields.ropeFrequencyBase; + this.ropeFrequencyScale = fields.ropeFrequencyScale; + this.temperature = fields.temperature; + this.stop = fields.stop; + this.tfsZ = fields.tfsZ; + this.topK = fields.topK; + this.topP = fields.topP; + this.typicalP = fields.typicalP; + this.useMLock = fields.useMLock; + this.useMMap = fields.useMMap; + this.vocabOnly = fields.vocabOnly; + this.format = fields.format; + this.seed = fields.seed; + } + + protected getLsParams(options: this["ParsedCallOptions"]) { + const params = this.invocationParams(options); + return { + ls_provider: "ollama", + ls_model_name: this.model, + ls_model_type: "chat", + ls_temperature: this.temperature ?? undefined, + ls_stop: this.stop, + ls_max_tokens: params.options.num_predict, + }; + } + + _llmType() { + return "ollama"; + } + + /** + * A method that returns the parameters for an Ollama API call. It + * includes model and options parameters. + * @param options Optional parsed call options. + * @returns An object containing the parameters for an Ollama API call. + */ + invocationParams(options?: this["ParsedCallOptions"]) { + return { + model: this.model, + format: this.format, + keep_alive: this.keepAlive, + options: { + embedding_only: this.embeddingOnly, + f16_kv: this.f16KV, + frequency_penalty: this.frequencyPenalty, + logits_all: this.logitsAll, + low_vram: this.lowVram, + main_gpu: this.mainGpu, + mirostat: this.mirostat, + mirostat_eta: this.mirostatEta, + mirostat_tau: this.mirostatTau, + num_batch: this.numBatch, + num_ctx: this.numCtx, + num_gpu: this.numGpu, + num_gqa: this.numGqa, + num_keep: this.numKeep, + num_predict: this.numPredict, + num_thread: this.numThread, + penalize_newline: this.penalizeNewline, + presence_penalty: this.presencePenalty, + repeat_last_n: this.repeatLastN, + repeat_penalty: this.repeatPenalty, + rope_frequency_base: this.ropeFrequencyBase, + rope_frequency_scale: this.ropeFrequencyScale, + temperature: this.temperature, + stop: options?.stop ?? this.stop, + tfs_z: this.tfsZ, + top_k: this.topK, + top_p: this.topP, + typical_p: this.typicalP, + use_mlock: this.useMLock, + use_mmap: this.useMMap, + vocab_only: this.vocabOnly, + seed: this.seed, + }, + }; + } + + _combineLLMOutput() { + return {}; + } + + /** @deprecated */ + async *_streamResponseChunksLegacy( + input: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const stream = createOllamaGenerateStream( + this.baseUrl, + { + ...this.invocationParams(options), + prompt: this._formatMessagesAsPrompt(input), + }, + { + ...options, + headers: this.headers, + } + ); + for await (const chunk of stream) { + if (!chunk.done) { + yield new ChatGenerationChunk({ + text: chunk.response, + message: new AIMessageChunk({ content: chunk.response }), + }); + await runManager?.handleLLMNewToken(chunk.response ?? ""); + } else { + yield new ChatGenerationChunk({ + text: "", + message: new AIMessageChunk({ content: "" }), + generationInfo: { + model: chunk.model, + total_duration: chunk.total_duration, + load_duration: chunk.load_duration, + prompt_eval_count: chunk.prompt_eval_count, + prompt_eval_duration: chunk.prompt_eval_duration, + eval_count: chunk.eval_count, + eval_duration: chunk.eval_duration, + }, + }); + } + } + } + + async *_streamResponseChunks( + input: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + try { + const stream = await this.caller.call(async () => + createOllamaChatStream( + this.baseUrl, + { + ...this.invocationParams(options), + messages: this._convertMessagesToOllamaMessages(input), + }, + { + ...options, + headers: this.headers, + } + ) + ); + for await (const chunk of stream) { + if (!chunk.done) { + yield new ChatGenerationChunk({ + text: chunk.message.content, + message: new AIMessageChunk({ content: chunk.message.content }), + }); + await runManager?.handleLLMNewToken(chunk.message.content ?? ""); + } else { + yield new ChatGenerationChunk({ + text: "", + message: new AIMessageChunk({ content: "" }), + generationInfo: { + model: chunk.model, + total_duration: chunk.total_duration, + load_duration: chunk.load_duration, + prompt_eval_count: chunk.prompt_eval_count, + prompt_eval_duration: chunk.prompt_eval_duration, + eval_count: chunk.eval_count, + eval_duration: chunk.eval_duration, + }, + }); + } + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + if (e.response?.status === 404) { + console.warn( + "[WARNING]: It seems you are using a legacy version of Ollama. Please upgrade to a newer version for better chat support." + ); + yield* this._streamResponseChunksLegacy(input, options, runManager); + } else { + throw e; + } + } + } + + protected _convertMessagesToOllamaMessages( + messages: BaseMessage[] + ): OllamaMessage[] { + return messages.map((message) => { + let role; + if (message._getType() === "human") { + role = "user"; + } else if (message._getType() === "ai") { + role = "assistant"; + } else if (message._getType() === "system") { + role = "system"; + } else { + throw new Error( + `Unsupported message type for Ollama: ${message._getType()}` + ); + } + let content = ""; + const images = []; + if (typeof message.content === "string") { + content = message.content; + } else { + for (const contentPart of message.content) { + if (contentPart.type === "text") { + content = `${content}\n${contentPart.text}`; + } else if ( + contentPart.type === "image_url" && + typeof contentPart.image_url === "string" + ) { + const imageUrlComponents = contentPart.image_url.split(","); + // Support both data:image/jpeg;base64, format as well + images.push(imageUrlComponents[1] ?? imageUrlComponents[0]); + } else { + throw new Error( + `Unsupported message content type. Must either have type "text" or type "image_url" with a string "image_url" field.` + ); + } + } + } + return { + role, + content, + images, + }; + }); + } + + /** @deprecated */ + protected _formatMessagesAsPrompt(messages: BaseMessage[]): string { + const formattedMessages = messages + .map((message) => { + let messageText; + if (message._getType() === "human") { + messageText = `[INST] ${message.content} [/INST]`; + } else if (message._getType() === "ai") { + messageText = message.content; + } else if (message._getType() === "system") { + messageText = `<> ${message.content} <>`; + } else if (ChatMessage.isInstance(message)) { + messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( + 1 + )}: ${message.content}`; + } else { + console.warn( + `Unsupported message type passed to Ollama: "${message._getType()}"` + ); + messageText = ""; + } + return messageText; + }) + .join("\n"); + return formattedMessages; + } + + /** @ignore */ + async _call( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const chunks = []; + for await (const chunk of this._streamResponseChunks( + messages, + options, + runManager + )) { + chunks.push(chunk.message.content); + } + return chunks.join(""); + } +} \ No newline at end of file diff --git a/src/models/utils/ollama.ts b/src/models/utils/ollama.ts new file mode 100644 index 0000000..d3524fb --- /dev/null +++ b/src/models/utils/ollama.ts @@ -0,0 +1,201 @@ +import { IterableReadableStream } from "@langchain/core/utils/stream"; +import type { StringWithAutocomplete } from "@langchain/core/utils/types"; +import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; + +export interface OllamaInput { + embeddingOnly?: boolean; + f16KV?: boolean; + frequencyPenalty?: number; + headers?: Record; + keepAlive?: string; + logitsAll?: boolean; + lowVram?: boolean; + mainGpu?: number; + model?: string; + baseUrl?: string; + mirostat?: number; + mirostatEta?: number; + mirostatTau?: number; + numBatch?: number; + numCtx?: number; + numGpu?: number; + numGqa?: number; + numKeep?: number; + numPredict?: number; + numThread?: number; + penalizeNewline?: boolean; + presencePenalty?: number; + repeatLastN?: number; + repeatPenalty?: number; + ropeFrequencyBase?: number; + ropeFrequencyScale?: number; + temperature?: number; + stop?: string[]; + tfsZ?: number; + topK?: number; + topP?: number; + typicalP?: number; + useMLock?: boolean; + useMMap?: boolean; + vocabOnly?: boolean; + seed?: number; + format?: StringWithAutocomplete<"json">; +} + +export interface OllamaRequestParams { + model: string; + format?: StringWithAutocomplete<"json">; + images?: string[]; + options: { + embedding_only?: boolean; + f16_kv?: boolean; + frequency_penalty?: number; + logits_all?: boolean; + low_vram?: boolean; + main_gpu?: number; + mirostat?: number; + mirostat_eta?: number; + mirostat_tau?: number; + num_batch?: number; + num_ctx?: number; + num_gpu?: number; + num_gqa?: number; + num_keep?: number; + num_thread?: number; + num_predict?: number; + penalize_newline?: boolean; + presence_penalty?: number; + repeat_last_n?: number; + repeat_penalty?: number; + rope_frequency_base?: number; + rope_frequency_scale?: number; + temperature?: number; + stop?: string[]; + tfs_z?: number; + top_k?: number; + top_p?: number; + typical_p?: number; + use_mlock?: boolean; + use_mmap?: boolean; + vocab_only?: boolean; + }; +} + +export type OllamaMessage = { + role: StringWithAutocomplete<"user" | "assistant" | "system">; + content: string; + images?: string[]; +}; + +export interface OllamaGenerateRequestParams extends OllamaRequestParams { + prompt: string; +} + +export interface OllamaChatRequestParams extends OllamaRequestParams { + messages: OllamaMessage[]; +} + +export type BaseOllamaGenerationChunk = { + model: string; + created_at: string; + done: boolean; + total_duration?: number; + load_duration?: number; + prompt_eval_count?: number; + prompt_eval_duration?: number; + eval_count?: number; + eval_duration?: number; +}; + +export type OllamaGenerationChunk = BaseOllamaGenerationChunk & { + response: string; +}; + +export type OllamaChatGenerationChunk = BaseOllamaGenerationChunk & { + message: OllamaMessage; +}; + +export type OllamaCallOptions = BaseLanguageModelCallOptions & { + headers?: Record; +}; + +async function* createOllamaStream( + url: string, + params: OllamaRequestParams, + options: OllamaCallOptions +) { + let formattedUrl = url; + if (formattedUrl.startsWith("http://localhost:")) { + // Node 18 has issues with resolving "localhost" + // See https://github.com/node-fetch/node-fetch/issues/1624 + formattedUrl = formattedUrl.replace( + "http://localhost:", + "http://127.0.0.1:" + ); + } + const response = await fetch(formattedUrl, { + method: "POST", + body: JSON.stringify(params), + headers: { + "Content-Type": "application/json", + ...options.headers, + }, + signal: options.signal, + }); + if (!response.ok) { + let error; + const responseText = await response.text(); + try { + const json = JSON.parse(responseText); + error = new Error( + `Ollama call failed with status code ${response.status}: ${json.error}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + error = new Error( + `Ollama call failed with status code ${response.status}: ${responseText}` + ); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; + } + if (!response.body) { + throw new Error( + "Could not begin Ollama stream. Please check the given URL and try again." + ); + } + + const stream = IterableReadableStream.fromReadableStream(response.body); + + const decoder = new TextDecoder(); + let extra = ""; + for await (const chunk of stream) { + const decoded = extra + decoder.decode(chunk); + const lines = decoded.split("\n"); + extra = lines.pop() || ""; + for (const line of lines) { + try { + yield JSON.parse(line); + } catch (e) { + console.warn(`Received a non-JSON parseable chunk: ${line}`); + } + } + } +} + +export async function* createOllamaGenerateStream( + baseUrl: string, + params: OllamaGenerateRequestParams, + options: OllamaCallOptions +): AsyncGenerator { + yield* createOllamaStream(`${baseUrl}/api/generate`, params, options); +} + +export async function* createOllamaChatStream( + baseUrl: string, + params: OllamaChatRequestParams, + options: OllamaCallOptions +): AsyncGenerator { + yield* createOllamaStream(`${baseUrl}/api/chat`, params, options); +} \ No newline at end of file diff --git a/src/services/model-settings.ts b/src/services/model-settings.ts new file mode 100644 index 0000000..278576d --- /dev/null +++ b/src/services/model-settings.ts @@ -0,0 +1,101 @@ +import { Storage } from "@plasmohq/storage" +const storage = new Storage() + +type ModelSettings = { + f16KV?: boolean + frequencyPenalty?: number + keepAlive?: string + logitsAll?: boolean + mirostat?: number + mirostatEta?: number + mirostatTau?: number + numBatch?: number + numCtx?: number + numGpu?: number + numGqa?: number + numKeep?: number + numPredict?: number + numThread?: number + penalizeNewline?: boolean + presencePenalty?: number + repeatLastN?: number + repeatPenalty?: number + ropeFrequencyBase?: number + ropeFrequencyScale?: number + temperature?: number + tfsZ?: number + topK?: number + topP?: number + typicalP?: number + useMLock?: boolean + useMMap?: boolean + vocabOnly?: boolean +} + +const keys = [ + "f16KV", + "frequencyPenalty", + "keepAlive", + "logitsAll", + "mirostat", + "mirostatEta", + "mirostatTau", + "numBatch", + "numCtx", + "numGpu", + "numGqa", + "numKeep", + "numPredict", + "numThread", + "penalizeNewline", + "presencePenalty", + "repeatLastN", + "repeatPenalty", + "ropeFrequencyBase", + "ropeFrequencyScale", + "temperature", + "tfsZ", + "topK", + "topP", + "typicalP", + "useMLock", + "useMMap", + "vocabOnly" +] + +const getAllModelSettings = async () => { + try { + const settings: ModelSettings = {} + for (const key of keys) { + const value = await storage.get(key) + settings[key] = value + if (!value && key === "keepAlive") { + settings[key] = "5m" + } + + } + return settings + } catch (error) { + console.error(error) + return {} + } +} + +const setModelSetting = async (key: string, + value: string | number | boolean) => { + await storage.set(key, value) +} + +export const getAllDefaultModelSettings = async (): Promise => { + const settings: ModelSettings = {} + for (const key of keys) { + const value = await storage.get(key) + settings[key] = value + if (!value && key === "keepAlive") { + settings[key] = "5m" + } + } + return settings +} + +export { getAllModelSettings, setModelSetting } \ No newline at end of file diff --git a/src/store/model.tsx b/src/store/model.tsx new file mode 100644 index 0000000..8fadca8 --- /dev/null +++ b/src/store/model.tsx @@ -0,0 +1,136 @@ +import { create } from "zustand" + +type CurrentChatModelSettings = { + f16KV?: boolean + frequencyPenalty?: number + keepAlive?: string + logitsAll?: boolean + mirostat?: number + mirostatEta?: number + mirostatTau?: number + numBatch?: number + numCtx?: number + numGpu?: number + numGqa?: number + numKeep?: number + numPredict?: number + numThread?: number + penalizeNewline?: boolean + presencePenalty?: number + repeatLastN?: number + repeatPenalty?: number + ropeFrequencyBase?: number + ropeFrequencyScale?: number + temperature?: number + tfsZ?: number + topK?: number + topP?: number + typicalP?: number + useMLock?: boolean + useMMap?: boolean + vocabOnly?: boolean + seed?: number + + setF16KV?: (f16KV: boolean) => void + setFrequencyPenalty?: (frequencyPenalty: number) => void + setKeepAlive?: (keepAlive: string) => void + setLogitsAll?: (logitsAll: boolean) => void + setMirostat?: (mirostat: number) => void + setMirostatEta?: (mirostatEta: number) => void + setMirostatTau?: (mirostatTau: number) => void + setNumBatch?: (numBatch: number) => void + setNumCtx?: (numCtx: number) => void + setNumGpu?: (numGpu: number) => void + setNumGqa?: (numGqa: number) => void + setNumKeep?: (numKeep: number) => void + setNumPredict?: (numPredict: number) => void + setNumThread?: (numThread: number) => void + setPenalizeNewline?: (penalizeNewline: boolean) => void + setPresencePenalty?: (presencePenalty: number) => void + setRepeatLastN?: (repeatLastN: number) => void + setRepeatPenalty?: (repeatPenalty: number) => void + setRopeFrequencyBase?: (ropeFrequencyBase: number) => void + setRopeFrequencyScale?: (ropeFrequencyScale: number) => void + setTemperature?: (temperature: number) => void + setTfsZ?: (tfsZ: number) => void + setTopK?: (topK: number) => void + setTopP?: (topP: number) => void + setTypicalP?: (typicalP: number) => void + setUseMLock?: (useMLock: boolean) => void + setUseMMap?: (useMMap: boolean) => void + setVocabOnly?: (vocabOnly: boolean) => void + seetSeed?: (seed: number) => void + + setX: (key: string, value: any) => void + reset: () => void +} + +export const useStoreChatModelSettings = create( + (set) => ({ + setF16KV: (f16KV: boolean) => set({ f16KV }), + setFrequencyPenalty: (frequencyPenalty: number) => + set({ frequencyPenalty }), + setKeepAlive: (keepAlive: string) => set({ keepAlive }), + setLogitsAll: (logitsAll: boolean) => set({ logitsAll }), + setMirostat: (mirostat: number) => set({ mirostat }), + setMirostatEta: (mirostatEta: number) => set({ mirostatEta }), + setMirostatTau: (mirostatTau: number) => set({ mirostatTau }), + setNumBatch: (numBatch: number) => set({ numBatch }), + setNumCtx: (numCtx: number) => set({ numCtx }), + setNumGpu: (numGpu: number) => set({ numGpu }), + setNumGqa: (numGqa: number) => set({ numGqa }), + setNumKeep: (numKeep: number) => set({ numKeep }), + setNumPredict: (numPredict: number) => set({ numPredict }), + setNumThread: (numThread: number) => set({ numThread }), + setPenalizeNewline: (penalizeNewline: boolean) => set({ penalizeNewline }), + setPresencePenalty: (presencePenalty: number) => set({ presencePenalty }), + setRepeatLastN: (repeatLastN: number) => set({ repeatLastN }), + setRepeatPenalty: (repeatPenalty: number) => set({ repeatPenalty }), + setRopeFrequencyBase: (ropeFrequencyBase: number) => + set({ ropeFrequencyBase }), + setRopeFrequencyScale: (ropeFrequencyScale: number) => + set({ ropeFrequencyScale }), + setTemperature: (temperature: number) => set({ temperature }), + setTfsZ: (tfsZ: number) => set({ tfsZ }), + setTopK: (topK: number) => set({ topK }), + setTopP: (topP: number) => set({ topP }), + setTypicalP: (typicalP: number) => set({ typicalP }), + setUseMLock: (useMLock: boolean) => set({ useMLock }), + setUseMMap: (useMMap: boolean) => set({ useMMap }), + setVocabOnly: (vocabOnly: boolean) => set({ vocabOnly }), + seetSeed: (seed: number) => set({ seed }), + setX: (key: string, value: any) => set({ [key]: value }), + reset: () => + set({ + f16KV: undefined, + frequencyPenalty: undefined, + keepAlive: undefined, + logitsAll: undefined, + mirostat: undefined, + mirostatEta: undefined, + mirostatTau: undefined, + numBatch: undefined, + numCtx: undefined, + numGpu: undefined, + numGqa: undefined, + numKeep: undefined, + numPredict: undefined, + numThread: undefined, + penalizeNewline: undefined, + presencePenalty: undefined, + repeatLastN: undefined, + repeatPenalty: undefined, + ropeFrequencyBase: undefined, + ropeFrequencyScale: undefined, + temperature: undefined, + tfsZ: undefined, + topK: undefined, + topP: undefined, + typicalP: undefined, + useMLock: undefined, + useMMap: undefined, + vocabOnly: undefined, + seed: undefined + }) + }) +) diff --git a/wxt.config.ts b/wxt.config.ts index aa890c4..69fab0a 100644 --- a/wxt.config.ts +++ b/wxt.config.ts @@ -48,7 +48,7 @@ export default defineConfig({ outDir: "build", manifest: { - version: "1.1.8", + version: "1.1.9", name: process.env.TARGET === "firefox" ? "Page Assist - A Web UI for Local AI Models"