Skip to content

Commit

Permalink
Add support for mocking sub-workflows in tests using SdkTypes (#123)
Browse files Browse the repository at this point in the history
* Add support for mocking sub-workflows

Signed-off-by: Michel Davit <[email protected]>

* Use SdkType to mock subworkflows in tests

I would like to hide the literals from the user interface

Signed-off-by: Nelson Arapé <[email protected]>

* Hide Literals from users when mocking subworkflows

Signed-off-by: Nelson Arapé <[email protected]>

* Add unit tests for mocking subworkflows

Signed-off-by: Nelson Arapé <[email protected]>

* Minor formatting

Signed-off-by: Nelson Arapé <[email protected]>

* Refactoring idlTemplates generation/verification

Signed-off-by: Nelson Arapé <[email protected]>

* Remove debug statements

Signed-off-by: Nelson Arapé <[email protected]>

* Remove unused import

Signed-off-by: Nelson Arapé <[email protected]>

* Propagates variable descriptions

Signed-off-by: Nelson Arapé <[email protected]>

* Replace iteration with stream().collect()

Signed-off-by: Nelson Arapé <[email protected]>

* Refactor TestingWorkflow

Signed-off-by: Nelson Arapé <[email protected]>

Co-authored-by: Michel Davit <[email protected]>
Co-authored-by: Nelson Arapé <[email protected]>
  • Loading branch information
3 people authored Aug 1, 2022
1 parent 37a9acf commit a5a7952
Show file tree
Hide file tree
Showing 8 changed files with 314 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.flyte.examples;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
Expand All @@ -31,4 +32,24 @@ public void expand(SdkWorkflowBuilder builder) {
SdkBindingData result = builder.apply("sum", SumTask.of(left, right)).getOutput("c");
builder.output("result", result);
}

@AutoValue
public abstract static class Input {
abstract long left();

abstract long right();

public static Input create(long left, long right) {
return new AutoValue_SubWorkflow_Input(left, right);
}
}

@AutoValue
public abstract static class Output {
abstract long result();

public static Output create(long result) {
return new AutoValue_SubWorkflow_Output(result);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import org.flyte.examples.SumTask.SumInput;
import org.flyte.examples.SumTask.SumOutput;
import org.flyte.flytekit.jackson.JacksonSdkType;
import org.flyte.flytekit.testing.SdkTestingExecutor;
import org.junit.jupiter.api.Test;

Expand All @@ -37,7 +40,7 @@ public void testSubWorkflow() {
}

@Test
public void testMockSubWorkflow() {
public void testMockTasks() {
SdkTestingExecutor.Result result =
SdkTestingExecutor.of(new UberWorkflow())
.withFixedInput("a", 1)
Expand All @@ -54,4 +57,32 @@ public void testMockSubWorkflow() {

assertEquals(42L, result.getIntegerOutput("total"));
}

@Test
public void testMockSubWorkflow() {
SdkTestingExecutor.Result result =
SdkTestingExecutor.of(new UberWorkflow())
.withFixedInput("a", 1)
.withFixedInput("b", 2)
.withFixedInput("c", 3)
.withFixedInput("d", 4)
// Deliberately mock with absurd values to make sure that we are not picking the
// SumTask implementation
.withWorkflowOutput(
new SubWorkflow(),
JacksonSdkType.of(SubWorkflow.Input.class),
SubWorkflow.Input.create(1L, 2L),
JacksonSdkType.of(SubWorkflow.Output.class),
SubWorkflow.Output.create(5L))
.withWorkflowOutput(
new SubWorkflow(),
JacksonSdkType.of(SubWorkflow.Input.class),
SubWorkflow.Input.create(5L, 3L),
JacksonSdkType.of(SubWorkflow.Output.class),
SubWorkflow.Output.create(10L))
.withTaskOutput(new SumTask(), SumInput.create(10L, 4L), SumOutput.create(15L))
.execute();

assertEquals(15L, result.getIntegerOutput("total"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.flyte.api.v1.PartialWorkflowIdentifier;
import org.flyte.api.v1.Variable;
import org.flyte.api.v1.WorkflowNode;
import org.flyte.api.v1.WorkflowTemplate;

public abstract class SdkWorkflow extends SdkTransform {

Expand Down Expand Up @@ -69,4 +70,11 @@ public SdkNode apply(
return new SdkWorkflowNode(
builder, nodeId, upstreamNodeIds, metadata, workflowNode, inputs, outputs);
}

public WorkflowTemplate toIdlTemplate() {
SdkWorkflowBuilder builder = new SdkWorkflowBuilder();
this.expand(builder);

return builder.toIdlTemplate();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
*/
package org.flyte.flytekit.testing;

import static java.util.stream.Collectors.toMap;
import static org.flyte.api.v1.LiteralType.ofSimpleType;

import java.util.Map;
import java.util.stream.Collectors;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.SimpleType;
import org.flyte.api.v1.Variable;

class LiteralTypes {
static final LiteralType INTEGER = ofSimpleType(SimpleType.INTEGER);
Expand All @@ -29,6 +33,20 @@ class LiteralTypes {
static final LiteralType DATETIME = ofSimpleType(SimpleType.DATETIME);
static final LiteralType DURATION = ofSimpleType(SimpleType.DURATION);

static LiteralType from(Variable var) {
return var.literalType();
}

static Map<String, LiteralType> from(Map<String, Variable> vars) {
return vars.entrySet().stream().collect(toMap(Map.Entry::getKey, e -> from(e.getValue())));
}

static String toPrettyString(Map<String, LiteralType> literalTypes) {
return literalTypes.entrySet().stream()
.map(e -> String.format("%s=%s", e.getKey(), toPrettyString(e.getValue())))
.collect(Collectors.joining(", ", "{ ", " }"));
}

static String toPrettyString(LiteralType literalType) {
switch (literalType.getKind()) {
case SIMPLE_TYPE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Node;
import org.flyte.api.v1.TaskNode;
import org.flyte.api.v1.TypedInterface;
import org.flyte.api.v1.Variable;
import org.flyte.api.v1.WorkflowNode;
import org.flyte.api.v1.WorkflowTemplate;
import org.flyte.flytekit.SdkRemoteTask;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkType;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.localengine.LocalEngine;

@AutoValue
Expand Down Expand Up @@ -84,17 +84,11 @@ public static SdkTestingExecutor of(SdkWorkflow workflow, List<SdkRunnableTask<?

public static SdkTestingExecutor of(
SdkWorkflow workflow, List<SdkRunnableTask<?, ?>> tasks, List<SdkWorkflow> workflows) {
Map<String, TestingRunnableTask<?, ?>> fixedTasks = new HashMap<>();
for (SdkRunnableTask<?, ?> task : tasks) {
fixedTasks.put(task.getName(), TestingRunnableTask.create(task));
}
Map<String, TestingRunnableTask<?, ?>> fixedTasks =
tasks.stream().collect(toMap(SdkRunnableTask::getName, TestingRunnableTask::create));

Map<String, WorkflowTemplate> workflowTemplateMap = new HashMap<>();
for (SdkWorkflow w : workflows) {
SdkWorkflowBuilder builder = new SdkWorkflowBuilder();
w.expand(builder);
workflowTemplateMap.put(w.getName(), builder.toIdlTemplate());
}
Map<String, WorkflowTemplate> workflowTemplateMap =
workflows.stream().collect(toMap(SdkWorkflow::getName, SdkWorkflow::toIdlTemplate));

return SdkTestingExecutor.builder()
.workflow(workflow)
Expand Down Expand Up @@ -162,11 +156,8 @@ private Literal getOutput(String name, LiteralType expectedLiteralType) {
}

public Result execute() {
TestingSdkWorkflowBuilder builder =
new TestingSdkWorkflowBuilder(fixedInputMap(), fixedInputTypeMap());

workflow().expand(builder);
WorkflowTemplate workflowTemplate = builder.toIdlTemplate();
WorkflowTemplate workflowTemplate = workflow().toIdlTemplate();
checkInputsInFixedInputs(workflowTemplate);
checkFixedTransform(workflowTemplate);

Map<String, Literal> outputLiteralMap =
Expand All @@ -184,6 +175,31 @@ public Result execute() {
return Result.create(outputLiteralMap, outputLiteralTypeMap);
}

private void checkInputsInFixedInputs(WorkflowTemplate template) {
template
.interface_()
.inputs()
.forEach(
(inputName, inputVar) -> {
LiteralType inputType = inputVar.literalType();

LiteralType fixedInputType = fixedInputTypeMap().get(inputName);

checkArgument(
fixedInputType != null,
"Fixed input [%s] (of type %s) isn't defined, use SdkTestingExecutor#withFixedInput",
inputName,
LiteralTypes.toPrettyString(inputType));

checkArgument(
fixedInputType.equals(inputType),
"Fixed input [%s] (of type %s) doesn't match expected type %s",
inputName,
LiteralTypes.toPrettyString(fixedInputType),
LiteralTypes.toPrettyString(inputType));
});
}

private void checkFixedTransform(WorkflowTemplate template) {
for (Node node : template.nodes()) {
TaskNode taskNode = node.taskNode();
Expand Down Expand Up @@ -289,6 +305,47 @@ public <InputT, OutputT> SdkTestingExecutor withTask(
return toBuilder().putFixedTask(task.getName(), fixedTask.withRunFn(runFn)).build();
}

public <InputT, OutputT> SdkTestingExecutor withWorkflowOutput(
SdkWorkflow workflow,
SdkType<InputT> inputType,
InputT input,
SdkType<OutputT> outputType,
OutputT output) {
verifyInputOutputMatchesWorkflowInterface(workflow, inputType, outputType);

// fixed tasks
TestingRunnableTask<InputT, OutputT> fixedTask =
getFixedTaskOrDefault(workflow.getName(), inputType, outputType);

// replace workflow
SdkWorkflow mockWorkflow = new TestingWorkflow<>(inputType, outputType, output);

return toBuilder()
.putWorkflowTemplate(workflow.getName(), mockWorkflow.toIdlTemplate())
.putFixedTask(workflow.getName(), fixedTask.withFixedOutput(input, output))
.build();
}

private static <InputT, OutputT> void verifyInputOutputMatchesWorkflowInterface(
SdkWorkflow workflow, SdkType<InputT> inputType, SdkType<OutputT> outputType) {
TypedInterface intf = workflow.toIdlTemplate().interface_();

verifyVariablesMatches("Input", intf.inputs(), inputType.getVariableMap());
verifyVariablesMatches("Output", intf.outputs(), outputType.getVariableMap());
}

private static void verifyVariablesMatches(
String type, Map<String, Variable> actualVariables, Map<String, Variable> variables) {
if (!actualVariables.equals(variables)) {
throw new IllegalArgumentException(
String.format(
"%s type %s doesn't match expected type %s",
type,
LiteralTypes.toPrettyString(LiteralTypes.from(variables)),
LiteralTypes.toPrettyString(LiteralTypes.from(actualVariables))));
}
}

private <InputT, OutputT> TestingRunnableTask<InputT, OutputT> getFixedTaskOrDefault(
String name, SdkType<InputT> inputType, SdkType<OutputT> outputType) {
@SuppressWarnings({"unchecked"})
Expand Down Expand Up @@ -346,6 +403,13 @@ Builder putFixedTask(String name, TestingRunnableTask<?, ?> fn) {
return fixedTaskMap(newFixedTaskMap);
}

Builder putWorkflowTemplate(String name, WorkflowTemplate template) {
Map<String, WorkflowTemplate> newWorkflowTemplateMap = new HashMap<>(workflowTemplateMap());
newWorkflowTemplateMap.put(name, template);

return workflowTemplateMap(newWorkflowTemplateMap);
}

abstract SdkTestingExecutor build();
}
}

This file was deleted.

Loading

0 comments on commit a5a7952

Please sign in to comment.