mirror of
https://github.com/farcasclaudiu/Flowise.git
synced 2026-06-28 15:00:57 +03:00
Feature/Buffer Memory SessionId (#2111)
* add sessionId to buffer memory, add conversation summary buffer memory * add fix for conv retrieval qa chain
This commit is contained in:
+187
@@ -0,0 +1,187 @@
|
||||
import {
|
||||
IMessage,
|
||||
IDatabaseEntity,
|
||||
INode,
|
||||
INodeData,
|
||||
INodeParams,
|
||||
MemoryMethods,
|
||||
ICommonObject,
|
||||
FlowiseSummaryBufferMemory
|
||||
} from '../../../src/Interface'
|
||||
import { getBaseClasses, mapChatMessageToBaseMessage } from '../../../src/utils'
|
||||
import { BaseLanguageModel } from '@langchain/core/language_models/base'
|
||||
import { BaseMessage, getBufferString } from '@langchain/core/messages'
|
||||
import { ConversationSummaryBufferMemory, ConversationSummaryBufferMemoryInput } from 'langchain/memory'
|
||||
import { DataSource } from 'typeorm'
|
||||
|
||||
class ConversationSummaryBufferMemory_Memory implements INode {
|
||||
label: string
|
||||
name: string
|
||||
version: number
|
||||
description: string
|
||||
type: string
|
||||
icon: string
|
||||
category: string
|
||||
baseClasses: string[]
|
||||
inputs: INodeParams[]
|
||||
|
||||
constructor() {
|
||||
this.label = 'Conversation Summary Buffer Memory'
|
||||
this.name = 'conversationSummaryBufferMemory'
|
||||
this.version = 1.0
|
||||
this.type = 'ConversationSummaryBufferMemory'
|
||||
this.icon = 'memory.svg'
|
||||
this.category = 'Memory'
|
||||
this.description = 'Uses token length to decide when to summarize conversations'
|
||||
this.baseClasses = [this.type, ...getBaseClasses(ConversationSummaryBufferMemory)]
|
||||
this.inputs = [
|
||||
{
|
||||
label: 'Chat Model',
|
||||
name: 'model',
|
||||
type: 'BaseChatModel'
|
||||
},
|
||||
{
|
||||
label: 'Max Token Limit',
|
||||
name: 'maxTokenLimit',
|
||||
type: 'number',
|
||||
default: 2000,
|
||||
description: 'Summarize conversations once token limit is reached. Default to 2000'
|
||||
},
|
||||
{
|
||||
label: 'Session Id',
|
||||
name: 'sessionId',
|
||||
type: 'string',
|
||||
description:
|
||||
'If not specified, a random id will be used. Learn <a target="_blank" href="https://docs.flowiseai.com/memory#ui-and-embedded-chat">more</a>',
|
||||
default: '',
|
||||
optional: true,
|
||||
additionalParams: true
|
||||
},
|
||||
{
|
||||
label: 'Memory Key',
|
||||
name: 'memoryKey',
|
||||
type: 'string',
|
||||
default: 'chat_history',
|
||||
additionalParams: true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
|
||||
const model = nodeData.inputs?.model as BaseLanguageModel
|
||||
const _maxTokenLimit = nodeData.inputs?.maxTokenLimit as string
|
||||
const maxTokenLimit = _maxTokenLimit ? parseInt(_maxTokenLimit, 10) : 2000
|
||||
const sessionId = nodeData.inputs?.sessionId as string
|
||||
const memoryKey = (nodeData.inputs?.memoryKey as string) ?? 'chat_history'
|
||||
|
||||
const appDataSource = options.appDataSource as DataSource
|
||||
const databaseEntities = options.databaseEntities as IDatabaseEntity
|
||||
const chatflowid = options.chatflowid as string
|
||||
|
||||
const obj: ConversationSummaryBufferMemoryInput & BufferMemoryExtendedInput = {
|
||||
llm: model,
|
||||
sessionId,
|
||||
memoryKey,
|
||||
maxTokenLimit,
|
||||
returnMessages: true,
|
||||
appDataSource,
|
||||
databaseEntities,
|
||||
chatflowid
|
||||
}
|
||||
|
||||
return new ConversationSummaryBufferMemoryExtended(obj)
|
||||
}
|
||||
}
|
||||
|
||||
interface BufferMemoryExtendedInput {
|
||||
sessionId: string
|
||||
appDataSource: DataSource
|
||||
databaseEntities: IDatabaseEntity
|
||||
chatflowid: string
|
||||
}
|
||||
|
||||
class ConversationSummaryBufferMemoryExtended extends FlowiseSummaryBufferMemory implements MemoryMethods {
|
||||
appDataSource: DataSource
|
||||
databaseEntities: IDatabaseEntity
|
||||
chatflowid: string
|
||||
sessionId = ''
|
||||
|
||||
constructor(fields: ConversationSummaryBufferMemoryInput & BufferMemoryExtendedInput) {
|
||||
super(fields)
|
||||
this.sessionId = fields.sessionId
|
||||
this.appDataSource = fields.appDataSource
|
||||
this.databaseEntities = fields.databaseEntities
|
||||
this.chatflowid = fields.chatflowid
|
||||
}
|
||||
|
||||
async getChatMessages(overrideSessionId = '', returnBaseMessages = false): Promise<IMessage[] | BaseMessage[]> {
|
||||
const id = overrideSessionId ? overrideSessionId : this.sessionId
|
||||
if (!id) return []
|
||||
|
||||
let chatMessage = await this.appDataSource.getRepository(this.databaseEntities['ChatMessage']).find({
|
||||
where: {
|
||||
sessionId: id,
|
||||
chatflowid: this.chatflowid
|
||||
},
|
||||
order: {
|
||||
createdDate: 'ASC'
|
||||
}
|
||||
})
|
||||
|
||||
let baseMessages = mapChatMessageToBaseMessage(chatMessage)
|
||||
|
||||
// Prune baseMessages if it exceeds max token limit
|
||||
if (this.movingSummaryBuffer) {
|
||||
baseMessages = [new this.summaryChatMessageClass(this.movingSummaryBuffer), ...baseMessages]
|
||||
}
|
||||
|
||||
let currBufferLength = 0
|
||||
|
||||
if (this.llm && typeof this.llm !== 'string') {
|
||||
currBufferLength = await this.llm.getNumTokens(getBufferString(baseMessages, this.humanPrefix, this.aiPrefix))
|
||||
if (currBufferLength > this.maxTokenLimit) {
|
||||
const prunedMemory = []
|
||||
while (currBufferLength > this.maxTokenLimit) {
|
||||
const poppedMessage = baseMessages.shift()
|
||||
if (poppedMessage) {
|
||||
prunedMemory.push(poppedMessage)
|
||||
currBufferLength = await this.llm.getNumTokens(getBufferString(baseMessages, this.humanPrefix, this.aiPrefix))
|
||||
}
|
||||
}
|
||||
this.movingSummaryBuffer = await this.predictNewSummary(prunedMemory, this.movingSummaryBuffer)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------- Finished Pruning ---------------
|
||||
|
||||
if (this.movingSummaryBuffer) {
|
||||
baseMessages = [new this.summaryChatMessageClass(this.movingSummaryBuffer), ...baseMessages]
|
||||
}
|
||||
|
||||
if (returnBaseMessages) {
|
||||
return baseMessages
|
||||
}
|
||||
|
||||
let returnIMessages: IMessage[] = []
|
||||
for (const m of baseMessages) {
|
||||
returnIMessages.push({
|
||||
message: m.content as string,
|
||||
type: m._getType() === 'human' ? 'userMessage' : 'apiMessage'
|
||||
})
|
||||
}
|
||||
|
||||
return returnIMessages
|
||||
}
|
||||
|
||||
async addChatMessages(): Promise<void> {
|
||||
// adding chat messages is done on server level
|
||||
return
|
||||
}
|
||||
|
||||
async clearChatMessages(): Promise<void> {
|
||||
// clearing chat messages is done on server level
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { nodeClass: ConversationSummaryBufferMemory_Memory }
|
||||
@@ -0,0 +1,19 @@
|
||||
<svg width="32" height="32" viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_124_19510)">
|
||||
<path d="M25 15V9C25 7.89543 24.1046 7 23 7H9C7.89543 7 7 7.89543 7 9V23C7 24.1046 7.89543 25 9 25H15" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M20 15V13C20 12.4477 19.5523 12 19 12H13C12.4477 12 12 12.4477 12 13V19C12 19.5523 12.4477 20 13 20H15" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M7 11H5" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M7 16H5" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M7 21H5" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M21 7L21 5" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M16 7L16 5" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M11 7L11 5" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M11 27L11 25" stroke="black" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M26 19H21C19.8954 19 19 19.8954 19 21V24.2857C19 25.3903 19.8954 26.2857 21 26.2857H21.4545V28L23.5 26.2857H26C27.1046 26.2857 28 25.3903 28 24.2857V21C28 19.8954 27.1046 19 26 19Z" stroke="black" stroke-width="2" stroke-linejoin="round"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_124_19510">
|
||||
<rect width="32" height="32" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
Reference in New Issue
Block a user