diff --git a/packages/components/nodes/memory/MongoDBMemory/MongoDBMemory.ts b/packages/components/nodes/memory/MongoDBMemory/MongoDBMemory.ts index 9bb6bbf8..44832466 100644 --- a/packages/components/nodes/memory/MongoDBMemory/MongoDBMemory.ts +++ b/packages/components/nodes/memory/MongoDBMemory/MongoDBMemory.ts @@ -1,5 +1,4 @@ -import { MongoClient, Collection, Document } from 'mongodb' -import { MongoDBChatMessageHistory } from '@langchain/mongodb' +import { MongoClient } from 'mongodb' import { BufferMemory, BufferMemoryInput } from 'langchain/memory' import { mapStoredMessageToChatMessage, AIMessage, HumanMessage, BaseMessage } from '@langchain/core/messages' import { @@ -7,28 +6,13 @@ import { getBaseClasses, getCredentialData, getCredentialParam, + getVersion, mapChatMessageToBaseMessage } from '../../../src/utils' import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, MemoryMethods, MessageType } from '../../../src/Interface' -let mongoClientSingleton: MongoClient -let mongoUrl: string +// TODO: Add ability to specify env variable and use singleton pattern (i.e initialize MongoDB on server and pass to component) -const getMongoClient = async (newMongoUrl: string) => { - if (!mongoClientSingleton) { - // if client does not exist - mongoClientSingleton = new MongoClient(newMongoUrl) - mongoUrl = newMongoUrl - return mongoClientSingleton - } else if (mongoClientSingleton && newMongoUrl !== mongoUrl) { - // if client exists but url changed - mongoClientSingleton.close() - mongoClientSingleton = new MongoClient(newMongoUrl) - mongoUrl = newMongoUrl - return mongoClientSingleton - } - return mongoClientSingleton -} class MongoDB_Memory implements INode { label: string name: string @@ -102,62 +86,43 @@ const initializeMongoDB = async (nodeData: INodeData, options: ICommonObject): P const credentialData = await getCredentialData(nodeData.credential ?? '', options) const mongoDBConnectUrl = getCredentialParam('mongoDBConnectUrl', credentialData, nodeData) - - const client = await getMongoClient(mongoDBConnectUrl) - const collection = client.db(databaseName).collection(collectionName) - - const mongoDBChatMessageHistory = new MongoDBChatMessageHistory({ - collection, - sessionId - }) - - // @ts-ignore - mongoDBChatMessageHistory.getMessages = async (): Promise => { - const document = await collection.findOne({ - sessionId: (mongoDBChatMessageHistory as any).sessionId - }) - const messages = document?.messages || [] - return messages.map(mapStoredMessageToChatMessage) - } - - // @ts-ignore - mongoDBChatMessageHistory.addMessage = async (message: BaseMessage): Promise => { - const messages = [message].map((msg) => msg.toDict()) - await collection.updateOne( - { sessionId: (mongoDBChatMessageHistory as any).sessionId }, - { - $push: { messages: { $each: messages } } - }, - { upsert: true } - ) - } - - mongoDBChatMessageHistory.clear = async (): Promise => { - await collection.deleteOne({ sessionId: (mongoDBChatMessageHistory as any).sessionId }) - } + const driverInfo = { name: 'Flowise', version: (await getVersion()).version } return new BufferMemoryExtended({ memoryKey: memoryKey ?? 'chat_history', - // @ts-ignore - chatHistory: mongoDBChatMessageHistory, sessionId, - collection + mongoConnection: { + databaseName, + collectionName, + mongoDBConnectUrl, + driverInfo + } }) } interface BufferMemoryExtendedInput { - collection: Collection sessionId: string + mongoConnection: { + databaseName: string + collectionName: string + mongoDBConnectUrl: string + driverInfo: { name: string; version: string } + } } class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { sessionId = '' - collection: Collection + mongoConnection: { + databaseName: string + collectionName: string + mongoDBConnectUrl: string + driverInfo: { name: string; version: string } + } constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) { super(fields) this.sessionId = fields.sessionId - this.collection = fields.collection + this.mongoConnection = fields.mongoConnection } async getChatMessages( @@ -165,20 +130,24 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { returnBaseMessages = false, prependMessages?: IMessage[] ): Promise { - if (!this.collection) return [] + const client = new MongoClient(this.mongoConnection.mongoDBConnectUrl, { driverInfo: this.mongoConnection.driverInfo }) + const collection = client.db(this.mongoConnection.databaseName).collection(this.mongoConnection.collectionName) const id = overrideSessionId ? overrideSessionId : this.sessionId - const document = await this.collection.findOne({ sessionId: id }) + const document = await collection.findOne({ sessionId: id }) const messages = document?.messages || [] const baseMessages = messages.map(mapStoredMessageToChatMessage) if (prependMessages?.length) { baseMessages.unshift(...(await mapChatMessageToBaseMessage(prependMessages))) } + + await client.close() return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages) } async addChatMessages(msgArray: { text: string; type: MessageType }[], overrideSessionId = ''): Promise { - if (!this.collection) return + const client = new MongoClient(this.mongoConnection.mongoDBConnectUrl, { driverInfo: this.mongoConnection.driverInfo }) + const collection = client.db(this.mongoConnection.databaseName).collection(this.mongoConnection.collectionName) const id = overrideSessionId ? overrideSessionId : this.sessionId const input = msgArray.find((msg) => msg.type === 'userMessage') @@ -187,7 +156,7 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { if (input) { const newInputMessage = new HumanMessage(input.text) const messageToAdd = [newInputMessage].map((msg) => msg.toDict()) - await this.collection.updateOne( + await collection.updateOne( { sessionId: id }, { $push: { messages: { $each: messageToAdd } } @@ -199,7 +168,7 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { if (output) { const newOutputMessage = new AIMessage(output.text) const messageToAdd = [newOutputMessage].map((msg) => msg.toDict()) - await this.collection.updateOne( + await collection.updateOne( { sessionId: id }, { $push: { messages: { $each: messageToAdd } } @@ -207,14 +176,19 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { { upsert: true } ) } + + await client.close() } async clearChatMessages(overrideSessionId = ''): Promise { - if (!this.collection) return + const client = new MongoClient(this.mongoConnection.mongoDBConnectUrl, { driverInfo: this.mongoConnection.driverInfo }) + const collection = client.db(this.mongoConnection.databaseName).collection(this.mongoConnection.collectionName) const id = overrideSessionId ? overrideSessionId : this.sessionId - await this.collection.deleteOne({ sessionId: id }) + await collection.deleteOne({ sessionId: id }) await this.clear() + + await client.close() } } diff --git a/packages/components/nodes/vectorstores/MongoDBAtlas/MongoDBAtlas.ts b/packages/components/nodes/vectorstores/MongoDBAtlas/MongoDBAtlas.ts index 81c6bc2d..785c6448 100644 --- a/packages/components/nodes/vectorstores/MongoDBAtlas/MongoDBAtlas.ts +++ b/packages/components/nodes/vectorstores/MongoDBAtlas/MongoDBAtlas.ts @@ -1,13 +1,12 @@ import { flatten } from 'lodash' -import { MongoClient } from 'mongodb' -import { MongoDBAtlasVectorSearch } from '@langchain/mongodb' import { Embeddings } from '@langchain/core/embeddings' import { Document } from '@langchain/core/documents' import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams, IndexingResult } from '../../../src/Interface' -import { getBaseClasses, getCredentialData, getCredentialParam, getVersion } from '../../../src/utils' +import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { addMMRInputParams, resolveVectorStoreOrRetriever } from '../VectorStoreUtils' -import { VectorStore } from '@langchain/core/vectorstores' +import { MongoDBAtlasVectorSearch } from './core' +// TODO: Add ability to specify env variable and use singleton pattern (i.e initialize MongoDB on server and pass to component) class MongoDBAtlas_VectorStores implements INode { label: string name: string @@ -142,20 +141,18 @@ class MongoDBAtlas_VectorStores implements INode { } } - const mongoClient = await getMongoClient(mongoDBConnectUrl) try { - const collection = mongoClient.db(databaseName).collection(collectionName) - if (!textKey || textKey === '') textKey = 'text' if (!embeddingKey || embeddingKey === '') embeddingKey = 'embedding' const mongoDBAtlasVectorSearch = new MongoDBAtlasVectorSearch(embeddings, { - collection, + connectionDetails: { mongoDBConnectUrl, databaseName, collectionName }, indexName, textKey, embeddingKey }) await mongoDBAtlasVectorSearch.addDocuments(finalDocs) + return { numAdded: finalDocs.length, addedDocs: finalDocs } } catch (e) { throw new Error(e) @@ -175,28 +172,25 @@ class MongoDBAtlas_VectorStores implements INode { let mongoDBConnectUrl = getCredentialParam('mongoDBConnectUrl', credentialData, nodeData) - const filter: MongoDBAtlasVectorSearch['FilterType'] = {} + const mongoDbFilter: MongoDBAtlasVectorSearch['FilterType'] = {} - const mongoClient = await getMongoClient(mongoDBConnectUrl) try { - const collection = mongoClient.db(databaseName).collection(collectionName) - if (!textKey || textKey === '') textKey = 'text' if (!embeddingKey || embeddingKey === '') embeddingKey = 'embedding' const vectorStore = new MongoDBAtlasVectorSearch(embeddings, { - collection, + connectionDetails: { mongoDBConnectUrl, databaseName, collectionName }, indexName, textKey, embeddingKey - }) as unknown as VectorStore + }) if (mongoMetadataFilter) { const metadataFilter = typeof mongoMetadataFilter === 'object' ? mongoMetadataFilter : JSON.parse(mongoMetadataFilter) for (const key in metadataFilter) { - filter.preFilter = { - ...filter.preFilter, + mongoDbFilter.preFilter = { + ...mongoDbFilter.preFilter, [key]: { $eq: metadataFilter[key] } @@ -204,31 +198,11 @@ class MongoDBAtlas_VectorStores implements INode { } } - return resolveVectorStoreOrRetriever(nodeData, vectorStore, filter) + return resolveVectorStoreOrRetriever(nodeData, vectorStore, mongoDbFilter) } catch (e) { throw new Error(e) } } } -let mongoClientSingleton: MongoClient -let mongoUrl: string - -const getMongoClient = async (newMongoUrl: string) => { - const driverInfo = { name: 'Flowise', version: (await getVersion()).version } - - if (!mongoClientSingleton) { - // if client does not exist - mongoClientSingleton = new MongoClient(newMongoUrl, { driverInfo }) - mongoUrl = newMongoUrl - return mongoClientSingleton - } else if (mongoClientSingleton && newMongoUrl !== mongoUrl) { - // if client exists but url changed - mongoClientSingleton.close() - mongoClientSingleton = new MongoClient(newMongoUrl, { driverInfo }) - mongoUrl = newMongoUrl - return mongoClientSingleton - } - return mongoClientSingleton -} module.exports = { nodeClass: MongoDBAtlas_VectorStores } diff --git a/packages/components/nodes/vectorstores/MongoDBAtlas/core.ts b/packages/components/nodes/vectorstores/MongoDBAtlas/core.ts new file mode 100644 index 00000000..eb26178e --- /dev/null +++ b/packages/components/nodes/vectorstores/MongoDBAtlas/core.ts @@ -0,0 +1,237 @@ +import { MongoClient, type Document as MongoDBDocument } from 'mongodb' +import { MaxMarginalRelevanceSearchOptions, VectorStore } from '@langchain/core/vectorstores' +import type { EmbeddingsInterface } from '@langchain/core/embeddings' +import { chunkArray } from '@langchain/core/utils/chunk_array' +import { Document } from '@langchain/core/documents' +import { maximalMarginalRelevance } from '@langchain/core/utils/math' +import { AsyncCaller, AsyncCallerParams } from '@langchain/core/utils/async_caller' +import { getVersion } from '../../../src/utils' + +export interface MongoDBAtlasVectorSearchLibArgs extends AsyncCallerParams { + readonly connectionDetails: { + readonly mongoDBConnectUrl: string + readonly databaseName: string + readonly collectionName: string + } + readonly indexName?: string + readonly textKey?: string + readonly embeddingKey?: string + readonly primaryKey?: string +} + +type MongoDBAtlasFilter = { + preFilter?: MongoDBDocument + postFilterPipeline?: MongoDBDocument[] + includeEmbeddings?: boolean +} & MongoDBDocument + +export class MongoDBAtlasVectorSearch extends VectorStore { + declare FilterType: MongoDBAtlasFilter + + private readonly connectionDetails: { + readonly mongoDBConnectUrl: string + readonly databaseName: string + readonly collectionName: string + } + + private readonly indexName: string + + private readonly textKey: string + + private readonly embeddingKey: string + + private readonly primaryKey: string + + private caller: AsyncCaller + + _vectorstoreType(): string { + return 'mongodb_atlas' + } + + constructor(embeddings: EmbeddingsInterface, args: MongoDBAtlasVectorSearchLibArgs) { + super(embeddings, args) + this.connectionDetails = args.connectionDetails + this.indexName = args.indexName ?? 'default' + this.textKey = args.textKey ?? 'text' + this.embeddingKey = args.embeddingKey ?? 'embedding' + this.primaryKey = args.primaryKey ?? '_id' + this.caller = new AsyncCaller(args) + } + + async getClient() { + const driverInfo = { name: 'Flowise', version: (await getVersion()).version } + const mongoClient = new MongoClient(this.connectionDetails.mongoDBConnectUrl, { driverInfo }) + return mongoClient + } + + async closeConnection(client: MongoClient) { + await client.close() + } + + async addVectors(vectors: number[][], documents: Document[], options?: { ids?: string[] }) { + const client = await this.getClient() + const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName) + const docs = vectors.map((embedding, idx) => ({ + [this.textKey]: documents[idx].pageContent, + [this.embeddingKey]: embedding, + ...documents[idx].metadata + })) + if (options?.ids === undefined) { + await collection.insertMany(docs) + } else { + if (options.ids.length !== vectors.length) { + throw new Error(`If provided, "options.ids" must be an array with the same length as "vectors".`) + } + const { ids } = options + for (let i = 0; i < docs.length; i += 1) { + await this.caller.call(async () => { + await collection.updateOne( + { [this.primaryKey]: ids[i] }, + { $set: { [this.primaryKey]: ids[i], ...docs[i] } }, + { upsert: true } + ) + }) + } + } + await this.closeConnection(client) + return options?.ids ?? docs.map((doc) => doc[this.primaryKey]) + } + + async addDocuments(documents: Document[], options?: { ids?: string[] }) { + const texts = documents.map(({ pageContent }) => pageContent) + return this.addVectors(await this.embeddings.embedDocuments(texts), documents, options) + } + + async similaritySearchVectorWithScore(query: number[], k: number, filter?: MongoDBAtlasFilter): Promise<[Document, number][]> { + const client = await this.getClient() + const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName) + + const postFilterPipeline = filter?.postFilterPipeline ?? [] + const preFilter: MongoDBDocument | undefined = + filter?.preFilter || filter?.postFilterPipeline || filter?.includeEmbeddings ? filter.preFilter : filter + const removeEmbeddingsPipeline = !filter?.includeEmbeddings + ? [ + { + $project: { + [this.embeddingKey]: 0 + } + } + ] + : [] + + const pipeline: MongoDBDocument[] = [ + { + $vectorSearch: { + queryVector: this.fixArrayPrecision(query), + index: this.indexName, + path: this.embeddingKey, + limit: k, + numCandidates: 10 * k, + ...(preFilter && { filter: preFilter }) + } + }, + { + $set: { + score: { $meta: 'vectorSearchScore' } + } + }, + ...removeEmbeddingsPipeline, + ...postFilterPipeline + ] + + const results = await collection + .aggregate(pipeline) + .map<[Document, number]>((result) => { + const { score, [this.textKey]: text, ...metadata } = result + return [new Document({ pageContent: text, metadata }), score] + }) + .toArray() + + await this.closeConnection(client) + + return results + } + + async maxMarginalRelevanceSearch(query: string, options: MaxMarginalRelevanceSearchOptions): Promise { + const { k, fetchK = 20, lambda = 0.5, filter } = options + + const queryEmbedding = await this.embeddings.embedQuery(query) + + // preserve the original value of includeEmbeddings + const includeEmbeddingsFlag = options.filter?.includeEmbeddings || false + + // update filter to include embeddings, as they will be used in MMR + const includeEmbeddingsFilter = { + ...filter, + includeEmbeddings: true + } + + const resultDocs = await this.similaritySearchVectorWithScore( + this.fixArrayPrecision(queryEmbedding), + fetchK, + includeEmbeddingsFilter + ) + + const embeddingList = resultDocs.map((doc) => doc[0].metadata[this.embeddingKey]) + + const mmrIndexes = maximalMarginalRelevance(queryEmbedding, embeddingList, lambda, k) + + return mmrIndexes.map((idx) => { + const doc = resultDocs[idx][0] + + // remove embeddings if they were not requested originally + if (!includeEmbeddingsFlag) { + delete doc.metadata[this.embeddingKey] + } + return doc + }) + } + + async delete(params: { ids: any[] }): Promise { + const client = await this.getClient() + const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName) + const CHUNK_SIZE = 50 + const chunkIds: any[][] = chunkArray(params.ids, CHUNK_SIZE) + for (const chunk of chunkIds) { + await collection.deleteMany({ _id: { $in: chunk } }) + } + await this.closeConnection(client) + } + + static async fromTexts( + texts: string[], + metadatas: object[] | object, + embeddings: EmbeddingsInterface, + dbConfig: MongoDBAtlasVectorSearchLibArgs & { ids?: string[] } + ): Promise { + const docs: Document[] = [] + for (let i = 0; i < texts.length; i += 1) { + const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas + const newDoc = new Document({ + pageContent: texts[i], + metadata + }) + docs.push(newDoc) + } + return MongoDBAtlasVectorSearch.fromDocuments(docs, embeddings, dbConfig) + } + + static async fromDocuments( + docs: Document[], + embeddings: EmbeddingsInterface, + dbConfig: MongoDBAtlasVectorSearchLibArgs & { ids?: string[] } + ): Promise { + const instance = new this(embeddings, dbConfig) + await instance.addDocuments(docs, { ids: dbConfig.ids }) + return instance + } + + fixArrayPrecision(array: number[]) { + return array.map((value) => { + if (Number.isInteger(value)) { + return value + 0.000000000000001 + } + return value + }) + } +}