From 772000bff407a9b14627a8925baa9b5b0b0e2c80 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 21 Dec 2024 18:13:21 +0530 Subject: [PATCH] feat: Add support for Google AI (Gemini) as a custom model provider --- src/components/Common/ProviderIcon.tsx | 3 +++ src/components/Icons/GeminiIcon.tsx | 18 +++++++++++++++ src/models/ChatGoogleAI.ts | 11 +++++++++ src/models/index.ts | 31 ++++++++++++++++---------- src/utils/oai-api-providers.ts | 6 ++++- 5 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 src/components/Icons/GeminiIcon.tsx create mode 100644 src/models/ChatGoogleAI.ts diff --git a/src/components/Common/ProviderIcon.tsx b/src/components/Common/ProviderIcon.tsx index 8142adc..1db9d6d 100644 --- a/src/components/Common/ProviderIcon.tsx +++ b/src/components/Common/ProviderIcon.tsx @@ -7,6 +7,7 @@ import { OpenAiIcon } from "../Icons/OpenAI" import { TogtherMonoIcon } from "../Icons/Togther" import { OpenRouterIcon } from "../Icons/OpenRouter" import { LLamaFile } from "../Icons/Llamafile" +import { GeminiIcon } from "../Icons/GeminiIcon" export const ProviderIcons = ({ provider, @@ -34,6 +35,8 @@ export const ProviderIcons = ({ return case "llamafile": return + case "gemini": + return default: return } diff --git a/src/components/Icons/GeminiIcon.tsx b/src/components/Icons/GeminiIcon.tsx new file mode 100644 index 0000000..05cc743 --- /dev/null +++ b/src/components/Icons/GeminiIcon.tsx @@ -0,0 +1,18 @@ +import React from "react" + +export const GeminiIcon = React.forwardRef< + SVGSVGElement, + React.SVGProps +>((props, ref) => { + return ( + + + + ) +}) diff --git a/src/models/ChatGoogleAI.ts b/src/models/ChatGoogleAI.ts new file mode 100644 index 0000000..37d7973 --- /dev/null +++ b/src/models/ChatGoogleAI.ts @@ -0,0 +1,11 @@ +import { ChatOpenAI } from "@langchain/openai"; + +export class ChatGoogleAI extends ChatOpenAI { + + frequencyPenalty: number = undefined; + presencePenalty: number = undefined; + + static lc_name() { + return "ChatGoogleAI"; + } +} \ No newline at end of file diff --git a/src/models/index.ts b/src/models/index.ts index 3481d52..953c968 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -4,6 +4,7 @@ import { ChatOllama } from "./ChatOllama" import { getOpenAIConfigById } from "@/db/openai" import { ChatOpenAI } from "@langchain/openai" import { urlRewriteRuntime } from "@/libs/runtime" +import { ChatGoogleAI } from "./ChatGoogleAI" export const pageAssistModel = async ({ model, @@ -30,18 +31,15 @@ export const pageAssistModel = async ({ numPredict?: number useMMap?: boolean }) => { - if (model === "chrome::gemini-nano::page-assist") { return new ChatChromeAI({ temperature, - topK, + topK }) } - const isCustom = isCustomModel(model) - if (isCustom) { const modelInfo = await getModelInfo(model) const providerInfo = await getOpenAIConfigById(modelInfo.provider_id) @@ -50,6 +48,20 @@ export const pageAssistModel = async ({ await urlRewriteRuntime(providerInfo.baseUrl || "") } + if (providerInfo.provider === "gemini") { + return new ChatGoogleAI({ + modelName: modelInfo.model_id, + openAIApiKey: providerInfo.apiKey || "temp", + temperature, + topP, + maxTokens: numPredict, + configuration: { + apiKey: providerInfo.apiKey || "temp", + baseURL: providerInfo.baseUrl || "" + } + }) as any + } + return new ChatOpenAI({ modelName: modelInfo.model_id, openAIApiKey: providerInfo.apiKey || "temp", @@ -58,13 +70,11 @@ export const pageAssistModel = async ({ maxTokens: numPredict, configuration: { apiKey: providerInfo.apiKey || "temp", - baseURL: providerInfo.baseUrl || "", - }, + baseURL: providerInfo.baseUrl || "" + } }) as any } - - return new ChatOllama({ baseUrl, keepAlive, @@ -76,9 +86,6 @@ export const pageAssistModel = async ({ model, numGpu, numPredict, - useMMap, + useMMap }) - - - } diff --git a/src/utils/oai-api-providers.ts b/src/utils/oai-api-providers.ts index 40c1e21..06a59e7 100644 --- a/src/utils/oai-api-providers.ts +++ b/src/utils/oai-api-providers.ts @@ -44,5 +44,9 @@ export const OAI_API_PROVIDERS = [ value: "openrouter", baseUrl: "https://openrouter.ai/api/v1" }, - + { + label: "Google AI", + value: "gemini", + baseUrl: "https://generativelanguage.googleapis.com/v1beta/openai" + } ] \ No newline at end of file