feat: Add support for Google AI (Gemini) as a custom model provider

This commit is contained in:
n4ze3m 2024-12-21 18:13:21 +05:30
parent a3c76f0757
commit 772000bff4
5 changed files with 56 additions and 13 deletions

View File

@ -7,6 +7,7 @@ import { OpenAiIcon } from "../Icons/OpenAI"
import { TogtherMonoIcon } from "../Icons/Togther" import { TogtherMonoIcon } from "../Icons/Togther"
import { OpenRouterIcon } from "../Icons/OpenRouter" import { OpenRouterIcon } from "../Icons/OpenRouter"
import { LLamaFile } from "../Icons/Llamafile" import { LLamaFile } from "../Icons/Llamafile"
import { GeminiIcon } from "../Icons/GeminiIcon"
export const ProviderIcons = ({ export const ProviderIcons = ({
provider, provider,
@ -34,6 +35,8 @@ export const ProviderIcons = ({
return <OpenRouterIcon className={className} /> return <OpenRouterIcon className={className} />
case "llamafile": case "llamafile":
return <LLamaFile className={className} /> return <LLamaFile className={className} />
case "gemini":
return <GeminiIcon className={className} />
default: default:
return <OllamaIcon className={className} /> return <OllamaIcon className={className} />
} }

View File

@ -0,0 +1,18 @@
import React from "react"
export const GeminiIcon = React.forwardRef<
SVGSVGElement,
React.SVGProps<SVGSVGElement>
>((props, ref) => {
return (
<svg
fill="currentColor"
fillRule="evenodd"
ref={ref}
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
{...props}>
<path d="M12 24A14.304 14.304 0 000 12 14.304 14.304 0 0012 0a14.305 14.305 0 0012 12 14.305 14.305 0 00-12 12"></path>
</svg>
)
})

View File

@ -0,0 +1,11 @@
import { ChatOpenAI } from "@langchain/openai";
export class ChatGoogleAI extends ChatOpenAI {
frequencyPenalty: number = undefined;
presencePenalty: number = undefined;
static lc_name() {
return "ChatGoogleAI";
}
}

View File

@ -4,6 +4,7 @@ import { ChatOllama } from "./ChatOllama"
import { getOpenAIConfigById } from "@/db/openai" import { getOpenAIConfigById } from "@/db/openai"
import { ChatOpenAI } from "@langchain/openai" import { ChatOpenAI } from "@langchain/openai"
import { urlRewriteRuntime } from "@/libs/runtime" import { urlRewriteRuntime } from "@/libs/runtime"
import { ChatGoogleAI } from "./ChatGoogleAI"
export const pageAssistModel = async ({ export const pageAssistModel = async ({
model, model,
@ -30,18 +31,15 @@ export const pageAssistModel = async ({
numPredict?: number numPredict?: number
useMMap?: boolean useMMap?: boolean
}) => { }) => {
if (model === "chrome::gemini-nano::page-assist") { if (model === "chrome::gemini-nano::page-assist") {
return new ChatChromeAI({ return new ChatChromeAI({
temperature, temperature,
topK, topK
}) })
} }
const isCustom = isCustomModel(model) const isCustom = isCustomModel(model)
if (isCustom) { if (isCustom) {
const modelInfo = await getModelInfo(model) const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id) const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
@ -50,6 +48,20 @@ export const pageAssistModel = async ({
await urlRewriteRuntime(providerInfo.baseUrl || "") await urlRewriteRuntime(providerInfo.baseUrl || "")
} }
if (providerInfo.provider === "gemini") {
return new ChatGoogleAI({
modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp",
temperature,
topP,
maxTokens: numPredict,
configuration: {
apiKey: providerInfo.apiKey || "temp",
baseURL: providerInfo.baseUrl || ""
}
}) as any
}
return new ChatOpenAI({ return new ChatOpenAI({
modelName: modelInfo.model_id, modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp", openAIApiKey: providerInfo.apiKey || "temp",
@ -58,13 +70,11 @@ export const pageAssistModel = async ({
maxTokens: numPredict, maxTokens: numPredict,
configuration: { configuration: {
apiKey: providerInfo.apiKey || "temp", apiKey: providerInfo.apiKey || "temp",
baseURL: providerInfo.baseUrl || "", baseURL: providerInfo.baseUrl || ""
}, }
}) as any }) as any
} }
return new ChatOllama({ return new ChatOllama({
baseUrl, baseUrl,
keepAlive, keepAlive,
@ -76,9 +86,6 @@ export const pageAssistModel = async ({
model, model,
numGpu, numGpu,
numPredict, numPredict,
useMMap, useMMap
}) })
} }

View File

@ -44,5 +44,9 @@ export const OAI_API_PROVIDERS = [
value: "openrouter", value: "openrouter",
baseUrl: "https://openrouter.ai/api/v1" baseUrl: "https://openrouter.ai/api/v1"
}, },
{
label: "Google AI",
value: "gemini",
baseUrl: "https://generativelanguage.googleapis.com/v1beta/openai"
}
] ]