- add refine chain type custom system message

- allow all memory type to conver_re_qa_chain
- fix startsWith undefined where override config file is not passed
- ability to generate new session id when sharing chatbot
This commit is contained in:
Henry
2023-07-24 19:11:39 +01:00
parent e237a5f18c
commit bcf3946efb
10 changed files with 373 additions and 240 deletions
@@ -1,9 +1,9 @@
import { BaseLanguageModel } from 'langchain/base_language'
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { ConversationalRetrievalQAChain, QAChainParams } from 'langchain/chains'
import { AIMessage, BaseRetriever, HumanMessage } from 'langchain/schema'
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
import { BaseChatMemory, BufferMemory, ChatMessageHistory, BufferMemoryInput } from 'langchain/memory'
import { PromptTemplate } from 'langchain/prompts'
import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler'
import {
@@ -11,7 +11,9 @@ import {
default_qa_template,
qa_template,
map_reduce_template,
CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT
CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT,
refine_question_template,
refine_template
} from './prompts'
class ConversationalRetrievalQAChain_Chains implements INode {
@@ -46,9 +48,9 @@ class ConversationalRetrievalQAChain_Chains implements INode {
{
label: 'Memory',
name: 'memory',
type: 'DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory',
type: 'BaseMemory',
optional: true,
description: 'If no memory connected, default BufferMemory will be used'
description: 'If left empty, a default BufferMemory will be used'
},
{
label: 'Return Source Documents',
@@ -115,28 +117,43 @@ class ConversationalRetrievalQAChain_Chains implements INode {
combinePrompt: PromptTemplate.fromTemplate(
systemMessagePrompt ? `${systemMessagePrompt}\n${map_reduce_template}` : default_map_reduce_template
)
}
} as QAChainParams
} else if (chainOption === 'refine') {
// TODO: Add custom system message
const qprompt = new PromptTemplate({
inputVariables: ['context', 'question'],
template: refine_question_template(systemMessagePrompt)
})
const rprompt = new PromptTemplate({
inputVariables: ['context', 'question', 'existing_answer'],
template: refine_template
})
obj.qaChainOptions = {
type: 'refine',
questionPrompt: qprompt,
refinePrompt: rprompt
} as QAChainParams
} else {
obj.qaChainOptions = {
type: 'stuff',
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
}
} as QAChainParams
}
if (memory) {
memory.inputKey = 'question'
memory.outputKey = 'text'
memory.memoryKey = 'chat_history'
if (chainOption === 'refine') memory.outputKey = 'output_text'
else memory.outputKey = 'text'
obj.memory = memory
} else {
obj.memory = new BufferMemory({
const fields: BufferMemoryInput = {
memoryKey: 'chat_history',
inputKey: 'question',
outputKey: 'text',
returnMessages: true
})
}
if (chainOption === 'refine') fields.outputKey = 'output_text'
else fields.outputKey = 'text'
obj.memory = new BufferMemory(fields)
}
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
@@ -147,6 +164,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const chain = nodeData.instance as ConversationalRetrievalQAChain
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const memory = nodeData.inputs?.memory
const chainOption = nodeData.inputs?.chainOption as string
let model = nodeData.inputs?.model
@@ -176,8 +194,22 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const loggerHandler = new ConsoleCallbackHandler(options.logger)
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, undefined, returnSourceDocuments)
const handler = new CustomChainHandler(
options.socketIO,
options.socketIOClientId,
chainOption === 'refine' ? 4 : undefined,
returnSourceDocuments
)
const res = await chain.call(obj, [loggerHandler, handler])
if (chainOption === 'refine') {
if (res.output_text && res.sourceDocuments) {
return {
text: res.output_text,
sourceDocuments: res.sourceDocuments
}
}
return res?.output_text
}
if (res.text && res.sourceDocuments) return res
return res?.text
} else {