From ddb8993f17708a6347190ed1b48dc3b3e7332ab4 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 12 Oct 2024 19:05:21 +0530 Subject: [PATCH] feat: Support LMStudio models Adds support for LMStudio models, allowing users to access and use them within the application. This involves: - Adding new functions to `db/models.ts` to handle LMStudio model IDs and fetch their information from the OpenAI API. - Modifying the `ollamaFormatAllCustomModels` function to include LMStudio models in the list of available models. - Introducing a timeout mechanism in `libs/openai.ts` to prevent API requests from hanging. This change enhances the model selection experience, providing users with a wider range of models to choose from. --- src/db/models.ts | 79 ++++++++++++++++++++++++++++++++++++++++++--- src/libs/openai.ts | 16 +++++++-- src/models/index.ts | 2 +- 3 files changed, 89 insertions(+), 8 deletions(-) diff --git a/src/db/models.ts b/src/db/models.ts index c284e58..e985472 100644 --- a/src/db/models.ts +++ b/src/db/models.ts @@ -1,3 +1,4 @@ +import { getAllOpenAIModels } from "@/libs/openai" import { getAllOpenAIConfig, getOpenAIConfigById as providerInfo @@ -22,10 +23,27 @@ export const removeModelSuffix = (id: string) => { return id.replace( /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/, "" - ) + ).replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "") +} +export const isLMStudioModel = (model: string) => { + const lmstudioModelRegex = /_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/ + return lmstudioModelRegex.test(model) } +export const getLMStudioModelId = (model: string): { model_id: string, provider_id: string } => { + const lmstudioModelRegex = /_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/ + const match = model.match(lmstudioModelRegex) + if (match) { + const modelId = match[0] + const providerId = match[0].replace("_lmstudio_openai-", "") + return { model_id: modelId, provider_id: providerId } + } + return null +} export const isCustomModel = (model: string) => { + if (isLMStudioModel(model)) { + return true + } const customModelRegex = /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/ return customModelRegex.test(model) @@ -158,6 +176,19 @@ export const createModel = async ( export const getModelInfo = async (id: string) => { const db = new ModelDb() + + if (isLMStudioModel(id)) { + const lmstudioId = getLMStudioModelId(id) + if (!lmstudioId) { + throw new Error("Invalid LMStudio model ID") + } + return { + model_id: id.replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, ""), + provider_id: `openai-${lmstudioId.provider_id}`, + name: id.replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "") + } + } + const model = await db.getById(id) return model } @@ -199,12 +230,52 @@ export const isLookupExist = async (lookup: string) => { return model ? true : false } +export const dynamicFetchLMStudio = async ({ + baseUrl, + providerId +}: { + baseUrl: string + providerId: string +}) => { + const models = await getAllOpenAIModels(baseUrl) + const lmstudioModels = models.map((e) => { + return { + name: e?.name || e?.id, + id: `${e?.id}_lmstudio_${providerId}`, + provider: providerId, + lookup: `${e?.id}_${providerId}`, + provider_id: providerId, + } + }) + + return lmstudioModels +} + export const ollamaFormatAllCustomModels = async () => { - const allModles = await getAllCustomModels() + const [allModles, allProviders] = await Promise.all([ + getAllCustomModels(), + getAllOpenAIConfig() + ]) - const allProviders = await getAllOpenAIConfig() + const lmstudioProviders = allProviders.filter( + (provider) => provider.provider === "lmstudio" + ) - const ollamaModels = allModles.map((model) => { + const lmModelsPromises = lmstudioProviders.map((provider) => + dynamicFetchLMStudio({ + baseUrl: provider.baseUrl, + providerId: provider.id + }) + ) + + const lmModelsFetch = await Promise.all(lmModelsPromises) + + const lmModels = lmModelsFetch.flat() + + // merge allModels and lmModels + const allModlesWithLMStudio = [...allModles, ...lmModels] + + const ollamaModels = allModlesWithLMStudio.map((model) => { return { name: model.name, model: model.id, diff --git a/src/libs/openai.ts b/src/libs/openai.ts index 377639f..dad8d72 100644 --- a/src/libs/openai.ts +++ b/src/libs/openai.ts @@ -14,10 +14,16 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => { } : {} + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), 10000) + const res = await fetch(url, { - headers + headers, + signal: controller.signal }) + clearTimeout(timeoutId) + if (!res.ok) { return [] } @@ -27,14 +33,18 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => { return data.map(model => ({ id: model.id, name: model.display_name, - })) + })) as Model[] } const data = (await res.json()) as { data: Model[] } return data.data } catch (e) { - console.log(e) + if (e instanceof DOMException && e.name === 'AbortError') { + console.log('Request timed out') + } else { + console.log(e) + } return [] } } diff --git a/src/models/index.ts b/src/models/index.ts index d459e66..4df2419 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -40,7 +40,7 @@ export const pageAssistModel = async ({ if (isCustom) { const modelInfo = await getModelInfo(model) const providerInfo = await getOpenAIConfigById(modelInfo.provider_id) - + console.log(modelInfo, providerInfo) return new ChatOpenAI({ modelName: modelInfo.model_id, openAIApiKey: providerInfo.apiKey || "temp",