Optimize getEndingNodes (#2133)

This commit is contained in:
YISH
2024-04-19 09:12:24 +08:00
committed by GitHub
parent 6bd8aaefc8
commit b7e4fc9517
3 changed files with 69 additions and 57 deletions
+48 -2
View File
@@ -41,6 +41,8 @@ import { Assistant } from '../database/entities/Assistant'
import { DataSource } from 'typeorm'
import { CachePool } from '../CachePool'
import { Variable } from '../database/entities/Variable'
import { InternalFlowiseError } from '../errors/internalFlowiseError'
import { StatusCodes } from 'http-status-codes'
const QUESTION_VAR_PREFIX = 'question'
const CHAT_HISTORY_VAR_PREFIX = 'chat_history'
@@ -224,8 +226,13 @@ export const getAllConnectedNodes = (graph: INodeDirectedGraph, startNodeId: str
* Get ending node and check if flow is valid
* @param {INodeDependencies} nodeDependencies
* @param {INodeDirectedGraph} graph
* @param {IReactFlowNode[]} allNodes
*/
export const getEndingNodes = (nodeDependencies: INodeDependencies, graph: INodeDirectedGraph) => {
export const getEndingNodes = (
nodeDependencies: INodeDependencies,
graph: INodeDirectedGraph,
allNodes: IReactFlowNode[]
): IReactFlowNode[] => {
const endingNodeIds: string[] = []
Object.keys(graph).forEach((nodeId) => {
if (Object.keys(nodeDependencies).length === 1) {
@@ -234,7 +241,46 @@ export const getEndingNodes = (nodeDependencies: INodeDependencies, graph: INode
endingNodeIds.push(nodeId)
}
})
return endingNodeIds
let endingNodes = allNodes.filter((nd) => endingNodeIds.includes(nd.id))
// If there are multiple endingnodes, the failed ones will be automatically ignored.
// And only ensure that at least one can pass the verification.
const verifiedEndingNodes: typeof endingNodes = []
let error: InternalFlowiseError | null = null
for (const endingNode of endingNodes) {
const endingNodeData = endingNode.data
if (!endingNodeData) {
error = new InternalFlowiseError(StatusCodes.INTERNAL_SERVER_ERROR, `Ending node ${endingNode.id} data not found`)
continue
}
const isEndingNode = endingNodeData?.outputs?.output === 'EndingNode'
if (!isEndingNode) {
if (
endingNodeData &&
endingNodeData.category !== 'Chains' &&
endingNodeData.category !== 'Agents' &&
endingNodeData.category !== 'Engine'
) {
error = new InternalFlowiseError(StatusCodes.INTERNAL_SERVER_ERROR, `Ending node must be either a Chain or Agent`)
continue
}
}
verifiedEndingNodes.push(endingNode)
}
if (verifiedEndingNodes.length > 0) {
return verifiedEndingNodes
}
if (endingNodes.length === 0 || error === null) {
error = new InternalFlowiseError(StatusCodes.INTERNAL_SERVER_ERROR, `Ending nodes not found`)
}
throw error
}
/**