Feature/Add dedicated agent memory nodes (#3649)

add dedicated agent memory nodes
This commit is contained in:
Henry Heng
2024-12-06 22:16:01 +00:00
committed by GitHub
parent fe2ed26999
commit cadc3b8fb3
12 changed files with 330 additions and 15 deletions
@@ -0,0 +1,112 @@
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../../src/utils'
import { SaverOptions } from '../interface'
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams } from '../../../../src/Interface'
import { DataSource } from 'typeorm'
import { MySQLSaver } from './mysqlSaver'
class MySQLAgentMemory_Memory implements INode {
label: string
name: string
version: number
description: string
type: string
icon: string
category: string
badge: string
baseClasses: string[]
inputs: INodeParams[]
credential: INodeParams
constructor() {
this.label = 'MySQL Agent Memory'
this.name = 'mySQLAgentMemory'
this.version = 1.0
this.type = 'AgentMemory'
this.icon = 'mysql.png'
this.category = 'Memory'
this.description = 'Memory for agentflow to remember the state of the conversation using MySQL database'
this.baseClasses = [this.type, ...getBaseClasses(MySQLSaver)]
this.credential = {
label: 'Connect Credential',
name: 'credential',
type: 'credential',
credentialNames: ['MySQLApi'],
optional: true
}
this.inputs = [
{
label: 'Host',
name: 'host',
type: 'string'
},
{
label: 'Database',
name: 'database',
type: 'string'
},
{
label: 'Port',
name: 'port',
type: 'number',
default: '3306'
},
{
label: 'Additional Connection Configuration',
name: 'additionalConfig',
type: 'json',
additionalParams: true,
optional: true
}
]
}
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
const additionalConfig = nodeData.inputs?.additionalConfig as string
const databaseEntities = options.databaseEntities as IDatabaseEntity
const chatflowid = options.chatflowid as string
const appDataSource = options.appDataSource as DataSource
let additionalConfiguration = {}
if (additionalConfig) {
try {
additionalConfiguration = typeof additionalConfig === 'object' ? additionalConfig : JSON.parse(additionalConfig)
} catch (exception) {
throw new Error('Invalid JSON in the Additional Configuration: ' + exception)
}
}
const threadId = options.sessionId || options.chatId
let datasourceOptions: ICommonObject = {
...additionalConfiguration,
type: 'mysql'
}
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const user = getCredentialParam('user', credentialData, nodeData)
const password = getCredentialParam('password', credentialData, nodeData)
const _port = (nodeData.inputs?.port as string) || '3306'
const port = parseInt(_port)
datasourceOptions = {
...datasourceOptions,
host: nodeData.inputs?.host as string,
port,
database: nodeData.inputs?.database as string,
username: user,
user: user,
password: password,
charset: 'utf8mb4'
}
const args: SaverOptions = {
datasourceOptions,
threadId,
appDataSource,
databaseEntities,
chatflowid
}
const recordManager = new MySQLSaver(args)
return recordManager
}
}
module.exports = { nodeClass: MySQLAgentMemory_Memory }
Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

