mirror of
https://github.com/farcasclaudiu/Flowise.git
synced 2026-06-23 15:00:29 +03:00
Merge pull request #879 from lpongetti/feature/add-options-to-sql-chain
Added some options for SqlDatabaseChain
This commit is contained in:
@@ -43,7 +43,7 @@ class SqlDatabaseChain_Chains implements INode {
|
||||
constructor() {
|
||||
this.label = 'Sql Database Chain'
|
||||
this.name = 'sqlDatabaseChain'
|
||||
this.version = 2.0
|
||||
this.version = 3.0
|
||||
this.type = 'SqlDatabaseChain'
|
||||
this.icon = 'sqlchain.svg'
|
||||
this.category = 'Chains'
|
||||
@@ -85,6 +85,41 @@ class SqlDatabaseChain_Chains implements INode {
|
||||
type: 'string',
|
||||
placeholder: '1270.0.0.1:5432/chinook'
|
||||
},
|
||||
{
|
||||
label: 'Include Tables',
|
||||
name: 'includesTables',
|
||||
type: 'string',
|
||||
description: 'Tables to include for queries.',
|
||||
additionalParams: true,
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
label: 'Ignore Tables',
|
||||
name: 'ignoreTables',
|
||||
type: 'string',
|
||||
description: 'Tables to ignore for queries.',
|
||||
additionalParams: true,
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
label: "Sample table's rows info",
|
||||
name: 'sampleRowsInTableInfo',
|
||||
type: 'number',
|
||||
description: 'Number of sample row for tables to load for info.',
|
||||
placeholder: '3',
|
||||
additionalParams: true,
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
label: 'Top Keys',
|
||||
name: 'topK',
|
||||
type: 'number',
|
||||
description:
|
||||
'If you are querying for several rows of a table you can select the maximum number of results you want to get by using the "top_k" parameter (default is 10). This is useful for avoiding query results that exceed the prompt max length or consume tokens unnecessarily.',
|
||||
placeholder: '10',
|
||||
additionalParams: true,
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
label: 'Custom Prompt',
|
||||
name: 'customPrompt',
|
||||
@@ -104,20 +139,50 @@ class SqlDatabaseChain_Chains implements INode {
|
||||
async init(nodeData: INodeData): Promise<any> {
|
||||
const databaseType = nodeData.inputs?.database as DatabaseType
|
||||
const model = nodeData.inputs?.model as BaseLanguageModel
|
||||
const url = nodeData.inputs?.url
|
||||
const url = nodeData.inputs?.url as string
|
||||
const includesTables = nodeData.inputs?.includesTables
|
||||
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
|
||||
const ignoreTables = nodeData.inputs?.ignoreTables
|
||||
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
|
||||
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
|
||||
const topK = nodeData.inputs?.topK as number
|
||||
const customPrompt = nodeData.inputs?.customPrompt as string
|
||||
|
||||
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
|
||||
const chain = await getSQLDBChain(
|
||||
databaseType,
|
||||
url,
|
||||
model,
|
||||
splittedIncludesTables,
|
||||
splittedIgnoreTables,
|
||||
sampleRowsInTableInfo,
|
||||
topK,
|
||||
customPrompt
|
||||
)
|
||||
return chain
|
||||
}
|
||||
|
||||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
|
||||
const databaseType = nodeData.inputs?.database as DatabaseType
|
||||
const model = nodeData.inputs?.model as BaseLanguageModel
|
||||
const url = nodeData.inputs?.url
|
||||
const url = nodeData.inputs?.url as string
|
||||
const includesTables = nodeData.inputs?.includesTables
|
||||
const splittedIncludesTables = includesTables == '' ? undefined : includesTables?.split(',')
|
||||
const ignoreTables = nodeData.inputs?.ignoreTables
|
||||
const splittedIgnoreTables = ignoreTables == '' ? undefined : ignoreTables?.split(',')
|
||||
const sampleRowsInTableInfo = nodeData.inputs?.sampleRowsInTableInfo as number
|
||||
const topK = nodeData.inputs?.topK as number
|
||||
const customPrompt = nodeData.inputs?.customPrompt as string
|
||||
|
||||
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
|
||||
const chain = await getSQLDBChain(
|
||||
databaseType,
|
||||
url,
|
||||
model,
|
||||
splittedIncludesTables,
|
||||
splittedIgnoreTables,
|
||||
sampleRowsInTableInfo,
|
||||
topK,
|
||||
customPrompt
|
||||
)
|
||||
const loggerHandler = new ConsoleCallbackHandler(options.logger)
|
||||
|
||||
if (options.socketIO && options.socketIOClientId) {
|
||||
@@ -131,7 +196,16 @@ class SqlDatabaseChain_Chains implements INode {
|
||||
}
|
||||
}
|
||||
|
||||
const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel, customPrompt?: string) => {
|
||||
const getSQLDBChain = async (
|
||||
databaseType: DatabaseType,
|
||||
url: string,
|
||||
llm: BaseLanguageModel,
|
||||
includesTables?: string[],
|
||||
ignoreTables?: string[],
|
||||
sampleRowsInTableInfo?: number,
|
||||
topK?: number,
|
||||
customPrompt?: string
|
||||
) => {
|
||||
const datasource = new DataSource(
|
||||
databaseType === 'sqlite'
|
||||
? {
|
||||
@@ -145,13 +219,17 @@ const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseL
|
||||
)
|
||||
|
||||
const db = await SqlDatabase.fromDataSourceParams({
|
||||
appDataSource: datasource
|
||||
appDataSource: datasource,
|
||||
includesTables: includesTables,
|
||||
ignoreTables: ignoreTables,
|
||||
sampleRowsInTableInfo: sampleRowsInTableInfo
|
||||
})
|
||||
|
||||
const obj: SqlDatabaseChainInput = {
|
||||
llm,
|
||||
database: db,
|
||||
verbose: process.env.DEBUG === 'true' ? true : false
|
||||
verbose: process.env.DEBUG === 'true' ? true : false,
|
||||
topK: topK
|
||||
}
|
||||
|
||||
if (customPrompt) {
|
||||
|
||||
Reference in New Issue
Block a user