112 lines
2.3 KiB
TypeScript
112 lines
2.3 KiB
TypeScript
import { getModelInfo, isCustomModel, isOllamaModel } from "@/db/models"
|
|
import { ChatChromeAI } from "./ChatChromeAi"
|
|
import { ChatOllama } from "./ChatOllama"
|
|
import { getOpenAIConfigById } from "@/db/openai"
|
|
import { urlRewriteRuntime } from "@/libs/runtime"
|
|
import { ChatGoogleAI } from "./ChatGoogleAI"
|
|
import { CustomChatOpenAI } from "./CustomChatOpenAI"
|
|
|
|
export const pageAssistModel = async ({
|
|
model,
|
|
baseUrl,
|
|
keepAlive,
|
|
temperature,
|
|
topK,
|
|
topP,
|
|
numCtx,
|
|
seed,
|
|
numGpu,
|
|
numPredict,
|
|
useMMap,
|
|
minP,
|
|
repeatLastN,
|
|
repeatPenalty,
|
|
tfsZ,
|
|
numKeep,
|
|
numThread,
|
|
useMlock,
|
|
}: {
|
|
model: string
|
|
baseUrl: string
|
|
keepAlive?: string
|
|
temperature?: number
|
|
topK?: number
|
|
topP?: number
|
|
numCtx?: number
|
|
seed?: number
|
|
numGpu?: number
|
|
numPredict?: number
|
|
useMMap?: boolean
|
|
minP?: number
|
|
repeatPenalty?: number
|
|
repeatLastN?: number
|
|
tfsZ?: number,
|
|
numKeep?: number,
|
|
numThread?: number,
|
|
useMlock?: boolean,
|
|
}) => {
|
|
if (model === "chrome::gemini-nano::page-assist") {
|
|
return new ChatChromeAI({
|
|
temperature,
|
|
topK
|
|
})
|
|
}
|
|
|
|
const isCustom = isCustomModel(model)
|
|
|
|
if (isCustom) {
|
|
const modelInfo = await getModelInfo(model)
|
|
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
|
|
|
|
if (isOllamaModel(model)) {
|
|
await urlRewriteRuntime(providerInfo.baseUrl || "")
|
|
}
|
|
|
|
if (providerInfo.provider === "gemini") {
|
|
return new ChatGoogleAI({
|
|
modelName: modelInfo.model_id,
|
|
openAIApiKey: providerInfo.apiKey || "temp",
|
|
temperature,
|
|
topP,
|
|
maxTokens: numPredict,
|
|
configuration: {
|
|
apiKey: providerInfo.apiKey || "temp",
|
|
baseURL: providerInfo.baseUrl || ""
|
|
}
|
|
}) as any
|
|
}
|
|
|
|
return new CustomChatOpenAI({
|
|
modelName: modelInfo.model_id,
|
|
openAIApiKey: providerInfo.apiKey || "temp",
|
|
temperature,
|
|
topP,
|
|
maxTokens: numPredict,
|
|
configuration: {
|
|
apiKey: providerInfo.apiKey || "temp",
|
|
baseURL: providerInfo.baseUrl || ""
|
|
}
|
|
}) as any
|
|
}
|
|
return new ChatOllama({
|
|
baseUrl,
|
|
keepAlive,
|
|
temperature,
|
|
topK,
|
|
topP,
|
|
numCtx,
|
|
seed,
|
|
model,
|
|
numGpu,
|
|
numPredict,
|
|
useMMap,
|
|
minP: minP,
|
|
repeatPenalty: repeatPenalty,
|
|
repeatLastN: repeatLastN,
|
|
tfsZ,
|
|
numKeep,
|
|
numThread,
|
|
useMlock
|
|
})
|
|
}
|