feat: Add LlamaFile support

Add support for LlamaFile, a new model provider that allows users to interact with models stored in LlamaFile format. This includes:

- Adding an icon for LlamaFile in the provider selection menu.
- Updating the model provider selection to include LlamaFile.
- Updating the model handling logic to properly identify and process LlamaFile models.
- Updating the API providers list to include LlamaFile.

This enables users to leverage the capabilities of LlamaFile models within the application.
This commit is contained in:
n4ze3m
2024-11-10 14:02:44 +05:30
parent f52e3d564a
commit c6a62126dd
5 changed files with 117 additions and 2 deletions

View File

@@ -24,6 +24,7 @@ export const removeModelSuffix = (id: string) => {
return id
.replace(/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/, "")
.replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "")
.replace(/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "")
}
export const isLMStudioModel = (model: string) => {
const lmstudioModelRegex =
@@ -31,6 +32,12 @@ export const isLMStudioModel = (model: string) => {
return lmstudioModelRegex.test(model)
}
export const isLlamafileModel = (model: string) => {
const llamafileModelRegex =
/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/
return llamafileModelRegex.test(model)
}
export const getLMStudioModelId = (
model: string
): { model_id: string; provider_id: string } => {
@@ -44,10 +51,29 @@ export const getLMStudioModelId = (
}
return null
}
export const getLlamafileModelId = (
model: string
): { model_id: string; provider_id: string } => {
const llamafileModelRegex =
/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/
const match = model.match(llamafileModelRegex)
if (match) {
const modelId = match[0]
const providerId = match[0].replace("_llamafile_openai-", "")
return { model_id: modelId, provider_id: providerId }
}
return null
}
export const isCustomModel = (model: string) => {
if (isLMStudioModel(model)) {
return true
}
if (isLlamafileModel(model)) {
return true
}
const customModelRegex =
/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/
return customModelRegex.test(model)
@@ -201,6 +227,25 @@ export const getModelInfo = async (id: string) => {
}
}
if (isLlamafileModel(id)) {
const llamafileId = getLlamafileModelId(id)
if (!llamafileId) {
throw new Error("Invalid LMStudio model ID")
}
return {
model_id: id.replace(
/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/,
""
),
provider_id: `openai-${llamafileId.provider_id}`,
name: id.replace(
/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/,
""
)
}
}
const model = await db.getById(id)
return model
}
@@ -264,6 +309,27 @@ export const dynamicFetchLMStudio = async ({
return lmstudioModels
}
export const dynamicFetchLlamafile = async ({
baseUrl,
providerId
}: {
baseUrl: string
providerId: string
}) => {
const models = await getAllOpenAIModels(baseUrl)
const llamafileModels = models.map((e) => {
return {
name: e?.name || e?.id,
id: `${e?.id}_llamafile_${providerId}`,
provider: providerId,
lookup: `${e?.id}_${providerId}`,
provider_id: providerId
}
})
return llamafileModels
}
export const ollamaFormatAllCustomModels = async (
modelType: "all" | "chat" | "embedding" = "all"
) => {
@@ -276,6 +342,10 @@ export const ollamaFormatAllCustomModels = async (
(provider) => provider.provider === "lmstudio"
)
const llamafileProviders = allProviders.filter(
(provider) => provider.provider === "llamafile"
)
const lmModelsPromises = lmstudioProviders.map((provider) =>
dynamicFetchLMStudio({
baseUrl: provider.baseUrl,
@@ -283,16 +353,28 @@ export const ollamaFormatAllCustomModels = async (
})
)
const llamafileModelsPromises = llamafileProviders.map((provider) =>
dynamicFetchLlamafile({
baseUrl: provider.baseUrl,
providerId: provider.id
})
)
const lmModelsFetch = await Promise.all(lmModelsPromises)
const llamafileModelsFetch = await Promise.all(llamafileModelsPromises)
const lmModels = lmModelsFetch.flat()
const llamafileModels = llamafileModelsFetch.flat()
// merge allModels and lmModels
const allModlesWithLMStudio = [
...(modelType !== "all"
? allModles.filter((model) => model.model_type === modelType)
: allModles),
...lmModels
...lmModels,
...llamafileModels
]
const ollamaModels = allModlesWithLMStudio.map((model) => {