From 8857530f290e7b19caaae389d0f810d3a338978c Mon Sep 17 00:00:00 2001 From: Henry Date: Sun, 29 Oct 2023 10:27:04 +0000 Subject: [PATCH] return JSON output in the chat --- .../nodes/chains/LLMChain/LLMChain.ts | 29 ++++++++++--------- .../outputparsers/OutputParserHelpers.ts | 13 ++------- .../vectorstores/Milvus/Milvus_Upsert.ts | 2 +- packages/components/src/handler.ts | 5 +++- packages/server/src/index.ts | 5 +++- .../ui/src/views/chatmessage/ChatMessage.js | 6 ++++ 6 files changed, 33 insertions(+), 27 deletions(-) diff --git a/packages/components/nodes/chains/LLMChain/LLMChain.ts b/packages/components/nodes/chains/LLMChain/LLMChain.ts index 7ec777be..7d450825 100644 --- a/packages/components/nodes/chains/LLMChain/LLMChain.ts +++ b/packages/components/nodes/chains/LLMChain/LLMChain.ts @@ -4,7 +4,7 @@ import { LLMChain } from 'langchain/chains' import { BaseLanguageModel } from 'langchain/base_language' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { BaseOutputParser } from 'langchain/schema/output_parser' -import { injectOutputParser } from '../../outputparsers/OutputParserHelpers' +import { formatResponse, injectOutputParser } from '../../outputparsers/OutputParserHelpers' import { BaseLLMOutputParser } from 'langchain/schema/output_parser' import { OutputFixingParser } from 'langchain/output_parsers' @@ -98,7 +98,7 @@ class LLMChain_Chains implements INode { verbose: process.env.DEBUG === 'true' }) const inputVariables = chain.prompt.inputVariables as string[] // ["product"] - const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData) + const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData, this.outputParser) // eslint-disable-next-line no-console console.log('\x1b[92m\x1b[1m\n*****OUTPUT PREDICTION*****\n\x1b[0m\x1b[0m') // eslint-disable-next-line no-console @@ -112,7 +112,7 @@ class LLMChain_Chains implements INode { } } - async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const inputVariables = nodeData.instance.prompt.inputVariables as string[] // ["product"] const chain = nodeData.instance as LLMChain let promptValues: ICommonObject | undefined = nodeData.inputs?.prompt.promptValues as ICommonObject @@ -121,7 +121,7 @@ class LLMChain_Chains implements INode { this.outputParser = outputParser } promptValues = injectOutputParser(this.outputParser, chain, promptValues) - const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData) + const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData, this.outputParser) // eslint-disable-next-line no-console console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m') // eslint-disable-next-line no-console @@ -136,7 +136,8 @@ const runPrediction = async ( input: string, promptValuesRaw: ICommonObject | undefined, options: ICommonObject, - nodeData: INodeData + nodeData: INodeData, + outputParser: BaseOutputParser ) => { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) @@ -166,12 +167,12 @@ const runPrediction = async ( // All inputVariables have fixed values specified const options = { ...promptValues } if (isStreaming) { - const handler = new CustomChainHandler(socketIO, socketIOClientId) + const handler = new CustomChainHandler(socketIO, socketIOClientId, undefined, undefined, outputParser ? true : undefined) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return res?.text + return formatResponse(res?.text) } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return res?.text + return formatResponse(res?.text) } } else if (seen.length === 1) { // If one inputVariable is not specify, use input (user's question) as value @@ -182,24 +183,24 @@ const runPrediction = async ( [lastValue]: input } if (isStreaming) { - const handler = new CustomChainHandler(socketIO, socketIOClientId) + const handler = new CustomChainHandler(socketIO, socketIOClientId, undefined, undefined, outputParser ? true : undefined) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return res?.text + return formatResponse(res?.text) } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return res?.text + return formatResponse(res?.text) } } else { throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`) } } else { if (isStreaming) { - const handler = new CustomChainHandler(socketIO, socketIOClientId) + const handler = new CustomChainHandler(socketIO, socketIOClientId, undefined, undefined, outputParser ? true : undefined) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) - return res + return formatResponse(res) } else { const res = await chain.run(input, [loggerHandler, ...callbacks]) - return res + return formatResponse(res) } } } diff --git a/packages/components/nodes/outputparsers/OutputParserHelpers.ts b/packages/components/nodes/outputparsers/OutputParserHelpers.ts index 87a59170..63dd4838 100644 --- a/packages/components/nodes/outputparsers/OutputParserHelpers.ts +++ b/packages/components/nodes/outputparsers/OutputParserHelpers.ts @@ -6,16 +6,9 @@ import { ChatPromptTemplate, FewShotPromptTemplate, PromptTemplate, SystemMessag export const CATEGORY = 'Output Parser (Experimental)' -export const applyOutputParser = async (response: string, outputParser: BaseOutputParser | undefined): Promise => { - if (outputParser) { - const parsedResponse = await outputParser.parse(response) - // eslint-disable-next-line no-console - console.log('**** parsedResponse ****', parsedResponse) - if (typeof parsedResponse === 'object') { - return JSON.stringify(parsedResponse) - } else { - return parsedResponse as string - } +export const formatResponse = (response: string | object): string | object => { + if (typeof response === 'object') { + return { json: response } } return response } diff --git a/packages/components/nodes/vectorstores/Milvus/Milvus_Upsert.ts b/packages/components/nodes/vectorstores/Milvus/Milvus_Upsert.ts index ca69cb39..40afe9a4 100644 --- a/packages/components/nodes/vectorstores/Milvus/Milvus_Upsert.ts +++ b/packages/components/nodes/vectorstores/Milvus/Milvus_Upsert.ts @@ -252,7 +252,7 @@ class MilvusUpsert extends Milvus { collection_name: this.collectionName }) - if (descIndexResp.status.error_code === ErrorCode.INDEX_NOT_EXIST) { + if (descIndexResp.status.error_code === ErrorCode.IndexNotExist) { const resp = await this.client.createIndex({ collection_name: this.collectionName, field_name: this.vectorField, diff --git a/packages/components/src/handler.ts b/packages/components/src/handler.ts index 37075342..c0ee67e4 100644 --- a/packages/components/src/handler.ts +++ b/packages/components/src/handler.ts @@ -152,13 +152,15 @@ export class CustomChainHandler extends BaseCallbackHandler { skipK = 0 // Skip streaming for first K numbers of handleLLMStart returnSourceDocuments = false cachedResponse = true + isOutputParser = false - constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean) { + constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean, isOutputParser?: boolean) { super() this.socketIO = socketIO this.socketIOClientId = socketIOClientId this.skipK = skipK ?? this.skipK this.returnSourceDocuments = returnSourceDocuments ?? this.returnSourceDocuments + this.isOutputParser = isOutputParser ?? this.isOutputParser } handleLLMStart() { @@ -171,6 +173,7 @@ export class CustomChainHandler extends BaseCallbackHandler { if (!this.isLLMStarted) { this.isLLMStarted = true this.socketIO.to(this.socketIOClientId).emit('start', token) + if (this.isOutputParser) this.socketIO.to(this.socketIOClientId).emit('token', '```json') } this.socketIO.to(this.socketIOClientId).emit('token', token) } diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 9d3f7052..1afb8396 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -980,7 +980,7 @@ export class App { if (nodeToExecuteData.instance) checkMemorySessionId(nodeToExecuteData.instance, chatId) - const result = isStreamValid + let result = isStreamValid ? await nodeInstance.run(nodeToExecuteData, incomingInput.question, { chatHistory: incomingInput.history, socketIO, @@ -998,7 +998,10 @@ export class App { analytic: chatflow.analytic }) + result = typeof result === 'string' ? { text: result } : result + logger.debug(`[server]: Finished running ${nodeToExecuteData.label} (${nodeToExecuteData.id})`) + return res.json(result) } catch (e: any) { logger.error('[server]: Error:', e) diff --git a/packages/ui/src/views/chatmessage/ChatMessage.js b/packages/ui/src/views/chatmessage/ChatMessage.js index 3e967541..d199259b 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.js +++ b/packages/ui/src/views/chatmessage/ChatMessage.js @@ -165,6 +165,12 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { ]) } addChatMessage(data.text, 'apiMessage', data.sourceDocuments) + } else if (typeof data === 'object' && data.json) { + const text = '```json' + JSON.stringify(data.json, null, 2) + if (!isChatFlowAvailableToStream) { + setMessages((prevMessages) => [...prevMessages, { message: text, type: 'apiMessage' }]) + } + addChatMessage(text, 'apiMessage') } else { if (!isChatFlowAvailableToStream) { setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }])