Merge branch 'main' into feature/Vec2Doc

This commit is contained in:
Henry
2023-08-16 11:33:41 +01:00
3 changed files with 75 additions and 17 deletions
@@ -5,6 +5,8 @@ import { flatten } from 'lodash'
import { BaseChatMemory } from 'langchain/memory' import { BaseChatMemory } from 'langchain/memory'
import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler' import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler'
const defaultMessage = `Do your best to answer the questions. Feel free to use any tools available to look up relevant information, only if necessary.`
class ConversationalRetrievalAgent_Agents implements INode { class ConversationalRetrievalAgent_Agents implements INode {
label: string label: string
name: string name: string
@@ -46,6 +48,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
label: 'System Message', label: 'System Message',
name: 'systemMessage', name: 'systemMessage',
type: 'string', type: 'string',
default: defaultMessage,
rows: 4, rows: 4,
optional: true, optional: true,
additionalParams: true additionalParams: true
@@ -65,7 +68,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
agentType: 'openai-functions', agentType: 'openai-functions',
verbose: process.env.DEBUG === 'true' ? true : false, verbose: process.env.DEBUG === 'true' ? true : false,
agentArgs: { agentArgs: {
prefix: systemMessage ?? `You are a helpful AI assistant.` prefix: systemMessage ?? defaultMessage
}, },
returnIntermediateSteps: true returnIntermediateSteps: true
}) })
@@ -5,6 +5,9 @@ import { DataSource } from 'typeorm'
import { SqlDatabase } from 'langchain/sql_db' import { SqlDatabase } from 'langchain/sql_db'
import { BaseLanguageModel } from 'langchain/base_language' import { BaseLanguageModel } from 'langchain/base_language'
import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler' import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler'
import { DataSourceOptions } from 'typeorm/data-source'
type DatabaseType = 'sqlite' | 'postgres' | 'mssql' | 'mysql'
class SqlDatabaseChain_Chains implements INode { class SqlDatabaseChain_Chains implements INode {
label: string label: string
@@ -38,36 +41,48 @@ class SqlDatabaseChain_Chains implements INode {
type: 'options', type: 'options',
options: [ options: [
{ {
label: 'SQlite', label: 'SQLite',
name: 'sqlite' name: 'sqlite'
},
{
label: 'PostgreSQL',
name: 'postgres'
},
{
label: 'MSSQL',
name: 'mssql'
},
{
label: 'MySQL',
name: 'mysql'
} }
], ],
default: 'sqlite' default: 'sqlite'
}, },
{ {
label: 'Database File Path', label: 'Connection string or file path (sqlite only)',
name: 'dbFilePath', name: 'url',
type: 'string', type: 'string',
placeholder: 'C:/Users/chinook.db' placeholder: '1270.0.0.1:5432/chinook'
} }
] ]
} }
async init(nodeData: INodeData): Promise<any> { async init(nodeData: INodeData): Promise<any> {
const databaseType = nodeData.inputs?.database as 'sqlite' const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel const model = nodeData.inputs?.model as BaseLanguageModel
const dbFilePath = nodeData.inputs?.dbFilePath const url = nodeData.inputs?.url
const chain = await getSQLDBChain(databaseType, dbFilePath, model) const chain = await getSQLDBChain(databaseType, url, model)
return chain return chain
} }
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const databaseType = nodeData.inputs?.database as 'sqlite' const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel const model = nodeData.inputs?.model as BaseLanguageModel
const dbFilePath = nodeData.inputs?.dbFilePath const url = nodeData.inputs?.url
const chain = await getSQLDBChain(databaseType, dbFilePath, model) const chain = await getSQLDBChain(databaseType, url, model)
const loggerHandler = new ConsoleCallbackHandler(options.logger) const loggerHandler = new ConsoleCallbackHandler(options.logger)
if (options.socketIO && options.socketIOClientId) { if (options.socketIO && options.socketIOClientId) {
@@ -81,11 +96,18 @@ class SqlDatabaseChain_Chains implements INode {
} }
} }
const getSQLDBChain = async (databaseType: 'sqlite', dbFilePath: string, llm: BaseLanguageModel) => { const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel) => {
const datasource = new DataSource({ const datasource = new DataSource(
databaseType === 'sqlite'
? {
type: databaseType, type: databaseType,
database: dbFilePath database: url
}) }
: ({
type: databaseType,
url: url
} as DataSourceOptions)
)
const db = await SqlDatabase.fromDataSourceParams({ const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource appDataSource: datasource
@@ -61,7 +61,40 @@ class Folder_DocumentLoaders implements INode {
'.csv': (path) => new CSVLoader(path), '.csv': (path) => new CSVLoader(path),
'.docx': (path) => new DocxLoader(path), '.docx': (path) => new DocxLoader(path),
// @ts-ignore // @ts-ignore
'.pdf': (path) => new PDFLoader(path, { pdfjs: () => import('pdf-parse/lib/pdf.js/v1.10.100/build/pdf.js') }) '.pdf': (path) => new PDFLoader(path, { pdfjs: () => import('pdf-parse/lib/pdf.js/v1.10.100/build/pdf.js') }),
'.aspx': (path) => new TextLoader(path),
'.asp': (path) => new TextLoader(path),
'.cpp': (path) => new TextLoader(path), // C++
'.c': (path) => new TextLoader(path),
'.cs': (path) => new TextLoader(path),
'.css': (path) => new TextLoader(path),
'.go': (path) => new TextLoader(path), // Go
'.h': (path) => new TextLoader(path), // C++ Header files
'.java': (path) => new TextLoader(path), // Java
'.js': (path) => new TextLoader(path), // JavaScript
'.less': (path) => new TextLoader(path), // Less files
'.ts': (path) => new TextLoader(path), // TypeScript
'.php': (path) => new TextLoader(path), // PHP
'.proto': (path) => new TextLoader(path), // Protocol Buffers
'.python': (path) => new TextLoader(path), // Python
'.py': (path) => new TextLoader(path), // Python
'.rst': (path) => new TextLoader(path), // reStructuredText
'.ruby': (path) => new TextLoader(path), // Ruby
'.rb': (path) => new TextLoader(path), // Ruby
'.rs': (path) => new TextLoader(path), // Rust
'.scala': (path) => new TextLoader(path), // Scala
'.sc': (path) => new TextLoader(path), // Scala
'.scss': (path) => new TextLoader(path), // Sass
'.sol': (path) => new TextLoader(path), // Solidity
'.sql': (path) => new TextLoader(path), //SQL
'.swift': (path) => new TextLoader(path), // Swift
'.markdown': (path) => new TextLoader(path), // Markdown
'.md': (path) => new TextLoader(path), // Markdown
'.tex': (path) => new TextLoader(path), // LaTeX
'.ltx': (path) => new TextLoader(path), // LaTeX
'.html': (path) => new TextLoader(path), // HTML
'.vb': (path) => new TextLoader(path), // Visual Basic
'.xml': (path) => new TextLoader(path) // XML
}) })
let docs = [] let docs = []