update query engine tool

This commit is contained in:
Henry
2024-02-14 18:40:35 +08:00
parent e2df5e9e01
commit 778e024c02
5 changed files with 581 additions and 468 deletions
@@ -33,13 +33,13 @@ class SubQuestionQueryEngine_LlamaIndex implements INode {
constructor(fields?: { sessionId?: string }) {
this.label = 'Sub Question Query Engine'
this.name = 'subQuestionQueryEngine'
this.version = 1.0
this.version = 2.0
this.type = 'SubQuestionQueryEngine'
this.icon = 'subQueryEngine.svg'
this.category = 'Engine'
this.description =
'Breaks complex query into sub questions for each relevant data source, then gather all the intermediate reponses and synthesizes a final response'
this.baseClasses = [this.type]
this.baseClasses = [this.type, 'BaseQueryEngine']
this.tags = ['LlamaIndex']
this.inputs = [
{
@@ -76,85 +76,13 @@ class SubQuestionQueryEngine_LlamaIndex implements INode {
this.sessionId = fields?.sessionId
}
async init(): Promise<any> {
return null
async init(nodeData: INodeData): Promise<any> {
return prepareEngine(nodeData)
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const embeddings = nodeData.inputs?.embeddings as BaseEmbedding
const model = nodeData.inputs?.model
const serviceContext = serviceContextFromDefaults({
llm: model,
embedModel: embeddings
})
let queryEngineTools = nodeData.inputs?.queryEngineTools as QueryEngineTool[]
queryEngineTools = flatten(queryEngineTools)
let queryEngine = SubQuestionQueryEngine.fromDefaults({
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
const responseSynthesizerObj = nodeData.inputs?.responseSynthesizer
if (responseSynthesizerObj) {
if (responseSynthesizerObj.type === 'TreeSummarize') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new TreeSummarize(serviceContext, responseSynthesizerObj.textQAPromptTemplate),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'CompactAndRefine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'Refine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new Refine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'SimpleResponseBuilder') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new SimpleResponseBuilder(serviceContext),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
}
}
const queryEngine = prepareEngine(nodeData)
let text = ''
let sourceDocuments: ICommonObject[] = []
@@ -190,4 +118,82 @@ class SubQuestionQueryEngine_LlamaIndex implements INode {
}
}
const prepareEngine = (nodeData: INodeData) => {
const embeddings = nodeData.inputs?.embeddings as BaseEmbedding
const model = nodeData.inputs?.model
const serviceContext = serviceContextFromDefaults({
llm: model,
embedModel: embeddings
})
let queryEngineTools = nodeData.inputs?.queryEngineTools as QueryEngineTool[]
queryEngineTools = flatten(queryEngineTools)
let queryEngine = SubQuestionQueryEngine.fromDefaults({
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
const responseSynthesizerObj = nodeData.inputs?.responseSynthesizer
if (responseSynthesizerObj) {
if (responseSynthesizerObj.type === 'TreeSummarize') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new TreeSummarize(serviceContext, responseSynthesizerObj.textQAPromptTemplate),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'CompactAndRefine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'Refine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new Refine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'SimpleResponseBuilder') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new SimpleResponseBuilder(serviceContext),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
}
}
return queryEngine
}
module.exports = { nodeClass: SubQuestionQueryEngine_LlamaIndex }