feat: Add support for Google AI (Gemini) as a custom model provider

This commit is contained in:
n4ze3m
2024-12-21 18:13:21 +05:30
parent a3c76f0757
commit 772000bff4
5 changed files with 56 additions and 13 deletions

View File

@@ -0,0 +1,11 @@
import { ChatOpenAI } from "@langchain/openai";
export class ChatGoogleAI extends ChatOpenAI {
frequencyPenalty: number = undefined;
presencePenalty: number = undefined;
static lc_name() {
return "ChatGoogleAI";
}
}

View File

@@ -4,6 +4,7 @@ import { ChatOllama } from "./ChatOllama"
import { getOpenAIConfigById } from "@/db/openai"
import { ChatOpenAI } from "@langchain/openai"
import { urlRewriteRuntime } from "@/libs/runtime"
import { ChatGoogleAI } from "./ChatGoogleAI"
export const pageAssistModel = async ({
model,
@@ -30,18 +31,15 @@ export const pageAssistModel = async ({
numPredict?: number
useMMap?: boolean
}) => {
if (model === "chrome::gemini-nano::page-assist") {
return new ChatChromeAI({
temperature,
topK,
topK
})
}
const isCustom = isCustomModel(model)
if (isCustom) {
const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
@@ -50,6 +48,20 @@ export const pageAssistModel = async ({
await urlRewriteRuntime(providerInfo.baseUrl || "")
}
if (providerInfo.provider === "gemini") {
return new ChatGoogleAI({
modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp",
temperature,
topP,
maxTokens: numPredict,
configuration: {
apiKey: providerInfo.apiKey || "temp",
baseURL: providerInfo.baseUrl || ""
}
}) as any
}
return new ChatOpenAI({
modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp",
@@ -58,13 +70,11 @@ export const pageAssistModel = async ({
maxTokens: numPredict,
configuration: {
apiKey: providerInfo.apiKey || "temp",
baseURL: providerInfo.baseUrl || "",
},
baseURL: providerInfo.baseUrl || ""
}
}) as any
}
return new ChatOllama({
baseUrl,
keepAlive,
@@ -76,9 +86,6 @@ export const pageAssistModel = async ({
model,
numGpu,
numPredict,
useMMap,
useMMap
})
}