diff --git a/bun.lockb b/bun.lockb
index 546445c..deaa93f 100644
Binary files a/bun.lockb and b/bun.lockb differ
diff --git a/package.json b/package.json
index 43057df..c088aa4 100644
--- a/package.json
+++ b/package.json
@@ -19,6 +19,7 @@
"@headlessui/react": "^1.7.18",
"@heroicons/react": "^2.1.1",
"@langchain/community": "^0.0.41",
+ "@langchain/openai": "0.0.24",
"@mantine/form": "^7.5.0",
"@mantine/hooks": "^7.5.3",
"@mozilla/readability": "^0.5.0",
@@ -39,6 +40,7 @@
"lucide-react": "^0.350.0",
"mammoth": "^1.7.2",
"ml-distance": "^4.0.1",
+ "openai": "^4.65.0",
"pdfjs-dist": "4.0.379",
"property-information": "^6.4.1",
"pubsub-js": "^1.9.4",
diff --git a/src/components/Common/ModelSelect.tsx b/src/components/Common/ModelSelect.tsx
index e39a9f6..1a9e8d0 100644
--- a/src/components/Common/ModelSelect.tsx
+++ b/src/components/Common/ModelSelect.tsx
@@ -38,10 +38,10 @@ export const ModelSelect: React.FC = () => {
),
onClick: () => {
- if (selectedModel === d.name) {
+ if (selectedModel === d.model) {
setSelectedModel(null)
} else {
- setSelectedModel(d.name)
+ setSelectedModel(d.model)
}
}
})) || [],
diff --git a/src/components/Common/ProviderIcon.tsx b/src/components/Common/ProviderIcon.tsx
index a97776f..83a8cca 100644
--- a/src/components/Common/ProviderIcon.tsx
+++ b/src/components/Common/ProviderIcon.tsx
@@ -1,4 +1,4 @@
-import { ChromeIcon } from "lucide-react"
+import { ChromeIcon, CloudCog } from "lucide-react"
import { OllamaIcon } from "../Icons/Ollama"
export const ProviderIcons = ({
@@ -11,6 +11,8 @@ export const ProviderIcons = ({
switch (provider) {
case "chrome":
return
+ case "custom":
+ return
default:
return
}
diff --git a/src/components/Layouts/Header.tsx b/src/components/Layouts/Header.tsx
index 65fab8e..67338c2 100644
--- a/src/components/Layouts/Header.tsx
+++ b/src/components/Layouts/Header.tsx
@@ -11,7 +11,6 @@ import {
} from "lucide-react"
import { useTranslation } from "react-i18next"
import { useLocation, NavLink } from "react-router-dom"
-import { OllamaIcon } from "../Icons/Ollama"
import { SelectedKnowledge } from "../Option/Knowledge/SelectedKnwledge"
import { ModelSelect } from "../Common/ModelSelect"
import { PromptSelect } from "../Common/PromptSelect"
diff --git a/src/components/Option/Models/index.tsx b/src/components/Option/Models/index.tsx
index af3c866..bd14c30 100644
--- a/src/components/Option/Models/index.tsx
+++ b/src/components/Option/Models/index.tsx
@@ -1,9 +1,5 @@
-import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
+import { useMutation, } from "@tanstack/react-query"
import {
- Skeleton,
- Table,
- Tag,
- Tooltip,
notification,
Modal,
Input,
@@ -23,7 +19,7 @@ dayjs.extend(relativeTime)
export const ModelsBody = () => {
const [open, setOpen] = useState(false)
const [segmented, setSegmented] = useState("ollama")
-
+
const { t } = useTranslation(["settings", "common", "openai"])
const form = useForm({
diff --git a/src/db/models.ts b/src/db/models.ts
index 207fe97..ef615a9 100644
--- a/src/db/models.ts
+++ b/src/db/models.ts
@@ -18,6 +18,11 @@ export const generateID = () => {
export const removeModelPrefix = (id: string) => {
return id.replace(/^model-/, "")
}
+
+export const isCustomModel = (model: string) => {
+ const customModelRegex = /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/
+ return customModelRegex.test(model)
+}
export class ModelDb {
db: chrome.storage.StorageArea
@@ -174,3 +179,30 @@ export const isLookupExist = async (lookup: string) => {
const model = models.find((model) => model.lookup === lookup)
return model ? true : false
}
+
+
+export const ollamaFormatAllCustomModels = async () => {
+
+ const allModles = await getAllCustomModels()
+
+ const ollamaModels = allModles.map((model) => {
+ return {
+ name: model.name,
+ model: model.id,
+ modified_at: "",
+ provider: "custom",
+ size: 0,
+ digest: "",
+ details: {
+ parent_model: "",
+ format: "",
+ family: "",
+ families: [],
+ parameter_size: "",
+ quantization_level: ""
+ }
+ }
+ })
+
+ return ollamaModels
+}
\ No newline at end of file
diff --git a/src/models/index.ts b/src/models/index.ts
index ce3ab39..07c134e 100644
--- a/src/models/index.ts
+++ b/src/models/index.ts
@@ -1,5 +1,8 @@
+import { getModelInfo, isCustomModel } from "@/db/models"
import { ChatChromeAI } from "./ChatChromeAi"
import { ChatOllama } from "./ChatOllama"
+import { getOpenAIConfigById } from "@/db/openai"
+import { ChatOpenAI } from "@langchain/openai"
export const pageAssistModel = async ({
model,
@@ -22,23 +25,49 @@ export const pageAssistModel = async ({
seed?: number
numGpu?: number
}) => {
- switch (model) {
- case "chrome::gemini-nano::page-assist":
- return new ChatChromeAI({
- temperature,
- topK
- })
- default:
- return new ChatOllama({
- baseUrl,
- keepAlive,
- temperature,
- topK,
- topP,
- numCtx,
- seed,
- model,
- numGpu
- })
+
+ if (model === "chrome::gemini-nano::page-assist") {
+ return new ChatChromeAI({
+ temperature,
+ topK
+ })
}
+
+
+ const isCustom = isCustomModel(model)
+
+ console.log("isCustom", isCustom, model)
+
+ if (isCustom) {
+ const modelInfo = await getModelInfo(model)
+ const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
+
+ return new ChatOpenAI({
+ modelName: modelInfo.model_id,
+ openAIApiKey: providerInfo.apiKey || "",
+ temperature,
+ topP,
+ configuration: {
+ apiKey: providerInfo.apiKey || "",
+ baseURL: providerInfo.baseUrl || "",
+ }
+ }) as any
+ }
+
+
+
+ return new ChatOllama({
+ baseUrl,
+ keepAlive,
+ temperature,
+ topK,
+ topP,
+ numCtx,
+ seed,
+ model,
+ numGpu
+ })
+
+
+
}
diff --git a/src/services/ollama.ts b/src/services/ollama.ts
index ee58b0e..5aff8c3 100644
--- a/src/services/ollama.ts
+++ b/src/services/ollama.ts
@@ -4,6 +4,7 @@ import { urlRewriteRuntime } from "../libs/runtime"
import { getChromeAIModel } from "./chrome"
import { setNoOfRetrievedDocs, setTotalFilePerKB } from "./app"
import fetcher from "@/libs/fetcher"
+import { ollamaFormatAllCustomModels } from "@/db/models"
const storage = new Storage()
@@ -193,9 +194,13 @@ export const fetchChatModels = async ({
}
})
const chromeModel = await getChromeAIModel()
+
+ const customModels = await ollamaFormatAllCustomModels()
+
return [
...chatModels,
- ...chromeModel
+ ...chromeModel,
+ ...customModels
]
} catch (e) {
console.error(e)
@@ -207,10 +212,11 @@ export const fetchChatModels = async ({
}
})
const chromeModel = await getChromeAIModel()
-
+ const customModels = await ollamaFormatAllCustomModels()
return [
...models,
- ...chromeModel
+ ...chromeModel,
+ ...customModels
]
}
}