From 2a2610afb8643bfefc1751ec6fa931cd381a7e27 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sun, 29 Sep 2024 19:12:19 +0530 Subject: [PATCH] feat: add model management UI This commit introduces a new UI for managing models within the OpenAI integration. This UI allows users to view, add, and delete OpenAI models associated with their OpenAI providers. It includes functionality to fetch and refresh model lists, as well as to search for specific models. These changes enhance the user experience by offering greater control over their OpenAI model interactions. This commit also includes improvements to the existing OpenAI configuration UI, enabling users to seamlessly manage multiple OpenAI providers and associated models. --- src/assets/locale/en/common.json | 6 +- src/assets/locale/en/openai.json | 30 ++- .../Option/Models/CustomModelsTable.tsx | 85 ++++++++ .../Option/Models/OllamaModelsTable.tsx | 199 ++++++++++++++++++ src/components/Option/Models/index.tsx | 176 +++------------- .../Option/Settings/openai-fetch-model.tsx | 132 ++++++++++++ src/components/Option/Settings/openai.tsx | 34 ++- src/db/models.ts | 176 ++++++++++++++++ src/db/openai.ts | 27 ++- src/libs/openai.ts | 25 +++ 10 files changed, 729 insertions(+), 161 deletions(-) create mode 100644 src/components/Option/Models/CustomModelsTable.tsx create mode 100644 src/components/Option/Models/OllamaModelsTable.tsx create mode 100644 src/components/Option/Settings/openai-fetch-model.tsx create mode 100644 src/db/models.ts create mode 100644 src/libs/openai.ts diff --git a/src/assets/locale/en/common.json b/src/assets/locale/en/common.json index fe6fae2..70fde08 100644 --- a/src/assets/locale/en/common.json +++ b/src/assets/locale/en/common.json @@ -96,5 +96,9 @@ "translate": "Translate", "custom": "Custom" }, - "citations": "Citations" + "citations": "Citations", + "segmented": { + "ollama": "Ollama Models", + "custom": "Custom Models" + } } \ No newline at end of file diff --git a/src/assets/locale/en/openai.json b/src/assets/locale/en/openai.json index e9babc3..48f0430 100644 --- a/src/assets/locale/en/openai.json +++ b/src/assets/locale/en/openai.json @@ -26,13 +26,37 @@ "required": "API Key is required.", "placeholder": "Enter API Key" }, - "submit": "Submit", + "submit": "Save", "update": "Update", - "deleteConfirm": "Are you sure you want to delete this provider?" + "deleteConfirm": "Are you sure you want to delete this provider?", + "model": { + "title": "Model List", + "subheading": "Please select the models you want to use with this provider.", + "success": "Successfully added new models." + } }, "addSuccess": "Provider added successfully.", "deleteSuccess": "Provider deleted successfully.", "updateSuccess": "Provider updated successfully.", "delete": "Delete", - "edit": "Edit" + "edit": "Edit", + "refetch": "Refech Model List", + "searchModel": "Search Model", + "selectAll": "Select All", + "save": "Save", + "saving": "Saving...", + "manageModels": { + "columns": { + "name": "Model Name", + "model_id": "Model ID", + "provider": "Provider Name", + "actions": "Action" + }, + "tooltip": { + "delete": "Delete" + }, + "confirm": { + "delete": "Are you sure you want to delete this model?" + } + } } \ No newline at end of file diff --git a/src/components/Option/Models/CustomModelsTable.tsx b/src/components/Option/Models/CustomModelsTable.tsx new file mode 100644 index 0000000..4bc57b9 --- /dev/null +++ b/src/components/Option/Models/CustomModelsTable.tsx @@ -0,0 +1,85 @@ +import { getAllCustomModels, deleteModel } from "@/db/models" +import { useStorage } from "@plasmohq/storage/hook" +import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query" +import { Skeleton, Table, Tooltip } from "antd" +import { Trash2 } from "lucide-react" +import { useTranslation } from "react-i18next" + +export const CustomModelsTable = () => { + const [selectedModel, setSelectedModel] = useStorage("selectedModel") + + const { t } = useTranslation(["openai", "common"]) + + + const queryClient = useQueryClient() + + const { data, status } = useQuery({ + queryKey: ["fetchCustomModels"], + queryFn: () => getAllCustomModels() + }) + + const { mutate: deleteCustomModel } = useMutation({ + mutationFn: deleteModel, + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: ["fetchCustomModels"] + }) + } + }) + + + return ( +
+
+ {status === "pending" && } + + {status === "success" && ( +
+ record.provider.name + }, + { + title: t("manageModels.columns.actions"), + render: (_, record) => ( + + + + ) + } + ]} + bordered + dataSource={data} + /> + + )} + + + ) +} diff --git a/src/components/Option/Models/OllamaModelsTable.tsx b/src/components/Option/Models/OllamaModelsTable.tsx new file mode 100644 index 0000000..72335fa --- /dev/null +++ b/src/components/Option/Models/OllamaModelsTable.tsx @@ -0,0 +1,199 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" +import { Skeleton, Table, Tag, Tooltip, notification, Modal, Input } from "antd" +import { bytePerSecondFormatter } from "~/libs/byte-formater" +import { deleteModel, getAllModels } from "~/services/ollama" +import dayjs from "dayjs" +import relativeTime from "dayjs/plugin/relativeTime" +import { useForm } from "@mantine/form" +import { RotateCcw, Trash2 } from "lucide-react" +import { useTranslation } from "react-i18next" +import { useStorage } from "@plasmohq/storage/hook" + +dayjs.extend(relativeTime) + +export const OllamaModelsTable = () => { + const queryClient = useQueryClient() + const { t } = useTranslation(["settings", "common"]) + const [selectedModel, setSelectedModel] = useStorage("selectedModel") + + const form = useForm({ + initialValues: { + model: "" + } + }) + + const { data, status } = useQuery({ + queryKey: ["fetchAllModels"], + queryFn: () => getAllModels({ returnEmpty: true }) + }) + + const { mutate: deleteOllamaModel } = useMutation({ + mutationFn: deleteModel, + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: ["fetchAllModels"] + }) + notification.success({ + message: t("manageModels.notification.success"), + description: t("manageModels.notification.successDeleteDescription") + }) + }, + onError: (error) => { + notification.error({ + message: "Error", + description: error?.message || t("manageModels.notification.someError") + }) + } + }) + + const pullModel = async (modelName: string) => { + notification.info({ + message: t("manageModels.notification.pullModel"), + description: t("manageModels.notification.pullModelDescription", { + modelName + }) + }) + + form.reset() + + browser.runtime.sendMessage({ + type: "pull_model", + modelName + }) + + return true + } + + const { mutate: pullOllamaModel } = useMutation({ + mutationFn: pullModel + }) + + return ( +
+
+ {status === "pending" && } + + {status === "success" && ( +
+
( + + {`${text?.slice(0, 5)}...${text?.slice(-4)}`} + + ) + }, + { + title: t("manageModels.columns.modifiedAt"), + dataIndex: "modified_at", + key: "modified_at", + render: (text: string) => dayjs(text).fromNow(true) + }, + { + title: t("manageModels.columns.size"), + dataIndex: "size", + key: "size", + render: (text: number) => bytePerSecondFormatter(text) + }, + { + title: t("manageModels.columns.actions"), + render: (_, record) => ( +
+ + + + + + +
+ ) + } + ]} + expandable={{ + expandedRowRender: (record) => ( +
+ ), + defaultExpandAllRows: false + }} + bordered + dataSource={data} + rowKey={(record) => `${record.model}-${record.digest}`} + /> + + )} + + + ) +} diff --git a/src/components/Option/Models/index.tsx b/src/components/Option/Models/index.tsx index 1fd12ba..af3c866 100644 --- a/src/components/Option/Models/index.tsx +++ b/src/components/Option/Models/index.tsx @@ -1,22 +1,30 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" -import { Skeleton, Table, Tag, Tooltip, notification, Modal, Input } from "antd" -import { bytePerSecondFormatter } from "~/libs/byte-formater" -import { deleteModel, getAllModels } from "~/services/ollama" +import { + Skeleton, + Table, + Tag, + Tooltip, + notification, + Modal, + Input, + Segmented +} from "antd" import dayjs from "dayjs" import relativeTime from "dayjs/plugin/relativeTime" import { useState } from "react" import { useForm } from "@mantine/form" -import { Download, RotateCcw, Trash2 } from "lucide-react" +import { Download } from "lucide-react" import { useTranslation } from "react-i18next" -import { useStorage } from "@plasmohq/storage/hook" +import { OllamaModelsTable } from "./OllamaModelsTable" +import { CustomModelsTable } from "./CustomModelsTable" dayjs.extend(relativeTime) export const ModelsBody = () => { - const queryClient = useQueryClient() const [open, setOpen] = useState(false) - const { t } = useTranslation(["settings", "common"]) - const [selectedModel, setSelectedModel] = useStorage("selectedModel") + const [segmented, setSegmented] = useState("ollama") + + const { t } = useTranslation(["settings", "common", "openai"]) const form = useForm({ initialValues: { @@ -24,30 +32,6 @@ export const ModelsBody = () => { } }) - const { data, status } = useQuery({ - queryKey: ["fetchAllModels"], - queryFn: () => getAllModels({ returnEmpty: true }) - }) - - const { mutate: deleteOllamaModel } = useMutation({ - mutationFn: deleteModel, - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: ["fetchAllModels"] - }) - notification.success({ - message: t("manageModels.notification.success"), - description: t("manageModels.notification.successDeleteDescription") - }) - }, - onError: (error) => { - notification.error({ - message: "Error", - description: error?.message || t("manageModels.notification.someError") - }) - } - }) - const pullModel = async (modelName: string) => { notification.info({ message: t("manageModels.notification.pullModel"), @@ -86,130 +70,26 @@ export const ModelsBody = () => { - - - {status === "pending" && } - - {status === "success" && ( -
-
+ ( - - {`${text?.slice(0, 5)}...${text?.slice(-4)}`} - - ) - }, - { - title: t("manageModels.columns.modifiedAt"), - dataIndex: "modified_at", - key: "modified_at", - render: (text: string) => dayjs(text).fromNow(true) - }, - { - title: t("manageModels.columns.size"), - dataIndex: "size", - key: "size", - render: (text: number) => bytePerSecondFormatter(text) - }, - { - title: t("manageModels.columns.actions"), - render: (_, record) => ( -
- - - - - - -
- ) + label: t("common:segmented.custom"), + value: "custom" } ]} - expandable={{ - expandedRowRender: (record) => ( -
- ), - defaultExpandAllRows: false + onChange={(value) => { + setSegmented(value) }} - bordered - dataSource={data} - rowKey={(record) => `${record.model}-${record.digest}`} /> - )} + + + {segmented === "ollama" ? : } void +} + +export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { + const { t } = useTranslation(["openai"]) + const [selectedModels, setSelectedModels] = useState([]) + const [searchTerm, setSearchTerm] = useState("") + + const { data, status } = useQuery({ + queryKey: ["openAIConfigs", openaiId], + queryFn: async () => { + const config = await getOpenAIConfigById(openaiId) + const models = await getAllOpenAIModels(config.baseUrl, config.apiKey) + return models + }, + enabled: !!openaiId + }) + + const filteredModels = useMemo(() => { + return ( + data?.filter((model) => + (model.name ?? model.id) + .toLowerCase() + .includes(searchTerm.toLowerCase()) + ) || [] + ) + }, [data, searchTerm]) + + const handleSelectAll = (checked: boolean) => { + if (checked) { + setSelectedModels(filteredModels.map((model) => model.id)) + } else { + setSelectedModels([]) + } + } + + const handleModelSelect = (modelId: string, checked: boolean) => { + if (checked) { + setSelectedModels((prev) => [...prev, modelId]) + } else { + setSelectedModels((prev) => prev.filter((id) => id !== modelId)) + } + } + + const onSave = async (models: string[]) => { + const payload = models.map((id) => ({ + model_id: id, + name: filteredModels.find((model) => model.id === id)?.name ?? id, + provider_id: openaiId + })) + + await createManyModels(payload) + + return true + } + + const { mutate: saveModels, isPending: isSaving } = useMutation({ + mutationFn: onSave, + onSuccess: () => { + setOpenModelModal(false) + message.success(t("modal.model.success")) + } + }) + + const handleSave = () => { + saveModels(selectedModels) + } + + if (status === "pending") { + return + } + + if (status === "error" || !data || data.length === 0) { + return
{t("noModelFound")}
+ } + + return ( +
+

+ {t("modal.model.subheading")} +

+ setSearchTerm(e.target.value)} + className="w-full" + /> +
+ 0 && + selectedModels.length < filteredModels.length + } + onChange={(e) => handleSelectAll(e.target.checked)}> + {t("selectAll")} + +
+ {`${selectedModels?.length} / ${data?.length}`} +
+
+
+
+ {filteredModels.map((model) => ( + handleModelSelect(model.id, e.target.checked)}> + {model?.name || model.id} + + ))} +
+
+ +
+ ) +} diff --git a/src/components/Option/Settings/openai.tsx b/src/components/Option/Settings/openai.tsx index ff3c9b4..5178628 100644 --- a/src/components/Option/Settings/openai.tsx +++ b/src/components/Option/Settings/openai.tsx @@ -8,7 +8,8 @@ import { updateOpenAIConfig } from "@/db/openai" import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" -import { Pencil, Trash2, Plus } from "lucide-react" +import { Pencil, Trash2, RotateCwIcon } from "lucide-react" +import { OpenAIFetchModel } from "./openai-fetch-model" export const OpenAIApp = () => { const { t } = useTranslation("openai") @@ -16,6 +17,8 @@ export const OpenAIApp = () => { const [editingConfig, setEditingConfig] = useState(null) const queryClient = useQueryClient() const [form] = Form.useForm() + const [openaiId, setOpenaiId] = useState(null) + const [openModelModal, setOpenModelModal] = useState(false) const { data: configs, isLoading } = useQuery({ queryKey: ["openAIConfigs"], @@ -24,12 +27,14 @@ export const OpenAIApp = () => { const addMutation = useMutation({ mutationFn: addOpenAICofig, - onSuccess: () => { + onSuccess: (data) => { queryClient.invalidateQueries({ queryKey: ["openAIConfigs"] }) setOpen(false) message.success(t("addSuccess")) + setOpenaiId(data) + setOpenModelModal(true) } }) @@ -129,6 +134,18 @@ export const OpenAIApp = () => { + + + +
+ + setOpenModelModal(false)}> + {openaiId ? ( + + ) : null} + ) diff --git a/src/db/models.ts b/src/db/models.ts new file mode 100644 index 0000000..207fe97 --- /dev/null +++ b/src/db/models.ts @@ -0,0 +1,176 @@ +import { getOpenAIConfigById as providerInfo } from "./openai" + +type Model = { + id: string + model_id: string + name: string + provider_id: string + lookup: string + db_type: string +} +export const generateID = () => { + return "model-xxxx-xxxx-xxx-xxxx".replace(/[x]/g, () => { + const r = Math.floor(Math.random() * 16) + return r.toString(16) + }) +} + +export const removeModelPrefix = (id: string) => { + return id.replace(/^model-/, "") +} +export class ModelDb { + db: chrome.storage.StorageArea + + constructor() { + this.db = chrome.storage.local + } + + getAll = async (): Promise => { + return new Promise((resolve, reject) => { + this.db.get(null, (result) => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + const data = Object.keys(result).map((key) => result[key]) + resolve(data) + } + }) + }) + } + + create = async (model: Model): Promise => { + return new Promise((resolve, reject) => { + this.db.set({ [model.id]: model }, () => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } + + getById = async (id: string): Promise => { + return new Promise((resolve, reject) => { + this.db.get(id, (result) => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve(result[id]) + } + }) + }) + } + + update = async (model: Model): Promise => { + return new Promise((resolve, reject) => { + this.db.set({ [model.id]: model }, () => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } + + delete = async (id: string): Promise => { + return new Promise((resolve, reject) => { + this.db.remove(id, () => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } + + deleteAll = async (): Promise => { + return new Promise((resolve, reject) => { + this.db.clear(() => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } +} + +export const createManyModels = async ( + data: { model_id: string; name: string; provider_id: string }[] +) => { + const db = new ModelDb() + + const models = data.map((item) => { + return { + ...item, + lookup: `${item.model_id}_${item.provider_id}`, + id: `${item.model_id}_${generateID()}`, + db_type: "openai_model" + } + }) + + for (const model of models) { + const isExist = await isLookupExist(model.lookup) + + if (isExist) { + continue + } + + await db.create(model) + } +} + +export const createModel = async ( + model_id: string, + name: string, + provider_id: string +) => { + const db = new ModelDb() + const id = generateID() + const model: Model = { + id: `${model_id}_${id}`, + model_id, + name, + provider_id, + lookup: `${model_id}_${provider_id}`, + db_type: "openai_model" + } + await db.create(model) + return model +} + +export const getModelInfo = async (id: string) => { + const db = new ModelDb() + const model = await db.getById(id) + return model +} + +export const getAllCustomModels = async () => { + const db = new ModelDb() + const models = (await db.getAll()).filter( + (model) => model.db_type === "openai_model" + ) + const modelsWithProvider = await Promise.all( + models.map(async (model) => { + const provider = await providerInfo(model.provider_id) + return { ...model, provider } + }) + ) + return modelsWithProvider +} + +export const deleteModel = async (id: string) => { + const db = new ModelDb() + await db.delete(id) +} + +export const isLookupExist = async (lookup: string) => { + const db = new ModelDb() + const models = await db.getAll() + const model = models.find((model) => model.lookup === lookup) + return model ? true : false +} diff --git a/src/db/openai.ts b/src/db/openai.ts index 501ecfd..45963cf 100644 --- a/src/db/openai.ts +++ b/src/db/openai.ts @@ -1,9 +1,12 @@ +import { cleanUrl } from "@/libs/clean-url" + type OpenAIModelConfig = { id: string name: string baseUrl: string apiKey?: string createdAt: number + db_type: string } export const generateID = () => { return "openai-xxxx-xxx-xxxx".replace(/[x]/g, () => { @@ -95,9 +98,10 @@ export const addOpenAICofig = async ({ name, baseUrl, apiKey }: { name: string, const config: OpenAIModelConfig = { id, name, - baseUrl, + baseUrl: cleanUrl(baseUrl), apiKey, - createdAt: Date.now() + createdAt: Date.now(), + db_type: "openai" } await openaiDb.create(config) return id @@ -107,7 +111,7 @@ export const addOpenAICofig = async ({ name, baseUrl, apiKey }: { name: string, export const getAllOpenAIConfig = async () => { const openaiDb = new OpenAIModelDb() const configs = await openaiDb.getAll() - return configs + return configs.filter(config => config.db_type === "openai") } export const updateOpenAIConfig = async ({ id, name, baseUrl, apiKey }: { id: string, name: string, baseUrl: string, apiKey: string }) => { @@ -115,9 +119,10 @@ export const updateOpenAIConfig = async ({ id, name, baseUrl, apiKey }: { id: st const config: OpenAIModelConfig = { id, name, - baseUrl, + baseUrl: cleanUrl(baseUrl), apiKey, - createdAt: Date.now() + createdAt: Date.now(), + db_type: "openai" } await openaiDb.update(config) @@ -137,10 +142,18 @@ export const updateOpenAIConfigApiKey = async (id: string, { name, baseUrl, apiK const config: OpenAIModelConfig = { id, name, - baseUrl, + baseUrl: cleanUrl(baseUrl), apiKey, - createdAt: Date.now() + createdAt: Date.now(), + db_type: "openai" } await openaiDb.update(config) +} + + +export const getOpenAIConfigById = async (id: string) => { + const openaiDb = new OpenAIModelDb() + const config = await openaiDb.getById(id) + return config } \ No newline at end of file diff --git a/src/libs/openai.ts b/src/libs/openai.ts new file mode 100644 index 0000000..8b6230e --- /dev/null +++ b/src/libs/openai.ts @@ -0,0 +1,25 @@ +type Model = { + id: string + name?: string +} + +export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => { + const url = `${baseUrl}/models` + const headers = apiKey + ? { + Authorization: `Bearer ${apiKey}` + } + : {} + + const res = await fetch(url, { + headers + }) + + if (!res.ok) { + return [] + } + + const data = (await res.json()) as { data: Model[] } + + return data.data +}