feat: Improve model selection and embedding

Refactor embedding models and their handling to improve performance and simplify the process.
Add a new model selection mechanism,  and enhance the UI for model selection, offering clearer and more user-friendly options for embedding models.
Refactor embeddings to use a common model for page assist and RAG, further improving performance and streamlining the workflow.
This commit is contained in:
n4ze3m 2024-10-12 23:32:00 +05:30
parent ba071ffeb1
commit 768ff2e555
14 changed files with 98 additions and 43 deletions

View File

@ -5,13 +5,14 @@ import {
defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize,
defaultEmbeddingModelForRag,
getAllModels,
getEmbeddingModels,
saveForRag
} from "~/services/ollama"
import { SettingPrompt } from "./prompt"
import { useTranslation } from "react-i18next"
import { getNoOfRetrievedDocs, getTotalFilePerKB } from "@/services/app"
import { SidepanelRag } from "./sidepanel-rag"
import { ProviderIcons } from "@/components/Common/ProviderIcon"
export const RagSettings = () => {
const { t } = useTranslation("settings")
@ -29,7 +30,7 @@ export const RagSettings = () => {
totalFilePerKB,
noOfRetrievedDocs
] = await Promise.all([
getAllModels({ returnEmpty: true }),
getEmbeddingModels({ returnEmpty: true }),
defaultEmbeddingChunkOverlap(),
defaultEmbeddingChunkSize(),
defaultEmbeddingModelForRag(),
@ -113,18 +114,27 @@ export const RagSettings = () => {
]}>
<Select
size="large"
filterOption={(input, option) =>
option!.label.toLowerCase().indexOf(input.toLowerCase()) >=
0 ||
option!.value.toLowerCase().indexOf(input.toLowerCase()) >=
0
}
showSearch
placeholder={t("rag.ragSettings.model.placeholder")}
style={{ width: "100%" }}
className="mt-4"
filterOption={(input, option) =>
option.label.key
.toLowerCase()
.indexOf(input.toLowerCase()) >= 0
}
options={ollamaInfo.models?.map((model) => ({
label: model.name,
label: (
<span
key={model.model}
className="flex flex-row gap-3 items-center truncate">
<ProviderIcons
provider={model?.provider}
className="w-5 h-5"
/>
<span className="truncate">{model.name}</span>
</span>
),
value: model.model
}))}
/>

View File

@ -26,7 +26,6 @@ import { notification } from "antd"
import { useTranslation } from "react-i18next"
import { usePageAssist } from "@/context"
import { formatDocs } from "@/chain/chat-with-x"
import { OllamaEmbeddingsPageAssist } from "@/models/OllamaEmbedding"
import { useStorage } from "@plasmohq/storage/hook"
import { useStoreChatModelSettings } from "@/store/model"
import { getAllDefaultModelSettings } from "@/services/model-settings"
@ -34,6 +33,7 @@ import { getSystemPromptForWeb } from "@/web/web"
import { pageAssistModel } from "@/models"
import { getPrompt } from "@/services/application"
import { humanMessageFormatter } from "@/utils/human-message"
import { pageAssistEmbeddingModel } from "@/models/embedding"
export const useMessage = () => {
const {
@ -202,7 +202,7 @@ export const useMessage = () => {
const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddingsPageAssist({
const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || selectedModel,
baseUrl: cleanUrl(ollamaUrl),
signal: embeddingSignal,

View File

@ -24,7 +24,6 @@ import { generateHistory } from "@/utils/generate-history"
import { useTranslation } from "react-i18next"
import { saveMessageOnError, saveMessageOnSuccess } from "./chat-helper"
import { usePageAssist } from "@/context"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore"
import { formatDocs } from "@/chain/chat-with-x"
import { useWebUI } from "@/store/webui"
@ -34,6 +33,7 @@ import { getAllDefaultModelSettings } from "@/services/model-settings"
import { pageAssistModel } from "@/models"
import { getNoOfRetrievedDocs } from "@/services/app"
import { humanMessageFormatter } from "@/utils/human-message"
import { pageAssistEmbeddingModel } from "@/models/embedding"
export const useMessageOption = () => {
const {
@ -628,7 +628,7 @@ export const useMessageOption = () => {
const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaUrl = await getOllamaURL()
const ollamaEmbedding = new OllamaEmbeddings({
const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || selectedModel,
baseUrl: cleanUrl(ollamaUrl),
keepAlive:

View File

@ -5,7 +5,6 @@ import {
defaultEmbeddingChunkSize,
getOllamaURL
} from "@/services/ollama"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { PageAssistVectorStore } from "./PageAssistVectorStore"
import { PageAssisCSVUrlLoader } from "@/loader/csv"
@ -13,6 +12,7 @@ import { PageAssisTXTUrlLoader } from "@/loader/txt"
import { PageAssistDocxLoader } from "@/loader/docx"
import { cleanUrl } from "./clean-url"
import { sendEmbeddingCompleteNotification } from "./send-notification"
import { pageAssistEmbeddingModel } from "@/models/embedding"
export const processKnowledge = async (msg: any, id: string): Promise<void> => {
@ -28,7 +28,7 @@ export const processKnowledge = async (msg: any, id: string): Promise<void> => {
await updateKnowledgeStatus(id, "processing")
const ollamaEmbedding = new OllamaEmbeddings({
const ollamaEmbedding = await pageAssistEmbeddingModel({
baseUrl: cleanUrl(ollamaUrl),
model: knowledge.embedding_model
})

View File

@ -44,21 +44,6 @@ export interface OpenAIEmbeddingsParams extends EmbeddingsParams {
signal?: AbortSignal
}
/**
* Class for generating embeddings using the OpenAI API. Extends the
* Embeddings class and implements OpenAIEmbeddingsParams and
* AzureOpenAIInput.
* @example
* ```typescript
* // Embed a query using OpenAIEmbeddings to generate embeddings for a given text
* const model = new OpenAIEmbeddings();
* const res = await model.embedQuery(
* "What would be a good company name for a company that makes colorful socks?",
* );
* console.log({ res });
*
* ```
*/
export class OAIEmbedding
extends Embeddings
implements OpenAIEmbeddingsParams {
@ -96,6 +81,7 @@ export class OAIEmbedding
protected client: OpenAIClient
protected clientConfig: ClientOptions
signal?: AbortSignal
constructor(

36
src/models/embedding.ts Normal file
View File

@ -0,0 +1,36 @@
import { getModelInfo, isCustomModel } from "@/db/models"
import { OllamaEmbeddingsPageAssist } from "./OllamaEmbedding"
import { OAIEmbedding } from "./OAIEmbedding"
import { getOpenAIConfigById } from "@/db/openai"
type EmbeddingModel = {
model: string
baseUrl: string
signal?: AbortSignal
keepAlive?: string
}
export const pageAssistEmbeddingModel = async ({ baseUrl, model, keepAlive, signal }: EmbeddingModel) => {
const isCustom = isCustomModel(model)
if (isCustom) {
const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
return new OAIEmbedding({
modelName: modelInfo.model_id,
model: modelInfo.model_id,
signal,
openAIApiKey: providerInfo.apiKey || "temp",
configuration: {
apiKey: providerInfo.apiKey || "temp",
baseURL: providerInfo.baseUrl || "",
}
}) as any
}
return new OllamaEmbeddingsPageAssist({
model,
baseUrl,
keepAlive,
signal
})
}

View File

@ -40,7 +40,7 @@ export const pageAssistModel = async ({
if (isCustom) {
const modelInfo = await getModelInfo(model)
const providerInfo = await getOpenAIConfigById(modelInfo.provider_id)
console.log(modelInfo, providerInfo)
return new ChatOpenAI({
modelName: modelInfo.model_id,
openAIApiKey: providerInfo.apiKey || "temp",

View File

@ -133,6 +133,28 @@ export const getAllModels = async ({
}
}
export const getEmbeddingModels = async ({ returnEmpty }: {
returnEmpty?: boolean
}) => {
try {
const ollamaModels = await getAllModels({ returnEmpty })
const customModels = await ollamaFormatAllCustomModels()
return [
...ollamaModels.map((model) => {
return {
...model,
provider: "ollama"
}
}),
...customModels
]
} catch (e) {
console.error(e)
return []
}
}
export const deleteModel = async (model: string) => {
const baseUrl = await getOllamaURL()
const response = await fetcher(`${cleanUrl(baseUrl)}/api/delete`, {
@ -341,7 +363,7 @@ export const saveForRag = async (
await setDefaultEmbeddingChunkSize(chunkSize)
await setDefaultEmbeddingChunkOverlap(overlap)
await setTotalFilePerKB(totalFilePerKB)
if(noOfRetrievedDocs) {
if (noOfRetrievedDocs) {
await setNoOfRetrievedDocs(noOfRetrievedDocs)
}
}

View File

@ -1,7 +1,7 @@
import { PageAssistHtmlLoader } from "~/loader/html"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { MemoryVectorStore } from "langchain/vectorstores/memory"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import {
defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize

View File

@ -1,6 +1,7 @@
import { cleanUrl } from "@/libs/clean-url"
import { urlRewriteRuntime } from "@/libs/runtime"
import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import {
defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize,
@ -11,7 +12,7 @@ import {
getIsSimpleInternetSearch,
totalSearchResults
} from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents"
import * as cheerio from "cheerio"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
@ -81,7 +82,7 @@ export const webBraveSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({
const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl)
})

View File

@ -1,6 +1,7 @@
import { cleanUrl } from "@/libs/clean-url"
import { urlRewriteRuntime } from "@/libs/runtime"
import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import {
defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize,
@ -11,7 +12,6 @@ import {
getIsSimpleInternetSearch,
totalSearchResults
} from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents"
import * as cheerio from "cheerio"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
@ -85,7 +85,7 @@ export const webDuckDuckGoSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({
const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl)
})

View File

@ -1,8 +1,8 @@
import { pageAssistEmbeddingModel } from "@/models/embedding"
import {
getIsSimpleInternetSearch,
totalSearchResults
} from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { MemoryVectorStore } from "langchain/vectorstores/memory"
@ -84,7 +84,7 @@ export const webGoogleSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({
const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl)
})

View File

@ -1,6 +1,7 @@
import { cleanUrl } from "@/libs/clean-url"
import { urlRewriteRuntime } from "@/libs/runtime"
import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import {
defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize,
@ -11,7 +12,6 @@ import {
getIsSimpleInternetSearch,
totalSearchResults
} from "@/services/search"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import type { Document } from "@langchain/core/documents"
import * as cheerio from "cheerio"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
@ -99,7 +99,7 @@ export const webSogouSearch = async (query: string) => {
const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({
const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl)
})

View File

@ -1,7 +1,7 @@
import { cleanUrl } from "@/libs/clean-url"
import { PageAssistHtmlLoader } from "@/loader/html"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, getOllamaURL } from "@/services/ollama"
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { MemoryVectorStore } from "langchain/vectorstores/memory"
@ -15,7 +15,7 @@ export const processSingleWebsite = async (url: string, query: string) => {
const ollamaUrl = await getOllamaURL()
const embeddingModle = await defaultEmbeddingModelForRag()
const ollamaEmbedding = new OllamaEmbeddings({
const ollamaEmbedding = await pageAssistEmbeddingModel({
model: embeddingModle || "",
baseUrl: cleanUrl(ollamaUrl)
})