Feature/Buffer Memory SessionId (#2111)

* add sessionId to buffer memory, add conversation summary buffer memory

* add fix for conv retrieval qa chain
This commit is contained in:
Henry Heng
2024-04-11 11:18:39 +01:00
committed by GitHub
parent 57b716c7d7
commit c33642cdf9
39 changed files with 784 additions and 574 deletions
@@ -1,8 +1,18 @@
import { FlowiseSummaryMemory, IMessage, INode, INodeData, INodeParams, MemoryMethods } from '../../../src/Interface'
import { convertBaseMessagetoIMessage, getBaseClasses } from '../../../src/utils'
import {
FlowiseSummaryMemory,
IMessage,
IDatabaseEntity,
INode,
INodeData,
INodeParams,
MemoryMethods,
ICommonObject
} from '../../../src/Interface'
import { getBaseClasses, mapChatMessageToBaseMessage } from '../../../src/utils'
import { BaseLanguageModel } from '@langchain/core/language_models/base'
import { BaseMessage } from '@langchain/core/messages'
import { BaseMessage, SystemMessage } from '@langchain/core/messages'
import { ConversationSummaryMemory, ConversationSummaryMemoryInput } from 'langchain/memory'
import { DataSource } from 'typeorm'
class ConversationSummaryMemory_Memory implements INode {
label: string
@@ -18,7 +28,7 @@ class ConversationSummaryMemory_Memory implements INode {
constructor() {
this.label = 'Conversation Summary Memory'
this.name = 'conversationSummaryMemory'
this.version = 1.0
this.version = 2.0
this.type = 'ConversationSummaryMemory'
this.icon = 'memory.svg'
this.category = 'Memory'
@@ -30,67 +40,123 @@ class ConversationSummaryMemory_Memory implements INode {
name: 'model',
type: 'BaseChatModel'
},
{
label: 'Session Id',
name: 'sessionId',
type: 'string',
description:
'If not specified, a random id will be used. Learn <a target="_blank" href="https://docs.flowiseai.com/memory#ui-and-embedded-chat">more</a>',
default: '',
optional: true,
additionalParams: true
},
{
label: 'Memory Key',
name: 'memoryKey',
type: 'string',
default: 'chat_history'
},
{
label: 'Input Key',
name: 'inputKey',
type: 'string',
default: 'input'
default: 'chat_history',
additionalParams: true
}
]
}
async init(nodeData: INodeData): Promise<any> {
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
const model = nodeData.inputs?.model as BaseLanguageModel
const memoryKey = nodeData.inputs?.memoryKey as string
const inputKey = nodeData.inputs?.inputKey as string
const sessionId = nodeData.inputs?.sessionId as string
const memoryKey = (nodeData.inputs?.memoryKey as string) ?? 'chat_history'
const obj: ConversationSummaryMemoryInput = {
const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity
const chatflowid = options.chatflowid as string
const obj: ConversationSummaryMemoryInput & BufferMemoryExtendedInput = {
llm: model,
returnMessages: true,
memoryKey,
inputKey
returnMessages: true,
sessionId,
appDataSource,
databaseEntities,
chatflowid
}
return new ConversationSummaryMemoryExtended(obj)
}
}
interface BufferMemoryExtendedInput {
sessionId: string
appDataSource: DataSource
databaseEntities: IDatabaseEntity
chatflowid: string
}
class ConversationSummaryMemoryExtended extends FlowiseSummaryMemory implements MemoryMethods {
constructor(fields: ConversationSummaryMemoryInput) {
appDataSource: DataSource
databaseEntities: IDatabaseEntity
chatflowid: string
sessionId = ''
constructor(fields: ConversationSummaryMemoryInput & BufferMemoryExtendedInput) {
super(fields)
this.sessionId = fields.sessionId
this.appDataSource = fields.appDataSource
this.databaseEntities = fields.databaseEntities
this.chatflowid = fields.chatflowid
}
async getChatMessages(_?: string, returnBaseMessages = false, prevHistory: IMessage[] = []): Promise<IMessage[] | BaseMessage[]> {
await this.chatHistory.clear()
this.buffer = ''
async getChatMessages(overrideSessionId = '', returnBaseMessages = false): Promise<IMessage[] | BaseMessage[]> {
const id = overrideSessionId ? overrideSessionId : this.sessionId
if (!id) return []
for (const msg of prevHistory) {
if (msg.type === 'userMessage') await this.chatHistory.addUserMessage(msg.message)
else if (msg.type === 'apiMessage') await this.chatHistory.addAIChatMessage(msg.message)
}
this.buffer = ''
let chatMessage = await this.appDataSource.getRepository(this.databaseEntities['ChatMessage']).find({
where: {
sessionId: id,
chatflowid: this.chatflowid
},
order: {
createdDate: 'ASC'
}
})
const baseMessages = mapChatMessageToBaseMessage(chatMessage)
// Get summary
const chatMessages = await this.chatHistory.getMessages()
this.buffer = chatMessages.length ? await this.predictNewSummary(chatMessages.slice(-2), this.buffer) : ''
if (this.llm && typeof this.llm !== 'string') {
this.buffer = baseMessages.length ? await this.predictNewSummary(baseMessages.slice(-2), this.buffer) : ''
}
const memoryResult = await this.loadMemoryVariables({})
const baseMessages = memoryResult[this.memoryKey ?? 'chat_history']
return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages)
if (returnBaseMessages) {
return [new SystemMessage(this.buffer)]
}
if (this.buffer) {
return [
{
message: this.buffer,
type: 'apiMessage'
}
]
}
let returnIMessages: IMessage[] = []
for (const m of chatMessage) {
returnIMessages.push({
message: m.content as string,
type: m.role
})
}
return returnIMessages
}
async addChatMessages(): Promise<void> {
// adding chat messages will be done on the fly in getChatMessages()
// adding chat messages is done on server level
return
}
async clearChatMessages(): Promise<void> {
await this.clear()
// clearing chat messages is done on server level
return
}
}