mirror of
https://github.com/farcasclaudiu/Flowise.git
synced 2026-06-28 13:00:56 +03:00
Feature/Add dedicated agent memory nodes (#3649)
add dedicated agent memory nodes
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
import path from 'path'
|
||||
import { getBaseClasses, getUserHome } from '../../../../src/utils'
|
||||
import { SaverOptions } from '../interface'
|
||||
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams } from '../../../../src/Interface'
|
||||
import { SqliteSaver } from './sqliteSaver'
|
||||
import { DataSource } from 'typeorm'
|
||||
|
||||
class SQLiteAgentMemory_Memory implements INode {
|
||||
label: string
|
||||
name: string
|
||||
version: number
|
||||
description: string
|
||||
type: string
|
||||
icon: string
|
||||
category: string
|
||||
badge: string
|
||||
baseClasses: string[]
|
||||
inputs: INodeParams[]
|
||||
credential: INodeParams
|
||||
|
||||
constructor() {
|
||||
this.label = 'SQLite Agent Memory'
|
||||
this.name = 'sqliteAgentMemory'
|
||||
this.version = 1.0
|
||||
this.type = 'SQLiteAgentMemory'
|
||||
this.icon = 'sqlite.png'
|
||||
this.category = 'Memory'
|
||||
this.description = 'Memory for agentflow to remember the state of the conversation using SQLite database'
|
||||
this.baseClasses = [this.type, ...getBaseClasses(SqliteSaver)]
|
||||
this.inputs = [
|
||||
/*{
|
||||
label: 'Database File Path',
|
||||
name: 'databaseFilePath',
|
||||
type: 'string',
|
||||
placeholder: 'C:\\Users\\User\\.flowise\\database.sqlite',
|
||||
description: 'Path to the SQLite database file. Leave empty to use default application database',
|
||||
optional: true
|
||||
},*/
|
||||
{
|
||||
label: 'Additional Connection Configuration',
|
||||
name: 'additionalConfig',
|
||||
type: 'json',
|
||||
additionalParams: true,
|
||||
optional: true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
|
||||
const additionalConfig = nodeData.inputs?.additionalConfig as string
|
||||
const databaseEntities = options.databaseEntities as IDatabaseEntity
|
||||
const chatflowid = options.chatflowid as string
|
||||
const appDataSource = options.appDataSource as DataSource
|
||||
|
||||
let additionalConfiguration = {}
|
||||
if (additionalConfig) {
|
||||
try {
|
||||
additionalConfiguration = typeof additionalConfig === 'object' ? additionalConfig : JSON.parse(additionalConfig)
|
||||
} catch (exception) {
|
||||
throw new Error('Invalid JSON in the Additional Configuration: ' + exception)
|
||||
}
|
||||
}
|
||||
|
||||
const threadId = options.sessionId || options.chatId
|
||||
|
||||
const database = path.join(process.env.DATABASE_PATH ?? path.join(getUserHome(), '.flowise'), 'database.sqlite')
|
||||
|
||||
let datasourceOptions: ICommonObject = {
|
||||
database,
|
||||
...additionalConfiguration,
|
||||
type: 'sqlite'
|
||||
}
|
||||
|
||||
const args: SaverOptions = {
|
||||
datasourceOptions,
|
||||
threadId,
|
||||
appDataSource,
|
||||
databaseEntities,
|
||||
chatflowid
|
||||
}
|
||||
|
||||
const recordManager = new SqliteSaver(args)
|
||||
return recordManager
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { nodeClass: SQLiteAgentMemory_Memory }
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 49 KiB |
@@ -0,0 +1,242 @@
|
||||
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
|
||||
import { RunnableConfig } from '@langchain/core/runnables'
|
||||
import { BaseMessage } from '@langchain/core/messages'
|
||||
import { DataSource, QueryRunner } from 'typeorm'
|
||||
import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface'
|
||||
import { IMessage, MemoryMethods } from '../../../../src/Interface'
|
||||
import { mapChatMessageToBaseMessage } from '../../../../src/utils'
|
||||
|
||||
export class SqliteSaver extends BaseCheckpointSaver implements MemoryMethods {
|
||||
protected isSetup: boolean
|
||||
|
||||
datasource: DataSource
|
||||
|
||||
queryRunner: QueryRunner
|
||||
|
||||
config: SaverOptions
|
||||
|
||||
threadId: string
|
||||
|
||||
tableName = 'checkpoints'
|
||||
|
||||
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
|
||||
super(serde)
|
||||
this.config = config
|
||||
const { datasourceOptions, threadId } = config
|
||||
this.threadId = threadId
|
||||
this.datasource = new DataSource(datasourceOptions)
|
||||
}
|
||||
|
||||
private async setup(): Promise<void> {
|
||||
if (this.isSetup) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const appDataSource = await this.datasource.initialize()
|
||||
|
||||
this.queryRunner = appDataSource.createQueryRunner()
|
||||
await this.queryRunner.manager.query(`
|
||||
CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
||||
thread_id TEXT NOT NULL,
|
||||
checkpoint_id TEXT NOT NULL,
|
||||
parent_id TEXT,
|
||||
checkpoint BLOB,
|
||||
metadata BLOB,
|
||||
PRIMARY KEY (thread_id, checkpoint_id));`)
|
||||
} catch (error) {
|
||||
console.error(`Error creating ${this.tableName} table`, error)
|
||||
throw new Error(`Error creating ${this.tableName} table`)
|
||||
}
|
||||
|
||||
this.isSetup = true
|
||||
}
|
||||
|
||||
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
|
||||
await this.setup()
|
||||
const thread_id = config.configurable?.thread_id || this.threadId
|
||||
const checkpoint_id = config.configurable?.checkpoint_id
|
||||
|
||||
if (checkpoint_id) {
|
||||
try {
|
||||
const keys = [thread_id, checkpoint_id]
|
||||
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
|
||||
|
||||
const rows = await this.queryRunner.manager.query(sql, [...keys])
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
return {
|
||||
config,
|
||||
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
|
||||
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
|
||||
parentConfig: rows[0].parent_id
|
||||
? {
|
||||
configurable: {
|
||||
thread_id,
|
||||
checkpoint_id: rows[0].parent_id
|
||||
}
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error retrieving ${this.tableName}`, error)
|
||||
throw new Error(`Error retrieving ${this.tableName}`)
|
||||
}
|
||||
} else {
|
||||
const keys = [thread_id]
|
||||
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
|
||||
|
||||
const rows = await this.queryRunner.manager.query(sql, [...keys])
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
return {
|
||||
config: {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].checkpoint_id
|
||||
}
|
||||
},
|
||||
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
|
||||
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
|
||||
parentConfig: rows[0].parent_id
|
||||
? {
|
||||
configurable: {
|
||||
thread_id: rows[0].thread_id,
|
||||
checkpoint_id: rows[0].parent_id
|
||||
}
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
|
||||
await this.setup()
|
||||
const thread_id = config.configurable?.thread_id || this.threadId
|
||||
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
|
||||
before ? 'AND checkpoint_id < ?' : ''
|
||||
} ORDER BY checkpoint_id DESC`
|
||||
if (limit) {
|
||||
sql += ` LIMIT ${limit}`
|
||||
}
|
||||
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)
|
||||
|
||||
try {
|
||||
const rows = await this.queryRunner.manager.query(sql, [...args])
|
||||
|
||||
if (rows && rows.length > 0) {
|
||||
for (const row of rows) {
|
||||
yield {
|
||||
config: {
|
||||
configurable: {
|
||||
thread_id: row.thread_id,
|
||||
checkpoint_id: row.checkpoint_id
|
||||
}
|
||||
},
|
||||
checkpoint: (await this.serde.parse(row.checkpoint)) as Checkpoint,
|
||||
metadata: (await this.serde.parse(row.metadata)) as CheckpointMetadata,
|
||||
parentConfig: row.parent_id
|
||||
? {
|
||||
configurable: {
|
||||
thread_id: row.thread_id,
|
||||
checkpoint_id: row.parent_id
|
||||
}
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error listing ${this.tableName}`, error)
|
||||
throw new Error(`Error listing ${this.tableName}`)
|
||||
}
|
||||
}
|
||||
|
||||
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
|
||||
await this.setup()
|
||||
if (!config.configurable?.checkpoint_id) return {}
|
||||
try {
|
||||
const row = [
|
||||
config.configurable?.thread_id || this.threadId,
|
||||
checkpoint.id,
|
||||
config.configurable?.checkpoint_id,
|
||||
this.serde.stringify(checkpoint),
|
||||
this.serde.stringify(metadata)
|
||||
]
|
||||
|
||||
const query = `INSERT OR REPLACE INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)`
|
||||
|
||||
await this.queryRunner.manager.query(query, row)
|
||||
} catch (error) {
|
||||
console.error('Error saving checkpoint', error)
|
||||
throw new Error('Error saving checkpoint')
|
||||
}
|
||||
|
||||
return {
|
||||
configurable: {
|
||||
thread_id: config.configurable?.thread_id || this.threadId,
|
||||
checkpoint_id: checkpoint.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async delete(threadId: string): Promise<void> {
|
||||
if (!threadId) {
|
||||
return
|
||||
}
|
||||
await this.setup()
|
||||
const query = `DELETE FROM "${this.tableName}" WHERE thread_id = ?;`
|
||||
|
||||
try {
|
||||
await this.queryRunner.manager.query(query, [threadId])
|
||||
} catch (error) {
|
||||
console.error(`Error deleting thread_id ${threadId}`, error)
|
||||
}
|
||||
}
|
||||
|
||||
async getChatMessages(
|
||||
overrideSessionId = '',
|
||||
returnBaseMessages = false,
|
||||
prependMessages?: IMessage[]
|
||||
): Promise<IMessage[] | BaseMessage[]> {
|
||||
if (!overrideSessionId) return []
|
||||
|
||||
const chatMessage = await this.config.appDataSource.getRepository(this.config.databaseEntities['ChatMessage']).find({
|
||||
where: {
|
||||
sessionId: overrideSessionId,
|
||||
chatflowid: this.config.chatflowid
|
||||
},
|
||||
order: {
|
||||
createdDate: 'ASC'
|
||||
}
|
||||
})
|
||||
|
||||
if (prependMessages?.length) {
|
||||
chatMessage.unshift(...prependMessages)
|
||||
}
|
||||
|
||||
if (returnBaseMessages) {
|
||||
return await mapChatMessageToBaseMessage(chatMessage)
|
||||
}
|
||||
|
||||
let returnIMessages: IMessage[] = []
|
||||
for (const m of chatMessage) {
|
||||
returnIMessages.push({
|
||||
message: m.content as string,
|
||||
type: m.role
|
||||
})
|
||||
}
|
||||
return returnIMessages
|
||||
}
|
||||
|
||||
async addChatMessages(): Promise<void> {
|
||||
// Empty as its not being used
|
||||
}
|
||||
|
||||
async clearChatMessages(overrideSessionId = ''): Promise<void> {
|
||||
await this.delete(overrideSessionId)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user