feat: Support LMStudio models

Adds support for LMStudio models, allowing users to access and use them within the application. This involves:

- Adding new functions to `db/models.ts` to handle LMStudio model IDs and fetch their information from the OpenAI API.
- Modifying the `ollamaFormatAllCustomModels` function to include LMStudio models in the list of available models.
- Introducing a timeout mechanism in `libs/openai.ts` to prevent API requests from hanging.

This change enhances the model selection experience, providing users with a wider range of models to choose from.
This commit is contained in:
n4ze3m 2024-10-12 19:05:21 +05:30
parent f1e40d5908
commit ddb8993f17
3 changed files with 89 additions and 8 deletions

View File

@ -1,3 +1,4 @@
import { getAllOpenAIModels } from "@/libs/openai"
import { import {
getAllOpenAIConfig, getAllOpenAIConfig,
getOpenAIConfigById as providerInfo getOpenAIConfigById as providerInfo
@ -22,10 +23,27 @@ export const removeModelSuffix = (id: string) => {
return id.replace( return id.replace(
/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/, /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/,
"" ""
) ).replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "")
}
export const isLMStudioModel = (model: string) => {
const lmstudioModelRegex = /_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/
return lmstudioModelRegex.test(model)
} }
export const getLMStudioModelId = (model: string): { model_id: string, provider_id: string } => {
const lmstudioModelRegex = /_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/
const match = model.match(lmstudioModelRegex)
if (match) {
const modelId = match[0]
const providerId = match[0].replace("_lmstudio_openai-", "")
return { model_id: modelId, provider_id: providerId }
}
return null
}
export const isCustomModel = (model: string) => { export const isCustomModel = (model: string) => {
if (isLMStudioModel(model)) {
return true
}
const customModelRegex = const customModelRegex =
/_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/ /_model-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{3,4}-[a-f0-9]{4}/
return customModelRegex.test(model) return customModelRegex.test(model)
@ -158,6 +176,19 @@ export const createModel = async (
export const getModelInfo = async (id: string) => { export const getModelInfo = async (id: string) => {
const db = new ModelDb() const db = new ModelDb()
if (isLMStudioModel(id)) {
const lmstudioId = getLMStudioModelId(id)
if (!lmstudioId) {
throw new Error("Invalid LMStudio model ID")
}
return {
model_id: id.replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, ""),
provider_id: `openai-${lmstudioId.provider_id}`,
name: id.replace(/_lmstudio_openai-[a-f0-9]{4}-[a-f0-9]{3}-[a-f0-9]{4}/, "")
}
}
const model = await db.getById(id) const model = await db.getById(id)
return model return model
} }
@ -199,12 +230,52 @@ export const isLookupExist = async (lookup: string) => {
return model ? true : false return model ? true : false
} }
export const dynamicFetchLMStudio = async ({
baseUrl,
providerId
}: {
baseUrl: string
providerId: string
}) => {
const models = await getAllOpenAIModels(baseUrl)
const lmstudioModels = models.map((e) => {
return {
name: e?.name || e?.id,
id: `${e?.id}_lmstudio_${providerId}`,
provider: providerId,
lookup: `${e?.id}_${providerId}`,
provider_id: providerId,
}
})
return lmstudioModels
}
export const ollamaFormatAllCustomModels = async () => { export const ollamaFormatAllCustomModels = async () => {
const allModles = await getAllCustomModels() const [allModles, allProviders] = await Promise.all([
getAllCustomModels(),
getAllOpenAIConfig()
])
const allProviders = await getAllOpenAIConfig() const lmstudioProviders = allProviders.filter(
(provider) => provider.provider === "lmstudio"
)
const ollamaModels = allModles.map((model) => { const lmModelsPromises = lmstudioProviders.map((provider) =>
dynamicFetchLMStudio({
baseUrl: provider.baseUrl,
providerId: provider.id
})
)
const lmModelsFetch = await Promise.all(lmModelsPromises)
const lmModels = lmModelsFetch.flat()
// merge allModels and lmModels
const allModlesWithLMStudio = [...allModles, ...lmModels]
const ollamaModels = allModlesWithLMStudio.map((model) => {
return { return {
name: model.name, name: model.name,
model: model.id, model: model.id,

View File

@ -14,10 +14,16 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => {
} }
: {} : {}
const controller = new AbortController()
const timeoutId = setTimeout(() => controller.abort(), 10000)
const res = await fetch(url, { const res = await fetch(url, {
headers headers,
signal: controller.signal
}) })
clearTimeout(timeoutId)
if (!res.ok) { if (!res.ok) {
return [] return []
} }
@ -27,14 +33,18 @@ export const getAllOpenAIModels = async (baseUrl: string, apiKey?: string) => {
return data.map(model => ({ return data.map(model => ({
id: model.id, id: model.id,
name: model.display_name, name: model.display_name,
})) })) as Model[]
} }
const data = (await res.json()) as { data: Model[] } const data = (await res.json()) as { data: Model[] }
return data.data return data.data
} catch (e) { } catch (e) {
if (e instanceof DOMException && e.name === 'AbortError') {
console.log('Request timed out')
} else {
console.log(e) console.log(e)
}
return [] return []
} }
} }

View File

@ -40,7 +40,7 @@ export const pageAssistModel = async ({
if (isCustom) { if (isCustom) {
const modelInfo = await getModelInfo(model) const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id) const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
console.log(modelInfo, providerInfo)
return new ChatOpenAI({ return new ChatOpenAI({
modelName: modelInfo.model_id, modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp", openAIApiKey: providerInfo.apiKey || "temp",