Skip to content

Commit

Permalink
Fixes nested pipelines with multiple artifact, or parameter outputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
boarder7395 authored and droctothorpe committed Sep 26, 2024
1 parent 217ff0d commit 2f0af61
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,9 +760,10 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
// Handle writing output parameters to MLMD.
outputParameters := opts.Component.GetDag().GetOutputs().GetParameters()
glog.V(4).Info("outputParameters: ", outputParameters)
for _, value := range outputParameters {
outputParameterKey := value.GetValueFromParameter().OutputParameterKey
producerSubTask := value.GetValueFromParameter().ProducerSubtask
ecfg.OutputParameters = make(map[string]*structpb.Value)
for name, value := range outputParameters {
outputParameterKey := value.GetValueFromParameter().GetOutputParameterKey()
producerSubTask := value.GetValueFromParameter().GetProducerSubtask()
glog.V(4).Info("outputParameterKey: ", outputParameterKey)
glog.V(4).Info("producerSubtask: ", producerSubTask)

Expand All @@ -773,9 +774,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E

outputParameterStruct, _ := structpb.NewValue(outputParameterMap)

ecfg.OutputParameters = map[string]*structpb.Value{
value.GetValueFromParameter().OutputParameterKey: outputParameterStruct,
}
ecfg.OutputParameters[name] = outputParameterStruct
}

// Handle writing output artifacts to MLMD.
Expand Down Expand Up @@ -1198,6 +1197,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,

// Handle artifacts.
for name, artifactSpec := range task.GetInputs().GetArtifacts() {
glog.V(4).Infof("inputs: %#v", task.GetInputs())
glog.V(4).Infof("artifacts: %#v", task.GetInputs().GetArtifacts())
artifactError := func(err error) error {
return fmt.Errorf("failed to resolve input artifact %s with spec %s: %w", name, artifactSpec, err)
}
Expand Down Expand Up @@ -1269,6 +1270,7 @@ func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error {
// The producer is the task that produces the output that we need to
// consume.
producer := tasks[taskOutput.GetProducerTask()]
outputParameterKey := taskOutput.GetOutputParameterKey()
glog.V(4).Info("producer: ", producer)
currentTask := producer
currentSubTaskMaybeDAG := true
Expand All @@ -1286,13 +1288,14 @@ func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error {
// corresponding producer sub-task, reassign currentTask,
// and iterate through this loop again.
var outputParametersMap map[string]string
b, err := outputParametersCustomProperty["Output"].GetStructValue().MarshalJSON()
b, err := outputParametersCustomProperty[outputParameterKey].GetStructValue().MarshalJSON()
if err != nil {
return err
}
json.Unmarshal(b, &outputParametersMap)
glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap)
subTaskName := outputParametersMap["producer_subtask"]
outputParameterKey = outputParametersMap["output_parameter_key"]
glog.V(4).Infof(
"Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.",
currentTask.TaskName(),
Expand All @@ -1302,7 +1305,7 @@ func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error {
// Reassign sub-task before running through the loop again.
currentTask = tasks[subTaskName]
} else {
cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[taskOutput.GetOutputParameterKey()]
cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[outputParameterKey]
// Exit the loop.
currentSubTaskMaybeDAG = false
}
Expand All @@ -1329,6 +1332,7 @@ type resolveUpstreamArtifactsConfig struct {
// straightforward, or DAGs, in which case, we need to traverse the graph until
// we arrive at a component/container (since there can be n nested DAGs).
func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error {
glog.V(4).Infof("artifactSpec: %#v", cfg.artifactSpec)
taskOutput := cfg.artifactSpec.GetTaskOutputArtifact()
if taskOutput.GetProducerTask() == "" {
return cfg.artifactError(fmt.Errorf("producer task is empty"))
Expand All @@ -1349,7 +1353,7 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error {
}
glog.V(4).Info("producer: ", producer)
currentTask := producer
var outputArtifactKey string
var outputArtifactKey string = taskOutput.GetOutputArtifactKey()
currentSubTaskMaybeDAG := true
// Continue looping until we reach a sub-task that is NOT a DAG.
for currentSubTaskMaybeDAG {
Expand All @@ -1365,12 +1369,17 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error {
return err
}
glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts)
artifactSelectors := outputArtifacts["Output"].GetArtifactSelectors()
// TODO: Add support for multiple output artifacts.
subTaskName := artifactSelectors[0].ProducerSubtask
outputArtifactKey = artifactSelectors[0].OutputArtifactKey
glog.V(4).Info("subTaskName: ", subTaskName)
glog.V(4).Info("outputArtifactKey: ", outputArtifactKey)
// Adding support for multiple output artifacts
var subTaskName string
value := outputArtifacts[outputArtifactKey].GetArtifactSelectors()

for _, v := range value {
glog.V(4).Infof("v: %v", v)
glog.V(4).Infof("v.ProducerSubtask: %v", v.ProducerSubtask)
glog.V(4).Infof("v.OutputArtifactKey: %v", v.OutputArtifactKey)
subTaskName = v.ProducerSubtask
outputArtifactKey = v.OutputArtifactKey
}
// If the sub-task is a DAG, reassign currentTask and run
// through the loop again.
currentTask = tasks[subTaskName]
Expand Down

0 comments on commit 2f0af61

Please sign in to comment.