mirror of
https://github.com/farcasclaudiu/Flowise.git
synced 2026-06-29 05:01:10 +03:00
Bugfix/SQLite agent memory node (#3650)
* add dedicated agent memory nodes * sqlite agent memory fix * Update pnpm-lock.yaml
This commit is contained in:
@@ -1,42 +1,39 @@
|
|||||||
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
|
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
|
||||||
import { RunnableConfig } from '@langchain/core/runnables'
|
import { RunnableConfig } from '@langchain/core/runnables'
|
||||||
import { BaseMessage } from '@langchain/core/messages'
|
import { BaseMessage } from '@langchain/core/messages'
|
||||||
import { DataSource, QueryRunner } from 'typeorm'
|
import { DataSource } from 'typeorm'
|
||||||
import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface'
|
import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface'
|
||||||
import { IMessage, MemoryMethods } from '../../../../src/Interface'
|
import { IMessage, MemoryMethods } from '../../../../src/Interface'
|
||||||
import { mapChatMessageToBaseMessage } from '../../../../src/utils'
|
import { mapChatMessageToBaseMessage } from '../../../../src/utils'
|
||||||
|
|
||||||
export class SqliteSaver extends BaseCheckpointSaver implements MemoryMethods {
|
export class SqliteSaver extends BaseCheckpointSaver implements MemoryMethods {
|
||||||
protected isSetup: boolean
|
protected isSetup: boolean
|
||||||
|
|
||||||
datasource: DataSource
|
|
||||||
|
|
||||||
queryRunner: QueryRunner
|
|
||||||
|
|
||||||
config: SaverOptions
|
config: SaverOptions
|
||||||
|
|
||||||
threadId: string
|
threadId: string
|
||||||
|
|
||||||
tableName = 'checkpoints'
|
tableName = 'checkpoints'
|
||||||
|
|
||||||
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
|
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
|
||||||
super(serde)
|
super(serde)
|
||||||
this.config = config
|
this.config = config
|
||||||
const { datasourceOptions, threadId } = config
|
const { threadId } = config
|
||||||
this.threadId = threadId
|
this.threadId = threadId
|
||||||
this.datasource = new DataSource(datasourceOptions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private async setup(): Promise<void> {
|
private async getDataSource(): Promise<DataSource> {
|
||||||
|
const { datasourceOptions } = this.config
|
||||||
|
const dataSource = new DataSource(datasourceOptions)
|
||||||
|
await dataSource.initialize()
|
||||||
|
return dataSource
|
||||||
|
}
|
||||||
|
|
||||||
|
private async setup(dataSource: DataSource): Promise<void> {
|
||||||
if (this.isSetup) {
|
if (this.isSetup) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const appDataSource = await this.datasource.initialize()
|
const queryRunner = dataSource.createQueryRunner()
|
||||||
|
await queryRunner.manager.query(`
|
||||||
this.queryRunner = appDataSource.createQueryRunner()
|
|
||||||
await this.queryRunner.manager.query(`
|
|
||||||
CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
||||||
thread_id TEXT NOT NULL,
|
thread_id TEXT NOT NULL,
|
||||||
checkpoint_id TEXT NOT NULL,
|
checkpoint_id TEXT NOT NULL,
|
||||||
@@ -44,6 +41,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||||||
checkpoint BLOB,
|
checkpoint BLOB,
|
||||||
metadata BLOB,
|
metadata BLOB,
|
||||||
PRIMARY KEY (thread_id, checkpoint_id));`)
|
PRIMARY KEY (thread_id, checkpoint_id));`)
|
||||||
|
await queryRunner.release()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Error creating ${this.tableName} table`, error)
|
console.error(`Error creating ${this.tableName} table`, error)
|
||||||
throw new Error(`Error creating ${this.tableName} table`)
|
throw new Error(`Error creating ${this.tableName} table`)
|
||||||
@@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||||||
}
|
}
|
||||||
|
|
||||||
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
|
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
|
||||||
await this.setup()
|
const dataSource = await this.getDataSource()
|
||||||
|
await this.setup(dataSource)
|
||||||
|
|
||||||
const thread_id = config.configurable?.thread_id || this.threadId
|
const thread_id = config.configurable?.thread_id || this.threadId
|
||||||
const checkpoint_id = config.configurable?.checkpoint_id
|
const checkpoint_id = config.configurable?.checkpoint_id
|
||||||
|
|
||||||
if (checkpoint_id) {
|
if (checkpoint_id) {
|
||||||
try {
|
try {
|
||||||
|
const queryRunner = dataSource.createQueryRunner()
|
||||||
const keys = [thread_id, checkpoint_id]
|
const keys = [thread_id, checkpoint_id]
|
||||||
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
|
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
|
||||||
|
|
||||||
const rows = await this.queryRunner.manager.query(sql, [...keys])
|
const rows = await queryRunner.manager.query(sql, [...keys])
|
||||||
|
await queryRunner.release()
|
||||||
|
|
||||||
if (rows && rows.length > 0) {
|
if (rows && rows.length > 0) {
|
||||||
return {
|
return {
|
||||||
@@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Error retrieving ${this.tableName}`, error)
|
console.error(`Error retrieving ${this.tableName}`, error)
|
||||||
throw new Error(`Error retrieving ${this.tableName}`)
|
throw new Error(`Error retrieving ${this.tableName}`)
|
||||||
|
} finally {
|
||||||
|
await dataSource.destroy()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const keys = [thread_id]
|
try {
|
||||||
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
|
const queryRunner = dataSource.createQueryRunner()
|
||||||
|
const keys = [thread_id]
|
||||||
|
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
|
||||||
|
|
||||||
const rows = await this.queryRunner.manager.query(sql, [...keys])
|
const rows = await queryRunner.manager.query(sql, [...keys])
|
||||||
|
await queryRunner.release()
|
||||||
|
|
||||||
if (rows && rows.length > 0) {
|
if (rows && rows.length > 0) {
|
||||||
return {
|
return {
|
||||||
config: {
|
config: {
|
||||||
configurable: {
|
configurable: {
|
||||||
thread_id: rows[0].thread_id,
|
thread_id: rows[0].thread_id,
|
||||||
checkpoint_id: rows[0].checkpoint_id
|
checkpoint_id: rows[0].checkpoint_id
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
|
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
|
||||||
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
|
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
|
||||||
parentConfig: rows[0].parent_id
|
parentConfig: rows[0].parent_id
|
||||||
? {
|
? {
|
||||||
configurable: {
|
configurable: {
|
||||||
thread_id: rows[0].thread_id,
|
thread_id: rows[0].thread_id,
|
||||||
checkpoint_id: rows[0].parent_id
|
checkpoint_id: rows[0].parent_id
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
: undefined
|
||||||
: undefined
|
}
|
||||||
}
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error retrieving ${this.tableName}`, error)
|
||||||
|
throw new Error(`Error retrieving ${this.tableName}`)
|
||||||
|
} finally {
|
||||||
|
await dataSource.destroy()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
|
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
|
||||||
await this.setup()
|
const dataSource = await this.getDataSource()
|
||||||
|
await this.setup(dataSource)
|
||||||
|
|
||||||
|
const queryRunner = dataSource.createQueryRunner()
|
||||||
const thread_id = config.configurable?.thread_id || this.threadId
|
const thread_id = config.configurable?.thread_id || this.threadId
|
||||||
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
|
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
|
||||||
before ? 'AND checkpoint_id < ?' : ''
|
before ? 'AND checkpoint_id < ?' : ''
|
||||||
@@ -125,7 +141,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||||||
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)
|
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const rows = await this.queryRunner.manager.query(sql, [...args])
|
const rows = await queryRunner.manager.query(sql, [...args])
|
||||||
|
await queryRunner.release()
|
||||||
|
|
||||||
if (rows && rows.length > 0) {
|
if (rows && rows.length > 0) {
|
||||||
for (const row of rows) {
|
for (const row of rows) {
|
||||||
@@ -152,13 +169,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Error listing ${this.tableName}`, error)
|
console.error(`Error listing ${this.tableName}`, error)
|
||||||
throw new Error(`Error listing ${this.tableName}`)
|
throw new Error(`Error listing ${this.tableName}`)
|
||||||
|
} finally {
|
||||||
|
await dataSource.destroy()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
|
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
|
||||||
await this.setup()
|
const dataSource = await this.getDataSource()
|
||||||
|
await this.setup(dataSource)
|
||||||
|
|
||||||
if (!config.configurable?.checkpoint_id) return {}
|
if (!config.configurable?.checkpoint_id) return {}
|
||||||
try {
|
try {
|
||||||
|
const queryRunner = dataSource.createQueryRunner()
|
||||||
const row = [
|
const row = [
|
||||||
config.configurable?.thread_id || this.threadId,
|
config.configurable?.thread_id || this.threadId,
|
||||||
checkpoint.id,
|
checkpoint.id,
|
||||||
@@ -169,10 +191,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||||||
|
|
||||||
const query = `INSERT OR REPLACE INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)`
|
const query = `INSERT OR REPLACE INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)`
|
||||||
|
|
||||||
await this.queryRunner.manager.query(query, row)
|
await queryRunner.manager.query(query, row)
|
||||||
|
await queryRunner.release()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error saving checkpoint', error)
|
console.error('Error saving checkpoint', error)
|
||||||
throw new Error('Error saving checkpoint')
|
throw new Error('Error saving checkpoint')
|
||||||
|
} finally {
|
||||||
|
await dataSource.destroy()
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -187,13 +212,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
|
|||||||
if (!threadId) {
|
if (!threadId) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
await this.setup()
|
|
||||||
|
const dataSource = await this.getDataSource()
|
||||||
|
await this.setup(dataSource)
|
||||||
|
|
||||||
const query = `DELETE FROM "${this.tableName}" WHERE thread_id = ?;`
|
const query = `DELETE FROM "${this.tableName}" WHERE thread_id = ?;`
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await this.queryRunner.manager.query(query, [threadId])
|
const queryRunner = dataSource.createQueryRunner()
|
||||||
|
await queryRunner.manager.query(query, [threadId])
|
||||||
|
await queryRunner.release()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Error deleting thread_id ${threadId}`, error)
|
console.error(`Error deleting thread_id ${threadId}`, error)
|
||||||
|
} finally {
|
||||||
|
await dataSource.destroy()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user