From c6a62126dd2c14917dc4a25de90617e747f7ebbe Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sun, 10 Nov 2024 14:02:44 +0530 Subject: [PATCH] feat: Add LlamaFile support Add support for LlamaFile, a new model provider that allows users to interact with models stored in LlamaFile format. This includes: - Adding an icon for LlamaFile in the provider selection menu. - Updating the model provider selection to include LlamaFile. - Updating the model handling logic to properly identify and process LlamaFile models. - Updating the API providers list to include LlamaFile. This enables users to leverage the capabilities of LlamaFile models within the application. --- src/components/Common/ProviderIcon.tsx | 3 + src/components/Icons/Llamafile.tsx | 24 +++++++ src/components/Option/Settings/openai.tsx | 3 +- src/db/models.ts | 84 ++++++++++++++++++++++- src/utils/oai-api-providers.ts | 5 ++ 5 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 src/components/Icons/Llamafile.tsx diff --git a/src/components/Common/ProviderIcon.tsx b/src/components/Common/ProviderIcon.tsx index 38ba504..8142adc 100644 --- a/src/components/Common/ProviderIcon.tsx +++ b/src/components/Common/ProviderIcon.tsx @@ -6,6 +6,7 @@ import { LMStudioIcon } from "../Icons/LMStudio" import { OpenAiIcon } from "../Icons/OpenAI" import { TogtherMonoIcon } from "../Icons/Togther" import { OpenRouterIcon } from "../Icons/OpenRouter" +import { LLamaFile } from "../Icons/Llamafile" export const ProviderIcons = ({ provider, @@ -31,6 +32,8 @@ export const ProviderIcons = ({ return case "openrouter": return + case "llamafile": + return default: return } diff --git a/src/components/Icons/Llamafile.tsx b/src/components/Icons/Llamafile.tsx new file mode 100644 index 0000000..734cdd9 --- /dev/null +++ b/src/components/Icons/Llamafile.tsx @@ -0,0 +1,24 @@ +// copied logo from Hugging Face webiste +import React from "react" + +export const LLamaFile = React.forwardRef< + SVGSVGElement, + React.SVGProps +>((props, ref) => { + return ( + + + + ) +}) diff --git a/src/components/Option/Settings/openai.tsx b/src/components/Option/Settings/openai.tsx index 273b1ce..fe4dcdb 100644 --- a/src/components/Option/Settings/openai.tsx +++ b/src/components/Option/Settings/openai.tsx @@ -47,7 +47,8 @@ export const OpenAIApp = () => { }) setOpen(false) message.success(t("addSuccess")) - if (provider !== "lmstudio") { + const noPopupProvider = ["lmstudio", "llamafile"] + if (!noPopupProvider.includes(provider)) { setOpenaiId(data) setOpenModelModal(true) } diff --git a/src/db/models.ts b/src/db/models.ts index 4575c67..eaa7650 100644 --- a/src/db/models.ts +++ b/src/db/models.ts @@ -24,6 +24,7 @@ export const removeModelSuffix = (id: string) => { return id .replace(/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/, "") .replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "") + .replace(/_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "") } export const isLMStudioModel = (model: string) => { const lmstudioModelRegex = @@ -31,6 +32,12 @@ export const isLMStudioModel = (model: string) => { return lmstudioModelRegex.test(model) } +export const isLlamafileModel = (model: string) => { + const llamafileModelRegex = + /_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/ + return llamafileModelRegex.test(model) +} + export const getLMStudioModelId = ( model: string ): { model_id: string; provider_id: string } => { @@ -44,10 +51,29 @@ export const getLMStudioModelId = ( } return null } + +export const getLlamafileModelId = ( + model: string +): { model_id: string; provider_id: string } => { + const llamafileModelRegex = + /_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/ + const match = model.match(llamafileModelRegex) + if (match) { + const modelId = match[0] + const providerId = match[0].replace("_llamafile_openai-", "") + return { model_id: modelId, provider_id: providerId } + } + return null +} export const isCustomModel = (model: string) => { if (isLMStudioModel(model)) { return true } + + if (isLlamafileModel(model)) { + return true + } + const customModelRegex = /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/ return customModelRegex.test(model) @@ -201,6 +227,25 @@ export const getModelInfo = async (id: string) => { } } + + if (isLlamafileModel(id)) { + const llamafileId = getLlamafileModelId(id) + if (!llamafileId) { + throw new Error("Invalid LMStudio model ID") + } + return { + model_id: id.replace( + /_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, + "" + ), + provider_id: `openai-${llamafileId.provider_id}`, + name: id.replace( + /_llamafile_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, + "" + ) + } + } + const model = await db.getById(id) return model } @@ -264,6 +309,27 @@ export const dynamicFetchLMStudio = async ({ return lmstudioModels } +export const dynamicFetchLlamafile = async ({ + baseUrl, + providerId +}: { + baseUrl: string + providerId: string +}) => { + const models = await getAllOpenAIModels(baseUrl) + const llamafileModels = models.map((e) => { + return { + name: e?.name || e?.id, + id: `${e?.id}_llamafile_${providerId}`, + provider: providerId, + lookup: `${e?.id}_${providerId}`, + provider_id: providerId + } + }) + + return llamafileModels +} + export const ollamaFormatAllCustomModels = async ( modelType: "all" | "chat" | "embedding" = "all" ) => { @@ -276,6 +342,10 @@ export const ollamaFormatAllCustomModels = async ( (provider) => provider.provider === "lmstudio" ) + const llamafileProviders = allProviders.filter( + (provider) => provider.provider === "llamafile" + ) + const lmModelsPromises = lmstudioProviders.map((provider) => dynamicFetchLMStudio({ baseUrl: provider.baseUrl, @@ -283,16 +353,28 @@ export const ollamaFormatAllCustomModels = async ( }) ) + const llamafileModelsPromises = llamafileProviders.map((provider) => + dynamicFetchLlamafile({ + baseUrl: provider.baseUrl, + providerId: provider.id + }) + ) + const lmModelsFetch = await Promise.all(lmModelsPromises) + const llamafileModelsFetch = await Promise.all(llamafileModelsPromises) + const lmModels = lmModelsFetch.flat() + const llamafileModels = llamafileModelsFetch.flat() + // merge allModels and lmModels const allModlesWithLMStudio = [ ...(modelType !== "all" ? allModles.filter((model) => model.model_type === modelType) : allModles), - ...lmModels + ...lmModels, + ...llamafileModels ] const ollamaModels = allModlesWithLMStudio.map((model) => { diff --git a/src/utils/oai-api-providers.ts b/src/utils/oai-api-providers.ts index 5d65105..52bfcf9 100644 --- a/src/utils/oai-api-providers.ts +++ b/src/utils/oai-api-providers.ts @@ -4,6 +4,11 @@ export const OAI_API_PROVIDERS = [ value: "lmstudio", baseUrl: "http://localhost:1234/v1" }, + { + label: "LlamaFile", + value: "llamafile", + baseUrl: "http://127.0.0.1:8080/v1" + }, { label: "OpenAI", value: "openai",