From c25eaa11398a782e31ff3b5412cf9191b3a06f79 Mon Sep 17 00:00:00 2001 From: Henry Date: Mon, 28 Aug 2023 13:11:15 +0100 Subject: [PATCH 1/2] add chatOpenAI Finetuned --- .../ChatOpenAIFineTuned.ts | 149 ++++++++++++++++++ .../chatmodels/ChatOpenAIFineTuned/openai.png | Bin 0 -> 3991 bytes .../components/nodes/llms/OpenAI/OpenAI.ts | 21 ++- 3 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 packages/components/nodes/chatmodels/ChatOpenAIFineTuned/ChatOpenAIFineTuned.ts create mode 100644 packages/components/nodes/chatmodels/ChatOpenAIFineTuned/openai.png diff --git a/packages/components/nodes/chatmodels/ChatOpenAIFineTuned/ChatOpenAIFineTuned.ts b/packages/components/nodes/chatmodels/ChatOpenAIFineTuned/ChatOpenAIFineTuned.ts new file mode 100644 index 00000000..bfe3ba7a --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatOpenAIFineTuned/ChatOpenAIFineTuned.ts @@ -0,0 +1,149 @@ +import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' +import { ChatOpenAI, OpenAIChatInput } from 'langchain/chat_models/openai' + +class ChatOpenAIFineTuned_ChatModels implements INode { + label: string + name: string + version: number + type: string + icon: string + category: string + description: string + baseClasses: string[] + credential: INodeParams + inputs: INodeParams[] + + constructor() { + this.label = 'ChatOpenAI Fine-Tuned' + this.name = 'chatOpenAIFineTuned' + this.version = 1.0 + this.type = 'ChatOpenAI-FineTuned' + this.icon = 'openai.png' + this.category = 'Chat Models' + this.description = 'Wrapper around fine-tuned OpenAI LLM that use the Chat endpoint' + this.baseClasses = [this.type, ...getBaseClasses(ChatOpenAI)] + this.credential = { + label: 'Connect Credential', + name: 'credential', + type: 'credential', + credentialNames: ['openAIApi'] + } + this.inputs = [ + { + label: 'Model Name', + name: 'modelName', + type: 'string', + placeholder: 'ft:gpt-3.5-turbo:my-org:custom_suffix:id' + }, + { + label: 'Temperature', + name: 'temperature', + type: 'number', + step: 0.1, + default: 0.9, + optional: true + }, + { + label: 'Max Tokens', + name: 'maxTokens', + type: 'number', + step: 1, + optional: true, + additionalParams: true + }, + { + label: 'Top Probability', + name: 'topP', + type: 'number', + step: 0.1, + optional: true, + additionalParams: true + }, + { + label: 'Frequency Penalty', + name: 'frequencyPenalty', + type: 'number', + step: 0.1, + optional: true, + additionalParams: true + }, + { + label: 'Presence Penalty', + name: 'presencePenalty', + type: 'number', + step: 0.1, + optional: true, + additionalParams: true + }, + { + label: 'Timeout', + name: 'timeout', + type: 'number', + step: 1, + optional: true, + additionalParams: true + }, + { + label: 'BasePath', + name: 'basepath', + type: 'string', + optional: true, + additionalParams: true + }, + { + label: 'BaseOptions', + name: 'baseOptions', + type: 'json', + optional: true, + additionalParams: true + } + ] + } + + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { + const temperature = nodeData.inputs?.temperature as string + const modelName = nodeData.inputs?.modelName as string + const maxTokens = nodeData.inputs?.maxTokens as string + const topP = nodeData.inputs?.topP as string + const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string + const presencePenalty = nodeData.inputs?.presencePenalty as string + const timeout = nodeData.inputs?.timeout as string + const streaming = nodeData.inputs?.streaming as boolean + const basePath = nodeData.inputs?.basepath as string + const baseOptions = nodeData.inputs?.baseOptions + + const credentialData = await getCredentialData(nodeData.credential ?? '', options) + const openAIApiKey = getCredentialParam('openAIApiKey', credentialData, nodeData) + + const obj: Partial & { openAIApiKey?: string } = { + temperature: parseFloat(temperature), + modelName, + openAIApiKey, + streaming: streaming ?? true + } + + if (maxTokens) obj.maxTokens = parseInt(maxTokens, 10) + if (topP) obj.topP = parseFloat(topP) + if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty) + if (presencePenalty) obj.presencePenalty = parseFloat(presencePenalty) + if (timeout) obj.timeout = parseInt(timeout, 10) + + let parsedBaseOptions: any | undefined = undefined + + if (baseOptions) { + try { + parsedBaseOptions = typeof baseOptions === 'object' ? baseOptions : JSON.parse(baseOptions) + } catch (exception) { + throw new Error("Invalid JSON in the ChatOpenAI's BaseOptions: " + exception) + } + } + const model = new ChatOpenAI(obj, { + basePath, + baseOptions: parsedBaseOptions + }) + return model + } +} + +module.exports = { nodeClass: ChatOpenAIFineTuned_ChatModels } diff --git a/packages/components/nodes/chatmodels/ChatOpenAIFineTuned/openai.png b/packages/components/nodes/chatmodels/ChatOpenAIFineTuned/openai.png new file mode 100644 index 0000000000000000000000000000000000000000..de08a05b28979826c4cc669c4899789763a938a1 GIT binary patch literal 3991 zcmV;I4`}d-P)gjhf~eq#s7M4i1UD8XqJW~vA_xftNZ1L8K*&N!(&?H%y1G)Gbam|=aK3jA=g{eX z@7=H7a^Jo8-Gcvf2q9|6MLi;kA{=m2P6cQ2)V1)TAs~(pT*(!*p&9W+1Lr8@H};dw zHug~X$0Z<~j`Sm$E?i7hfWKF8n(eG&Ik{BUEe-a=MOQM|iyKj+xXEW0-3hDfF58I& z#*>GLh)0tE?>F`_krs8`ZF?Zli!3TN1+P64R?{bBi?U;gWH|YTh4+>Hq!L-zB3MBb zV>qpA;HyoBGo%q+*J7AOBx5KtExwO}V$uTc8RtC&hEr%s{OVDVdLga_y~wvLzK??a z^sQ@gj3R+7Tg3NKu$q>k>9{@Whl|Gh`>Pxu-Wg^SlVy}NwlLW4Twghj6#mHhirCmm~&>D3b&!V;S94?d=RK$ zm*UB~=s*f7bfqyD1^9jm$Joe9zT=R}sBsiYlGd-CA*uv8` zKMGwKhufy*PekMhFLJ3|cWa(uw}INL_=N{(5K8gm_}Vt%i};Y<^0aKgz5Hn6RB@I? zbPgQ>S5aV#@Rh7(5HV7%5!}cUN=(wUoD6&QPY6>=!9usII*OSN;4%;Yvb;%^wNdhi0 zr39VgO}gQd>S)X({7RL@S+7<~8R;YeZP;h9LuD+dzpT*K<95F0oFmWPS2oesE^#A? zCxJw|a3xI*65v6kimg0EL#Z|wSMxTf9Ti?g#7&yINO})Ljp#&oy3m&9G$4{NGMHkB zJb^m0!|Z!DK@i3u7IE0@&m-x^1lDq*hLgi9zTOc~NG8|Iib*^p*`&j1 zVploKP_x`!!)y)&T%0EBCZL>e_$&3LI-|ISFDMO}@ZR#C8C!Ep(m9}7r9J{Y#1yf(U)Cd3yJ5ZTF_yw7nmxtCOht;i^v5v`AaL5#Oie195@ zL7;#|%wrb-N14WQ9^!AZwa^&i1CO7YA700^G_+BC^ETRIRx+FQQtXI;h{z7c@EM~? zHZh)}1Di+u324j&5^*AM##oJRe&RKjQw%@^y-8*nKT<^nS#0D^9yJ_OqN4`_PZ%(7 z>DvbLxCDR~b=T`}9)nI~P=LrGCeuOwv<(vto z1WXJ1tvx&=eGlMLUgU^I>-jvZI7)Y55(k5Rzm&U!i(j9`hQ!xPz#dHke&=<%$g{o) zkFi~sdCbj5Mi4LkE{q<$#~Iac@28X20=Q43K_>_(<#TS4BZBI4Cs~HfW2JmfisJVJ zSgo>;Es>AoD!AM53Ec<*0@DLL!A*>mpI{U{SU{n{K8T2%U^Z9CBd8hwBDGL=bYeygf8|aRgNq-z z@iu~PyunFR;){q>@;&#+9)Jk?vY2A&ZyqLT9j4=05h4Q0Si!9UqdXv*ek{|us|PB@ zd`w>=q}pN`!Vgp;lFQ{<6QH392bVqqauozr@r%MJuGW(W`Ne{h?bXV6Z z{&r{`XvjK;2-syh;ITdf_|}4}+}{(S3h(OZ=8Va1I)_r0Fqo& zy~6A2VG=$CA%_x22(SZ{tY$c)*kCd$xE@1UpcXaeBOsdslikxAoYuxb6bT4G5t$4k zoqUt^bZ0Ju11*U@0*>;tsfw#8cTjw|m{)mR+SLy-nSspXw268|+K|DZV2^7kXH9H_ z5!}VvAkmyT7B9mkkV7R|+#wsnjhMk|DoH`3##$LvhGxjWY(W~ijuEg!+STWCjh`88 zn_-1HVANRktSBO$8x6Q1TM$V;B|tIj`4$iD0_a{RSYT;+jb#{3fM~jsLOcg31j^V% z7H4UvW$E>U03;B{sz5F>fOc#)#HN3AZxqRVR?DoCO@amSm0@@uMBHpv7*WFL$s)rF zb1A9n&7~T)3l;G`H^}~_ct+F+eX);#Y5|mHuxjJE{>iXeJu)el_Y5yCB8Qo(u(-4( zU2#6JN0LVL7MI||o5g<~@ItEG>ph<@M8Z>H5;2sK0QBc#` z*M!GeOmm9_1ov0wO3!k#!JcLYbCiW~kD)o`U-FnJ2SE!YSiA?UMZkWkEu#eF(k@uD z0?7up#M+CDG7R1tT51rm&m;jw+~#c{u;L@Kiu+l}SyP=3<67o0U*YLJ{}AJ|6sv1~ z*^J^bw&KCek)}R(vWRJ1ZaLb-nM*GMiQdN(O!Y11Z3Z%#gC;xCl*jmlCoC?5PNA81 z8PwAK^GIh9nI>(90+%s`4;Uyb%;yiJs4?xsPZb*&#c;k+J3?q6g1)@P8wSndJ?MuE z5FESr744NhH~|-vhzls?Q-&~(Y|L3`8!&_qd0r9Z6b$W2XEC>!XvYY2MUADA!!ruC zu_G^W)YR7KyD##vZr4}_0?WBm#oRmzyVGg?7}~Bv~CU<`e#`@VgFXorz1$zH*YebeC)Mwbq?C{es?{ zCg5Ey9W}9r4t9*0ia47VJgG4_gG~mJb$<7|+cL3M`X#3cn5Z=YM)~>Wynfdl#jY;U znOFJE1O*|#*1Q4c9YH(zhwl4B(;OvWvDK(CB0O^{~(@5xY5g*h@jj>*H7jcq+Y%QmGGz)Z9q$hN_ zf;9@`d8F>t7%w?SfQRR_qsCV1t}a;UvWK0Fm92sTNxaN)o%MPN(7K&&hJc+~3m`P& zM?*s@aOmCZBf=`GoXjf-)XBYs3!7yk^;GpBY2>%at5u-8(9;q>MP)7PEdY~(e* z0I(?myTE>)W1+sQvtDFVV$qOkR{XuZ!vZb&D!V6f?^G5?ug(>yjyw|PvxWLQkFqzl%f#=ONpHAVts zCG(iIQtoWFNaZLow*`7JpLxvr%hcIh#i>4)=Ng|Qv#1oA` zIYpcxpKP{sKuUt;!&NNd66{UTmd{;mwXr^vh@t_FXi8HW6R&DN2xFp!kea|#?BAi# z0PI6cR@*lJJ&0tT7rq8V=xdWo?Lj1uUUe;waR{W^^eV2?48IUx#RXA}qu3G!9z=>5 za~_A`Yap6&7Dj>h>5sWE-$my`BqI$c?W!($47+fjz7GO@SZ!ictR#zG7v|irjTTIh z{Qi1h%9_Xc3vc5K1{d9!NuI9P^5&62SEtmTx*SsBbfiDYT%r16=2L9vYgUkp+o?{} z{hX@#YHpEo3i*wF>|h&volfvm_XK!x-oBju50C!=Nj{KH?md;N0000bbVXQnWMOn= zI%9HWVRU5xGB7eSEigGPF*Z~%IXW;nIx#mZFfckWFtzvO1ONa4C3HntbYx+4Wjbwd xWNBu305UK#GA%GUEipD!FgZFfI65&mD=;uRFfhcbT(|%L002ovPDHLkV1j%(J<0$8 literal 0 HcmV?d00001 diff --git a/packages/components/nodes/llms/OpenAI/OpenAI.ts b/packages/components/nodes/llms/OpenAI/OpenAI.ts index 4e35d659..951d1a70 100644 --- a/packages/components/nodes/llms/OpenAI/OpenAI.ts +++ b/packages/components/nodes/llms/OpenAI/OpenAI.ts @@ -125,6 +125,13 @@ class OpenAI_LLMs implements INode { type: 'string', optional: true, additionalParams: true + }, + { + label: 'BaseOptions', + name: 'baseOptions', + type: 'json', + optional: true, + additionalParams: true } ] } @@ -141,6 +148,7 @@ class OpenAI_LLMs implements INode { const bestOf = nodeData.inputs?.bestOf as string const streaming = nodeData.inputs?.streaming as boolean const basePath = nodeData.inputs?.basepath as string + const baseOptions = nodeData.inputs?.baseOptions const credentialData = await getCredentialData(nodeData.credential ?? '', options) const openAIApiKey = getCredentialParam('openAIApiKey', credentialData, nodeData) @@ -160,8 +168,19 @@ class OpenAI_LLMs implements INode { if (batchSize) obj.batchSize = parseInt(batchSize, 10) if (bestOf) obj.bestOf = parseInt(bestOf, 10) + let parsedBaseOptions: any | undefined = undefined + + if (baseOptions) { + try { + parsedBaseOptions = typeof baseOptions === 'object' ? baseOptions : JSON.parse(baseOptions) + } catch (exception) { + throw new Error("Invalid JSON in the OpenAI's BaseOptions: " + exception) + } + } + const model = new OpenAI(obj, { - basePath + basePath, + baseOptions: parsedBaseOptions }) return model } From 3f0157dab192cb428990b38b958e605ea0b067a7 Mon Sep 17 00:00:00 2001 From: Henry Date: Tue, 29 Aug 2023 15:07:59 +0100 Subject: [PATCH 2/2] add custom prompt to sql db chain --- .../SqlDatabaseChain/SqlDatabaseChain.ts | 53 +++++++++++++++++-- .../marketplaces/chatflows/SQL DB Chain.json | 27 +++++++--- 2 files changed, 68 insertions(+), 12 deletions(-) diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index 2a0c71cf..04d704a5 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -1,14 +1,34 @@ import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains/sql_db' -import { getBaseClasses } from '../../../src/utils' +import { getBaseClasses, getInputVariables } from '../../../src/utils' import { DataSource } from 'typeorm' import { SqlDatabase } from 'langchain/sql_db' import { BaseLanguageModel } from 'langchain/base_language' +import { PromptTemplate, PromptTemplateInput } from 'langchain/prompts' import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler' import { DataSourceOptions } from 'typeorm/data-source' type DatabaseType = 'sqlite' | 'postgres' | 'mssql' | 'mysql' +const defaultPrompt = `Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database. + +Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. + +Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + +Use the following format: + +Question: "Question here" +SQLQuery: "SQL Query to run" +SQLResult: "Result of the SQLQuery" +Answer: "Final answer here" + +Only use the tables listed below. + +{table_info} + +Question: {input}` + class SqlDatabaseChain_Chains implements INode { label: string name: string @@ -23,7 +43,7 @@ class SqlDatabaseChain_Chains implements INode { constructor() { this.label = 'Sql Database Chain' this.name = 'sqlDatabaseChain' - this.version = 1.0 + this.version = 2.0 this.type = 'SqlDatabaseChain' this.icon = 'sqlchain.svg' this.category = 'Chains' @@ -64,6 +84,19 @@ class SqlDatabaseChain_Chains implements INode { name: 'url', type: 'string', placeholder: '1270.0.0.1:5432/chinook' + }, + { + label: 'Custom Prompt', + name: 'customPrompt', + type: 'string', + description: + 'You can provide custom prompt to the chain. This will override the existing default prompt used. See guide', + warning: + 'Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above', + rows: 4, + placeholder: defaultPrompt, + additionalParams: true, + optional: true } ] } @@ -72,8 +105,9 @@ class SqlDatabaseChain_Chains implements INode { const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel const url = nodeData.inputs?.url + const customPrompt = nodeData.inputs?.customPrompt as string - const chain = await getSQLDBChain(databaseType, url, model) + const chain = await getSQLDBChain(databaseType, url, model, customPrompt) return chain } @@ -81,8 +115,9 @@ class SqlDatabaseChain_Chains implements INode { const databaseType = nodeData.inputs?.database as DatabaseType const model = nodeData.inputs?.model as BaseLanguageModel const url = nodeData.inputs?.url + const customPrompt = nodeData.inputs?.customPrompt as string - const chain = await getSQLDBChain(databaseType, url, model) + const chain = await getSQLDBChain(databaseType, url, model, customPrompt) const loggerHandler = new ConsoleCallbackHandler(options.logger) if (options.socketIO && options.socketIOClientId) { @@ -96,7 +131,7 @@ class SqlDatabaseChain_Chains implements INode { } } -const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel) => { +const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel, customPrompt?: string) => { const datasource = new DataSource( databaseType === 'sqlite' ? { @@ -119,6 +154,14 @@ const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseL verbose: process.env.DEBUG === 'true' ? true : false } + if (customPrompt) { + const options: PromptTemplateInput = { + template: customPrompt, + inputVariables: getInputVariables(customPrompt) + } + obj.prompt = new PromptTemplate(options) + } + const chain = new SqlDatabaseChain(obj) return chain } diff --git a/packages/server/marketplaces/chatflows/SQL DB Chain.json b/packages/server/marketplaces/chatflows/SQL DB Chain.json index e41c5f78..acec5432 100644 --- a/packages/server/marketplaces/chatflows/SQL DB Chain.json +++ b/packages/server/marketplaces/chatflows/SQL DB Chain.json @@ -157,17 +157,17 @@ }, { "width": 300, - "height": 423, + "height": 475, "id": "sqlDatabaseChain_0", "position": { - "x": 1229.0092429246013, - "y": 231.59431102290245 + "x": 1206.5244299447634, + "y": 201.04431101230608 }, "type": "customNode", "data": { "id": "sqlDatabaseChain_0", "label": "Sql Database Chain", - "version": 1, + "version": 2, "name": "sqlDatabaseChain", "type": "SqlDatabaseChain", "baseClasses": ["SqlDatabaseChain", "BaseChain", "Runnable"], @@ -205,6 +205,18 @@ "type": "string", "placeholder": "1270.0.0.1:5432/chinook", "id": "sqlDatabaseChain_0-input-url-string" + }, + { + "label": "Custom Prompt", + "name": "customPrompt", + "type": "string", + "description": "You can provide custom prompt to the chain. This will override the existing default prompt used. See guide", + "warning": "Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above", + "rows": 4, + "placeholder": "Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\nUse the following format:\n\nQuestion: \"Question here\"\nSQLQuery: \"SQL Query to run\"\nSQLResult: \"Result of the SQLQuery\"\nAnswer: \"Final answer here\"\n\nOnly use the tables listed below.\n\n{table_info}\n\nQuestion: {input}", + "additionalParams": true, + "optional": true, + "id": "sqlDatabaseChain_0-input-customPrompt-string" } ], "inputAnchors": [ @@ -218,7 +230,8 @@ "inputs": { "model": "{{chatOpenAI_0.data.instance}}", "database": "sqlite", - "url": "" + "url": "", + "customPrompt": "" }, "outputAnchors": [ { @@ -233,8 +246,8 @@ }, "selected": false, "positionAbsolute": { - "x": 1229.0092429246013, - "y": 231.59431102290245 + "x": 1206.5244299447634, + "y": 201.04431101230608 }, "dragging": false }