feat: Implement Ollama embedding

Adds support for Ollama embedding, enabling the use of Ollama as an embedding model for RAG.

This allows users to leverage Ollama's advanced embedding capabilities for better document understanding and retrieval.
This commit is contained in:
n4ze3m 2024-10-13 19:07:01 +05:30
parent ff4473c35b
commit d64c84fc83
4 changed files with 23 additions and 60 deletions

View File

@ -11,7 +11,8 @@ import {
defaultEmbeddingChunkOverlap, defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize, defaultEmbeddingChunkSize,
defaultEmbeddingModelForRag, defaultEmbeddingModelForRag,
saveForRag saveForRag,
getEmbeddingModels
} from "~/services/ollama" } from "~/services/ollama"
import { import {
@ -77,7 +78,7 @@ export const SettingsBody = () => {
getOllamaURL(), getOllamaURL(),
systemPromptForNonRag(), systemPromptForNonRag(),
promptForRag(), promptForRag(),
getAllModels({ returnEmpty: true }), getEmbeddingModels({ returnEmpty: true }),
defaultEmbeddingChunkOverlap(), defaultEmbeddingChunkOverlap(),
defaultEmbeddingChunkSize(), defaultEmbeddingChunkSize(),
defaultEmbeddingModelForRag(), defaultEmbeddingModelForRag(),

View File

@ -1,52 +1,21 @@
import { type ClientOptions, OpenAI as OpenAIClient } from "openai" import { type ClientOptions, OpenAI as OpenAIClient } from "openai"
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings" import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"
import { chunkArray } from "@langchain/core/utils/chunk_array" import { chunkArray } from "@langchain/core/utils/chunk_array"
import { OpenAICoreRequestOptions, LegacyOpenAIInput } from "./types" import { OpenAICoreRequestOptions, LegacyOpenAIInput } from "./types"
import { wrapOpenAIClientError } from "./utils/openai" import { wrapOpenAIClientError } from "./utils/openai"
/**
* Interface for OpenAIEmbeddings parameters. Extends EmbeddingsParams and
* defines additional parameters specific to the OpenAIEmbeddings class.
*/
export interface OpenAIEmbeddingsParams extends EmbeddingsParams { export interface OpenAIEmbeddingsParams extends EmbeddingsParams {
/**
* Model name to use
* Alias for `model`
*/
modelName: string modelName: string
/** Model name to use */
model: string model: string
/**
* The number of dimensions the resulting output embeddings should have.
* Only supported in `text-embedding-3` and later models.
*/
dimensions?: number dimensions?: number
/**
* Timeout to use when making requests to OpenAI.
*/
timeout?: number timeout?: number
/**
* The maximum number of documents to embed in a single request. This is
* limited by the OpenAI API to a maximum of 2048.
*/
batchSize?: number batchSize?: number
/**
* Whether to strip new lines from the input text. This is recommended by
* OpenAI for older models, but may not be suitable for all use cases.
* See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
*/
stripNewLines?: boolean stripNewLines?: boolean
signal?: AbortSignal signal?: AbortSignal
} }
export class OAIEmbedding export class OAIEmbedding
extends Embeddings extends Embeddings {
implements OpenAIEmbeddingsParams {
modelName = "text-embedding-ada-002" modelName = "text-embedding-ada-002"
model = "text-embedding-ada-002" model = "text-embedding-ada-002"
@ -107,7 +76,7 @@ export class OAIEmbedding
this.modelName = this.modelName =
fieldsWithDefaults?.model ?? fieldsWithDefaults?.modelName ?? this.model fieldsWithDefaults?.model ?? fieldsWithDefaults?.modelName ?? this.model
this.model = this.modelName this.model = this.modelName
this.batchSize = fieldsWithDefaults?.batchSize this.batchSize = fieldsWithDefaults?.batchSize || this.batchSize
this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines
this.timeout = fieldsWithDefaults?.timeout this.timeout = fieldsWithDefaults?.timeout
this.dimensions = fieldsWithDefaults?.dimensions this.dimensions = fieldsWithDefaults?.dimensions
@ -127,15 +96,12 @@ export class OAIEmbedding
...configuration, ...configuration,
...fields?.configuration ...fields?.configuration
} }
// initialize the client
this.client = new OpenAIClient(this.clientConfig)
} }
/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the OpenAI API to generate
* embeddings.
* @param texts Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> { async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray( const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
@ -165,12 +131,6 @@ export class OAIEmbedding
return embeddings return embeddings
} }
/**
* Method to generate an embedding for a single document. Calls the
* embeddingWithRetry method with the document as the input.
* @param text Document to generate an embedding for.
* @returns Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> { async embedQuery(text: string): Promise<number[]> {
const params: OpenAIClient.EmbeddingCreateParams = { const params: OpenAIClient.EmbeddingCreateParams = {
model: this.model, model: this.model,
@ -183,16 +143,19 @@ export class OAIEmbedding
return data[0].embedding return data[0].embedding
} }
/** async _embed(texts: string[]): Promise<number[][]> {
* Private method to make a request to the OpenAI API to generate const embeddings: number[][] = await Promise.all(
* embeddings. Handles the retry logic and returns the response from the texts.map((text) => this.caller.call(() => this.embedQuery(text)))
* API. )
* @param request Request to send to the OpenAI API.
* @returns Promise that resolves to the response from the API. return embeddings
*/ }
protected async embeddingWithRetry( protected async embeddingWithRetry(
request: OpenAIClient.EmbeddingCreateParams request: OpenAIClient.EmbeddingCreateParams
) { ) {
const requestOptions: OpenAICoreRequestOptions = {} const requestOptions: OpenAICoreRequestOptions = {}
if (this.azureOpenAIApiKey) { if (this.azureOpenAIApiKey) {
requestOptions.headers = { requestOptions.headers = {

View File

@ -47,12 +47,11 @@ export const memoryEmbedding = async ({
type: string type: string
pdf: { content: string; page: number }[] pdf: { content: string; page: number }[]
keepTrackOfEmbedding: Record<string, MemoryVectorStore> keepTrackOfEmbedding: Record<string, MemoryVectorStore>
ollamaEmbedding: OllamaEmbeddings ollamaEmbedding: any
setIsEmbedding: (value: boolean) => void setIsEmbedding: (value: boolean) => void
setKeepTrackOfEmbedding: (value: Record<string, MemoryVectorStore>) => void setKeepTrackOfEmbedding: (value: Record<string, MemoryVectorStore>) => void
}) => { }) => {
setIsEmbedding(true) setIsEmbedding(true)
const loader = getLoader({ html, pdf, type, url }) const loader = getLoader({ html, pdf, type, url })
const docs = await loader.load() const docs = await loader.load()
const chunkSize = await defaultEmbeddingChunkSize() const chunkSize = await defaultEmbeddingChunkSize()