diff --git a/src/components/Option/Settings/openai.tsx b/src/components/Option/Settings/openai.tsx index fe4dcdb..a0807ed 100644 --- a/src/components/Option/Settings/openai.tsx +++ b/src/components/Option/Settings/openai.tsx @@ -47,7 +47,7 @@ export const OpenAIApp = () => { }) setOpen(false) message.success(t("addSuccess")) - const noPopupProvider = ["lmstudio", "llamafile"] + const noPopupProvider = ["lmstudio", "llamafile", "ollama2"] if (!noPopupProvider.includes(provider)) { setOpenaiId(data) setOpenModelModal(true) diff --git a/src/db/models.ts b/src/db/models.ts index eaa7650..012379c 100644 --- a/src/db/models.ts +++ b/src/db/models.ts @@ -25,6 +25,7 @@ export const removeModelSuffix = (id: string) => { .replace(/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/, "") .replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "") .replace(/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "") + .replace(/_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "") } export const isLMStudioModel = (model: string) => { const lmstudioModelRegex = @@ -37,7 +38,11 @@ export const isLlamafileModel = (model: string) => { /_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/ return llamafileModelRegex.test(model) } - +export const isOllamaModel = (model: string) => { + const ollamaModelRegex = + /_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/ + return ollamaModelRegex.test(model) +} export const getLMStudioModelId = ( model: string ): { model_id: string; provider_id: string } => { @@ -51,7 +56,19 @@ export const getLMStudioModelId = ( } return null } - +export const getOllamaModelId = ( + model: string +): { model_id: string; provider_id: string } => { + const ollamaModelRegex = + /_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/ + const match = model.match(ollamaModelRegex) + if (match) { + const modelId = match[0] + const providerId = match[0].replace("_ollama2_openai-", "") + return { model_id: modelId, provider_id: providerId } + } + return null +} export const getLlamafileModelId = ( model: string ): { model_id: string; provider_id: string } => { @@ -74,6 +91,10 @@ export const isCustomModel = (model: string) => { return true } + if (isOllamaModel(model)) { + return true + } + const customModelRegex = /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/ return customModelRegex.test(model) @@ -246,6 +267,25 @@ export const getModelInfo = async (id: string) => { } } + + if (isOllamaModel(id)) { + const ollamaId = getOllamaModelId(id) + if (!ollamaId) { + throw new Error("Invalid LMStudio model ID") + } + return { + model_id: id.replace( + /_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, + "" + ), + provider_id: `openai-${ollamaId.provider_id}`, + name: id.replace( + /_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, + "" + ) + } + } + const model = await db.getById(id) return model } @@ -309,6 +349,27 @@ export const dynamicFetchLMStudio = async ({ return lmstudioModels } +export const dynamicFetchOllama2 = async ({ + baseUrl, + providerId +}: { + baseUrl: string + providerId: string +}) => { + const models = await getAllOpenAIModels(baseUrl) + const ollama2Models = models.map((e) => { + return { + name: e?.name || e?.id, + id: `${e?.id}_ollama2_${providerId}`, + provider: providerId, + lookup: `${e?.id}_${providerId}`, + provider_id: providerId + } + }) + + return ollama2Models +} + export const dynamicFetchLlamafile = async ({ baseUrl, providerId @@ -360,21 +421,33 @@ export const ollamaFormatAllCustomModels = async ( }) ) + const ollamaModelsPromises = allProviders.map((provider) => ( + dynamicFetchOllama2({ + baseUrl: provider.baseUrl, + providerId: provider.id + }) + )) + const lmModelsFetch = await Promise.all(lmModelsPromises) const llamafileModelsFetch = await Promise.all(llamafileModelsPromises) + const ollamaModelsFetch = await Promise.all(ollamaModelsPromises) + const lmModels = lmModelsFetch.flat() const llamafileModels = llamafileModelsFetch.flat() + const ollama2Models = ollamaModelsFetch.flat() + // merge allModels and lmModels const allModlesWithLMStudio = [ ...(modelType !== "all" ? allModles.filter((model) => model.model_type === modelType) : allModles), ...lmModels, - ...llamafileModels + ...llamafileModels, + ...ollama2Models ] const ollamaModels = allModlesWithLMStudio.map((model) => { diff --git a/src/models/embedding.ts b/src/models/embedding.ts index 03eb663..0c1026b 100644 --- a/src/models/embedding.ts +++ b/src/models/embedding.ts @@ -1,4 +1,4 @@ -import { getModelInfo, isCustomModel } from "@/db/models" +import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models" import { OllamaEmbeddingsPageAssist } from "./OllamaEmbedding" import { OAIEmbedding } from "./OAIEmbedding" import { getOpenAIConfigById } from "@/db/openai" diff --git a/src/models/index.ts b/src/models/index.ts index 4798f78..135025f 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -1,8 +1,9 @@ -import { getModelInfo, isCustomModel } from "@/db/models" +import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models" import { ChatChromeAI } from "./ChatChromeAi" import { ChatOllama } from "./ChatOllama" import { getOpenAIConfigById } from "@/db/openai" import { ChatOpenAI } from "@langchain/openai" +import { urlRewriteRuntime } from "@/libs/runtime" export const pageAssistModel = async ({ model, @@ -43,6 +44,10 @@ export const pageAssistModel = async ({ const modelInfo = await getModelInfo(model) const providerInfo = await getOpenAIConfigById(modelInfo.provider_id) + if (isOllamaModel(model)) { + await urlRewriteRuntime(providerInfo.baseUrl || "") + } + return new ChatOpenAI({ modelName: modelInfo.model_id, openAIApiKey: providerInfo.apiKey || "temp", diff --git a/src/utils/oai-api-providers.ts b/src/utils/oai-api-providers.ts index 52bfcf9..40c1e21 100644 --- a/src/utils/oai-api-providers.ts +++ b/src/utils/oai-api-providers.ts @@ -1,14 +1,24 @@ export const OAI_API_PROVIDERS = [ + { + label: "Custom", + value: "custom", + baseUrl: "" + }, { label: "LM Studio", value: "lmstudio", baseUrl: "http://localhost:1234/v1" }, { - label: "LlamaFile", + label: "Llamafile", value: "llamafile", baseUrl: "http://127.0.0.1:8080/v1" }, + { + label: "Ollama", + value: "ollama2", + baseUrl: "http://localhost:11434/v1" + }, { label: "OpenAI", value: "openai", @@ -34,9 +44,5 @@ export const OAI_API_PROVIDERS = [ value: "openrouter", baseUrl: "https://openrouter.ai/api/v1" }, - { - label: "Custom", - value: "custom", - baseUrl: "" - } + ] \ No newline at end of file