feat: Add option to restore last used model for previous chats

This commit is contained in:
n4ze3m 2024-07-08 23:56:25 +05:30
parent c80ef8bf1f
commit 031e74f609
5 changed files with 155 additions and 86 deletions

View File

@ -23,6 +23,9 @@
}, },
"hideCurrentChatModelSettings": { "hideCurrentChatModelSettings": {
"label": "Hide the current Chat Model Settings" "label": "Hide the current Chat Model Settings"
},
"restoreLastChatModel": {
"label": "Restore last used model for previous chats"
} }
}, },
"webSearch": { "webSearch": {

View File

@ -24,6 +24,11 @@ export const GeneralSettings = () => {
false false
) )
const [restoreLastChatModel, setRestoreLastChatModel] = useStorage(
"restoreLastChatModel",
false
)
const [hideCurrentChatModelSettings, setHideCurrentChatModelSettings] = const [hideCurrentChatModelSettings, setHideCurrentChatModelSettings] =
useStorage("hideCurrentChatModelSettings", false) useStorage("hideCurrentChatModelSettings", false)
@ -107,6 +112,18 @@ export const GeneralSettings = () => {
onChange={(checked) => setHideCurrentChatModelSettings(checked)} onChange={(checked) => setHideCurrentChatModelSettings(checked)}
/> />
</div> </div>
<div className="flex flex-row justify-between">
<div className="inline-flex items-center gap-2">
<span className="text-gray-700 dark:text-neutral-50">
{t("generalSettings.settings.restoreLastChatModel.label")}
</span>
</div>
<Switch
checked={restoreLastChatModel}
onChange={(checked) => setRestoreLastChatModel(checked)}
/>
</div>
<div className="flex flex-row justify-between"> <div className="flex flex-row justify-between">
<span className="text-gray-700 dark:text-neutral-50 "> <span className="text-gray-700 dark:text-neutral-50 ">
{t("generalSettings.settings.darkMode.label")} {t("generalSettings.settings.darkMode.label")}

View File

@ -11,14 +11,24 @@ import { useMessageOption } from "~/hooks/useMessageOption"
import { PencilIcon, Trash2 } from "lucide-react" import { PencilIcon, Trash2 } from "lucide-react"
import { useNavigate } from "react-router-dom" import { useNavigate } from "react-router-dom"
import { useTranslation } from "react-i18next" import { useTranslation } from "react-i18next"
import {
getLastUsedChatModel,
lastUsedChatModelEnabled
} from "@/services/model-settings"
type Props = { type Props = {
onClose: () => void onClose: () => void
} }
export const Sidebar = ({ onClose }: Props) => { export const Sidebar = ({ onClose }: Props) => {
const { setMessages, setHistory, setHistoryId, historyId, clearChat } = const {
useMessageOption() setMessages,
setHistory,
setHistoryId,
historyId,
clearChat,
setSelectedModel
} = useMessageOption()
const { t } = useTranslation(["option", "common"]) const { t } = useTranslation(["option", "common"])
const client = useQueryClient() const client = useQueryClient()
const navigate = useNavigate() const navigate = useNavigate()
@ -88,6 +98,13 @@ export const Sidebar = ({ onClose }: Props) => {
setHistoryId(chat.id) setHistoryId(chat.id)
setHistory(formatToChatHistory(history)) setHistory(formatToChatHistory(history))
setMessages(formatToMessage(history)) setMessages(formatToMessage(history))
const isLastUsedChatModel = await lastUsedChatModelEnabled()
if (isLastUsedChatModel) {
const currentChatModel = await getLastUsedChatModel(chat.id)
if (currentChatModel) {
setSelectedModel(currentChatModel)
}
}
navigate("/") navigate("/")
onClose() onClose()
}}> }}>

View File

@ -1,4 +1,5 @@
import { saveHistory, saveMessage } from "@/db" import { saveHistory, saveMessage } from "@/db"
import { setLastUsedChatModel } from "@/services/model-settings"
import { ChatHistory } from "@/store/option" import { ChatHistory } from "@/store/option"
export const saveMessageOnError = async ({ export const saveMessageOnError = async ({
@ -23,7 +24,7 @@ export const saveMessageOnError = async ({
historyId: string | null historyId: string | null
selectedModel: string selectedModel: string
setHistoryId: (historyId: string) => void setHistoryId: (historyId: string) => void
isRegenerating: boolean, isRegenerating: boolean
message_source?: "copilot" | "web-ui" message_source?: "copilot" | "web-ui"
}) => { }) => {
if ( if (
@ -66,6 +67,7 @@ export const saveMessageOnError = async ({
[], [],
2 2
) )
await setLastUsedChatModel(historyId, selectedModel)
} else { } else {
const newHistoryId = await saveHistory(userMessage, false, message_source) const newHistoryId = await saveHistory(userMessage, false, message_source)
if (!isRegenerating) { if (!isRegenerating) {
@ -89,6 +91,7 @@ export const saveMessageOnError = async ({
2 2
) )
setHistoryId(newHistoryId.id) setHistoryId(newHistoryId.id)
await setLastUsedChatModel(newHistoryId.id, selectedModel)
} }
return true return true
@ -115,7 +118,7 @@ export const saveMessageOnSuccess = async ({
message: string message: string
image: string image: string
fullText: string fullText: string
source: any[], source: any[]
message_source?: "copilot" | "web-ui" message_source?: "copilot" | "web-ui"
}) => { }) => {
if (historyId) { if (historyId) {
@ -139,6 +142,7 @@ export const saveMessageOnSuccess = async ({
source, source,
2 2
) )
await setLastUsedChatModel(historyId, selectedModel!)
} else { } else {
const newHistoryId = await saveHistory(message, false, message_source) const newHistoryId = await saveHistory(message, false, message_source)
await saveMessage( await saveMessage(
@ -160,5 +164,6 @@ export const saveMessageOnSuccess = async ({
2 2
) )
setHistoryId(newHistoryId.id) setHistoryId(newHistoryId.id)
await setLastUsedChatModel(newHistoryId.id, selectedModel!)
} }
} }

View File

@ -2,100 +2,127 @@ import { Storage } from "@plasmohq/storage"
const storage = new Storage() const storage = new Storage()
type ModelSettings = { type ModelSettings = {
f16KV?: boolean f16KV?: boolean
frequencyPenalty?: number frequencyPenalty?: number
keepAlive?: string keepAlive?: string
logitsAll?: boolean logitsAll?: boolean
mirostat?: number mirostat?: number
mirostatEta?: number mirostatEta?: number
mirostatTau?: number mirostatTau?: number
numBatch?: number numBatch?: number
numCtx?: number numCtx?: number
numGpu?: number numGpu?: number
numGqa?: number numGqa?: number
numKeep?: number numKeep?: number
numPredict?: number numPredict?: number
numThread?: number numThread?: number
penalizeNewline?: boolean penalizeNewline?: boolean
presencePenalty?: number presencePenalty?: number
repeatLastN?: number repeatLastN?: number
repeatPenalty?: number repeatPenalty?: number
ropeFrequencyBase?: number ropeFrequencyBase?: number
ropeFrequencyScale?: number ropeFrequencyScale?: number
temperature?: number temperature?: number
tfsZ?: number tfsZ?: number
topK?: number topK?: number
topP?: number topP?: number
typicalP?: number typicalP?: number
useMLock?: boolean useMLock?: boolean
useMMap?: boolean useMMap?: boolean
vocabOnly?: boolean vocabOnly?: boolean
} }
const keys = [ const keys = [
"f16KV", "f16KV",
"frequencyPenalty", "frequencyPenalty",
"keepAlive", "keepAlive",
"logitsAll", "logitsAll",
"mirostat", "mirostat",
"mirostatEta", "mirostatEta",
"mirostatTau", "mirostatTau",
"numBatch", "numBatch",
"numCtx", "numCtx",
"numGpu", "numGpu",
"numGqa", "numGqa",
"numKeep", "numKeep",
"numPredict", "numPredict",
"numThread", "numThread",
"penalizeNewline", "penalizeNewline",
"presencePenalty", "presencePenalty",
"repeatLastN", "repeatLastN",
"repeatPenalty", "repeatPenalty",
"ropeFrequencyBase", "ropeFrequencyBase",
"ropeFrequencyScale", "ropeFrequencyScale",
"temperature", "temperature",
"tfsZ", "tfsZ",
"topK", "topK",
"topP", "topP",
"typicalP", "typicalP",
"useMLock", "useMLock",
"useMMap", "useMMap",
"vocabOnly" "vocabOnly"
] ]
const getAllModelSettings = async () => { const getAllModelSettings = async () => {
try { try {
const settings: ModelSettings = {} const settings: ModelSettings = {}
for (const key of keys) { for (const key of keys) {
const value = await storage.get(key) const value = await storage.get(key)
settings[key] = value settings[key] = value
if (!value && key === "keepAlive") { if (!value && key === "keepAlive") {
settings[key] = "5m" settings[key] = "5m"
} }
}
return settings
} catch (error) {
console.error(error)
return {}
} }
return settings
} catch (error) {
console.error(error)
return {}
}
} }
const setModelSetting = async (key: string, const setModelSetting = async (
value: string | number | boolean) => { key: string,
await storage.set(key, value) value: string | number | boolean
) => {
await storage.set(key, value)
} }
export const getAllDefaultModelSettings = async (): Promise<ModelSettings> => { export const getAllDefaultModelSettings = async (): Promise<ModelSettings> => {
const settings: ModelSettings = {} const settings: ModelSettings = {}
for (const key of keys) { for (const key of keys) {
const value = await storage.get(key) const value = await storage.get(key)
settings[key] = value settings[key] = value
if (!value && key === "keepAlive") { if (!value && key === "keepAlive") {
settings[key] = "5m" settings[key] = "5m"
}
} }
return settings }
return settings
}
export const lastUsedChatModelEnabled = async (): Promise<boolean> => {
const isLastUsedChatModelEnabled = await storage.get<boolean | undefined>(
"restoreLastChatModel"
)
return isLastUsedChatModelEnabled ?? false
}
export const setLastUsedChatModelEnabled = async (
enabled: boolean
): Promise<void> => {
await storage.set("restoreLastChatModel", enabled)
}
export const getLastUsedChatModel = async (
historyId: string
): Promise<string | undefined> => {
return await storage.get<string | undefined>(`lastUsedChatModel-${historyId}`)
}
export const setLastUsedChatModel = async (
historyId: string,
model: string
): Promise<void> => {
await storage.set(`lastUsedChatModel-${historyId}`, model)
} }
export { getAllModelSettings, setModelSetting } export { getAllModelSettings, setModelSetting }