feat: Improve memory embedding and vector store handling

This commit includes the following improvements:

- Update the `memoryEmbedding` function to use the `PAMemoryVectorStore` instead of the generic `MemoryVectorStore`. This ensures that the vector store is specifically designed for the Page Assist application.
- Modify the `useMessage` hook to use the `PAMemoryVectorStore` type for the `keepTrackOfEmbedding` state.
- Update the `rerankDocs` function to use the `EmbeddingsInterface` type instead of the deprecated `Embeddings` type.
- Add a new `PageAssistVectorStore` class that extends the `VectorStore` interface and provides a custom implementation for the Page Assist application.

These changes improve the handling of memory embeddings and vector stores, ensuring better compatibility and performance within the Page Assist application.
This commit is contained in:
n4ze3m 2024-11-16 19:33:51 +05:30
parent 64e88bd493
commit ca26e059eb
5 changed files with 116 additions and 11 deletions

View File

@ -34,6 +34,8 @@ import { pageAssistModel } from "@/models"
import { getPrompt } from "@/services/application"
import { humanMessageFormatter } from "@/utils/human-message"
import { pageAssistEmbeddingModel } from "@/models/embedding"
import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore"
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
export const useMessage = () => {
const {
@ -90,7 +92,7 @@ export const useMessage = () => {
)
const [keepTrackOfEmbedding, setKeepTrackOfEmbedding] = React.useState<{
[key: string]: MemoryVectorStore
[key: string]: PAMemoryVectorStore
}>({})
const clearChat = () => {
@ -177,7 +179,7 @@ export const useMessage = () => {
let embedURL: string, embedHTML: string, embedType: string
let embedPDF: { content: string; page: number }[] = []
let isAlreadyExistEmbedding: MemoryVectorStore
let isAlreadyExistEmbedding: PAMemoryVectorStore
const {
content: html,
url: websiteUrl,
@ -212,7 +214,7 @@ export const useMessage = () => {
currentChatModelSettings?.keepAlive ??
userDefaultModelSettings?.keepAlive
})
let vectorstore: MemoryVectorStore
let vectorstore: PAMemoryVectorStore
try {
if (isAlreadyExistEmbedding) {

View File

@ -0,0 +1,104 @@
import { similarity as ml_distance_similarity } from "ml-distance"
import { VectorStore } from "@langchain/core/vectorstores"
import type { EmbeddingsInterface } from "@langchain/core/embeddings"
import { Document, DocumentInterface } from "@langchain/core/documents"
import { rerankDocs } from "../utils/rerank"
interface MemoryVector {
content: string
embedding: number[]
metadata: Record<string, any>
}
interface MemoryVectorStoreArgs {
similarity?: typeof ml_distance_similarity.cosine
}
export class PAMemoryVectorStore extends VectorStore {
declare FilterType: (doc: Document) => boolean
private memoryVectors: MemoryVector[] = []
private similarity: typeof ml_distance_similarity.cosine
constructor(embeddings: EmbeddingsInterface, args?: MemoryVectorStoreArgs) {
super(embeddings, args)
this.similarity = args?.similarity ?? ml_distance_similarity.cosine
}
_vectorstoreType(): string {
return "memory"
}
async addVectors(vectors: number[][], documents: DocumentInterface[], options?: { [x: string]: any }): Promise<void> {
const memoryVectors = documents.map((doc, index) => ({
content: doc.pageContent,
embedding: vectors[index],
metadata: doc.metadata
}))
this.memoryVectors.push(...memoryVectors)
}
similaritySearchVectorWithScore(query: number[], k: number, filter?: this["FilterType"]): Promise<[DocumentInterface, number][]> {
throw new Error("Method not implemented.")
}
async addDocuments(documents: Document[]): Promise<void> {
const texts = documents.map((doc) => doc.pageContent)
const embeddings = await this.embeddings.embedDocuments(texts)
await this.addVectors(embeddings, documents)
}
async similaritySearch(query: string, k = 4): Promise<Document[]> {
const queryEmbedding = await this.embeddings.embedQuery(query)
const similarities = this.memoryVectors.map((vector) => ({
similarity: this.similarity(queryEmbedding, vector.embedding),
document: vector
}))
similarities.sort((a, b) => b.similarity - a.similarity)
const topK = similarities.slice(0, k)
const docs = topK.map(({ document }) =>
new Document({
pageContent: document.content,
metadata: document.metadata
})
)
return docs
}
async similaritySearchWithScore(query: string, k = 4): Promise<[Document, number][]> {
const queryEmbedding = await this.embeddings.embedQuery(query)
const similarities = this.memoryVectors.map((vector) => ({
similarity: this.similarity(queryEmbedding, vector.embedding),
document: vector
}))
similarities.sort((a, b) => b.similarity - a.similarity)
const topK = similarities.slice(0, k)
return topK.map(({ document, similarity }) => [
new Document({
pageContent: document.content,
metadata: document.metadata
}),
similarity
])
}
static async fromDocuments(
docs: Document[],
embeddings: EmbeddingsInterface,
args?: MemoryVectorStoreArgs
): Promise<PAMemoryVectorStore> {
const store = new PAMemoryVectorStore(embeddings, args)
await store.addDocuments(docs)
return store
}
}

View File

@ -3,8 +3,6 @@ import { VectorStore } from "@langchain/core/vectorstores"
import type { EmbeddingsInterface } from "@langchain/core/embeddings"
import { Document } from "@langchain/core/documents"
import { getVector, insertVector } from "@/db/vector"
import { cp } from "fs"
/**
* Interface representing a vector in memory. It includes the content
* (text), the corresponding embedding (vector), and any associated

View File

@ -1,12 +1,12 @@
import { PageAssistHtmlLoader } from "~/loader/html"
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
import { MemoryVectorStore } from "langchain/vectorstores/memory"
import {
defaultEmbeddingChunkOverlap,
defaultEmbeddingChunkSize
} from "@/services/ollama"
import { PageAssistPDFLoader } from "@/loader/pdf"
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
export const getLoader = ({
html,
@ -46,10 +46,10 @@ export const memoryEmbedding = async ({
html: string
type: string
pdf: { content: string; page: number }[]
keepTrackOfEmbedding: Record<string, MemoryVectorStore>
keepTrackOfEmbedding: Record<string, PAMemoryVectorStore>
ollamaEmbedding: any
setIsEmbedding: (value: boolean) => void
setKeepTrackOfEmbedding: (value: Record<string, MemoryVectorStore>) => void
setKeepTrackOfEmbedding: (value: Record<string, PAMemoryVectorStore>) => void
}) => {
setIsEmbedding(true)
const loader = getLoader({ html, pdf, type, url })
@ -63,7 +63,7 @@ export const memoryEmbedding = async ({
const chunks = await textSplitter.splitDocuments(docs)
const store = new MemoryVectorStore(ollamaEmbedding)
const store = new PAMemoryVectorStore(ollamaEmbedding)
await store.addDocuments(chunks)
setKeepTrackOfEmbedding({

View File

@ -1,4 +1,4 @@
import type { Embeddings } from "@langchain/core/embeddings"
import type { EmbeddingsInterface } from "@langchain/core/embeddings"
import type { Document } from "@langchain/core/documents"
import * as ml_distance from "ml-distance"
@ -9,7 +9,7 @@ export const rerankDocs = async ({
}: {
query: string
docs: Document[]
embedding: Embeddings
embedding: EmbeddingsInterface
}) => {
if (docs.length === 0) {
return docs
@ -34,6 +34,7 @@ export const rerankDocs = async ({
}
})
console.log("similarity", similarity)
const sortedDocs = similarity
.sort((a, b) => b.similarity - a.similarity)
.filter((sim) => sim.similarity > 0.5)