feat: Improve model selection and embedding
Refactor embedding models and their handling to improve performance and simplify the process. Add a new model selection mechanism, and enhance the UI for model selection, offering clearer and more user-friendly options for embedding models. Refactor embeddings to use a common model for page assist and RAG, further improving performance and streamlining the workflow.
This commit is contained in:
parent
ba071ffeb1
commit
768ff2e555
@ -5,13 +5,14 @@ import {
|
||||
defaultEmbeddingChunkOverlap,
|
||||
defaultEmbeddingChunkSize,
|
||||
defaultEmbeddingModelForRag,
|
||||
getAllModels,
|
||||
getEmbeddingModels,
|
||||
saveForRag
|
||||
} from "~/services/ollama"
|
||||
import { SettingPrompt } from "./prompt"
|
||||
import { useTranslation } from "react-i18next"
|
||||
import { getNoOfRetrievedDocs, getTotalFilePerKB } from "@/services/app"
|
||||
import { SidepanelRag } from "./sidepanel-rag"
|
||||
import { ProviderIcons } from "@/components/Common/ProviderIcon"
|
||||
|
||||
export const RagSettings = () => {
|
||||
const { t } = useTranslation("settings")
|
||||
@ -29,7 +30,7 @@ export const RagSettings = () => {
|
||||
totalFilePerKB,
|
||||
noOfRetrievedDocs
|
||||
] = await Promise.all([
|
||||
getAllModels({ returnEmpty: true }),
|
||||
getEmbeddingModels({ returnEmpty: true }),
|
||||
defaultEmbeddingChunkOverlap(),
|
||||
defaultEmbeddingChunkSize(),
|
||||
defaultEmbeddingModelForRag(),
|
||||
@ -113,18 +114,27 @@ export const RagSettings = () => {
|
||||
]}>
|
||||
<Select
|
||||
size="large"
|
||||
filterOption={(input, option) =>
|
||||
option!.label.toLowerCase().indexOf(input.toLowerCase()) >=
|
||||
0 ||
|
||||
option!.value.toLowerCase().indexOf(input.toLowerCase()) >=
|
||||
0
|
||||
}
|
||||
showSearch
|
||||
placeholder={t("rag.ragSettings.model.placeholder")}
|
||||
style={{ width: "100%" }}
|
||||
className="mt-4"
|
||||
filterOption={(input, option) =>
|
||||
option.label.key
|
||||
.toLowerCase()
|
||||
.indexOf(input.toLowerCase()) >= 0
|
||||
}
|
||||
options={ollamaInfo.models?.map((model) => ({
|
||||
label: model.name,
|
||||
label: (
|
||||
<span
|
||||
key={model.model}
|
||||
className="flex flex-row gap-3 items-center truncate">
|
||||
<ProviderIcons
|
||||
provider={model?.provider}
|
||||
className="w-5 h-5"
|
||||
/>
|
||||
<span className="truncate">{model.name}</span>
|
||||
</span>
|
||||
),
|
||||
value: model.model
|
||||
}))}
|
||||
/>
|
||||
|
@ -26,7 +26,6 @@ import { notification } from "antd"
|
||||
import { useTranslation } from "react-i18next"
|
||||
import { usePageAssist } from "@/context"
|
||||
import { formatDocs } from "@/chain/chat-with-x"
|
||||
import { OllamaEmbeddingsPageAssist } from "@/models/OllamaEmbedding"
|
||||
import { useStorage } from "@plasmohq/storage/hook"
|
||||
import { useStoreChatModelSettings } from "@/store/model"
|
||||
import { getAllDefaultModelSettings } from "@/services/model-settings"
|
||||
@ -34,6 +33,7 @@ import { getSystemPromptForWeb } from "@/web/web"
|
||||
import { pageAssistModel } from "@/models"
|
||||
import { getPrompt } from "@/services/application"
|
||||
import { humanMessageFormatter } from "@/utils/human-message"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
|
||||
export const useMessage = () => {
|
||||
const {
|
||||
@ -202,7 +202,7 @@ export const useMessage = () => {
|
||||
const ollamaUrl = await getOllamaURL()
|
||||
const embeddingModle = await defaultEmbeddingModelForRag()
|
||||
|
||||
const ollamaEmbedding = new OllamaEmbeddingsPageAssist({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
model: embeddingModle || selectedModel,
|
||||
baseUrl: cleanUrl(ollamaUrl),
|
||||
signal: embeddingSignal,
|
||||
|
@ -24,7 +24,6 @@ import { generateHistory } from "@/utils/generate-history"
|
||||
import { useTranslation } from "react-i18next"
|
||||
import { saveMessageOnError, saveMessageOnSuccess } from "./chat-helper"
|
||||
import { usePageAssist } from "@/context"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore"
|
||||
import { formatDocs } from "@/chain/chat-with-x"
|
||||
import { useWebUI } from "@/store/webui"
|
||||
@ -34,6 +33,7 @@ import { getAllDefaultModelSettings } from "@/services/model-settings"
|
||||
import { pageAssistModel } from "@/models"
|
||||
import { getNoOfRetrievedDocs } from "@/services/app"
|
||||
import { humanMessageFormatter } from "@/utils/human-message"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
|
||||
export const useMessageOption = () => {
|
||||
const {
|
||||
@ -628,7 +628,7 @@ export const useMessageOption = () => {
|
||||
|
||||
const embeddingModle = await defaultEmbeddingModelForRag()
|
||||
const ollamaUrl = await getOllamaURL()
|
||||
const ollamaEmbedding = new OllamaEmbeddings({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
model: embeddingModle || selectedModel,
|
||||
baseUrl: cleanUrl(ollamaUrl),
|
||||
keepAlive:
|
||||
|
@ -5,7 +5,6 @@ import {
|
||||
defaultEmbeddingChunkSize,
|
||||
getOllamaURL
|
||||
} from "@/services/ollama"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
import { PageAssistVectorStore } from "./PageAssistVectorStore"
|
||||
import { PageAssisCSVUrlLoader } from "@/loader/csv"
|
||||
@ -13,6 +12,7 @@ import { PageAssisTXTUrlLoader } from "@/loader/txt"
|
||||
import { PageAssistDocxLoader } from "@/loader/docx"
|
||||
import { cleanUrl } from "./clean-url"
|
||||
import { sendEmbeddingCompleteNotification } from "./send-notification"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
|
||||
|
||||
export const processKnowledge = async (msg: any, id: string): Promise<void> => {
|
||||
@ -28,7 +28,7 @@ export const processKnowledge = async (msg: any, id: string): Promise<void> => {
|
||||
|
||||
await updateKnowledgeStatus(id, "processing")
|
||||
|
||||
const ollamaEmbedding = new OllamaEmbeddings({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
baseUrl: cleanUrl(ollamaUrl),
|
||||
model: knowledge.embedding_model
|
||||
})
|
||||
|
@ -44,21 +44,6 @@ export interface OpenAIEmbeddingsParams extends EmbeddingsParams {
|
||||
signal?: AbortSignal
|
||||
}
|
||||
|
||||
/**
|
||||
* Class for generating embeddings using the OpenAI API. Extends the
|
||||
* Embeddings class and implements OpenAIEmbeddingsParams and
|
||||
* AzureOpenAIInput.
|
||||
* @example
|
||||
* ```typescript
|
||||
* // Embed a query using OpenAIEmbeddings to generate embeddings for a given text
|
||||
* const model = new OpenAIEmbeddings();
|
||||
* const res = await model.embedQuery(
|
||||
* "What would be a good company name for a company that makes colorful socks?",
|
||||
* );
|
||||
* console.log({ res });
|
||||
*
|
||||
* ```
|
||||
*/
|
||||
export class OAIEmbedding
|
||||
extends Embeddings
|
||||
implements OpenAIEmbeddingsParams {
|
||||
@ -96,6 +81,7 @@ export class OAIEmbedding
|
||||
protected client: OpenAIClient
|
||||
|
||||
protected clientConfig: ClientOptions
|
||||
|
||||
signal?: AbortSignal
|
||||
|
||||
constructor(
|
||||
|
36
src/models/embedding.ts
Normal file
36
src/models/embedding.ts
Normal file
@ -0,0 +1,36 @@
|
||||
import { getModelInfo, isCustomModel } from "@/db/models"
|
||||
import { OllamaEmbeddingsPageAssist } from "./OllamaEmbedding"
|
||||
import { OAIEmbedding } from "./OAIEmbedding"
|
||||
import { getOpenAIConfigById } from "@/db/openai"
|
||||
|
||||
type EmbeddingModel = {
|
||||
model: string
|
||||
baseUrl: string
|
||||
signal?: AbortSignal
|
||||
keepAlive?: string
|
||||
}
|
||||
|
||||
export const pageAssistEmbeddingModel = async ({ baseUrl, model, keepAlive, signal }: EmbeddingModel) => {
|
||||
const isCustom = isCustomModel(model)
|
||||
if (isCustom) {
|
||||
const modelInfo = await getModelInfo(model)
|
||||
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
|
||||
return new OAIEmbedding({
|
||||
modelName: modelInfo.model_id,
|
||||
model: modelInfo.model_id,
|
||||
signal,
|
||||
openAIApiKey: providerInfo.apiKey || "temp",
|
||||
configuration: {
|
||||
apiKey: providerInfo.apiKey || "temp",
|
||||
baseURL: providerInfo.baseUrl || "",
|
||||
}
|
||||
}) as any
|
||||
}
|
||||
|
||||
return new OllamaEmbeddingsPageAssist({
|
||||
model,
|
||||
baseUrl,
|
||||
keepAlive,
|
||||
signal
|
||||
})
|
||||
}
|
@ -40,7 +40,7 @@ export const pageAssistModel = async ({
|
||||
if (isCustom) {
|
||||
const modelInfo = await getModelInfo(model)
|
||||
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
|
||||
console.log(modelInfo, providerInfo)
|
||||
|
||||
return new ChatOpenAI({
|
||||
modelName: modelInfo.model_id,
|
||||
openAIApiKey: providerInfo.apiKey || "temp",
|
||||
|
@ -133,6 +133,28 @@ export const getAllModels = async ({
|
||||
}
|
||||
}
|
||||
|
||||
export const getEmbeddingModels = async ({ returnEmpty }: {
|
||||
returnEmpty?: boolean
|
||||
}) => {
|
||||
try {
|
||||
const ollamaModels = await getAllModels({ returnEmpty })
|
||||
const customModels = await ollamaFormatAllCustomModels()
|
||||
|
||||
return [
|
||||
...ollamaModels.map((model) => {
|
||||
return {
|
||||
...model,
|
||||
provider: "ollama"
|
||||
}
|
||||
}),
|
||||
...customModels
|
||||
]
|
||||
} catch (e) {
|
||||
console.error(e)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export const deleteModel = async (model: string) => {
|
||||
const baseUrl = await getOllamaURL()
|
||||
const response = await fetcher(`${cleanUrl(baseUrl)}/api/delete`, {
|
||||
@ -341,7 +363,7 @@ export const saveForRag = async (
|
||||
await setDefaultEmbeddingChunkSize(chunkSize)
|
||||
await setDefaultEmbeddingChunkOverlap(overlap)
|
||||
await setTotalFilePerKB(totalFilePerKB)
|
||||
if(noOfRetrievedDocs) {
|
||||
if (noOfRetrievedDocs) {
|
||||
await setNoOfRetrievedDocs(noOfRetrievedDocs)
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { PageAssistHtmlLoader } from "~/loader/html"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
import { MemoryVectorStore } from "langchain/vectorstores/memory"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
|
||||
import {
|
||||
defaultEmbeddingChunkOverlap,
|
||||
defaultEmbeddingChunkSize
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { cleanUrl } from "@/libs/clean-url"
|
||||
import { urlRewriteRuntime } from "@/libs/runtime"
|
||||
import { PageAssistHtmlLoader } from "@/loader/html"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
import {
|
||||
defaultEmbeddingChunkOverlap,
|
||||
defaultEmbeddingChunkSize,
|
||||
@ -11,7 +12,7 @@ import {
|
||||
getIsSimpleInternetSearch,
|
||||
totalSearchResults
|
||||
} from "@/services/search"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
|
||||
import type { Document } from "@langchain/core/documents"
|
||||
import * as cheerio from "cheerio"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
@ -81,7 +82,7 @@ export const webBraveSearch = async (query: string) => {
|
||||
const ollamaUrl = await getOllamaURL()
|
||||
|
||||
const embeddingModle = await defaultEmbeddingModelForRag()
|
||||
const ollamaEmbedding = new OllamaEmbeddings({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
model: embeddingModle || "",
|
||||
baseUrl: cleanUrl(ollamaUrl)
|
||||
})
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { cleanUrl } from "@/libs/clean-url"
|
||||
import { urlRewriteRuntime } from "@/libs/runtime"
|
||||
import { PageAssistHtmlLoader } from "@/loader/html"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
import {
|
||||
defaultEmbeddingChunkOverlap,
|
||||
defaultEmbeddingChunkSize,
|
||||
@ -11,7 +12,6 @@ import {
|
||||
getIsSimpleInternetSearch,
|
||||
totalSearchResults
|
||||
} from "@/services/search"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
import type { Document } from "@langchain/core/documents"
|
||||
import * as cheerio from "cheerio"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
@ -85,7 +85,7 @@ export const webDuckDuckGoSearch = async (query: string) => {
|
||||
const ollamaUrl = await getOllamaURL()
|
||||
|
||||
const embeddingModle = await defaultEmbeddingModelForRag()
|
||||
const ollamaEmbedding = new OllamaEmbeddings({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
model: embeddingModle || "",
|
||||
baseUrl: cleanUrl(ollamaUrl)
|
||||
})
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
import {
|
||||
getIsSimpleInternetSearch,
|
||||
totalSearchResults
|
||||
} from "@/services/search"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
import type { Document } from "@langchain/core/documents"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
import { MemoryVectorStore } from "langchain/vectorstores/memory"
|
||||
@ -84,7 +84,7 @@ export const webGoogleSearch = async (query: string) => {
|
||||
const ollamaUrl = await getOllamaURL()
|
||||
|
||||
const embeddingModle = await defaultEmbeddingModelForRag()
|
||||
const ollamaEmbedding = new OllamaEmbeddings({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
model: embeddingModle || "",
|
||||
baseUrl: cleanUrl(ollamaUrl)
|
||||
})
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { cleanUrl } from "@/libs/clean-url"
|
||||
import { urlRewriteRuntime } from "@/libs/runtime"
|
||||
import { PageAssistHtmlLoader } from "@/loader/html"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
import {
|
||||
defaultEmbeddingChunkOverlap,
|
||||
defaultEmbeddingChunkSize,
|
||||
@ -11,7 +12,6 @@ import {
|
||||
getIsSimpleInternetSearch,
|
||||
totalSearchResults
|
||||
} from "@/services/search"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
import type { Document } from "@langchain/core/documents"
|
||||
import * as cheerio from "cheerio"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
@ -99,7 +99,7 @@ export const webSogouSearch = async (query: string) => {
|
||||
const ollamaUrl = await getOllamaURL()
|
||||
|
||||
const embeddingModle = await defaultEmbeddingModelForRag()
|
||||
const ollamaEmbedding = new OllamaEmbeddings({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
model: embeddingModle || "",
|
||||
baseUrl: cleanUrl(ollamaUrl)
|
||||
})
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { cleanUrl } from "@/libs/clean-url"
|
||||
import { PageAssistHtmlLoader } from "@/loader/html"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama"
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
import { MemoryVectorStore } from "langchain/vectorstores/memory"
|
||||
|
||||
@ -15,7 +15,7 @@ export const processSingleWebsite = async (url: string, query: string) => {
|
||||
const ollamaUrl = await getOllamaURL()
|
||||
|
||||
const embeddingModle = await defaultEmbeddingModelForRag()
|
||||
const ollamaEmbedding = new OllamaEmbeddings({
|
||||
const ollamaEmbedding = await pageAssistEmbeddingModel({
|
||||
model: embeddingModle || "",
|
||||
baseUrl: cleanUrl(ollamaUrl)
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user