From ba071ffeb19ec96931d19fdb00ad3df4cac19ce7 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 12 Oct 2024 21:12:45 +0530 Subject: [PATCH] add embedding support --- src/models/OAIEmbedding.ts | 234 +++++++++++++++++++++++++++++++++++++ src/models/types.ts | 26 +++++ src/models/utils/openai.ts | 70 +++++++++++ 3 files changed, 330 insertions(+) create mode 100644 src/models/OAIEmbedding.ts create mode 100644 src/models/types.ts create mode 100644 src/models/utils/openai.ts diff --git a/src/models/OAIEmbedding.ts b/src/models/OAIEmbedding.ts new file mode 100644 index 0000000..b2df653 --- /dev/null +++ b/src/models/OAIEmbedding.ts @@ -0,0 +1,234 @@ +import { type ClientOptions, OpenAI as OpenAIClient } from "openai" +import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings" +import { chunkArray } from "@langchain/core/utils/chunk_array" +import { OpenAICoreRequestOptions, LegacyOpenAIInput } from "./types" +import { wrapOpenAIClientError } from "./utils/openai" + +/** + * Interface for OpenAIEmbeddings parameters. Extends EmbeddingsParams and + * defines additional parameters specific to the OpenAIEmbeddings class. + */ +export interface OpenAIEmbeddingsParams extends EmbeddingsParams { + /** + * Model name to use + * Alias for `model` + */ + modelName: string + /** Model name to use */ + model: string + + /** + * The number of dimensions the resulting output embeddings should have. + * Only supported in `text-embedding-3` and later models. + */ + dimensions?: number + + /** + * Timeout to use when making requests to OpenAI. + */ + timeout?: number + + /** + * The maximum number of documents to embed in a single request. This is + * limited by the OpenAI API to a maximum of 2048. + */ + batchSize?: number + + /** + * Whether to strip new lines from the input text. This is recommended by + * OpenAI for older models, but may not be suitable for all use cases. + * See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + */ + stripNewLines?: boolean + + signal?: AbortSignal +} + +/** + * Class for generating embeddings using the OpenAI API. Extends the + * Embeddings class and implements OpenAIEmbeddingsParams and + * AzureOpenAIInput. + * @example + * ```typescript + * // Embed a query using OpenAIEmbeddings to generate embeddings for a given text + * const model = new OpenAIEmbeddings(); + * const res = await model.embedQuery( + * "What would be a good company name for a company that makes colorful socks?", + * ); + * console.log({ res }); + * + * ``` + */ +export class OAIEmbedding + extends Embeddings + implements OpenAIEmbeddingsParams { + modelName = "text-embedding-ada-002" + + model = "text-embedding-ada-002" + + batchSize = 512 + + // TODO: Update to `false` on next minor release (see: https://github.com/langchain-ai/langchainjs/pull/3612) + stripNewLines = true + + /** + * The number of dimensions the resulting output embeddings should have. + * Only supported in `text-embedding-3` and later models. + */ + dimensions?: number + + timeout?: number + + azureOpenAIApiVersion?: string + + azureOpenAIApiKey?: string + + azureADTokenProvider?: () => Promise + + azureOpenAIApiInstanceName?: string + + azureOpenAIApiDeploymentName?: string + + azureOpenAIBasePath?: string + + organization?: string + + protected client: OpenAIClient + + protected clientConfig: ClientOptions + signal?: AbortSignal + + constructor( + fields?: Partial & { + verbose?: boolean + /** + * The OpenAI API key to use. + * Alias for `apiKey`. + */ + openAIApiKey?: string + /** The OpenAI API key to use. */ + apiKey?: string + configuration?: ClientOptions + }, + configuration?: ClientOptions & LegacyOpenAIInput + ) { + const fieldsWithDefaults = { maxConcurrency: 2, ...fields } + + super(fieldsWithDefaults) + + let apiKey = fieldsWithDefaults?.apiKey ?? fieldsWithDefaults?.openAIApiKey + + this.modelName = + fieldsWithDefaults?.model ?? fieldsWithDefaults?.modelName ?? this.model + this.model = this.modelName + this.batchSize = fieldsWithDefaults?.batchSize + this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines + this.timeout = fieldsWithDefaults?.timeout + this.dimensions = fieldsWithDefaults?.dimensions + + if (fields.signal) { + this.signal = fields.signal + } + + + this.clientConfig = { + apiKey, + organization: this.organization, + baseURL: configuration?.basePath, + dangerouslyAllowBrowser: true, + defaultHeaders: configuration?.baseOptions?.headers, + defaultQuery: configuration?.baseOptions?.params, + ...configuration, + ...fields?.configuration + } + } + + /** + * Method to generate embeddings for an array of documents. Splits the + * documents into batches and makes requests to the OpenAI API to generate + * embeddings. + * @param texts Array of documents to generate embeddings for. + * @returns Promise that resolves to a 2D array of embeddings for each document. + */ + async embedDocuments(texts: string[]): Promise { + const batches = chunkArray( + this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, + this.batchSize + ) + + const batchRequests = batches.map((batch) => { + const params: OpenAIClient.EmbeddingCreateParams = { + model: this.model, + input: batch + } + if (this.dimensions) { + params.dimensions = this.dimensions + } + return this.embeddingWithRetry(params) + }) + const batchResponses = await Promise.all(batchRequests) + + const embeddings: number[][] = [] + for (let i = 0; i < batchResponses.length; i += 1) { + const batch = batches[i] + const { data: batchResponse } = batchResponses[i] + for (let j = 0; j < batch.length; j += 1) { + embeddings.push(batchResponse[j].embedding) + } + } + return embeddings + } + + /** + * Method to generate an embedding for a single document. Calls the + * embeddingWithRetry method with the document as the input. + * @param text Document to generate an embedding for. + * @returns Promise that resolves to an embedding for the document. + */ + async embedQuery(text: string): Promise { + const params: OpenAIClient.EmbeddingCreateParams = { + model: this.model, + input: this.stripNewLines ? text.replace(/\n/g, " ") : text + } + if (this.dimensions) { + params.dimensions = this.dimensions + } + const { data } = await this.embeddingWithRetry(params) + return data[0].embedding + } + + /** + * Private method to make a request to the OpenAI API to generate + * embeddings. Handles the retry logic and returns the response from the + * API. + * @param request Request to send to the OpenAI API. + * @returns Promise that resolves to the response from the API. + */ + protected async embeddingWithRetry( + request: OpenAIClient.EmbeddingCreateParams + ) { + const requestOptions: OpenAICoreRequestOptions = {} + if (this.azureOpenAIApiKey) { + requestOptions.headers = { + "api-key": this.azureOpenAIApiKey, + ...requestOptions.headers + } + requestOptions.query = { + "api-version": this.azureOpenAIApiVersion, + ...requestOptions.query + } + } + return this.caller.call(async () => { + try { + const res = await this.client.embeddings.create(request, { + ...requestOptions, + signal: this.signal + }) + return res + } catch (e) { + const error = wrapOpenAIClientError(e) + throw error + } + }) + } +} diff --git a/src/models/types.ts b/src/models/types.ts new file mode 100644 index 0000000..0fc96f6 --- /dev/null +++ b/src/models/types.ts @@ -0,0 +1,26 @@ +export type OpenAICoreRequestOptions< + Req extends object = Record +> = { + path?: string; + query?: Req | undefined; + body?: Req | undefined; + headers?: Record | undefined; + + maxRetries?: number; + stream?: boolean | undefined; + timeout?: number; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + httpAgent?: any; + signal?: AbortSignal | undefined | null; + idempotencyKey?: string; +}; + +export interface LegacyOpenAIInput { + /** @deprecated Use baseURL instead */ + basePath?: string; + /** @deprecated Use defaultHeaders and defaultQuery instead */ + baseOptions?: { + headers?: Record; + params?: Record; + }; +} diff --git a/src/models/utils/openai.ts b/src/models/utils/openai.ts new file mode 100644 index 0000000..22ecb56 --- /dev/null +++ b/src/models/utils/openai.ts @@ -0,0 +1,70 @@ +import { + APIConnectionTimeoutError, + APIUserAbortError, + OpenAI as OpenAIClient, + } from "openai"; + import { zodToJsonSchema } from "zod-to-json-schema"; + import type { StructuredToolInterface } from "@langchain/core/tools"; + import { + convertToOpenAIFunction, + convertToOpenAITool, + } from "@langchain/core/utils/function_calling"; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + export function wrapOpenAIClientError(e: any) { + let error; + if (e.constructor.name === APIConnectionTimeoutError.name) { + error = new Error(e.message); + error.name = "TimeoutError"; + } else if (e.constructor.name === APIUserAbortError.name) { + error = new Error(e.message); + error.name = "AbortError"; + } else { + error = e; + } + return error; + } + + export { + convertToOpenAIFunction as formatToOpenAIFunction, + convertToOpenAITool as formatToOpenAITool, + }; + + export function formatToOpenAIAssistantTool(tool: StructuredToolInterface) { + return { + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: zodToJsonSchema(tool.schema), + }, + }; + } + + export type OpenAIToolChoice = + | OpenAIClient.ChatCompletionToolChoiceOption + | "any" + | string; + + export function formatToOpenAIToolChoice( + toolChoice?: OpenAIToolChoice + ): OpenAIClient.ChatCompletionToolChoiceOption | undefined { + if (!toolChoice) { + return undefined; + } else if (toolChoice === "any" || toolChoice === "required") { + return "required"; + } else if (toolChoice === "auto") { + return "auto"; + } else if (toolChoice === "none") { + return "none"; + } else if (typeof toolChoice === "string") { + return { + type: "function", + function: { + name: toolChoice, + }, + }; + } else { + return toolChoice; + } + } \ No newline at end of file