From ff4473c35b8fcdb41668519d81f2d403f8b43b83 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sun, 13 Oct 2024 18:22:16 +0530 Subject: [PATCH] feat: Add model type support Adds model type support for chat and embedding models. This allows users to specify which type of model they want to use when adding custom models. Additionally, this commit introduces a more descriptive interface for adding custom models, enhancing the clarity of the model selection process. --- src/assets/locale/en/openai.json | 32 ++++- .../Option/Models/AddCustomModelModal.tsx | 129 ++++++++++++++++++ .../Option/Models/AddOllamaModelModal.tsx | 3 +- .../Option/Models/CustomModelsTable.tsx | 18 +-- src/components/Option/Models/index.tsx | 9 ++ .../Option/Settings/openai-fetch-model.tsx | 44 +++++- src/components/Option/Settings/openai.tsx | 30 ++-- src/db/models.ts | 49 ++++--- src/services/ollama.ts | 6 +- 9 files changed, 277 insertions(+), 43 deletions(-) create mode 100644 src/components/Option/Models/AddCustomModelModal.tsx diff --git a/src/assets/locale/en/openai.json b/src/assets/locale/en/openai.json index 18a5f81..c0838cb 100644 --- a/src/assets/locale/en/openai.json +++ b/src/assets/locale/en/openai.json @@ -31,7 +31,7 @@ "deleteConfirm": "Are you sure you want to delete this provider?", "model": { "title": "Model List", - "subheading": "Please select the models you want to use with this provider.", + "subheading": "Please select the chat models you want to use with this provider.", "success": "Successfully added new models." }, "tipLMStudio": "Page Assist will automatically fetch the models you loaded on LM Studio. You don't need to add them manually." @@ -41,7 +41,8 @@ "updateSuccess": "Provider updated successfully.", "delete": "Delete", "edit": "Edit", - "refetch": "Refech Model List", + "newModel": "Add Models to Provider", + "noNewModel": "For LMStudio, we fetch dynamically. No manual addition needed.", "searchModel": "Search Model", "selectAll": "Select All", "save": "Save", @@ -49,6 +50,7 @@ "manageModels": { "columns": { "name": "Model Name", + "model_type": "Model Type", "model_id": "Model ID", "provider": "Provider Name", "actions": "Action" @@ -58,7 +60,31 @@ }, "confirm": { "delete": "Are you sure you want to delete this model?" + }, + "modal": { + "title": "Add Custom Model", + "form": { + "name": { + "label": "Model ID", + "placeholder": "llama3.2", + "required": "Model ID is required." + }, + "provider": { + "label": "Provider", + "placeholder": "Select provider", + "required": "Provider is required." + }, + "type": { + "label": "Model Type" + } + } } }, - "noModelFound": "No model found. Make sure you have added correct provider with base URL and API key." + "noModelFound": "No model found. Make sure you have added correct provider with base URL and API key.", + "radio": { + "chat": "Chat Model", + "embedding": "Embedding Model", + "chatInfo": "is used for chat completion and conversation generation", + "embeddingInfo": "is used for RAG and other semantic search related tasks." + } } \ No newline at end of file diff --git a/src/components/Option/Models/AddCustomModelModal.tsx b/src/components/Option/Models/AddCustomModelModal.tsx new file mode 100644 index 0000000..c0315af --- /dev/null +++ b/src/components/Option/Models/AddCustomModelModal.tsx @@ -0,0 +1,129 @@ +import { createModel } from "@/db/models" +import { getAllOpenAIConfig } from "@/db/openai" +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" +import { Input, Modal, Form, Select, Radio } from "antd" +import { Loader2 } from "lucide-react" +import { useTranslation } from "react-i18next" + +type Props = { + open: boolean + setOpen: (open: boolean) => void +} + +export const AddCustomModelModal: React.FC = ({ open, setOpen }) => { + const { t } = useTranslation(["openai"]) + const [form] = Form.useForm() + const queryClient = useQueryClient() + + const { data, isPending } = useQuery({ + queryKey: ["fetchProviders"], + queryFn: async () => { + const providers = await getAllOpenAIConfig() + return providers.filter((provider) => provider.provider !== "lmstudio") + } + }) + + const onFinish = async (values: { + model_id: string + model_type: "chat" | "embedding" + provider_id: string + }) => { + await createModel( + values.model_id, + values.model_id, + values.provider_id, + values.model_type + ) + + return true + } + + const { mutate: createModelMutation, isPending: isSaving } = useMutation({ + mutationFn: onFinish, + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: ["fetchCustomModels"] + }) + queryClient.invalidateQueries({ + queryKey: ["fetchModel"] + }) + setOpen(false) + form.resetFields() + } + }) + + return ( + setOpen(false)}> +
+ + + + + + + + + + + {t("radio.chat")} + {t("radio.embedding")} + + + + + + +
+
+ ) +} diff --git a/src/components/Option/Models/AddOllamaModelModal.tsx b/src/components/Option/Models/AddOllamaModelModal.tsx index dd2bd01..7ca2972 100644 --- a/src/components/Option/Models/AddOllamaModelModal.tsx +++ b/src/components/Option/Models/AddOllamaModelModal.tsx @@ -1,5 +1,5 @@ import { useForm } from "@mantine/form" -import { useMutation } from "@tanstack/react-query" +import { useMutation, useQueryClient } from "@tanstack/react-query" import { Input, Modal, notification } from "antd" import { Download } from "lucide-react" import { useTranslation } from "react-i18next" @@ -11,6 +11,7 @@ type Props = { export const AddOllamaModelModal: React.FC = ({ open, setOpen }) => { const { t } = useTranslation(["settings", "common", "openai"]) + const queryClient = useQueryClient() const form = useForm({ initialValues: { diff --git a/src/components/Option/Models/CustomModelsTable.tsx b/src/components/Option/Models/CustomModelsTable.tsx index 4bc57b9..74ed12b 100644 --- a/src/components/Option/Models/CustomModelsTable.tsx +++ b/src/components/Option/Models/CustomModelsTable.tsx @@ -1,7 +1,7 @@ import { getAllCustomModels, deleteModel } from "@/db/models" import { useStorage } from "@plasmohq/storage/hook" import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query" -import { Skeleton, Table, Tooltip } from "antd" +import { Skeleton, Table, Tag, Tooltip } from "antd" import { Trash2 } from "lucide-react" import { useTranslation } from "react-i18next" @@ -10,7 +10,6 @@ export const CustomModelsTable = () => { const { t } = useTranslation(["openai", "common"]) - const queryClient = useQueryClient() const { data, status } = useQuery({ @@ -27,7 +26,6 @@ export const CustomModelsTable = () => { } }) - return (
@@ -37,16 +35,20 @@ export const CustomModelsTable = () => {
( + + {t(`radio.${txt}`)} + + ) + }, { title: t("manageModels.columns.provider"), dataIndex: "provider", diff --git a/src/components/Option/Models/index.tsx b/src/components/Option/Models/index.tsx index b2cab08..6a63700 100644 --- a/src/components/Option/Models/index.tsx +++ b/src/components/Option/Models/index.tsx @@ -6,11 +6,13 @@ import { useTranslation } from "react-i18next" import { OllamaModelsTable } from "./OllamaModelsTable" import { CustomModelsTable } from "./CustomModelsTable" import { AddOllamaModelModal } from "./AddOllamaModelModal" +import { AddCustomModelModal } from "./AddCustomModelModal" dayjs.extend(relativeTime) export const ModelsBody = () => { const [open, setOpen] = useState(false) + const [openAddModelModal, setOpenAddModelModal] = useState(false) const [segmented, setSegmented] = useState("ollama") const { t } = useTranslation(["settings", "common", "openai"]) @@ -26,6 +28,8 @@ export const ModelsBody = () => { onClick={() => { if (segmented === "ollama") { setOpen(true) + } else { + setOpenAddModelModal(true) } }} className="inline-flex items-center rounded-md border border-transparent bg-black px-2 py-2 text-md font-medium leading-4 text-white shadow-sm hover:bg-gray-800 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50"> @@ -56,6 +60,11 @@ export const ModelsBody = () => { + + ) } diff --git a/src/components/Option/Settings/openai-fetch-model.tsx b/src/components/Option/Settings/openai-fetch-model.tsx index c64e030..5a67f01 100644 --- a/src/components/Option/Settings/openai-fetch-model.tsx +++ b/src/components/Option/Settings/openai-fetch-model.tsx @@ -1,10 +1,12 @@ import { getOpenAIConfigById } from "@/db/openai" import { getAllOpenAIModels } from "@/libs/openai" -import { useMutation, useQuery } from "@tanstack/react-query" +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" import { useTranslation } from "react-i18next" -import { Checkbox, Input, Spin, message } from "antd" +import { Checkbox, Input, Spin, message, Radio } from "antd" import { useState, useMemo } from "react" import { createManyModels } from "@/db/models" +import { Popover } from "antd" +import { InfoIcon } from "lucide-react" type Props = { openaiId: string @@ -15,6 +17,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { const { t } = useTranslation(["openai"]) const [selectedModels, setSelectedModels] = useState([]) const [searchTerm, setSearchTerm] = useState("") + const [modelType, setModelType] = useState("chat") + const queryClient = useQueryClient() const { data, status } = useQuery({ queryKey: ["openAIConfigs", openaiId], @@ -56,7 +60,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { const payload = models.map((id) => ({ model_id: id, name: filteredModels.find((model) => model.id === id)?.name ?? id, - provider_id: openaiId + provider_id: openaiId, + model_type: modelType })) await createManyModels(payload) @@ -68,6 +73,9 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { mutationFn: onSave, onSuccess: () => { setOpenModelModal(false) + queryClient.invalidateQueries({ + queryKey: ["fetchModel"] + }) message.success(t("modal.model.success")) } }) @@ -97,6 +105,7 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {

{t("modal.model.subheading")}

+ { ))} + +
+ setModelType(e.target.value)} + value={modelType}> + {t("radio.chat")} + {t("radio.embedding")} + + +

+ + {t("radio.chat")} + {" "} + {t("radio.chatInfo")} +

+

+ + {t("radio.embedding")} + {" "} + {t("radio.embeddingInfo")} +

+
+ }> + + + + - + +