Merge pull request #1644 from FlowiseAI/feature/Retriever-Tool-Source-Documents

Feature/Return Source Documens to retriever tool
This commit is contained in:
Henry Heng
2024-01-30 16:28:32 +00:00
committed by GitHub
6 changed files with 86 additions and 14 deletions
@@ -64,7 +64,7 @@ class OpenAIFunctionAgent_Agents implements INode {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory) return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
} }
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const memory = nodeData.inputs?.memory as FlowiseMemory const memory = nodeData.inputs?.memory as FlowiseMemory
const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory) const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
@@ -72,12 +72,20 @@ class OpenAIFunctionAgent_Agents implements INode {
const callbacks = await additionalCallbacks(nodeData, options) const callbacks = await additionalCallbacks(nodeData, options)
let res: ChainValues = {} let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []
if (options.socketIO && options.socketIOClientId) { if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
if (res.sourceDocuments) {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments
}
} else { } else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
} }
await memory.addChatMessages( await memory.addChatMessages(
@@ -94,7 +102,7 @@ class OpenAIFunctionAgent_Agents implements INode {
this.sessionId this.sessionId
) )
return res?.output return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output
} }
} }
@@ -29,16 +29,17 @@ class CustomListOutputParser implements INode {
label: 'Length', label: 'Length',
name: 'length', name: 'length',
type: 'number', type: 'number',
default: 5,
step: 1, step: 1,
description: 'Number of values to return' description: 'Number of values to return',
optional: true
}, },
{ {
label: 'Separator', label: 'Separator',
name: 'separator', name: 'separator',
type: 'string', type: 'string',
description: 'Separator between values', description: 'Separator between values',
default: ',' default: ',',
optional: true
}, },
{ {
label: 'Autofix', label: 'Autofix',
@@ -54,10 +55,11 @@ class CustomListOutputParser implements INode {
const separator = nodeData.inputs?.separator as string const separator = nodeData.inputs?.separator as string
const lengthStr = nodeData.inputs?.length as string const lengthStr = nodeData.inputs?.length as string
const autoFix = nodeData.inputs?.autofixParser as boolean const autoFix = nodeData.inputs?.autofixParser as boolean
let length = 5
if (lengthStr) length = parseInt(lengthStr, 10)
const parser = new LangchainCustomListOutputParser({ length: length, separator: separator }) const parser = new LangchainCustomListOutputParser({
length: lengthStr ? parseInt(lengthStr, 10) : undefined,
separator: separator
})
Object.defineProperty(parser, 'autoFix', { Object.defineProperty(parser, 'autoFix', {
enumerable: true, enumerable: true,
configurable: true, configurable: true,
@@ -1,8 +1,11 @@
import { INode, INodeData, INodeParams } from '../../../src/Interface' import { INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils' import { getBaseClasses } from '../../../src/utils'
import { DynamicTool } from 'langchain/tools' import { DynamicTool } from 'langchain/tools'
import { createRetrieverTool } from 'langchain/agents/toolkits' import { DynamicStructuredTool } from '@langchain/core/tools'
import { CallbackManagerForToolRun } from '@langchain/core/callbacks/manager'
import { BaseRetriever } from 'langchain/schema/retriever' import { BaseRetriever } from 'langchain/schema/retriever'
import { z } from 'zod'
import { SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents'
class Retriever_Tools implements INode { class Retriever_Tools implements INode {
label: string label: string
@@ -19,7 +22,7 @@ class Retriever_Tools implements INode {
constructor() { constructor() {
this.label = 'Retriever Tool' this.label = 'Retriever Tool'
this.name = 'retrieverTool' this.name = 'retrieverTool'
this.version = 1.0 this.version = 2.0
this.type = 'RetrieverTool' this.type = 'RetrieverTool'
this.icon = 'retrievertool.svg' this.icon = 'retrievertool.svg'
this.category = 'Tools' this.category = 'Tools'
@@ -44,6 +47,12 @@ class Retriever_Tools implements INode {
label: 'Retriever', label: 'Retriever',
name: 'retriever', name: 'retriever',
type: 'BaseRetriever' type: 'BaseRetriever'
},
{
label: 'Return Source Documents',
name: 'returnSourceDocuments',
type: 'boolean',
optional: true
} }
] ]
} }
@@ -52,12 +61,25 @@ class Retriever_Tools implements INode {
const name = nodeData.inputs?.name as string const name = nodeData.inputs?.name as string
const description = nodeData.inputs?.description as string const description = nodeData.inputs?.description as string
const retriever = nodeData.inputs?.retriever as BaseRetriever const retriever = nodeData.inputs?.retriever as BaseRetriever
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const tool = createRetrieverTool(retriever, { const input = {
name, name,
description description
}
const func = async ({ input }: { input: string }, runManager?: CallbackManagerForToolRun) => {
const docs = await retriever.getRelevantDocuments(input, runManager?.getChild('retriever'))
const content = docs.map((doc) => doc.pageContent).join('\n\n')
const sourceDocuments = JSON.stringify(docs)
return returnSourceDocuments ? content + SOURCE_DOCUMENTS_PREFIX + sourceDocuments : content
}
const schema = z.object({
input: z.string().describe('query to look up in retriever')
}) })
const tool = new DynamicStructuredTool({ ...input, func, schema })
return tool return tool
} }
} }
+24 -1
View File
@@ -1,5 +1,6 @@
import { flatten } from 'lodash'
import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents' import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents'
import { ChainValues, AgentStep, AgentFinish, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema' import { ChainValues, AgentStep, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema'
import { OutputParserException } from 'langchain/schema/output_parser' import { OutputParserException } from 'langchain/schema/output_parser'
import { CallbackManager, CallbackManagerForChainRun, Callbacks } from 'langchain/callbacks' import { CallbackManager, CallbackManagerForChainRun, Callbacks } from 'langchain/callbacks'
import { ToolInputParsingException, Tool } from '@langchain/core/tools' import { ToolInputParsingException, Tool } from '@langchain/core/tools'
@@ -7,6 +8,11 @@ import { Runnable } from 'langchain/schema/runnable'
import { BaseChain, SerializedLLMChain } from 'langchain/chains' import { BaseChain, SerializedLLMChain } from 'langchain/chains'
import { Serializable } from '@langchain/core/load/serializable' import { Serializable } from '@langchain/core/load/serializable'
export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n'
type AgentFinish = {
returnValues: Record<string, any>
log: string
}
type AgentExecutorOutput = ChainValues type AgentExecutorOutput = ChainValues
interface AgentExecutorIteratorInput { interface AgentExecutorIteratorInput {
@@ -315,10 +321,12 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
const steps: AgentStep[] = [] const steps: AgentStep[] = []
let iterations = 0 let iterations = 0
let sourceDocuments: Array<Document> = []
const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => { const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => {
const { returnValues } = finishStep const { returnValues } = finishStep
const additional = await this.agent.prepareForOutput(returnValues, steps) const additional = await this.agent.prepareForOutput(returnValues, steps)
if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments)
if (this.returnIntermediateSteps) { if (this.returnIntermediateSteps) {
return { ...returnValues, intermediateSteps: steps, ...additional } return { ...returnValues, intermediateSteps: steps, ...additional }
@@ -406,6 +414,17 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
return { action, observation: observation ?? '' } return { action, observation: observation ?? '' }
} }
} }
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
observation = observationArray[0]
const docs = observationArray[1]
try {
const parsedDocs = JSON.parse(docs)
sourceDocuments.push(parsedDocs)
} catch (e) {
console.error('Error parsing source documents from tool')
}
}
return { action, observation: observation ?? '' } return { action, observation: observation ?? '' }
}) })
) )
@@ -500,6 +519,10 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
chatId: this.chatId, chatId: this.chatId,
input: this.input input: this.input
}) })
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
observation = observationArray[0]
}
} catch (e) { } catch (e) {
if (e instanceof ToolInputParsingException) { if (e instanceof ToolInputParsingException) {
if (this.handleParsingErrors === true) { if (this.handleParsingErrors === true) {
@@ -217,6 +217,13 @@
"rows": 3, "rows": 3,
"placeholder": "Searches and returns documents regarding the state-of-the-union.", "placeholder": "Searches and returns documents regarding the state-of-the-union.",
"id": "retrieverTool_0-input-description-string" "id": "retrieverTool_0-input-description-string"
},
{
"label": "Return Source Documents",
"name": "returnSourceDocuments",
"type": "boolean",
"optional": true,
"id": "retrieverTool_0-input-returnSourceDocuments-boolean"
} }
], ],
"inputAnchors": [ "inputAnchors": [
@@ -230,7 +237,8 @@
"inputs": { "inputs": {
"name": "search_website", "name": "search_website",
"description": "Searches and return documents regarding Jane - a culinary institution that offers top quality coffee, pastries, breakfast, lunch, and a variety of baked goods. They have multiple locations, including Jane on Fillmore, Jane on Larkin, Jane the Bakery, Toy Boat By Jane, and Little Jane on Grant. They emphasize healthy eating with a focus on flavor and quality ingredients. They bake everything in-house and work with local suppliers to source ingredients directly from farmers. They also offer catering services and delivery options.", "description": "Searches and return documents regarding Jane - a culinary institution that offers top quality coffee, pastries, breakfast, lunch, and a variety of baked goods. They have multiple locations, including Jane on Fillmore, Jane on Larkin, Jane the Bakery, Toy Boat By Jane, and Little Jane on Grant. They emphasize healthy eating with a focus on flavor and quality ingredients. They bake everything in-house and work with local suppliers to source ingredients directly from farmers. They also offer catering services and delivery options.",
"retriever": "{{pinecone_0.data.instance}}" "retriever": "{{pinecone_0.data.instance}}",
"returnSourceDocuments": true
}, },
"outputAnchors": [ "outputAnchors": [
{ {
+10 -1
View File
@@ -473,6 +473,8 @@ export class App {
const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id)) const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id))
let isStreaming = false let isStreaming = false
let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode')
for (const endingNode of endingNodes) { for (const endingNode of endingNodes) {
const endingNodeData = endingNode.data const endingNodeData = endingNode.data
if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`) if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`)
@@ -488,7 +490,8 @@ export class App {
isStreaming = isEndingNode ? false : isFlowValidForStream(nodes, endingNodeData) isStreaming = isEndingNode ? false : isFlowValidForStream(nodes, endingNodeData)
} }
const obj = { isStreaming } // Once custom function ending node exists, flow is always unavailable to stream
const obj = { isStreaming: isEndingNodeExists ? false : isStreaming }
return res.json(obj) return res.json(obj)
}) })
@@ -1677,6 +1680,9 @@ export class App {
if (!endingNodeIds.length) return res.status(500).send(`Ending nodes not found`) if (!endingNodeIds.length) return res.status(500).send(`Ending nodes not found`)
const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id)) const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id))
let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode')
for (const endingNode of endingNodes) { for (const endingNode of endingNodes) {
const endingNodeData = endingNode.data const endingNodeData = endingNode.data
if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`) if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`)
@@ -1704,6 +1710,9 @@ export class App {
isStreamValid = isFlowValidForStream(nodes, endingNodeData) isStreamValid = isFlowValidForStream(nodes, endingNodeData)
} }
// Once custom function ending node exists, flow is always unavailable to stream
isStreamValid = isEndingNodeExists ? false : isStreamValid
let chatHistory: IMessage[] = incomingInput.history ?? [] let chatHistory: IMessage[] = incomingInput.history ?? []
// When {{chat_history}} is used in Prompt Template, fetch the chat conversations from memory node // When {{chat_history}} is used in Prompt Template, fetch the chat conversations from memory node