add return source documens to retriever tool

This commit is contained in:
Henry
2024-01-30 14:12:55 +00:00
parent 18c9c1cc51
commit 4d6881b506
6 changed files with 86 additions and 14 deletions
+24 -1
View File
@@ -1,5 +1,6 @@
import { flatten } from 'lodash'
import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents'
import { ChainValues, AgentStep, AgentFinish, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema'
import { ChainValues, AgentStep, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema'
import { OutputParserException } from 'langchain/schema/output_parser'
import { CallbackManager, CallbackManagerForChainRun, Callbacks } from 'langchain/callbacks'
import { ToolInputParsingException, Tool } from '@langchain/core/tools'
@@ -7,6 +8,11 @@ import { Runnable } from 'langchain/schema/runnable'
import { BaseChain, SerializedLLMChain } from 'langchain/chains'
import { Serializable } from '@langchain/core/load/serializable'
export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n'
type AgentFinish = {
returnValues: Record<string, any>
log: string
}
type AgentExecutorOutput = ChainValues
interface AgentExecutorIteratorInput {
@@ -315,10 +321,12 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
const steps: AgentStep[] = []
let iterations = 0
let sourceDocuments: Array<Document> = []
const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => {
const { returnValues } = finishStep
const additional = await this.agent.prepareForOutput(returnValues, steps)
if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments)
if (this.returnIntermediateSteps) {
return { ...returnValues, intermediateSteps: steps, ...additional }
@@ -406,6 +414,17 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
return { action, observation: observation ?? '' }
}
}
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
observation = observationArray[0]
const docs = observationArray[1]
try {
const parsedDocs = JSON.parse(docs)
sourceDocuments.push(parsedDocs)
} catch (e) {
console.error('Error parsing source documents from tool')
}
}
return { action, observation: observation ?? '' }
})
)
@@ -500,6 +519,10 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
chatId: this.chatId,
input: this.input
})
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
observation = observationArray[0]
}
} catch (e) {
if (e instanceof ToolInputParsingException) {
if (this.handleParsingErrors === true) {