From 031e74f60922ebe1bef07d2fec7052f970ca7af7 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Mon, 8 Jul 2024 23:56:25 +0530 Subject: [PATCH] feat: Add option to restore last used model for previous chats --- src/assets/locale/en/settings.json | 3 + .../Option/Settings/general-settings.tsx | 17 ++ src/components/Option/Sidebar.tsx | 21 +- src/hooks/chat-helper/index.ts | 9 +- src/services/model-settings.ts | 191 ++++++++++-------- 5 files changed, 155 insertions(+), 86 deletions(-) diff --git a/src/assets/locale/en/settings.json b/src/assets/locale/en/settings.json index fb798d2..ed76ddf 100644 --- a/src/assets/locale/en/settings.json +++ b/src/assets/locale/en/settings.json @@ -23,6 +23,9 @@ }, "hideCurrentChatModelSettings": { "label": "Hide the current Chat Model Settings" + }, + "restoreLastChatModel": { + "label": "Restore last used model for previous chats" } }, "webSearch": { diff --git a/src/components/Option/Settings/general-settings.tsx b/src/components/Option/Settings/general-settings.tsx index 1b1e7cc..367f283 100644 --- a/src/components/Option/Settings/general-settings.tsx +++ b/src/components/Option/Settings/general-settings.tsx @@ -24,6 +24,11 @@ export const GeneralSettings = () => { false ) + const [restoreLastChatModel, setRestoreLastChatModel] = useStorage( + "restoreLastChatModel", + false + ) + const [hideCurrentChatModelSettings, setHideCurrentChatModelSettings] = useStorage("hideCurrentChatModelSettings", false) @@ -107,6 +112,18 @@ export const GeneralSettings = () => { onChange={(checked) => setHideCurrentChatModelSettings(checked)} /> +
+
+ + {t("generalSettings.settings.restoreLastChatModel.label")} + +
+ + setRestoreLastChatModel(checked)} + /> +
{t("generalSettings.settings.darkMode.label")} diff --git a/src/components/Option/Sidebar.tsx b/src/components/Option/Sidebar.tsx index e18f567..9535360 100644 --- a/src/components/Option/Sidebar.tsx +++ b/src/components/Option/Sidebar.tsx @@ -11,14 +11,24 @@ import { useMessageOption } from "~/hooks/useMessageOption" import { PencilIcon, Trash2 } from "lucide-react" import { useNavigate } from "react-router-dom" import { useTranslation } from "react-i18next" +import { + getLastUsedChatModel, + lastUsedChatModelEnabled +} from "@/services/model-settings" type Props = { onClose: () => void } export const Sidebar = ({ onClose }: Props) => { - const { setMessages, setHistory, setHistoryId, historyId, clearChat } = - useMessageOption() + const { + setMessages, + setHistory, + setHistoryId, + historyId, + clearChat, + setSelectedModel + } = useMessageOption() const { t } = useTranslation(["option", "common"]) const client = useQueryClient() const navigate = useNavigate() @@ -88,6 +98,13 @@ export const Sidebar = ({ onClose }: Props) => { setHistoryId(chat.id) setHistory(formatToChatHistory(history)) setMessages(formatToMessage(history)) + const isLastUsedChatModel = await lastUsedChatModelEnabled() + if (isLastUsedChatModel) { + const currentChatModel = await getLastUsedChatModel(chat.id) + if (currentChatModel) { + setSelectedModel(currentChatModel) + } + } navigate("/") onClose() }}> diff --git a/src/hooks/chat-helper/index.ts b/src/hooks/chat-helper/index.ts index 73f8494..f7b690f 100644 --- a/src/hooks/chat-helper/index.ts +++ b/src/hooks/chat-helper/index.ts @@ -1,4 +1,5 @@ import { saveHistory, saveMessage } from "@/db" +import { setLastUsedChatModel } from "@/services/model-settings" import { ChatHistory } from "@/store/option" export const saveMessageOnError = async ({ @@ -23,7 +24,7 @@ export const saveMessageOnError = async ({ historyId: string | null selectedModel: string setHistoryId: (historyId: string) => void - isRegenerating: boolean, + isRegenerating: boolean message_source?: "copilot" | "web-ui" }) => { if ( @@ -66,6 +67,7 @@ export const saveMessageOnError = async ({ [], 2 ) + await setLastUsedChatModel(historyId, selectedModel) } else { const newHistoryId = await saveHistory(userMessage, false, message_source) if (!isRegenerating) { @@ -89,6 +91,7 @@ export const saveMessageOnError = async ({ 2 ) setHistoryId(newHistoryId.id) + await setLastUsedChatModel(newHistoryId.id, selectedModel) } return true @@ -115,7 +118,7 @@ export const saveMessageOnSuccess = async ({ message: string image: string fullText: string - source: any[], + source: any[] message_source?: "copilot" | "web-ui" }) => { if (historyId) { @@ -139,6 +142,7 @@ export const saveMessageOnSuccess = async ({ source, 2 ) + await setLastUsedChatModel(historyId, selectedModel!) } else { const newHistoryId = await saveHistory(message, false, message_source) await saveMessage( @@ -160,5 +164,6 @@ export const saveMessageOnSuccess = async ({ 2 ) setHistoryId(newHistoryId.id) + await setLastUsedChatModel(newHistoryId.id, selectedModel!) } } diff --git a/src/services/model-settings.ts b/src/services/model-settings.ts index 278576d..501bf59 100644 --- a/src/services/model-settings.ts +++ b/src/services/model-settings.ts @@ -2,100 +2,127 @@ import { Storage } from "@plasmohq/storage" const storage = new Storage() type ModelSettings = { - f16KV?: boolean - frequencyPenalty?: number - keepAlive?: string - logitsAll?: boolean - mirostat?: number - mirostatEta?: number - mirostatTau?: number - numBatch?: number - numCtx?: number - numGpu?: number - numGqa?: number - numKeep?: number - numPredict?: number - numThread?: number - penalizeNewline?: boolean - presencePenalty?: number - repeatLastN?: number - repeatPenalty?: number - ropeFrequencyBase?: number - ropeFrequencyScale?: number - temperature?: number - tfsZ?: number - topK?: number - topP?: number - typicalP?: number - useMLock?: boolean - useMMap?: boolean - vocabOnly?: boolean + f16KV?: boolean + frequencyPenalty?: number + keepAlive?: string + logitsAll?: boolean + mirostat?: number + mirostatEta?: number + mirostatTau?: number + numBatch?: number + numCtx?: number + numGpu?: number + numGqa?: number + numKeep?: number + numPredict?: number + numThread?: number + penalizeNewline?: boolean + presencePenalty?: number + repeatLastN?: number + repeatPenalty?: number + ropeFrequencyBase?: number + ropeFrequencyScale?: number + temperature?: number + tfsZ?: number + topK?: number + topP?: number + typicalP?: number + useMLock?: boolean + useMMap?: boolean + vocabOnly?: boolean } const keys = [ - "f16KV", - "frequencyPenalty", - "keepAlive", - "logitsAll", - "mirostat", - "mirostatEta", - "mirostatTau", - "numBatch", - "numCtx", - "numGpu", - "numGqa", - "numKeep", - "numPredict", - "numThread", - "penalizeNewline", - "presencePenalty", - "repeatLastN", - "repeatPenalty", - "ropeFrequencyBase", - "ropeFrequencyScale", - "temperature", - "tfsZ", - "topK", - "topP", - "typicalP", - "useMLock", - "useMMap", - "vocabOnly" + "f16KV", + "frequencyPenalty", + "keepAlive", + "logitsAll", + "mirostat", + "mirostatEta", + "mirostatTau", + "numBatch", + "numCtx", + "numGpu", + "numGqa", + "numKeep", + "numPredict", + "numThread", + "penalizeNewline", + "presencePenalty", + "repeatLastN", + "repeatPenalty", + "ropeFrequencyBase", + "ropeFrequencyScale", + "temperature", + "tfsZ", + "topK", + "topP", + "typicalP", + "useMLock", + "useMMap", + "vocabOnly" ] const getAllModelSettings = async () => { - try { - const settings: ModelSettings = {} - for (const key of keys) { - const value = await storage.get(key) - settings[key] = value - if (!value && key === "keepAlive") { - settings[key] = "5m" - } - - } - return settings - } catch (error) { - console.error(error) - return {} + try { + const settings: ModelSettings = {} + for (const key of keys) { + const value = await storage.get(key) + settings[key] = value + if (!value && key === "keepAlive") { + settings[key] = "5m" + } } + return settings + } catch (error) { + console.error(error) + return {} + } } -const setModelSetting = async (key: string, - value: string | number | boolean) => { - await storage.set(key, value) +const setModelSetting = async ( + key: string, + value: string | number | boolean +) => { + await storage.set(key, value) } export const getAllDefaultModelSettings = async (): Promise => { - const settings: ModelSettings = {} - for (const key of keys) { - const value = await storage.get(key) - settings[key] = value - if (!value && key === "keepAlive") { - settings[key] = "5m" - } + const settings: ModelSettings = {} + for (const key of keys) { + const value = await storage.get(key) + settings[key] = value + if (!value && key === "keepAlive") { + settings[key] = "5m" } - return settings + } + return settings } -export { getAllModelSettings, setModelSetting } \ No newline at end of file +export const lastUsedChatModelEnabled = async (): Promise => { + const isLastUsedChatModelEnabled = await storage.get( + "restoreLastChatModel" + ) + return isLastUsedChatModelEnabled ?? false +} + +export const setLastUsedChatModelEnabled = async ( + enabled: boolean +): Promise => { + await storage.set("restoreLastChatModel", enabled) +} + +export const getLastUsedChatModel = async ( + historyId: string +): Promise => { + return await storage.get(`lastUsedChatModel-${historyId}`) +} + +export const setLastUsedChatModel = async ( + historyId: string, + model: string +): Promise => { + await storage.set(`lastUsedChatModel-${historyId}`, model) +} + +export { getAllModelSettings, setModelSetting }