feat: add OpenAI model support

Adds support for OpenAI models, allowing users to leverage various OpenAI models directly from the application. This includes custom OpenAI models and OpenAI-specific configurations for seamless integration.
This commit is contained in:
n4ze3m 2024-09-29 19:57:26 +05:30
parent 2a2610afb8
commit c8620637f8
9 changed files with 97 additions and 31 deletions

BIN
bun.lockb

Binary file not shown.

View File

@ -19,6 +19,7 @@
"@headlessui/react": "^1.7.18", "@headlessui/react": "^1.7.18",
"@heroicons/react": "^2.1.1", "@heroicons/react": "^2.1.1",
"@langchain/community": "^0.0.41", "@langchain/community": "^0.0.41",
"@langchain/openai": "0.0.24",
"@mantine/form": "^7.5.0", "@mantine/form": "^7.5.0",
"@mantine/hooks": "^7.5.3", "@mantine/hooks": "^7.5.3",
"@mozilla/readability": "^0.5.0", "@mozilla/readability": "^0.5.0",
@ -39,6 +40,7 @@
"lucide-react": "^0.350.0", "lucide-react": "^0.350.0",
"mammoth": "^1.7.2", "mammoth": "^1.7.2",
"ml-distance": "^4.0.1", "ml-distance": "^4.0.1",
"openai": "^4.65.0",
"pdfjs-dist": "4.0.379", "pdfjs-dist": "4.0.379",
"property-information": "^6.4.1", "property-information": "^6.4.1",
"pubsub-js": "^1.9.4", "pubsub-js": "^1.9.4",

View File

@ -38,10 +38,10 @@ export const ModelSelect: React.FC = () => {
</div> </div>
), ),
onClick: () => { onClick: () => {
if (selectedModel === d.name) { if (selectedModel === d.model) {
setSelectedModel(null) setSelectedModel(null)
} else { } else {
setSelectedModel(d.name) setSelectedModel(d.model)
} }
} }
})) || [], })) || [],

View File

@ -1,4 +1,4 @@
import { ChromeIcon } from "lucide-react" import { ChromeIcon, CloudCog } from "lucide-react"
import { OllamaIcon } from "../Icons/Ollama" import { OllamaIcon } from "../Icons/Ollama"
export const ProviderIcons = ({ export const ProviderIcons = ({
@ -11,6 +11,8 @@ export const ProviderIcons = ({
switch (provider) { switch (provider) {
case "chrome": case "chrome":
return <ChromeIcon className={className} /> return <ChromeIcon className={className} />
case "custom":
return <CloudCog className={className} />
default: default:
return <OllamaIcon className={className} /> return <OllamaIcon className={className} />
} }

View File

@ -11,7 +11,6 @@ import {
} from "lucide-react" } from "lucide-react"
import { useTranslation } from "react-i18next" import { useTranslation } from "react-i18next"
import { useLocation, NavLink } from "react-router-dom" import { useLocation, NavLink } from "react-router-dom"
import { OllamaIcon } from "../Icons/Ollama"
import { SelectedKnowledge } from "../Option/Knowledge/SelectedKnwledge" import { SelectedKnowledge } from "../Option/Knowledge/SelectedKnwledge"
import { ModelSelect } from "../Common/ModelSelect" import { ModelSelect } from "../Common/ModelSelect"
import { PromptSelect } from "../Common/PromptSelect" import { PromptSelect } from "../Common/PromptSelect"

View File

@ -1,9 +1,5 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" import { useMutation, } from "@tanstack/react-query"
import { import {
Skeleton,
Table,
Tag,
Tooltip,
notification, notification,
Modal, Modal,
Input, Input,

View File

@ -18,6 +18,11 @@ export const generateID = () => {
export const removeModelPrefix = (id: string) => { export const removeModelPrefix = (id: string) => {
return id.replace(/^model-/, "") return id.replace(/^model-/, "")
} }
export const isCustomModel = (model: string) => {
const customModelRegex = /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/
return customModelRegex.test(model)
}
export class ModelDb { export class ModelDb {
db: chrome.storage.StorageArea db: chrome.storage.StorageArea
@ -174,3 +179,30 @@ export const isLookupExist = async (lookup: string) => {
const model = models.find((model) => model.lookup === lookup) const model = models.find((model) => model.lookup === lookup)
return model ? true : false return model ? true : false
} }
export const ollamaFormatAllCustomModels = async () => {
const allModles = await getAllCustomModels()
const ollamaModels = allModles.map((model) => {
return {
name: model.name,
model: model.id,
modified_at: "",
provider: "custom",
size: 0,
digest: "",
details: {
parent_model: "",
format: "",
family: "",
families: [],
parameter_size: "",
quantization_level: ""
}
}
})
return ollamaModels
}

View File

@ -1,5 +1,8 @@
import { getModelInfo, isCustomModel } from "@/db/models"
import { ChatChromeAI } from "./ChatChromeAi" import { ChatChromeAI } from "./ChatChromeAi"
import { ChatOllama } from "./ChatOllama" import { ChatOllama } from "./ChatOllama"
import { getOpenAIConfigById } from "@/db/openai"
import { ChatOpenAI } from "@langchain/openai"
export const pageAssistModel = async ({ export const pageAssistModel = async ({
model, model,
@ -22,23 +25,49 @@ export const pageAssistModel = async ({
seed?: number seed?: number
numGpu?: number numGpu?: number
}) => { }) => {
switch (model) {
case "chrome::gemini-nano::page-assist": if (model === "chrome::gemini-nano::page-assist") {
return new ChatChromeAI({ return new ChatChromeAI({
temperature, temperature,
topK topK
}) })
default:
return new ChatOllama({
baseUrl,
keepAlive,
temperature,
topK,
topP,
numCtx,
seed,
model,
numGpu
})
} }
const isCustom = isCustomModel(model)
console.log("isCustom", isCustom, model)
if (isCustom) {
const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
return new ChatOpenAI({
modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "",
temperature,
topP,
configuration: {
apiKey: providerInfo.apiKey || "",
baseURL: providerInfo.baseUrl || "",
}
}) as any
}
return new ChatOllama({
baseUrl,
keepAlive,
temperature,
topK,
topP,
numCtx,
seed,
model,
numGpu
})
} }

View File

@ -4,6 +4,7 @@ import { urlRewriteRuntime } from "../libs/runtime"
import { getChromeAIModel } from "./chrome" import { getChromeAIModel } from "./chrome"
import { setNoOfRetrievedDocs, setTotalFilePerKB } from "./app" import { setNoOfRetrievedDocs, setTotalFilePerKB } from "./app"
import fetcher from "@/libs/fetcher" import fetcher from "@/libs/fetcher"
import { ollamaFormatAllCustomModels } from "@/db/models"
const storage = new Storage() const storage = new Storage()
@ -193,9 +194,13 @@ export const fetchChatModels = async ({
} }
}) })
const chromeModel = await getChromeAIModel() const chromeModel = await getChromeAIModel()
const customModels = await ollamaFormatAllCustomModels()
return [ return [
...chatModels, ...chatModels,
...chromeModel ...chromeModel,
...customModels
] ]
} catch (e) { } catch (e) {
console.error(e) console.error(e)
@ -207,10 +212,11 @@ export const fetchChatModels = async ({
} }
}) })
const chromeModel = await getChromeAIModel() const chromeModel = await getChromeAIModel()
const customModels = await ollamaFormatAllCustomModels()
return [ return [
...models, ...models,
...chromeModel ...chromeModel,
...customModels
] ]
} }
} }