mirror of
https://github.com/farcasclaudiu/Flowise.git
synced 2026-06-28 15:00:57 +03:00
Feature/Buffer Memory SessionId (#2111)
* add sessionId to buffer memory, add conversation summary buffer memory * add fix for conv retrieval qa chain
This commit is contained in:
@@ -217,7 +217,6 @@ const prepareChatPrompt = (nodeData: INodeData, humanImageMessages: MessageConte
|
||||
}
|
||||
|
||||
const prepareChain = (nodeData: INodeData, options: ICommonObject, sessionId?: string) => {
|
||||
const chatHistory = options.chatHistory
|
||||
let model = nodeData.inputs?.model as BaseChatModel
|
||||
const memory = nodeData.inputs?.memory as FlowiseMemory
|
||||
const memoryKey = memory.memoryKey ?? 'chat_history'
|
||||
@@ -253,7 +252,7 @@ const prepareChain = (nodeData: INodeData, options: ICommonObject, sessionId?: s
|
||||
{
|
||||
[inputKey]: (input: { input: string }) => input.input,
|
||||
[memoryKey]: async () => {
|
||||
const history = await memory.getChatMessages(sessionId, true, chatHistory)
|
||||
const history = await memory.getChatMessages(sessionId, true)
|
||||
return history
|
||||
},
|
||||
...promptVariables
|
||||
|
||||
+59
-15
@@ -1,4 +1,5 @@
|
||||
import { applyPatch } from 'fast-json-patch'
|
||||
import { DataSource } from 'typeorm'
|
||||
import { BaseLanguageModel } from '@langchain/core/language_models/base'
|
||||
import { BaseRetriever } from '@langchain/core/retrievers'
|
||||
import { PromptTemplate, ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
|
||||
@@ -11,9 +12,18 @@ import { StringOutputParser } from '@langchain/core/output_parsers'
|
||||
import type { Document } from '@langchain/core/documents'
|
||||
import { BufferMemoryInput } from 'langchain/memory'
|
||||
import { ConversationalRetrievalQAChain } from 'langchain/chains'
|
||||
import { convertBaseMessagetoIMessage, getBaseClasses } from '../../../src/utils'
|
||||
import { getBaseClasses, mapChatMessageToBaseMessage } from '../../../src/utils'
|
||||
import { ConsoleCallbackHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, MemoryMethods } from '../../../src/Interface'
|
||||
import {
|
||||
FlowiseMemory,
|
||||
ICommonObject,
|
||||
IMessage,
|
||||
INode,
|
||||
INodeData,
|
||||
INodeParams,
|
||||
IDatabaseEntity,
|
||||
MemoryMethods
|
||||
} from '../../../src/Interface'
|
||||
import { QA_TEMPLATE, REPHRASE_TEMPLATE, RESPONSE_TEMPLATE } from './prompts'
|
||||
|
||||
type RetrievalChainInput = {
|
||||
@@ -166,6 +176,10 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
const responsePrompt = nodeData.inputs?.responsePrompt as string
|
||||
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||
|
||||
const appDataSource = options.appDataSource as DataSource
|
||||
const databaseEntities = options.databaseEntities as IDatabaseEntity
|
||||
const chatflowid = options.chatflowid as string
|
||||
|
||||
let customResponsePrompt = responsePrompt
|
||||
// If the deprecated systemMessagePrompt is still exists
|
||||
if (systemMessagePrompt) {
|
||||
@@ -178,7 +192,9 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
memory = new BufferMemory({
|
||||
returnMessages: true,
|
||||
memoryKey: 'chat_history',
|
||||
inputKey: 'input'
|
||||
appDataSource,
|
||||
databaseEntities,
|
||||
chatflowid
|
||||
})
|
||||
}
|
||||
|
||||
@@ -194,7 +210,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
}
|
||||
const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)
|
||||
|
||||
const history = ((await memory.getChatMessages(this.sessionId, false, options.chatHistory)) as IMessage[]) ?? []
|
||||
const history = ((await memory.getChatMessages(this.sessionId, false)) as IMessage[]) ?? []
|
||||
|
||||
const loggerHandler = new ConsoleCallbackHandler(options.logger)
|
||||
const additionalCallback = await additionalCallbacks(nodeData, options)
|
||||
@@ -367,31 +383,59 @@ const createChain = (
|
||||
return conversationalQAChain
|
||||
}
|
||||
|
||||
interface BufferMemoryExtendedInput {
|
||||
appDataSource: DataSource
|
||||
databaseEntities: IDatabaseEntity
|
||||
chatflowid: string
|
||||
}
|
||||
|
||||
class BufferMemory extends FlowiseMemory implements MemoryMethods {
|
||||
constructor(fields: BufferMemoryInput) {
|
||||
appDataSource: DataSource
|
||||
databaseEntities: IDatabaseEntity
|
||||
chatflowid: string
|
||||
|
||||
constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) {
|
||||
super(fields)
|
||||
this.appDataSource = fields.appDataSource
|
||||
this.databaseEntities = fields.databaseEntities
|
||||
this.chatflowid = fields.chatflowid
|
||||
}
|
||||
|
||||
async getChatMessages(_?: string, returnBaseMessages = false, prevHistory: IMessage[] = []): Promise<IMessage[] | BaseMessage[]> {
|
||||
await this.chatHistory.clear()
|
||||
async getChatMessages(overrideSessionId = '', returnBaseMessages = false): Promise<IMessage[] | BaseMessage[]> {
|
||||
if (!overrideSessionId) return []
|
||||
|
||||
for (const msg of prevHistory) {
|
||||
if (msg.type === 'userMessage') await this.chatHistory.addUserMessage(msg.message)
|
||||
else if (msg.type === 'apiMessage') await this.chatHistory.addAIChatMessage(msg.message)
|
||||
const chatMessage = await this.appDataSource.getRepository(this.databaseEntities['ChatMessage']).find({
|
||||
where: {
|
||||
sessionId: overrideSessionId,
|
||||
chatflowid: this.chatflowid
|
||||
},
|
||||
order: {
|
||||
createdDate: 'ASC'
|
||||
}
|
||||
})
|
||||
|
||||
if (returnBaseMessages) {
|
||||
return mapChatMessageToBaseMessage(chatMessage)
|
||||
}
|
||||
|
||||
const memoryResult = await this.loadMemoryVariables({})
|
||||
const baseMessages = memoryResult[this.memoryKey ?? 'chat_history']
|
||||
return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages)
|
||||
let returnIMessages: IMessage[] = []
|
||||
for (const m of chatMessage) {
|
||||
returnIMessages.push({
|
||||
message: m.content as string,
|
||||
type: m.role
|
||||
})
|
||||
}
|
||||
return returnIMessages
|
||||
}
|
||||
|
||||
async addChatMessages(): Promise<void> {
|
||||
// adding chat messages will be done on the fly in getChatMessages()
|
||||
// adding chat messages is done on server level
|
||||
return
|
||||
}
|
||||
|
||||
async clearChatMessages(): Promise<void> {
|
||||
await this.clear()
|
||||
// clearing chat messages is done on server level
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user