update mrkl agents

This commit is contained in:
Henry
2024-02-20 18:23:39 +08:00
parent d1fdd8b3bd
commit 15afb8a2dd
4 changed files with 345 additions and 105 deletions
+121 -3
View File
@@ -3,12 +3,23 @@ import { ChainValues } from '@langchain/core/utils/types'
import { AgentStep, AgentAction } from '@langchain/core/agents'
import { BaseMessage, FunctionMessage, AIMessage } from '@langchain/core/messages'
import { OutputParserException } from '@langchain/core/output_parsers'
import { BaseLanguageModel } from '@langchain/core/language_models/base'
import { CallbackManager, CallbackManagerForChainRun, Callbacks } from '@langchain/core/callbacks/manager'
import { ToolInputParsingException, Tool } from '@langchain/core/tools'
import { Runnable } from '@langchain/core/runnables'
import { ToolInputParsingException, Tool, StructuredToolInterface } from '@langchain/core/tools'
import { Runnable, RunnableSequence, RunnablePassthrough } from '@langchain/core/runnables'
import { Serializable } from '@langchain/core/load/serializable'
import { renderTemplate } from '@langchain/core/prompts'
import { BaseChain, SerializedLLMChain } from 'langchain/chains'
import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents'
import {
CreateReactAgentParams,
AgentExecutorInput,
AgentActionOutputParser,
BaseSingleActionAgent,
BaseMultiActionAgent,
RunnableAgent,
StoppingMethod
} from 'langchain/agents'
import { formatLogToString } from 'langchain/agents/format_scratchpad/log'
export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n'
type AgentFinish = {
@@ -647,3 +658,110 @@ export const formatAgentSteps = (steps: AgentStep[]): BaseMessage[] =>
return [new AIMessage(action.log)]
}
})
const renderTextDescription = (tools: StructuredToolInterface[]): string => {
return tools.map((tool) => `${tool.name}: ${tool.description}`).join('\n')
}
export const createReactAgent = async ({ llm, tools, prompt }: CreateReactAgentParams) => {
const missingVariables = ['tools', 'tool_names', 'agent_scratchpad'].filter((v) => !prompt.inputVariables.includes(v))
if (missingVariables.length > 0) {
throw new Error(`Provided prompt is missing required input variables: ${JSON.stringify(missingVariables)}`)
}
const toolNames = tools.map((tool) => tool.name)
const partialedPrompt = await prompt.partial({
tools: renderTextDescription(tools),
tool_names: toolNames.join(', ')
})
// TODO: Add .bind to core runnable interface.
const llmWithStop = (llm as BaseLanguageModel).bind({
stop: ['\nObservation:']
})
const agent = RunnableSequence.from([
RunnablePassthrough.assign({
//@ts-ignore
agent_scratchpad: (input: { steps: AgentStep[] }) => formatLogToString(input.steps)
}),
partialedPrompt,
llmWithStop,
new ReActSingleInputOutputParser({
toolNames
})
])
return agent
}
class ReActSingleInputOutputParser extends AgentActionOutputParser {
lc_namespace = ['langchain', 'agents', 'react']
private toolNames: string[]
private FINAL_ANSWER_ACTION = 'Final Answer:'
private FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = 'Parsing LLM output produced both a final answer and a parse-able action:'
private FORMAT_INSTRUCTIONS = `Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`
constructor(fields: { toolNames: string[] }) {
super(...arguments)
this.toolNames = fields.toolNames
}
/**
* Parses the given text into an AgentAction or AgentFinish object. If an
* output fixing parser is defined, uses it to parse the text.
* @param text Text to parse.
* @returns Promise that resolves to an AgentAction or AgentFinish object.
*/
async parse(text: string): Promise<AgentAction | AgentFinish> {
const includesAnswer = text.includes(this.FINAL_ANSWER_ACTION)
const regex = /Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)/
const actionMatch = text.match(regex)
if (actionMatch) {
if (includesAnswer) {
throw new Error(`${this.FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: ${text}`)
}
const action = actionMatch[1]
const actionInput = actionMatch[2]
const toolInput = actionInput.trim().replace(/"/g, '')
return {
tool: action,
toolInput,
log: text
}
}
if (includesAnswer) {
const finalAnswerText = text.split(this.FINAL_ANSWER_ACTION)[1].trim()
return {
returnValues: {
output: finalAnswerText
},
log: text
}
}
// Instead of throwing Error, we return a AgentFinish object
return { returnValues: { output: text }, log: text }
}
/**
* Returns the format instructions as a string. If the 'raw' option is
* true, returns the raw FORMAT_INSTRUCTIONS.
* @param options Options for getting the format instructions.
* @returns Format instructions as a string.
*/
getFormatInstructions(): string {
return renderTemplate(this.FORMAT_INSTRUCTIONS, 'f-string', {
tool_names: this.toolNames.join(', ')
})
}
}