From d64c84fc83adacc3455b388690821fd184e6a80f Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sun, 13 Oct 2024 19:07:01 +0530 Subject: [PATCH] feat: Implement Ollama embedding Adds support for Ollama embedding, enabling the use of Ollama as an embedding model for RAG. This allows users to leverage Ollama's advanced embedding capabilities for better document understanding and retrieval. --- src/components/Sidepanel/Settings/body.tsx | 5 +- src/hooks/useMessage.tsx | 2 +- src/models/OAIEmbedding.ts | 73 ++++++---------------- src/utils/memory-embeddings.ts | 3 +- 4 files changed, 23 insertions(+), 60 deletions(-) diff --git a/src/components/Sidepanel/Settings/body.tsx b/src/components/Sidepanel/Settings/body.tsx index 89410bc..dc36baa 100644 --- a/src/components/Sidepanel/Settings/body.tsx +++ b/src/components/Sidepanel/Settings/body.tsx @@ -11,7 +11,8 @@ import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize, defaultEmbeddingModelForRag, - saveForRag + saveForRag, + getEmbeddingModels } from "~/services/ollama" import { @@ -77,7 +78,7 @@ export const SettingsBody = () => { getOllamaURL(), systemPromptForNonRag(), promptForRag(), - getAllModels({ returnEmpty: true }), + getEmbeddingModels({ returnEmpty: true }), defaultEmbeddingChunkOverlap(), defaultEmbeddingChunkSize(), defaultEmbeddingModelForRag(), diff --git a/src/hooks/useMessage.tsx b/src/hooks/useMessage.tsx index 18cceb5..c85324a 100644 --- a/src/hooks/useMessage.tsx +++ b/src/hooks/useMessage.tsx @@ -539,7 +539,7 @@ export const useMessage = () => { if (selectedPrompt) { applicationChatHistory.unshift( new SystemMessage({ - content: selectedPrompt.content + content: selectedPrompt.content }) ) } diff --git a/src/models/OAIEmbedding.ts b/src/models/OAIEmbedding.ts index 7617021..c7b8a25 100644 --- a/src/models/OAIEmbedding.ts +++ b/src/models/OAIEmbedding.ts @@ -1,52 +1,21 @@ import { type ClientOptions, OpenAI as OpenAIClient } from "openai" -import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings" +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings" import { chunkArray } from "@langchain/core/utils/chunk_array" import { OpenAICoreRequestOptions, LegacyOpenAIInput } from "./types" import { wrapOpenAIClientError } from "./utils/openai" -/** - * Interface for OpenAIEmbeddings parameters. Extends EmbeddingsParams and - * defines additional parameters specific to the OpenAIEmbeddings class. - */ export interface OpenAIEmbeddingsParams extends EmbeddingsParams { - /** - * Model name to use - * Alias for `model` - */ modelName: string - /** Model name to use */ model: string - - /** - * The number of dimensions the resulting output embeddings should have. - * Only supported in `text-embedding-3` and later models. - */ dimensions?: number - - /** - * Timeout to use when making requests to OpenAI. - */ timeout?: number - - /** - * The maximum number of documents to embed in a single request. This is - * limited by the OpenAI API to a maximum of 2048. - */ batchSize?: number - - /** - * Whether to strip new lines from the input text. This is recommended by - * OpenAI for older models, but may not be suitable for all use cases. - * See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 - */ stripNewLines?: boolean - signal?: AbortSignal } export class OAIEmbedding - extends Embeddings - implements OpenAIEmbeddingsParams { + extends Embeddings { modelName = "text-embedding-ada-002" model = "text-embedding-ada-002" @@ -81,7 +50,7 @@ export class OAIEmbedding protected client: OpenAIClient protected clientConfig: ClientOptions - + signal?: AbortSignal constructor( @@ -107,7 +76,7 @@ export class OAIEmbedding this.modelName = fieldsWithDefaults?.model ?? fieldsWithDefaults?.modelName ?? this.model this.model = this.modelName - this.batchSize = fieldsWithDefaults?.batchSize + this.batchSize = fieldsWithDefaults?.batchSize || this.batchSize this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines this.timeout = fieldsWithDefaults?.timeout this.dimensions = fieldsWithDefaults?.dimensions @@ -127,15 +96,12 @@ export class OAIEmbedding ...configuration, ...fields?.configuration } + + + // initialize the client + this.client = new OpenAIClient(this.clientConfig) } - /** - * Method to generate embeddings for an array of documents. Splits the - * documents into batches and makes requests to the OpenAI API to generate - * embeddings. - * @param texts Array of documents to generate embeddings for. - * @returns Promise that resolves to a 2D array of embeddings for each document. - */ async embedDocuments(texts: string[]): Promise { const batches = chunkArray( this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, @@ -165,12 +131,6 @@ export class OAIEmbedding return embeddings } - /** - * Method to generate an embedding for a single document. Calls the - * embeddingWithRetry method with the document as the input. - * @param text Document to generate an embedding for. - * @returns Promise that resolves to an embedding for the document. - */ async embedQuery(text: string): Promise { const params: OpenAIClient.EmbeddingCreateParams = { model: this.model, @@ -183,16 +143,19 @@ export class OAIEmbedding return data[0].embedding } - /** - * Private method to make a request to the OpenAI API to generate - * embeddings. Handles the retry logic and returns the response from the - * API. - * @param request Request to send to the OpenAI API. - * @returns Promise that resolves to the response from the API. - */ + async _embed(texts: string[]): Promise { + const embeddings: number[][] = await Promise.all( + texts.map((text) => this.caller.call(() => this.embedQuery(text))) + ) + + return embeddings + } + + protected async embeddingWithRetry( request: OpenAIClient.EmbeddingCreateParams ) { + const requestOptions: OpenAICoreRequestOptions = {} if (this.azureOpenAIApiKey) { requestOptions.headers = { diff --git a/src/utils/memory-embeddings.ts b/src/utils/memory-embeddings.ts index 8d0decc..99d2be4 100644 --- a/src/utils/memory-embeddings.ts +++ b/src/utils/memory-embeddings.ts @@ -47,12 +47,11 @@ export const memoryEmbedding = async ({ type: string pdf: { content: string; page: number }[] keepTrackOfEmbedding: Record - ollamaEmbedding: OllamaEmbeddings + ollamaEmbedding: any setIsEmbedding: (value: boolean) => void setKeepTrackOfEmbedding: (value: Record) => void }) => { setIsEmbedding(true) - const loader = getLoader({ html, pdf, type, url }) const docs = await loader.load() const chunkSize = await defaultEmbeddingChunkSize()