feat: Add support for Google AI (Gemini) as a custom model provider
This commit is contained in:
parent
a3c76f0757
commit
772000bff4
@ -7,6 +7,7 @@ import { OpenAiIcon } from "../Icons/OpenAI"
|
|||||||
import { TogtherMonoIcon } from "../Icons/Togther"
|
import { TogtherMonoIcon } from "../Icons/Togther"
|
||||||
import { OpenRouterIcon } from "../Icons/OpenRouter"
|
import { OpenRouterIcon } from "../Icons/OpenRouter"
|
||||||
import { LLamaFile } from "../Icons/Llamafile"
|
import { LLamaFile } from "../Icons/Llamafile"
|
||||||
|
import { GeminiIcon } from "../Icons/GeminiIcon"
|
||||||
|
|
||||||
export const ProviderIcons = ({
|
export const ProviderIcons = ({
|
||||||
provider,
|
provider,
|
||||||
@ -34,6 +35,8 @@ export const ProviderIcons = ({
|
|||||||
return <OpenRouterIcon className={className} />
|
return <OpenRouterIcon className={className} />
|
||||||
case "llamafile":
|
case "llamafile":
|
||||||
return <LLamaFile className={className} />
|
return <LLamaFile className={className} />
|
||||||
|
case "gemini":
|
||||||
|
return <GeminiIcon className={className} />
|
||||||
default:
|
default:
|
||||||
return <OllamaIcon className={className} />
|
return <OllamaIcon className={className} />
|
||||||
}
|
}
|
||||||
|
18
src/components/Icons/GeminiIcon.tsx
Normal file
18
src/components/Icons/GeminiIcon.tsx
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import React from "react"
|
||||||
|
|
||||||
|
export const GeminiIcon = React.forwardRef<
|
||||||
|
SVGSVGElement,
|
||||||
|
React.SVGProps<SVGSVGElement>
|
||||||
|
>((props, ref) => {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
fill="currentColor"
|
||||||
|
fillRule="evenodd"
|
||||||
|
ref={ref}
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
{...props}>
|
||||||
|
<path d="M12 24A14.304 14.304 0 000 12 14.304 14.304 0 0012 0a14.305 14.305 0 0012 12 14.305 14.305 0 00-12 12"></path>
|
||||||
|
</svg>
|
||||||
|
)
|
||||||
|
})
|
11
src/models/ChatGoogleAI.ts
Normal file
11
src/models/ChatGoogleAI.ts
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import { ChatOpenAI } from "@langchain/openai";
|
||||||
|
|
||||||
|
export class ChatGoogleAI extends ChatOpenAI {
|
||||||
|
|
||||||
|
frequencyPenalty: number = undefined;
|
||||||
|
presencePenalty: number = undefined;
|
||||||
|
|
||||||
|
static lc_name() {
|
||||||
|
return "ChatGoogleAI";
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ import { ChatOllama } from "./ChatOllama"
|
|||||||
import { getOpenAIConfigById } from "@/db/openai"
|
import { getOpenAIConfigById } from "@/db/openai"
|
||||||
import { ChatOpenAI } from "@langchain/openai"
|
import { ChatOpenAI } from "@langchain/openai"
|
||||||
import { urlRewriteRuntime } from "@/libs/runtime"
|
import { urlRewriteRuntime } from "@/libs/runtime"
|
||||||
|
import { ChatGoogleAI } from "./ChatGoogleAI"
|
||||||
|
|
||||||
export const pageAssistModel = async ({
|
export const pageAssistModel = async ({
|
||||||
model,
|
model,
|
||||||
@ -30,18 +31,15 @@ export const pageAssistModel = async ({
|
|||||||
numPredict?: number
|
numPredict?: number
|
||||||
useMMap?: boolean
|
useMMap?: boolean
|
||||||
}) => {
|
}) => {
|
||||||
|
|
||||||
if (model === "chrome::gemini-nano::page-assist") {
|
if (model === "chrome::gemini-nano::page-assist") {
|
||||||
return new ChatChromeAI({
|
return new ChatChromeAI({
|
||||||
temperature,
|
temperature,
|
||||||
topK,
|
topK
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const isCustom = isCustomModel(model)
|
const isCustom = isCustomModel(model)
|
||||||
|
|
||||||
|
|
||||||
if (isCustom) {
|
if (isCustom) {
|
||||||
const modelInfo = await getModelInfo(model)
|
const modelInfo = await getModelInfo(model)
|
||||||
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
|
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
|
||||||
@ -50,6 +48,20 @@ export const pageAssistModel = async ({
|
|||||||
await urlRewriteRuntime(providerInfo.baseUrl || "")
|
await urlRewriteRuntime(providerInfo.baseUrl || "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (providerInfo.provider === "gemini") {
|
||||||
|
return new ChatGoogleAI({
|
||||||
|
modelName: modelInfo.model_id,
|
||||||
|
openAIApiKey: providerInfo.apiKey || "temp",
|
||||||
|
temperature,
|
||||||
|
topP,
|
||||||
|
maxTokens: numPredict,
|
||||||
|
configuration: {
|
||||||
|
apiKey: providerInfo.apiKey || "temp",
|
||||||
|
baseURL: providerInfo.baseUrl || ""
|
||||||
|
}
|
||||||
|
}) as any
|
||||||
|
}
|
||||||
|
|
||||||
return new ChatOpenAI({
|
return new ChatOpenAI({
|
||||||
modelName: modelInfo.model_id,
|
modelName: modelInfo.model_id,
|
||||||
openAIApiKey: providerInfo.apiKey || "temp",
|
openAIApiKey: providerInfo.apiKey || "temp",
|
||||||
@ -58,13 +70,11 @@ export const pageAssistModel = async ({
|
|||||||
maxTokens: numPredict,
|
maxTokens: numPredict,
|
||||||
configuration: {
|
configuration: {
|
||||||
apiKey: providerInfo.apiKey || "temp",
|
apiKey: providerInfo.apiKey || "temp",
|
||||||
baseURL: providerInfo.baseUrl || "",
|
baseURL: providerInfo.baseUrl || ""
|
||||||
},
|
}
|
||||||
}) as any
|
}) as any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return new ChatOllama({
|
return new ChatOllama({
|
||||||
baseUrl,
|
baseUrl,
|
||||||
keepAlive,
|
keepAlive,
|
||||||
@ -76,9 +86,6 @@ export const pageAssistModel = async ({
|
|||||||
model,
|
model,
|
||||||
numGpu,
|
numGpu,
|
||||||
numPredict,
|
numPredict,
|
||||||
useMMap,
|
useMMap
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -44,5 +44,9 @@ export const OAI_API_PROVIDERS = [
|
|||||||
value: "openrouter",
|
value: "openrouter",
|
||||||
baseUrl: "https://openrouter.ai/api/v1"
|
baseUrl: "https://openrouter.ai/api/v1"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: "Google AI",
|
||||||
|
value: "gemini",
|
||||||
|
baseUrl: "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||||
|
}
|
||||||
]
|
]
|
Loading…
x
Reference in New Issue
Block a user