mirror of
https://github.com/farcasclaudiu/Flowise.git
synced 2026-06-28 19:00:59 +03:00
Feature/Add Neo4j GraphRag support (#3686)
* added: Neo4j database connectivity, Neo4j credentials, supports the usage of the GraphCypherQaChain node and modifies the FewShotPromptTemplate node to handle variables from the prefix field. * Merge branch 'main' of github.com:FlowiseAI/Flowise into feature/graphragsupport * revert pnpm-lock.yaml * add: neo4j package * Refactor GraphCypherQAChain: Update version to 1.0, remove memory input, and enhance prompt handling - Changed version from 2.0 to 1.0. - Removed the 'Memory' input parameter from the GraphCypherQAChain. - Made 'cypherPrompt' optional and improved error handling for prompt validation. - Updated the 'init' and 'run' methods to streamline input processing and response handling. - Enhanced streaming response logic based on the 'returnDirect' flag. * Refactor GraphCypherQAChain: Simplify imports and update init method signature - Consolidated import statements for better readability. - Removed the 'input' and 'options' parameters from the 'init' method, streamlining its signature to only accept 'nodeData'. * add output, format final response, fix optional inputs --------- Co-authored-by: Henry <hzj94@hotmail.com>
This commit is contained in:
committed by
GitHub
parent
93f3a5d98a
commit
a7c1ab881c
@@ -0,0 +1,256 @@
|
||||
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface'
|
||||
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher'
|
||||
import { getBaseClasses } from '../../../src/utils'
|
||||
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts'
|
||||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
|
||||
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
|
||||
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
|
||||
|
||||
class GraphCypherQA_Chain implements INode {
|
||||
label: string
|
||||
name: string
|
||||
version: number
|
||||
type: string
|
||||
icon: string
|
||||
category: string
|
||||
description: string
|
||||
baseClasses: string[]
|
||||
inputs: INodeParams[]
|
||||
sessionId?: string
|
||||
outputs: INodeOutputsValue[]
|
||||
|
||||
constructor(fields?: { sessionId?: string }) {
|
||||
this.label = 'Graph Cypher QA Chain'
|
||||
this.name = 'graphCypherQAChain'
|
||||
this.version = 1.0
|
||||
this.type = 'GraphCypherQAChain'
|
||||
this.icon = 'graphqa.svg'
|
||||
this.category = 'Chains'
|
||||
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements'
|
||||
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)]
|
||||
this.sessionId = fields?.sessionId
|
||||
this.inputs = [
|
||||
{
|
||||
label: 'Language Model',
|
||||
name: 'model',
|
||||
type: 'BaseLanguageModel',
|
||||
description: 'Model for generating Cypher queries and answers.'
|
||||
},
|
||||
{
|
||||
label: 'Neo4j Graph',
|
||||
name: 'graph',
|
||||
type: 'Neo4j'
|
||||
},
|
||||
{
|
||||
label: 'Cypher Generation Prompt',
|
||||
name: 'cypherPrompt',
|
||||
optional: true,
|
||||
type: 'BasePromptTemplate',
|
||||
description: 'Prompt template for generating Cypher queries. Must include {schema} and {question} variables'
|
||||
},
|
||||
{
|
||||
label: 'Cypher Generation Model',
|
||||
name: 'cypherModel',
|
||||
optional: true,
|
||||
type: 'BaseLanguageModel',
|
||||
description: 'Model for generating Cypher queries. If not provided, the main model will be used.'
|
||||
},
|
||||
{
|
||||
label: 'QA Prompt',
|
||||
name: 'qaPrompt',
|
||||
optional: true,
|
||||
type: 'BasePromptTemplate',
|
||||
description: 'Prompt template for generating answers. Must include {context} and {question} variables'
|
||||
},
|
||||
{
|
||||
label: 'QA Model',
|
||||
name: 'qaModel',
|
||||
optional: true,
|
||||
type: 'BaseLanguageModel',
|
||||
description: 'Model for generating answers. If not provided, the main model will be used.'
|
||||
},
|
||||
{
|
||||
label: 'Input Moderation',
|
||||
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
|
||||
name: 'inputModeration',
|
||||
type: 'Moderation',
|
||||
optional: true,
|
||||
list: true
|
||||
},
|
||||
{
|
||||
label: 'Return Direct',
|
||||
name: 'returnDirect',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
optional: true,
|
||||
description: 'If true, return the raw query results instead of using the QA chain'
|
||||
}
|
||||
]
|
||||
this.outputs = [
|
||||
{
|
||||
label: 'Graph Cypher QA Chain',
|
||||
name: 'graphCypherQAChain',
|
||||
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)]
|
||||
},
|
||||
{
|
||||
label: 'Output Prediction',
|
||||
name: 'outputPrediction',
|
||||
baseClasses: ['string', 'json']
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
|
||||
const model = nodeData.inputs?.model
|
||||
const cypherModel = nodeData.inputs?.cypherModel
|
||||
const qaModel = nodeData.inputs?.qaModel
|
||||
const graph = nodeData.inputs?.graph
|
||||
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined
|
||||
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined
|
||||
const returnDirect = nodeData.inputs?.returnDirect as boolean
|
||||
const output = nodeData.outputs?.output as string
|
||||
|
||||
// Handle prompt values if they exist
|
||||
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined
|
||||
let qaPromptTemplate: PromptTemplate | undefined
|
||||
|
||||
if (cypherPrompt) {
|
||||
if (cypherPrompt instanceof PromptTemplate) {
|
||||
cypherPromptTemplate = new PromptTemplate({
|
||||
template: cypherPrompt.template as string,
|
||||
inputVariables: cypherPrompt.inputVariables
|
||||
})
|
||||
if (!qaPrompt) {
|
||||
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template')
|
||||
}
|
||||
} else if (cypherPrompt instanceof FewShotPromptTemplate) {
|
||||
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate
|
||||
cypherPromptTemplate = new FewShotPromptTemplate({
|
||||
examples: cypherPrompt.examples,
|
||||
examplePrompt: examplePrompt,
|
||||
inputVariables: cypherPrompt.inputVariables,
|
||||
prefix: cypherPrompt.prefix,
|
||||
suffix: cypherPrompt.suffix,
|
||||
exampleSeparator: cypherPrompt.exampleSeparator,
|
||||
templateFormat: cypherPrompt.templateFormat
|
||||
})
|
||||
} else {
|
||||
cypherPromptTemplate = cypherPrompt as PromptTemplate
|
||||
}
|
||||
}
|
||||
|
||||
if (qaPrompt instanceof PromptTemplate) {
|
||||
qaPromptTemplate = new PromptTemplate({
|
||||
template: qaPrompt.template as string,
|
||||
inputVariables: qaPrompt.inputVariables
|
||||
})
|
||||
}
|
||||
|
||||
if ((!cypherModel || !qaModel) && !model) {
|
||||
throw new Error('Language Model is required when Cypher Model or QA Model are not provided')
|
||||
}
|
||||
|
||||
// Validate required variables in prompts
|
||||
if (
|
||||
cypherPromptTemplate &&
|
||||
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question'))
|
||||
) {
|
||||
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables')
|
||||
}
|
||||
|
||||
const fromLLMInput: FromLLMInput = {
|
||||
llm: model,
|
||||
graph,
|
||||
returnDirect
|
||||
}
|
||||
|
||||
if (cypherModel && cypherPromptTemplate) {
|
||||
fromLLMInput['cypherLLM'] = cypherModel
|
||||
fromLLMInput['cypherPrompt'] = cypherPromptTemplate
|
||||
}
|
||||
|
||||
if (qaModel && qaPromptTemplate) {
|
||||
fromLLMInput['qaLLM'] = qaModel
|
||||
fromLLMInput['qaPrompt'] = qaPromptTemplate
|
||||
}
|
||||
|
||||
const chain = GraphCypherQAChain.fromLLM(fromLLMInput)
|
||||
|
||||
if (output === this.name) {
|
||||
return chain
|
||||
} else if (output === 'outputPrediction') {
|
||||
nodeData.instance = chain
|
||||
return await this.run(nodeData, input, options)
|
||||
}
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
|
||||
const chain = nodeData.instance as GraphCypherQAChain
|
||||
const moderations = nodeData.inputs?.inputModeration as Moderation[]
|
||||
const returnDirect = nodeData.inputs?.returnDirect as boolean
|
||||
|
||||
const shouldStreamResponse = options.shouldStreamResponse
|
||||
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
|
||||
const chatId = options.chatId
|
||||
|
||||
// Handle input moderation if configured
|
||||
if (moderations && moderations.length > 0) {
|
||||
try {
|
||||
input = await checkInputs(moderations, input)
|
||||
} catch (e) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
if (shouldStreamResponse) {
|
||||
streamResponse(sseStreamer, chatId, e.message)
|
||||
}
|
||||
return formatResponse(e.message)
|
||||
}
|
||||
}
|
||||
|
||||
const obj = {
|
||||
query: input
|
||||
}
|
||||
|
||||
const loggerHandler = new ConsoleCallbackHandler(options.logger)
|
||||
const callbackHandlers = await additionalCallbacks(nodeData, options)
|
||||
let callbacks = [loggerHandler, ...callbackHandlers]
|
||||
|
||||
if (process.env.DEBUG === 'true') {
|
||||
callbacks.push(new LCConsoleCallbackHandler())
|
||||
}
|
||||
|
||||
try {
|
||||
let response
|
||||
if (shouldStreamResponse) {
|
||||
if (returnDirect) {
|
||||
response = await chain.invoke(obj, { callbacks })
|
||||
let result = response?.result
|
||||
if (typeof result === 'object') {
|
||||
result = '```json\n' + JSON.stringify(result, null, 2)
|
||||
}
|
||||
if (result && typeof result === 'string') {
|
||||
streamResponse(sseStreamer, chatId, result)
|
||||
}
|
||||
} else {
|
||||
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
|
||||
callbacks.push(handler)
|
||||
response = await chain.invoke(obj, { callbacks })
|
||||
}
|
||||
} else {
|
||||
response = await chain.invoke(obj, { callbacks })
|
||||
}
|
||||
|
||||
return formatResponse(response?.result)
|
||||
} catch (error) {
|
||||
console.error('Error in GraphCypherQAChain:', error)
|
||||
if (shouldStreamResponse) {
|
||||
streamResponse(sseStreamer, chatId, error.message)
|
||||
}
|
||||
return formatResponse(`Error: ${error.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { nodeClass: GraphCypherQA_Chain }
|
||||
@@ -0,0 +1,22 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="24px" height="24px" viewBox="0 0 24 24" version="1.1" xmlns="http://www.w3.org/2000/svg">
|
||||
<g stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<!-- Graph Nodes and Edges -->
|
||||
<circle fill="#4CAF50" cx="6" cy="6" r="2.5"/>
|
||||
<circle fill="#4CAF50" cx="18" cy="6" r="2.5"/>
|
||||
<circle fill="#4CAF50" cx="6" cy="18" r="2.5"/>
|
||||
<circle fill="#4CAF50" cx="18" cy="18" r="2.5"/>
|
||||
|
||||
<!-- Graph Connections -->
|
||||
<line x1="6" y1="6" x2="18" y2="6" stroke="#4CAF50" stroke-width="1.5"/>
|
||||
<line x1="6" y1="6" x2="6" y2="18" stroke="#4CAF50" stroke-width="1.5"/>
|
||||
<line x1="18" y1="6" x2="18" y2="18" stroke="#4CAF50" stroke-width="1.5"/>
|
||||
<line x1="6" y1="18" x2="18" y2="18" stroke="#4CAF50" stroke-width="1.5"/>
|
||||
|
||||
<!-- Question Mark -->
|
||||
<path d="M12,8 C13.1045695,8 14,8.8954305 14,10 C14,10.7403567 13.5978014,11.3866184 13,11.7324555 L13,13 C13,13.5522847 12.5522847,14 12,14 C11.4477153,14 11,13.5522847 11,13 L11,11 C11,10.4477153 11.4477153,10 12,10 C12.5522847,10 13,10.4477153 13,11 C13,11.5522847 12.5522847,12 12,12"
|
||||
fill="#2196F3"
|
||||
fill-rule="nonzero"/>
|
||||
<circle fill="#2196F3" cx="12" cy="16" r="1"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
Reference in New Issue
Block a user