feat: Add multiple Ollama support

Adds support for using Ollama 2 as a model provider. This includes:

- Adding Ollama 2 to the list of supported providers in the UI
- Updating the model identification logic to properly handle Ollama 2 models
- Modifying the model loading and runtime configuration to work with Ollama 2
- Implementing Ollama 2 specific functionality in the embedding and chat models

This change allows users to leverage the capabilities of Ollama 2 for both embeddings and conversational AI tasks.
This commit is contained in:
n4ze3m 2024-11-10 15:31:28 +05:30
parent c6a62126dd
commit a7f461da0b
5 changed files with 96 additions and 12 deletions

View File

@ -47,7 +47,7 @@ export const OpenAIApp = () => {
})
setOpen(false)
message.success(t("addSuccess"))
const noPopupProvider = ["lmstudio", "llamafile"]
const noPopupProvider = ["lmstudio", "llamafile", "ollama2"]
if (!noPopupProvider.includes(provider)) {
setOpenaiId(data)
setOpenModelModal(true)

View File

@ -25,6 +25,7 @@ export const removeModelSuffix = (id: string) => {
.replace(/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/, "")
.replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "")
.replace(/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "")
.replace(/_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "")
}
export const isLMStudioModel = (model: string) => {
const lmstudioModelRegex =
@ -37,7 +38,11 @@ export const isLlamafileModel = (model: string) => {
/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/
return llamafileModelRegex.test(model)
}
export const isOllamaModel = (model: string) => {
const ollamaModelRegex =
/_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/
return ollamaModelRegex.test(model)
}
export const getLMStudioModelId = (
model: string
): { model_id: string; provider_id: string } => {
@ -51,7 +56,19 @@ export const getLMStudioModelId = (
}
return null
}
export const getOllamaModelId = (
model: string
): { model_id: string; provider_id: string } => {
const ollamaModelRegex =
/_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/
const match = model.match(ollamaModelRegex)
if (match) {
const modelId = match[0]
const providerId = match[0].replace("_ollama2_openai-", "")
return { model_id: modelId, provider_id: providerId }
}
return null
}
export const getLlamafileModelId = (
model: string
): { model_id: string; provider_id: string } => {
@ -74,6 +91,10 @@ export const isCustomModel = (model: string) => {
return true
}
if (isOllamaModel(model)) {
return true
}
const customModelRegex =
/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/
return customModelRegex.test(model)
@ -246,6 +267,25 @@ export const getModelInfo = async (id: string) => {
}
}
if (isOllamaModel(id)) {
const ollamaId = getOllamaModelId(id)
if (!ollamaId) {
throw new Error("Invalid LMStudio model ID")
}
return {
model_id: id.replace(
/_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/,
""
),
provider_id: `openai-${ollamaId.provider_id}`,
name: id.replace(
/_ollama2_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/,
""
)
}
}
const model = await db.getById(id)
return model
}
@ -309,6 +349,27 @@ export const dynamicFetchLMStudio = async ({
return lmstudioModels
}
export const dynamicFetchOllama2 = async ({
baseUrl,
providerId
}: {
baseUrl: string
providerId: string
}) => {
const models = await getAllOpenAIModels(baseUrl)
const ollama2Models = models.map((e) => {
return {
name: e?.name || e?.id,
id: `${e?.id}_ollama2_${providerId}`,
provider: providerId,
lookup: `${e?.id}_${providerId}`,
provider_id: providerId
}
})
return ollama2Models
}
export const dynamicFetchLlamafile = async ({
baseUrl,
providerId
@ -360,21 +421,33 @@ export const ollamaFormatAllCustomModels = async (
})
)
const ollamaModelsPromises = allProviders.map((provider) => (
dynamicFetchOllama2({
baseUrl: provider.baseUrl,
providerId: provider.id
})
))
const lmModelsFetch = await Promise.all(lmModelsPromises)
const llamafileModelsFetch = await Promise.all(llamafileModelsPromises)
const ollamaModelsFetch = await Promise.all(ollamaModelsPromises)
const lmModels = lmModelsFetch.flat()
const llamafileModels = llamafileModelsFetch.flat()
const ollama2Models = ollamaModelsFetch.flat()
// merge allModels and lmModels
const allModlesWithLMStudio = [
...(modelType !== "all"
? allModles.filter((model) => model.model_type === modelType)
: allModles),
...lmModels,
...llamafileModels
...llamafileModels,
...ollama2Models
]
const ollamaModels = allModlesWithLMStudio.map((model) => {

View File

@ -1,4 +1,4 @@
import { getModelInfo, isCustomModel } from "@/db/models"
import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models"
import { OllamaEmbeddingsPageAssist } from "./OllamaEmbedding"
import { OAIEmbedding } from "./OAIEmbedding"
import { getOpenAIConfigById } from "@/db/openai"

View File

@ -1,8 +1,9 @@
import { getModelInfo, isCustomModel } from "@/db/models"
import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models"
import { ChatChromeAI } from "./ChatChromeAi"
import { ChatOllama } from "./ChatOllama"
import { getOpenAIConfigById } from "@/db/openai"
import { ChatOpenAI } from "@langchain/openai"
import { urlRewriteRuntime } from "@/libs/runtime"
export const pageAssistModel = async ({
model,
@ -43,6 +44,10 @@ export const pageAssistModel = async ({
const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
if (isOllamaModel(model)) {
await urlRewriteRuntime(providerInfo.baseUrl || "")
}
return new ChatOpenAI({
modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp",

View File

@ -1,14 +1,24 @@
export const OAI_API_PROVIDERS = [
{
label: "Custom",
value: "custom",
baseUrl: ""
},
{
label: "LM Studio",
value: "lmstudio",
baseUrl: "http://localhost:1234/v1"
},
{
label: "LlamaFile",
label: "Llamafile",
value: "llamafile",
baseUrl: "http://127.0.0.1:8080/v1"
},
{
label: "Ollama",
value: "ollama2",
baseUrl: "http://localhost:11434/v1"
},
{
label: "OpenAI",
value: "openai",
@ -34,9 +44,5 @@ export const OAI_API_PROVIDERS = [
value: "openrouter",
baseUrl: "https://openrouter.ai/api/v1"
},
{
label: "Custom",
value: "custom",
baseUrl: ""
}
]