Bugfix/Add missing TTL implementation for Redis (#2131)

add missing TTL implementation
This commit is contained in:
Henry Heng
2024-04-09 14:36:24 +01:00
committed by GitHub
parent 827de07e94
commit 057e056257
2 changed files with 17 additions and 3 deletions
@@ -162,7 +162,8 @@ const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Prom
chatHistory: redisChatMessageHistory,
sessionId,
windowSize,
redisClient: client
redisClient: client,
sessionTTL
})
return memory
@@ -172,18 +173,21 @@ interface BufferMemoryExtendedInput {
redisClient: Redis
sessionId: string
windowSize?: number
sessionTTL?: number
}
class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
sessionId = ''
redisClient: Redis
windowSize?: number
sessionTTL?: number
constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) {
super(fields)
this.sessionId = fields.sessionId
this.redisClient = fields.redisClient
this.windowSize = fields.windowSize
this.sessionTTL = fields.sessionTTL
}
async getChatMessages(overrideSessionId = '', returnBaseMessages = false): Promise<IMessage[] | BaseMessage[]> {
@@ -207,12 +211,14 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
const newInputMessage = new HumanMessage(input.text)
const messageToAdd = [newInputMessage].map((msg) => msg.toDict())
await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL)
}
if (output) {
const newOutputMessage = new AIMessage(output.text)
const messageToAdd = [newOutputMessage].map((msg) => msg.toDict())
await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL)
}
}
@@ -88,8 +88,10 @@ class UpstashRedisBackedChatMemory_Memory implements INode {
const initalizeUpstashRedis = async (nodeData: INodeData, options: ICommonObject): Promise<BufferMemory> => {
const baseURL = nodeData.inputs?.baseURL as string
const sessionTTL = nodeData.inputs?.sessionTTL as string
const sessionId = nodeData.inputs?.sessionId as string
const _sessionTTL = nodeData.inputs?.sessionTTL as string
const sessionTTL = _sessionTTL ? parseInt(_sessionTTL, 10) : undefined
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const upstashRestToken = getCredentialParam('upstashRestToken', credentialData, nodeData)
@@ -101,7 +103,7 @@ const initalizeUpstashRedis = async (nodeData: INodeData, options: ICommonObject
const redisChatMessageHistory = new UpstashRedisChatMessageHistory({
sessionId,
sessionTTL: sessionTTL ? parseInt(sessionTTL, 10) : undefined,
sessionTTL,
client
})
@@ -109,6 +111,7 @@ const initalizeUpstashRedis = async (nodeData: INodeData, options: ICommonObject
memoryKey: 'chat_history',
chatHistory: redisChatMessageHistory,
sessionId,
sessionTTL,
redisClient: client
})
@@ -118,16 +121,19 @@ const initalizeUpstashRedis = async (nodeData: INodeData, options: ICommonObject
interface BufferMemoryExtendedInput {
redisClient: Redis
sessionId: string
sessionTTL?: number
}
class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
sessionId = ''
redisClient: Redis
sessionTTL?: number
constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) {
super(fields)
this.sessionId = fields.sessionId
this.redisClient = fields.redisClient
this.sessionTTL = fields.sessionTTL
}
async getChatMessages(overrideSessionId = '', returnBaseMessages = false): Promise<IMessage[] | BaseMessage[]> {
@@ -152,12 +158,14 @@ class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
const newInputMessage = new HumanMessage(input.text)
const messageToAdd = [newInputMessage].map((msg) => msg.toDict())
await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL)
}
if (output) {
const newOutputMessage = new AIMessage(output.text)
const messageToAdd = [newOutputMessage].map((msg) => msg.toDict())
await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL)
}
}