diff --git a/src/assets/locale/en/common.json b/src/assets/locale/en/common.json index fe6fae2..70fde08 100644 --- a/src/assets/locale/en/common.json +++ b/src/assets/locale/en/common.json @@ -96,5 +96,9 @@ "translate": "Translate", "custom": "Custom" }, - "citations": "Citations" + "citations": "Citations", + "segmented": { + "ollama": "Ollama Models", + "custom": "Custom Models" + } } \ No newline at end of file diff --git a/src/assets/locale/en/openai.json b/src/assets/locale/en/openai.json index e9babc3..48f0430 100644 --- a/src/assets/locale/en/openai.json +++ b/src/assets/locale/en/openai.json @@ -26,13 +26,37 @@ "required": "API Key is required.", "placeholder": "Enter API Key" }, - "submit": "Submit", + "submit": "Save", "update": "Update", - "deleteConfirm": "Are you sure you want to delete this provider?" + "deleteConfirm": "Are you sure you want to delete this provider?", + "model": { + "title": "Model List", + "subheading": "Please select the models you want to use with this provider.", + "success": "Successfully added new models." + } }, "addSuccess": "Provider added successfully.", "deleteSuccess": "Provider deleted successfully.", "updateSuccess": "Provider updated successfully.", "delete": "Delete", - "edit": "Edit" + "edit": "Edit", + "refetch": "Refech Model List", + "searchModel": "Search Model", + "selectAll": "Select All", + "save": "Save", + "saving": "Saving...", + "manageModels": { + "columns": { + "name": "Model Name", + "model_id": "Model ID", + "provider": "Provider Name", + "actions": "Action" + }, + "tooltip": { + "delete": "Delete" + }, + "confirm": { + "delete": "Are you sure you want to delete this model?" + } + } } \ No newline at end of file diff --git a/src/components/Option/Models/CustomModelsTable.tsx b/src/components/Option/Models/CustomModelsTable.tsx new file mode 100644 index 0000000..4bc57b9 --- /dev/null +++ b/src/components/Option/Models/CustomModelsTable.tsx @@ -0,0 +1,85 @@ +import { getAllCustomModels, deleteModel } from "@/db/models" +import { useStorage } from "@plasmohq/storage/hook" +import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query" +import { Skeleton, Table, Tooltip } from "antd" +import { Trash2 } from "lucide-react" +import { useTranslation } from "react-i18next" + +export const CustomModelsTable = () => { + const [selectedModel, setSelectedModel] = useStorage("selectedModel") + + const { t } = useTranslation(["openai", "common"]) + + + const queryClient = useQueryClient() + + const { data, status } = useQuery({ + queryKey: ["fetchCustomModels"], + queryFn: () => getAllCustomModels() + }) + + const { mutate: deleteCustomModel } = useMutation({ + mutationFn: deleteModel, + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: ["fetchCustomModels"] + }) + } + }) + + + return ( +
+
+ {status === "pending" && } + + {status === "success" && ( +
+ record.provider.name + }, + { + title: t("manageModels.columns.actions"), + render: (_, record) => ( + + + + ) + } + ]} + bordered + dataSource={data} + /> + + )} + + + ) +} diff --git a/src/components/Option/Models/OllamaModelsTable.tsx b/src/components/Option/Models/OllamaModelsTable.tsx new file mode 100644 index 0000000..72335fa --- /dev/null +++ b/src/components/Option/Models/OllamaModelsTable.tsx @@ -0,0 +1,199 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" +import { Skeleton, Table, Tag, Tooltip, notification, Modal, Input } from "antd" +import { bytePerSecondFormatter } from "~/libs/byte-formater" +import { deleteModel, getAllModels } from "~/services/ollama" +import dayjs from "dayjs" +import relativeTime from "dayjs/plugin/relativeTime" +import { useForm } from "@mantine/form" +import { RotateCcw, Trash2 } from "lucide-react" +import { useTranslation } from "react-i18next" +import { useStorage } from "@plasmohq/storage/hook" + +dayjs.extend(relativeTime) + +export const OllamaModelsTable = () => { + const queryClient = useQueryClient() + const { t } = useTranslation(["settings", "common"]) + const [selectedModel, setSelectedModel] = useStorage("selectedModel") + + const form = useForm({ + initialValues: { + model: "" + } + }) + + const { data, status } = useQuery({ + queryKey: ["fetchAllModels"], + queryFn: () => getAllModels({ returnEmpty: true }) + }) + + const { mutate: deleteOllamaModel } = useMutation({ + mutationFn: deleteModel, + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: ["fetchAllModels"] + }) + notification.success({ + message: t("manageModels.notification.success"), + description: t("manageModels.notification.successDeleteDescription") + }) + }, + onError: (error) => { + notification.error({ + message: "Error", + description: error?.message || t("manageModels.notification.someError") + }) + } + }) + + const pullModel = async (modelName: string) => { + notification.info({ + message: t("manageModels.notification.pullModel"), + description: t("manageModels.notification.pullModelDescription", { + modelName + }) + }) + + form.reset() + + browser.runtime.sendMessage({ + type: "pull_model", + modelName + }) + + return true + } + + const { mutate: pullOllamaModel } = useMutation({ + mutationFn: pullModel + }) + + return ( +
+
+ {status === "pending" && } + + {status === "success" && ( +
+
( + + {`${text?.slice(0, 5)}...${text?.slice(-4)}`} + + ) + }, + { + title: t("manageModels.columns.modifiedAt"), + dataIndex: "modified_at", + key: "modified_at", + render: (text: string) => dayjs(text).fromNow(true) + }, + { + title: t("manageModels.columns.size"), + dataIndex: "size", + key: "size", + render: (text: number) => bytePerSecondFormatter(text) + }, + { + title: t("manageModels.columns.actions"), + render: (_, record) => ( +
+ + + + + + +
+ ) + } + ]} + expandable={{ + expandedRowRender: (record) => ( +
+ ), + defaultExpandAllRows: false + }} + bordered + dataSource={data} + rowKey={(record) => `${record.model}-${record.digest}`} + /> + + )} + + + ) +} diff --git a/src/components/Option/Models/index.tsx b/src/components/Option/Models/index.tsx index 1fd12ba..af3c866 100644 --- a/src/components/Option/Models/index.tsx +++ b/src/components/Option/Models/index.tsx @@ -1,22 +1,30 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" -import { Skeleton, Table, Tag, Tooltip, notification, Modal, Input } from "antd" -import { bytePerSecondFormatter } from "~/libs/byte-formater" -import { deleteModel, getAllModels } from "~/services/ollama" +import { + Skeleton, + Table, + Tag, + Tooltip, + notification, + Modal, + Input, + Segmented +} from "antd" import dayjs from "dayjs" import relativeTime from "dayjs/plugin/relativeTime" import { useState } from "react" import { useForm } from "@mantine/form" -import { Download, RotateCcw, Trash2 } from "lucide-react" +import { Download } from "lucide-react" import { useTranslation } from "react-i18next" -import { useStorage } from "@plasmohq/storage/hook" +import { OllamaModelsTable } from "./OllamaModelsTable" +import { CustomModelsTable } from "./CustomModelsTable" dayjs.extend(relativeTime) export const ModelsBody = () => { - const queryClient = useQueryClient() const [open, setOpen] = useState(false) - const { t } = useTranslation(["settings", "common"]) - const [selectedModel, setSelectedModel] = useStorage("selectedModel") + const [segmented, setSegmented] = useState("ollama") + + const { t } = useTranslation(["settings", "common", "openai"]) const form = useForm({ initialValues: { @@ -24,30 +32,6 @@ export const ModelsBody = () => { } }) - const { data, status } = useQuery({ - queryKey: ["fetchAllModels"], - queryFn: () => getAllModels({ returnEmpty: true }) - }) - - const { mutate: deleteOllamaModel } = useMutation({ - mutationFn: deleteModel, - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: ["fetchAllModels"] - }) - notification.success({ - message: t("manageModels.notification.success"), - description: t("manageModels.notification.successDeleteDescription") - }) - }, - onError: (error) => { - notification.error({ - message: "Error", - description: error?.message || t("manageModels.notification.someError") - }) - } - }) - const pullModel = async (modelName: string) => { notification.info({ message: t("manageModels.notification.pullModel"), @@ -86,130 +70,26 @@ export const ModelsBody = () => { - - - {status === "pending" && } - - {status === "success" && ( -
-
+ ( - - {`${text?.slice(0, 5)}...${text?.slice(-4)}`} - - ) - }, - { - title: t("manageModels.columns.modifiedAt"), - dataIndex: "modified_at", - key: "modified_at", - render: (text: string) => dayjs(text).fromNow(true) - }, - { - title: t("manageModels.columns.size"), - dataIndex: "size", - key: "size", - render: (text: number) => bytePerSecondFormatter(text) - }, - { - title: t("manageModels.columns.actions"), - render: (_, record) => ( -
- - - - - - -
- ) + label: t("common:segmented.custom"), + value: "custom" } ]} - expandable={{ - expandedRowRender: (record) => ( -
- ), - defaultExpandAllRows: false + onChange={(value) => { + setSegmented(value) }} - bordered - dataSource={data} - rowKey={(record) => `${record.model}-${record.digest}`} /> - )} + + + {segmented === "ollama" ? : } void +} + +export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { + const { t } = useTranslation(["openai"]) + const [selectedModels, setSelectedModels] = useState([]) + const [searchTerm, setSearchTerm] = useState("") + + const { data, status } = useQuery({ + queryKey: ["openAIConfigs", openaiId], + queryFn: async () => { + const config = await getOpenAIConfigById(openaiId) + const models = await getAllOpenAIModels(config.baseUrl, config.apiKey) + return models + }, + enabled: !!openaiId + }) + + const filteredModels = useMemo(() => { + return ( + data?.filter((model) => + (model.name ?? model.id) + .toLowerCase() + .includes(searchTerm.toLowerCase()) + ) || [] + ) + }, [data, searchTerm]) + + const handleSelectAll = (checked: boolean) => { + if (checked) { + setSelectedModels(filteredModels.map((model) => model.id)) + } else { + setSelectedModels([]) + } + } + + const handleModelSelect = (modelId: string, checked: boolean) => { + if (checked) { + setSelectedModels((prev) => [...prev, modelId]) + } else { + setSelectedModels((prev) => prev.filter((id) => id !== modelId)) + } + } + + const onSave = async (models: string[]) => { + const payload = models.map((id) => ({ + model_id: id, + name: filteredModels.find((model) => model.id === id)?.name ?? id, + provider_id: openaiId + })) + + await createManyModels(payload) + + return true + } + + const { mutate: saveModels, isPending: isSaving } = useMutation({ + mutationFn: onSave, + onSuccess: () => { + setOpenModelModal(false) + message.success(t("modal.model.success")) + } + }) + + const handleSave = () => { + saveModels(selectedModels) + } + + if (status === "pending") { + return + } + + if (status === "error" || !data || data.length === 0) { + return
{t("noModelFound")}
+ } + + return ( +
+

+ {t("modal.model.subheading")} +

+ setSearchTerm(e.target.value)} + className="w-full" + /> +
+ 0 && + selectedModels.length < filteredModels.length + } + onChange={(e) => handleSelectAll(e.target.checked)}> + {t("selectAll")} + +
+ {`${selectedModels?.length} / ${data?.length}`} +
+
+
+
+ {filteredModels.map((model) => ( + handleModelSelect(model.id, e.target.checked)}> + {model?.name || model.id} + + ))} +
+
+ +
+ ) +} diff --git a/src/components/Option/Settings/openai.tsx b/src/components/Option/Settings/openai.tsx index ff3c9b4..5178628 100644 --- a/src/components/Option/Settings/openai.tsx +++ b/src/components/Option/Settings/openai.tsx @@ -8,7 +8,8 @@ import { updateOpenAIConfig } from "@/db/openai" import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" -import { Pencil, Trash2, Plus } from "lucide-react" +import { Pencil, Trash2, RotateCwIcon } from "lucide-react" +import { OpenAIFetchModel } from "./openai-fetch-model" export const OpenAIApp = () => { const { t } = useTranslation("openai") @@ -16,6 +17,8 @@ export const OpenAIApp = () => { const [editingConfig, setEditingConfig] = useState(null) const queryClient = useQueryClient() const [form] = Form.useForm() + const [openaiId, setOpenaiId] = useState(null) + const [openModelModal, setOpenModelModal] = useState(false) const { data: configs, isLoading } = useQuery({ queryKey: ["openAIConfigs"], @@ -24,12 +27,14 @@ export const OpenAIApp = () => { const addMutation = useMutation({ mutationFn: addOpenAICofig, - onSuccess: () => { + onSuccess: (data) => { queryClient.invalidateQueries({ queryKey: ["openAIConfigs"] }) setOpen(false) message.success(t("addSuccess")) + setOpenaiId(data) + setOpenModelModal(true) } }) @@ -129,6 +134,18 @@ export const OpenAIApp = () => { + + + +
+ + setOpenModelModal(false)}> + {openaiId ? ( + + ) : null} + ) diff --git a/src/db/models.ts b/src/db/models.ts new file mode 100644 index 0000000..207fe97 --- /dev/null +++ b/src/db/models.ts @@ -0,0 +1,176 @@ +import { getOpenAIConfigById as providerInfo } from "./openai" + +type Model = { + id: string + model_id: string + name: string + provider_id: string + lookup: string + db_type: string +} +export const generateID = () => { + return "model-xxxx-xxxx-xxx-xxxx".replace(/[x]/g, () => { + const r = Math.floor(Math.random() * 16) + return r.toString(16) + }) +} + +export const removeModelPrefix = (id: string) => { + return id.replace(/^model-/, "") +} +export class ModelDb { + db: chrome.storage.StorageArea + + constructor() { + this.db = chrome.storage.local + } + + getAll = async (): Promise => { + return new Promise((resolve, reject) => { + this.db.get(null, (result) => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + const data = Object.keys(result).map((key) => result[key]) + resolve(data) + } + }) + }) + } + + create = async (model: Model): Promise => { + return new Promise((resolve, reject) => { + this.db.set({ [model.id]: model }, () => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } + + getById = async (id: string): Promise => { + return new Promise((resolve, reject) => { + this.db.get(id, (result) => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve(result[id]) + } + }) + }) + } + + update = async (model: Model): Promise => { + return new Promise((resolve, reject) => { + this.db.set({ [model.id]: model }, () => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } + + delete = async (id: string): Promise => { + return new Promise((resolve, reject) => { + this.db.remove(id, () => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } + + deleteAll = async (): Promise => { + return new Promise((resolve, reject) => { + this.db.clear(() => { + if (chrome.runtime.lastError) { + reject(chrome.runtime.lastError) + } else { + resolve() + } + }) + }) + } +} + +export const createManyModels = async ( + data: { model_id: string; name: string; provider_id: string }[] +) => { + const db = new ModelDb() + + const models = data.map((item) => { + return { + ...item, + lookup: `${item.model_id}_${item.provider_id}`, + id: `${item.model_id}_${generateID()}`, + db_type: "openai_model" + } + }) + + for (const model of models) { + const isExist = await isLookupExist(model.lookup) + + if (isExist) { + continue + } + + await db.create(model) + } +} + +export const createModel = async ( + model_id: string, + name: string, + provider_id: string +) => { + const db = new ModelDb() + const id = generateID() + const model: Model = { + id: `${model_id}_${id}`, + model_id, + name, + provider_id, + lookup: `${model_id}_${provider_id}`, + db_type: "openai_model" + } + await db.create(model) + return model +} + +export const getModelInfo = async (id: string) => { + const db = new ModelDb() + const model = await db.getById(id) + return model +} + +export const getAllCustomModels = async () => { + const db = new ModelDb() + const models = (await db.getAll()).filter( + (model) => model.db_type === "openai_model" + ) + const modelsWithProvider = await Promise.all( + models.map(async (model) => { + const provider = await providerInfo(model.provider_id) + return { ...model, provider } + }) + ) + return modelsWithProvider +} + +export const deleteModel = async (id: string) => { + const db = new ModelDb() + await db.delete(id) +} + +export const isLookupExist = async (lookup: string) => { + const db = new ModelDb() + const models = await db.getAll() + const model = models.find((model) => model.lookup === lookup) + return model ? true : false +} diff --git a/src/db/openai.ts b/src/db/openai.ts index 501ecfd..45963cf 100644 --- a/src/db/openai.ts +++ b/src/db/openai.ts @@ -1,9 +1,12 @@ +import { cleanUrl } from "@/libs/clean-url" + type OpenAIModelConfig = { id: string name: string baseUrl: string apiKey?: string createdAt: number + db_type: string } export const generateID = () => { return "openai-xxxx-xxx-xxxx".replace(/[x]/g, () => { @@ -95,9 +98,10 @@ export const addOpenAICofig = async ({ name, baseUrl, apiKey }: { name: string, const config: OpenAIModelConfig = { id, name, - baseUrl, + baseUrl: cleanUrl(baseUrl), apiKey, - createdAt: Date.now() + createdAt: Date.now(), + db_type: "openai" } await openaiDb.create(config) return id @@ -107,7 +111,7 @@ export const addOpenAICofig = async ({ name, baseUrl, apiKey }: { name: string, export const getAllOpenAIConfig = async () => { const openaiDb = new OpenAIModelDb() const configs = await openaiDb.getAll() - return configs + return configs.filter(config => config.db_type === "openai") } export const updateOpenAIConfig = async ({ id, name, baseUrl, apiKey }: { id: string, name: string, baseUrl: string, apiKey: string }) => { @@ -115,9 +119,10 @@ export const updateOpenAIConfig = async ({ id, name, baseUrl, apiKey }: { id: st const config: OpenAIModelConfig = { id, name, - baseUrl, + baseUrl: cleanUrl(baseUrl), apiKey, - createdAt: Date.now() + createdAt: Date.now(), + db_type: "openai" } await openaiDb.update(config) @@ -137,10 +142,18 @@ export const updateOpenAIConfigApiKey = async (id: string, { name, baseUrl, apiK const config: OpenAIModelConfig = { id, name, - baseUrl, + baseUrl: cleanUrl(baseUrl), apiKey, - createdAt: Date.now() + createdAt: Date.now(), + db_type: "openai" } await openaiDb.update(config) +} + + +export const getOpenAIConfigById = async (id: string) => { + const openaiDb = new OpenAIModelDb() + const config = await openaiDb.getById(id) + return config } \ No newline at end of file diff --git a/src/libs/openai.ts b/src/libs/openai.ts new file mode 100644 index 0000000..8b6230e --- /dev/null +++ b/src/libs/openai.ts @@ -0,0 +1,25 @@ +type Model = { + id: string + name?: string +} + +export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => { + const url = `${baseUrl}/models` + const headers = apiKey + ? { + Authorization: `Bearer ${apiKey}` + } + : {} + + const res = await fetch(url, { + headers + }) + + if (!res.ok) { + return [] + } + + const data = (await res.json()) as { data: Model[] } + + return data.data +}