Feature/Buffer Memory SessionId (#2111)

* add sessionId to buffer memory, add conversation summary buffer memory

* add fix for conv retrieval qa chain
This commit is contained in:
Henry Heng
2024-04-11 11:18:39 +01:00
committed by GitHub
parent 57b716c7d7
commit c33642cdf9
39 changed files with 784 additions and 574 deletions
@@ -1,4 +1,5 @@
import { applyPatch } from 'fast-json-patch'
import { DataSource } from 'typeorm'
import { BaseLanguageModel } from '@langchain/core/language_models/base'
import { BaseRetriever } from '@langchain/core/retrievers'
import { PromptTemplate, ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
@@ -11,9 +12,18 @@ import { StringOutputParser } from '@langchain/core/output_parsers'
import type { Document } from '@langchain/core/documents'
import { BufferMemoryInput } from 'langchain/memory'
import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { convertBaseMessagetoIMessage, getBaseClasses } from '../../../src/utils'
import { getBaseClasses, mapChatMessageToBaseMessage } from '../../../src/utils'
import { ConsoleCallbackHandler, additionalCallbacks } from '../../../src/handler'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, MemoryMethods } from '../../../src/Interface'
import {
FlowiseMemory,
ICommonObject,
IMessage,
INode,
INodeData,
INodeParams,
IDatabaseEntity,
MemoryMethods
} from '../../../src/Interface'
import { QA_TEMPLATE, REPHRASE_TEMPLATE, RESPONSE_TEMPLATE } from './prompts'
type RetrievalChainInput = {
@@ -166,6 +176,10 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const responsePrompt = nodeData.inputs?.responsePrompt as string
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity
const chatflowid = options.chatflowid as string
let customResponsePrompt = responsePrompt
// If the deprecated systemMessagePrompt is still exists
if (systemMessagePrompt) {
@@ -178,7 +192,9 @@ class ConversationalRetrievalQAChain_Chains implements INode {
memory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
inputKey: 'input'
appDataSource,
databaseEntities,
chatflowid
})
}
@@ -194,7 +210,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
}
const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)
const history = ((await memory.getChatMessages(this.sessionId, false, options.chatHistory)) as IMessage[]) ?? []
const history = ((await memory.getChatMessages(this.sessionId, false)) as IMessage[]) ?? []
const loggerHandler = new ConsoleCallbackHandler(options.logger)
const additionalCallback = await additionalCallbacks(nodeData, options)
@@ -367,31 +383,59 @@ const createChain = (
return conversationalQAChain
}
interface BufferMemoryExtendedInput {
appDataSource: DataSource
databaseEntities: IDatabaseEntity
chatflowid: string
}
class BufferMemory extends FlowiseMemory implements MemoryMethods {
constructor(fields: BufferMemoryInput) {
appDataSource: DataSource
databaseEntities: IDatabaseEntity
chatflowid: string
constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) {
super(fields)
this.appDataSource = fields.appDataSource
this.databaseEntities = fields.databaseEntities
this.chatflowid = fields.chatflowid
}
async getChatMessages(_?: string, returnBaseMessages = false, prevHistory: IMessage[] = []): Promise<IMessage[] | BaseMessage[]> {
await this.chatHistory.clear()
async getChatMessages(overrideSessionId = '', returnBaseMessages = false): Promise<IMessage[] | BaseMessage[]> {
if (!overrideSessionId) return []
for (const msg of prevHistory) {
if (msg.type === 'userMessage') await this.chatHistory.addUserMessage(msg.message)
else if (msg.type === 'apiMessage') await this.chatHistory.addAIChatMessage(msg.message)
const chatMessage = await this.appDataSource.getRepository(this.databaseEntities['ChatMessage']).find({
where: {
sessionId: overrideSessionId,
chatflowid: this.chatflowid
},
order: {
createdDate: 'ASC'
}
})
if (returnBaseMessages) {
return mapChatMessageToBaseMessage(chatMessage)
}
const memoryResult = await this.loadMemoryVariables({})
const baseMessages = memoryResult[this.memoryKey ?? 'chat_history']
return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages)
let returnIMessages: IMessage[] = []
for (const m of chatMessage) {
returnIMessages.push({
message: m.content as string,
type: m.role
})
}
return returnIMessages
}
async addChatMessages(): Promise<void> {
// adding chat messages will be done on the fly in getChatMessages()
// adding chat messages is done on server level
return
}
async clearChatMessages(): Promise<void> {
await this.clear()
// clearing chat messages is done on server level
return
}
}