page-assist/src/components/Option/Models/AddCustomModelModal.tsx
n4ze3m ff4473c35b feat: Add model type support
Adds model type support for chat and embedding models. This allows users to specify which type of model they want to use when adding custom models.

Additionally, this commit introduces a more descriptive interface for adding custom models, enhancing the clarity of the model selection process.
2024-10-13 18:22:16 +05:30

130 lines
3.9 KiB
TypeScript

import { createModel } from "@/db/models"
import { getAllOpenAIConfig } from "@/db/openai"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
import { Input, Modal, Form, Select, Radio } from "antd"
import { Loader2 } from "lucide-react"
import { useTranslation } from "react-i18next"
type Props = {
open: boolean
setOpen: (open: boolean) => void
}
export const AddCustomModelModal: React.FC<Props> = ({ open, setOpen }) => {
const { t } = useTranslation(["openai"])
const [form] = Form.useForm()
const queryClient = useQueryClient()
const { data, isPending } = useQuery({
queryKey: ["fetchProviders"],
queryFn: async () => {
const providers = await getAllOpenAIConfig()
return providers.filter((provider) => provider.provider !== "lmstudio")
}
})
const onFinish = async (values: {
model_id: string
model_type: "chat" | "embedding"
provider_id: string
}) => {
await createModel(
values.model_id,
values.model_id,
values.provider_id,
values.model_type
)
return true
}
const { mutate: createModelMutation, isPending: isSaving } = useMutation({
mutationFn: onFinish,
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: ["fetchCustomModels"]
})
queryClient.invalidateQueries({
queryKey: ["fetchModel"]
})
setOpen(false)
form.resetFields()
}
})
return (
<Modal
footer={null}
open={open}
title={t("manageModels.modal.title")}
onCancel={() => setOpen(false)}>
<Form form={form} onFinish={createModelMutation} layout="vertical">
<Form.Item
name="model_id"
label={t("manageModels.modal.form.name.label")}
rules={[
{
required: true,
message: t("manageModels.modal.form.name.required")
}
]}>
<Input
placeholder={t("manageModels.modal.form.name.placeholder")}
size="large"
/>
</Form.Item>
<Form.Item
name="provider_id"
label={t("manageModels.modal.form.provider.label")}
rules={[
{
required: true,
message: t("manageModels.modal.form.provider.required")
}
]}>
<Select
placeholder={t("manageModels.modal.form.provider.placeholder")}
size="large"
loading={isPending}>
{data?.map((provider: any) => (
<Select.Option key={provider.id} value={provider.id}>
{provider.name}
</Select.Option>
))}
</Select>
</Form.Item>
<Form.Item
name="model_type"
label={t("manageModels.modal.form.type.label")}
initialValue="chat"
rules={[
{
required: true,
message: t("manageModels.modal.form.type.required")
}
]}>
<Radio.Group>
<Radio value="chat">{t("radio.chat")}</Radio>
<Radio value="embedding">{t("radio.embedding")}</Radio>
</Radio.Group>
</Form.Item>
<Form.Item>
<button
type="submit"
disabled={isSaving}
className="inline-flex justify-center w-full text-center mt-4 items-center rounded-md border border-transparent bg-black px-2 py-2 text-sm font-medium leading-4 text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50 ">
{!isSaving ? (
t("common:save")
) : (
<Loader2 className="w-5 h-5 animate-spin" />
)}
</button>
</Form.Item>
</Form>
</Modal>
)
}