@@ -0,0 +1,243 @@
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
import { RunnableConfig } from '@langchain/core/runnables'
import { BaseMessage } from '@langchain/core/messages'
import { DataSource } from 'typeorm'
import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface'
import { IMessage, MemoryMethods } from '../../../../src/Interface'
import { mapChatMessageToBaseMessage } from '../../../../src/utils'
export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {
protected isSetup: boolean
config: SaverOptions
threadId: string
tableName = 'checkpoints'
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
super(serde)
this.config = config
const { threadId } = config
this.threadId = threadId
}
private async getDataSource(): Promise<DataSource> {
const { datasourceOptions } = this.config
const dataSource = new DataSource(datasourceOptions)
await dataSource.initialize()
return dataSource
}
private async setup(dataSource: DataSource): Promise<void> {
if (this.isSetup) return
try {
const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(`
CREATE TABLE IF NOT EXISTS ${this.tableName} (
thread_id VARCHAR(255) NOT NULL,
checkpoint_id VARCHAR(255) NOT NULL,
parent_id VARCHAR(255),
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_id)
);`)
await queryRunner.release()
} catch (error) {
console.error(`Error creating ${this.tableName} table`, error)
throw new Error(`Error creating ${this.tableName} table`)
}
this.isSetup = true
}
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
const dataSource = await this.getDataSource()
await this.setup(dataSource)
const thread_id = config.configurable?.thread_id || this.threadId
const checkpoint_id = config.configurable?.checkpoint_id
try {
const queryRunner = dataSource.createQueryRunner()
const sql = checkpoint_id
? `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
: `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
const rows = await queryRunner.manager.query(sql, checkpoint_id ? [thread_id, checkpoint_id] : [thread_id])
await queryRunner.release()
if (rows && rows.length > 0) {
const row = rows[0]
return {
config: {
configurable: {
thread_id: row.thread_id || thread_id,
checkpoint_id: row.checkpoint_id || checkpoint_id
}
},
checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata,
parentConfig: row.parent_id
? {
configurable: {
thread_id,
checkpoint_id: row.parent_id
}
}
: undefined
}
}
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
}
return undefined
}
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple, void, unknown> {
const dataSource = await this.getDataSource()
await this.setup(dataSource)
const queryRunner = dataSource.createQueryRunner()
try {
const threadId = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
before ? 'AND checkpoint_id < ?' : ''
} ORDER BY checkpoint_id DESC`
if (limit) {
sql += ` LIMIT ${limit}`
}
const args = [threadId, before?.configurable?.checkpoint_id].filter(Boolean)
const rows = await queryRunner.manager.query(sql, args)
await queryRunner.release()
if (rows && rows.length > 0) {
for (const row of rows) {
yield {
config: {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.checkpoint_id
}
},
checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata,
parentConfig: row.parent_id
? {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.parent_id
}
}
: undefined
}
}
}
} catch (error) {
console.error(`Error listing checkpoints`, error)
throw new Error(`Error listing checkpoints`)
} finally {
await dataSource.destroy()
}
}
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
const dataSource = await this.getDataSource()
await this.setup(dataSource)
if (!config.configurable?.checkpoint_id) return {}
try {
const queryRunner = dataSource.createQueryRunner()
const row = [
config.configurable?.thread_id || this.threadId,
checkpoint.id,
config.configurable?.checkpoint_id,
Buffer.from(this.serde.stringify(checkpoint)), // Encode to binary
Buffer.from(this.serde.stringify(metadata)) // Encode to binary
]
const query = `INSERT INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata)
VALUES (?, ?, ?, ?, ?)
ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)`
await queryRunner.manager.query(query, row)
await queryRunner.release()
} catch (error) {
console.error('Error saving checkpoint', error)
throw new Error('Error saving checkpoint')
} finally {
await dataSource.destroy()
}
return {
configurable: {
thread_id: config.configurable?.thread_id || this.threadId,
checkpoint_id: checkpoint.id
}
}
}
async delete(threadId: string): Promise<void> {
if (!threadId) return
const dataSource = await this.getDataSource()
await this.setup(dataSource)
try {
const queryRunner = dataSource.createQueryRunner()
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
await queryRunner.manager.query(query, [threadId])
await queryRunner.release()
} catch (error) {
console.error(`Error deleting thread_id ${threadId}`, error)
} finally {
await dataSource.destroy()
}
}
async getChatMessages(
overrideSessionId = '',
returnBaseMessages = false,
prependMessages?: IMessage[]
): Promise<IMessage[] | BaseMessage[]> {
if (!overrideSessionId) return []
const chatMessage = await this.config.appDataSource.getRepository(this.config.databaseEntities['ChatMessage']).find({
where: {
sessionId: overrideSessionId,
chatflowid: this.config.chatflowid
},
order: {
createdDate: 'ASC'
}
})
if (prependMessages?.length) {
chatMessage.unshift(...prependMessages)
}
if (returnBaseMessages) {
return await mapChatMessageToBaseMessage(chatMessage)
}
let returnIMessages: IMessage[] = []
for (const m of chatMessage) {
returnIMessages.push({
message: m.content as string,
type: m.role
})
}
return returnIMessages
}
async addChatMessages(): Promise<void> {
// Empty as it's not being used
}
async clearChatMessages(overrideSessionId = ''): Promise<void> {
if (!overrideSessionId) return
await this.delete(overrideSessionId)
}
}