feat: Add model type support
Adds model type support for chat and embedding models. This allows users to specify which type of model they want to use when adding custom models. Additionally, this commit introduces a more descriptive interface for adding custom models, enhancing the clarity of the model selection process.
This commit is contained in:
129
src/components/Option/Models/AddCustomModelModal.tsx
Normal file
129
src/components/Option/Models/AddCustomModelModal.tsx
Normal file
@@ -0,0 +1,129 @@
|
||||
import { createModel } from "@/db/models"
|
||||
import { getAllOpenAIConfig } from "@/db/openai"
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
|
||||
import { Input, Modal, Form, Select, Radio } from "antd"
|
||||
import { Loader2 } from "lucide-react"
|
||||
import { useTranslation } from "react-i18next"
|
||||
|
||||
type Props = {
|
||||
open: boolean
|
||||
setOpen: (open: boolean) => void
|
||||
}
|
||||
|
||||
export const AddCustomModelModal: React.FC<Props> = ({ open, setOpen }) => {
|
||||
const { t } = useTranslation(["openai"])
|
||||
const [form] = Form.useForm()
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isPending } = useQuery({
|
||||
queryKey: ["fetchProviders"],
|
||||
queryFn: async () => {
|
||||
const providers = await getAllOpenAIConfig()
|
||||
return providers.filter((provider) => provider.provider !== "lmstudio")
|
||||
}
|
||||
})
|
||||
|
||||
const onFinish = async (values: {
|
||||
model_id: string
|
||||
model_type: "chat" | "embedding"
|
||||
provider_id: string
|
||||
}) => {
|
||||
await createModel(
|
||||
values.model_id,
|
||||
values.model_id,
|
||||
values.provider_id,
|
||||
values.model_type
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
const { mutate: createModelMutation, isPending: isSaving } = useMutation({
|
||||
mutationFn: onFinish,
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["fetchCustomModels"]
|
||||
})
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["fetchModel"]
|
||||
})
|
||||
setOpen(false)
|
||||
form.resetFields()
|
||||
}
|
||||
})
|
||||
|
||||
return (
|
||||
<Modal
|
||||
footer={null}
|
||||
open={open}
|
||||
title={t("manageModels.modal.title")}
|
||||
onCancel={() => setOpen(false)}>
|
||||
<Form form={form} onFinish={createModelMutation} layout="vertical">
|
||||
<Form.Item
|
||||
name="model_id"
|
||||
label={t("manageModels.modal.form.name.label")}
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: t("manageModels.modal.form.name.required")
|
||||
}
|
||||
]}>
|
||||
<Input
|
||||
placeholder={t("manageModels.modal.form.name.placeholder")}
|
||||
size="large"
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="provider_id"
|
||||
label={t("manageModels.modal.form.provider.label")}
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: t("manageModels.modal.form.provider.required")
|
||||
}
|
||||
]}>
|
||||
<Select
|
||||
placeholder={t("manageModels.modal.form.provider.placeholder")}
|
||||
size="large"
|
||||
loading={isPending}>
|
||||
{data?.map((provider: any) => (
|
||||
<Select.Option key={provider.id} value={provider.id}>
|
||||
{provider.name}
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="model_type"
|
||||
label={t("manageModels.modal.form.type.label")}
|
||||
initialValue="chat"
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: t("manageModels.modal.form.type.required")
|
||||
}
|
||||
]}>
|
||||
<Radio.Group>
|
||||
<Radio value="chat">{t("radio.chat")}</Radio>
|
||||
<Radio value="embedding">{t("radio.embedding")}</Radio>
|
||||
</Radio.Group>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isSaving}
|
||||
className="inline-flex justify-center w-full text-center mt-4 items-center rounded-md border border-transparent bg-black px-2 py-2 text-sm font-medium leading-4 text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50 ">
|
||||
{!isSaving ? (
|
||||
t("common:save")
|
||||
) : (
|
||||
<Loader2 className="w-5 h-5 animate-spin" />
|
||||
)}
|
||||
</button>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useForm } from "@mantine/form"
|
||||
import { useMutation } from "@tanstack/react-query"
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query"
|
||||
import { Input, Modal, notification } from "antd"
|
||||
import { Download } from "lucide-react"
|
||||
import { useTranslation } from "react-i18next"
|
||||
@@ -11,6 +11,7 @@ type Props = {
|
||||
|
||||
export const AddOllamaModelModal: React.FC<Props> = ({ open, setOpen }) => {
|
||||
const { t } = useTranslation(["settings", "common", "openai"])
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const form = useForm({
|
||||
initialValues: {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { getAllCustomModels, deleteModel } from "@/db/models"
|
||||
import { useStorage } from "@plasmohq/storage/hook"
|
||||
import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query"
|
||||
import { Skeleton, Table, Tooltip } from "antd"
|
||||
import { Skeleton, Table, Tag, Tooltip } from "antd"
|
||||
import { Trash2 } from "lucide-react"
|
||||
import { useTranslation } from "react-i18next"
|
||||
|
||||
@@ -10,7 +10,6 @@ export const CustomModelsTable = () => {
|
||||
|
||||
const { t } = useTranslation(["openai", "common"])
|
||||
|
||||
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, status } = useQuery({
|
||||
@@ -27,7 +26,6 @@ export const CustomModelsTable = () => {
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>
|
||||
@@ -37,16 +35,20 @@ export const CustomModelsTable = () => {
|
||||
<div className="overflow-x-auto">
|
||||
<Table
|
||||
columns={[
|
||||
{
|
||||
title: t("manageModels.columns.name"),
|
||||
dataIndex: "name",
|
||||
key: "name"
|
||||
},
|
||||
{
|
||||
title: t("manageModels.columns.model_id"),
|
||||
dataIndex: "model_id",
|
||||
key: "model_id"
|
||||
},
|
||||
{
|
||||
title: t("manageModels.columns.model_type"),
|
||||
dataIndex: "model_type",
|
||||
render: (txt) => (
|
||||
<Tag color={txt === "chat" ? "green" : "blue"}>
|
||||
{t(`radio.${txt}`)}
|
||||
</Tag>
|
||||
)
|
||||
},
|
||||
{
|
||||
title: t("manageModels.columns.provider"),
|
||||
dataIndex: "provider",
|
||||
|
||||
@@ -6,11 +6,13 @@ import { useTranslation } from "react-i18next"
|
||||
import { OllamaModelsTable } from "./OllamaModelsTable"
|
||||
import { CustomModelsTable } from "./CustomModelsTable"
|
||||
import { AddOllamaModelModal } from "./AddOllamaModelModal"
|
||||
import { AddCustomModelModal } from "./AddCustomModelModal"
|
||||
|
||||
dayjs.extend(relativeTime)
|
||||
|
||||
export const ModelsBody = () => {
|
||||
const [open, setOpen] = useState(false)
|
||||
const [openAddModelModal, setOpenAddModelModal] = useState(false)
|
||||
const [segmented, setSegmented] = useState<string>("ollama")
|
||||
|
||||
const { t } = useTranslation(["settings", "common", "openai"])
|
||||
@@ -26,6 +28,8 @@ export const ModelsBody = () => {
|
||||
onClick={() => {
|
||||
if (segmented === "ollama") {
|
||||
setOpen(true)
|
||||
} else {
|
||||
setOpenAddModelModal(true)
|
||||
}
|
||||
}}
|
||||
className="inline-flex items-center rounded-md border border-transparent bg-black px-2 py-2 text-md font-medium leading-4 text-white shadow-sm hover:bg-gray-800 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50">
|
||||
@@ -56,6 +60,11 @@ export const ModelsBody = () => {
|
||||
</div>
|
||||
|
||||
<AddOllamaModelModal open={open} setOpen={setOpen} />
|
||||
|
||||
<AddCustomModelModal
|
||||
open={openAddModelModal}
|
||||
setOpen={setOpenAddModelModal}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { getOpenAIConfigById } from "@/db/openai"
|
||||
import { getAllOpenAIModels } from "@/libs/openai"
|
||||
import { useMutation, useQuery } from "@tanstack/react-query"
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
|
||||
import { useTranslation } from "react-i18next"
|
||||
import { Checkbox, Input, Spin, message } from "antd"
|
||||
import { Checkbox, Input, Spin, message, Radio } from "antd"
|
||||
import { useState, useMemo } from "react"
|
||||
import { createManyModels } from "@/db/models"
|
||||
import { Popover } from "antd"
|
||||
import { InfoIcon } from "lucide-react"
|
||||
|
||||
type Props = {
|
||||
openaiId: string
|
||||
@@ -15,6 +17,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
|
||||
const { t } = useTranslation(["openai"])
|
||||
const [selectedModels, setSelectedModels] = useState<string[]>([])
|
||||
const [searchTerm, setSearchTerm] = useState("")
|
||||
const [modelType, setModelType] = useState("chat")
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, status } = useQuery({
|
||||
queryKey: ["openAIConfigs", openaiId],
|
||||
@@ -56,7 +60,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
|
||||
const payload = models.map((id) => ({
|
||||
model_id: id,
|
||||
name: filteredModels.find((model) => model.id === id)?.name ?? id,
|
||||
provider_id: openaiId
|
||||
provider_id: openaiId,
|
||||
model_type: modelType
|
||||
}))
|
||||
|
||||
await createManyModels(payload)
|
||||
@@ -68,6 +73,9 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
|
||||
mutationFn: onSave,
|
||||
onSuccess: () => {
|
||||
setOpenModelModal(false)
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["fetchModel"]
|
||||
})
|
||||
message.success(t("modal.model.success"))
|
||||
}
|
||||
})
|
||||
@@ -97,6 +105,7 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
|
||||
<p className="text-sm text-gray-500 dark:text-gray-400">
|
||||
{t("modal.model.subheading")}
|
||||
</p>
|
||||
|
||||
<Input
|
||||
placeholder={t("searchModel")}
|
||||
value={searchTerm}
|
||||
@@ -134,6 +143,35 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center">
|
||||
<Radio.Group
|
||||
onChange={(e) => setModelType(e.target.value)}
|
||||
value={modelType}>
|
||||
<Radio value="chat">{t("radio.chat")}</Radio>
|
||||
<Radio value="embedding">{t("radio.embedding")}</Radio>
|
||||
</Radio.Group>
|
||||
<Popover
|
||||
content={
|
||||
<div>
|
||||
<p>
|
||||
<b className="text-gray-800 dark:text-gray-100">
|
||||
{t("radio.chat")}
|
||||
</b>{" "}
|
||||
{t("radio.chatInfo")}
|
||||
</p>
|
||||
<p>
|
||||
<b className="text-gray-800 dark:text-gray-100">
|
||||
{t("radio.embedding")}
|
||||
</b>{" "}
|
||||
{t("radio.embeddingInfo")}
|
||||
</p>
|
||||
</div>
|
||||
}>
|
||||
<InfoIcon className="ml-2 h-4 w-4 text-gray-500 cursor-pointer" />
|
||||
</Popover>
|
||||
</div>
|
||||
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={isSaving}
|
||||
|
||||
@@ -14,7 +14,13 @@ import {
|
||||
updateOpenAIConfig
|
||||
} from "@/db/openai"
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
|
||||
import { Pencil, Trash2, RotateCwIcon, DownloadIcon, AlertTriangle } from "lucide-react"
|
||||
import {
|
||||
Pencil,
|
||||
Trash2,
|
||||
RotateCwIcon,
|
||||
DownloadIcon,
|
||||
AlertTriangle
|
||||
} from "lucide-react"
|
||||
import { OpenAIFetchModel } from "./openai-fetch-model"
|
||||
import { OAI_API_PROVIDERS } from "@/utils/oai-api-providers"
|
||||
|
||||
@@ -149,17 +155,23 @@ export const OpenAIApp = () => {
|
||||
</button>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip title={t("refetch")}>
|
||||
<Tooltip
|
||||
title={
|
||||
record.provider !== "lmstudio"
|
||||
? t("newModel")
|
||||
: t("noNewModel")
|
||||
}>
|
||||
<button
|
||||
className="text-gray-700 dark:text-gray-400"
|
||||
className="text-gray-700 dark:text-gray-400 disabled:opacity-50"
|
||||
onClick={() => {
|
||||
setOpenModelModal(true)
|
||||
setOpenaiId(record.id)
|
||||
}}
|
||||
disabled={!record.id}>
|
||||
disabled={!record.id || record.provider === "lmstudio"}>
|
||||
<DownloadIcon className="size-4" />
|
||||
</button>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip title={t("delete")}>
|
||||
<button
|
||||
className="text-red-500 dark:text-red-400"
|
||||
@@ -251,11 +263,11 @@ export const OpenAIApp = () => {
|
||||
placeholder={t("modal.apiKey.placeholder")}
|
||||
/>
|
||||
</Form.Item>
|
||||
{
|
||||
provider === "lmstudio" && <div className="text-xs text-gray-600 dark:text-gray-400 mb-4">
|
||||
{t("modal.tipLMStudio")}
|
||||
</div>
|
||||
}
|
||||
{provider === "lmstudio" && (
|
||||
<div className="text-xs text-gray-600 dark:text-gray-400 mb-4">
|
||||
{t("modal.tipLMStudio")}
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
type="submit"
|
||||
className="inline-flex justify-center w-full text-center mt-4 items-center rounded-md border border-transparent bg-black px-2 py-2 text-sm font-medium leading-4 text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50">
|
||||
|
||||
Reference in New Issue
Block a user