feat: Add support for Google AI (Gemini) as a custom model provider
This commit is contained in:
11
src/models/ChatGoogleAI.ts
Normal file
11
src/models/ChatGoogleAI.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { ChatOpenAI } from "@langchain/openai";
|
||||
|
||||
export class ChatGoogleAI extends ChatOpenAI {
|
||||
|
||||
frequencyPenalty: number = undefined;
|
||||
presencePenalty: number = undefined;
|
||||
|
||||
static lc_name() {
|
||||
return "ChatGoogleAI";
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import { ChatOllama } from "./ChatOllama"
|
||||
import { getOpenAIConfigById } from "@/db/openai"
|
||||
import { ChatOpenAI } from "@langchain/openai"
|
||||
import { urlRewriteRuntime } from "@/libs/runtime"
|
||||
import { ChatGoogleAI } from "./ChatGoogleAI"
|
||||
|
||||
export const pageAssistModel = async ({
|
||||
model,
|
||||
@@ -30,18 +31,15 @@ export const pageAssistModel = async ({
|
||||
numPredict?: number
|
||||
useMMap?: boolean
|
||||
}) => {
|
||||
|
||||
if (model === "chrome::gemini-nano::page-assist") {
|
||||
return new ChatChromeAI({
|
||||
temperature,
|
||||
topK,
|
||||
topK
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
const isCustom = isCustomModel(model)
|
||||
|
||||
|
||||
if (isCustom) {
|
||||
const modelInfo = await getModelInfo(model)
|
||||
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
|
||||
@@ -50,6 +48,20 @@ export const pageAssistModel = async ({
|
||||
await urlRewriteRuntime(providerInfo.baseUrl || "")
|
||||
}
|
||||
|
||||
if (providerInfo.provider === "gemini") {
|
||||
return new ChatGoogleAI({
|
||||
modelName: modelInfo.model_id,
|
||||
openAIApiKey: providerInfo.apiKey || "temp",
|
||||
temperature,
|
||||
topP,
|
||||
maxTokens: numPredict,
|
||||
configuration: {
|
||||
apiKey: providerInfo.apiKey || "temp",
|
||||
baseURL: providerInfo.baseUrl || ""
|
||||
}
|
||||
}) as any
|
||||
}
|
||||
|
||||
return new ChatOpenAI({
|
||||
modelName: modelInfo.model_id,
|
||||
openAIApiKey: providerInfo.apiKey || "temp",
|
||||
@@ -58,13 +70,11 @@ export const pageAssistModel = async ({
|
||||
maxTokens: numPredict,
|
||||
configuration: {
|
||||
apiKey: providerInfo.apiKey || "temp",
|
||||
baseURL: providerInfo.baseUrl || "",
|
||||
},
|
||||
baseURL: providerInfo.baseUrl || ""
|
||||
}
|
||||
}) as any
|
||||
}
|
||||
|
||||
|
||||
|
||||
return new ChatOllama({
|
||||
baseUrl,
|
||||
keepAlive,
|
||||
@@ -76,9 +86,6 @@ export const pageAssistModel = async ({
|
||||
model,
|
||||
numGpu,
|
||||
numPredict,
|
||||
useMMap,
|
||||
useMMap
|
||||
})
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user