diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index 04d704a5..f5fd0ccc 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -43,7 +43,7 @@ class SqlDatabaseChain_Chains implements INode { constructor() { this.label = 'Sql Database Chain' this.name = 'sqlDatabaseChain' - this.version = 2.0 + this.version = 3.0 this.type = 'SqlDatabaseChain' this.icon = 'sqlchain.svg' this.category = 'Chains' @@ -85,6 +85,41 @@ class SqlDatabaseChain_Chains implements INode { type: 'string', placeholder: '1270.0.0.1:5432/chinook' }, + { + label: 'Include Tables', + name: 'includesTables', + type: 'string', + description: 'Tables to include for queries.', + additionalParams: true, + optional: true + }, + { + label: 'Ignore Tables', + name: 'ignoreTables', + type: 'string', + description: 'Tables to ignore for queries.', + additionalParams: true, + optional: true + }, + { + label: "Sample table's rows info", + name: 'sampleRowsInTableInfo', + type: 'number', + description: 'Number of sample row for tables to load for info.', + placeholder: '3', + additionalParams: true, + optional: true + }, + { + label: 'Top Keys', + name: 'topK', + type: 'number', + description: + 'If you are querying for several rows of a table you can select the maximum number of results you want to get by using the "top_k" parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily.', + placeholder: '10', + additionalParams: true, + optional: true + }, { label: 'Custom Prompt', name: 'customPrompt', @@ -104,20 +139,50 @@ class SqlDatabaseChain_Chains implements INode { async init(nodeData: INodeData): Promise { const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel - const url = nodeData.inputs?.url + const url = nodeData.inputs?.url as string + const includesTables = nodeData.inputs?.includesTables + const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',') + const ignoreTables = nodeData.inputs?.ignoreTables + const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',') + const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number + const topK = nodeData.inputs?.topK as number const customPrompt = nodeData.inputs?.customPrompt as string - const chain = await getSQLDBChain(databaseType, url, model, customPrompt) + const chain = await getSQLDBChain( + databaseType, + url, + model, + splittedIncludesTables, + splittedIgnoreTables, + sampleRowsInTableInfo, + topK, + customPrompt + ) return chain } async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel - const url = nodeData.inputs?.url + const url = nodeData.inputs?.url as string + const includesTables = nodeData.inputs?.includesTables + const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',') + const ignoreTables = nodeData.inputs?.ignoreTables + const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',') + const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number + const topK = nodeData.inputs?.topK as number const customPrompt = nodeData.inputs?.customPrompt as string - const chain = await getSQLDBChain(databaseType, url, model, customPrompt) + const chain = await getSQLDBChain( + databaseType, + url, + model, + splittedIncludesTables, + splittedIgnoreTables, + sampleRowsInTableInfo, + topK, + customPrompt + ) const loggerHandler = new ConsoleCallbackHandler(options.logger) if (options.socketIO && options.socketIOClientId) { @@ -131,7 +196,16 @@ class SqlDatabaseChain_Chains implements INode { } } -const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel, customPrompt?: string) => { +const getSQLDBChain = async ( + databaseType: DatabaseType, + url: string, + llm: BaseLanguageModel, + includesTables?: string[], + ignoreTables?: string[], + sampleRowsInTableInfo?: number, + topK?: number, + customPrompt?: string +) => { const datasource = new DataSource( databaseType === 'sqlite' ? { @@ -145,13 +219,17 @@ const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseL ) const db = await SqlDatabase.fromDataSourceParams({ - appDataSource: datasource + appDataSource: datasource, + includesTables: includesTables, + ignoreTables: ignoreTables, + sampleRowsInTableInfo: sampleRowsInTableInfo }) const obj: SqlDatabaseChainInput = { llm, database: db, - verbose: process.env.DEBUG === 'true' ? true : false + verbose: process.env.DEBUG === 'true' ? true : false, + topK: topK } if (customPrompt) {