diff --git a/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts b/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts index f10f25ce..3e3697d1 100644 --- a/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts +++ b/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts @@ -1,9 +1,9 @@ -import { INode, INodeData, INodeParams } from '../../../src/Interface' -import { getBaseClasses } from '../../../src/utils' -import { ICommonObject } from '../../../src' +import { INode, INodeData, INodeParams, ICommonObject } from '../../../src/Interface' +import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { BufferMemory, BufferMemoryInput } from 'langchain/memory' -import { RedisChatMessageHistory, RedisChatMessageHistoryInput } from 'langchain/stores/message/redis' -import { createClient } from 'redis' +import { RedisChatMessageHistory, RedisChatMessageHistoryInput } from 'langchain/stores/message/ioredis' +import { mapStoredMessageToChatMessage, BaseMessage } from 'langchain/schema' +import { Redis } from 'ioredis' class RedisBackedChatMemory_Memory implements INode { label: string @@ -15,23 +15,25 @@ class RedisBackedChatMemory_Memory implements INode { category: string baseClasses: string[] inputs: INodeParams[] + credential: INodeParams constructor() { this.label = 'Redis-Backed Chat Memory' this.name = 'RedisBackedChatMemory' - this.version = 1.0 + this.version = 2.0 this.type = 'RedisBackedChatMemory' this.icon = 'redis.svg' this.category = 'Memory' this.description = 'Summarizes the conversation and stores the memory in Redis server' this.baseClasses = [this.type, ...getBaseClasses(BufferMemory)] + this.credential = { + label: 'Connect Credential', + name: 'credential', + type: 'credential', + optional: true, + credentialNames: ['redisCacheApi', 'redisCacheUrlApi'] + } this.inputs = [ - { - label: 'Base URL', - name: 'baseURL', - type: 'string', - default: 'redis://localhost:6379' - }, { label: 'Session Id', name: 'sessionId', @@ -60,11 +62,11 @@ class RedisBackedChatMemory_Memory implements INode { } async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { - return initalizeRedis(nodeData, options) + return await initalizeRedis(nodeData, options) } async clearSessionMemory(nodeData: INodeData, options: ICommonObject): Promise { - const redis = initalizeRedis(nodeData, options) + const redis = await initalizeRedis(nodeData, options) const sessionId = nodeData.inputs?.sessionId as string const chatId = options?.chatId as string options.logger.info(`Clearing Redis memory session ${sessionId ? sessionId : chatId}`) @@ -73,8 +75,7 @@ class RedisBackedChatMemory_Memory implements INode { } } -const initalizeRedis = (nodeData: INodeData, options: ICommonObject): BufferMemory => { - const baseURL = nodeData.inputs?.baseURL as string +const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Promise => { const sessionId = nodeData.inputs?.sessionId as string const sessionTTL = nodeData.inputs?.sessionTTL as number const memoryKey = nodeData.inputs?.memoryKey as string @@ -83,10 +84,29 @@ const initalizeRedis = (nodeData: INodeData, options: ICommonObject): BufferMemo let isSessionIdUsingChatMessageId = false if (!sessionId && chatId) isSessionIdUsingChatMessageId = true - const redisClient = createClient({ url: baseURL }) + const credentialData = await getCredentialData(nodeData.credential ?? '', options) + const redisUrl = getCredentialParam('redisUrl', credentialData, nodeData) + + let client: Redis + if (!redisUrl || redisUrl === '') { + const username = getCredentialParam('redisCacheUser', credentialData, nodeData) + const password = getCredentialParam('redisCachePwd', credentialData, nodeData) + const portStr = getCredentialParam('redisCachePort', credentialData, nodeData) + const host = getCredentialParam('redisCacheHost', credentialData, nodeData) + + client = new Redis({ + port: portStr ? parseInt(portStr) : 6379, + host, + username, + password + }) + } else { + client = new Redis(redisUrl) + } + let obj: RedisChatMessageHistoryInput = { sessionId: sessionId ? sessionId : chatId, - client: redisClient + client } if (sessionTTL) { @@ -98,6 +118,24 @@ const initalizeRedis = (nodeData: INodeData, options: ICommonObject): BufferMemo const redisChatMessageHistory = new RedisChatMessageHistory(obj) + redisChatMessageHistory.getMessages = async (): Promise => { + const rawStoredMessages = await client.lrange(sessionId ? sessionId : chatId, 0, -1) + const orderedMessages = rawStoredMessages.reverse().map((message) => JSON.parse(message)) + return orderedMessages.map(mapStoredMessageToChatMessage) + } + + redisChatMessageHistory.addMessage = async (message: BaseMessage): Promise => { + const messageToAdd = [message].map((msg) => msg.toDict()) + await client.lpush(sessionId ? sessionId : chatId, JSON.stringify(messageToAdd[0])) + if (sessionTTL) { + await client.expire(sessionId ? sessionId : chatId, sessionTTL) + } + } + + redisChatMessageHistory.clear = async (): Promise => { + await client.del(sessionId ? sessionId : chatId) + } + const memory = new BufferMemoryExtended({ memoryKey, chatHistory: redisChatMessageHistory,