diff --git a/flytepropeller/pkg/controller/executors/failure_node_lookup.go b/flytepropeller/pkg/controller/executors/failure_node_lookup.go index e169f6fc27..15777d5582 100644 --- a/flytepropeller/pkg/controller/executors/failure_node_lookup.go +++ b/flytepropeller/pkg/controller/executors/failure_node_lookup.go @@ -6,15 +6,23 @@ import ( ) type FailureNodeLookup struct { - NodeSpec *v1alpha1.NodeSpec - NodeStatus v1alpha1.ExecutableNodeStatus + NodeSpec *v1alpha1.NodeSpec + NodeStatus v1alpha1.ExecutableNodeStatus + StartNode v1alpha1.ExecutableNode + StartNodeStatus v1alpha1.ExecutableNodeStatus } func (f FailureNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + if nodeID == v1alpha1.StartNodeID { + return f.StartNode, true + } return f.NodeSpec, true } func (f FailureNodeLookup) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { + if id == v1alpha1.StartNodeID { + return f.StartNodeStatus + } return f.NodeStatus } @@ -26,9 +34,13 @@ func (f FailureNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, erro return nil, nil } -func NewFailureNodeLookup(nodeSpec *v1alpha1.NodeSpec, nodeStatus v1alpha1.ExecutableNodeStatus) NodeLookup { +func NewFailureNodeLookup(nodeSpec *v1alpha1.NodeSpec, startNode v1alpha1.ExecutableNode, nodeStatusGetter v1alpha1.NodeStatusGetter) NodeLookup { + startNodeStatus := nodeStatusGetter.GetNodeExecutionStatus(context.TODO(), v1alpha1.StartNodeID) + errNodeStatus := nodeStatusGetter.GetNodeExecutionStatus(context.TODO(), nodeSpec.GetID()) return FailureNodeLookup{ - NodeSpec: nodeSpec, - NodeStatus: nodeStatus, + NodeSpec: nodeSpec, + NodeStatus: errNodeStatus, + StartNode: startNode, + StartNodeStatus: startNodeStatus, } } diff --git a/flytepropeller/pkg/controller/workflow/executor.go b/flytepropeller/pkg/controller/workflow/executor.go index d92c0cd939..f782c30d0a 100644 --- a/flytepropeller/pkg/controller/workflow/executor.go +++ b/flytepropeller/pkg/controller/workflow/executor.go @@ -174,8 +174,7 @@ func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.Fl execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) // TODO: GetNodeExecutionStatus doesn't work. How do we get the error node status from CRD - status := w.GetExecutionStatus().GetNodeExecutionStatus(ctx, errorNode.GetID()) - failureNodeLookup := executors.NewFailureNodeLookup(errorNode.(*v1alpha1.NodeSpec), status) + failureNodeLookup := executors.NewFailureNodeLookup(errorNode.(*v1alpha1.NodeSpec), w.GetNode(v1alpha1.StartNodeID), w.GetExecutionStatus()) state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, failureNodeLookup, errorNode) logger.Infof(ctx, "FailureNode [%v] finished with state [%v]", errorNode, state) logger.Infof(ctx, "FailureNode [%v] finished with error [%v]", errorNode, err)