diff --git a/packages/components/nodes/chains/LLMChain/LLMChain.ts b/packages/components/nodes/chains/LLMChain/LLMChain.ts index 63994b13..0544365a 100644 --- a/packages/components/nodes/chains/LLMChain/LLMChain.ts +++ b/packages/components/nodes/chains/LLMChain/LLMChain.ts @@ -3,6 +3,8 @@ import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils' import { LLMChain } from 'langchain/chains' import { BaseLanguageModel } from 'langchain/base_language' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' +import { BaseOutputParser } from 'langchain/schema/output_parser' +import { ChatPromptTemplate, PromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' class LLMChain_Chains implements INode { label: string @@ -19,7 +21,7 @@ class LLMChain_Chains implements INode { constructor() { this.label = 'LLM Chain' this.name = 'llmChain' - this.version = 1.0 + this.version = 2.0 this.type = 'LLMChain' this.icon = 'chain.svg' this.category = 'Chains' @@ -36,6 +38,12 @@ class LLMChain_Chains implements INode { name: 'prompt', type: 'BasePromptTemplate' }, + { + label: 'Output Parser', + name: 'outputParser', + type: 'BaseLLMOutputParser', + optional: true + }, { label: 'Chain Name', name: 'chainName', @@ -87,8 +95,35 @@ class LLMChain_Chains implements INode { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const inputVariables = nodeData.instance.prompt.inputVariables as string[] // ["product"] const chain = nodeData.instance as LLMChain - const promptValues = nodeData.inputs?.prompt.promptValues as ICommonObject - const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData) + let promptValues = nodeData.inputs?.prompt.promptValues as ICommonObject + const outputParser = nodeData.inputs?.outputParser as BaseOutputParser + if (outputParser && chain.prompt) { + const formatInstructions = outputParser.getFormatInstructions() + chain.prompt.inputVariables.push('format_instructions') + if (chain.prompt instanceof PromptTemplate) { + let pt = chain.prompt + pt.template = pt.template + '\n{format_instructions}' + chain.prompt.partialVariables = { format_instructions: formatInstructions } + // eslint-disable-next-line no-console + console.log('prompt :: ', chain.prompt) + } else if (chain.prompt instanceof ChatPromptTemplate) { + let pt = chain.prompt + pt.promptMessages.forEach((msg) => { + if (msg instanceof SystemMessagePromptTemplate) { + ;(msg.prompt as any).partialVariables = { format_instructions: outputParser.getFormatInstructions() } + ;(msg.prompt as any).template = ((msg.prompt as any).template + '\n{format_instructions}') as string + // eslint-disable-next-line no-console + console.log(msg) + } + }) + //pt.template = pt.template + '\n{format_instructions}' + } + + promptValues = { ...promptValues, format_instructions: outputParser.getFormatInstructions() } + // eslint-disable-next-line no-console + console.log('promptValues :: ', promptValues) + } + const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData, outputParser) // eslint-disable-next-line no-console console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m') // eslint-disable-next-line no-console @@ -103,7 +138,8 @@ const runPrediction = async ( input: string, promptValuesRaw: ICommonObject, options: ICommonObject, - nodeData: INodeData + nodeData: INodeData, + outputParser: BaseOutputParser | undefined = undefined ) => { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) @@ -135,10 +171,10 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return res?.text + return runOutputParser(res?.text, outputParser) } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return res?.text + return runOutputParser(res?.text, outputParser) } } else if (seen.length === 1) { // If one inputVariable is not specify, use input (user's question) as value @@ -151,10 +187,10 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return res?.text + return runOutputParser(res?.text, outputParser) } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return res?.text + return runOutputParser(res?.text, outputParser) } } else { throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`) @@ -163,12 +199,26 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) - return res + return runOutputParser(res, outputParser) } else { const res = await chain.run(input, [loggerHandler, ...callbacks]) - return res + return runOutputParser(res, outputParser) } } } +const runOutputParser = async (response: string, outputParser: BaseOutputParser | undefined): Promise => { + if (outputParser) { + const parsedResponse = await outputParser.parse(response) + // eslint-disable-next-line no-console + console.log('**** parsedResponse ****', parsedResponse) + if (typeof parsedResponse === 'object') { + return JSON.stringify(parsedResponse) + } else { + return parsedResponse as string + } + } + return response +} + module.exports = { nodeClass: LLMChain_Chains } diff --git a/packages/components/nodes/outputparsers/csvlist/CSVListOutputParser.ts b/packages/components/nodes/outputparsers/csvlist/CSVListOutputParser.ts new file mode 100644 index 00000000..04911fb8 --- /dev/null +++ b/packages/components/nodes/outputparsers/csvlist/CSVListOutputParser.ts @@ -0,0 +1,35 @@ +import { getBaseClasses, ICommonObject, INode, INodeData, INodeParams } from '../../../src' +import { BaseOutputParser } from 'langchain/schema/output_parser' +import { CommaSeparatedListOutputParser } from 'langchain/output_parsers' + +class CSVListOutputParser implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + credential: INodeParams + + constructor() { + this.label = 'CSV Output Parser' + this.name = 'csvOutputParser' + this.version = 1.0 + this.type = 'CSVListOutputParser' + this.description = 'Parse the output of an LLM call as a comma-separated list of values' + this.icon = 'csv.png' + this.category = 'Output Parser' + this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)] + this.inputs = [] + } + + // eslint-disable-next-line unused-imports/no-unused-vars + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { + return new CommaSeparatedListOutputParser() + } +} + +module.exports = { nodeClass: CSVListOutputParser } diff --git a/packages/components/nodes/outputparsers/csvlist/csv.png b/packages/components/nodes/outputparsers/csvlist/csv.png new file mode 100644 index 00000000..41b84e16 Binary files /dev/null and b/packages/components/nodes/outputparsers/csvlist/csv.png differ diff --git a/packages/components/nodes/outputparsers/customlist/CustomListOutputParser.ts b/packages/components/nodes/outputparsers/customlist/CustomListOutputParser.ts new file mode 100644 index 00000000..db05117e --- /dev/null +++ b/packages/components/nodes/outputparsers/customlist/CustomListOutputParser.ts @@ -0,0 +1,55 @@ +import { getBaseClasses, ICommonObject, INode, INodeData, INodeParams } from '../../../src' +import { BaseOutputParser } from 'langchain/schema/output_parser' +import { CustomListOutputParser as LangchainCustomListOutputParser } from 'langchain/output_parsers' + +class CustomListOutputParser implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + credential: INodeParams + + constructor() { + this.label = 'Custom List Output Parser' + this.name = 'customListOutputParser' + this.version = 1.0 + this.type = 'CustomListOutputParser' + this.description = 'Parse the output of an LLM call as a list of values.' + this.icon = 'list.png' + this.category = 'Output Parser' + this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)] + this.inputs = [ + { + label: 'Length', + name: 'length', + type: 'number', + default: 5, + step: 1, + description: 'Number of values to return' + }, + { + label: 'Separator', + name: 'separator', + type: 'string', + description: 'Separator between values', + default: ',' + } + ] + } + + // eslint-disable-next-line unused-imports/no-unused-vars + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { + const separator = nodeData.inputs?.separator as string + const lengthStr = nodeData.inputs?.length as string + let length = 5 + if (lengthStr) length = parseInt(lengthStr, 10) + return new LangchainCustomListOutputParser({ length: length, separator: separator }) + } +} + +module.exports = { nodeClass: CustomListOutputParser } diff --git a/packages/components/nodes/outputparsers/customlist/list.png b/packages/components/nodes/outputparsers/customlist/list.png new file mode 100644 index 00000000..acb4e5d6 Binary files /dev/null and b/packages/components/nodes/outputparsers/customlist/list.png differ diff --git a/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts b/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts new file mode 100644 index 00000000..e935e5fb --- /dev/null +++ b/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts @@ -0,0 +1,77 @@ +import { getBaseClasses, ICommonObject, INode, INodeData, INodeParams } from '../../../src' +import { BaseOutputParser } from 'langchain/schema/output_parser' +import { StructuredOutputParser as LangchainStructuredOutputParser } from 'langchain/output_parsers' + +class StructuredOutputParser implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + credential: INodeParams + + constructor() { + this.label = 'Structured Output Parser' + this.name = 'structuredOutputParser' + this.version = 1.0 + this.type = 'StructuredOutputParser' + this.description = 'Parse the output of an LLM call into a given (JSON) structure.' + this.icon = 'structure.png' + this.category = 'Output Parser' + this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)] + this.inputs = [ + { + label: 'Structure Type', + name: 'structureType', + type: 'options', + options: [ + { + label: 'Names And Descriptions', + name: 'fromNamesAndDescriptions' + }, + { + label: 'Zod Schema', + name: 'fromZodSchema' + } + ], + default: 'fromNamesAndDescriptions' + }, + { + label: 'Structure', + name: 'structure', + type: 'string', + rows: 4, + placeholder: + '{' + + ' answer: "answer to the question",\n' + + ' source: "source used to answer the question, should be a website.",\n' + + '}' + } + ] + } + + // eslint-disable-next-line unused-imports/no-unused-vars + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { + const structureType = nodeData.inputs?.structureType as string + const structure = nodeData.inputs?.structure as string + let parsedStructure: any | undefined = undefined + if (structure) { + try { + parsedStructure = JSON.parse(structure) + } catch (exception) { + throw new Error('Invalid JSON in StructuredOutputParser: ' + exception) + } + } + if (structureType === 'fromZodSchema') { + return LangchainStructuredOutputParser.fromZodSchema(parsedStructure) + } else { + return LangchainStructuredOutputParser.fromNamesAndDescriptions(parsedStructure) + } + } +} + +module.exports = { nodeClass: StructuredOutputParser } diff --git a/packages/components/nodes/outputparsers/structured/structure.png b/packages/components/nodes/outputparsers/structured/structure.png new file mode 100644 index 00000000..c56b2dd7 Binary files /dev/null and b/packages/components/nodes/outputparsers/structured/structure.png differ