From ca26e059eb12ecda65d4da8ae1f15f0b0a12e794 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 16 Nov 2024 19:33:51 +0530 Subject: [PATCH] feat: Improve memory embedding and vector store handling This commit includes the following improvements: - Update the `memoryEmbedding` function to use the `PAMemoryVectorStore` instead of the generic `MemoryVectorStore`. This ensures that the vector store is specifically designed for the Page Assist application. - Modify the `useMessage` hook to use the `PAMemoryVectorStore` type for the `keepTrackOfEmbedding` state. - Update the `rerankDocs` function to use the `EmbeddingsInterface` type instead of the deprecated `Embeddings` type. - Add a new `PageAssistVectorStore` class that extends the `VectorStore` interface and provides a custom implementation for the Page Assist application. These changes improve the handling of memory embeddings and vector stores, ensuring better compatibility and performance within the Page Assist application. --- src/hooks/useMessage.tsx | 8 ++- src/libs/PAMemoryVectorStore.ts | 104 ++++++++++++++++++++++++++++++ src/libs/PageAssistVectorStore.ts | 2 - src/utils/memory-embeddings.ts | 8 +-- src/utils/rerank.ts | 5 +- 5 files changed, 116 insertions(+), 11 deletions(-) create mode 100644 src/libs/PAMemoryVectorStore.ts diff --git a/src/hooks/useMessage.tsx b/src/hooks/useMessage.tsx index d06102d..6fb2175 100644 --- a/src/hooks/useMessage.tsx +++ b/src/hooks/useMessage.tsx @@ -34,6 +34,8 @@ import { pageAssistModel } from "@/models" import { getPrompt } from "@/services/application" import { humanMessageFormatter } from "@/utils/human-message" import { pageAssistEmbeddingModel } from "@/models/embedding" +import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore" +import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore" export const useMessage = () => { const { @@ -90,7 +92,7 @@ export const useMessage = () => { ) const [keepTrackOfEmbedding, setKeepTrackOfEmbedding] = React.useState<{ - [key: string]: MemoryVectorStore + [key: string]: PAMemoryVectorStore }>({}) const clearChat = () => { @@ -177,7 +179,7 @@ export const useMessage = () => { let embedURL: string, embedHTML: string, embedType: string let embedPDF: { content: string; page: number }[] = [] - let isAlreadyExistEmbedding: MemoryVectorStore + let isAlreadyExistEmbedding: PAMemoryVectorStore const { content: html, url: websiteUrl, @@ -212,7 +214,7 @@ export const useMessage = () => { currentChatModelSettings?.keepAlive ?? userDefaultModelSettings?.keepAlive }) - let vectorstore: MemoryVectorStore + let vectorstore: PAMemoryVectorStore try { if (isAlreadyExistEmbedding) { diff --git a/src/libs/PAMemoryVectorStore.ts b/src/libs/PAMemoryVectorStore.ts new file mode 100644 index 0000000..3c30083 --- /dev/null +++ b/src/libs/PAMemoryVectorStore.ts @@ -0,0 +1,104 @@ + +import { similarity as ml_distance_similarity } from "ml-distance" +import { VectorStore } from "@langchain/core/vectorstores" +import type { EmbeddingsInterface } from "@langchain/core/embeddings" +import { Document, DocumentInterface } from "@langchain/core/documents" +import { rerankDocs } from "../utils/rerank" + +interface MemoryVector { + content: string + embedding: number[] + metadata: Record +} + +interface MemoryVectorStoreArgs { + similarity?: typeof ml_distance_similarity.cosine +} + +export class PAMemoryVectorStore extends VectorStore { + + + declare FilterType: (doc: Document) => boolean + + private memoryVectors: MemoryVector[] = [] + private similarity: typeof ml_distance_similarity.cosine + + constructor(embeddings: EmbeddingsInterface, args?: MemoryVectorStoreArgs) { + super(embeddings, args) + this.similarity = args?.similarity ?? ml_distance_similarity.cosine + } + + _vectorstoreType(): string { + return "memory" + } + + async addVectors(vectors: number[][], documents: DocumentInterface[], options?: { [x: string]: any }): Promise { + const memoryVectors = documents.map((doc, index) => ({ + content: doc.pageContent, + embedding: vectors[index], + metadata: doc.metadata + })) + + this.memoryVectors.push(...memoryVectors) + } + similaritySearchVectorWithScore(query: number[], k: number, filter?: this["FilterType"]): Promise<[DocumentInterface, number][]> { + throw new Error("Method not implemented.") + } + + async addDocuments(documents: Document[]): Promise { + const texts = documents.map((doc) => doc.pageContent) + const embeddings = await this.embeddings.embedDocuments(texts) + await this.addVectors(embeddings, documents) + } + + async similaritySearch(query: string, k = 4): Promise { + const queryEmbedding = await this.embeddings.embedQuery(query) + + const similarities = this.memoryVectors.map((vector) => ({ + similarity: this.similarity(queryEmbedding, vector.embedding), + document: vector + })) + + similarities.sort((a, b) => b.similarity - a.similarity) + const topK = similarities.slice(0, k) + + const docs = topK.map(({ document }) => + new Document({ + pageContent: document.content, + metadata: document.metadata + }) + ) + + return docs + } + + async similaritySearchWithScore(query: string, k = 4): Promise<[Document, number][]> { + const queryEmbedding = await this.embeddings.embedQuery(query) + + const similarities = this.memoryVectors.map((vector) => ({ + similarity: this.similarity(queryEmbedding, vector.embedding), + document: vector + })) + + similarities.sort((a, b) => b.similarity - a.similarity) + const topK = similarities.slice(0, k) + + return topK.map(({ document, similarity }) => [ + new Document({ + pageContent: document.content, + metadata: document.metadata + }), + similarity + ]) + } + + static async fromDocuments( + docs: Document[], + embeddings: EmbeddingsInterface, + args?: MemoryVectorStoreArgs + ): Promise { + const store = new PAMemoryVectorStore(embeddings, args) + await store.addDocuments(docs) + return store + } +} diff --git a/src/libs/PageAssistVectorStore.ts b/src/libs/PageAssistVectorStore.ts index fd1d4f6..d52425e 100644 --- a/src/libs/PageAssistVectorStore.ts +++ b/src/libs/PageAssistVectorStore.ts @@ -3,8 +3,6 @@ import { VectorStore } from "@langchain/core/vectorstores" import type { EmbeddingsInterface } from "@langchain/core/embeddings" import { Document } from "@langchain/core/documents" import { getVector, insertVector } from "@/db/vector" -import { cp } from "fs" - /** * Interface representing a vector in memory. It includes the content * (text), the corresponding embedding (vector), and any associated diff --git a/src/utils/memory-embeddings.ts b/src/utils/memory-embeddings.ts index 99d2be4..9cb1b82 100644 --- a/src/utils/memory-embeddings.ts +++ b/src/utils/memory-embeddings.ts @@ -1,12 +1,12 @@ import { PageAssistHtmlLoader } from "~/loader/html" import { RecursiveCharacterTextSplitter } from "langchain/text_splitter" -import { MemoryVectorStore } from "langchain/vectorstores/memory" import { defaultEmbeddingChunkOverlap, defaultEmbeddingChunkSize } from "@/services/ollama" import { PageAssistPDFLoader } from "@/loader/pdf" +import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore" export const getLoader = ({ html, @@ -46,10 +46,10 @@ export const memoryEmbedding = async ({ html: string type: string pdf: { content: string; page: number }[] - keepTrackOfEmbedding: Record + keepTrackOfEmbedding: Record ollamaEmbedding: any setIsEmbedding: (value: boolean) => void - setKeepTrackOfEmbedding: (value: Record) => void + setKeepTrackOfEmbedding: (value: Record) => void }) => { setIsEmbedding(true) const loader = getLoader({ html, pdf, type, url }) @@ -63,7 +63,7 @@ export const memoryEmbedding = async ({ const chunks = await textSplitter.splitDocuments(docs) - const store = new MemoryVectorStore(ollamaEmbedding) + const store = new PAMemoryVectorStore(ollamaEmbedding) await store.addDocuments(chunks) setKeepTrackOfEmbedding({ diff --git a/src/utils/rerank.ts b/src/utils/rerank.ts index fcf30cf..e24f7c8 100644 --- a/src/utils/rerank.ts +++ b/src/utils/rerank.ts @@ -1,4 +1,4 @@ -import type { Embeddings } from "@langchain/core/embeddings" +import type { EmbeddingsInterface } from "@langchain/core/embeddings" import type { Document } from "@langchain/core/documents" import * as ml_distance from "ml-distance" @@ -9,7 +9,7 @@ export const rerankDocs = async ({ }: { query: string docs: Document[] - embedding: Embeddings + embedding: EmbeddingsInterface }) => { if (docs.length === 0) { return docs @@ -34,6 +34,7 @@ export const rerankDocs = async ({ } }) + console.log("similarity", similarity) const sortedDocs = similarity .sort((a, b) => b.similarity - a.similarity) .filter((sim) => sim.similarity > 0.5)