diff --git a/packages/components/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts b/packages/components/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts new file mode 100644 index 00000000..54013ac5 --- /dev/null +++ b/packages/components/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts @@ -0,0 +1,286 @@ +import { flatten } from 'lodash' +import { BaseMessage } from '@langchain/core/messages' +import { ChainValues } from '@langchain/core/utils/types' +import { RunnableSequence } from '@langchain/core/runnables' +import { BaseChatModel } from '@langchain/core/language_models/chat_models' +import { ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, PromptTemplate } from '@langchain/core/prompts' +import { formatToOpenAIToolMessages } from 'langchain/agents/format_scratchpad/openai_tools' +import { getBaseClasses } from '../../../src/utils' +import { type ToolsAgentStep } from 'langchain/agents/openai/output_parser' +import { FlowiseMemory, ICommonObject, INode, INodeData, INodeParams, IUsedTool, IVisionChatModal } from '../../../src/Interface' +import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' +import { AgentExecutor, ToolCallingAgentOutputParser } from '../../../src/agents' +import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation' +import { formatResponse } from '../../outputparsers/OutputParserHelpers' +import type { Document } from '@langchain/core/documents' +import { BaseRetriever } from '@langchain/core/retrievers' +import { RESPONSE_TEMPLATE } from '../../chains/ConversationalRetrievalQAChain/prompts' +import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils' + +class ConversationalRetrievalToolAgent_Agents implements INode { + label: string + name: string + author: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + sessionId?: string + badge?: string + + constructor(fields?: { sessionId?: string }) { + this.label = 'Conversational Retrieval Tool Agent' + this.name = 'conversationalRetrievalToolAgent' + this.author = 'niztal(falkor)' + this.version = 1.0 + this.type = 'AgentExecutor' + this.category = 'Agents' + this.icon = 'toolAgent.png' + this.description = `Agent that calls a vector store retrieval and uses Function Calling to pick the tools and args to call` + this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)] + this.badge = 'NEW' + this.inputs = [ + { + label: 'Tools', + name: 'tools', + type: 'Tool', + list: true + }, + { + label: 'Memory', + name: 'memory', + type: 'BaseChatMemory' + }, + { + label: 'Tool Calling Chat Model', + name: 'model', + type: 'BaseChatModel', + description: + 'Only compatible with models that are capable of function calling. ChatOpenAI, ChatMistral, ChatAnthropic, ChatVertexAI' + }, + { + label: 'System Message', + name: 'systemMessage', + type: 'string', + description: 'Taking the rephrased question, search for answer from the provided context', + warning: 'Prompt must include input variable: {context}', + rows: 4, + additionalParams: true, + optional: true, + default: RESPONSE_TEMPLATE + }, + { + label: 'Input Moderation', + description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', + name: 'inputModeration', + type: 'Moderation', + optional: true, + list: true + }, + { + label: 'Max Iterations', + name: 'maxIterations', + type: 'number', + optional: true, + additionalParams: true + }, + { + label: 'Vector Store Retriever', + name: 'vectorStoreRetriever', + type: 'BaseRetriever' + } + ] + this.sessionId = fields?.sessionId + } + + async init(nodeData: INodeData, input: string, options: ICommonObject): Promise { + return prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input }) + } + + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + const memory = nodeData.inputs?.memory as FlowiseMemory + const moderations = nodeData.inputs?.inputModeration as Moderation[] + + const isStreamable = options.socketIO && options.socketIOClientId + + if (moderations && moderations.length > 0) { + try { + // Use the output of the moderation chain as input for the OpenAI Function Agent + input = await checkInputs(moderations, input) + } catch (e) { + await new Promise((resolve) => setTimeout(resolve, 500)) + if (isStreamable) + streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + return formatResponse(e.message) + } + } + + const executor = await prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input }) + + const loggerHandler = new ConsoleCallbackHandler(options.logger) + const callbacks = await additionalCallbacks(nodeData, options) + + let res: ChainValues = {} + let sourceDocuments: ICommonObject[] = [] + let usedTools: IUsedTool[] = [] + + if (isStreamable) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) + if (res.sourceDocuments) { + options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) + sourceDocuments = res.sourceDocuments + } + if (res.usedTools) { + options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) + usedTools = res.usedTools + } + } else { + res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) + if (res.sourceDocuments) { + sourceDocuments = res.sourceDocuments + } + if (res.usedTools) { + usedTools = res.usedTools + } + } + + let output = res?.output as string + + // Claude 3 Opus tends to spit out .. as well, discard that in final output + const regexPattern: RegExp = /[\s\S]*?<\/thinking>/ + const matches: RegExpMatchArray | null = output.match(regexPattern) + if (matches) { + for (const match of matches) { + output = output.replace(match, '') + } + } + + await memory.addChatMessages( + [ + { + text: input, + type: 'userMessage' + }, + { + text: output, + type: 'apiMessage' + } + ], + this.sessionId + ) + + let finalRes = res?.output + + if (sourceDocuments.length || usedTools.length) { + const finalRes: ICommonObject = { text: output } + if (sourceDocuments.length) { + finalRes.sourceDocuments = flatten(sourceDocuments) + } + if (usedTools.length) { + finalRes.usedTools = usedTools + } + return finalRes + } + + return finalRes + } +} + +const formatDocs = (docs: Document[]) => { + return docs.map((doc, i) => `${doc.pageContent}`).join('\n') +} + +const prepareAgent = async ( + nodeData: INodeData, + options: ICommonObject, + flowObj: { sessionId?: string; chatId?: string; input?: string } +) => { + const model = nodeData.inputs?.model as BaseChatModel + const maxIterations = nodeData.inputs?.maxIterations as string + const memory = nodeData.inputs?.memory as FlowiseMemory + const systemMessage = nodeData.inputs?.systemMessage as string + let tools = nodeData.inputs?.tools + tools = flatten(tools) + const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history' + const inputKey = memory.inputKey ? memory.inputKey : 'input' + const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever + + const prompt = ChatPromptTemplate.fromMessages([ + ['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`], + new MessagesPlaceholder(memoryKey), + ['human', `{${inputKey}}`], + new MessagesPlaceholder('agent_scratchpad') + ]) + + if (llmSupportsVision(model)) { + const visionChatModel = model as IVisionChatModal + const messageContent = await addImagesToMessages(nodeData, options, model.multiModalOption) + + if (messageContent?.length) { + visionChatModel.setVisionModel() + + // Pop the `agent_scratchpad` MessagePlaceHolder + let messagePlaceholder = prompt.promptMessages.pop() as MessagesPlaceholder + if (prompt.promptMessages.at(-1) instanceof HumanMessagePromptTemplate) { + const lastMessage = prompt.promptMessages.pop() as HumanMessagePromptTemplate + const template = (lastMessage.prompt as PromptTemplate).template as string + const msg = HumanMessagePromptTemplate.fromTemplate([ + ...messageContent, + { + text: template + } + ]) + msg.inputVariables = lastMessage.inputVariables + prompt.promptMessages.push(msg) + } + + // Add the `agent_scratchpad` MessagePlaceHolder back + prompt.promptMessages.push(messagePlaceholder) + } else { + visionChatModel.revertToOriginalModel() + } + } + + if (model.bindTools === undefined) { + throw new Error(`This agent requires that the "bindTools()" method be implemented on the input model.`) + } + + const modelWithTools = model.bindTools(tools) + + const runnableAgent = RunnableSequence.from([ + { + [inputKey]: (i: { input: string; steps: ToolsAgentStep[] }) => i.input, + agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => formatToOpenAIToolMessages(i.steps), + [memoryKey]: async (_: { input: string; steps: ToolsAgentStep[] }) => { + const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[] + return messages ?? [] + }, + context: async (i: { input: string; chatHistory?: string }) => { + const relevantDocs = await vectorStoreRetriever.invoke(i.input) + const formattedDocs = formatDocs(relevantDocs) + return formattedDocs + } + }, + prompt, + modelWithTools, + new ToolCallingAgentOutputParser() + ]) + + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnableAgent, + tools, + sessionId: flowObj?.sessionId, + chatId: flowObj?.chatId, + input: flowObj?.input, + verbose: process.env.DEBUG === 'true' ? true : false, + maxIterations: maxIterations ? parseFloat(maxIterations) : undefined + }) + + return executor +} + +module.exports = { nodeClass: ConversationalRetrievalToolAgent_Agents } diff --git a/packages/components/nodes/agents/ConversationalRetrievalToolAgent/toolAgent.png b/packages/components/nodes/agents/ConversationalRetrievalToolAgent/toolAgent.png new file mode 100644 index 00000000..7bf44339 Binary files /dev/null and b/packages/components/nodes/agents/ConversationalRetrievalToolAgent/toolAgent.png differ diff --git a/packages/server/src/utils/index.ts b/packages/server/src/utils/index.ts index 2aa430f0..1fc3be60 100644 --- a/packages/server/src/utils/index.ts +++ b/packages/server/src/utils/index.ts @@ -1228,6 +1228,7 @@ export const isFlowValidForStream = (reactFlowNodes: IReactFlowNode[], endingNod 'conversationalRetrievalAgent', 'openAIToolAgent', 'toolAgent', + 'conversationalRetrievalToolAgent', 'openAIToolAgentLlamaIndex' ] isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name)