mirror of
https://github.com/farcasclaudiu/Flowise.git
synced 2026-06-28 23:01:09 +03:00
change agent/chain with memory to use runnable
This commit is contained in:
+266
-114
@@ -1,20 +1,25 @@
|
||||
import { BaseLanguageModel } from 'langchain/base_language'
|
||||
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||
import { getBaseClasses, mapChatHistory } from '../../../src/utils'
|
||||
import { ConversationalRetrievalQAChain, QAChainParams } from 'langchain/chains'
|
||||
import { ConversationalRetrievalQAChain } from 'langchain/chains'
|
||||
import { BaseRetriever } from 'langchain/schema/retriever'
|
||||
import { BufferMemory, BufferMemoryInput } from 'langchain/memory'
|
||||
import { BufferMemoryInput } from 'langchain/memory'
|
||||
import { PromptTemplate } from 'langchain/prompts'
|
||||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import {
|
||||
default_map_reduce_template,
|
||||
default_qa_template,
|
||||
qa_template,
|
||||
map_reduce_template,
|
||||
CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT,
|
||||
refine_question_template,
|
||||
refine_template
|
||||
} from './prompts'
|
||||
import { QA_TEMPLATE, REPHRASE_TEMPLATE, RESPONSE_TEMPLATE } from './prompts'
|
||||
import { Runnable, RunnableSequence, RunnableMap, RunnableBranch, RunnableLambda } from 'langchain/schema/runnable'
|
||||
import { BaseMessage, HumanMessage, AIMessage } from 'langchain/schema'
|
||||
import { StringOutputParser } from 'langchain/schema/output_parser'
|
||||
import type { Document } from 'langchain/document'
|
||||
import { ChatPromptTemplate, MessagesPlaceholder } from 'langchain/prompts'
|
||||
import { applyPatch } from 'fast-json-patch'
|
||||
import { convertBaseMessagetoIMessage, getBaseClasses } from '../../../src/utils'
|
||||
import { ConsoleCallbackHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, MemoryMethods } from '../../../src/Interface'
|
||||
|
||||
type RetrievalChainInput = {
|
||||
chat_history: string
|
||||
question: string
|
||||
}
|
||||
|
||||
const sourceRunnableName = 'FindDocs'
|
||||
|
||||
class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
label: string
|
||||
@@ -26,11 +31,12 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
baseClasses: string[]
|
||||
description: string
|
||||
inputs: INodeParams[]
|
||||
sessionId?: string
|
||||
|
||||
constructor() {
|
||||
constructor(fields?: { sessionId?: string }) {
|
||||
this.label = 'Conversational Retrieval QA Chain'
|
||||
this.name = 'conversationalRetrievalQAChain'
|
||||
this.version = 1.0
|
||||
this.version = 2.0
|
||||
this.type = 'ConversationalRetrievalQAChain'
|
||||
this.icon = 'qa.svg'
|
||||
this.category = 'Chains'
|
||||
@@ -38,9 +44,9 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
this.baseClasses = [this.type, ...getBaseClasses(ConversationalRetrievalQAChain)]
|
||||
this.inputs = [
|
||||
{
|
||||
label: 'Language Model',
|
||||
label: 'Chat Model',
|
||||
name: 'model',
|
||||
type: 'BaseLanguageModel'
|
||||
type: 'BaseChatModel'
|
||||
},
|
||||
{
|
||||
label: 'Vector Store Retriever',
|
||||
@@ -60,6 +66,29 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
type: 'boolean',
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
label: 'Rephrase Prompt',
|
||||
name: 'rephrasePrompt',
|
||||
type: 'string',
|
||||
description: 'Using previous chat history, rephrase question into a standalone question',
|
||||
warning: 'Prompt must include input variables: {chat_history} and {question}',
|
||||
rows: 4,
|
||||
additionalParams: true,
|
||||
optional: true,
|
||||
default: REPHRASE_TEMPLATE
|
||||
},
|
||||
{
|
||||
label: 'Response Prompt',
|
||||
name: 'responsePrompt',
|
||||
type: 'string',
|
||||
description: 'Taking the rephrased question, search for answer from the provided context',
|
||||
warning: 'Prompt must include input variable: {context}',
|
||||
rows: 4,
|
||||
additionalParams: true,
|
||||
optional: true,
|
||||
default: RESPONSE_TEMPLATE
|
||||
}
|
||||
/** Deprecated
|
||||
{
|
||||
label: 'System Message',
|
||||
name: 'systemMessagePrompt',
|
||||
@@ -70,6 +99,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
placeholder:
|
||||
'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.'
|
||||
},
|
||||
// TODO: create standalone chains for these 3 modes as they are not compatible with memory
|
||||
{
|
||||
label: 'Chain Option',
|
||||
name: 'chainOption',
|
||||
@@ -95,124 +125,246 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
additionalParams: true,
|
||||
optional: true
|
||||
}
|
||||
*/
|
||||
]
|
||||
this.sessionId = fields?.sessionId
|
||||
}
|
||||
|
||||
async init(nodeData: INodeData): Promise<any> {
|
||||
const model = nodeData.inputs?.model as BaseLanguageModel
|
||||
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
|
||||
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
|
||||
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||
const chainOption = nodeData.inputs?.chainOption as string
|
||||
const externalMemory = nodeData.inputs?.memory
|
||||
const rephrasePrompt = nodeData.inputs?.rephrasePrompt as string
|
||||
const responsePrompt = nodeData.inputs?.responsePrompt as string
|
||||
|
||||
const obj: any = {
|
||||
verbose: process.env.DEBUG === 'true' ? true : false,
|
||||
questionGeneratorChainOptions: {
|
||||
template: CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT
|
||||
}
|
||||
let customResponsePrompt = responsePrompt
|
||||
// If the deprecated systemMessagePrompt is still exists
|
||||
if (systemMessagePrompt) {
|
||||
customResponsePrompt = `${systemMessagePrompt}\n${QA_TEMPLATE}`
|
||||
}
|
||||
|
||||
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
|
||||
|
||||
if (chainOption === 'map_reduce') {
|
||||
obj.qaChainOptions = {
|
||||
type: 'map_reduce',
|
||||
combinePrompt: PromptTemplate.fromTemplate(
|
||||
systemMessagePrompt ? `${systemMessagePrompt}\n${map_reduce_template}` : default_map_reduce_template
|
||||
)
|
||||
} as QAChainParams
|
||||
} else if (chainOption === 'refine') {
|
||||
const qprompt = new PromptTemplate({
|
||||
inputVariables: ['context', 'question'],
|
||||
template: refine_question_template(systemMessagePrompt)
|
||||
})
|
||||
const rprompt = new PromptTemplate({
|
||||
inputVariables: ['context', 'question', 'existing_answer'],
|
||||
template: refine_template
|
||||
})
|
||||
obj.qaChainOptions = {
|
||||
type: 'refine',
|
||||
questionPrompt: qprompt,
|
||||
refinePrompt: rprompt
|
||||
} as QAChainParams
|
||||
} else {
|
||||
obj.qaChainOptions = {
|
||||
type: 'stuff',
|
||||
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
|
||||
} as QAChainParams
|
||||
}
|
||||
|
||||
if (externalMemory) {
|
||||
externalMemory.memoryKey = 'chat_history'
|
||||
externalMemory.inputKey = 'question'
|
||||
externalMemory.outputKey = 'text'
|
||||
externalMemory.returnMessages = true
|
||||
if (chainOption === 'refine') externalMemory.outputKey = 'output_text'
|
||||
obj.memory = externalMemory
|
||||
} else {
|
||||
const fields: BufferMemoryInput = {
|
||||
memoryKey: 'chat_history',
|
||||
inputKey: 'question',
|
||||
outputKey: 'text',
|
||||
returnMessages: true
|
||||
}
|
||||
if (chainOption === 'refine') fields.outputKey = 'output_text'
|
||||
obj.memory = new BufferMemory(fields)
|
||||
}
|
||||
|
||||
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
|
||||
return chain
|
||||
const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)
|
||||
return answerChain
|
||||
}
|
||||
|
||||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
|
||||
const chain = nodeData.instance as ConversationalRetrievalQAChain
|
||||
const model = nodeData.inputs?.model as BaseLanguageModel
|
||||
const externalMemory = nodeData.inputs?.memory
|
||||
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
|
||||
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
|
||||
const rephrasePrompt = nodeData.inputs?.rephrasePrompt as string
|
||||
const responsePrompt = nodeData.inputs?.responsePrompt as string
|
||||
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||
const chainOption = nodeData.inputs?.chainOption as string
|
||||
|
||||
let model = nodeData.inputs?.model
|
||||
|
||||
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
|
||||
model.streaming = false
|
||||
chain.questionGeneratorChain.llm = model
|
||||
|
||||
const obj = { question: input }
|
||||
|
||||
if (options && options.chatHistory && chain.memory) {
|
||||
const chatHistoryClassName = (chain.memory as any).chatHistory.constructor.name
|
||||
// Only replace when its In-Memory
|
||||
if (chatHistoryClassName && chatHistoryClassName === 'ChatMessageHistory') {
|
||||
;(chain.memory as any).chatHistory = mapChatHistory(options)
|
||||
}
|
||||
let customResponsePrompt = responsePrompt
|
||||
// If the deprecated systemMessagePrompt is still exists
|
||||
if (systemMessagePrompt) {
|
||||
customResponsePrompt = `${systemMessagePrompt}\n${QA_TEMPLATE}`
|
||||
}
|
||||
|
||||
let memory: FlowiseMemory | undefined = externalMemory
|
||||
if (!memory) {
|
||||
memory = new BufferMemory({
|
||||
returnMessages: true,
|
||||
memoryKey: 'chat_history',
|
||||
inputKey: 'input'
|
||||
})
|
||||
}
|
||||
|
||||
const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)
|
||||
|
||||
const history = ((await memory.getChatMessages(this.sessionId, false, options.chatHistory)) as IMessage[]) ?? []
|
||||
|
||||
const loggerHandler = new ConsoleCallbackHandler(options.logger)
|
||||
const callbacks = await additionalCallbacks(nodeData, options)
|
||||
|
||||
if (options.socketIO && options.socketIOClientId) {
|
||||
const handler = new CustomChainHandler(
|
||||
options.socketIO,
|
||||
options.socketIOClientId,
|
||||
chainOption === 'refine' ? 4 : undefined,
|
||||
returnSourceDocuments
|
||||
)
|
||||
const res = await chain.call(obj, [loggerHandler, handler, ...callbacks])
|
||||
if (chainOption === 'refine') {
|
||||
if (res.output_text && res.sourceDocuments) {
|
||||
return {
|
||||
text: res.output_text,
|
||||
sourceDocuments: res.sourceDocuments
|
||||
}
|
||||
}
|
||||
return res?.output_text
|
||||
const stream = answerChain.streamLog(
|
||||
{ question: input, chat_history: history },
|
||||
{ callbacks: [loggerHandler, ...callbacks] },
|
||||
{
|
||||
includeNames: [sourceRunnableName]
|
||||
}
|
||||
)
|
||||
|
||||
let streamedResponse: Record<string, any> = {}
|
||||
let sourceDocuments: ICommonObject[] = []
|
||||
let text = ''
|
||||
let isStreamingStarted = false
|
||||
const isStreamingEnabled = options.socketIO && options.socketIOClientId
|
||||
|
||||
for await (const chunk of stream) {
|
||||
streamedResponse = applyPatch(streamedResponse, chunk.ops).newDocument
|
||||
|
||||
if (streamedResponse.final_output) {
|
||||
text = streamedResponse.final_output?.output
|
||||
if (isStreamingEnabled) options.socketIO.to(options.socketIOClientId).emit('end')
|
||||
if (Array.isArray(streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output)) {
|
||||
sourceDocuments = streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output
|
||||
if (isStreamingEnabled && returnSourceDocuments)
|
||||
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', sourceDocuments)
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
Array.isArray(streamedResponse?.streamed_output) &&
|
||||
streamedResponse?.streamed_output.length &&
|
||||
!streamedResponse.final_output
|
||||
) {
|
||||
const token = streamedResponse.streamed_output[streamedResponse.streamed_output.length - 1]
|
||||
|
||||
if (!isStreamingStarted) {
|
||||
isStreamingStarted = true
|
||||
if (isStreamingEnabled) options.socketIO.to(options.socketIOClientId).emit('start', token)
|
||||
}
|
||||
if (isStreamingEnabled) options.socketIO.to(options.socketIOClientId).emit('token', token)
|
||||
}
|
||||
if (res.text && res.sourceDocuments) return res
|
||||
return res?.text
|
||||
} else {
|
||||
const res = await chain.call(obj, [loggerHandler, ...callbacks])
|
||||
if (res.text && res.sourceDocuments) return res
|
||||
return res?.text
|
||||
}
|
||||
|
||||
await memory.addChatMessages(
|
||||
[
|
||||
{
|
||||
text: input,
|
||||
type: 'userMessage'
|
||||
},
|
||||
{
|
||||
text: text,
|
||||
type: 'apiMessage'
|
||||
}
|
||||
],
|
||||
this.sessionId
|
||||
)
|
||||
|
||||
if (returnSourceDocuments) return { text, sourceDocuments }
|
||||
else return { text }
|
||||
}
|
||||
}
|
||||
|
||||
const createRetrieverChain = (llm: BaseLanguageModel, retriever: Runnable, rephrasePrompt: string) => {
|
||||
// Small speed/accuracy optimization: no need to rephrase the first question
|
||||
// since there shouldn't be any meta-references to prior chat history
|
||||
const CONDENSE_QUESTION_PROMPT = PromptTemplate.fromTemplate(rephrasePrompt)
|
||||
const condenseQuestionChain = RunnableSequence.from([CONDENSE_QUESTION_PROMPT, llm, new StringOutputParser()]).withConfig({
|
||||
runName: 'CondenseQuestion'
|
||||
})
|
||||
|
||||
const hasHistoryCheckFn = RunnableLambda.from((input: RetrievalChainInput) => input.chat_history.length > 0).withConfig({
|
||||
runName: 'HasChatHistoryCheck'
|
||||
})
|
||||
|
||||
const conversationChain = condenseQuestionChain.pipe(retriever).withConfig({
|
||||
runName: 'RetrievalChainWithHistory'
|
||||
})
|
||||
|
||||
const basicRetrievalChain = RunnableLambda.from((input: RetrievalChainInput) => input.question)
|
||||
.withConfig({
|
||||
runName: 'Itemgetter:question'
|
||||
})
|
||||
.pipe(retriever)
|
||||
.withConfig({ runName: 'RetrievalChainWithNoHistory' })
|
||||
|
||||
return RunnableBranch.from([[hasHistoryCheckFn, conversationChain], basicRetrievalChain]).withConfig({ runName: sourceRunnableName })
|
||||
}
|
||||
|
||||
const formatDocs = (docs: Document[]) => {
|
||||
return docs.map((doc, i) => `<doc id='${i}'>${doc.pageContent}</doc>`).join('\n')
|
||||
}
|
||||
|
||||
const formatChatHistoryAsString = (history: BaseMessage[]) => {
|
||||
return history.map((message) => `${message._getType()}: ${message.content}`).join('\n')
|
||||
}
|
||||
|
||||
const serializeHistory = (input: any) => {
|
||||
const chatHistory: IMessage[] = input.chat_history || []
|
||||
const convertedChatHistory = []
|
||||
for (const message of chatHistory) {
|
||||
if (message.type === 'userMessage') {
|
||||
convertedChatHistory.push(new HumanMessage({ content: message.message }))
|
||||
}
|
||||
if (message.type === 'apiMessage') {
|
||||
convertedChatHistory.push(new AIMessage({ content: message.message }))
|
||||
}
|
||||
}
|
||||
return convertedChatHistory
|
||||
}
|
||||
|
||||
const createChain = (
|
||||
llm: BaseLanguageModel,
|
||||
retriever: Runnable,
|
||||
rephrasePrompt = REPHRASE_TEMPLATE,
|
||||
responsePrompt = RESPONSE_TEMPLATE
|
||||
) => {
|
||||
const retrieverChain = createRetrieverChain(llm, retriever, rephrasePrompt)
|
||||
|
||||
const context = RunnableMap.from({
|
||||
context: RunnableSequence.from([
|
||||
({ question, chat_history }) => ({
|
||||
question,
|
||||
chat_history: formatChatHistoryAsString(chat_history)
|
||||
}),
|
||||
retrieverChain,
|
||||
RunnableLambda.from(formatDocs).withConfig({
|
||||
runName: 'FormatDocumentChunks'
|
||||
})
|
||||
]),
|
||||
question: RunnableLambda.from((input: RetrievalChainInput) => input.question).withConfig({
|
||||
runName: 'Itemgetter:question'
|
||||
}),
|
||||
chat_history: RunnableLambda.from((input: RetrievalChainInput) => input.chat_history).withConfig({
|
||||
runName: 'Itemgetter:chat_history'
|
||||
})
|
||||
}).withConfig({ tags: ['RetrieveDocs'] })
|
||||
|
||||
const prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', responsePrompt],
|
||||
new MessagesPlaceholder('chat_history'),
|
||||
['human', `{question}`]
|
||||
])
|
||||
|
||||
const responseSynthesizerChain = RunnableSequence.from([prompt, llm, new StringOutputParser()]).withConfig({
|
||||
tags: ['GenerateResponse']
|
||||
})
|
||||
|
||||
const conversationalQAChain = RunnableSequence.from([
|
||||
{
|
||||
question: RunnableLambda.from((input: RetrievalChainInput) => input.question).withConfig({
|
||||
runName: 'Itemgetter:question'
|
||||
}),
|
||||
chat_history: RunnableLambda.from(serializeHistory).withConfig({
|
||||
runName: 'SerializeHistory'
|
||||
})
|
||||
},
|
||||
context,
|
||||
responseSynthesizerChain
|
||||
])
|
||||
|
||||
return conversationalQAChain
|
||||
}
|
||||
|
||||
class BufferMemory extends FlowiseMemory implements MemoryMethods {
|
||||
constructor(fields: BufferMemoryInput) {
|
||||
super(fields)
|
||||
}
|
||||
|
||||
async getChatMessages(_?: string, returnBaseMessages = false, prevHistory: IMessage[] = []): Promise<IMessage[] | BaseMessage[]> {
|
||||
await this.chatHistory.clear()
|
||||
|
||||
for (const msg of prevHistory) {
|
||||
if (msg.type === 'userMessage') await this.chatHistory.addUserMessage(msg.message)
|
||||
else if (msg.type === 'apiMessage') await this.chatHistory.addAIChatMessage(msg.message)
|
||||
}
|
||||
|
||||
const memoryResult = await this.loadMemoryVariables({})
|
||||
const baseMessages = memoryResult[this.memoryKey ?? 'chat_history']
|
||||
return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages)
|
||||
}
|
||||
|
||||
async addChatMessages(): Promise<void> {
|
||||
// adding chat messages will be done on the fly in getChatMessages()
|
||||
return
|
||||
}
|
||||
|
||||
async clearChatMessages(): Promise<void> {
|
||||
await this.clear()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user