diff --git a/flytepropeller/pkg/controller/executors/failure_node_lookup.go b/flytepropeller/pkg/controller/executors/failure_node_lookup.go index 15777d5582..a517ba7ead 100644 --- a/flytepropeller/pkg/controller/executors/failure_node_lookup.go +++ b/flytepropeller/pkg/controller/executors/failure_node_lookup.go @@ -6,41 +6,38 @@ import ( ) type FailureNodeLookup struct { - NodeSpec *v1alpha1.NodeSpec - NodeStatus v1alpha1.ExecutableNodeStatus - StartNode v1alpha1.ExecutableNode - StartNodeStatus v1alpha1.ExecutableNodeStatus + NodeLookup + FailureNode v1alpha1.ExecutableNode + FailureNodeStatus v1alpha1.ExecutableNodeStatus } func (f FailureNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { if nodeID == v1alpha1.StartNodeID { - return f.StartNode, true + return f.NodeLookup.GetNode(nodeID) } - return f.NodeSpec, true + return f.FailureNode, true } func (f FailureNodeLookup) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { if id == v1alpha1.StartNodeID { - return f.StartNodeStatus + return f.NodeLookup.GetNodeExecutionStatus(ctx, id) } - return f.NodeStatus + return f.FailureNodeStatus } func (f FailureNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + // The upstream node of the failure node is always the start node return []v1alpha1.NodeID{v1alpha1.StartNodeID}, nil } func (f FailureNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { - return nil, nil + return []v1alpha1.NodeID{v1alpha1.EndNodeID}, nil } -func NewFailureNodeLookup(nodeSpec *v1alpha1.NodeSpec, startNode v1alpha1.ExecutableNode, nodeStatusGetter v1alpha1.NodeStatusGetter) NodeLookup { - startNodeStatus := nodeStatusGetter.GetNodeExecutionStatus(context.TODO(), v1alpha1.StartNodeID) - errNodeStatus := nodeStatusGetter.GetNodeExecutionStatus(context.TODO(), nodeSpec.GetID()) +func NewFailureNodeLookup(nodeLookup NodeLookup, failureNode v1alpha1.ExecutableNode, failureNodeStatus v1alpha1.ExecutableNodeStatus) NodeLookup { return FailureNodeLookup{ - NodeSpec: nodeSpec, - NodeStatus: errNodeStatus, - StartNode: startNode, - StartNodeStatus: startNodeStatus, + NodeLookup: nodeLookup, + FailureNode: failureNode, + FailureNodeStatus: failureNodeStatus, } } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go index 10e48358dd..6ea22b9637 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go @@ -143,12 +143,18 @@ func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx interfaces.No func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { originalError := nCtx.NodeStateReader().GetWorkflowNodeState().Error - if subworkflow.GetOnFailureNode() != nil { + failureNode := subworkflow.GetOnFailureNode() + if failureNode != nil { execContext, err := s.getExecutionContextForDownstream(nCtx) if err != nil { return handler.UnknownTransition, err } - state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.GetOnFailureNode()) + subNodeLookup := nCtx.ContextualNodeLookup() + // TODO: NodeStatus() is deprecated, how do we get the status of the failure node? + failureNodeStatus := nCtx.NodeStatus().GetNodeExecutionStatus(ctx, failureNode.GetID()) + failureNodeLookup := executors.NewFailureNodeLookup(subNodeLookup, failureNode, failureNodeStatus) + + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, failureNodeLookup, failureNodeLookup, failureNode) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } diff --git a/flytepropeller/pkg/controller/workflow/executor.go b/flytepropeller/pkg/controller/workflow/executor.go index 4a48a3d7f0..2b0b850905 100644 --- a/flytepropeller/pkg/controller/workflow/executor.go +++ b/flytepropeller/pkg/controller/workflow/executor.go @@ -169,18 +169,14 @@ func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { execErr := executionErrorOrDefault(w.GetExecutionStatus().GetExecutionError(), w.GetExecutionStatus().GetMessage()) - errorNode := w.GetOnFailureNode() - logger.Infof(ctx, "Handling FailureNode [%v]", errorNode) + failureNode := w.GetOnFailureNode() + logger.Infof(ctx, "Handling FailureNode [%v]", failureNode.GetID()) execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) - // TODO: GetNodeExecutionStatus doesn't work. How do we get the error node status from CRD - startNode, _ := w.GetNode(v1alpha1.StartNodeID) - failureNodeLookup := executors.NewFailureNodeLookup(errorNode.(*v1alpha1.NodeSpec), startNode, w.GetExecutionStatus()) - state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, failureNodeLookup, failureNodeLookup, errorNode) - logger.Infof(ctx, "FailureNode [%v] finished with state [%v]", errorNode, state) - logger.Infof(ctx, "FailureNode [%v] finished with error [%v]", errorNode, err) + failureNodeStatus := w.GetExecutionStatus().GetNodeExecutionStatus(ctx, failureNode.GetID()) + failureNodeLookup := executors.NewFailureNodeLookup(w, failureNode, failureNodeStatus) + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, failureNodeLookup, failureNodeLookup, failureNode) if err != nil { - logger.Infof(ctx, "test") return StatusFailureNode(execErr), err } @@ -202,8 +198,6 @@ func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.Fl return StatusFailureNode(execErr), nil } - logger.Infof(ctx, "test2") - // If the failure node finished executing, transition to failed. return StatusFailed(execErr), nil }