diff --git a/src/assets/locale/en/openai.json b/src/assets/locale/en/openai.json index 18a5f81..c0838cb 100644 --- a/src/assets/locale/en/openai.json +++ b/src/assets/locale/en/openai.json @@ -31,7 +31,7 @@ "deleteConfirm": "Are you sure you want to delete this provider?", "model": { "title": "Model List", - "subheading": "Please select the models you want to use with this provider.", + "subheading": "Please select the chat models you want to use with this provider.", "success": "Successfully added new models." }, "tipLMStudio": "Page Assist will automatically fetch the models you loaded on LM Studio. You don't need to add them manually." @@ -41,7 +41,8 @@ "updateSuccess": "Provider updated successfully.", "delete": "Delete", "edit": "Edit", - "refetch": "Refech Model List", + "newModel": "Add Models to Provider", + "noNewModel": "For LMStudio, we fetch dynamically. No manual addition needed.", "searchModel": "Search Model", "selectAll": "Select All", "save": "Save", @@ -49,6 +50,7 @@ "manageModels": { "columns": { "name": "Model Name", + "model_type": "Model Type", "model_id": "Model ID", "provider": "Provider Name", "actions": "Action" @@ -58,7 +60,31 @@ }, "confirm": { "delete": "Are you sure you want to delete this model?" + }, + "modal": { + "title": "Add Custom Model", + "form": { + "name": { + "label": "Model ID", + "placeholder": "llama3.2", + "required": "Model ID is required." + }, + "provider": { + "label": "Provider", + "placeholder": "Select provider", + "required": "Provider is required." + }, + "type": { + "label": "Model Type" + } + } } }, - "noModelFound": "No model found. Make sure you have added correct provider with base URL and API key." + "noModelFound": "No model found. Make sure you have added correct provider with base URL and API key.", + "radio": { + "chat": "Chat Model", + "embedding": "Embedding Model", + "chatInfo": "is used for chat completion and conversation generation", + "embeddingInfo": "is used for RAG and other semantic search related tasks." + } } \ No newline at end of file diff --git a/src/components/Option/Models/AddCustomModelModal.tsx b/src/components/Option/Models/AddCustomModelModal.tsx new file mode 100644 index 0000000..c0315af --- /dev/null +++ b/src/components/Option/Models/AddCustomModelModal.tsx @@ -0,0 +1,129 @@ +import { createModel } from "@/db/models" +import { getAllOpenAIConfig } from "@/db/openai" +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" +import { Input, Modal, Form, Select, Radio } from "antd" +import { Loader2 } from "lucide-react" +import { useTranslation } from "react-i18next" + +type Props = { + open: boolean + setOpen: (open: boolean) => void +} + +export const AddCustomModelModal: React.FC = ({ open, setOpen }) => { + const { t } = useTranslation(["openai"]) + const [form] = Form.useForm() + const queryClient = useQueryClient() + + const { data, isPending } = useQuery({ + queryKey: ["fetchProviders"], + queryFn: async () => { + const providers = await getAllOpenAIConfig() + return providers.filter((provider) => provider.provider !== "lmstudio") + } + }) + + const onFinish = async (values: { + model_id: string + model_type: "chat" | "embedding" + provider_id: string + }) => { + await createModel( + values.model_id, + values.model_id, + values.provider_id, + values.model_type + ) + + return true + } + + const { mutate: createModelMutation, isPending: isSaving } = useMutation({ + mutationFn: onFinish, + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: ["fetchCustomModels"] + }) + queryClient.invalidateQueries({ + queryKey: ["fetchModel"] + }) + setOpen(false) + form.resetFields() + } + }) + + return ( + setOpen(false)}> +
+ + + + + + + + + + + {t("radio.chat")} + {t("radio.embedding")} + + + + + + +
+
+ ) +} diff --git a/src/components/Option/Models/AddOllamaModelModal.tsx b/src/components/Option/Models/AddOllamaModelModal.tsx index dd2bd01..7ca2972 100644 --- a/src/components/Option/Models/AddOllamaModelModal.tsx +++ b/src/components/Option/Models/AddOllamaModelModal.tsx @@ -1,5 +1,5 @@ import { useForm } from "@mantine/form" -import { useMutation } from "@tanstack/react-query" +import { useMutation, useQueryClient } from "@tanstack/react-query" import { Input, Modal, notification } from "antd" import { Download } from "lucide-react" import { useTranslation } from "react-i18next" @@ -11,6 +11,7 @@ type Props = { export const AddOllamaModelModal: React.FC = ({ open, setOpen }) => { const { t } = useTranslation(["settings", "common", "openai"]) + const queryClient = useQueryClient() const form = useForm({ initialValues: { diff --git a/src/components/Option/Models/CustomModelsTable.tsx b/src/components/Option/Models/CustomModelsTable.tsx index 4bc57b9..74ed12b 100644 --- a/src/components/Option/Models/CustomModelsTable.tsx +++ b/src/components/Option/Models/CustomModelsTable.tsx @@ -1,7 +1,7 @@ import { getAllCustomModels, deleteModel } from "@/db/models" import { useStorage } from "@plasmohq/storage/hook" import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query" -import { Skeleton, Table, Tooltip } from "antd" +import { Skeleton, Table, Tag, Tooltip } from "antd" import { Trash2 } from "lucide-react" import { useTranslation } from "react-i18next" @@ -10,7 +10,6 @@ export const CustomModelsTable = () => { const { t } = useTranslation(["openai", "common"]) - const queryClient = useQueryClient() const { data, status } = useQuery({ @@ -27,7 +26,6 @@ export const CustomModelsTable = () => { } }) - return (
@@ -37,16 +35,20 @@ export const CustomModelsTable = () => {
( + + {t(`radio.${txt}`)} + + ) + }, { title: t("manageModels.columns.provider"), dataIndex: "provider", diff --git a/src/components/Option/Models/index.tsx b/src/components/Option/Models/index.tsx index b2cab08..6a63700 100644 --- a/src/components/Option/Models/index.tsx +++ b/src/components/Option/Models/index.tsx @@ -6,11 +6,13 @@ import { useTranslation } from "react-i18next" import { OllamaModelsTable } from "./OllamaModelsTable" import { CustomModelsTable } from "./CustomModelsTable" import { AddOllamaModelModal } from "./AddOllamaModelModal" +import { AddCustomModelModal } from "./AddCustomModelModal" dayjs.extend(relativeTime) export const ModelsBody = () => { const [open, setOpen] = useState(false) + const [openAddModelModal, setOpenAddModelModal] = useState(false) const [segmented, setSegmented] = useState("ollama") const { t } = useTranslation(["settings", "common", "openai"]) @@ -26,6 +28,8 @@ export const ModelsBody = () => { onClick={() => { if (segmented === "ollama") { setOpen(true) + } else { + setOpenAddModelModal(true) } }} className="inline-flex items-center rounded-md border border-transparent bg-black px-2 py-2 text-md font-medium leading-4 text-white shadow-sm hover:bg-gray-800 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50"> @@ -56,6 +60,11 @@ export const ModelsBody = () => { + + ) } diff --git a/src/components/Option/Settings/openai-fetch-model.tsx b/src/components/Option/Settings/openai-fetch-model.tsx index c64e030..5a67f01 100644 --- a/src/components/Option/Settings/openai-fetch-model.tsx +++ b/src/components/Option/Settings/openai-fetch-model.tsx @@ -1,10 +1,12 @@ import { getOpenAIConfigById } from "@/db/openai" import { getAllOpenAIModels } from "@/libs/openai" -import { useMutation, useQuery } from "@tanstack/react-query" +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" import { useTranslation } from "react-i18next" -import { Checkbox, Input, Spin, message } from "antd" +import { Checkbox, Input, Spin, message, Radio } from "antd" import { useState, useMemo } from "react" import { createManyModels } from "@/db/models" +import { Popover } from "antd" +import { InfoIcon } from "lucide-react" type Props = { openaiId: string @@ -15,6 +17,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { const { t } = useTranslation(["openai"]) const [selectedModels, setSelectedModels] = useState([]) const [searchTerm, setSearchTerm] = useState("") + const [modelType, setModelType] = useState("chat") + const queryClient = useQueryClient() const { data, status } = useQuery({ queryKey: ["openAIConfigs", openaiId], @@ -56,7 +60,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { const payload = models.map((id) => ({ model_id: id, name: filteredModels.find((model) => model.id === id)?.name ?? id, - provider_id: openaiId + provider_id: openaiId, + model_type: modelType })) await createManyModels(payload) @@ -68,6 +73,9 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => { mutationFn: onSave, onSuccess: () => { setOpenModelModal(false) + queryClient.invalidateQueries({ + queryKey: ["fetchModel"] + }) message.success(t("modal.model.success")) } }) @@ -97,6 +105,7 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {

{t("modal.model.subheading")}

+ { ))} + +
+ setModelType(e.target.value)} + value={modelType}> + {t("radio.chat")} + {t("radio.embedding")} + + +

+ + {t("radio.chat")} + {" "} + {t("radio.chatInfo")} +

+

+ + {t("radio.embedding")} + {" "} + {t("radio.embeddingInfo")} +

+
+ }> + + + + - + +