feat: Improve model selection and embedding

Refactor embedding models and their handling to improve performance and simplify the process.
Add a new model selection mechanism,  and enhance the UI for model selection, offering clearer and more user-friendly options for embedding models.
Refactor embeddings to use a common model for page assist and RAG, further improving performance and streamlining the workflow.
This commit is contained in:
n4ze3m 2024-10-12 23:32:00 +05:30
parent ba071ffeb1
commit 768ff2e555
14 changed files with 98 additions and 43 deletions

View File

@ -5,13 +5,14 @@ import {
defaultEmbeddingChunkOverlap, defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize, defaultEmbeddingChunkSize,
defaultEmbeddingModelForRag, defaultEmbeddingModelForRag,
getAllModels, getEmbeddingModels,
saveForRag saveForRag
} from "~/services/ollama" } from "~/services/ollama"
import { SettingPrompt } from "./prompt" import { SettingPrompt } from "./prompt"
import { useTranslation } from "react-i18next" import { useTranslation } from "react-i18next"
import { getNoOfRetrievedDocs, getTotalFilePerKB } from "@/services/app" import { getNoOfRetrievedDocs, getTotalFilePerKB } from "@/services/app"
import { SidepanelRag } from "./sidepanel-rag" import { SidepanelRag } from "./sidepanel-rag"
import { ProviderIcons } from "@/components/Common/ProviderIcon"
export const RagSettings = () => { export const RagSettings = () => {
const { t } = useTranslation("settings") const { t } = useTranslation("settings")
@ -29,7 +30,7 @@ export const RagSettings = () => {
totalFilePerKB, totalFilePerKB,
noOfRetrievedDocs noOfRetrievedDocs
] = await Promise.all([ ] = await Promise.all([
getAllModels({ returnEmpty: true }), getEmbeddingModels({ returnEmpty: true }),
defaultEmbeddingChunkOverlap(), defaultEmbeddingChunkOverlap(),
defaultEmbeddingChunkSize(), defaultEmbeddingChunkSize(),
defaultEmbeddingModelForRag(), defaultEmbeddingModelForRag(),
@ -113,18 +114,27 @@ export const RagSettings = () => {
]}> ]}>
<Select <Select
size="large" size="large"
filterOption={(input, option) =>
option!.label.toLowerCase().indexOf(input.toLowerCase()) >=
0 ||
option!.value.toLowerCase().indexOf(input.toLowerCase()) >=
0
}
showSearch showSearch
placeholder={t("rag.ragSettings.model.placeholder")} placeholder={t("rag.ragSettings.model.placeholder")}
style={{ width: "100%" }} style={{ width: "100%" }}
className="mt-4" className="mt-4"
filterOption={(input, option) =>
option.label.key
.toLowerCase()
.indexOf(input.toLowerCase()) >= 0
}
options={ollamaInfo.models?.map((model) => ({ options={ollamaInfo.models?.map((model) => ({
label: model.name, label: (
<span
key={model.model}
className="flex flex-row gap-3 items-center truncate">
<ProviderIcons
provider={model?.provider}
className="w-5 h-5"
/>
<span className="truncate">{model.name}</span>
</span>
),
value: model.model value: model.model
}))} }))}
/> />

View File

@ -26,7 +26,6 @@ import { notification } from "antd"
import { useTranslation } from "react-i18next" import { useTranslation } from "react-i18next"
import { usePageAssist } from "@/context" import { usePageAssist } from "@/context"
import { formatDocs } from "@/chain/chat-with-x" import { formatDocs } from "@/chain/chat-with-x"
import { OllamaEmbeddingsPageAssist } from "@/models/OllamaEmbedding"
import { useStorage } from "@plasmohq/storage/hook" import { useStorage } from "@plasmohq/storage/hook"
import { useStoreChatModelSettings } from "@/store/model" import { useStoreChatModelSettings } from "@/store/model"
import { getAllDefaultModelSettings } from "@/services/model-settings" import { getAllDefaultModelSettings } from "@/services/model-settings"
@ -34,6 +33,7 @@ import { getSystemPromptForWeb } from "@/web/web"
import { pageAssistModel } from "@/models" import { pageAssistModel } from "@/models"
import { getPrompt } from "@/services/application" import { getPrompt } from "@/services/application"
import { humanMessageFormatter } from "@/utils/human-message" import { humanMessageFormatter } from "@/utils/human-message"
import { pageAssistEmbeddingModel } from "@/models/embedding"
export const useMessage = () => { export const useMessage = () => {
const { const {
@ -202,7 +202,7 @@ export const useMessage = () => {
const ollamaUrl = await getOllamaURL() const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag() const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddingsPageAssist({ const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || selectedModel, model: embeddingModle || selectedModel,
baseUrl: cleanUrl(ollamaUrl), baseUrl: cleanUrl(ollamaUrl),
signal: embeddingSignal, signal: embeddingSignal,

View File

@ -24,7 +24,6 @@ import { generateHistory } from "@/utils/generate-history"
import { useTranslation } from "react-i18next" import { useTranslation } from "react-i18next"
import { saveMessageOnError, saveMessageOnSuccess } from "./chat-helper" import { saveMessageOnError, saveMessageOnSuccess } from "./chat-helper"
import { usePageAssist } from "@/context" import { usePageAssist } from "@/context"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore" import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore"
import { formatDocs } from "@/chain/chat-with-x" import { formatDocs } from "@/chain/chat-with-x"
import { useWebUI } from "@/store/webui" import { useWebUI } from "@/store/webui"
@ -34,6 +33,7 @@ import { getAllDefaultModelSettings } from "@/services/model-settings"
import { pageAssistModel } from "@/models" import { pageAssistModel } from "@/models"
import { getNoOfRetrievedDocs } from "@/services/app" import { getNoOfRetrievedDocs } from "@/services/app"
import { humanMessageFormatter } from "@/utils/human-message" import { humanMessageFormatter } from "@/utils/human-message"
import { pageAssistEmbeddingModel } from "@/models/embedding"
export const useMessageOption = () => { export const useMessageOption = () => {
const { const {
@ -628,7 +628,7 @@ export const useMessageOption = () => {
const embeddingModle = await defaultEmbeddingModelForRag() const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaUrl = await getOllamaURL() const ollamaUrl = await getOllamaURL()
const ollamaEmbedding = new OllamaEmbeddings({ const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || selectedModel, model: embeddingModle || selectedModel,
baseUrl: cleanUrl(ollamaUrl), baseUrl: cleanUrl(ollamaUrl),
keepAlive: keepAlive:

View File

@ -5,7 +5,6 @@ import {
defaultEmbeddingChunkSize, defaultEmbeddingChunkSize,
getOllamaURL getOllamaURL
} from "@/services/ollama" } from "@/services/ollama"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { PageAssistVectorStore } from "./PageAssistVectorStore" import { PageAssistVectorStore } from "./PageAssistVectorStore"
import { PageAssisCSVUrlLoader } from "@/loader/csv" import { PageAssisCSVUrlLoader } from "@/loader/csv"
@ -13,6 +12,7 @@ import { PageAssisTXTUrlLoader } from "@/loader/txt"
import { PageAssistDocxLoader } from "@/loader/docx" import { PageAssistDocxLoader } from "@/loader/docx"
import { cleanUrl } from "./clean-url" import { cleanUrl } from "./clean-url"
import { sendEmbeddingCompleteNotification } from "./send-notification" import { sendEmbeddingCompleteNotification } from "./send-notification"
import { pageAssistEmbeddingModel } from "@/models/embedding"
export const processKnowledge = async (msg: any, id: string): Promise<void> => { export const processKnowledge = async (msg: any, id: string): Promise<void> => {
@ -28,7 +28,7 @@ export const processKnowledge = async (msg: any, id: string): Promise<void> => {
await updateKnowledgeStatus(id, "processing") await updateKnowledgeStatus(id, "processing")
const ollamaEmbedding = new OllamaEmbeddings({ const ollamaEmbedding = await pageAssistEmbeddingModel({
baseUrl: cleanUrl(ollamaUrl), baseUrl: cleanUrl(ollamaUrl),
model: knowledge.embedding_model model: knowledge.embedding_model
}) })

View File

@ -44,21 +44,6 @@ export interface OpenAIEmbeddingsParams extends EmbeddingsParams {
signal?: AbortSignal signal?: AbortSignal
} }
/**
* Class for generating embeddings using the OpenAI API. Extends the
* Embeddings class and implements OpenAIEmbeddingsParams and
* AzureOpenAIInput.
* @example
* ```typescript
* // Embed a query using OpenAIEmbeddings to generate embeddings for a given text
* const model = new OpenAIEmbeddings();
* const res = await model.embedQuery(
* "What would be a good company name for a company that makes colorful socks?",
* );
* console.log({ res });
*
* ```
*/
export class OAIEmbedding export class OAIEmbedding
extends Embeddings extends Embeddings
implements OpenAIEmbeddingsParams { implements OpenAIEmbeddingsParams {
@ -96,6 +81,7 @@ export class OAIEmbedding
protected client: OpenAIClient protected client: OpenAIClient
protected clientConfig: ClientOptions protected clientConfig: ClientOptions
signal?: AbortSignal signal?: AbortSignal
constructor( constructor(

36
src/models/embedding.ts Normal file
View File

@ -0,0 +1,36 @@
import { getModelInfo, isCustomModel } from "@/db/models"
import { OllamaEmbeddingsPageAssist } from "./OllamaEmbedding"
import { OAIEmbedding } from "./OAIEmbedding"
import { getOpenAIConfigById } from "@/db/openai"
type EmbeddingModel = {
model: string
baseUrl: string
signal?: AbortSignal
keepAlive?: string
}
export const pageAssistEmbeddingModel = async ({ baseUrl, model, keepAlive, signal }: EmbeddingModel) => {
const isCustom = isCustomModel(model)
if (isCustom) {
const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
return new OAIEmbedding({
modelName: modelInfo.model_id,
model: modelInfo.model_id,
signal,
openAIApiKey: providerInfo.apiKey || "temp",
configuration: {
apiKey: providerInfo.apiKey || "temp",
baseURL: providerInfo.baseUrl || "",
}
}) as any
}
return new OllamaEmbeddingsPageAssist({
model,
baseUrl,
keepAlive,
signal
})
}

View File

@ -40,7 +40,7 @@ export const pageAssistModel = async ({
if (isCustom) { if (isCustom) {
const modelInfo = await getModelInfo(model) const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id) const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
console.log(modelInfo, providerInfo)
return new ChatOpenAI({ return new ChatOpenAI({
modelName: modelInfo.model_id, modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp", openAIApiKey: providerInfo.apiKey || "temp",

View File

@ -133,6 +133,28 @@ export const getAllModels = async ({
} }
} }
export const getEmbeddingModels = async ({ returnEmpty }: {
returnEmpty?: boolean
}) => {
try {
const ollamaModels = await getAllModels({ returnEmpty })
const customModels = await ollamaFormatAllCustomModels()
return [
...ollamaModels.map((model) => {
return {
...model,
provider: "ollama"
}
}),
...customModels
]
} catch (e) {
console.error(e)
return []
}
}
export const deleteModel = async (model: string) => { export const deleteModel = async (model: string) => {
const baseUrl = await getOllamaURL() const baseUrl = await getOllamaURL()
const response = await fetcher(`${cleanUrl(baseUrl)}/api/delete`, { const response = await fetcher(`${cleanUrl(baseUrl)}/api/delete`, {

View File

@ -1,7 +1,7 @@
import { PageAssistHtmlLoader } from "~/loader/html" import { PageAssistHtmlLoader } from "~/loader/html"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { MemoryVectorStore } from "langchain/vectorstores/memory" import { MemoryVectorStore } from "langchain/vectorstores/memory"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import { import {
defaultEmbeddingChunkOverlap, defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize defaultEmbeddingChunkSize

View File

@ -1,6 +1,7 @@
import { cleanUrl } from "@/libs/clean-url" import { cleanUrl } from "@/libs/clean-url"
import { urlRewriteRuntime } from "@/libs/runtime" import { urlRewriteRuntime } from "@/libs/runtime"
import { PageAssistHtmlLoader } from "@/loader/html" import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import { import {
defaultEmbeddingChunkOverlap, defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize, defaultEmbeddingChunkSize,
@ -11,7 +12,7 @@ import {
getIsSimpleInternetSearch, getIsSimpleInternetSearch,
totalSearchResults totalSearchResults
} from "@/services/search" } from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents" import type { Document } from "@langchain/core/documents"
import * as cheerio from "cheerio" import * as cheerio from "cheerio"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
@ -81,7 +82,7 @@ export const webBraveSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL() const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag() const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({ const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "", model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl) baseUrl: cleanUrl(ollamaUrl)
}) })

View File

@ -1,6 +1,7 @@
import { cleanUrl } from "@/libs/clean-url" import { cleanUrl } from "@/libs/clean-url"
import { urlRewriteRuntime } from "@/libs/runtime" import { urlRewriteRuntime } from "@/libs/runtime"
import { PageAssistHtmlLoader } from "@/loader/html" import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import { import {
defaultEmbeddingChunkOverlap, defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize, defaultEmbeddingChunkSize,
@ -11,7 +12,6 @@ import {
getIsSimpleInternetSearch, getIsSimpleInternetSearch,
totalSearchResults totalSearchResults
} from "@/services/search" } from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents" import type { Document } from "@langchain/core/documents"
import * as cheerio from "cheerio" import * as cheerio from "cheerio"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
@ -85,7 +85,7 @@ export const webDuckDuckGoSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL() const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag() const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({ const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "", model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl) baseUrl: cleanUrl(ollamaUrl)
}) })

View File

@ -1,8 +1,8 @@
import { pageAssistEmbeddingModel } from "@/models/embedding"
import { import {
getIsSimpleInternetSearch, getIsSimpleInternetSearch,
totalSearchResults totalSearchResults
} from "@/services/search" } from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents" import type { Document } from "@langchain/core/documents"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { MemoryVectorStore } from "langchain/vectorstores/memory" import { MemoryVectorStore } from "langchain/vectorstores/memory"
@ -84,7 +84,7 @@ export const webGoogleSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL() const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag() const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({ const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "", model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl) baseUrl: cleanUrl(ollamaUrl)
}) })

View File

@ -1,6 +1,7 @@
import { cleanUrl } from "@/libs/clean-url" import { cleanUrl } from "@/libs/clean-url"
import { urlRewriteRuntime } from "@/libs/runtime" import { urlRewriteRuntime } from "@/libs/runtime"
import { PageAssistHtmlLoader } from "@/loader/html" import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import { import {
defaultEmbeddingChunkOverlap, defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize, defaultEmbeddingChunkSize,
@ -11,7 +12,6 @@ import {
getIsSimpleInternetSearch, getIsSimpleInternetSearch,
totalSearchResults totalSearchResults
} from "@/services/search" } from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents" import type { Document } from "@langchain/core/documents"
import * as cheerio from "cheerio" import * as cheerio from "cheerio"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
@ -99,7 +99,7 @@ export const webSogouSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL() const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag() const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({ const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "", model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl) baseUrl: cleanUrl(ollamaUrl)
}) })

View File

@ -1,7 +1,7 @@
import { cleanUrl } from "@/libs/clean-url" import { cleanUrl } from "@/libs/clean-url"
import { PageAssistHtmlLoader } from "@/loader/html" import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama" import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { MemoryVectorStore } from "langchain/vectorstores/memory" import { MemoryVectorStore } from "langchain/vectorstores/memory"
@ -15,7 +15,7 @@ export const processSingleWebsite = async (url: string, query: string) => {
const ollamaUrl = await getOllamaURL() const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag() const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({ const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "", model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl) baseUrl: cleanUrl(ollamaUrl)
}) })