Merge pull request #274 from FlowiseAI/feature/ConversationalRetrievalQAChain_Prompt

Feature/Return SourceDocuments
This commit is contained in:
Henry Heng
2023-06-09 09:59:24 +01:00
committed by GitHub
11 changed files with 371 additions and 91 deletions
@@ -2,7 +2,9 @@ import { BaseLanguageModel } from 'langchain/base_language'
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { BaseRetriever } from 'langchain/schema'
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
import { PromptTemplate } from 'langchain/prompts'
const default_qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -47,6 +49,12 @@ class ConversationalRetrievalQAChain_Chains implements INode {
name: 'vectorStoreRetriever',
type: 'BaseRetriever'
},
{
label: 'Return Source Documents',
name: 'returnSourceDocuments',
type: 'boolean',
optional: true
},
{
label: 'System Message',
name: 'systemMessagePrompt',
@@ -56,6 +64,31 @@ class ConversationalRetrievalQAChain_Chains implements INode {
optional: true,
placeholder:
'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.'
},
{
label: 'Chain Option',
name: 'chainOption',
type: 'options',
options: [
{
label: 'MapReduceDocumentsChain',
name: 'map_reduce',
description:
'Suitable for QA tasks over larger documents and can run the preprocessing step in parallel, reducing the running time'
},
{
label: 'RefineDocumentsChain',
name: 'refine',
description: 'Suitable for QA tasks over a large number of documents.'
},
{
label: 'StuffDocumentsChain',
name: 'stuff',
description: 'Suitable for QA tasks over a small number of documents.'
}
],
additionalParams: true,
optional: true
}
]
}
@@ -64,44 +97,64 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const model = nodeData.inputs?.model as BaseLanguageModel
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const chainOption = nodeData.inputs?.chainOption as string
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, {
const obj: any = {
verbose: process.env.DEBUG === 'true' ? true : false,
qaTemplate: systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template
})
qaChainOptions: {
type: 'stuff',
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
},
memory: new BufferMemory({
memoryKey: 'chat_history',
inputKey: 'question',
outputKey: 'text',
returnMessages: true
})
}
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const chain = nodeData.instance as ConversationalRetrievalQAChain
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
let model = nodeData.inputs?.model
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
model.streaming = false
chain.questionGeneratorChain.llm = model
let chatHistory = ''
const obj = { question: input }
if (options && options.chatHistory) {
if (chain.memory && options && options.chatHistory) {
const chatHistory = []
const histories: IMessage[] = options.chatHistory
chatHistory = histories
.map((item) => {
return item.message
})
.join('')
}
const memory = chain.memory as BaseChatMemory
const obj = {
question: input,
chat_history: chatHistory ? chatHistory : []
for (const message of histories) {
if (message.type === 'apiMessage') {
chatHistory.push(new AIChatMessage(message.message))
} else if (message.type === 'userMessage') {
chatHistory.push(new HumanChatMessage(message.message))
}
}
memory.chatHistory = new ChatMessageHistory(chatHistory)
chain.memory = memory
}
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, undefined, returnSourceDocuments)
const res = await chain.call(obj, [handler])
if (res.text && res.sourceDocuments) return res
return res?.text
} else {
const res = await chain.call(obj)
if (res.text && res.sourceDocuments) return res
return res?.text
}
}
+1 -1
View File
@@ -31,7 +31,7 @@
"faiss-node": "^0.2.1",
"form-data": "^4.0.0",
"graphql": "^16.6.0",
"langchain": "^0.0.84",
"langchain": "^0.0.91",
"linkifyjs": "^4.1.1",
"mammoth": "^1.5.1",
"moment": "^2.29.3",
+1 -1
View File
@@ -75,7 +75,7 @@ export interface INode extends INodeProperties {
inputs?: INodeParams[]
output?: INodeOutputsValue[]
init?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<any>
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string>
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string | ICommonObject>
}
export interface INodeData extends INodeProperties {
+10 -1
View File
@@ -4,6 +4,7 @@ import * as fs from 'fs'
import * as path from 'path'
import { BaseCallbackHandler } from 'langchain/callbacks'
import { Server } from 'socket.io'
import { ChainValues } from 'langchain/dist/schema'
export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}}
export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank
@@ -208,12 +209,14 @@ export class CustomChainHandler extends BaseCallbackHandler {
socketIO: Server
socketIOClientId = ''
skipK = 0 // Skip streaming for first K numbers of handleLLMStart
returnSourceDocuments = false
constructor(socketIO: Server, socketIOClientId: string, skipK?: number) {
constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean) {
super()
this.socketIO = socketIO
this.socketIOClientId = socketIOClientId
this.skipK = skipK ?? this.skipK
this.returnSourceDocuments = returnSourceDocuments ?? this.returnSourceDocuments
}
handleLLMStart() {
@@ -233,6 +236,12 @@ export class CustomChainHandler extends BaseCallbackHandler {
handleLLMEnd() {
this.socketIO.to(this.socketIOClientId).emit('end')
}
handleChainEnd(outputs: ChainValues): void | Promise<void> {
if (this.returnSourceDocuments) {
this.socketIO.to(this.socketIOClientId).emit('sourceDocuments', outputs?.sourceDocuments)
}
}
}
export const returnJSONStr = (jsonStr: string): string => {