Merge pull request #879 from lpongetti/feature/add-options-to-sql-chain

Added some options for SqlDatabaseChain
This commit is contained in:
Henry Heng
2023-09-12 19:20:46 +01:00
committed by GitHub
@@ -43,7 +43,7 @@ class SqlDatabaseChain_Chains implements INode {
constructor() {
this.label = 'Sql Database Chain'
this.name = 'sqlDatabaseChain'
this.version = 2.0
this.version = 3.0
this.type = 'SqlDatabaseChain'
this.icon = 'sqlchain.svg'
this.category = 'Chains'
@@ -85,6 +85,41 @@ class SqlDatabaseChain_Chains implements INode {
type: 'string',
placeholder: '1270.0.0.1:5432/chinook'
},
{
label: 'Include Tables',
name: 'includesTables',
type: 'string',
description: 'Tables to include for queries.',
additionalParams: true,
optional: true
},
{
label: 'Ignore Tables',
name: 'ignoreTables',
type: 'string',
description: 'Tables to ignore for queries.',
additionalParams: true,
optional: true
},
{
label: "Sample table's rows info",
name: 'sampleRowsInTableInfo',
type: 'number',
description: 'Number of sample row for tables to load for info.',
placeholder: '3',
additionalParams: true,
optional: true
},
{
label: 'Top Keys',
name: 'topK',
type: 'number',
description:
'If you are querying for several rows of a table you can select the maximum number of results you want to get by using the "top_k" parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily.',
placeholder: '10',
additionalParams: true,
optional: true
},
{
label: 'Custom Prompt',
name: 'customPrompt',
@@ -104,20 +139,50 @@ class SqlDatabaseChain_Chains implements INode {
async init(nodeData: INodeData): Promise<any> {
const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel
const url = nodeData.inputs?.url
const url = nodeData.inputs?.url as string
const includesTables = nodeData.inputs?.includesTables
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
const ignoreTables = nodeData.inputs?.ignoreTables
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
const topK = nodeData.inputs?.topK as number
const customPrompt = nodeData.inputs?.customPrompt as string
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
const chain = await getSQLDBChain(
databaseType,
url,
model,
splittedIncludesTables,
splittedIgnoreTables,
sampleRowsInTableInfo,
topK,
customPrompt
)
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel
const url = nodeData.inputs?.url
const url = nodeData.inputs?.url as string
const includesTables = nodeData.inputs?.includesTables
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
const ignoreTables = nodeData.inputs?.ignoreTables
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
const topK = nodeData.inputs?.topK as number
const customPrompt = nodeData.inputs?.customPrompt as string
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
const chain = await getSQLDBChain(
databaseType,
url,
model,
splittedIncludesTables,
splittedIgnoreTables,
sampleRowsInTableInfo,
topK,
customPrompt
)
const loggerHandler = new ConsoleCallbackHandler(options.logger)
if (options.socketIO && options.socketIOClientId) {
@@ -131,7 +196,16 @@ class SqlDatabaseChain_Chains implements INode {
}
}
const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel, customPrompt?: string) => {
const getSQLDBChain = async (
databaseType: DatabaseType,
url: string,
llm: BaseLanguageModel,
includesTables?: string[],
ignoreTables?: string[],
sampleRowsInTableInfo?: number,
topK?: number,
customPrompt?: string
) => {
const datasource = new DataSource(
databaseType === 'sqlite'
? {
@@ -145,13 +219,17 @@ const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseL
)
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource
appDataSource: datasource,
includesTables: includesTables,
ignoreTables: ignoreTables,
sampleRowsInTableInfo: sampleRowsInTableInfo
})
const obj: SqlDatabaseChainInput = {
llm,
database: db,
verbose: process.env.DEBUG === 'true' ? true : false
verbose: process.env.DEBUG === 'true' ? true : false,
topK: topK
}
if (customPrompt) {