diff --git a/src/assets/locale/en/common.json b/src/assets/locale/en/common.json
index fe6fae2..70fde08 100644
--- a/src/assets/locale/en/common.json
+++ b/src/assets/locale/en/common.json
@@ -96,5 +96,9 @@
"translate": "Translate",
"custom": "Custom"
},
- "citations": "Citations"
+ "citations": "Citations",
+ "segmented": {
+ "ollama": "Ollama Models",
+ "custom": "Custom Models"
+ }
}
\ No newline at end of file
diff --git a/src/assets/locale/en/openai.json b/src/assets/locale/en/openai.json
index e9babc3..48f0430 100644
--- a/src/assets/locale/en/openai.json
+++ b/src/assets/locale/en/openai.json
@@ -26,13 +26,37 @@
"required": "API Key is required.",
"placeholder": "Enter API Key"
},
- "submit": "Submit",
+ "submit": "Save",
"update": "Update",
- "deleteConfirm": "Are you sure you want to delete this provider?"
+ "deleteConfirm": "Are you sure you want to delete this provider?",
+ "model": {
+ "title": "Model List",
+ "subheading": "Please select the models you want to use with this provider.",
+ "success": "Successfully added new models."
+ }
},
"addSuccess": "Provider added successfully.",
"deleteSuccess": "Provider deleted successfully.",
"updateSuccess": "Provider updated successfully.",
"delete": "Delete",
- "edit": "Edit"
+ "edit": "Edit",
+ "refetch": "Refech Model List",
+ "searchModel": "Search Model",
+ "selectAll": "Select All",
+ "save": "Save",
+ "saving": "Saving...",
+ "manageModels": {
+ "columns": {
+ "name": "Model Name",
+ "model_id": "Model ID",
+ "provider": "Provider Name",
+ "actions": "Action"
+ },
+ "tooltip": {
+ "delete": "Delete"
+ },
+ "confirm": {
+ "delete": "Are you sure you want to delete this model?"
+ }
+ }
}
\ No newline at end of file
diff --git a/src/components/Option/Models/CustomModelsTable.tsx b/src/components/Option/Models/CustomModelsTable.tsx
new file mode 100644
index 0000000..4bc57b9
--- /dev/null
+++ b/src/components/Option/Models/CustomModelsTable.tsx
@@ -0,0 +1,85 @@
+import { getAllCustomModels, deleteModel } from "@/db/models"
+import { useStorage } from "@plasmohq/storage/hook"
+import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query"
+import { Skeleton, Table, Tooltip } from "antd"
+import { Trash2 } from "lucide-react"
+import { useTranslation } from "react-i18next"
+
+export const CustomModelsTable = () => {
+ const [selectedModel, setSelectedModel] = useStorage("selectedModel")
+
+ const { t } = useTranslation(["openai", "common"])
+
+
+ const queryClient = useQueryClient()
+
+ const { data, status } = useQuery({
+ queryKey: ["fetchCustomModels"],
+ queryFn: () => getAllCustomModels()
+ })
+
+ const { mutate: deleteCustomModel } = useMutation({
+ mutationFn: deleteModel,
+ onSuccess: () => {
+ queryClient.invalidateQueries({
+ queryKey: ["fetchCustomModels"]
+ })
+ }
+ })
+
+
+ return (
+
+
+ {status === "pending" &&
}
+
+ {status === "success" && (
+
+
record.provider.name
+ },
+ {
+ title: t("manageModels.columns.actions"),
+ render: (_, record) => (
+
+
+
+ )
+ }
+ ]}
+ bordered
+ dataSource={data}
+ />
+
+ )}
+
+
+ )
+}
diff --git a/src/components/Option/Models/OllamaModelsTable.tsx b/src/components/Option/Models/OllamaModelsTable.tsx
new file mode 100644
index 0000000..72335fa
--- /dev/null
+++ b/src/components/Option/Models/OllamaModelsTable.tsx
@@ -0,0 +1,199 @@
+import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
+import { Skeleton, Table, Tag, Tooltip, notification, Modal, Input } from "antd"
+import { bytePerSecondFormatter } from "~/libs/byte-formater"
+import { deleteModel, getAllModels } from "~/services/ollama"
+import dayjs from "dayjs"
+import relativeTime from "dayjs/plugin/relativeTime"
+import { useForm } from "@mantine/form"
+import { RotateCcw, Trash2 } from "lucide-react"
+import { useTranslation } from "react-i18next"
+import { useStorage } from "@plasmohq/storage/hook"
+
+dayjs.extend(relativeTime)
+
+export const OllamaModelsTable = () => {
+ const queryClient = useQueryClient()
+ const { t } = useTranslation(["settings", "common"])
+ const [selectedModel, setSelectedModel] = useStorage("selectedModel")
+
+ const form = useForm({
+ initialValues: {
+ model: ""
+ }
+ })
+
+ const { data, status } = useQuery({
+ queryKey: ["fetchAllModels"],
+ queryFn: () => getAllModels({ returnEmpty: true })
+ })
+
+ const { mutate: deleteOllamaModel } = useMutation({
+ mutationFn: deleteModel,
+ onSuccess: () => {
+ queryClient.invalidateQueries({
+ queryKey: ["fetchAllModels"]
+ })
+ notification.success({
+ message: t("manageModels.notification.success"),
+ description: t("manageModels.notification.successDeleteDescription")
+ })
+ },
+ onError: (error) => {
+ notification.error({
+ message: "Error",
+ description: error?.message || t("manageModels.notification.someError")
+ })
+ }
+ })
+
+ const pullModel = async (modelName: string) => {
+ notification.info({
+ message: t("manageModels.notification.pullModel"),
+ description: t("manageModels.notification.pullModelDescription", {
+ modelName
+ })
+ })
+
+ form.reset()
+
+ browser.runtime.sendMessage({
+ type: "pull_model",
+ modelName
+ })
+
+ return true
+ }
+
+ const { mutate: pullOllamaModel } = useMutation({
+ mutationFn: pullModel
+ })
+
+ return (
+
+
+ {status === "pending" &&
}
+
+ {status === "success" && (
+
+
(
+
+ {`${text?.slice(0, 5)}...${text?.slice(-4)}`}
+
+ )
+ },
+ {
+ title: t("manageModels.columns.modifiedAt"),
+ dataIndex: "modified_at",
+ key: "modified_at",
+ render: (text: string) => dayjs(text).fromNow(true)
+ },
+ {
+ title: t("manageModels.columns.size"),
+ dataIndex: "size",
+ key: "size",
+ render: (text: number) => bytePerSecondFormatter(text)
+ },
+ {
+ title: t("manageModels.columns.actions"),
+ render: (_, record) => (
+
+
+
+
+
+
+
+
+ )
+ }
+ ]}
+ expandable={{
+ expandedRowRender: (record) => (
+
+ ),
+ defaultExpandAllRows: false
+ }}
+ bordered
+ dataSource={data}
+ rowKey={(record) => `${record.model}-${record.digest}`}
+ />
+
+ )}
+
+
+ )
+}
diff --git a/src/components/Option/Models/index.tsx b/src/components/Option/Models/index.tsx
index 1fd12ba..af3c866 100644
--- a/src/components/Option/Models/index.tsx
+++ b/src/components/Option/Models/index.tsx
@@ -1,22 +1,30 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
-import { Skeleton, Table, Tag, Tooltip, notification, Modal, Input } from "antd"
-import { bytePerSecondFormatter } from "~/libs/byte-formater"
-import { deleteModel, getAllModels } from "~/services/ollama"
+import {
+ Skeleton,
+ Table,
+ Tag,
+ Tooltip,
+ notification,
+ Modal,
+ Input,
+ Segmented
+} from "antd"
import dayjs from "dayjs"
import relativeTime from "dayjs/plugin/relativeTime"
import { useState } from "react"
import { useForm } from "@mantine/form"
-import { Download, RotateCcw, Trash2 } from "lucide-react"
+import { Download } from "lucide-react"
import { useTranslation } from "react-i18next"
-import { useStorage } from "@plasmohq/storage/hook"
+import { OllamaModelsTable } from "./OllamaModelsTable"
+import { CustomModelsTable } from "./CustomModelsTable"
dayjs.extend(relativeTime)
export const ModelsBody = () => {
- const queryClient = useQueryClient()
const [open, setOpen] = useState(false)
- const { t } = useTranslation(["settings", "common"])
- const [selectedModel, setSelectedModel] = useStorage("selectedModel")
+ const [segmented, setSegmented] = useState("ollama")
+
+ const { t } = useTranslation(["settings", "common", "openai"])
const form = useForm({
initialValues: {
@@ -24,30 +32,6 @@ export const ModelsBody = () => {
}
})
- const { data, status } = useQuery({
- queryKey: ["fetchAllModels"],
- queryFn: () => getAllModels({ returnEmpty: true })
- })
-
- const { mutate: deleteOllamaModel } = useMutation({
- mutationFn: deleteModel,
- onSuccess: () => {
- queryClient.invalidateQueries({
- queryKey: ["fetchAllModels"]
- })
- notification.success({
- message: t("manageModels.notification.success"),
- description: t("manageModels.notification.successDeleteDescription")
- })
- },
- onError: (error) => {
- notification.error({
- message: "Error",
- description: error?.message || t("manageModels.notification.someError")
- })
- }
- })
-
const pullModel = async (modelName: string) => {
notification.info({
message: t("manageModels.notification.pullModel"),
@@ -86,130 +70,26 @@ export const ModelsBody = () => {
-
-
- {status === "pending" && }
-
- {status === "success" && (
-
-
+ (
-
- {`${text?.slice(0, 5)}...${text?.slice(-4)}`}
-
- )
- },
- {
- title: t("manageModels.columns.modifiedAt"),
- dataIndex: "modified_at",
- key: "modified_at",
- render: (text: string) => dayjs(text).fromNow(true)
- },
- {
- title: t("manageModels.columns.size"),
- dataIndex: "size",
- key: "size",
- render: (text: number) => bytePerSecondFormatter(text)
- },
- {
- title: t("manageModels.columns.actions"),
- render: (_, record) => (
-
-
-
-
-
-
-
-
- )
+ label: t("common:segmented.custom"),
+ value: "custom"
}
]}
- expandable={{
- expandedRowRender: (record) => (
-
- ),
- defaultExpandAllRows: false
+ onChange={(value) => {
+ setSegmented(value)
}}
- bordered
- dataSource={data}
- rowKey={(record) => `${record.model}-${record.digest}`}
/>
- )}
+
+
+ {segmented === "ollama" ? : }
void
+}
+
+export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
+ const { t } = useTranslation(["openai"])
+ const [selectedModels, setSelectedModels] = useState([])
+ const [searchTerm, setSearchTerm] = useState("")
+
+ const { data, status } = useQuery({
+ queryKey: ["openAIConfigs", openaiId],
+ queryFn: async () => {
+ const config = await getOpenAIConfigById(openaiId)
+ const models = await getAllOpenAIModels(config.baseUrl, config.apiKey)
+ return models
+ },
+ enabled: !!openaiId
+ })
+
+ const filteredModels = useMemo(() => {
+ return (
+ data?.filter((model) =>
+ (model.name ?? model.id)
+ .toLowerCase()
+ .includes(searchTerm.toLowerCase())
+ ) || []
+ )
+ }, [data, searchTerm])
+
+ const handleSelectAll = (checked: boolean) => {
+ if (checked) {
+ setSelectedModels(filteredModels.map((model) => model.id))
+ } else {
+ setSelectedModels([])
+ }
+ }
+
+ const handleModelSelect = (modelId: string, checked: boolean) => {
+ if (checked) {
+ setSelectedModels((prev) => [...prev, modelId])
+ } else {
+ setSelectedModels((prev) => prev.filter((id) => id !== modelId))
+ }
+ }
+
+ const onSave = async (models: string[]) => {
+ const payload = models.map((id) => ({
+ model_id: id,
+ name: filteredModels.find((model) => model.id === id)?.name ?? id,
+ provider_id: openaiId
+ }))
+
+ await createManyModels(payload)
+
+ return true
+ }
+
+ const { mutate: saveModels, isPending: isSaving } = useMutation({
+ mutationFn: onSave,
+ onSuccess: () => {
+ setOpenModelModal(false)
+ message.success(t("modal.model.success"))
+ }
+ })
+
+ const handleSave = () => {
+ saveModels(selectedModels)
+ }
+
+ if (status === "pending") {
+ return
+ }
+
+ if (status === "error" || !data || data.length === 0) {
+ return {t("noModelFound")}
+ }
+
+ return (
+
+
+ {t("modal.model.subheading")}
+
+
setSearchTerm(e.target.value)}
+ className="w-full"
+ />
+
+
0 &&
+ selectedModels.length < filteredModels.length
+ }
+ onChange={(e) => handleSelectAll(e.target.checked)}>
+ {t("selectAll")}
+
+
+ {`${selectedModels?.length} / ${data?.length}`}
+
+
+
+
+ {filteredModels.map((model) => (
+ handleModelSelect(model.id, e.target.checked)}>
+ {model?.name || model.id}
+
+ ))}
+
+
+
+
+ )
+}
diff --git a/src/components/Option/Settings/openai.tsx b/src/components/Option/Settings/openai.tsx
index ff3c9b4..5178628 100644
--- a/src/components/Option/Settings/openai.tsx
+++ b/src/components/Option/Settings/openai.tsx
@@ -8,7 +8,8 @@ import {
updateOpenAIConfig
} from "@/db/openai"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
-import { Pencil, Trash2, Plus } from "lucide-react"
+import { Pencil, Trash2, RotateCwIcon } from "lucide-react"
+import { OpenAIFetchModel } from "./openai-fetch-model"
export const OpenAIApp = () => {
const { t } = useTranslation("openai")
@@ -16,6 +17,8 @@ export const OpenAIApp = () => {
const [editingConfig, setEditingConfig] = useState(null)
const queryClient = useQueryClient()
const [form] = Form.useForm()
+ const [openaiId, setOpenaiId] = useState(null)
+ const [openModelModal, setOpenModelModal] = useState(false)
const { data: configs, isLoading } = useQuery({
queryKey: ["openAIConfigs"],
@@ -24,12 +27,14 @@ export const OpenAIApp = () => {
const addMutation = useMutation({
mutationFn: addOpenAICofig,
- onSuccess: () => {
+ onSuccess: (data) => {
queryClient.invalidateQueries({
queryKey: ["openAIConfigs"]
})
setOpen(false)
message.success(t("addSuccess"))
+ setOpenaiId(data)
+ setOpenModelModal(true)
}
})
@@ -129,6 +134,18 @@ export const OpenAIApp = () => {
+
+
+
+
+
+ setOpenModelModal(false)}>
+ {openaiId ? (
+
+ ) : null}
+
)
diff --git a/src/db/models.ts b/src/db/models.ts
new file mode 100644
index 0000000..207fe97
--- /dev/null
+++ b/src/db/models.ts
@@ -0,0 +1,176 @@
+import { getOpenAIConfigById as providerInfo } from "./openai"
+
+type Model = {
+ id: string
+ model_id: string
+ name: string
+ provider_id: string
+ lookup: string
+ db_type: string
+}
+export const generateID = () => {
+ return "model-xxxx-xxxx-xxx-xxxx".replace(/[x]/g, () => {
+ const r = Math.floor(Math.random() * 16)
+ return r.toString(16)
+ })
+}
+
+export const removeModelPrefix = (id: string) => {
+ return id.replace(/^model-/, "")
+}
+export class ModelDb {
+ db: chrome.storage.StorageArea
+
+ constructor() {
+ this.db = chrome.storage.local
+ }
+
+ getAll = async (): Promise => {
+ return new Promise((resolve, reject) => {
+ this.db.get(null, (result) => {
+ if (chrome.runtime.lastError) {
+ reject(chrome.runtime.lastError)
+ } else {
+ const data = Object.keys(result).map((key) => result[key])
+ resolve(data)
+ }
+ })
+ })
+ }
+
+ create = async (model: Model): Promise => {
+ return new Promise((resolve, reject) => {
+ this.db.set({ [model.id]: model }, () => {
+ if (chrome.runtime.lastError) {
+ reject(chrome.runtime.lastError)
+ } else {
+ resolve()
+ }
+ })
+ })
+ }
+
+ getById = async (id: string): Promise => {
+ return new Promise((resolve, reject) => {
+ this.db.get(id, (result) => {
+ if (chrome.runtime.lastError) {
+ reject(chrome.runtime.lastError)
+ } else {
+ resolve(result[id])
+ }
+ })
+ })
+ }
+
+ update = async (model: Model): Promise => {
+ return new Promise((resolve, reject) => {
+ this.db.set({ [model.id]: model }, () => {
+ if (chrome.runtime.lastError) {
+ reject(chrome.runtime.lastError)
+ } else {
+ resolve()
+ }
+ })
+ })
+ }
+
+ delete = async (id: string): Promise => {
+ return new Promise((resolve, reject) => {
+ this.db.remove(id, () => {
+ if (chrome.runtime.lastError) {
+ reject(chrome.runtime.lastError)
+ } else {
+ resolve()
+ }
+ })
+ })
+ }
+
+ deleteAll = async (): Promise => {
+ return new Promise((resolve, reject) => {
+ this.db.clear(() => {
+ if (chrome.runtime.lastError) {
+ reject(chrome.runtime.lastError)
+ } else {
+ resolve()
+ }
+ })
+ })
+ }
+}
+
+export const createManyModels = async (
+ data: { model_id: string; name: string; provider_id: string }[]
+) => {
+ const db = new ModelDb()
+
+ const models = data.map((item) => {
+ return {
+ ...item,
+ lookup: `${item.model_id}_${item.provider_id}`,
+ id: `${item.model_id}_${generateID()}`,
+ db_type: "openai_model"
+ }
+ })
+
+ for (const model of models) {
+ const isExist = await isLookupExist(model.lookup)
+
+ if (isExist) {
+ continue
+ }
+
+ await db.create(model)
+ }
+}
+
+export const createModel = async (
+ model_id: string,
+ name: string,
+ provider_id: string
+) => {
+ const db = new ModelDb()
+ const id = generateID()
+ const model: Model = {
+ id: `${model_id}_${id}`,
+ model_id,
+ name,
+ provider_id,
+ lookup: `${model_id}_${provider_id}`,
+ db_type: "openai_model"
+ }
+ await db.create(model)
+ return model
+}
+
+export const getModelInfo = async (id: string) => {
+ const db = new ModelDb()
+ const model = await db.getById(id)
+ return model
+}
+
+export const getAllCustomModels = async () => {
+ const db = new ModelDb()
+ const models = (await db.getAll()).filter(
+ (model) => model.db_type === "openai_model"
+ )
+ const modelsWithProvider = await Promise.all(
+ models.map(async (model) => {
+ const provider = await providerInfo(model.provider_id)
+ return { ...model, provider }
+ })
+ )
+ return modelsWithProvider
+}
+
+export const deleteModel = async (id: string) => {
+ const db = new ModelDb()
+ await db.delete(id)
+}
+
+export const isLookupExist = async (lookup: string) => {
+ const db = new ModelDb()
+ const models = await db.getAll()
+ const model = models.find((model) => model.lookup === lookup)
+ return model ? true : false
+}
diff --git a/src/db/openai.ts b/src/db/openai.ts
index 501ecfd..45963cf 100644
--- a/src/db/openai.ts
+++ b/src/db/openai.ts
@@ -1,9 +1,12 @@
+import { cleanUrl } from "@/libs/clean-url"
+
type OpenAIModelConfig = {
id: string
name: string
baseUrl: string
apiKey?: string
createdAt: number
+ db_type: string
}
export const generateID = () => {
return "openai-xxxx-xxx-xxxx".replace(/[x]/g, () => {
@@ -95,9 +98,10 @@ export const addOpenAICofig = async ({ name, baseUrl, apiKey }: { name: string,
const config: OpenAIModelConfig = {
id,
name,
- baseUrl,
+ baseUrl: cleanUrl(baseUrl),
apiKey,
- createdAt: Date.now()
+ createdAt: Date.now(),
+ db_type: "openai"
}
await openaiDb.create(config)
return id
@@ -107,7 +111,7 @@ export const addOpenAICofig = async ({ name, baseUrl, apiKey }: { name: string,
export const getAllOpenAIConfig = async () => {
const openaiDb = new OpenAIModelDb()
const configs = await openaiDb.getAll()
- return configs
+ return configs.filter(config => config.db_type === "openai")
}
export const updateOpenAIConfig = async ({ id, name, baseUrl, apiKey }: { id: string, name: string, baseUrl: string, apiKey: string }) => {
@@ -115,9 +119,10 @@ export const updateOpenAIConfig = async ({ id, name, baseUrl, apiKey }: { id: st
const config: OpenAIModelConfig = {
id,
name,
- baseUrl,
+ baseUrl: cleanUrl(baseUrl),
apiKey,
- createdAt: Date.now()
+ createdAt: Date.now(),
+ db_type: "openai"
}
await openaiDb.update(config)
@@ -137,10 +142,18 @@ export const updateOpenAIConfigApiKey = async (id: string, { name, baseUrl, apiK
const config: OpenAIModelConfig = {
id,
name,
- baseUrl,
+ baseUrl: cleanUrl(baseUrl),
apiKey,
- createdAt: Date.now()
+ createdAt: Date.now(),
+ db_type: "openai"
}
await openaiDb.update(config)
+}
+
+
+export const getOpenAIConfigById = async (id: string) => {
+ const openaiDb = new OpenAIModelDb()
+ const config = await openaiDb.getById(id)
+ return config
}
\ No newline at end of file
diff --git a/src/libs/openai.ts b/src/libs/openai.ts
new file mode 100644
index 0000000..8b6230e
--- /dev/null
+++ b/src/libs/openai.ts
@@ -0,0 +1,25 @@
+type Model = {
+ id: string
+ name?: string
+}
+
+export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => {
+ const url = `${baseUrl}/models`
+ const headers = apiKey
+ ? {
+ Authorization: `Bearer ${apiKey}`
+ }
+ : {}
+
+ const res = await fetch(url, {
+ headers
+ })
+
+ if (!res.ok) {
+ return []
+ }
+
+ const data = (await res.json()) as { data: Model[] }
+
+ return data.data
+}