From ff7ee41758dd4fc24893cadb8c7551706f111529 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Mon, 14 Aug 2023 00:31:07 +0200 Subject: [PATCH] Added postgres, cockroachdb, mssql, mysql, mariadb, mongodb and oracle to SqlDatabaseChain_Chains --- .../SqlDatabaseChain/SqlDatabaseChain.ts | 68 +++++++++++++++---- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index 9416371b..6bfc2f8a 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -5,6 +5,9 @@ import { DataSource } from 'typeorm' import { SqlDatabase } from 'langchain/sql_db' import { BaseLanguageModel } from 'langchain/base_language' import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler' +import { DataSourceOptions } from 'typeorm/data-source' + +type DatabaseType = 'sqlite' | 'postgres' | 'cockroachdb' | 'mssql' | 'mysql' | 'mariadb' | 'mongodb' | 'oracle' class SqlDatabaseChain_Chains implements INode { label: string @@ -38,36 +41,64 @@ class SqlDatabaseChain_Chains implements INode { type: 'options', options: [ { - label: 'SQlite', + label: 'SQLite', name: 'sqlite' + }, + { + label: 'PostgreSQL', + name: 'postgres' + }, + { + label: 'CockroachDB', + name: 'cockroachdb' + }, + { + label: 'MSSQL', + name: 'mssql' + }, + { + label: 'MySQL', + name: 'mysql' + }, + { + label: 'MariaDB', + name: 'mariadb' + }, + { + label: 'MongoDB', + name: 'mongodb' + }, + { + label: 'Oracle', + name: 'oracle' } ], default: 'sqlite' }, { - label: 'Database File Path', - name: 'dbFilePath', + label: 'Connection string or file path (sqlite only)', + name: 'url', type: 'string', - placeholder: 'C:/Users/chinook.db' + placeholder: '1270.0.0.1:5432/chinook' } ] } async init(nodeData: INodeData): Promise { - const databaseType = nodeData.inputs?.database as 'sqlite' + const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel - const dbFilePath = nodeData.inputs?.dbFilePath + const url = nodeData.inputs?.url - const chain = await getSQLDBChain(databaseType, dbFilePath, model) + const chain = await getSQLDBChain(databaseType, url, model) return chain } async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { - const databaseType = nodeData.inputs?.database as 'sqlite' + const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel - const dbFilePath = nodeData.inputs?.dbFilePath + const url = nodeData.inputs?.url - const chain = await getSQLDBChain(databaseType, dbFilePath, model) + const chain = await getSQLDBChain(databaseType, url, model) const loggerHandler = new ConsoleCallbackHandler(options.logger) if (options.socketIO && options.socketIOClientId) { @@ -81,11 +112,18 @@ class SqlDatabaseChain_Chains implements INode { } } -const getSQLDBChain = async (databaseType: 'sqlite', dbFilePath: string, llm: BaseLanguageModel) => { - const datasource = new DataSource({ - type: databaseType, - database: dbFilePath - }) +const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel) => { + const datasource = new DataSource( + databaseType === 'sqlite' + ? { + type: databaseType, + database: url + } + : ({ + type: databaseType, + url: url + } as DataSourceOptions) + ) const db = await SqlDatabase.fromDataSourceParams({ appDataSource: datasource