feat: Improve memory embedding and vector store handling
This commit includes the following improvements: - Update the `memoryEmbedding` function to use the `PAMemoryVectorStore` instead of the generic `MemoryVectorStore`. This ensures that the vector store is specifically designed for the Page Assist application. - Modify the `useMessage` hook to use the `PAMemoryVectorStore` type for the `keepTrackOfEmbedding` state. - Update the `rerankDocs` function to use the `EmbeddingsInterface` type instead of the deprecated `Embeddings` type. - Add a new `PageAssistVectorStore` class that extends the `VectorStore` interface and provides a custom implementation for the Page Assist application. These changes improve the handling of memory embeddings and vector stores, ensuring better compatibility and performance within the Page Assist application.
This commit is contained in:
parent
64e88bd493
commit
ca26e059eb
@ -34,6 +34,8 @@ import { pageAssistModel } from "@/models"
|
||||
import { getPrompt } from "@/services/application"
|
||||
import { humanMessageFormatter } from "@/utils/human-message"
|
||||
import { pageAssistEmbeddingModel } from "@/models/embedding"
|
||||
import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore"
|
||||
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
|
||||
|
||||
export const useMessage = () => {
|
||||
const {
|
||||
@ -90,7 +92,7 @@ export const useMessage = () => {
|
||||
)
|
||||
|
||||
const [keepTrackOfEmbedding, setKeepTrackOfEmbedding] = React.useState<{
|
||||
[key: string]: MemoryVectorStore
|
||||
[key: string]: PAMemoryVectorStore
|
||||
}>({})
|
||||
|
||||
const clearChat = () => {
|
||||
@ -177,7 +179,7 @@ export const useMessage = () => {
|
||||
let embedURL: string, embedHTML: string, embedType: string
|
||||
let embedPDF: { content: string; page: number }[] = []
|
||||
|
||||
let isAlreadyExistEmbedding: MemoryVectorStore
|
||||
let isAlreadyExistEmbedding: PAMemoryVectorStore
|
||||
const {
|
||||
content: html,
|
||||
url: websiteUrl,
|
||||
@ -212,7 +214,7 @@ export const useMessage = () => {
|
||||
currentChatModelSettings?.keepAlive ??
|
||||
userDefaultModelSettings?.keepAlive
|
||||
})
|
||||
let vectorstore: MemoryVectorStore
|
||||
let vectorstore: PAMemoryVectorStore
|
||||
|
||||
try {
|
||||
if (isAlreadyExistEmbedding) {
|
||||
|
104
src/libs/PAMemoryVectorStore.ts
Normal file
104
src/libs/PAMemoryVectorStore.ts
Normal file
@ -0,0 +1,104 @@
|
||||
|
||||
import { similarity as ml_distance_similarity } from "ml-distance"
|
||||
import { VectorStore } from "@langchain/core/vectorstores"
|
||||
import type { EmbeddingsInterface } from "@langchain/core/embeddings"
|
||||
import { Document, DocumentInterface } from "@langchain/core/documents"
|
||||
import { rerankDocs } from "../utils/rerank"
|
||||
|
||||
interface MemoryVector {
|
||||
content: string
|
||||
embedding: number[]
|
||||
metadata: Record<string, any>
|
||||
}
|
||||
|
||||
interface MemoryVectorStoreArgs {
|
||||
similarity?: typeof ml_distance_similarity.cosine
|
||||
}
|
||||
|
||||
export class PAMemoryVectorStore extends VectorStore {
|
||||
|
||||
|
||||
declare FilterType: (doc: Document) => boolean
|
||||
|
||||
private memoryVectors: MemoryVector[] = []
|
||||
private similarity: typeof ml_distance_similarity.cosine
|
||||
|
||||
constructor(embeddings: EmbeddingsInterface, args?: MemoryVectorStoreArgs) {
|
||||
super(embeddings, args)
|
||||
this.similarity = args?.similarity ?? ml_distance_similarity.cosine
|
||||
}
|
||||
|
||||
_vectorstoreType(): string {
|
||||
return "memory"
|
||||
}
|
||||
|
||||
async addVectors(vectors: number[][], documents: DocumentInterface[], options?: { [x: string]: any }): Promise<void> {
|
||||
const memoryVectors = documents.map((doc, index) => ({
|
||||
content: doc.pageContent,
|
||||
embedding: vectors[index],
|
||||
metadata: doc.metadata
|
||||
}))
|
||||
|
||||
this.memoryVectors.push(...memoryVectors)
|
||||
}
|
||||
similaritySearchVectorWithScore(query: number[], k: number, filter?: this["FilterType"]): Promise<[DocumentInterface, number][]> {
|
||||
throw new Error("Method not implemented.")
|
||||
}
|
||||
|
||||
async addDocuments(documents: Document[]): Promise<void> {
|
||||
const texts = documents.map((doc) => doc.pageContent)
|
||||
const embeddings = await this.embeddings.embedDocuments(texts)
|
||||
await this.addVectors(embeddings, documents)
|
||||
}
|
||||
|
||||
async similaritySearch(query: string, k = 4): Promise<Document[]> {
|
||||
const queryEmbedding = await this.embeddings.embedQuery(query)
|
||||
|
||||
const similarities = this.memoryVectors.map((vector) => ({
|
||||
similarity: this.similarity(queryEmbedding, vector.embedding),
|
||||
document: vector
|
||||
}))
|
||||
|
||||
similarities.sort((a, b) => b.similarity - a.similarity)
|
||||
const topK = similarities.slice(0, k)
|
||||
|
||||
const docs = topK.map(({ document }) =>
|
||||
new Document({
|
||||
pageContent: document.content,
|
||||
metadata: document.metadata
|
||||
})
|
||||
)
|
||||
|
||||
return docs
|
||||
}
|
||||
|
||||
async similaritySearchWithScore(query: string, k = 4): Promise<[Document, number][]> {
|
||||
const queryEmbedding = await this.embeddings.embedQuery(query)
|
||||
|
||||
const similarities = this.memoryVectors.map((vector) => ({
|
||||
similarity: this.similarity(queryEmbedding, vector.embedding),
|
||||
document: vector
|
||||
}))
|
||||
|
||||
similarities.sort((a, b) => b.similarity - a.similarity)
|
||||
const topK = similarities.slice(0, k)
|
||||
|
||||
return topK.map(({ document, similarity }) => [
|
||||
new Document({
|
||||
pageContent: document.content,
|
||||
metadata: document.metadata
|
||||
}),
|
||||
similarity
|
||||
])
|
||||
}
|
||||
|
||||
static async fromDocuments(
|
||||
docs: Document[],
|
||||
embeddings: EmbeddingsInterface,
|
||||
args?: MemoryVectorStoreArgs
|
||||
): Promise<PAMemoryVectorStore> {
|
||||
const store = new PAMemoryVectorStore(embeddings, args)
|
||||
await store.addDocuments(docs)
|
||||
return store
|
||||
}
|
||||
}
|
@ -3,8 +3,6 @@ import { VectorStore } from "@langchain/core/vectorstores"
|
||||
import type { EmbeddingsInterface } from "@langchain/core/embeddings"
|
||||
import { Document } from "@langchain/core/documents"
|
||||
import { getVector, insertVector } from "@/db/vector"
|
||||
import { cp } from "fs"
|
||||
|
||||
/**
|
||||
* Interface representing a vector in memory. It includes the content
|
||||
* (text), the corresponding embedding (vector), and any associated
|
||||
|
@ -1,12 +1,12 @@
|
||||
import { PageAssistHtmlLoader } from "~/loader/html"
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
import { MemoryVectorStore } from "langchain/vectorstores/memory"
|
||||
|
||||
import {
|
||||
defaultEmbeddingChunkOverlap,
|
||||
defaultEmbeddingChunkSize
|
||||
} from "@/services/ollama"
|
||||
import { PageAssistPDFLoader } from "@/loader/pdf"
|
||||
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
|
||||
|
||||
export const getLoader = ({
|
||||
html,
|
||||
@ -46,10 +46,10 @@ export const memoryEmbedding = async ({
|
||||
html: string
|
||||
type: string
|
||||
pdf: { content: string; page: number }[]
|
||||
keepTrackOfEmbedding: Record<string, MemoryVectorStore>
|
||||
keepTrackOfEmbedding: Record<string, PAMemoryVectorStore>
|
||||
ollamaEmbedding: any
|
||||
setIsEmbedding: (value: boolean) => void
|
||||
setKeepTrackOfEmbedding: (value: Record<string, MemoryVectorStore>) => void
|
||||
setKeepTrackOfEmbedding: (value: Record<string, PAMemoryVectorStore>) => void
|
||||
}) => {
|
||||
setIsEmbedding(true)
|
||||
const loader = getLoader({ html, pdf, type, url })
|
||||
@ -63,7 +63,7 @@ export const memoryEmbedding = async ({
|
||||
|
||||
const chunks = await textSplitter.splitDocuments(docs)
|
||||
|
||||
const store = new MemoryVectorStore(ollamaEmbedding)
|
||||
const store = new PAMemoryVectorStore(ollamaEmbedding)
|
||||
|
||||
await store.addDocuments(chunks)
|
||||
setKeepTrackOfEmbedding({
|
||||
|
@ -1,4 +1,4 @@
|
||||
import type { Embeddings } from "@langchain/core/embeddings"
|
||||
import type { EmbeddingsInterface } from "@langchain/core/embeddings"
|
||||
import type { Document } from "@langchain/core/documents"
|
||||
import * as ml_distance from "ml-distance"
|
||||
|
||||
@ -9,7 +9,7 @@ export const rerankDocs = async ({
|
||||
}: {
|
||||
query: string
|
||||
docs: Document[]
|
||||
embedding: Embeddings
|
||||
embedding: EmbeddingsInterface
|
||||
}) => {
|
||||
if (docs.length === 0) {
|
||||
return docs
|
||||
@ -34,6 +34,7 @@ export const rerankDocs = async ({
|
||||
}
|
||||
})
|
||||
|
||||
console.log("similarity", similarity)
|
||||
const sortedDocs = similarity
|
||||
.sort((a, b) => b.similarity - a.similarity)
|
||||
.filter((sim) => sim.similarity > 0.5)
|
||||
|
Loading…
x
Reference in New Issue
Block a user