diff --git a/bun.lockb b/bun.lockb index 546445c..deaa93f 100644 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 43057df..c088aa4 100644 --- a/package.json +++ b/package.json @@ -19,6 +19,7 @@ "@headlessui/react": "^1.7.18", "@heroicons/react": "^2.1.1", "@langchain/community": "^0.0.41", + "@langchain/openai": "0.0.24", "@mantine/form": "^7.5.0", "@mantine/hooks": "^7.5.3", "@mozilla/readability": "^0.5.0", @@ -39,6 +40,7 @@ "lucide-react": "^0.350.0", "mammoth": "^1.7.2", "ml-distance": "^4.0.1", + "openai": "^4.65.0", "pdfjs-dist": "4.0.379", "property-information": "^6.4.1", "pubsub-js": "^1.9.4", diff --git a/src/components/Common/ModelSelect.tsx b/src/components/Common/ModelSelect.tsx index e39a9f6..1a9e8d0 100644 --- a/src/components/Common/ModelSelect.tsx +++ b/src/components/Common/ModelSelect.tsx @@ -38,10 +38,10 @@ export const ModelSelect: React.FC = () => { ), onClick: () => { - if (selectedModel === d.name) { + if (selectedModel === d.model) { setSelectedModel(null) } else { - setSelectedModel(d.name) + setSelectedModel(d.model) } } })) || [], diff --git a/src/components/Common/ProviderIcon.tsx b/src/components/Common/ProviderIcon.tsx index a97776f..83a8cca 100644 --- a/src/components/Common/ProviderIcon.tsx +++ b/src/components/Common/ProviderIcon.tsx @@ -1,4 +1,4 @@ -import { ChromeIcon } from "lucide-react" +import { ChromeIcon, CloudCog } from "lucide-react" import { OllamaIcon } from "../Icons/Ollama" export const ProviderIcons = ({ @@ -11,6 +11,8 @@ export const ProviderIcons = ({ switch (provider) { case "chrome": return + case "custom": + return default: return } diff --git a/src/components/Layouts/Header.tsx b/src/components/Layouts/Header.tsx index 65fab8e..67338c2 100644 --- a/src/components/Layouts/Header.tsx +++ b/src/components/Layouts/Header.tsx @@ -11,7 +11,6 @@ import { } from "lucide-react" import { useTranslation } from "react-i18next" import { useLocation, NavLink } from "react-router-dom" -import { OllamaIcon } from "../Icons/Ollama" import { SelectedKnowledge } from "../Option/Knowledge/SelectedKnwledge" import { ModelSelect } from "../Common/ModelSelect" import { PromptSelect } from "../Common/PromptSelect" diff --git a/src/components/Option/Models/index.tsx b/src/components/Option/Models/index.tsx index af3c866..bd14c30 100644 --- a/src/components/Option/Models/index.tsx +++ b/src/components/Option/Models/index.tsx @@ -1,9 +1,5 @@ -import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" +import { useMutation, } from "@tanstack/react-query" import { - Skeleton, - Table, - Tag, - Tooltip, notification, Modal, Input, @@ -23,7 +19,7 @@ dayjs.extend(relativeTime) export const ModelsBody = () => { const [open, setOpen] = useState(false) const [segmented, setSegmented] = useState("ollama") - + const { t } = useTranslation(["settings", "common", "openai"]) const form = useForm({ diff --git a/src/db/models.ts b/src/db/models.ts index 207fe97..ef615a9 100644 --- a/src/db/models.ts +++ b/src/db/models.ts @@ -18,6 +18,11 @@ export const generateID = () => { export const removeModelPrefix = (id: string) => { return id.replace(/^model-/, "") } + +export const isCustomModel = (model: string) => { + const customModelRegex = /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/ + return customModelRegex.test(model) +} export class ModelDb { db: chrome.storage.StorageArea @@ -174,3 +179,30 @@ export const isLookupExist = async (lookup: string) => { const model = models.find((model) => model.lookup === lookup) return model ? true : false } + + +export const ollamaFormatAllCustomModels = async () => { + + const allModles = await getAllCustomModels() + + const ollamaModels = allModles.map((model) => { + return { + name: model.name, + model: model.id, + modified_at: "", + provider: "custom", + size: 0, + digest: "", + details: { + parent_model: "", + format: "", + family: "", + families: [], + parameter_size: "", + quantization_level: "" + } + } + }) + + return ollamaModels +} \ No newline at end of file diff --git a/src/models/index.ts b/src/models/index.ts index ce3ab39..07c134e 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -1,5 +1,8 @@ +import { getModelInfo, isCustomModel } from "@/db/models" import { ChatChromeAI } from "./ChatChromeAi" import { ChatOllama } from "./ChatOllama" +import { getOpenAIConfigById } from "@/db/openai" +import { ChatOpenAI } from "@langchain/openai" export const pageAssistModel = async ({ model, @@ -22,23 +25,49 @@ export const pageAssistModel = async ({ seed?: number numGpu?: number }) => { - switch (model) { - case "chrome::gemini-nano::page-assist": - return new ChatChromeAI({ - temperature, - topK - }) - default: - return new ChatOllama({ - baseUrl, - keepAlive, - temperature, - topK, - topP, - numCtx, - seed, - model, - numGpu - }) + + if (model === "chrome::gemini-nano::page-assist") { + return new ChatChromeAI({ + temperature, + topK + }) } + + + const isCustom = isCustomModel(model) + + console.log("isCustom", isCustom, model) + + if (isCustom) { + const modelInfo = await getModelInfo(model) + const providerInfo = await getOpenAIConfigById(modelInfo.provider_id) + + return new ChatOpenAI({ + modelName: modelInfo.model_id, + openAIApiKey: providerInfo.apiKey || "", + temperature, + topP, + configuration: { + apiKey: providerInfo.apiKey || "", + baseURL: providerInfo.baseUrl || "", + } + }) as any + } + + + + return new ChatOllama({ + baseUrl, + keepAlive, + temperature, + topK, + topP, + numCtx, + seed, + model, + numGpu + }) + + + } diff --git a/src/services/ollama.ts b/src/services/ollama.ts index ee58b0e..5aff8c3 100644 --- a/src/services/ollama.ts +++ b/src/services/ollama.ts @@ -4,6 +4,7 @@ import { urlRewriteRuntime } from "../libs/runtime" import { getChromeAIModel } from "./chrome" import { setNoOfRetrievedDocs, setTotalFilePerKB } from "./app" import fetcher from "@/libs/fetcher" +import { ollamaFormatAllCustomModels } from "@/db/models" const storage = new Storage() @@ -193,9 +194,13 @@ export const fetchChatModels = async ({ } }) const chromeModel = await getChromeAIModel() + + const customModels = await ollamaFormatAllCustomModels() + return [ ...chatModels, - ...chromeModel + ...chromeModel, + ...customModels ] } catch (e) { console.error(e) @@ -207,10 +212,11 @@ export const fetchChatModels = async ({ } }) const chromeModel = await getChromeAIModel() - + const customModels = await ollamaFormatAllCustomModels() return [ ...models, - ...chromeModel + ...chromeModel, + ...customModels ] } }