Chore/MongoDB Connection (#3469)

getting rid of singleton design and properly close connection after every interaction
This commit is contained in:
Henry Heng
2024-11-07 19:05:00 +00:00
committed by GitHub
parent a6183aba7a
commit eeb1d17f50
3 changed files with 287 additions and 102 deletions
@@ -1,5 +1,4 @@
import { MongoClient, Collection, Document } from 'mongodb' import { MongoClient } from 'mongodb'
import { MongoDBChatMessageHistory } from '@langchain/mongodb'
import { BufferMemory, BufferMemoryInput } from 'langchain/memory' import { BufferMemory, BufferMemoryInput } from 'langchain/memory'
import { mapStoredMessageToChatMessage, AIMessage, HumanMessage, BaseMessage } from '@langchain/core/messages' import { mapStoredMessageToChatMessage, AIMessage, HumanMessage, BaseMessage } from '@langchain/core/messages'
import { import {
@@ -7,28 +6,13 @@ import {
getBaseClasses, getBaseClasses,
getCredentialData, getCredentialData,
getCredentialParam, getCredentialParam,
getVersion,
mapChatMessageToBaseMessage mapChatMessageToBaseMessage
} from '../../../src/utils' } from '../../../src/utils'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, MemoryMethods, MessageType } from '../../../src/Interface' import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, MemoryMethods, MessageType } from '../../../src/Interface'
let mongoClientSingleton: MongoClient // TODO: Add ability to specify env variable and use singleton pattern (i.e initialize MongoDB on server and pass to component)
let mongoUrl: string
const getMongoClient = async (newMongoUrl: string) => {
if (!mongoClientSingleton) {
// if client does not exist
mongoClientSingleton = new MongoClient(newMongoUrl)
mongoUrl = newMongoUrl
return mongoClientSingleton
} else if (mongoClientSingleton && newMongoUrl !== mongoUrl) {
// if client exists but url changed
mongoClientSingleton.close()
mongoClientSingleton = new MongoClient(newMongoUrl)
mongoUrl = newMongoUrl
return mongoClientSingleton
}
return mongoClientSingleton
}
class MongoDB_Memory implements INode { class MongoDB_Memory implements INode {
label: string label: string
name: string name: string
@@ -102,62 +86,43 @@ const initializeMongoDB = async (nodeData: INodeData, options: ICommonObject): P
const credentialData = await getCredentialData(nodeData.credential ?? '', options) const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const mongoDBConnectUrl = getCredentialParam('mongoDBConnectUrl', credentialData, nodeData) const mongoDBConnectUrl = getCredentialParam('mongoDBConnectUrl', credentialData, nodeData)
const driverInfo = { name: 'Flowise', version: (await getVersion()).version }
const client = await getMongoClient(mongoDBConnectUrl)
const collection = client.db(databaseName).collection(collectionName)
const mongoDBChatMessageHistory = new MongoDBChatMessageHistory({
collection,
sessionId
})
// @ts-ignore
mongoDBChatMessageHistory.getMessages = async (): Promise<BaseMessage[]> => {
const document = await collection.findOne({
sessionId: (mongoDBChatMessageHistory as any).sessionId
})
const messages = document?.messages || []
return messages.map(mapStoredMessageToChatMessage)
}
// @ts-ignore
mongoDBChatMessageHistory.addMessage = async (message: BaseMessage): Promise<void> => {
const messages = [message].map((msg) => msg.toDict())
await collection.updateOne(
{ sessionId: (mongoDBChatMessageHistory as any).sessionId },
{
$push: { messages: { $each: messages } }
},
{ upsert: true }
)
}
mongoDBChatMessageHistory.clear = async (): Promise<void> => {
await collection.deleteOne({ sessionId: (mongoDBChatMessageHistory as any).sessionId })
}
return new BufferMemoryExtended({ return new BufferMemoryExtended({
memoryKey: memoryKey ?? 'chat_history', memoryKey: memoryKey ?? 'chat_history',
// @ts-ignore
chatHistory: mongoDBChatMessageHistory,
sessionId, sessionId,
collection mongoConnection: {
databaseName,
collectionName,
mongoDBConnectUrl,
driverInfo
}
}) })
} }
interface BufferMemoryExtendedInput { interface BufferMemoryExtendedInput {
collection: Collection<Document>
sessionId: string sessionId: string
mongoConnection: {
databaseName: string
collectionName: string
mongoDBConnectUrl: string
driverInfo: { name: string; version: string }
}
} }
class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods { class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
sessionId = '' sessionId = ''
collection: Collection<Document> mongoConnection: {
databaseName: string
collectionName: string
mongoDBConnectUrl: string
driverInfo: { name: string; version: string }
}
constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) { constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) {
super(fields) super(fields)
this.sessionId = fields.sessionId this.sessionId = fields.sessionId
this.collection = fields.collection this.mongoConnection = fields.mongoConnection
} }
async getChatMessages( async getChatMessages(
@@ -165,20 +130,24 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
returnBaseMessages = false, returnBaseMessages = false,
prependMessages?: IMessage[] prependMessages?: IMessage[]
): Promise<IMessage[] | BaseMessage[]> { ): Promise<IMessage[] | BaseMessage[]> {
if (!this.collection) return [] const client = new MongoClient(this.mongoConnection.mongoDBConnectUrl, { driverInfo: this.mongoConnection.driverInfo })
const collection = client.db(this.mongoConnection.databaseName).collection(this.mongoConnection.collectionName)
const id = overrideSessionId ? overrideSessionId : this.sessionId const id = overrideSessionId ? overrideSessionId : this.sessionId
const document = await this.collection.findOne({ sessionId: id }) const document = await collection.findOne({ sessionId: id })
const messages = document?.messages || [] const messages = document?.messages || []
const baseMessages = messages.map(mapStoredMessageToChatMessage) const baseMessages = messages.map(mapStoredMessageToChatMessage)
if (prependMessages?.length) { if (prependMessages?.length) {
baseMessages.unshift(...(await mapChatMessageToBaseMessage(prependMessages))) baseMessages.unshift(...(await mapChatMessageToBaseMessage(prependMessages)))
} }
await client.close()
return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages) return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages)
} }
async addChatMessages(msgArray: { text: string; type: MessageType }[], overrideSessionId = ''): Promise<void> { async addChatMessages(msgArray: { text: string; type: MessageType }[], overrideSessionId = ''): Promise<void> {
if (!this.collection) return const client = new MongoClient(this.mongoConnection.mongoDBConnectUrl, { driverInfo: this.mongoConnection.driverInfo })
const collection = client.db(this.mongoConnection.databaseName).collection(this.mongoConnection.collectionName)
const id = overrideSessionId ? overrideSessionId : this.sessionId const id = overrideSessionId ? overrideSessionId : this.sessionId
const input = msgArray.find((msg) => msg.type === 'userMessage') const input = msgArray.find((msg) => msg.type === 'userMessage')
@@ -187,7 +156,7 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
if (input) { if (input) {
const newInputMessage = new HumanMessage(input.text) const newInputMessage = new HumanMessage(input.text)
const messageToAdd = [newInputMessage].map((msg) => msg.toDict()) const messageToAdd = [newInputMessage].map((msg) => msg.toDict())
await this.collection.updateOne( await collection.updateOne(
{ sessionId: id }, { sessionId: id },
{ {
$push: { messages: { $each: messageToAdd } } $push: { messages: { $each: messageToAdd } }
@@ -199,7 +168,7 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
if (output) { if (output) {
const newOutputMessage = new AIMessage(output.text) const newOutputMessage = new AIMessage(output.text)
const messageToAdd = [newOutputMessage].map((msg) => msg.toDict()) const messageToAdd = [newOutputMessage].map((msg) => msg.toDict())
await this.collection.updateOne( await collection.updateOne(
{ sessionId: id }, { sessionId: id },
{ {
$push: { messages: { $each: messageToAdd } } $push: { messages: { $each: messageToAdd } }
@@ -207,14 +176,19 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
{ upsert: true } { upsert: true }
) )
} }
await client.close()
} }
async clearChatMessages(overrideSessionId = ''): Promise<void> { async clearChatMessages(overrideSessionId = ''): Promise<void> {
if (!this.collection) return const client = new MongoClient(this.mongoConnection.mongoDBConnectUrl, { driverInfo: this.mongoConnection.driverInfo })
const collection = client.db(this.mongoConnection.databaseName).collection(this.mongoConnection.collectionName)
const id = overrideSessionId ? overrideSessionId : this.sessionId const id = overrideSessionId ? overrideSessionId : this.sessionId
await this.collection.deleteOne({ sessionId: id }) await collection.deleteOne({ sessionId: id })
await this.clear() await this.clear()
await client.close()
} }
} }
@@ -1,13 +1,12 @@
import { flatten } from 'lodash' import { flatten } from 'lodash'
import { MongoClient } from 'mongodb'
import { MongoDBAtlasVectorSearch } from '@langchain/mongodb'
import { Embeddings } from '@langchain/core/embeddings' import { Embeddings } from '@langchain/core/embeddings'
import { Document } from '@langchain/core/documents' import { Document } from '@langchain/core/documents'
import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams, IndexingResult } from '../../../src/Interface' import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams, IndexingResult } from '../../../src/Interface'
import { getBaseClasses, getCredentialData, getCredentialParam, getVersion } from '../../../src/utils' import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
import { addMMRInputParams, resolveVectorStoreOrRetriever } from '../VectorStoreUtils' import { addMMRInputParams, resolveVectorStoreOrRetriever } from '../VectorStoreUtils'
import { VectorStore } from '@langchain/core/vectorstores' import { MongoDBAtlasVectorSearch } from './core'
// TODO: Add ability to specify env variable and use singleton pattern (i.e initialize MongoDB on server and pass to component)
class MongoDBAtlas_VectorStores implements INode { class MongoDBAtlas_VectorStores implements INode {
label: string label: string
name: string name: string
@@ -142,20 +141,18 @@ class MongoDBAtlas_VectorStores implements INode {
} }
} }
const mongoClient = await getMongoClient(mongoDBConnectUrl)
try { try {
const collection = mongoClient.db(databaseName).collection(collectionName)
if (!textKey || textKey === '') textKey = 'text' if (!textKey || textKey === '') textKey = 'text'
if (!embeddingKey || embeddingKey === '') embeddingKey = 'embedding' if (!embeddingKey || embeddingKey === '') embeddingKey = 'embedding'
const mongoDBAtlasVectorSearch = new MongoDBAtlasVectorSearch(embeddings, { const mongoDBAtlasVectorSearch = new MongoDBAtlasVectorSearch(embeddings, {
collection, connectionDetails: { mongoDBConnectUrl, databaseName, collectionName },
indexName, indexName,
textKey, textKey,
embeddingKey embeddingKey
}) })
await mongoDBAtlasVectorSearch.addDocuments(finalDocs) await mongoDBAtlasVectorSearch.addDocuments(finalDocs)
return { numAdded: finalDocs.length, addedDocs: finalDocs } return { numAdded: finalDocs.length, addedDocs: finalDocs }
} catch (e) { } catch (e) {
throw new Error(e) throw new Error(e)
@@ -175,28 +172,25 @@ class MongoDBAtlas_VectorStores implements INode {
let mongoDBConnectUrl = getCredentialParam('mongoDBConnectUrl', credentialData, nodeData) let mongoDBConnectUrl = getCredentialParam('mongoDBConnectUrl', credentialData, nodeData)
const filter: MongoDBAtlasVectorSearch['FilterType'] = {} const mongoDbFilter: MongoDBAtlasVectorSearch['FilterType'] = {}
const mongoClient = await getMongoClient(mongoDBConnectUrl)
try { try {
const collection = mongoClient.db(databaseName).collection(collectionName)
if (!textKey || textKey === '') textKey = 'text' if (!textKey || textKey === '') textKey = 'text'
if (!embeddingKey || embeddingKey === '') embeddingKey = 'embedding' if (!embeddingKey || embeddingKey === '') embeddingKey = 'embedding'
const vectorStore = new MongoDBAtlasVectorSearch(embeddings, { const vectorStore = new MongoDBAtlasVectorSearch(embeddings, {
collection, connectionDetails: { mongoDBConnectUrl, databaseName, collectionName },
indexName, indexName,
textKey, textKey,
embeddingKey embeddingKey
}) as unknown as VectorStore })
if (mongoMetadataFilter) { if (mongoMetadataFilter) {
const metadataFilter = typeof mongoMetadataFilter === 'object' ? mongoMetadataFilter : JSON.parse(mongoMetadataFilter) const metadataFilter = typeof mongoMetadataFilter === 'object' ? mongoMetadataFilter : JSON.parse(mongoMetadataFilter)
for (const key in metadataFilter) { for (const key in metadataFilter) {
filter.preFilter = { mongoDbFilter.preFilter = {
...filter.preFilter, ...mongoDbFilter.preFilter,
[key]: { [key]: {
$eq: metadataFilter[key] $eq: metadataFilter[key]
} }
@@ -204,31 +198,11 @@ class MongoDBAtlas_VectorStores implements INode {
} }
} }
return resolveVectorStoreOrRetriever(nodeData, vectorStore, filter) return resolveVectorStoreOrRetriever(nodeData, vectorStore, mongoDbFilter)
} catch (e) { } catch (e) {
throw new Error(e) throw new Error(e)
} }
} }
} }
let mongoClientSingleton: MongoClient
let mongoUrl: string
const getMongoClient = async (newMongoUrl: string) => {
const driverInfo = { name: 'Flowise', version: (await getVersion()).version }
if (!mongoClientSingleton) {
// if client does not exist
mongoClientSingleton = new MongoClient(newMongoUrl, { driverInfo })
mongoUrl = newMongoUrl
return mongoClientSingleton
} else if (mongoClientSingleton && newMongoUrl !== mongoUrl) {
// if client exists but url changed
mongoClientSingleton.close()
mongoClientSingleton = new MongoClient(newMongoUrl, { driverInfo })
mongoUrl = newMongoUrl
return mongoClientSingleton
}
return mongoClientSingleton
}
module.exports = { nodeClass: MongoDBAtlas_VectorStores } module.exports = { nodeClass: MongoDBAtlas_VectorStores }
@@ -0,0 +1,237 @@
import { MongoClient, type Document as MongoDBDocument } from 'mongodb'
import { MaxMarginalRelevanceSearchOptions, VectorStore } from '@langchain/core/vectorstores'
import type { EmbeddingsInterface } from '@langchain/core/embeddings'
import { chunkArray } from '@langchain/core/utils/chunk_array'
import { Document } from '@langchain/core/documents'
import { maximalMarginalRelevance } from '@langchain/core/utils/math'
import { AsyncCaller, AsyncCallerParams } from '@langchain/core/utils/async_caller'
import { getVersion } from '../../../src/utils'
export interface MongoDBAtlasVectorSearchLibArgs extends AsyncCallerParams {
readonly connectionDetails: {
readonly mongoDBConnectUrl: string
readonly databaseName: string
readonly collectionName: string
}
readonly indexName?: string
readonly textKey?: string
readonly embeddingKey?: string
readonly primaryKey?: string
}
type MongoDBAtlasFilter = {
preFilter?: MongoDBDocument
postFilterPipeline?: MongoDBDocument[]
includeEmbeddings?: boolean
} & MongoDBDocument
export class MongoDBAtlasVectorSearch extends VectorStore {
declare FilterType: MongoDBAtlasFilter
private readonly connectionDetails: {
readonly mongoDBConnectUrl: string
readonly databaseName: string
readonly collectionName: string
}
private readonly indexName: string
private readonly textKey: string
private readonly embeddingKey: string
private readonly primaryKey: string
private caller: AsyncCaller
_vectorstoreType(): string {
return 'mongodb_atlas'
}
constructor(embeddings: EmbeddingsInterface, args: MongoDBAtlasVectorSearchLibArgs) {
super(embeddings, args)
this.connectionDetails = args.connectionDetails
this.indexName = args.indexName ?? 'default'
this.textKey = args.textKey ?? 'text'
this.embeddingKey = args.embeddingKey ?? 'embedding'
this.primaryKey = args.primaryKey ?? '_id'
this.caller = new AsyncCaller(args)
}
async getClient() {
const driverInfo = { name: 'Flowise', version: (await getVersion()).version }
const mongoClient = new MongoClient(this.connectionDetails.mongoDBConnectUrl, { driverInfo })
return mongoClient
}
async closeConnection(client: MongoClient) {
await client.close()
}
async addVectors(vectors: number[][], documents: Document[], options?: { ids?: string[] }) {
const client = await this.getClient()
const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName)
const docs = vectors.map((embedding, idx) => ({
[this.textKey]: documents[idx].pageContent,
[this.embeddingKey]: embedding,
...documents[idx].metadata
}))
if (options?.ids === undefined) {
await collection.insertMany(docs)
} else {
if (options.ids.length !== vectors.length) {
throw new Error(`If provided, "options.ids" must be an array with the same length as "vectors".`)
}
const { ids } = options
for (let i = 0; i < docs.length; i += 1) {
await this.caller.call(async () => {
await collection.updateOne(
{ [this.primaryKey]: ids[i] },
{ $set: { [this.primaryKey]: ids[i], ...docs[i] } },
{ upsert: true }
)
})
}
}
await this.closeConnection(client)
return options?.ids ?? docs.map((doc) => doc[this.primaryKey])
}
async addDocuments(documents: Document[], options?: { ids?: string[] }) {
const texts = documents.map(({ pageContent }) => pageContent)
return this.addVectors(await this.embeddings.embedDocuments(texts), documents, options)
}
async similaritySearchVectorWithScore(query: number[], k: number, filter?: MongoDBAtlasFilter): Promise<[Document, number][]> {
const client = await this.getClient()
const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName)
const postFilterPipeline = filter?.postFilterPipeline ?? []
const preFilter: MongoDBDocument | undefined =
filter?.preFilter || filter?.postFilterPipeline || filter?.includeEmbeddings ? filter.preFilter : filter
const removeEmbeddingsPipeline = !filter?.includeEmbeddings
? [
{
$project: {
[this.embeddingKey]: 0
}
}
]
: []
const pipeline: MongoDBDocument[] = [
{
$vectorSearch: {
queryVector: this.fixArrayPrecision(query),
index: this.indexName,
path: this.embeddingKey,
limit: k,
numCandidates: 10 * k,
...(preFilter && { filter: preFilter })
}
},
{
$set: {
score: { $meta: 'vectorSearchScore' }
}
},
...removeEmbeddingsPipeline,
...postFilterPipeline
]
const results = await collection
.aggregate(pipeline)
.map<[Document, number]>((result) => {
const { score, [this.textKey]: text, ...metadata } = result
return [new Document({ pageContent: text, metadata }), score]
})
.toArray()
await this.closeConnection(client)
return results
}
async maxMarginalRelevanceSearch(query: string, options: MaxMarginalRelevanceSearchOptions<this['FilterType']>): Promise<Document[]> {
const { k, fetchK = 20, lambda = 0.5, filter } = options
const queryEmbedding = await this.embeddings.embedQuery(query)
// preserve the original value of includeEmbeddings
const includeEmbeddingsFlag = options.filter?.includeEmbeddings || false
// update filter to include embeddings, as they will be used in MMR
const includeEmbeddingsFilter = {
...filter,
includeEmbeddings: true
}
const resultDocs = await this.similaritySearchVectorWithScore(
this.fixArrayPrecision(queryEmbedding),
fetchK,
includeEmbeddingsFilter
)
const embeddingList = resultDocs.map((doc) => doc[0].metadata[this.embeddingKey])
const mmrIndexes = maximalMarginalRelevance(queryEmbedding, embeddingList, lambda, k)
return mmrIndexes.map((idx) => {
const doc = resultDocs[idx][0]
// remove embeddings if they were not requested originally
if (!includeEmbeddingsFlag) {
delete doc.metadata[this.embeddingKey]
}
return doc
})
}
async delete(params: { ids: any[] }): Promise<void> {
const client = await this.getClient()
const collection = client.db(this.connectionDetails.databaseName).collection(this.connectionDetails.collectionName)
const CHUNK_SIZE = 50
const chunkIds: any[][] = chunkArray(params.ids, CHUNK_SIZE)
for (const chunk of chunkIds) {
await collection.deleteMany({ _id: { $in: chunk } })
}
await this.closeConnection(client)
}
static async fromTexts(
texts: string[],
metadatas: object[] | object,
embeddings: EmbeddingsInterface,
dbConfig: MongoDBAtlasVectorSearchLibArgs & { ids?: string[] }
): Promise<MongoDBAtlasVectorSearch> {
const docs: Document[] = []
for (let i = 0; i < texts.length; i += 1) {
const metadata = Array.isArray(metadatas) ? metadatas[i] : metadatas
const newDoc = new Document({
pageContent: texts[i],
metadata
})
docs.push(newDoc)
}
return MongoDBAtlasVectorSearch.fromDocuments(docs, embeddings, dbConfig)
}
static async fromDocuments(
docs: Document[],
embeddings: EmbeddingsInterface,
dbConfig: MongoDBAtlasVectorSearchLibArgs & { ids?: string[] }
): Promise<MongoDBAtlasVectorSearch> {
const instance = new this(embeddings, dbConfig)
await instance.addDocuments(docs, { ids: dbConfig.ids })
return instance
}
fixArrayPrecision(array: number[]) {
return array.map((value) => {
if (Number.isInteger(value)) {
return value + 0.000000000000001
}
return value
})
}
}