From 305495e32dcd458b5bc0a8d15dfc82212dffa26b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nelson=20Arap=C3=A9?= Date: Tue, 14 Feb 2023 09:04:45 +0100 Subject: [PATCH] Refactor SdkBindindData factory methods (#193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WIP Signed-off-by: Nelson Arapé * Refactor SdkBindingData ans SdkType ... based on SdkLiteralType Still missing: * conversions * javadocs * check error messages Signed-off-by: Nelson Arapé * WIP for converting lists Signed-off-by: Nelson Arapé * WIP transform lists Signed-off-by: Lucía Pasarin * First test compiles Signed-off-by: Nelson Arapé * Now ready for review Signed-off-by: Nelson Arapé * Format Signed-off-by: Nelson Arapé * Fix invalid javadoc Signed-off-by: Nelson Arapé * Address first round of feedback Signed-off-by: Nelson Arapé * Rename SdkBindingDatas to SdkBindingDataFactory and change ofX methods Signed-off-by: Pablo Casares Crespo * Review changes Signed-off-by: Andres Gomez Ferrer * Fix verify Signed-off-by: Andres Gomez Ferrer * Fix integration test Signed-off-by: Andres Gomez Ferrer --------- Signed-off-by: Nelson Arapé Signed-off-by: Lucía Pasarin Signed-off-by: Pablo Casares Crespo Signed-off-by: Andres Gomez Ferrer Co-authored-by: Lucía Pasarin Co-authored-by: Pablo Casares Crespo Co-authored-by: Andres Gomez Ferrer --- .../flytekitscala/AddQuestionTask.scala | 8 +- .../DynamicFibonacciWorkflow.scala | 2 +- .../DynamicFibonacciWorkflowTask.scala | 6 +- .../examples/flytekitscala/GreetTask.scala | 9 +- .../examples/flytekitscala/SumTask.scala | 4 +- .../flytekitscala/WelcomeWorkflow.scala | 2 +- .../WorkflowWithRemoteLaunchPlan.scala | 6 +- .../WorkflowWithRemoteTask.scala | 6 +- .../org/flyte/examples/AddQuestionTask.java | 3 +- .../org/flyte/examples/AllInputsWorkflow.java | 21 +- .../org/flyte/examples/BatchLookUpTask.java | 3 +- .../examples/ConditionalGreetingWorkflow.java | 6 +- .../DynamicFibonacciWorkflowTask.java | 7 +- .../flyte/examples/FibonacciLaunchPlan.java | 3 +- .../java/org/flyte/examples/GreetTask.java | 3 +- .../examples/NodeMetadataExampleWorkflow.java | 5 +- .../org/flyte/examples/PhoneBookWorkflow.java | 5 +- .../main/java/org/flyte/examples/SumTask.java | 3 +- .../java/org/flyte/examples/WorkflowTest.java | 26 +- .../flytekit/jackson/JacksonSdkType.java | 34 +- .../flytekit/jackson/RootFormatVisitor.java | 9 + .../flytekit/jackson/VariableMapVisitor.java | 107 ++-- .../SdkBindingDataDeserializer.java | 69 ++- .../flytekit/jackson/JacksonSdkTypeTest.java | 73 +-- .../java/org/flyte/flytekit/Compiler.java | 2 +- .../org/flyte/flytekit/SdkBindingData.java | 582 ++++++------------ .../flyte/flytekit/SdkBindingDataFactory.java | 299 +++++++++ .../org/flyte/flytekit/SdkBranchNode.java | 13 +- .../org/flyte/flytekit/SdkConditions.java | 4 +- .../org/flyte/flytekit/SdkContainerTask.java | 9 +- .../flytekit/SdkDynamicWorkflowTask.java | 9 +- .../org/flyte/flytekit/SdkLaunchPlan.java | 3 +- .../org/flyte/flytekit/SdkLiteralType.java | 61 +- .../org/flyte/flytekit/SdkLiteralTypes.java | 287 +++++++-- .../flyte/flytekit/SdkRemoteLaunchPlan.java | 13 +- .../org/flyte/flytekit/SdkRemoteTask.java | 5 +- .../org/flyte/flytekit/SdkRunnableTask.java | 5 +- .../java/org/flyte/flytekit/SdkTaskNode.java | 21 +- .../main/java/org/flyte/flytekit/SdkType.java | 21 +- .../java/org/flyte/flytekit/SdkTypes.java | 5 + .../java/org/flyte/flytekit/SdkWorkflow.java | 8 +- .../flyte/flytekit/SdkWorkflowBuilder.java | 8 +- .../flyte/flytekit/WorkflowTemplateIdl.java | 2 +- .../flytekit/SdkBindingDataFactoryTest.java | 228 +++++++ .../flyte/flytekit/SdkBindingDataTest.java | 338 ---------- .../org/flyte/flytekit/SdkLaunchPlanTest.java | 49 +- .../flyte/flytekit/SdkLiteralTypesTest.java | 73 ++- .../flytekit/SdkRemoteLaunchPlanTest.java | 5 +- .../org/flyte/flytekit/SdkRemoteTaskTest.java | 6 +- .../org/flyte/flytekit/SdkTransformTest.java | 7 +- .../flytekit/SdkWorkflowBuilderTest.java | 20 +- .../flyte/flytekit/TestPairIntegerInput.java | 13 +- .../flytekit/TestUnaryBooleanOutput.java | 14 +- .../flyte/flytekit/TestUnaryIntegerInput.java | 15 +- .../flytekit/TestUnaryIntegerOutput.java | 14 +- .../flyte/localengine/LocalEngineTest.java | 10 +- .../CollatzConjectureStepWorkflow.java | 10 +- .../flyte/localengine/examples/ListTask.java | 3 +- .../localengine/examples/ListWorkflow.java | 21 +- .../flyte/localengine/examples/MapTask.java | 3 +- .../localengine/examples/MapWorkflow.java | 19 +- .../flyte/localengine/examples/SumTask.java | 3 +- flytekit-scala-tests/pom.xml | 5 - .../SdkBindingDataConvertersTest.scala | 347 +++++++++++ .../flytekitscala/SdkLiteralTypesTest.scala | 148 +++++ .../flytekitscala/SdkScalaTypeTest.scala | 153 +++-- .../flytekit/SdkBindingDataConverters.scala | 196 ------ .../flyte/flytekitscala/SdkBindingData.scala | 502 --------------- .../SdkBindingDataConverters.scala | 328 ++++++++++ .../flytekitscala/SdkBindingDataFactory.scala | 372 +++++++++++ .../flyte/flytekitscala/SdkLiteralTypes.scala | 283 +++++++++ .../flyte/flytekitscala/SdkScalaType.scala | 188 ++---- .../testing/FibonacciWorkflowTest.java | 24 +- .../flytekit/testing/IfElseWorkflowTest.java | 24 +- .../flyte/flytekit/testing/RemoteSumTask.java | 3 +- .../testing/RemoteVoidOutputTask.java | 3 +- .../testing/SdkTestingExecutorTest.java | 32 +- .../org/flyte/flytekit/testing/SumTask.java | 3 +- .../testing/TestingRunnableNodeTest.java | 7 +- .../integrationtests/BranchNodeWorkflow.java | 19 +- .../structs/BuildBqReference.java | 3 +- .../structs/MockLookupBqTask.java | 6 +- .../structs/MockPipelineWorkflow.java | 10 +- jflyte/pom.xml | 5 - pom.xml | 5 - 85 files changed, 3142 insertions(+), 2152 deletions(-) create mode 100644 flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java create mode 100644 flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataFactoryTest.java delete mode 100644 flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataTest.java create mode 100644 flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkBindingDataConvertersTest.scala create mode 100644 flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala delete mode 100644 flytekit-scala_2.13/src/main/scala/org/flyte/flytekit/SdkBindingDataConverters.scala delete mode 100644 flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingData.scala create mode 100644 flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala create mode 100644 flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala create mode 100644 flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala index 96f6cad86..5e5f47dcc 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala @@ -16,9 +16,9 @@ */ package org.flyte.examples.flytekitscala -import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform} +import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask} import org.flyte.flytekitscala.SdkScalaType -import org.flyte.flytekitscala.SdkBindingData.ofString +import org.flyte.flytekitscala.SdkBindingDataFactory case class AddQuestionTaskInput(greeting: SdkBindingData[String]) case class AddQuestionTaskOutput(greeting: SdkBindingData[String]) @@ -44,5 +44,7 @@ class AddQuestionTask * the updated greeting message */ override def run(input: AddQuestionTaskInput): AddQuestionTaskOutput = - AddQuestionTaskOutput(ofString(s"${input.greeting.get} How are you?")) + AddQuestionTaskOutput( + SdkBindingDataFactory.of(s"${input.greeting.get} How are you?") + ) } diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala index eaabf7f10..a0fdea73b 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala @@ -16,7 +16,7 @@ */ package org.flyte.examples.flytekitscala -import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder} +import org.flyte.flytekit.SdkBindingData import org.flyte.flytekitscala.{ SdkScalaType, SdkScalaWorkflow, diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala index 5e84affb8..a7bdeb3dc 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala @@ -24,7 +24,7 @@ import org.flyte.flytekit.{ import org.flyte.flytekitscala.SdkScalaType import scala.annotation.tailrec -import org.flyte.flytekitscala.SdkBindingData._ +import org.flyte.flytekitscala.SdkBindingDataFactory._ case class DynamicFibonacciWorkflowTaskInput(n: SdkBindingData[Long]) case class DynamicFibonacciWorkflowTaskOutput(output: SdkBindingData[Long]) @@ -64,9 +64,9 @@ class DynamicFibonacciWorkflowTask require(input.n.get > 0, "n < 0") val value = if (input.n.get == 0) { - ofInteger(0) + of(0) } else { - fib(1, ofInteger(1), ofInteger(0)) + fib(1, of(1), of(0)) } DynamicFibonacciWorkflowTaskOutput(value) } diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala index 0c2c812cb..0370bea66 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala @@ -17,8 +17,11 @@ package org.flyte.examples.flytekitscala import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform} -import org.flyte.flytekitscala.{Description, SdkScalaType} -import org.flyte.flytekitscala.SdkBindingData._ +import org.flyte.flytekitscala.{ + Description, + SdkBindingDataFactory, + SdkScalaType +} case class GreetTaskInput( @Description("the name of the person to be greeted") @@ -47,5 +50,5 @@ class GreetTask * the welcome message */ override def run(input: GreetTaskInput): GreetTaskOutput = - GreetTaskOutput(ofString(s"Welcome, ${input.name.get()}!")) + GreetTaskOutput(SdkBindingDataFactory.of(s"Welcome, ${input.name.get()}!")) } diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala index 3b9ce7b9e..d1b480b91 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala @@ -18,7 +18,7 @@ package org.flyte.examples.flytekitscala import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform} import org.flyte.flytekitscala.{Description, SdkScalaType} -import org.flyte.flytekitscala.SdkBindingData._ +import org.flyte.flytekitscala.SdkBindingDataFactory._ case class SumTaskInput( @Description("First operand") @@ -39,7 +39,7 @@ class SumTask override def run(input: SumTaskInput): SumTaskOutput = { val result = input.a.get + input.b.get - SumTaskOutput(ofInteger(result)) + SumTaskOutput(of(result)) } override def isCached: Boolean = true diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala index 3183ac7dd..35c47b0b6 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala @@ -16,7 +16,7 @@ */ package org.flyte.examples.flytekitscala -import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder} +import org.flyte.flytekit.SdkBindingData import org.flyte.flytekitscala.{ SdkScalaType, SdkScalaWorkflow, diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteLaunchPlan.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteLaunchPlan.scala index f95548799..9ef17faa4 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteLaunchPlan.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteLaunchPlan.scala @@ -17,7 +17,7 @@ package org.flyte.examples.flytekitscala import org.flyte.flytekitscala.{ - SdkBindingData, + SdkBindingDataFactory, SdkScalaType, SdkScalaWorkflow, SdkScalaWorkflowBuilder @@ -33,8 +33,8 @@ class WorkflowWithRemoteLaunchPlan builder: SdkScalaWorkflowBuilder, input: RemoteLaunchPlanInput ): RemoteLaunchPlanOutput = { - val fib0 = SdkBindingData.ofInteger(0L) - val fib1 = SdkBindingData.ofInteger(1L) + val fib0 = SdkBindingDataFactory.of(0L) + val fib1 = SdkBindingDataFactory.of(1L) val fib5 = builder .apply( diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteTask.scala index 44450c75f..62fb6925b 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WorkflowWithRemoteTask.scala @@ -17,7 +17,7 @@ package org.flyte.examples.flytekitscala import org.flyte.flytekitscala.{ - SdkBindingData, + SdkBindingDataFactory, SdkScalaType, SdkScalaWorkflow, SdkScalaWorkflowBuilder @@ -33,8 +33,8 @@ class WorkflowWithRemoteTask builder: SdkScalaWorkflowBuilder, input: RemoteSumTaskInput ): RemoteSumTaskOutput = { - val a = SdkBindingData.ofInteger(10) - val b = SdkBindingData.ofInteger(12) + val a = SdkBindingDataFactory.of(10) + val b = SdkBindingDataFactory.of(12) val c = builder .apply(new RemoteSumTask().create, RemoteSumTaskInput(a, b)) diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java index 97f834693..d95f5f32f 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java @@ -19,6 +19,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -75,6 +76,6 @@ public static Output create(SdkBindingData greeting) { @Override public Output run(Input input) { return Output.create( - SdkBindingData.ofString(String.format("%s How are you?", input.greeting().get()))); + SdkBindingDataFactory.of(String.format("%s How are you?", input.greeting().get()))); } } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 085d14446..38cc7d361 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -26,6 +26,7 @@ import java.util.Map; import org.flyte.examples.AllInputsTask.AutoAllInputsOutput; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkNode; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; @@ -50,16 +51,16 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) "all-inputs", new AllInputsTask(), AllInputsTask.AutoAllInputsInput.create( - SdkBindingData.ofInteger(1L), - SdkBindingData.ofFloat(2), - SdkBindingData.ofString("test"), - SdkBindingData.ofBoolean(true), - SdkBindingData.ofDatetime(someInstant), - SdkBindingData.ofDuration(Duration.ofDays(1L)), - SdkBindingData.ofStringCollection(Arrays.asList("foo", "bar")), - SdkBindingData.ofStringMap(Map.of("test", "test")), - SdkBindingData.ofStringCollection(Collections.emptyList()), - SdkBindingData.ofIntegerMap(Collections.emptyMap()))); + SdkBindingDataFactory.of(1L), + SdkBindingDataFactory.of(2.00), + SdkBindingDataFactory.of("test"), + SdkBindingDataFactory.of(true), + SdkBindingDataFactory.of(someInstant), + SdkBindingDataFactory.of(Duration.ofDays(1L)), + SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")), + SdkBindingDataFactory.ofStringMap(Map.of("test", "test")), + SdkBindingDataFactory.ofStringCollection(Collections.emptyList()), + SdkBindingDataFactory.ofIntegerMap(Collections.emptyMap()))); AllInputsTask.AutoAllInputsOutput outputs = apply.getOutputs(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java b/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java index bd0889a75..a186731c3 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.stream.Collectors; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -46,7 +47,7 @@ public Output run(Input input) { .map(key -> input.keyValues().get().get(key)) .collect(Collectors.toList()); - return Output.create(SdkBindingData.ofStringCollection(foundValues)); + return Output.create(SdkBindingDataFactory.ofStringCollection(foundValues)); } @AutoValue diff --git a/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java index e30724f1c..c13ab3b0b 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java @@ -16,11 +16,11 @@ */ package org.flyte.examples; -import static org.flyte.flytekit.SdkBindingData.ofString; import static org.flyte.flytekit.SdkConditions.eq; import com.google.auto.service.AutoService; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkConditions; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -40,9 +40,9 @@ public GreetTask.Output expand(SdkWorkflowBuilder builder, GreetTask.Input input "decide", SdkConditions.when( "when-empty", - eq(input.name(), ofString("")), + eq(input.name(), SdkBindingDataFactory.of("")), new GreetTask(), - GreetTask.Input.create(ofString("World"))) + GreetTask.Input.create(SdkBindingDataFactory.of("World"))) .otherwise( "when-not-empty", new GreetTask(), GreetTask.Input.create(input.name()))) .getOutputs() diff --git a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java index ff264f1a0..ad917813d 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java @@ -21,6 +21,7 @@ import com.google.errorprone.annotations.Var; import org.flyte.examples.SumTask.SumInput; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkDynamicWorkflowTask; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -56,10 +57,10 @@ public Output run(SdkWorkflowBuilder builder, Input input) { if (input.n().get() < 0) { throw new IllegalArgumentException("n < 0"); } else if (input.n().get() == 0) { - return Output.create(SdkBindingData.ofInteger(0)); + return Output.create(SdkBindingDataFactory.of(0)); } else { - @Var SdkBindingData prev = SdkBindingData.ofInteger(0); - @Var SdkBindingData value = SdkBindingData.ofInteger(1); + @Var SdkBindingData prev = SdkBindingDataFactory.of(0); + @Var SdkBindingData value = SdkBindingDataFactory.of(1); for (int i = 2; i <= input.n().get(); i++) { SdkBindingData next = builder.apply("fib-" + i, new SumTask(), SumInput.create(value, prev)).getOutputs().c(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java index 01f9be47c..01c6b6bc4 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java @@ -19,6 +19,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkLaunchPlan; import org.flyte.flytekit.SdkLaunchPlanRegistry; import org.flyte.flytekit.SimpleSdkLaunchPlanRegistry; @@ -37,7 +38,7 @@ public FibonacciLaunchPlan() { .withName("FibonacciWorkflowLaunchPlan") .withFixedInputs( JacksonSdkType.of(Input.class), - Input.create(SdkBindingData.ofInteger(0), SdkBindingData.ofInteger(1)))); + Input.create(SdkBindingDataFactory.of(0), SdkBindingDataFactory.of(1)))); // Register launch plan with fixed inputs specified directly registerLaunchPlan( diff --git a/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java b/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java index 5ccc5c9d3..cf4425588 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java @@ -19,6 +19,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -71,6 +72,6 @@ public static Output create(SdkBindingData greeting) { @Override public Output run(Input input) { return Output.create( - SdkBindingData.ofString(String.format("Welcome, %s!", input.name().get()))); + SdkBindingDataFactory.of(String.format("Welcome, %s!", input.name().get()))); } } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java index cff26d84d..3e665e06b 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java @@ -20,6 +20,7 @@ import com.google.auto.value.AutoValue; import java.time.Duration; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -50,8 +51,8 @@ public NodeMetadataExampleWorkflow() { @Override public Output expand(SdkWorkflowBuilder builder, Void noInput) { - SdkBindingData a = SdkBindingData.ofInteger(0); - SdkBindingData b = SdkBindingData.ofInteger(1); + SdkBindingData a = SdkBindingDataFactory.of(0); + SdkBindingData b = SdkBindingDataFactory.of(1); SdkBindingData c = builder diff --git a/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java index 834047f71..151f25d21 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -61,9 +62,9 @@ public PhoneBookWorkflow() { @Override public Output expand(SdkWorkflowBuilder builder, Void noInput) { - SdkBindingData> phoneBook = SdkBindingData.ofStringMap(PHONE_BOOK); + SdkBindingData> phoneBook = SdkBindingDataFactory.ofStringMap(PHONE_BOOK); - SdkBindingData> searchKeys = SdkBindingData.ofStringCollection(NAMES); + SdkBindingData> searchKeys = SdkBindingDataFactory.ofStringCollection(NAMES); SdkBindingData> phoneNumbers = builder diff --git a/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java b/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java index 39da7e54d..4e373a55b 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java @@ -19,6 +19,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.Description; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -55,7 +56,7 @@ public static SumOutput create(SdkBindingData c) { @Override public SumOutput run(SumInput input) { - return SumOutput.create(SdkBindingData.ofInteger(input.a().get() + input.b().get())); + return SumOutput.create(SdkBindingDataFactory.of(input.a().get() + input.b().get())); } @Override diff --git a/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java b/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java index 1957468eb..b1c806b47 100644 --- a/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java +++ b/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java @@ -20,7 +20,7 @@ import org.flyte.examples.SumTask.SumInput; import org.flyte.examples.SumTask.SumOutput; -import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.jackson.JacksonSdkType; import org.flyte.flytekit.testing.SdkTestingExecutor; import org.junit.jupiter.api.Test; @@ -50,16 +50,16 @@ public void testMockTasks() { .withFixedInput("d", 4) .withTaskOutput( new SumTask(), - SumTask.SumInput.create(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L)), - SumTask.SumOutput.create(SdkBindingData.ofInteger(0L))) + SumTask.SumInput.create(SdkBindingDataFactory.of(1L), SdkBindingDataFactory.of(2L)), + SumTask.SumOutput.create(SdkBindingDataFactory.of(0L))) .withTaskOutput( new SumTask(), - SumTask.SumInput.create(SdkBindingData.ofInteger(0L), SdkBindingData.ofInteger(3L)), - SumTask.SumOutput.create(SdkBindingData.ofInteger(0L))) + SumTask.SumInput.create(SdkBindingDataFactory.of(0L), SdkBindingDataFactory.of(3L)), + SumTask.SumOutput.create(SdkBindingDataFactory.of(0L))) .withTaskOutput( new SumTask(), - SumTask.SumInput.create(SdkBindingData.ofInteger(0L), SdkBindingData.ofInteger(4L)), - SumTask.SumOutput.create(SdkBindingData.ofInteger(42L))) + SumTask.SumInput.create(SdkBindingDataFactory.of(0L), SdkBindingDataFactory.of(4L)), + SumTask.SumOutput.create(SdkBindingDataFactory.of(42L))) .execute(); assertEquals(42L, result.getIntegerOutput("result")); @@ -79,20 +79,20 @@ public void testMockSubWorkflow() { new SubWorkflow(), JacksonSdkType.of(SubWorkflow.Input.class), SubWorkflow.Input.create( - SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L)), + SdkBindingDataFactory.of(1L), SdkBindingDataFactory.of(2L)), JacksonSdkType.of(SubWorkflow.Output.class), - SubWorkflow.Output.create(SdkBindingData.ofInteger(5L))) + SubWorkflow.Output.create(SdkBindingDataFactory.of(5L))) .withWorkflowOutput( new SubWorkflow(), JacksonSdkType.of(SubWorkflow.Input.class), SubWorkflow.Input.create( - SdkBindingData.ofInteger(5L), SdkBindingData.ofInteger(3L)), + SdkBindingDataFactory.of(5L), SdkBindingDataFactory.of(3L)), JacksonSdkType.of(SubWorkflow.Output.class), - SubWorkflow.Output.create(SdkBindingData.ofInteger(10L))) + SubWorkflow.Output.create(SdkBindingDataFactory.of(10L))) .withTaskOutput( new SumTask(), - SumInput.create(SdkBindingData.ofInteger(10L), SdkBindingData.ofInteger(4L)), - SumOutput.create(SdkBindingData.ofInteger(15L))) + SumInput.create(SdkBindingDataFactory.of(10L), SdkBindingDataFactory.of(4L)), + SumOutput.create(SdkBindingDataFactory.of(15L))) .execute(); assertEquals(15L, result.getIntegerOutput("result")); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java index b85841c08..5571cc9ad 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java @@ -16,6 +16,7 @@ */ package org.flyte.flytekit.jackson; +import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; import com.fasterxml.jackson.core.JsonParser; @@ -32,12 +33,12 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkLiteralType; import org.flyte.flytekit.SdkType; import org.flyte.flytekit.jackson.deserializers.CustomSdkBindingDataDeserializers; import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializer; @@ -49,12 +50,17 @@ public class JacksonSdkType extends SdkType { private final Class clazz; private final Map variableMap; private final Map membersMap; + private final Map> typesMap; private JacksonSdkType( - Class clazz, Map variableMap, Map membersMap) { - this.clazz = Objects.requireNonNull(clazz); - this.variableMap = Objects.requireNonNull(variableMap); - this.membersMap = Objects.requireNonNull(membersMap); + Class clazz, + Map variableMap, + Map membersMap, + Map> typesMap) { + this.clazz = requireNonNull(clazz); + this.variableMap = Map.copyOf(requireNonNull(variableMap)); + this.membersMap = Map.copyOf(requireNonNull(membersMap)); + this.typesMap = Map.copyOf(requireNonNull(typesMap)); } public static JacksonSdkType of(Class clazz) { @@ -78,7 +84,8 @@ public static JacksonSdkType of(Class clazz) { serializer.acceptJsonFormatVisitor( visitor, OBJECT_MAPPER.getTypeFactory().constructType(clazz)); - return new JacksonSdkType<>(clazz, visitor.getVariableMap(), visitor.getMembersMap()); + return new JacksonSdkType<>( + clazz, visitor.getVariableMap(), visitor.getMembersMap(), visitor.getTypesMap()); } catch (JsonMappingException e) { throw new IllegalArgumentException( String.format("Failed to find serializer for [%s]", clazz.getName()), e); @@ -125,6 +132,11 @@ public Map getVariableMap() { return variableMap; } + @Override + public Map> toLiteralTypes() { + return typesMap; + } + private Map getMembersMap() { return membersMap; } @@ -148,8 +160,8 @@ public T fromLiteralMap(Map value) { * Method used to create SdkBindingData output references/promises for SdkTransform outputs (e.g., * workflows and tasks outputs). We need to go from {@code Map} to object of * output class T. We leverage Jackson to help create the object of the output class T from the - * map. We use a the BindingMapSerializer to serialize only the keys of the map to JsonNode - * Instead of recreating SdkBindingData objects we pass the bindingMap to the + * map. We use the BindingMapSerializer to serialize only the keys of the map to JsonNode, instead + * of recreating SdkBindingData objects we pass the bindingMap to the * CustomSdkBindingDataDeserializers so it can get use the keys to retrieve the objects from the * map. We need to create a new object mapper to use a different deserializer for SdkBindingData * than the one used in other places. @@ -158,13 +170,11 @@ public T fromLiteralMap(Map value) { public T promiseFor(String nodeId) { try { Map> bindingMap = - getVariableMap().entrySet().stream() + typesMap.entrySet().stream() .collect( toMap( Map.Entry::getKey, - x -> - SdkBindingData.ofOutputReference( - nodeId, x.getKey(), x.getValue().literalType()))); + e -> SdkBindingData.promise(e.getValue(), nodeId, e.getKey()))); JsonNode tree = OBJECT_MAPPER.valueToTree(new JacksonBindingMap(bindingMap)); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java index b4142b5fa..b15b3d4ff 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonObjectFormatVisitor; import java.util.Map; import org.flyte.api.v1.Variable; +import org.flyte.flytekit.SdkLiteralType; class RootFormatVisitor extends JsonFormatVisitorWrapper.Base { @@ -53,4 +54,12 @@ public Map getMembersMap() { return builder.getMembersMap(); } + + public Map> getTypesMap() { + if (builder == null) { + throw new IllegalStateException("invariant failed: typesMap not set"); + } + + return builder.getTypesMap(); + } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index 9185cad7d..414f138e3 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -18,70 +18,68 @@ import static java.util.Collections.unmodifiableMap; -import com.fasterxml.jackson.databind.BeanDescription; import com.fasterxml.jackson.databind.BeanProperty; import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.introspect.AnnotatedMember; -import com.fasterxml.jackson.databind.introspect.BeanPropertyDefinition; import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonObjectFormatVisitor; import java.time.Duration; import java.time.Instant; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import org.flyte.api.v1.Blob; -import org.flyte.api.v1.BlobType; -import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.SimpleType; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkLiteralType; +import org.flyte.flytekit.SdkLiteralTypes; class VariableMapVisitor extends JsonObjectFormatVisitor.Base { - private static final Map, Class> PRIMITIVE_TO_WRAPPER; + private static final Map, Class> PRIMITIVE_TO_WRAPPER = + Map.of( + void.class, Void.class, + boolean.class, Boolean.class, + byte.class, Byte.class, + char.class, Character.class, + short.class, Short.class, + int.class, Integer.class, + long.class, Long.class, + float.class, Float.class, + double.class, Double.class); VariableMapVisitor(SerializerProvider provider) { super(provider); } - static { - Map, Class> map = new HashMap<>(); - map.put(void.class, Void.class); - map.put(boolean.class, Boolean.class); - map.put(byte.class, Byte.class); - map.put(char.class, Character.class); - map.put(short.class, Short.class); - map.put(int.class, Integer.class); - map.put(long.class, Long.class); - map.put(float.class, Float.class); - map.put(double.class, Double.class); - PRIMITIVE_TO_WRAPPER = unmodifiableMap(map); - } - - private final Map builder = new LinkedHashMap<>(); + private final Map builderVariables = new LinkedHashMap<>(); private final Map builderMembers = new LinkedHashMap<>(); + private final Map> builderTypes = new LinkedHashMap<>(); @Override public void property(BeanProperty prop) { JavaType handledType = getHandledType(prop); - LiteralType literalType = + String propName = prop.getName(); + AnnotatedMember member = prop.getMember(); + SdkLiteralType literalType = toLiteralType( handledType, /*rootLevel=*/ true, - prop.getName(), - prop.getMember().getMember().getDeclaringClass().getName()); + propName, + member.getMember().getDeclaringClass().getName()); - String description = getDescription(prop.getMember()); + String description = getDescription(member); Variable variable = - Variable.builder().description(description).literalType(literalType).build(); - - builderMembers.put(prop.getName(), prop.getMember()); - builder.put(prop.getName(), variable); + Variable.builder() + .description(description) + .literalType(literalType.getLiteralType()) + .build(); + + builderMembers.put(propName, member); + builderVariables.put(propName, variable); + builderTypes.put(propName, literalType); } private JavaType getHandledType(BeanProperty prop) { @@ -111,11 +109,15 @@ public void optionalProperty(BeanProperty prop) { } public Map getVariableMap() { - return unmodifiableMap(new HashMap<>(builder)); + return unmodifiableMap(builderVariables); } public Map getMembersMap() { - return unmodifiableMap(new HashMap<>(builderMembers)); + return unmodifiableMap(builderMembers); + } + + public Map> getTypesMap() { + return unmodifiableMap(builderTypes); } private String getDescription(AnnotatedMember member) { @@ -129,7 +131,7 @@ private String getDescription(AnnotatedMember member) { } @SuppressWarnings("AlreadyChecked") - private LiteralType toLiteralType( + private SdkLiteralType toLiteralType( JavaType javaType, boolean rootLevel, String propName, String declaringClassName) { Class type = javaType.getRawClass(); @@ -143,21 +145,21 @@ private LiteralType toLiteralType( + "Please make sure your variable declared type is wrapped in 'SdkBindingData<>'.", propName, declaringClassName, type)); } else if (isPrimitiveAssignableFrom(Long.class, type)) { - return LiteralTypes.INTEGER; + return SdkLiteralTypes.integers(); } else if (isPrimitiveAssignableFrom(Double.class, type)) { - return LiteralTypes.FLOAT; + return SdkLiteralTypes.floats(); } else if (String.class.equals(type) || javaType.isEnumType()) { - return LiteralTypes.STRING; + return SdkLiteralTypes.strings(); } else if (isPrimitiveAssignableFrom(Boolean.class, type)) { - return LiteralTypes.BOOLEAN; + return SdkLiteralTypes.booleans(); } else if (Instant.class.isAssignableFrom(type)) { - return LiteralTypes.DATETIME; + return SdkLiteralTypes.datetimes(); } else if (Duration.class.isAssignableFrom(type)) { - return LiteralTypes.DURATION; + return SdkLiteralTypes.durations(); } else if (List.class.isAssignableFrom(type)) { JavaType elementType = javaType.getBindings().getBoundType(0); - return LiteralType.ofCollectionType( + return SdkLiteralTypes.collections( toLiteralType(elementType, false, propName, declaringClassName)); } else if (Map.class.isAssignableFrom(type)) { JavaType keyType = javaType.getBindings().getBoundType(0); @@ -168,26 +170,11 @@ private LiteralType toLiteralType( "Only Map is supported, got [" + javaType.getGenericSignature() + "]"); } - return LiteralType.ofMapValueType( - toLiteralType(valueType, false, propName, declaringClassName)); - } else if (Blob.class.isAssignableFrom(type)) { - // TODO add annotation to specify dimensionality and format - BlobType blobType = - BlobType.builder().format("").dimensionality(BlobType.BlobDimensionality.SINGLE).build(); - - return LiteralType.ofBlobType(blobType); - } - - BeanDescription bean = getProvider().getConfig().introspect(javaType); - List properties = bean.findProperties(); - - if (properties.isEmpty()) { - // doesn't look like a bean, can be java.lang.Integer, or something else - throw new UnsupportedOperationException( - String.format("Unsupported type: [%s]", type.getName())); - } else { - return LiteralType.ofSimpleType(SimpleType.STRUCT); + return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, declaringClassName)); } + // TODO: Support blobs and structs + throw new UnsupportedOperationException( + String.format("Unsupported type: [%s]", type.getName())); } private static boolean isPrimitiveAssignableFrom(Class fromClass, Class toClass) { diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index e40ee86e4..e52626153 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -45,6 +45,9 @@ import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; +import org.flyte.flytekit.SdkLiteralType; +import org.flyte.flytekit.SdkLiteralTypes; class SdkBindingDataDeserializer extends StdDeserializer> { private static final long serialVersionUID = 0L; @@ -84,17 +87,17 @@ private static SdkBindingData transformScalar(JsonNode t Primitive.Kind primitiveKind = Primitive.Kind.valueOf(tree.get("primitive").asText()); switch (primitiveKind) { case INTEGER_VALUE: - return SdkBindingData.ofInteger(tree.get(VALUE).longValue()); + return SdkBindingDataFactory.of(tree.get(VALUE).longValue()); case BOOLEAN_VALUE: - return SdkBindingData.ofBoolean(tree.get(VALUE).booleanValue()); + return SdkBindingDataFactory.of(tree.get(VALUE).booleanValue()); case STRING_VALUE: - return SdkBindingData.ofString(tree.get(VALUE).asText()); + return SdkBindingDataFactory.of(tree.get(VALUE).asText()); case DURATION: - return SdkBindingData.ofDuration(Duration.parse(tree.get(VALUE).asText())); + return SdkBindingDataFactory.of(Duration.parse(tree.get(VALUE).asText())); case DATETIME: - return SdkBindingData.ofDatetime(Instant.parse(tree.get(VALUE).asText())); + return SdkBindingDataFactory.of(Instant.parse(tree.get(VALUE).asText())); case FLOAT_VALUE: - return SdkBindingData.ofFloat(tree.get(VALUE).doubleValue()); + return SdkBindingDataFactory.of(tree.get(VALUE).doubleValue()); } throw new UnsupportedOperationException( "Type contains an unsupported primitive: " + primitiveKind); @@ -109,63 +112,79 @@ private static SdkBindingData transformScalar(JsonNode t @SuppressWarnings("unchecked") private SdkBindingData> transformCollection(JsonNode tree) { - LiteralType literalType = readLiteralType(tree.get(TYPE)); + SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); Iterator elements = tree.get(VALUE).elements(); - switch (literalType.getKind()) { + switch (literalType.getLiteralType().getKind()) { case SIMPLE_TYPE: case MAP_VALUE_TYPE: case COLLECTION_TYPE: - List> collection = - streamOf(elements).map(this::transform).collect(toList()); - return SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(literalType), (List>) collection); + List collection = + (List) + streamOf(elements).map(this::transform).map(SdkBindingData::get).collect(toList()); + return SdkBindingDataFactory.of(literalType, collection); case SCHEMA_TYPE: case BLOB_TYPE: default: throw new UnsupportedOperationException( - "Type contains a collection of an supported literal type: " + literalType.getKind()); + "Type contains a collection of an supported literal type: " + literalType); } } @SuppressWarnings("unchecked") private SdkBindingData> transformMap(JsonNode tree) { - LiteralType literalType = readLiteralType(tree.get(TYPE)); + SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); JsonNode valueNode = tree.get(VALUE); List> entries = streamOf(valueNode.fieldNames()) .map(name -> Map.entry(name, valueNode.get(name))) .collect(toList()); - switch (literalType.getKind()) { + switch (literalType.getLiteralType().getKind()) { case SIMPLE_TYPE: case MAP_VALUE_TYPE: case COLLECTION_TYPE: - Map> bindingDataMap = + Map bindingDataMap = entries.stream() - .map( - entry -> - Map.entry(entry.getKey(), (SdkBindingData) transform(entry.getValue()))) + .map(entry -> Map.entry(entry.getKey(), (T) transform(entry.getValue()).get())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - return SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(literalType), bindingDataMap); + return SdkBindingDataFactory.of(literalType, bindingDataMap); case SCHEMA_TYPE: case BLOB_TYPE: default: throw new UnsupportedOperationException( - "Type contains a map of an supported literal type: " + literalType.getKind()); + "Type contains a map of an supported literal type: " + literalType); } } - private LiteralType readLiteralType(JsonNode typeNode) { + private SdkLiteralType readLiteralType(JsonNode typeNode) { LiteralType.Kind kind = LiteralType.Kind.valueOf(typeNode.get(KIND).asText()); switch (kind) { case SIMPLE_TYPE: - return LiteralType.ofSimpleType(SimpleType.valueOf(typeNode.get(VALUE).asText())); + SimpleType simpleType = SimpleType.valueOf(typeNode.get(VALUE).asText()); + switch (simpleType) { + case INTEGER: + return SdkLiteralTypes.integers(); + case FLOAT: + return SdkLiteralTypes.floats(); + case STRING: + return SdkLiteralTypes.strings(); + case BOOLEAN: + return SdkLiteralTypes.booleans(); + case DATETIME: + return SdkLiteralTypes.datetimes(); + case DURATION: + return SdkLiteralTypes.durations(); + case STRUCT: + // not yet supported, fallthrough + } + throw new UnsupportedOperationException( + "Type contains a collection/map of an supported literal type: " + kind); case MAP_VALUE_TYPE: - return LiteralType.ofMapValueType(readLiteralType(typeNode.get(VALUE).get(TYPE))); + return SdkLiteralTypes.maps(readLiteralType(typeNode.get(VALUE).get(TYPE))); case COLLECTION_TYPE: - return LiteralType.ofCollectionType(readLiteralType(typeNode.get(VALUE).get(TYPE))); + return SdkLiteralTypes.collections(readLiteralType(typeNode.get(VALUE).get(TYPE))); case SCHEMA_TYPE: case BLOB_TYPE: diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index 5d971490d..28c9dca52 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -46,6 +46,8 @@ import org.flyte.api.v1.SimpleType; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; +import org.flyte.flytekit.SdkLiteralTypes; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -70,30 +72,18 @@ public static AutoValueInput createAutoValueInput( Map> ml, Map> mm) { return AutoValueInput.create( - SdkBindingData.ofInteger(i), - SdkBindingData.ofFloat(f), - SdkBindingData.ofString(s), - SdkBindingData.ofBoolean(b), - SdkBindingData.ofDatetime(t), - SdkBindingData.ofDuration(d), - SdkBindingData.ofStringCollection(l), - SdkBindingData.ofStringMap(m), - SdkBindingData.ofCollection( - ll, - LiteralType.ofCollectionType(LiteralType.ofCollectionType(LiteralTypes.STRING)), - SdkBindingData::ofStringCollection), - SdkBindingData.ofCollection( - lm, - LiteralType.ofCollectionType(LiteralType.ofMapValueType(LiteralTypes.STRING)), - SdkBindingData::ofStringMap), - SdkBindingData.ofMap( - ml, - LiteralType.ofMapValueType(LiteralType.ofCollectionType(LiteralTypes.STRING)), - SdkBindingData::ofStringCollection), - SdkBindingData.ofMap( - mm, - LiteralType.ofMapValueType(LiteralType.ofMapValueType(LiteralTypes.STRING)), - SdkBindingData::ofStringMap)); + SdkBindingDataFactory.of(i), + SdkBindingDataFactory.of(f), + SdkBindingDataFactory.of(s), + SdkBindingDataFactory.of(b), + SdkBindingDataFactory.of(t), + SdkBindingDataFactory.of(d), + SdkBindingDataFactory.ofStringCollection(l), + SdkBindingDataFactory.ofStringMap(m), + SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll), + SdkBindingDataFactory.of(SdkLiteralTypes.maps(SdkLiteralTypes.strings()), lm), + SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ml), + SdkBindingDataFactory.of(SdkLiteralTypes.maps(SdkLiteralTypes.strings()), mm)); } @Test @@ -329,7 +319,7 @@ public void testToSdkBindingDataMap() { public void testToSdkBindingDataMapJsonProperties() { JsonPropertyClassInput input = new JsonPropertyClassInput( - SdkBindingData.ofString("test"), SdkBindingData.ofString("name")); + SdkBindingDataFactory.of("test"), SdkBindingDataFactory.of("name")); Map> sdkBindingDataMap = JacksonSdkType.of(JsonPropertyClassInput.class).toSdkBindingMap(input); @@ -340,10 +330,10 @@ public void testToSdkBindingDataMapJsonProperties() { } public static class JsonPropertyClassInput { - @JsonProperty SdkBindingData test; + @JsonProperty final SdkBindingData test; @JsonProperty("name") - SdkBindingData otherTest; + final SdkBindingData otherTest; @JsonCreator public JsonPropertyClassInput(SdkBindingData test, SdkBindingData otherTest) { @@ -355,7 +345,7 @@ public JsonPropertyClassInput(SdkBindingData test, SdkBindingData literalMap = JacksonSdkType.of(PojoInput.class).toLiteralMap(input); @@ -365,7 +355,7 @@ public void testPojoToLiteralMap() { @Test public void testPojoFromLiteralMap() { PojoInput expected = new PojoInput(); - expected.a = SdkBindingData.ofInteger(42); + expected.a = SdkBindingDataFactory.of(42); PojoInput pojoInput = JacksonSdkType.of(PojoInput.class) @@ -454,32 +444,32 @@ void testPromiseFor() { assertThat( autoValueInput.i(), - equalTo(SdkBindingData.ofOutputReference("node-id", "i", LiteralTypes.INTEGER))); + equalTo(SdkBindingData.promise(SdkLiteralTypes.integers(), "node-id", "i"))); assertThat( autoValueInput.f(), - equalTo(SdkBindingData.ofOutputReference("node-id", "f", LiteralTypes.FLOAT))); + equalTo(SdkBindingData.promise(SdkLiteralTypes.floats(), "node-id", "f"))); assertThat( autoValueInput.s(), - equalTo(SdkBindingData.ofOutputReference("node-id", "s", LiteralTypes.STRING))); + equalTo(SdkBindingData.promise(SdkLiteralTypes.strings(), "node-id", "s"))); assertThat( autoValueInput.b(), - equalTo(SdkBindingData.ofOutputReference("node-id", "b", LiteralTypes.BOOLEAN))); + equalTo(SdkBindingData.promise(SdkLiteralTypes.booleans(), "node-id", "b"))); assertThat( autoValueInput.t(), - equalTo(SdkBindingData.ofOutputReference("node-id", "t", LiteralTypes.DATETIME))); + equalTo(SdkBindingData.promise(SdkLiteralTypes.datetimes(), "node-id", "t"))); assertThat( autoValueInput.d(), - equalTo(SdkBindingData.ofOutputReference("node-id", "d", LiteralTypes.DURATION))); + equalTo(SdkBindingData.promise(SdkLiteralTypes.durations(), "node-id", "d"))); assertThat( autoValueInput.l(), equalTo( - SdkBindingData.ofOutputReference( - "node-id", "l", LiteralType.ofCollectionType(LiteralTypes.STRING)))); + SdkBindingData.promise( + SdkLiteralTypes.collections(SdkLiteralTypes.strings()), "node-id", "l"))); assertThat( autoValueInput.m(), equalTo( - SdkBindingData.ofOutputReference( - "node-id", "m", LiteralType.ofMapValueType(LiteralTypes.STRING)))); + SdkBindingData.promise( + SdkLiteralTypes.maps(SdkLiteralTypes.strings()), "node-id", "m"))); } @Test @@ -711,9 +701,4 @@ private static Variable createVar(LiteralType literalType, String description) { private static Literal literalOf(Primitive primitive) { return Literal.ofScalar(Scalar.ofPrimitive(primitive)); } - - // @SuppressWarnings({"unused"}) - // private static Literal literalOf(Blob blob) { - // return Literal.ofScalar(Scalar.ofBlob(blob)); - // } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/Compiler.java b/flytekit-java/src/main/java/org/flyte/flytekit/Compiler.java index 82ed0ff07..386bd76d3 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/Compiler.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/Compiler.java @@ -57,7 +57,7 @@ static List validateApply( continue; } - LiteralType actualType = input.type(); + LiteralType actualType = input.type().getLiteralType(); LiteralType expectedType = variable.literalType(); if (!actualType.equals(expectedType)) { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingData.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingData.java index 452c49d97..d4bc9c7ae 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingData.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingData.java @@ -17,495 +17,261 @@ package org.flyte.flytekit; import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toMap; +import static java.util.stream.Collectors.toUnmodifiableList; +import static java.util.stream.Collectors.toUnmodifiableMap; +import static org.flyte.flytekit.SdkLiteralTypes.collections; +import static org.flyte.flytekit.SdkLiteralTypes.maps; import com.google.auto.value.AutoValue; -import java.time.Duration; -import java.time.Instant; -import java.time.LocalDate; -import java.time.ZoneOffset; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.function.Function; -import java.util.stream.Collectors; -import javax.annotation.Nullable; import org.flyte.api.v1.BindingData; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.OutputReference; -import org.flyte.api.v1.Primitive; -import org.flyte.api.v1.Scalar; -import org.flyte.api.v1.SimpleType; -/** Specifies either a simple value or a reference to another output. */ -@AutoValue +/** + * Specifies either a literal value or a promise that reference to the output of a node. The {@link + * SdkBindingDataFactory} class provides factory methods for {@link SdkBindingData} of different + * types. + */ public abstract class SdkBindingData { abstract BindingData idl(); - abstract LiteralType type(); - - @Nullable - abstract T value(); - - // TODO: it would be interesting to see if we can use java 9 modules to only expose this method - // to other modules in the sdk /** - * Creates a {@code SdkBindingData} based on its components; however it is not meant to be used by - * users directly, but users must use the higher level factory methods. + * Returns the {@link SdkLiteralType} type of this instance. * - * @param idl the api class equivalent to this - * @param type the SdkBindingData type - * @param value when {@code idl} is not a {@link BindingData.Kind#PROMISE} then value contains the - * simple value of this class, must be null otherwise - * @return A newly created SdkBindingData - * @param the java or scala type for the corresponding LiteralType, for example {@code - * Duration} for {@code LiteralType.ofSimpleType(SimpleType.DURATION)} + * @return the type of this instance. */ - public static SdkBindingData create(BindingData idl, LiteralType type, @Nullable T value) { - return new AutoValue_SdkBindingData<>(idl, type, value); - } + public abstract SdkLiteralType type(); /** - * Creates a {@code SdkBindingData} for a flyte integer ({@link Long} for java) with the given - * value. + * Returns the literal value contained by this data. * - * @param value the simple value for this data - * @return the new {@code SdkBindingData} + * @return the literal value that this instance holds + * @throws IllegalArgumentException when this data is a promise for the output of a node */ - public static SdkBindingData ofInteger(long value) { - return ofPrimitive(Primitive::ofIntegerValue, value); - } + public abstract T get(); /** - * Creates a {@code SdkBindingData} for a flyte float ({@link Double} for java) with the given - * value. + * Returns a version of this {@code SdkBindingData} with a new type. * - * @param value the simple value for this data - * @return the new {@code SdkBindingData} + * @param newType the {@link SdkLiteralType} type to be casted to + * @param castFunction function to apply to the value to be converted to the new type + * @return the type casted version of this instance + * @param the java or scala type for the corresponding to {@code newType} + * @throws UnsupportedOperationException if a cast cannot be performed over this instance. */ - public static SdkBindingData ofFloat(double value) { - return ofPrimitive(Primitive::ofFloatValue, value); - } + public abstract SdkBindingData as( + SdkLiteralType newType, Function castFunction); /** - * Creates a {@code SdkBindingData} for a flyte String with the given value. + * Creates a {@code SdkBindingData} for a literal value. * - * @param value the simple value for this data - * @return the new {@code SdkBindingData} + * @param type the {@link SdkLiteralType} type + * @param value contains the simple value of this class + * @return A newly created SdkBindingData + * @param the java or scala type for the corresponding LiteralType, for example {@code + * Duration} for {@code LiteralType.ofSimpleType(SimpleType.DURATION)} */ - public static SdkBindingData ofString(String value) { - return ofPrimitive(Primitive::ofStringValue, value); + public static SdkBindingData literal(SdkLiteralType type, T value) { + return Literal.create(type, value); } /** - * Creates a {@code SdkBindingData} for a flyte boolean with the given value. + * Creates a {@code SdkBindingData} for a reference to (promise for) another output. * - * @param value the simple value for this data - * @return the new {@code SdkBindingData} + * @param type the {@link SdkLiteralType} type + * @param nodeId which nodeId to reference + * @param var variable name to reference on the node id + * @return A newly created SdkBindingData + * @param the java or scala type for the corresponding LiteralType, for example {@code + * Duration} for {@code LiteralType.ofSimpleType(SimpleType.DURATION)} */ - public static SdkBindingData ofBoolean(boolean value) { - return ofPrimitive(Primitive::ofBooleanValue, value); + public static SdkBindingData promise(SdkLiteralType type, String nodeId, String var) { + return Promise.create(type, nodeId, var); } /** - * Creates a {@code SdkBindingData} for a flyte Datetime ({@link Instant} for java) with the given - * date at 00:00 on UTC. + * Creates a {@code SdkBindingData} for a collections of {@link SdkBindingData}. * - * @param year the year to represent, from {@code Year.MIN_VALUE} to {@code Year.MAX_VALUE} - * @param month the month-of-year to represent, from 1 (January) to 12 (December) - * @param day the day-of-month to represent, from 1 to 31 - * @return the new {@code SdkBindingData} + * @param elementType the {@link SdkLiteralType} of the elements of the collection. + * @param collection collections of {@link SdkBindingData}s + * @return A newly created SdkBindingData + * @param the java or scala type for the corresponding LiteralType, for example {@code + * Duration} for {@code LiteralType.ofSimpleType(SimpleType.DURATION)} */ - public static SdkBindingData ofDatetime(int year, int month, int day) { - Instant instant = LocalDate.of(year, month, day).atStartOfDay().toInstant(ZoneOffset.UTC); - return ofDatetime(instant); + public static SdkBindingData> bindingCollection( + SdkLiteralType elementType, List> collection) { + return BindingCollection.create(elementType, collection); } /** - * Creates a {@code SdkBindingData} for a flyte Datetime ({@link Instant} for java) with the given - * value. + * Creates a {@code SdkBindingData} for a map of {@link SdkBindingData}. * - * @param value the simple value for this data - * @return the new {@code SdkBindingData} + * @param valuesType the {@link SdkLiteralType} of the elements of the collection. + * @param map map of {@link SdkBindingData}s + * @return A newly created SdkBindingData + * @param the java or scala type for the corresponding LiteralType, for example {@code + * Duration} for {@code LiteralType.ofSimpleType(SimpleType.DURATION)} */ - public static SdkBindingData ofDatetime(Instant value) { - return ofPrimitive(Primitive::ofDatetime, value); + public static SdkBindingData> bindingMap( + SdkLiteralType valuesType, Map> map) { + return BindingMap.create(valuesType, map); } - /** - * Creates a {@code SdkBindingData} for a flyte Duration for java with the given value. - * - * @param value the simple value for this data - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData ofDuration(Duration value) { - return ofPrimitive(Primitive::ofDuration, value); - } + @AutoValue + abstract static class Literal extends SdkBindingData { + abstract T value(); - private static SdkBindingData ofPrimitive(Function toPrimitive, T value) { - Primitive primitive = toPrimitive.apply(value); - BindingData bindingData = BindingData.ofScalar(Scalar.ofPrimitive(primitive)); - LiteralType literalType = LiteralType.ofSimpleType(getSimpleType(primitive.kind())); + private static Literal create(SdkLiteralType type, T value) { + return new AutoValue_SdkBindingData_Literal<>(type, value); + } - return create(bindingData, literalType, value); - } + @Override + BindingData idl() { + return type().toBindingData(value()); + } - // TODO: ofCollection and ofMap receive a literal type for itself, it would be simpler if they - // receive the element literal type instead - /** - * Creates a {@code SdkBindingData} for a flyte collection given a java {@code List} and a - * function to know how to convert each element form such list to a {@code SdkBindingData}. - * - * @param collection collection to represent on this data. - * @param literalType literal type for the whole collection. It must be a {@link - * LiteralType.Kind#COLLECTION_TYPE}. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofCollection( - List collection, LiteralType literalType, Function> mapper) { - return SdkBindingData.ofBindingCollection( - literalType, collection.stream().map(mapper).collect(Collectors.toList())); - } + @Override + public T get() { + return value(); + } - private static SdkBindingData> createCollection( - List collection, LiteralType literalType, Function bindingDataFn) { - return create( - BindingData.ofCollection( - collection.stream().map(bindingDataFn).collect(Collectors.toList())), - literalType, - collection); - } + @Override + public SdkBindingData as( + SdkLiteralType newType, Function castFunction) { + return create(newType, castFunction.apply(value())); + } - /** - * Creates a {@code SdkBindingData} for a flyte collection of string given a java {@code - * List}. - * - * @param collection collection to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofStringCollection(List collection) { - return createCollection( - collection, - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.STRING)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(value)))); + @Override + public final String toString() { + return String.format("SdkBindingData{type=%s, value=%s}", type(), value()); + } } - /** - * Creates a {@code SdkBindingData} for a flyte collection of float given a java {@code - * List}. - * - * @param collection collection to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofFloatCollection(List collection) { - return createCollection( - collection, - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.FLOAT)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofFloatValue(value)))); - } + @AutoValue + abstract static class Promise extends SdkBindingData { + abstract String nodeId(); - /** - * Creates a {@code SdkBindingData} for a flyte collection of integer given a java {@code - * List}. - * - * @param collection collection to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofIntegerCollection(List collection) { - return createCollection( - collection, - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.INTEGER)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(value)))); - } + abstract String var(); - /** - * Creates a {@code SdkBindingData} for a flyte collection of boolean given a java {@code - * List}. - * - * @param collection collection to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofBooleanCollection(List collection) { - return createCollection( - collection, - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.BOOLEAN)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofBooleanValue(value)))); - } + private static Promise create(SdkLiteralType type, String nodeId, String var) { + return new AutoValue_SdkBindingData_Promise<>(type, nodeId, var); + } - /** - * Creates a {@code SdkBindingData} for a flyte collection of Duration given a java {@code - * List}. - * - * @param collection collection to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofDurationCollection(List collection) { - return createCollection( - collection, - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.DURATION)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofDuration(value)))); - } + @Override + BindingData idl() { + return BindingData.ofOutputReference( + OutputReference.builder().nodeId(nodeId()).var(var()).build()); + } - /** - * Creates a {@code SdkBindingData} for a flyte collection of datetime given a java {@code - * List}. - * - * @param collection collection to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofDatetimeCollection(List collection) { - return createCollection( - collection, - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.DATETIME)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofDatetime(value)))); - } + @Override + public T get() { + throw new IllegalArgumentException( + String.format( + "Value only available at workflow execution time: promise of %s[%s]", + nodeId(), var())); + } - /** - * Creates a {@code SdkBindingData} for a flyte map given a java {@code Map} and a - * function to know how to convert each entry values form such map to a {@code SdkBindingData}. - * - * @param map map to represent on this data. - * @param literalType literal type for the whole collection. It must be a {@link - * LiteralType.Kind#MAP_VALUE_TYPE}. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofMap( - Map map, LiteralType literalType, Function> bindingFunction) { - return SdkBindingData.ofBindingMap( - literalType, - map.entrySet().stream() - .map(e -> Map.entry(e.getKey(), bindingFunction.apply(e.getValue()))) - .collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); - } + @Override + public SdkBindingData as( + SdkLiteralType newType, Function castFunction) { + return create(newType, nodeId(), var()); + } - private static SdkBindingData> createMap( - Map map, LiteralType literalType, Function bindingDataFn) { - return create( - BindingData.ofMap( - map.entrySet().stream() - .map(entry -> Map.entry(entry.getKey(), bindingDataFn.apply(entry.getValue()))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))), - literalType, - map); + @Override + public final String toString() { + return String.format("SdkBindingData{type=%s, nodeIs=%s, var=%s}", type(), nodeId(), var()); + } } - /** - * Creates a {@code SdkBindingData} for a flyte map of string given a java {@code Map}. - * - * @param map map to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofStringMap(Map map) { - return createMap( - map, - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.STRING)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(value)))); - } + @AutoValue + abstract static class BindingCollection extends SdkBindingData> { + abstract List> bindingCollection(); - /** - * Creates a {@code SdkBindingData} for a flyte map of float given a java {@code Map}. - * - * @param map map to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofFloatMap(Map map) { - return createMap( - map, - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.FLOAT)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofFloatValue(value)))); - } + private static BindingCollection create( + SdkLiteralType elementType, List> bindingCollection) { + checkIncompatibleTypes(elementType, bindingCollection); + return new AutoValue_SdkBindingData_BindingCollection<>( + collections(elementType), bindingCollection); + } - /** - * Creates a {@code SdkBindingData} for a flyte map of integer given a java {@code Map}. - * - * @param map map to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofIntegerMap(Map map) { - return createMap( - map, - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.INTEGER)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(value)))); - } + @Override + BindingData idl() { + return BindingData.ofCollection( + bindingCollection().stream().map(SdkBindingData::idl).collect(toUnmodifiableList())); + } - /** - * Creates a {@code SdkBindingData} for a flyte map of boolean given a java {@code Map}. - * - * @param map map to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofBooleanMap(Map map) { - return createMap( - map, - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.BOOLEAN)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofBooleanValue(value)))); - } + @Override + public List get() { + return bindingCollection().stream().map(SdkBindingData::get).collect(toUnmodifiableList()); + } - /** - * Creates a {@code SdkBindingData} for a flyte map of duration given a java {@code Map}. - * - * @param map map to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofDurationMap(Map map) { - return createMap( - map, - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.DURATION)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofDuration(value)))); - } + @Override + public SdkBindingData as( + SdkLiteralType newElementType, Function, NewT> castFunction) { + throw new UnsupportedOperationException( + "SdkBindingData of binding collections cannot be casted"); + } - /** - * Creates a {@code SdkBindingData} for a flyte map of datetime given a java {@code Map}. - * - * @param map map to represent on this data. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofDatetimeMap(Map map) { - return createMap( - map, - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.DATETIME)), - (value) -> BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofDatetime(value)))); + @Override + public final String toString() { + return String.format("SdkBindingData{type=%s, collection=%s}", type(), bindingCollection()); + } } - // TODO: ordering of parameters is inconsistent with other methods here - /** - * Creates a {@code SdkBindingData} for a flyte collection given a java {@code - * List>} and a literalType tp be used. - * - * @param elements collection to represent on this data. - * @param literalType literal type for the whole collection. It must be a {@link - * LiteralType.Kind#COLLECTION_TYPE}. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofBindingCollection( - LiteralType literalType, List> elements) { - List bindings = elements.stream().map(SdkBindingData::idl).collect(toList()); - BindingData bindingData = BindingData.ofCollection(bindings); + @AutoValue + public abstract static class BindingMap extends SdkBindingData> { + abstract Map> bindingMap(); + + private static BindingMap create( + SdkLiteralType valuesType, Map> bindingMap) { + checkIncompatibleTypes(valuesType, bindingMap.values()); + return new AutoValue_SdkBindingData_BindingMap<>(maps(valuesType), bindingMap); + } - checkIncompatibleTypes(literalType.collectionType(), elements); - boolean hasPromise = bindings.stream().anyMatch(SdkBindingData::isAPromise); - List unwrappedElements = - hasPromise ? null : elements.stream().map(SdkBindingData::get).collect(toList()); + @Override + BindingData idl() { + return BindingData.ofMap( + bindingMap().entrySet().stream() + .collect(toUnmodifiableMap(Map.Entry::getKey, e -> e.getValue().idl()))); + } - return SdkBindingData.create(bindingData, literalType, unwrappedElements); + @Override + public Map get() { + return bindingMap().entrySet().stream() + .collect(toUnmodifiableMap(Map.Entry::getKey, e -> e.getValue().get())); + } + + @Override + public SdkBindingData as( + SdkLiteralType newType, Function, NewT> castFunction) { + throw new UnsupportedOperationException("SdkBindingData of binding map cannot be casted"); + } + + @Override + public final String toString() { + return String.format("SdkBindingData{type=%s, map=%s}", type(), bindingMap()); + } } private static void checkIncompatibleTypes( - LiteralType literalType, List> elements) { + SdkLiteralType elementType, Collection> elements) { List incompatibleTypes = elements.stream() .map(SdkBindingData::type) + .filter(type -> !type.equals(elementType)) + .map(SdkLiteralType::getLiteralType) .distinct() - .filter(type -> !type.equals(literalType)) .collect(toList()); if (!incompatibleTypes.isEmpty()) { throw new IllegalArgumentException( String.format( "Type mismatch: expected all elements of type %s but found some elements of type: %s", - literalType, incompatibleTypes)); - } - } - - private static boolean isAPromise(BindingData bindingData) { - switch (bindingData.kind()) { - case SCALAR: - return false; - case PROMISE: - return true; - case COLLECTION: - return bindingData.collection().stream().anyMatch(SdkBindingData::isAPromise); - case MAP: - return bindingData.map().values().stream().anyMatch(SdkBindingData::isAPromise); - } - throw new IllegalArgumentException("BindingData.Kind not recognized: " + bindingData.kind()); - } - - /** - * Creates a {@code SdkBindingData} for a flyte map given a java {@code Map>} and a literalType tp be used. - * - * @param valueMap collection to represent on this data. - * @param literalType literal type for the whole map. It must be a {@link - * LiteralType.Kind#MAP_VALUE_TYPE}. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData> ofBindingMap( - LiteralType literalType, Map> valueMap) { - - Map bindings = - valueMap.entrySet().stream() - .map(e -> Map.entry(e.getKey(), e.getValue().idl())) - .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); - BindingData bindingData = BindingData.ofMap(bindings); - - boolean hasPromise = bindings.values().stream().anyMatch(SdkBindingData::isAPromise); - Map unwrappedElements = - hasPromise - ? null - : valueMap.entrySet().stream() - .map(e -> Map.entry(e.getKey(), e.getValue().get())) - .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); - - return SdkBindingData.create(bindingData, literalType, unwrappedElements); - } - - /** - * Creates a {@code SdkBindingData} for a flyte output reference. - * - * @param nodeId references to what node id this reference points to. - * @param nodeVar name of the output variable that this reference points to. - * @param type literal type of the referenced variable. - * @return the new {@code SdkBindingData} - */ - public static SdkBindingData ofOutputReference( - String nodeId, String nodeVar, LiteralType type) { - BindingData idl = - BindingData.ofOutputReference( - OutputReference.builder().nodeId(nodeId).var(nodeVar).build()); - // promises don't contain values yet - return create(idl, type, null); - } - - /** - * Returns the simple value contained by this data. - * - * @return the value that this simple data holds - * @throws IllegalArgumentException when this data is an output reference - */ - public T get() { - if (idl().kind() == BindingData.Kind.PROMISE) { - OutputReference promise = idl().promise(); - throw new IllegalArgumentException( - String.format( - "Value only available at workflow execution time: promise of %s[%s]", - promise.nodeId(), promise.var())); - } - - return value(); - } - - private static SimpleType getSimpleType(Primitive.Kind kind) { - switch (kind) { - case INTEGER_VALUE: - return SimpleType.INTEGER; - case FLOAT_VALUE: - return SimpleType.FLOAT; - case STRING_VALUE: - return SimpleType.STRING; - case BOOLEAN_VALUE: - return SimpleType.BOOLEAN; - case DATETIME: - return SimpleType.DATETIME; - case DURATION: - return SimpleType.DURATION; + elementType.getLiteralType(), incompatibleTypes)); } - - throw new AssertionError("Unexpected Primitive.Kind: " + kind); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java new file mode 100644 index 000000000..4a0ce54d2 --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -0,0 +1,299 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import static org.flyte.flytekit.SdkLiteralTypes.collections; +import static org.flyte.flytekit.SdkLiteralTypes.maps; +import static org.flyte.flytekit.SdkLiteralTypes.strings; + +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.ZoneOffset; +import java.util.List; +import java.util.Map; + +/** A utility class for creating {@link SdkBindingData} objects for different types. */ +public final class SdkBindingDataFactory { + + private SdkBindingDataFactory() { + // prevent instantiation + } + + /** + * Creates a {@code SdkBindingData} for a flyte integer ({@link Long} for java) with the given + * value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(long value) { + return SdkBindingData.literal(SdkLiteralTypes.integers(), value); + } + + /** + * Creates a {@code SdkBindingData} for a flyte float ({@link Double} for java) with the given + * value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(double value) { + return SdkBindingData.literal(SdkLiteralTypes.floats(), value); + } + + /** + * Creates a {@code SdkBindingData} for a flyte String with the given value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(String value) { + return SdkBindingData.literal(strings(), value); + } + + /** + * Creates a {@code SdkBindingData} for a flyte boolean with the given value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(boolean value) { + return SdkBindingData.literal(SdkLiteralTypes.booleans(), value); + } + + /** + * Creates a {@code SdkBindingData} for a flyte Datetime ({@link Instant} for java) with the given + * date at 00:00 on UTC. + * + * @param year the year to represent, from {@code Year.MIN_VALUE} to {@code Year.MAX_VALUE} + * @param month the month-of-year to represent, from 1 (January) to 12 (December) + * @param day the day-of-month to represent, from 1 to 31 + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(int year, int month, int day) { + Instant instant = LocalDate.of(year, month, day).atStartOfDay().toInstant(ZoneOffset.UTC); + return of(instant); + } + + /** + * Creates a {@code SdkBindingData} for a flyte Datetime ({@link Instant} for java) with the given + * value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(Instant value) { + return SdkBindingData.literal(SdkLiteralTypes.datetimes(), value); + } + + /** + * Creates a {@code SdkBindingData} for a flyte Duration for java with the given value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(Duration value) { + return SdkBindingData.literal(SdkLiteralTypes.durations(), value); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection given a java {@code List} and the + * elements type. + * + * @param elementType a {@link SdkLiteralType} for the collection elements type. + * @param collection collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> of(SdkLiteralType elementType, List collection) { + return SdkBindingData.literal(collections(elementType), collection); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection of string given a java {@code + * List}. + * + * @param collection collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofStringCollection(List collection) { + return of(strings(), collection); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection of float given a java {@code + * List}. + * + * @param collection collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofFloatCollection(List collection) { + return of(SdkLiteralTypes.floats(), collection); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection of integer given a java {@code + * List}. + * + * @param collection collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofIntegerCollection(List collection) { + return of(SdkLiteralTypes.integers(), collection); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection of boolean given a java {@code + * List}. + * + * @param collection collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofBooleanCollection(List collection) { + return of(SdkLiteralTypes.booleans(), collection); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection of Duration given a java {@code + * List}. + * + * @param collection collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofDurationCollection(List collection) { + return of(SdkLiteralTypes.durations(), collection); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection of datetime given a java {@code + * List}. + * + * @param collection collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofDatetimeCollection(List collection) { + return of(SdkLiteralTypes.datetimes(), collection); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map given a java {@code Map} and a + * function to know how to convert each entry values form such map to a {@code SdkBindingData}. + * + * @param valuesType literal type for the values of the map, keys are always strings. + * @param map map to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> of( + SdkLiteralType valuesType, Map map) { + return SdkBindingData.literal(maps(valuesType), map); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map of string given a java {@code Map}. + * + * @param map map to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofStringMap(Map map) { + return of(strings(), map); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map of float given a java {@code Map}. + * + * @param map map to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofFloatMap(Map map) { + return of(SdkLiteralTypes.floats(), map); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map of integer given a java {@code Map}. + * + * @param map map to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofIntegerMap(Map map) { + return of(SdkLiteralTypes.integers(), map); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map of boolean given a java {@code Map}. + * + * @param map map to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofBooleanMap(Map map) { + return of(SdkLiteralTypes.booleans(), map); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map of duration given a java {@code Map}. + * + * @param map map to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofDurationMap(Map map) { + return of(SdkLiteralTypes.durations(), map); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map of datetime given a java {@code Map}. + * + * @param map map to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofDatetimeMap(Map map) { + return of(SdkLiteralTypes.datetimes(), map); + } + + /** + * Creates a {@code SdkBindingData} for a flyte collection given a java {@code + * List>} and {@link SdkLiteralType} for types for the elements. + * + * @param elementType a {@link SdkLiteralType} expressing the types for the elements in the + * collection. + * @param elements collection to represent on this data. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofBindingCollection( + SdkLiteralType elementType, List> elements) { + return SdkBindingData.bindingCollection(elementType, elements); + } + + /** + * Creates a {@code SdkBindingData} for a flyte map given a java {@code Map>} and a {@link SdkLiteralType} for the values of the map. + * + * @param valueMap collection to represent on this data. + * @param valuesType a {@link SdkLiteralType} expressing the types for the values of the map. The + * keys are always String. LiteralType.Kind#MAP_VALUE_TYPE}. + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData> ofBindingMap( + SdkLiteralType valuesType, Map> valueMap) { + + return SdkBindingData.bindingMap(valuesType, valueMap); + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java index 28fc6acc3..1695a5c25 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java @@ -30,7 +30,6 @@ import org.flyte.api.v1.Binding; import org.flyte.api.v1.BranchNode; import org.flyte.api.v1.IfElseBlock; -import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Node; import org.flyte.api.v1.NodeError; @@ -38,7 +37,7 @@ public class SdkBranchNode extends SdkNode { private final String nodeId; private final SdkIfElseBlock ifElse; - private final Map outputTypes; + private final Map> outputTypes; private final List upstreamNodeIds; private final OutputT outputs; @@ -48,7 +47,7 @@ private SdkBranchNode( String nodeId, List upstreamNodeIds, SdkIfElseBlock ifElse, - Map outputTypes, + Map> outputTypes, OutputT outputs) { super(builder); @@ -72,8 +71,8 @@ public OutputT getOutputs() { return outputs; } - private SdkBindingData createOutput(Map.Entry entry) { - return SdkBindingData.ofOutputReference(nodeId, entry.getKey(), entry.getValue()); + private SdkBindingData createOutput(Map.Entry> entry) { + return SdkBindingData.promise(entry.getValue(), nodeId, entry.getKey()); } /** {@inheritDoc} */ @@ -121,7 +120,7 @@ static class Builder { private final List ifBlocks = new ArrayList<>(); private SdkNode elseNode; - private Map outputTypes; + private Map> outputTypes; Builder(SdkWorkflowBuilder builder, SdkType outputType) { this.builder = builder; @@ -134,7 +133,7 @@ Builder addCase(SdkConditionCase case_) { case_.then().apply(builder, case_.name(), List.of(), /*metadata=*/ null, Map.of()); Map> thatOutputs = sdkNode.getOutputBindings(); - Map thatOutputTypes = + Map> thatOutputTypes = thatOutputs.entrySet().stream() .collect(toUnmodifiableMap(Map.Entry::getKey, x -> x.getValue().type())); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java index 8fd3b99ae..168b5053e 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java @@ -153,7 +153,7 @@ public static SdkBooleanExpression lte(SdkBindingData left, SdkBindingDat public static SdkBooleanExpression isTrue(SdkBindingData data) { return ofComparison( SdkComparisonExpression.create( - data, SdkBindingData.ofBoolean(true), ComparisonExpression.Operator.EQ)); + data, SdkBindingDataFactory.of(true), ComparisonExpression.Operator.EQ)); } /** @@ -165,6 +165,6 @@ public static SdkBooleanExpression isTrue(SdkBindingData data) { public static SdkBooleanExpression isFalse(SdkBindingData data) { return ofComparison( SdkComparisonExpression.create( - data, SdkBindingData.ofBoolean(false), ComparisonExpression.Operator.EQ)); + data, SdkBindingDataFactory.of(false), ComparisonExpression.Operator.EQ)); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java index 22c01d7f8..056fd96a9 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java @@ -109,14 +109,7 @@ public SdkNode apply( } return new SdkTaskNode<>( - builder, - nodeId, - taskId, - upstreamNodeIds, - metadata, - inputs, - outputType.getVariableMap(), - outputType.promiseFor(nodeId)); + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); } /** Specifies container image. */ diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java index 4c162adf8..59f8c05c4 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java @@ -79,14 +79,7 @@ public SdkNode apply( } return new SdkTaskNode<>( - builder, - nodeId, - taskId, - upstreamNodeIds, - metadata, - inputs, - outputType.getVariableMap(), - outputType.promiseFor(nodeId)); + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); } public abstract OutputT run(SdkWorkflowBuilder builder, InputT input); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java index beb3cc801..d98b2127b 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java @@ -96,7 +96,8 @@ public static SdkLaunchPlan of(SdkWorkflow workflow) { return builder() .name(workflow.getName()) .workflowName(workflow.getName()) - .workflowInputTypeMap(toWorkflowInputTypeMap(wfBuilder.getInputs(), SdkBindingData::type)) + .workflowInputTypeMap( + toWorkflowInputTypeMap(wfBuilder.getInputs(), in -> in.type().getLiteralType())) .build(); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralType.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralType.java index 8eae02ddc..4856736a5 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralType.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralType.java @@ -16,15 +16,70 @@ */ package org.flyte.flytekit; +import org.flyte.api.v1.BindingData; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; -// TODO: this class it is not used. We should remove it or even better use it in place of -// raw literal types in SdkBinding data -abstract class SdkLiteralType { +/** + * Bridge between the a Java type and a variable in Flyte. + * + * @param the Java native type to bridge. + */ +public abstract class SdkLiteralType { + /** + * Returns the {@link LiteralType} corresponding to this type. + * + * @return the literal type. + */ public abstract LiteralType getLiteralType(); + /** + * Coverts the value into a {@link Literal}. + * + * @param value value to convert. + * @return the literal. + */ public abstract Literal toLiteral(T value); + /** + * Coverts a {@link Literal} into a value. + * + * @param literal literal to convert. + * @return the value. + */ public abstract T fromLiteral(Literal literal); + + /** + * Coverts the value into a {@link BindingData}. + * + * @param value value to convert. + * @return the binding data. + */ + public abstract BindingData toBindingData(T value); + + /** + * {@inheritDoc} + * + *

Hashcode is computed based on {@link #getLiteralType()} + */ + @Override + public final int hashCode() { + return getLiteralType().hashCode(); + } + + /** + * {@inheritDoc} + * + *

Equals comparing only {@link #getLiteralType()}. Simplifies equality among the several + * implementation of this class. + */ + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } else if (obj instanceof SdkLiteralType) { + return this.getLiteralType().equals(((SdkLiteralType) obj).getLiteralType()); + } + return false; + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java index ea996205e..d52131c18 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java @@ -16,56 +16,172 @@ */ package org.flyte.flytekit; -import static java.util.Collections.unmodifiableMap; import static java.util.stream.Collectors.toUnmodifiableList; +import static java.util.stream.Collectors.toUnmodifiableMap; import java.time.Duration; import java.time.Instant; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import org.flyte.api.v1.BindingData; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; -// TODO: this class it is not used. We should remove it or even better use it in place of -// raw literal types in SdkBinding data -class SdkLiteralTypes { +/** A utility class for creating {@link SdkLiteralType} objects for different types. */ +public class SdkLiteralTypes { + private SdkLiteralTypes() { + // prevent instantiation + } + + /** + * Returns a {@link SdkLiteralType} for the specified Java type. + * + *

+ *
{@code Long.class} {@code ->} + *
{@code SdkLiteralType}, equivalent to {@link #integers()} + *
{@code Double.class} {@code ->} + *
{@code SdkLiteralType}, equivalent to {@link #floats()} + *
{@code String.class} {@code ->} + *
{@code SdkLiteralType}, equivalent to {@link #strings()} + *
{@code Boolean.class} {@code ->} + *
{@code SdkLiteralType}, equivalent to {@link #booleans()} + *
{@code Instant.class} {@code ->} + *
{@code SdkLiteralType}, equivalent to {@link #datetimes()} + *
{@code Duration.class} {@code ->} + *
{@code SdkLiteralType}, equivalent to {@link #durations()} + *
+ * + * @param clazz Java type used to decide what {@link SdkLiteralType} to return. + * @return the {@link SdkLiteralType} based on the java type + * @param type of the returned {@link SdkLiteralType}, matching the one specified. + */ + @SuppressWarnings("unchecked") + public static SdkLiteralType of(Class clazz) { + if (clazz.equals(Long.class)) { + return (SdkLiteralType) integers(); + } else if (clazz.equals(Double.class)) { + return (SdkLiteralType) floats(); + } else if (clazz.equals(String.class)) { + return (SdkLiteralType) strings(); + } else if (clazz.equals(Boolean.class)) { + return (SdkLiteralType) booleans(); + } else if (clazz.equals(Instant.class)) { + return (SdkLiteralType) datetimes(); + } else if (clazz.equals(Duration.class)) { + return (SdkLiteralType) durations(); + } + throw new IllegalArgumentException("Unsupported type: " + clazz); + } + + /** + * Returns a {@link SdkLiteralType} for a collection of the specified Java type. Equivalent to + * {code SdkLiteralTypes.collections(SdkLiteralTypes.of(elementsClass} + * + * @param elementsClass Java type used to decide what {@link SdkLiteralType} to return. + * @return the {@link SdkLiteralType} based on the java type + * @param type of the elements of the collections for the returned {@link SdkLiteralType}. + * @see SdkLiteralTypes#of + */ + public static SdkLiteralType> collections(Class elementsClass) { + return collections(of(elementsClass)); + } + + /** + * Returns a {@link SdkLiteralType} for a map of the specified Java type. Equivalent to {code + * SdkLiteralTypes.maps(SdkLiteralTypes.of(valuesType} + * + * @param valuesType Java type used to decide what {@link SdkLiteralType} to return. + * @return the {@link SdkLiteralType} based on the java type + * @param type of the values of the map for the returned {@link SdkLiteralType}. Key types are + * always {@code String}. + * @see SdkLiteralTypes#of + */ + public static SdkLiteralType> maps(Class valuesType) { + return maps(of(valuesType)); + } + + /** + * Returns a {@link SdkLiteralType} for flyte integers. + * + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType integers() { return IntegerSdkLiteralType.INSTANCE; } + /** + * Returns a {@link SdkLiteralType} for flyte floats. + * + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType floats() { return FloatSdkLiteralType.INSTANCE; } + /** + * Returns a {@link SdkLiteralType} for strings. + * + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType strings() { return StringSdkLiteralType.INSTANCE; } + /** + * Returns a {@link SdkLiteralType} for booleans. + * + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType booleans() { return BooleanSdkLiteralType.INSTANCE; } + /** + * Returns a {@link SdkLiteralType} for flyte date times. + * + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType datetimes() { return DatetimeSdkLiteralType.INSTANCE; } + /** + * Returns a {@link SdkLiteralType} for durations. + * + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType durations() { return DurationSdkLiteralType.INSTANCE; } + /** + * Returns a {@link SdkLiteralType} for flyte collections. + * + * @param elementType the {@link SdkLiteralType} representing the types of the elements of the + * collection. + * @param the Java type of the elements of the collection. + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType> collections(SdkLiteralType elementType) { return new CollectionSdkLiteralType<>(elementType); } + /** + * Returns a {@link SdkLiteralType} for flyte maps. + * + * @param mapValueType the {@link SdkLiteralType} representing the types of the map's values. + * @param the Java type of the map's values, keys are always string. + * @return the {@link SdkLiteralType} + */ public static SdkLiteralType> maps(SdkLiteralType mapValueType) { return new MapSdkLiteralType<>(mapValueType); } - private static class IntegerSdkLiteralType extends SdkLiteralType { + private static class IntegerSdkLiteralType extends PrimitiveSdkLiteralType { private static final IntegerSdkLiteralType INSTANCE = new IntegerSdkLiteralType(); @Override @@ -74,17 +190,22 @@ public LiteralType getLiteralType() { } @Override - public Literal toLiteral(Long value) { - return Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(value))); + public Primitive toPrimitive(Long value) { + return Primitive.ofIntegerValue(value); + } + + @Override + public Long fromPrimitive(Primitive primitive) { + return primitive.integerValue(); } @Override - public Long fromLiteral(Literal literal) { - return literal.scalar().primitive().integerValue(); + public String toString() { + return "integers"; } } - private static class FloatSdkLiteralType extends SdkLiteralType { + private static class FloatSdkLiteralType extends PrimitiveSdkLiteralType { private static final FloatSdkLiteralType INSTANCE = new FloatSdkLiteralType(); @Override @@ -93,17 +214,22 @@ public LiteralType getLiteralType() { } @Override - public Literal toLiteral(Double value) { - return Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofFloatValue(value))); + public Primitive toPrimitive(Double value) { + return Primitive.ofFloatValue(value); } @Override - public Double fromLiteral(Literal literal) { - return literal.scalar().primitive().floatValue(); + public Double fromPrimitive(Primitive primitive) { + return primitive.floatValue(); + } + + @Override + public String toString() { + return "floats"; } } - private static class StringSdkLiteralType extends SdkLiteralType { + private static class StringSdkLiteralType extends PrimitiveSdkLiteralType { private static final StringSdkLiteralType INSTANCE = new StringSdkLiteralType(); @Override @@ -112,17 +238,22 @@ public LiteralType getLiteralType() { } @Override - public Literal toLiteral(String value) { - return Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(value))); + public Primitive toPrimitive(String value) { + return Primitive.ofStringValue(value); + } + + @Override + public String fromPrimitive(Primitive primitive) { + return primitive.stringValue(); } @Override - public String fromLiteral(Literal literal) { - return literal.scalar().primitive().stringValue(); + public String toString() { + return "strings"; } } - private static class BooleanSdkLiteralType extends SdkLiteralType { + private static class BooleanSdkLiteralType extends PrimitiveSdkLiteralType { private static final BooleanSdkLiteralType INSTANCE = new BooleanSdkLiteralType(); @Override @@ -131,17 +262,22 @@ public LiteralType getLiteralType() { } @Override - public Literal toLiteral(Boolean value) { - return Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofBooleanValue(value))); + public Primitive toPrimitive(Boolean value) { + return Primitive.ofBooleanValue(value); + } + + @Override + public Boolean fromPrimitive(Primitive primitive) { + return primitive.booleanValue(); } @Override - public Boolean fromLiteral(Literal literal) { - return literal.scalar().primitive().booleanValue(); + public String toString() { + return "booleans"; } } - private static class DatetimeSdkLiteralType extends SdkLiteralType { + private static class DatetimeSdkLiteralType extends PrimitiveSdkLiteralType { private static final DatetimeSdkLiteralType INSTANCE = new DatetimeSdkLiteralType(); @Override @@ -150,17 +286,22 @@ public LiteralType getLiteralType() { } @Override - public Literal toLiteral(Instant value) { - return Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofDatetime(value))); + public Primitive toPrimitive(Instant value) { + return Primitive.ofDatetime(value); + } + + @Override + public Instant fromPrimitive(Primitive primitive) { + return primitive.datetime(); } @Override - public Instant fromLiteral(Literal literal) { - return literal.scalar().primitive().datetime(); + public String toString() { + return "datetimes"; } } - private static class DurationSdkLiteralType extends SdkLiteralType { + private static class DurationSdkLiteralType extends PrimitiveSdkLiteralType { private static final DurationSdkLiteralType INSTANCE = new DurationSdkLiteralType(); @Override @@ -169,13 +310,18 @@ public LiteralType getLiteralType() { } @Override - public Literal toLiteral(Duration value) { - return Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofDuration(value))); + public Primitive toPrimitive(Duration value) { + return Primitive.ofDuration(value); } @Override - public Duration fromLiteral(Literal literal) { - return literal.scalar().primitive().duration(); + public Duration fromPrimitive(Primitive primitive) { + return primitive.duration(); + } + + @Override + public String toString() { + return "durations"; } } @@ -205,40 +351,79 @@ public List fromLiteral(Literal literal) { .map(elementType::fromLiteral) .collect(toUnmodifiableList()); } + + @Override + public BindingData toBindingData(List value) { + return BindingData.ofCollection( + value.stream().map(elementType::toBindingData).collect(toUnmodifiableList())); + } + + @Override + public String toString() { + return "collections of [" + elementType + ']'; + } } private static class MapSdkLiteralType extends SdkLiteralType> { - private final SdkLiteralType mapKeyType; + private final SdkLiteralType valuesType; - private MapSdkLiteralType(SdkLiteralType mapKeyType) { - this.mapKeyType = mapKeyType; + private MapSdkLiteralType(SdkLiteralType valuesType) { + this.valuesType = valuesType; } @Override public LiteralType getLiteralType() { - return LiteralType.ofMapValueType(mapKeyType.getLiteralType()); + return LiteralType.ofMapValueType(valuesType.getLiteralType()); + } + + @Override + public Literal toLiteral(java.util.Map value) { + var map = + value.entrySet().stream() + .collect(toUnmodifiableMap(Entry::getKey, e -> valuesType.toLiteral(e.getValue()))); + + return Literal.ofMap(map); } @Override - public Literal toLiteral(Map value) { - Map map = new LinkedHashMap<>(); + public java.util.Map fromLiteral(Literal literal) { + return literal.map().entrySet().stream() + .collect(toUnmodifiableMap(Entry::getKey, e -> valuesType.fromLiteral(e.getValue()))); + } + + @Override + public BindingData toBindingData(java.util.Map value) { + return BindingData.ofMap( + value.entrySet().stream() + .collect( + toUnmodifiableMap(Entry::getKey, e -> valuesType.toBindingData(e.getValue())))); + } + + @Override + public String toString() { + return "map of [" + valuesType + ']'; + } + } - for (Map.Entry entry : value.entrySet()) { - map.put(entry.getKey(), mapKeyType.toLiteral(entry.getValue())); - } + private abstract static class PrimitiveSdkLiteralType extends SdkLiteralType { - return Literal.ofMap(unmodifiableMap(map)); + @Override + public final Literal toLiteral(T value) { + return Literal.ofScalar(Scalar.ofPrimitive(toPrimitive(value))); } + public abstract Primitive toPrimitive(T value); + @Override - public Map fromLiteral(Literal literal) { - Map map = new LinkedHashMap<>(); + public final T fromLiteral(Literal literal) { + return fromPrimitive(literal.scalar().primitive()); + } - for (Map.Entry entry : literal.map().entrySet()) { - map.put(entry.getKey(), mapKeyType.fromLiteral(entry.getValue())); - } + public abstract T fromPrimitive(Primitive primitive); - return unmodifiableMap(map); + @Override + public final BindingData toBindingData(T value) { + return BindingData.ofScalar(Scalar.ofPrimitive(toPrimitive(value))); } } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java index eed9861f7..82b1d542f 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java @@ -16,8 +16,6 @@ */ package org.flyte.flytekit; -import static java.util.stream.Collectors.toUnmodifiableMap; - import com.google.auto.value.AutoValue; import java.util.List; import java.util.Map; @@ -120,16 +118,9 @@ public SdkNode apply( throw new CompilerException(errors); } - Map> outputs = - outputs().getVariableMap().entrySet().stream() - .collect( - toUnmodifiableMap( - Map.Entry::getKey, - entry -> - SdkBindingData.ofOutputReference( - nodeId, entry.getKey(), entry.getValue().literalType()))); - + Map> outputs = outputs().promiseMapFor(nodeId); OutputT promise = getOutputType().promiseFor(nodeId); + return new SdkWorkflowNode<>( builder, nodeId, diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java index 22a8d6a38..7df8be867 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java @@ -21,7 +21,6 @@ import java.util.Map; import javax.annotation.Nullable; import org.flyte.api.v1.PartialTaskIdentifier; -import org.flyte.api.v1.Variable; /** Reference to a task deployed in flyte, a remote Task. */ @AutoValue @@ -115,10 +114,8 @@ public SdkNode apply( throw new CompilerException(errors); } - Map variableMap = outputs().getVariableMap(); - OutputT output = outputs().promiseFor(nodeId); return new SdkTaskNode<>( - builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, variableMap, output); + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, getOutputType()); } static Builder builder() { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java index 52893a822..e363b9472 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java @@ -21,7 +21,6 @@ import java.util.Map; import javax.annotation.Nullable; import org.flyte.api.v1.PartialTaskIdentifier; -import org.flyte.api.v1.Variable; /** Building block for tasks that execute Java code. */ public abstract class SdkRunnableTask extends SdkTransform @@ -127,10 +126,8 @@ public SdkNode apply( throw new CompilerException(errors); } - Map variableMap = outputType.getVariableMap(); - OutputT output = outputType.promiseFor(nodeId); return new SdkTaskNode<>( - builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, variableMap, output); + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); } public abstract OutputT run(InputT input); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java index a1d25de19..58352b5be 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java @@ -17,7 +17,6 @@ package org.flyte.flytekit; import static java.util.stream.Collectors.toUnmodifiableList; -import static java.util.stream.Collectors.toUnmodifiableMap; import java.util.List; import java.util.Map; @@ -26,7 +25,6 @@ import org.flyte.api.v1.Node; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.TaskNode; -import org.flyte.api.v1.Variable; /** Represent a {@link org.flyte.flytekit.SdkRunnableTask} in a workflow DAG. */ public class SdkTaskNode extends SdkNode { @@ -35,8 +33,7 @@ public class SdkTaskNode extends SdkNode { private final List upstreamNodeIds; @Nullable private final SdkNodeMetadata metadata; private final Map> inputs; - private final Map outputVars; - private final T outputs; + private final SdkType outputsType; SdkTaskNode( SdkWorkflowBuilder builder, @@ -45,8 +42,7 @@ public class SdkTaskNode extends SdkNode { List upstreamNodeIds, @Nullable SdkNodeMetadata metadata, Map> inputs, - Map outputVars, - T outputs) { + SdkType outputsType) { super(builder); this.nodeId = nodeId; @@ -54,26 +50,19 @@ public class SdkTaskNode extends SdkNode { this.upstreamNodeIds = upstreamNodeIds; this.metadata = metadata; this.inputs = inputs; - this.outputVars = outputVars; - this.outputs = outputs; + this.outputsType = outputsType; } /** {@inheritDoc} */ @Override public Map> getOutputBindings() { - return outputVars.entrySet().stream() - .collect( - toUnmodifiableMap( - Map.Entry::getKey, - entry -> - SdkBindingData.ofOutputReference( - nodeId, entry.getKey(), entry.getValue().literalType()))); + return outputsType.promiseMapFor(nodeId); } /** {@inheritDoc} */ @Override public T getOutputs() { - return outputs; + return outputsType.promiseFor(nodeId); } /** {@inheritDoc} */ diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java index 53d458b3e..61634e879 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java @@ -16,10 +16,11 @@ */ package org.flyte.flytekit; +import static java.util.stream.Collectors.toUnmodifiableMap; + import java.util.Map; import java.util.Set; import org.flyte.api.v1.Literal; -import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Variable; /** @@ -47,14 +48,21 @@ public abstract class SdkType { public abstract T fromLiteralMap(Map value); /** - * Returns a value composed of {@link SdkBindingData#ofOutputReference(String, String, - * LiteralType)} for the supplied node is. + * Returns a value composed of {@link SdkBindingData#promise(SdkLiteralType, String, String)} for + * the supplied node is. * * @param nodeId the node id that the value is a promise for. * @return the value. */ public abstract T promiseFor(String nodeId); + public final Map> promiseMapFor(String nodeId) { + return toLiteralTypes().entrySet().stream() + .collect( + toUnmodifiableMap( + Map.Entry::getKey, e -> SdkBindingData.promise(e.getValue(), nodeId, e.getKey()))); + } + /** * Returns a variable map for the properties for {@link T}. * @@ -62,6 +70,13 @@ public abstract class SdkType { */ public abstract Map getVariableMap(); + /** + * Returns the {@link SdkLiteralType} map bay variable name corresponding to this type. + * + * @return the literal type. + */ + public abstract Map> toLiteralTypes(); + /** * Returns the names for the properties for {@link T}. * diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java index 5a5e66bf6..e5971cf67 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java @@ -58,6 +58,11 @@ public Map getVariableMap() { return Map.of(); } + @Override + public Map> toLiteralTypes() { + return Map.of(); + } + @Override public Map> toSdkBindingMap(Void value) { return Map.of(); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java index 0f66d65a9..51aac07fd 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java @@ -64,13 +64,14 @@ protected SdkWorkflow(SdkType inputType, SdkType outputType) { * @param builder workflow builder that this workflow expands into. \ */ public final void expand(SdkWorkflowBuilder builder) { + var literalTypes = inputType.toLiteralTypes(); inputType .getVariableMap() .forEach( (name, variable) -> - builder.inputOf( + builder.setInput( + literalTypes.get(name), name, - variable.literalType(), variable.description() == null ? "" : variable.description())); OutputT output = expand(builder, inputType.promiseFor(START_NODE_ID)); @@ -119,8 +120,7 @@ public SdkNode apply( .collect( Collectors.toMap( Map.Entry::getKey, - e -> - SdkBindingData.ofOutputReference(nodeId, e.getKey(), e.getValue().type()))); + e -> SdkBindingData.promise(e.getValue().type(), nodeId, e.getKey()))); var promise = getOutputType().promiseFor(nodeId); return new SdkWorkflowNode<>( diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java index 26e5709e2..1aeb929cc 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java @@ -26,7 +26,6 @@ import java.util.Map; import java.util.Objects; import javax.annotation.Nullable; -import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.WorkflowTemplate; /** Builder used during {@link SdkWorkflow#expand(SdkWorkflowBuilder)}. */ @@ -134,14 +133,11 @@ protected SdkNode applyInternal( return sdkNode; } - SdkBindingData inputOf(String name, LiteralType literalType, String help) { - SdkBindingData bindingData = - SdkBindingData.ofOutputReference(START_NODE_ID, name, literalType); + void setInput(SdkLiteralType type, String name, String help) { + SdkBindingData bindingData = SdkBindingData.promise(type, START_NODE_ID, name); inputDescriptions.put(name, help); inputs.put(name, bindingData); - - return bindingData; } /** Returns the nodes by id. */ diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/WorkflowTemplateIdl.java b/flytekit-java/src/main/java/org/flyte/flytekit/WorkflowTemplateIdl.java index 22436add3..550792cdf 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/WorkflowTemplateIdl.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/WorkflowTemplateIdl.java @@ -72,7 +72,7 @@ private static Map toVariableMap( entry -> { Variable variable = Variable.builder() - .literalType(entry.getValue().type()) + .literalType(entry.getValue().type().getLiteralType()) .description(nameToDescription.apply(entry.getKey())) .build(); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataFactoryTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataFactoryTest.java new file mode 100644 index 000000000..207cb3f9f --- /dev/null +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataFactoryTest.java @@ -0,0 +1,228 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import static java.time.ZoneOffset.UTC; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static org.flyte.flytekit.SdkLiteralTypes.booleans; +import static org.flyte.flytekit.SdkLiteralTypes.collections; +import static org.flyte.flytekit.SdkLiteralTypes.datetimes; +import static org.flyte.flytekit.SdkLiteralTypes.durations; +import static org.flyte.flytekit.SdkLiteralTypes.floats; +import static org.flyte.flytekit.SdkLiteralTypes.integers; +import static org.flyte.flytekit.SdkLiteralTypes.maps; +import static org.flyte.flytekit.SdkLiteralTypes.strings; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; + +public class SdkBindingDataFactoryTest { + + @Test + public void testOfBindingCollection() { + List collection = List.of(42L, 1337L); + + SdkBindingData> output = SdkBindingDataFactory.ofIntegerCollection(collection); + + assertThat(output.get(), equalTo(collection)); + assertThat(output.type(), equalTo(collections(integers()))); + } + + @Test + public void testOfBindingCollection_empty() { + SdkBindingData> output = + SdkBindingDataFactory.ofBindingCollection(integers(), emptyList()); + + assertThat(output.get(), equalTo(List.of())); + assertThat(output.type(), equalTo(collections(integers()))); + } + + @Test + public void testOfStringCollection() { + List expectedValue = List.of("1", "2"); + + SdkBindingData> output = SdkBindingDataFactory.ofStringCollection(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(collections(strings()))); + } + + @Test + public void testOfFloatCollection() { + List expectedValue = List.of(1.1, 1.2); + + SdkBindingData> output = SdkBindingDataFactory.ofFloatCollection(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(collections(floats()))); + } + + @Test + public void testOfIntegerCollection() { + List expectedValue = List.of(1L, 2L); + + SdkBindingData> output = SdkBindingDataFactory.ofIntegerCollection(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(collections(integers()))); + } + + @Test + public void testOfBooleanCollection() { + List expectedValue = List.of(true, false); + + SdkBindingData> output = SdkBindingDataFactory.ofBooleanCollection(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(collections(booleans()))); + } + + @Test + public void testOfDurationCollection() { + List expectedValue = List.of(Duration.ofDays(1), Duration.ofDays(2)); + + SdkBindingData> output = + SdkBindingDataFactory.ofDurationCollection(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(collections(durations()))); + } + + @Test + public void testOfDatetimeCollection() { + Instant first = LocalDate.of(2022, 1, 16).atStartOfDay().toInstant(UTC); + Instant second = LocalDate.of(2022, 1, 17).atStartOfDay().toInstant(UTC); + + List expectedValue = List.of(first, second); + + SdkBindingData> output = + SdkBindingDataFactory.ofDatetimeCollection(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(collections(datetimes()))); + } + + @Test + public void testOfBindingMap() { + Map> input = + Map.of( + "a", SdkBindingDataFactory.of(42L), + "b", SdkBindingDataFactory.of(1337L)); + + SdkBindingData> output = + SdkBindingDataFactory.ofBindingMap(integers(), input); + + assertThat(output.get(), equalTo(Map.of("a", 42L, "b", 1337L))); + assertThat(output.type(), equalTo(maps(integers()))); + } + + @Test + public void testOfBindingMap_empty() { + SdkBindingData> output = + SdkBindingDataFactory.ofBindingMap(integers(), emptyMap()); + + assertThat(output.get(), equalTo(Map.of())); + assertThat(output.type(), equalTo(maps(integers()))); + } + + @Test + public void testOfStringMap() { + Map expectedValue = + Map.of( + "a", "1", + "b", "2"); + + SdkBindingData> output = SdkBindingDataFactory.ofStringMap(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(maps(strings()))); + } + + @Test + public void testOfFloatMap() { + Map expectedValue = + Map.of( + "a", 1.1, + "b", 1.2); + + SdkBindingData> output = SdkBindingDataFactory.ofFloatMap(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(maps(floats()))); + } + + @Test + public void testOfIntegerMap() { + Map expectedValue = Map.of("a", 1L, "b", 2L); + + SdkBindingData> output = SdkBindingDataFactory.ofIntegerMap(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(maps(integers()))); + } + + @Test + public void testOfBooleanMap() { + Map expectedValue = + Map.of( + "a", true, + "b", false); + + SdkBindingData> output = SdkBindingDataFactory.ofBooleanMap(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(maps(booleans()))); + } + + @Test + public void testOfDurationMap() { + Map expectedValue = + Map.of( + "a", Duration.ofDays(1), + "b", Duration.ofDays(2)); + + SdkBindingData> output = + SdkBindingDataFactory.ofDurationMap(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(maps(durations()))); + } + + @Test + public void testOfDatetimeMap() { + Instant first = LocalDate.of(2022, 1, 16).atStartOfDay().toInstant(UTC); + Instant second = LocalDate.of(2022, 1, 17).atStartOfDay().toInstant(UTC); + + Map expectedValue = + Map.of( + "a", first, + "b", second); + + SdkBindingData> output = + SdkBindingDataFactory.ofDatetimeMap(expectedValue); + + assertThat(output.get(), equalTo(expectedValue)); + assertThat(output.type(), equalTo(maps(datetimes()))); + } +} diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataTest.java deleted file mode 100644 index d6446cc98..000000000 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkBindingDataTest.java +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Copyright 2021 Flyte Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.flyte.flytekit; - -import static java.time.ZoneOffset.UTC; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; - -import java.time.Duration; -import java.time.Instant; -import java.time.LocalDate; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import org.flyte.api.v1.BindingData; -import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.Primitive; -import org.flyte.api.v1.Scalar; -import org.flyte.api.v1.SimpleType; -import org.junit.jupiter.api.Test; - -public class SdkBindingDataTest { - - @Test - public void testOfBindingCollection() { - List> input = - Arrays.asList(SdkBindingData.ofInteger(42L), SdkBindingData.ofInteger(1337L)); - - List expected = - Arrays.asList( - BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(42L))), - BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(1337L)))); - - SdkBindingData> output = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.INTEGER)), input); - - assertThat( - output, - equalTo( - SdkBindingData.create( - BindingData.ofCollection(expected), - LiteralType.ofCollectionType(LiteralTypes.INTEGER), - List.of(42L, 1337L)))); - } - - @Test - public void testOfBindingCollection_empty() { - List expectedValue = emptyList(); - SdkBindingData> expected = - SdkBindingData.create( - BindingData.ofCollection(expectedValue), - LiteralType.ofCollectionType(LiteralTypes.INTEGER), - emptyList()); - SdkBindingData> output = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.INTEGER)), - emptyList()); - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfStringCollection() { - List> input = - List.of(SdkBindingData.ofString("1"), SdkBindingData.ofString("2")); - - List expectedValue = List.of("1", "2"); - - SdkBindingData> expected = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralTypes.STRING), input); - - SdkBindingData> output = SdkBindingData.ofStringCollection(expectedValue); - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfFloatCollection() { - List> input = - List.of(SdkBindingData.ofFloat(1.1), SdkBindingData.ofFloat(1.2)); - - List expectedValue = List.of(1.1, 1.2); - - SdkBindingData> expected = - SdkBindingData.ofBindingCollection(LiteralType.ofCollectionType(LiteralTypes.FLOAT), input); - - SdkBindingData> output = SdkBindingData.ofFloatCollection(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfIntegerCollection() { - List> input = - List.of(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L)); - - List expectedValue = List.of(1L, 2L); - - SdkBindingData> expected = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralTypes.INTEGER), input); - - SdkBindingData> output = SdkBindingData.ofIntegerCollection(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfBooleanCollection() { - List> input = - List.of(SdkBindingData.ofBoolean(true), SdkBindingData.ofBoolean(false)); - - List expectedValue = List.of(true, false); - - SdkBindingData> expected = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralTypes.BOOLEAN), input); - - SdkBindingData> output = SdkBindingData.ofBooleanCollection(expectedValue); - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfDurationCollection() { - List> input = - List.of( - SdkBindingData.ofDuration(Duration.ofDays(1)), - SdkBindingData.ofDuration(Duration.ofDays(2))); - - List expectedValue = List.of(Duration.ofDays(1), Duration.ofDays(2)); - - SdkBindingData> expected = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralTypes.DURATION), input); - - SdkBindingData> output = SdkBindingData.ofDurationCollection(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfDatetimeCollection() { - Instant first = LocalDate.of(2022, 1, 16).atStartOfDay().toInstant(UTC); - Instant second = LocalDate.of(2022, 1, 17).atStartOfDay().toInstant(UTC); - - List> input = - List.of(SdkBindingData.ofDatetime(first), SdkBindingData.ofDatetime(second)); - - List expectedValue = List.of(first, second); - - SdkBindingData> expected = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralTypes.DATETIME), input); - - SdkBindingData> output = SdkBindingData.ofDatetimeCollection(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfBindingMap() { - Map> input = - Map.of( - "a", SdkBindingData.ofInteger(42L), - "b", SdkBindingData.ofInteger(1337L)); - - Map expected = - Map.of( - "a", - BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(42L))), - "b", - BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(1337L)))); - - SdkBindingData> output = - SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(LiteralTypes.INTEGER), input); - - assertThat( - output, - equalTo( - SdkBindingData.create( - BindingData.ofMap(expected), - LiteralType.ofMapValueType(LiteralTypes.INTEGER), - Map.of("a", 42L, "b", 1337L)))); - } - - @Test - public void testOfBindingMap_empty() { - Map expectedValue = emptyMap(); - SdkBindingData> expected = - SdkBindingData.create( - BindingData.ofMap(expectedValue), - LiteralType.ofMapValueType(LiteralTypes.INTEGER), - emptyMap()); - - SdkBindingData> output = - SdkBindingData.ofBindingMap( - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.INTEGER)), emptyMap()); - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfStringMap() { - Map> input = - Map.of( - "a", SdkBindingData.ofString("1"), - "b", SdkBindingData.ofString("2")); - - Map expectedValue = - Map.of( - "a", "1", - "b", "2"); - - SdkBindingData> expected = - SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(LiteralTypes.STRING), input); - - SdkBindingData> output = SdkBindingData.ofStringMap(expectedValue); - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfFloatMap() { - Map> input = - Map.of( - "a", SdkBindingData.ofFloat(1.1), - "b", SdkBindingData.ofFloat(1.2)); - - Map expectedValue = - Map.of( - "a", 1.1, - "b", 1.2); - - SdkBindingData> expected = - SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(LiteralTypes.FLOAT), input); - - SdkBindingData> output = SdkBindingData.ofFloatMap(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfIntegerMap() { - Map> input = - Map.of( - "a", SdkBindingData.ofInteger(1L), - "b", SdkBindingData.ofInteger(2L)); - - Map expectedValue = - Map.of( - "a", 1L, - "b", 2L); - - SdkBindingData> expected = - SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(LiteralTypes.INTEGER), input); - - SdkBindingData> output = SdkBindingData.ofIntegerMap(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfBooleanMap() { - Map> input = - Map.of( - "a", SdkBindingData.ofBoolean(true), - "b", SdkBindingData.ofBoolean(false)); - - Map expectedValue = - Map.of( - "a", true, - "b", false); - - SdkBindingData> expected = - SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(LiteralTypes.BOOLEAN), input); - - SdkBindingData> output = SdkBindingData.ofBooleanMap(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfDurationMap() { - Map> input = - Map.of( - "a", SdkBindingData.ofDuration(Duration.ofDays(1)), - "b", SdkBindingData.ofDuration(Duration.ofDays(2))); - - Map expectedValue = - Map.of( - "a", Duration.ofDays(1), - "b", Duration.ofDays(2)); - - SdkBindingData> expected = - SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(LiteralTypes.DURATION), input); - - SdkBindingData> output = SdkBindingData.ofDurationMap(expectedValue); - - assertThat(output, equalTo(expected)); - } - - @Test - public void testOfDatetimeMap() { - Instant first = LocalDate.of(2022, 1, 16).atStartOfDay().toInstant(UTC); - Instant second = LocalDate.of(2022, 1, 17).atStartOfDay().toInstant(UTC); - - Map> input = - Map.of( - "a", SdkBindingData.ofDatetime(first), - "b", SdkBindingData.ofDatetime(second)); - - Map expectedValue = - Map.of( - "a", first, - "b", second); - - SdkBindingData> expected = - SdkBindingData.ofBindingMap(LiteralType.ofMapValueType(LiteralTypes.DATETIME), input); - - SdkBindingData> output = SdkBindingData.ofDatetimeMap(expectedValue); - - assertThat(output, equalTo(expected)); - } -} diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java index d104b5704..83de100e2 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java @@ -97,7 +97,7 @@ void shouldAddFixedInputs() { Duration duration = Duration.ofSeconds(123); TestPairIntegerInput fixedInputs = - TestPairIntegerInput.create(SdkBindingData.ofInteger(456), SdkBindingData.ofInteger(789)); + TestPairIntegerInput.create(SdkBindingDataFactory.of(456), SdkBindingDataFactory.of(789)); SdkLaunchPlan plan = SdkLaunchPlan.of(new TestWorkflow()) @@ -128,7 +128,7 @@ void shouldAddDefaultInputs() { Duration duration = Duration.ofSeconds(123); TestPairIntegerInput fixedInputs = - TestPairIntegerInput.create(SdkBindingData.ofInteger(456), SdkBindingData.ofInteger(789)); + TestPairIntegerInput.create(SdkBindingDataFactory.of(456), SdkBindingDataFactory.of(789)); SdkLaunchPlan plan = SdkLaunchPlan.of(new TestWorkflow()) @@ -311,27 +311,27 @@ public Map toLiteralMap(TestWorkflowInput value) { @Override public TestWorkflowInput fromLiteralMap(Map value) { return create( - SdkBindingData.ofInteger(value.get(INTEGER).scalar().primitive().integerValue()), - SdkBindingData.ofFloat(value.get(FLOAT).scalar().primitive().floatValue()), - SdkBindingData.ofString(value.get(STRING).scalar().primitive().stringValue()), - SdkBindingData.ofBoolean(value.get(BOOLEAN).scalar().primitive().booleanValue()), - SdkBindingData.ofDatetime(value.get(DATETIME).scalar().primitive().datetime()), - SdkBindingData.ofDuration(value.get(DURATION).scalar().primitive().duration()), - SdkBindingData.ofInteger(value.get(A).scalar().primitive().integerValue()), - SdkBindingData.ofInteger(value.get(B).scalar().primitive().integerValue())); + SdkBindingDataFactory.of(value.get(INTEGER).scalar().primitive().integerValue()), + SdkBindingDataFactory.of(value.get(FLOAT).scalar().primitive().floatValue()), + SdkBindingDataFactory.of(value.get(STRING).scalar().primitive().stringValue()), + SdkBindingDataFactory.of(value.get(BOOLEAN).scalar().primitive().booleanValue()), + SdkBindingDataFactory.of(value.get(DATETIME).scalar().primitive().datetime()), + SdkBindingDataFactory.of(value.get(DURATION).scalar().primitive().duration()), + SdkBindingDataFactory.of(value.get(A).scalar().primitive().integerValue()), + SdkBindingDataFactory.of(value.get(B).scalar().primitive().integerValue())); } @Override public TestWorkflowInput promiseFor(String nodeId) { return create( - SdkBindingData.ofOutputReference(nodeId, INTEGER, LiteralTypes.INTEGER), - SdkBindingData.ofOutputReference(nodeId, FLOAT, LiteralTypes.FLOAT), - SdkBindingData.ofOutputReference(nodeId, STRING, LiteralTypes.STRING), - SdkBindingData.ofOutputReference(nodeId, BOOLEAN, LiteralTypes.BOOLEAN), - SdkBindingData.ofOutputReference(nodeId, DATETIME, LiteralTypes.DATETIME), - SdkBindingData.ofOutputReference(nodeId, DURATION, LiteralTypes.DURATION), - SdkBindingData.ofOutputReference(nodeId, A, LiteralTypes.INTEGER), - SdkBindingData.ofOutputReference(nodeId, B, LiteralTypes.INTEGER)); + SdkBindingData.promise(SdkLiteralTypes.integers(), nodeId, INTEGER), + SdkBindingData.promise(SdkLiteralTypes.floats(), nodeId, FLOAT), + SdkBindingData.promise(SdkLiteralTypes.strings(), nodeId, STRING), + SdkBindingData.promise(SdkLiteralTypes.booleans(), nodeId, BOOLEAN), + SdkBindingData.promise(SdkLiteralTypes.datetimes(), nodeId, DATETIME), + SdkBindingData.promise(SdkLiteralTypes.durations(), nodeId, DURATION), + SdkBindingData.promise(SdkLiteralTypes.integers(), nodeId, A), + SdkBindingData.promise(SdkLiteralTypes.integers(), nodeId, B)); } @Override @@ -347,6 +347,19 @@ public Map getVariableMap() { Map.entry(B, Variable.builder().literalType(LiteralTypes.INTEGER).build())); } + @Override + public Map> toLiteralTypes() { + return Map.ofEntries( + Map.entry(INTEGER, SdkLiteralTypes.integers()), + Map.entry(FLOAT, SdkLiteralTypes.floats()), + Map.entry(STRING, SdkLiteralTypes.strings()), + Map.entry(BOOLEAN, SdkLiteralTypes.booleans()), + Map.entry(DATETIME, SdkLiteralTypes.datetimes()), + Map.entry(DURATION, SdkLiteralTypes.durations()), + Map.entry(A, SdkLiteralTypes.integers()), + Map.entry(B, SdkLiteralTypes.integers())); + } + @Override public Map> toSdkBindingMap(TestWorkflowInput value) { return Map.ofEntries( diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLiteralTypesTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLiteralTypesTest.java index db0fb2a07..5bfeb4120 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLiteralTypesTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLiteralTypesTest.java @@ -16,69 +16,110 @@ */ package org.flyte.flytekit; +import static org.flyte.flytekit.SdkLiteralTypes.booleans; +import static org.flyte.flytekit.SdkLiteralTypes.collections; +import static org.flyte.flytekit.SdkLiteralTypes.datetimes; +import static org.flyte.flytekit.SdkLiteralTypes.durations; +import static org.flyte.flytekit.SdkLiteralTypes.floats; +import static org.flyte.flytekit.SdkLiteralTypes.integers; +import static org.flyte.flytekit.SdkLiteralTypes.maps; +import static org.flyte.flytekit.SdkLiteralTypes.strings; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.time.Duration; import java.time.Instant; import java.util.Arrays; -import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; public class SdkLiteralTypesTest { + @ParameterizedTest + @ArgumentsSource(TestOfProvider.class) + void testOf(SdkLiteralType expected, SdkLiteralType actual) { + assertEquals(expected, actual); + } + + static class TestOfProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) { + return Stream.of( + Arguments.of(integers(), SdkLiteralTypes.of(Long.class)), + Arguments.of(floats(), SdkLiteralTypes.of(Double.class)), + Arguments.of(strings(), SdkLiteralTypes.of(String.class)), + Arguments.of(booleans(), SdkLiteralTypes.of(Boolean.class)), + Arguments.of(datetimes(), SdkLiteralTypes.of(Instant.class)), + Arguments.of(durations(), SdkLiteralTypes.of(Duration.class)), + Arguments.of(collections(integers()), collections(Long.class)), + Arguments.of(collections(floats()), collections(Double.class)), + Arguments.of(collections(strings()), collections(String.class)), + Arguments.of(collections(booleans()), collections(Boolean.class)), + Arguments.of(collections(datetimes()), collections(Instant.class)), + Arguments.of(collections(durations()), collections(Duration.class)), + Arguments.of(maps(integers()), maps(Long.class)), + Arguments.of(maps(floats()), maps(Double.class)), + Arguments.of(maps(strings()), maps(String.class)), + Arguments.of(maps(booleans()), maps(Boolean.class)), + Arguments.of(maps(datetimes()), maps(Instant.class)), + Arguments.of(maps(durations()), maps(Duration.class))); + } + } + @Test public void testIntegers() { - assertThat(roundtrip(SdkLiteralTypes.integers(), 42L), equalTo(42L)); + assertThat(roundtrip(integers(), 42L), equalTo(42L)); } @Test public void testDoubles() { - assertThat(roundtrip(SdkLiteralTypes.floats(), 1337.0), equalTo(1337.0)); + assertThat(roundtrip(floats(), 1337.0), equalTo(1337.0)); } @Test public void testBoolean() { - assertThat(roundtrip(SdkLiteralTypes.booleans(), true), equalTo(true)); + assertThat(roundtrip(booleans(), true), equalTo(true)); } @Test public void testStrings() { - assertThat(roundtrip(SdkLiteralTypes.strings(), "forty-two"), equalTo("forty-two")); + assertThat(roundtrip(strings(), "forty-two"), equalTo("forty-two")); } @Test public void testDatetimes() { assertThat( - roundtrip(SdkLiteralTypes.datetimes(), Instant.ofEpochSecond(1337L)), + roundtrip(datetimes(), Instant.ofEpochSecond(1337L)), equalTo(Instant.ofEpochSecond(1337L))); } @Test public void testDuration() { assertThat( - roundtrip(SdkLiteralTypes.durations(), Duration.ofSeconds(1337L)), - equalTo(Duration.ofSeconds(1337L))); + roundtrip(durations(), Duration.ofSeconds(1337L)), equalTo(Duration.ofSeconds(1337L))); } @Test public void testCollectionOfIntegers() { assertThat( - roundtrip( - SdkLiteralTypes.collections(SdkLiteralTypes.integers()), Arrays.asList(42L, 1337L)), + roundtrip(collections(integers()), List.of(42L, 1337L)), equalTo(Arrays.asList(42L, 1337L))); } @Test public void testMapOfIntegers() { - Map map = new LinkedHashMap<>(); - map.put("a", 42L); - map.put("b", 1337L); + Map map = Map.of("a", 42L, "b", 1337L); - assertThat( - roundtrip(SdkLiteralTypes.maps(SdkLiteralTypes.integers()), map).entrySet(), - equalTo(map.entrySet())); + assertThat(roundtrip(maps(integers()), map).entrySet(), equalTo(map.entrySet())); } public static T roundtrip(SdkLiteralType literalType, T value) { diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java index a9636c733..57c470e56 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java @@ -37,7 +37,7 @@ public class SdkRemoteLaunchPlanTest { @Test void applyShouldReturnASdkWorkflowNode() { var inputs = - TestPairIntegerInput.create(SdkBindingData.ofInteger(1), SdkBindingData.ofInteger(2)); + TestPairIntegerInput.create(SdkBindingDataFactory.of(1), SdkBindingDataFactory.of(2)); SdkRemoteLaunchPlan remoteLaunchPlan = new TestSdkRemoteLaunchPlan(); @@ -90,8 +90,7 @@ void applyShouldReturnASdkWorkflowNode() { is( singletonMap( "o", - SdkBindingData.ofOutputReference( - "some-node-id", "o", LiteralTypes.BOOLEAN))))); + SdkBindingData.promise(SdkLiteralTypes.booleans(), "some-node-id", "o"))))); } @SuppressWarnings("ExtendsAutoValue") diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java index ab8405452..80a09effc 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java @@ -38,7 +38,7 @@ class SdkRemoteTaskTest { @Test void applyShouldReturnASdkTaskNode() { var inputs = - TestPairIntegerInput.create(SdkBindingData.ofInteger(1), SdkBindingData.ofInteger(2)); + TestPairIntegerInput.create(SdkBindingDataFactory.of(1), SdkBindingDataFactory.of(2)); SdkRemoteTask remoteTask = new TestSdkRemoteTask(); @@ -90,8 +90,8 @@ void applyShouldReturnASdkTaskNode() { is( singletonMap( "o", - SdkBindingData.ofOutputReference( - "lookup-endsong", "o", LiteralTypes.BOOLEAN))))); + SdkBindingData.promise( + SdkLiteralTypes.booleans(), "lookup-endsong", "o"))))); } @SuppressWarnings("ExtendsAutoValue") diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkTransformTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkTransformTest.java index ed6c1a336..8dfebb773 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkTransformTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkTransformTest.java @@ -45,7 +45,7 @@ void applyShouldPropagateCallToSubClasses() { var nodeId = "node"; var upstreamNodeIds = List.of("upstream-node"); var metadata = SdkNodeMetadata.builder().name("fancy-name").build(); - var in = SdkBindingData.ofInteger(1); + var in = SdkBindingDataFactory.of(1); var inputs = TestUnaryIntegerInput.create(in); var inputsBindings = Map.>of("in", in); @@ -158,6 +158,11 @@ public Map getVariableMap() { return Map.of(); } + @Override + public Map> toLiteralTypes() { + return Map.of(); + } + @Override public Map> toSdkBindingMap(Object value) { return Map.of(); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java index 7bbe8059d..12f3feab9 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java @@ -229,8 +229,8 @@ void testConditionalWorkflowIdl() { void testDuplicateNodeId() { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - SdkBindingData a = SdkBindingData.ofInteger(10L); - SdkBindingData b = SdkBindingData.ofInteger(10L); + SdkBindingData a = SdkBindingDataFactory.of(10L); + SdkBindingData b = SdkBindingDataFactory.of(10L); TestPairIntegerInput input = TestPairIntegerInput.create(a, b); builder.apply("node-1", new MultiplicationTask(), input); @@ -252,8 +252,8 @@ void testUpstreamNode_withUpstreamNode( SdkTransform transform) { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - SdkBindingData a = SdkBindingData.ofInteger(10L); - SdkBindingData b = SdkBindingData.ofInteger(10L); + SdkBindingData a = SdkBindingDataFactory.of(10L); + SdkBindingData b = SdkBindingDataFactory.of(10L); TestPairIntegerInput input = TestPairIntegerInput.create(a, b); SdkNode el2 = builder.apply("el2", transform, input); @@ -310,8 +310,8 @@ void testNodeMetadataOverrides( SdkTransform transform) { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - SdkBindingData a = SdkBindingData.ofInteger(10L); - SdkBindingData b = SdkBindingData.ofInteger(10L); + SdkBindingData a = SdkBindingDataFactory.of(10L); + SdkBindingData b = SdkBindingDataFactory.of(10L); TestPairIntegerInput input = TestPairIntegerInput.create(a, b); @@ -337,8 +337,8 @@ void testNodeMetadataOverrides_duplicate( SdkTransform transform) { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - SdkBindingData a = SdkBindingData.ofInteger(10L); - SdkBindingData b = SdkBindingData.ofInteger(10L); + SdkBindingData a = SdkBindingDataFactory.of(10L); + SdkBindingData b = SdkBindingDataFactory.of(10L); TestPairIntegerInput input = TestPairIntegerInput.create(a, b); SdkNode el2 = builder.apply("el2", transform, input); @@ -392,7 +392,7 @@ protected Times4Workflow() { @Override public TestUnaryIntegerOutput expand(SdkWorkflowBuilder builder, TestUnaryIntegerInput input) { - SdkBindingData two = SdkBindingData.ofInteger(2L); + SdkBindingData two = SdkBindingDataFactory.of(2L); SdkBindingData out1 = builder @@ -419,7 +419,7 @@ private ConditionalWorkflow() { @Override public TestUnaryIntegerOutput expand(SdkWorkflowBuilder builder, TestUnaryIntegerInput input) { - SdkBindingData two = SdkBindingData.ofInteger(2L); + SdkBindingData two = SdkBindingDataFactory.of(2L); SdkNode out = builder.apply( diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java index e71850a8e..1b64892b1 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java @@ -46,15 +46,15 @@ public Map toLiteralMap(TestPairIntegerInput value) { @Override public TestPairIntegerInput fromLiteralMap(Map value) { return create( - SdkBindingData.ofInteger(value.get(A).scalar().primitive().integerValue()), - SdkBindingData.ofInteger(value.get(B).scalar().primitive().integerValue())); + SdkBindingDataFactory.of(value.get(A).scalar().primitive().integerValue()), + SdkBindingDataFactory.of(value.get(B).scalar().primitive().integerValue())); } @Override public TestPairIntegerInput promiseFor(String nodeId) { return create( - SdkBindingData.ofOutputReference(nodeId, A, LiteralTypes.INTEGER), - SdkBindingData.ofOutputReference(nodeId, B, LiteralTypes.INTEGER)); + SdkBindingData.promise(SdkLiteralTypes.integers(), nodeId, A), + SdkBindingData.promise(SdkLiteralTypes.integers(), nodeId, B)); } @Override @@ -64,6 +64,11 @@ public Map getVariableMap() { B, Variable.builder().literalType(LiteralTypes.INTEGER).build()); } + @Override + public Map> toLiteralTypes() { + return Map.of(A, SdkLiteralTypes.integers(), B, SdkLiteralTypes.integers()); + } + @Override public Map> toSdkBindingMap(TestPairIntegerInput value) { return Map.of(A, value.a(), B, value.b()); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java index 0fc661ea7..3f260fbdc 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java @@ -19,7 +19,6 @@ import com.google.auto.value.AutoValue; import java.util.Map; import org.flyte.api.v1.Literal; -import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Variable; @AutoValue @@ -33,7 +32,7 @@ public static TestUnaryBooleanOutput create(SdkBindingData o) { public static class SdkType extends org.flyte.flytekit.SdkType { private static final String VAR = "o"; - private static final LiteralType LITERAL_TYPE = LiteralTypes.BOOLEAN; + private static final SdkLiteralType BOOLEANS = SdkLiteralTypes.booleans(); @Override public Map toLiteralMap(TestUnaryBooleanOutput value) { @@ -42,17 +41,22 @@ public Map toLiteralMap(TestUnaryBooleanOutput value) { @Override public TestUnaryBooleanOutput fromLiteralMap(Map value) { - return create(SdkBindingData.ofBoolean(value.get(VAR).scalar().primitive().booleanValue())); + return create(SdkBindingDataFactory.of(value.get(VAR).scalar().primitive().booleanValue())); } @Override public TestUnaryBooleanOutput promiseFor(String nodeId) { - return create(SdkBindingData.ofOutputReference(nodeId, VAR, LITERAL_TYPE)); + return create(SdkBindingData.promise(BOOLEANS, nodeId, VAR)); } @Override public Map getVariableMap() { - return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build()); + return Map.of(VAR, Variable.builder().literalType(BOOLEANS.getLiteralType()).build()); + } + + @Override + public Map> toLiteralTypes() { + return Map.of(VAR, BOOLEANS); } @Override diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerInput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerInput.java index 71c877d96..4b805073e 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerInput.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerInput.java @@ -19,7 +19,6 @@ import com.google.auto.value.AutoValue; import java.util.Map; import org.flyte.api.v1.Literal; -import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Variable; @AutoValue @@ -34,7 +33,7 @@ public static TestUnaryIntegerInput create(SdkBindingData in) { public static class SdkType extends org.flyte.flytekit.SdkType { private static final String VAR = "in"; - private static final LiteralType LITERAL_TYPE = LiteralTypes.INTEGER; + private static final SdkLiteralType INTEGERS = SdkLiteralTypes.integers(); @Override public Map toLiteralMap(TestUnaryIntegerInput value) { @@ -43,17 +42,23 @@ public Map toLiteralMap(TestUnaryIntegerInput value) { @Override public TestUnaryIntegerInput fromLiteralMap(Map value) { - return create(SdkBindingData.ofInteger(value.get(VAR).scalar().primitive().integerValue())); + return create(SdkBindingDataFactory.of(value.get(VAR).scalar().primitive().integerValue())); } @Override public TestUnaryIntegerInput promiseFor(String nodeId) { - return create(SdkBindingData.ofOutputReference(nodeId, VAR, LITERAL_TYPE)); + return create(SdkBindingData.promise(INTEGERS, nodeId, VAR)); } @Override public Map getVariableMap() { - return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).description("").build()); + return Map.of( + VAR, Variable.builder().literalType(INTEGERS.getLiteralType()).description("").build()); + } + + @Override + public Map> toLiteralTypes() { + return Map.of(VAR, INTEGERS); } @Override diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java index 0ae4b3028..67bd5a101 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java @@ -19,7 +19,6 @@ import com.google.auto.value.AutoValue; import java.util.Map; import org.flyte.api.v1.Literal; -import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Variable; @AutoValue @@ -33,7 +32,7 @@ public static TestUnaryIntegerOutput create(SdkBindingData o) { public static class SdkType extends org.flyte.flytekit.SdkType { private static final String VAR = "o"; - private static final LiteralType LITERAL_TYPE = LiteralTypes.INTEGER; + private static final SdkLiteralType INTEGERS = SdkLiteralTypes.integers(); @Override public Map toLiteralMap(TestUnaryIntegerOutput value) { @@ -42,17 +41,22 @@ public Map toLiteralMap(TestUnaryIntegerOutput value) { @Override public TestUnaryIntegerOutput fromLiteralMap(Map value) { - return create(SdkBindingData.ofInteger(value.get(VAR).scalar().primitive().integerValue())); + return create(SdkBindingDataFactory.of(value.get(VAR).scalar().primitive().integerValue())); } @Override public TestUnaryIntegerOutput promiseFor(String nodeId) { - return create(SdkBindingData.ofOutputReference(nodeId, VAR, LITERAL_TYPE)); + return create(SdkBindingData.promise(INTEGERS, nodeId, VAR)); } @Override public Map getVariableMap() { - return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build()); + return Map.of(VAR, Variable.builder().literalType(INTEGERS.getLiteralType()).build()); + } + + @Override + public Map> toLiteralTypes() { + return Map.of(VAR, INTEGERS); } @Override diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java index aec4e4de7..3b8ced836 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java @@ -18,7 +18,6 @@ import static java.util.Collections.singletonMap; import static java.util.stream.Collectors.toMap; -import static org.flyte.flytekit.SdkBindingData.ofInteger; import static org.flyte.flytekit.SdkConditions.eq; import static org.flyte.flytekit.SdkConditions.when; import static org.flyte.localengine.TestingListener.ofCompleted; @@ -47,6 +46,7 @@ import org.flyte.api.v1.WorkflowTemplate; import org.flyte.api.v1.WorkflowTemplateRegistrar; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -562,8 +562,12 @@ public NoOpType expand(SdkWorkflowBuilder builder, NoOpType input) { builder .apply( "decide", - when("eq_1", eq(ofInteger(1L), x), new NoOp(), NoOpType.create(x)) - .when("eq_2", eq(ofInteger(2L), x), new NoOp(), NoOpType.create(x))) + when("eq_1", eq(SdkBindingDataFactory.of(1L), x), new NoOp(), NoOpType.create(x)) + .when( + "eq_2", + eq(SdkBindingDataFactory.of(2L), x), + new NoOp(), + NoOpType.create(x))) .getOutputs() .x(); diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java index eb241838c..233c226a8 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java @@ -16,13 +16,13 @@ */ package org.flyte.localengine.examples; -import static org.flyte.flytekit.SdkBindingData.ofInteger; import static org.flyte.flytekit.SdkConditions.isTrue; import static org.flyte.flytekit.SdkConditions.when; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -64,7 +64,7 @@ public TestUnaryIntegerOutput expand(SdkWorkflowBuilder builder, Input input) { "was_even", isTrue(isOdd), new Divide(), - Divide.Input.create(input.x(), ofInteger(2L))) + Divide.Input.create(input.x(), SdkBindingDataFactory.of(2L))) .otherwise( "was_odd", new ThreeXPlusOne(), ThreeXPlusOne.Input.create(input.x()))) .getOutputs() @@ -83,7 +83,7 @@ public IsEvenTask() { @Override public IsEvenTask.Output run(IsEvenTask.Input input) { - return IsEvenTask.Output.create(SdkBindingData.ofBoolean(input.x().get() % 2 == 0)); + return IsEvenTask.Output.create(SdkBindingDataFactory.of(input.x().get() % 2 == 0)); } @AutoValue @@ -118,7 +118,7 @@ public Divide() { @Override public TestUnaryIntegerOutput run(Divide.Input input) { return TestUnaryIntegerOutput.create( - SdkBindingData.ofInteger(input.num().get() / input.den().get())); + SdkBindingDataFactory.of(input.num().get() / input.den().get())); } @AutoValue @@ -157,7 +157,7 @@ public ThreeXPlusOne() { @Override public TestUnaryIntegerOutput run(ThreeXPlusOne.Input input) { - return TestUnaryIntegerOutput.create(SdkBindingData.ofInteger(3 * input.x().get() + 1)); + return TestUnaryIntegerOutput.create(SdkBindingDataFactory.of(3 * input.x().get() + 1)); } @AutoValue diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java index 1f1dd0c87..75d6b65d3 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java @@ -20,6 +20,7 @@ import com.google.auto.value.AutoValue; import java.util.List; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -33,7 +34,7 @@ public ListTask() { @Override public Output run(Input input) { - return Output.create(SdkBindingData.ofIntegerCollection(input.list().get())); + return Output.create(SdkBindingDataFactory.ofIntegerCollection(input.list().get())); } @AutoValue diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java index 9827c4ce6..a93aa4d7d 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java @@ -16,13 +16,11 @@ */ package org.flyte.localengine.examples; -import static org.flyte.flytekit.SdkBindingData.ofInteger; - import com.google.auto.service.AutoService; import java.util.List; -import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; +import org.flyte.flytekit.SdkLiteralTypes; import org.flyte.flytekit.SdkNode; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; @@ -38,14 +36,19 @@ public ListWorkflow() { @Override public ListTask.Output expand(SdkWorkflowBuilder builder, Void noInput) { SdkNode sum1 = - builder.apply("sum-1", new SumTask(), SumTask.Input.create(ofInteger(1), ofInteger(2))); + builder.apply( + "sum-1", + new SumTask(), + SumTask.Input.create(SdkBindingDataFactory.of(1), SdkBindingDataFactory.of(2))); SdkNode sum2 = - builder.apply("sum-2", new SumTask(), SumTask.Input.create(ofInteger(3), ofInteger(4))); + builder.apply( + "sum-2", + new SumTask(), + SumTask.Input.create(SdkBindingDataFactory.of(3), SdkBindingDataFactory.of(4))); SdkBindingData> list = - SdkBindingData.ofBindingCollection( - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.INTEGER)), - List.of(sum1.getOutputs().o(), sum2.getOutputs().o())); + SdkBindingDataFactory.ofBindingCollection( + SdkLiteralTypes.integers(), List.of(sum1.getOutputs().o(), sum2.getOutputs().o())); SdkNode list1 = builder.apply("list-1", new ListTask(), ListTask.Input.create(list)); diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java index 47d76801f..a46580703 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java @@ -20,6 +20,7 @@ import com.google.auto.value.AutoValue; import java.util.Map; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -50,7 +51,7 @@ public abstract static class Output { public abstract SdkBindingData> map(); public static Output create(Map map) { - return new AutoValue_MapTask_Output(SdkBindingData.ofIntegerMap(map)); + return new AutoValue_MapTask_Output(SdkBindingDataFactory.ofIntegerMap(map)); } } } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java index 73ff2ef5e..7af533f7b 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java @@ -16,14 +16,13 @@ */ package org.flyte.localengine.examples; -import static org.flyte.flytekit.SdkBindingData.ofInteger; +import static org.flyte.flytekit.SdkLiteralTypes.integers; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import java.util.Map; -import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkNode; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; @@ -50,20 +49,24 @@ public static MapWorkflow.Output create(SdkBindingData> map) { public Output expand(SdkWorkflowBuilder builder, Void noInput) { SdkBindingData sum1 = builder - .apply("sum-1", new SumTask(), SumTask.Input.create(ofInteger(1), ofInteger(2))) + .apply( + "sum-1", + new SumTask(), + SumTask.Input.create(SdkBindingDataFactory.of(1), SdkBindingDataFactory.of(2))) .getOutputs() .o(); SdkBindingData sum2 = builder - .apply("sum-2", new SumTask(), SumTask.Input.create(ofInteger(3), ofInteger(4))) + .apply( + "sum-2", + new SumTask(), + SumTask.Input.create(SdkBindingDataFactory.of(3), SdkBindingDataFactory.of(4))) .getOutputs() .o(); SdkBindingData> map = - SdkBindingData.ofBindingMap( - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.INTEGER)), - Map.of("e", sum1, "f", sum2)); + SdkBindingDataFactory.ofBindingMap(integers(), Map.of("e", sum1, "f", sum2)); SdkNode map1 = builder.apply("map-1", new MapTask(), MapTask.Input.create(map)); diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java index 7f37f1409..cb2ef230b 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java @@ -19,6 +19,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; import org.flyte.localengine.examples.SumTask.Input; @@ -45,6 +46,6 @@ public static Input create(SdkBindingData a, SdkBindingData b) { @Override public TestUnaryIntegerOutput run(Input input) { return TestUnaryIntegerOutput.create( - SdkBindingData.ofInteger(input.a().get() + input.b().get())); + SdkBindingDataFactory.of(input.a().get() + input.b().get())); } } diff --git a/flytekit-scala-tests/pom.xml b/flytekit-scala-tests/pom.xml index dbb76e44b..32ab79501 100644 --- a/flytekit-scala-tests/pom.xml +++ b/flytekit-scala-tests/pom.xml @@ -63,11 +63,6 @@ junit-jupiter test - - org.junit.vintage - junit-vintage-engine - test - diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkBindingDataConvertersTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkBindingDataConvertersTest.scala new file mode 100644 index 000000000..ef4425eb5 --- /dev/null +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkBindingDataConvertersTest.scala @@ -0,0 +1,347 @@ +/* + * Copyright 2021 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekitscala + +import org.flyte.flytekit.{ + SdkBindingData, + SdkBindingDataFactory => JavaSBD, + SdkLiteralTypes => JavaSLT +} +import org.flyte.flytekitscala.SdkBindingDataConverters._ +import org.flyte.flytekitscala.{ + SdkBindingDataFactory => ScalaSBD, + SdkLiteralTypes => ScalaSLT +} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{ + Arguments, + ArgumentsProvider, + ArgumentsSource +} + +import java.time.ZoneOffset.UTC +import java.time.{Duration, Instant, LocalDate} +import java.util.stream.Stream +import java.{lang => j, util => ju} +import scala.collection.JavaConverters._ + +class SdkBindingDataConvertersTest { + + @ParameterizedTest + @ArgumentsSource(classOf[TestRoundTripConversionForScalarProvider]) + def testRoundTripConversionForScalars[JavaT, ScalaT]( + javaScalar: SdkBindingData[JavaT], + toScala: SdkBindingData[JavaT] => SdkBindingData[ScalaT], + scalaScalar: SdkBindingData[ScalaT], + toJava: SdkBindingData[ScalaT] => SdkBindingData[JavaT] + ): Unit = { + val scalaConverted = toScala(javaScalar) + val javaConverted = toJava(scalaScalar) + + assertEquals(javaScalar, javaConverted) + assertEquals(scalaScalar, scalaConverted) + } + + @ParameterizedTest + @ArgumentsSource(classOf[TestRoundTripConversionForCollectionsProvider]) + def testRoundTripConversionForCollections[JavaT, ScalaT]( + javaCollection: SdkBindingData[ju.List[JavaT]], + scalaCollection: SdkBindingData[List[ScalaT]] + ): Unit = { + val scalaConverted = toScalaList(javaCollection) + val javaConverted = toJavaList(scalaConverted) + + assertEquals(javaCollection, javaConverted) + assertEquals(scalaCollection, scalaConverted) + } + + @ParameterizedTest + @ArgumentsSource(classOf[TestRoundTripConversionForMapProvider]) + def testRoundTripConversionForMap[JavaT, ScalaT]( + javaMap: SdkBindingData[ju.Map[String, JavaT]], + scalaMap: SdkBindingData[Map[String, ScalaT]] + ): Unit = { + val scalaConverted = toScalaMap(javaMap) + val javaConverted = toJavaMap(scalaMap) + + assertEquals(javaMap, javaConverted) + assertEquals(scalaMap, scalaConverted) + } + + @Test + def testToScalaListForBindCollectionsShouldThrowException(): Unit = { + val javaLongList = ju.List.of( + SdkBindingData.literal(JavaSLT.integers(), j.Long.valueOf(1L)), + SdkBindingData.literal(JavaSLT.integers(), j.Long.valueOf(2L)), + SdkBindingData.literal(JavaSLT.integers(), j.Long.valueOf(3L)) + ) + val original = + SdkBindingData.bindingCollection(JavaSLT.integers(), javaLongList) + + val exception = assertThrows( + classOf[UnsupportedOperationException], + () => toScalaList(original) + ) + + assertEquals( + exception.getMessage, + "SdkBindingData of binding collections cannot be casted" + ) + } + + @Test + def testToJavaListForBindCollectionsShouldThrowException(): Unit = { + val scalaLongList = List( + SdkBindingData.literal(ScalaSLT.integers(), 1L), + SdkBindingData.literal(ScalaSLT.integers(), 2L), + SdkBindingData.literal(ScalaSLT.integers(), 3L) + ) + val original = + SdkBindingData.bindingCollection( + ScalaSLT.integers(), + scalaLongList.asJava + ) + + val exception = assertThrows( + classOf[UnsupportedOperationException], + () => toScalaList(original) + ) + + assertEquals( + exception.getMessage, + "SdkBindingData of binding collections cannot be casted" + ) + } + + @Test + def testToScalaListForBindMapsShouldThrowException(): Unit = { + val javaLongList = ju.Map.of( + "a", + SdkBindingData.literal(JavaSLT.integers(), j.Long.valueOf(1L)), + "b", + SdkBindingData.literal(JavaSLT.integers(), j.Long.valueOf(2L)), + "c", + SdkBindingData.literal(JavaSLT.integers(), j.Long.valueOf(3L)) + ) + val original = + SdkBindingData.bindingMap(JavaSLT.integers(), javaLongList) + + val exception = assertThrows( + classOf[UnsupportedOperationException], + () => toScalaMap(original) + ) + + assertEquals( + exception.getMessage, + "SdkBindingData of binding map cannot be casted" + ) + } + + @Test + def testToJavaListForBindMapsShouldThrowException(): Unit = { + val scalaLongList = Map( + "a" -> SdkBindingData.literal(ScalaSLT.integers(), 1L), + "b" -> SdkBindingData.literal(ScalaSLT.integers(), 2L), + "c" -> SdkBindingData.literal(ScalaSLT.integers(), 3L) + ) + val original = + SdkBindingData.bindingMap(ScalaSLT.integers(), scalaLongList.asJava) + + val exception = assertThrows( + classOf[UnsupportedOperationException], + () => toScalaMap(original) + ) + + assertEquals( + exception.getMessage, + "SdkBindingData of binding map cannot be casted" + ) + } +} + +class TestRoundTripConversionForScalarProvider extends ArgumentsProvider { + override def provideArguments( + context: ExtensionContext + ): Stream[_ <: Arguments] = { + Stream.of( + Arguments.of( + JavaSBD.of(j.Long.valueOf(1L)), + d => toScalaLong(d), + ScalaSBD.of(1L), + d => toJavaLong(d) + ), + Arguments.of( + JavaSBD.of(j.Double.valueOf(1.0)), + d => toScalaDouble(d), + ScalaSBD.of(1.0), + d => toJavaDouble(d) + ), + Arguments.of( + JavaSBD.of(j.Boolean.valueOf(true)), + d => toScalaBoolean(d), + ScalaSBD.of(true), + d => toJavaBoolean(d) + ) + ) + } +} + +class TestRoundTripConversionForCollectionsProvider extends ArgumentsProvider { + override def provideArguments( + context: ExtensionContext + ): Stream[_ <: Arguments] = { + val date1 = LocalDate.now().atStartOfDay(UTC).toInstant + val date2 = LocalDate.of(2023, 1, 1).atStartOfDay(UTC).toInstant + Stream.of( + Arguments.of( + JavaSBD.ofIntegerCollection(ju.List.of[j.Long](1L, 2L, 3L)), + ScalaSBD.ofIntegerCollection(List(1L, 2L, 3L)) + ), + Arguments.of( + JavaSBD.ofFloatCollection(ju.List.of[j.Double](1.0, 2.0, 3.0)), + ScalaSBD.ofFloatCollection(List(1.0, 2.0, 3.0)) + ), + Arguments.of( + JavaSBD.ofStringCollection(ju.List.of[j.String]("a", "b", "c")), + ScalaSBD.ofStringCollection(List("a", "b", "c")) + ), + Arguments.of( + JavaSBD.ofBooleanCollection(ju.List.of[j.Boolean](true, false, true)), + ScalaSBD.ofBooleanCollection(List(true, false, true)) + ), + Arguments.of( + JavaSBD.ofDatetimeCollection(ju.List.of[Instant](date1, date2)), + ScalaSBD.ofDatetimeCollection(List(date1, date2)) + ), + Arguments.of( + JavaSBD.ofDurationCollection( + ju.List.of[Duration](Duration.ZERO, Duration.ofSeconds(5)) + ), + ScalaSBD.ofDurationCollection( + List(Duration.ZERO, Duration.ofSeconds(5)) + ) + ), + Arguments.of( + JavaSBD.of( + JavaSLT.collections(JavaSLT.strings()), + ju.List.of(ju.List.of("frodo", "sam"), ju.List.of("harry", "ron")) + ), + ScalaSBD.of( + ScalaSLT.collections(ScalaSLT.strings()), + List(List("frodo", "sam"), List("harry", "ron")) + ) + ), + Arguments.of( + JavaSBD.of( + JavaSLT.maps(JavaSLT.strings()), + ju.List.of(ju.Map.of("frodo", "sam"), ju.Map.of("harry", "ron")) + ), + ScalaSBD.of( + ScalaSLT.maps(ScalaSLT.strings()), + List(Map("frodo" -> "sam"), Map("harry" -> "ron")) + ) + ) + ) + } +} + +class TestRoundTripConversionForMapProvider extends ArgumentsProvider { + override def provideArguments( + context: ExtensionContext + ): Stream[_ <: Arguments] = { + val date1 = LocalDate.now().atStartOfDay(UTC).toInstant + val date2 = LocalDate.of(2023, 1, 1).atStartOfDay(UTC).toInstant + Stream.of( + Arguments.of( + JavaSBD.ofIntegerMap( + ju.Map.of[String, j.Long]("a", 1L, "b", 2L, "c", 3L) + ), + ScalaSBD.ofIntegerMap(Map("a" -> 1L, "b" -> 2L, "c" -> 3L)) + ), + Arguments.of( + JavaSBD.ofFloatMap( + ju.Map.of[String, j.Double]("a", 1.0, "b", 2.0, "c", 3.0) + ), + ScalaSBD.ofFloatMap(Map("a" -> 1.0, "b" -> 2.0, "c" -> 3.0)) + ), + Arguments.of( + JavaSBD.ofStringMap( + ju.Map.of[String, j.String]("a", "a", "b", "b", "c", "c") + ), + ScalaSBD.ofStringMap(Map("a" -> "a", "b" -> "b", "c" -> "c")) + ), + Arguments.of( + JavaSBD.ofBooleanMap( + ju.Map.of[String, j.Boolean]("a", true, "b", false, "c", true) + ), + ScalaSBD.ofBooleanMap(Map("a" -> true, "b" -> false, "c" -> true)) + ), + Arguments.of( + JavaSBD.ofDatetimeMap( + ju.Map.of[String, Instant]("a", date1, "b", date2) + ), + ScalaSBD.ofDatetimeMap(Map("a" -> date1, "b" -> date2)) + ), + Arguments.of( + JavaSBD.ofDurationMap( + ju.Map.of[String, Duration]( + "a", + Duration.ZERO, + "b", + Duration.ofSeconds(5) + ) + ), + ScalaSBD.ofDurationMap( + Map("a" -> Duration.ZERO, "b" -> Duration.ofSeconds(5)) + ) + ), + Arguments.of( + JavaSBD.of( + JavaSLT.maps(JavaSLT.strings()), + ju.Map.of( + "lotr", + ju.Map.of("frodo", "sam"), + "hp", + ju.Map.of("harry", "ron") + ) + ), + ScalaSBD.of( + ScalaSLT.maps(ScalaSLT.strings()), + Map("lotr" -> Map("frodo" -> "sam"), "hp" -> Map("harry" -> "ron")) + ) + ), + Arguments.of( + JavaSBD.of( + JavaSLT.collections(JavaSLT.strings()), + ju.Map.of( + "lotr", + ju.List.of("frodo", "sam"), + "hp", + ju.List.of("harry", "ron") + ) + ), + ScalaSBD.of( + ScalaSLT.collections(ScalaSLT.strings()), + Map("lotr" -> List("frodo", "sam"), "hp" -> List("harry", "ron")) + ) + ) + ) + } +} diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala new file mode 100644 index 000000000..3ae9a5b9c --- /dev/null +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala @@ -0,0 +1,148 @@ +/* + * Copyright 2021 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekitscala + +import org.flyte.flytekit.SdkLiteralType +import org.flyte.flytekitscala.SdkLiteralTypes.{of, _} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{ + Arguments, + ArgumentsProvider, + ArgumentsSource +} + +import java.time.{Duration, Instant} +import java.util.stream.Stream +import scala.annotation.unused + +class SdkLiteralTypesTest { + + @ParameterizedTest + @ArgumentsSource(classOf[TestOfReturnsProperTypeProvider]) + def testOfReturnsProperType( + expected: SdkLiteralType[_], + actual: SdkLiteralType[_] + ): Unit = { + assertEquals(expected, actual) + } + + @ParameterizedTest(name = "{index} {0}") + @ArgumentsSource(classOf[testOfThrowExceptionsForUnsupportedTypesProvider]) + def testOfThrowExceptionsForUnsupportedTypes( + @unused reason: String, + create: () => SdkLiteralType[_] + ): Unit = { + assertThrows(classOf[IllegalArgumentException], () => create()) + } + +} + +class TestOfReturnsProperTypeProvider extends ArgumentsProvider { + override def provideArguments( + context: ExtensionContext + ): Stream[_ <: Arguments] = { + Stream.of( + Arguments.of(integers(), of[Long]()), + Arguments.of(floats(), of[Double]()), + Arguments.of(strings(), of[String]()), + Arguments.of(booleans(), of[Boolean]()), + Arguments.of(datetimes(), of[Instant]()), + Arguments.of(durations(), of[Duration]()), + Arguments.of(collections(integers()), of[List[Long]]()), + Arguments.of(collections(floats()), of[List[Double]]()), + Arguments.of(collections(strings()), of[List[String]]()), + Arguments.of(collections(booleans()), of[List[Boolean]]()), + Arguments.of(collections(datetimes()), of[List[Instant]]()), + Arguments.of(collections(durations()), of[List[Duration]]()), + Arguments.of(maps(integers()), of[Map[String, Long]]()), + Arguments.of(maps(floats()), of[Map[String, Double]]()), + Arguments.of(maps(strings()), of[Map[String, String]]()), + Arguments.of(maps(booleans()), of[Map[String, Boolean]]()), + Arguments.of(maps(datetimes()), of[Map[String, Instant]]()), + Arguments.of(maps(durations()), of[Map[String, Duration]]()), + Arguments + .of(collections(collections(integers())), of[List[List[Long]]]()), + Arguments + .of(collections(collections(floats())), of[List[List[Double]]]()), + Arguments + .of(collections(collections(strings())), of[List[List[String]]]()), + Arguments + .of(collections(collections(booleans())), of[List[List[Boolean]]]()), + Arguments + .of(collections(collections(datetimes())), of[List[List[Instant]]]()), + Arguments + .of(collections(collections(durations())), of[List[List[Duration]]]()), + Arguments + .of(maps(maps(integers())), of[Map[String, Map[String, Long]]]()), + Arguments + .of(maps(maps(floats())), of[Map[String, Map[String, Double]]]()), + Arguments + .of(maps(maps(strings())), of[Map[String, Map[String, String]]]()), + Arguments + .of(maps(maps(booleans())), of[Map[String, Map[String, Boolean]]]()), + Arguments + .of(maps(maps(datetimes())), of[Map[String, Map[String, Instant]]]()), + Arguments + .of(maps(maps(durations())), of[Map[String, Map[String, Duration]]]()), + Arguments + .of(maps(collections(integers())), of[Map[String, List[Long]]]()), + Arguments + .of(maps(collections(floats())), of[Map[String, List[Double]]]()), + Arguments + .of(maps(collections(strings())), of[Map[String, List[String]]]()), + Arguments + .of(maps(collections(booleans())), of[Map[String, List[Boolean]]]()), + Arguments + .of(maps(collections(datetimes())), of[Map[String, List[Instant]]]()), + Arguments + .of(maps(collections(durations())), of[Map[String, List[Duration]]]()), + Arguments + .of(collections(maps(integers())), of[List[Map[String, Long]]]()), + Arguments + .of(collections(maps(floats())), of[List[Map[String, Double]]]()), + Arguments + .of(collections(maps(strings())), of[List[Map[String, String]]]()), + Arguments + .of(collections(maps(booleans())), of[List[Map[String, Boolean]]]()), + Arguments + .of(collections(maps(datetimes())), of[List[Map[String, Instant]]]()), + Arguments.of( + collections(maps(durations())), + of[List[Map[String, Duration]]]() + ) + ) + } +} + +class testOfThrowExceptionsForUnsupportedTypesProvider + extends ArgumentsProvider { + override def provideArguments( + context: ExtensionContext + ): Stream[_ <: Arguments] = { + Stream.of( + Arguments + .of("java type, must use java factory", () => of[java.lang.Long]()), + Arguments.of("not a supported type", () => of[Object]()), + Arguments.of( + "triple nesting not supported in of", + () => of[List[List[List[Long]]]]() + ) + ) + } +} diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 6e1e4df34..1d1625d87 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -19,7 +19,6 @@ package org.flyte.flytekitscala import java.time.{Duration, Instant} import scala.jdk.CollectionConverters._ import org.flyte.api.v1.{ - BindingData, Literal, LiteralType, Primitive, @@ -27,11 +26,15 @@ import org.flyte.api.v1.{ SimpleType, Variable } -import org.flyte.flytekit.SdkBindingData -import org.flyte.flytekitscala.SdkBindingData._ -import org.junit.Assert.{assertEquals, assertThrows} -import org.junit.Test +import org.flyte.flytekit.{ + SdkBindingData, + SdkBindingDataFactory => SdkJavaBindingDataFactory +} +import org.flyte.flytekitscala.SdkBindingDataFactory +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test import org.flyte.examples.AllInputsTask.AutoAllInputsInput +import org.flyte.flytekitscala.SdkLiteralTypes.{collections, maps, strings} class SdkScalaTypeTest { @@ -145,12 +148,12 @@ class SdkScalaTypeTest { val expected = ScalarInput( - string = ofString("string"), - integer = ofInteger(1337L), - float = ofFloat(42.0), - boolean = ofBoolean(true), - datetime = ofDateTime(Instant.ofEpochMilli(123456L)), - duration = ofDuration(Duration.ofSeconds(123, 456)) + string = SdkBindingDataFactory.of("string"), + integer = SdkBindingDataFactory.of(1337L), + float = SdkBindingDataFactory.of(42.0), + boolean = SdkBindingDataFactory.of(true), + datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), + duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) ) val output = SdkScalaType[ScalarInput].fromLiteralMap(input) @@ -162,12 +165,12 @@ class SdkScalaTypeTest { def testScalarToLiteralMap(): Unit = { val input = ScalarInput( - string = ofString("string"), - integer = ofInteger(1337L), - float = ofFloat(42.0), - boolean = ofBoolean(true), - datetime = ofDateTime(Instant.ofEpochMilli(123456L)), - duration = ofDuration(Duration.ofSeconds(123, 456)) + string = SdkBindingDataFactory.of("string"), + integer = SdkBindingDataFactory.of(1337L), + float = SdkBindingDataFactory.of(42.0), + boolean = SdkBindingDataFactory.of(true), + datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), + duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) ) val expected = Map( @@ -213,23 +216,23 @@ class SdkScalaTypeTest { @Test def testToSdkBindingMap(): Unit = { val input = ScalarInput( - string = ofString("string"), - integer = ofInteger(1337L), - float = ofFloat(42.0), - boolean = ofBoolean(true), - datetime = ofDateTime(Instant.ofEpochMilli(123456L)), - duration = ofDuration(Duration.ofSeconds(123, 456)) + string = SdkBindingDataFactory.of("string"), + integer = SdkBindingDataFactory.of(1337L), + float = SdkBindingDataFactory.of(42.0), + boolean = SdkBindingDataFactory.of(true), + datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), + duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) ) val output = SdkScalaType[ScalarInput].toSdkBindingMap(input) val expected = Map( - "string" -> ofString("string"), - "integer" -> ofInteger(1337L), - "float" -> ofFloat(42.0), - "boolean" -> ofBoolean(true), - "datetime" -> ofDateTime(Instant.ofEpochMilli(123456L)), - "duration" -> ofDuration(Duration.ofSeconds(123, 456)) + "string" -> SdkBindingDataFactory.of("string"), + "integer" -> SdkBindingDataFactory.of(1337L), + "float" -> SdkBindingDataFactory.of(42.0), + "boolean" -> SdkBindingDataFactory.of(true), + "datetime" -> SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)), + "duration" -> SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)) ).asJava assertEquals(expected, output) @@ -267,14 +270,14 @@ class SdkScalaTypeTest { def testRoundTripFromAndToCaseClassWithCollections(): Unit = { val input = CollectionInput( - strings = ofCollection(List("foo", "bar")), - integers = ofCollection(List(1337L, 321L)), - floats = ofCollection(List(42.0, 3.14)), - booleans = ofCollection(List(true, false)), - datetimes = ofCollection( + strings = SdkBindingDataFactory.of(List("foo", "bar")), + integers = SdkBindingDataFactory.of(List(1337L, 321L)), + floats = SdkBindingDataFactory.of(List(42.0, 3.14)), + booleans = SdkBindingDataFactory.of(List(true, false)), + datetimes = SdkBindingDataFactory.of( List(Instant.ofEpochMilli(123456L), Instant.ofEpochMilli(321L)) ), - durations = ofCollection( + durations = SdkBindingDataFactory.of( List(Duration.ofSeconds(123, 456), Duration.ofSeconds(543, 21)) ) ) @@ -316,12 +319,14 @@ class SdkScalaTypeTest { def testRoundTripFromAndToCaseClassWithMaps(): Unit = { val input = MapInput( - stringMap = ofMap(Map("k1" -> "foo")), - integerMap = ofMap(Map("k2" -> 321L)), - floatMap = ofMap(Map("k3" -> 3.14)), - booleanMap = ofMap(Map("k4" -> false)), - datetimeMap = ofMap(Map("k5" -> Instant.ofEpochMilli(321L))), - durationMap = ofMap(Map("k6" -> Duration.ofSeconds(543, 21))) + stringMap = SdkBindingDataFactory.of(Map("k1" -> "foo")), + integerMap = SdkBindingDataFactory.of(Map("k2" -> 321L)), + floatMap = SdkBindingDataFactory.of(Map("k3" -> 3.14)), + booleanMap = SdkBindingDataFactory.of(Map("k4" -> false)), + datetimeMap = + SdkBindingDataFactory.of(Map("k5" -> Instant.ofEpochMilli(321L))), + durationMap = + SdkBindingDataFactory.of(Map("k6" -> Duration.ofSeconds(543, 21))) ) val output = SdkScalaType[MapInput].fromLiteralMap( @@ -355,7 +360,7 @@ class SdkScalaTypeTest { def testRoundTripFromAndToCaseClassWithListsOfMaps(): Unit = { val input = ComplexInput( - metadataList = ofCollection( + metadataList = SdkBindingDataFactory.of( List( Map("Frodo" -> "Baggins", "Sam" -> "Gamgee"), Map("Clark" -> "Kent", "Loise" -> "Lane") @@ -372,19 +377,21 @@ class SdkScalaTypeTest { @Test def testUseAutoValueAttrIntoScalaClass(): Unit = { - import org.flyte.flytekit.SdkBindingDataConverters._ + import SdkBindingDataConverters._ val input = AutoAllInputsInput.create( - SdkBindingData.ofInteger(2L), - SdkBindingData.ofFloat(2.0), - SdkBindingData.ofString("hello"), - SdkBindingData.ofBoolean(true), - SdkBindingData.ofDatetime(Instant.parse("2023-01-01T00:00:00Z")), - SdkBindingData.ofDuration(Duration.ZERO), - SdkBindingData.ofStringCollection(List("1", "2", "3").asJava), - SdkBindingData.ofStringMap(Map("a" -> "2", "b" -> "3").asJava), - SdkBindingData.ofStringCollection(List.empty[String].asJava), - SdkBindingData.ofIntegerMap(Map.empty[String, java.lang.Long].asJava) + SdkJavaBindingDataFactory.of(2L), + SdkJavaBindingDataFactory.of(2.0), + SdkJavaBindingDataFactory.of("hello"), + SdkJavaBindingDataFactory.of(true), + SdkJavaBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), + SdkJavaBindingDataFactory.of(Duration.ZERO), + SdkJavaBindingDataFactory.ofStringCollection(List("1", "2", "3").asJava), + SdkJavaBindingDataFactory.ofStringMap(Map("a" -> "2", "b" -> "3").asJava), + SdkJavaBindingDataFactory.ofStringCollection(List.empty[String].asJava), + SdkJavaBindingDataFactory.ofIntegerMap( + Map.empty[String, java.lang.Long].asJava + ) ) case class AutoAllInputsInputScala( @@ -414,16 +421,16 @@ class SdkScalaTypeTest { ) val expected = AutoAllInputsInputScala( - ofInteger(2L), - ofFloat(2.0), - ofString("hello"), - ofBoolean(true), - ofDateTime(Instant.parse("2023-01-01T00:00:00Z")), - ofDuration(Duration.ZERO), - ofCollection(List("1", "2", "3")), - ofMap(Map("a" -> "2", "b" -> "3")), - ofStringCollection(List.empty[String]), - ofIntegerMap(Map.empty[String, Long]) + SdkBindingDataFactory.of(2L), + SdkBindingDataFactory.of(2.0), + SdkBindingDataFactory.of("hello"), + SdkBindingDataFactory.of(true), + SdkBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), + SdkBindingDataFactory.of(Duration.ZERO), + SdkBindingDataFactory.of(List("1", "2", "3")), + SdkBindingDataFactory.of(Map("a" -> "2", "b" -> "3")), + SdkBindingDataFactory.ofStringCollection(List.empty[String]), + SdkBindingDataFactory.ofIntegerMap(Map.empty[String, Long]) ) assertEquals(expected, scalaClass) @@ -432,27 +439,19 @@ class SdkScalaTypeTest { @Test def testEmptyCollection(): Unit = { - val emptyList = ofStringCollection(List.empty[String]) - val expected = SdkBindingData.create( - BindingData.ofCollection(List.empty[BindingData].asJava), - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.STRING)), - List.empty[String] - ) + val emptyList = SdkBindingDataFactory.ofStringCollection(List.empty[String]) + val expected = + SdkBindingData.literal(collections(strings()), List.empty[String]) assertEquals(emptyList, expected) } @Test def testEmptyMap(): Unit = { - val emptyMap = ofStringMap(Map.empty[String, String]) - val expected = SdkBindingData.create( - BindingData.ofMap(Map.empty[String, BindingData].asJava), - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.STRING)), - Map.empty[String, String] - ) + val emptyMap = SdkBindingDataFactory.ofStringMap(Map.empty[String, String]) + val expected = + SdkBindingData.literal(maps(strings()), Map.empty[String, String]) assertEquals(emptyMap, expected) } - - // Typed[String] doesn't compile aka illtyped } diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekit/SdkBindingDataConverters.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekit/SdkBindingDataConverters.scala deleted file mode 100644 index 5169753bc..000000000 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekit/SdkBindingDataConverters.scala +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright 2021 Flyte Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.flyte.flytekit - -import scala.collection.JavaConverters._ - -/** The [[SdkBindingDataConverters]] allows you to do java <-> scala conversions - * for [[SdkBindingData]] - */ -object SdkBindingDataConverters { - - /** Transform from java.lang.Long to scala Long. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toScalaLong( - sdkBindingData: SdkBindingData[java.lang.Long] - ): SdkBindingData[Long] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value() - ) - } - - /** Transform from scala Long to java.lang.Long. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toJavaLong( - sdkBindingData: SdkBindingData[Long] - ): SdkBindingData[java.lang.Long] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value() - ) - } - - /** Transform from java.lang.Boolean to scala Boolean. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toScalaBoolean( - sdkBindingData: SdkBindingData[java.lang.Boolean] - ): SdkBindingData[Boolean] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value() - ) - } - - /** Transform from scala Boolean to java.lang.Boolean. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toJavaBoolean( - sdkBindingData: SdkBindingData[Boolean] - ): SdkBindingData[java.lang.Boolean] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value() - ) - } - - /** Transform from scala Double to java.lang.Double. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toScalaDouble( - sdkBindingData: SdkBindingData[java.lang.Double] - ): SdkBindingData[Double] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value() - ) - } - - /** Transform from scala Double to java.lang.Double. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toJavaDouble( - sdkBindingData: SdkBindingData[Double] - ): SdkBindingData[java.lang.Double] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value() - ) - } - - /** Transform from java.util.List to scala List. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toScalaList[K, T]( - sdkBindingData: SdkBindingData[java.util.List[K]] - ): SdkBindingData[List[T]] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value().asScala.map(_.asInstanceOf[T]).toList - ) - } - - /** Transform from scala List to java.util.List. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toJavaList[K, T]( - sdkBindingData: SdkBindingData[List[K]] - ): SdkBindingData[java.util.List[T]] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value().toList.map(_.asInstanceOf[T]).asJava - ) - } - - /** Transform from scala Map to java.util.Map. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toScalaMap[K, T]( - sdkBindingData: SdkBindingData[java.util.Map[String, K]] - ): SdkBindingData[Map[String, T]] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value().asScala.mapValues(_.asInstanceOf[T]).toMap - ) - } - - /** Transform from scala Map to java.util.Map. - * - * @param sdkBindingData - * the value to transform - * @return - * the value transformed. - */ - def toJavaMap[K, T]( - sdkBindingData: SdkBindingData[Map[String, K]] - ): SdkBindingData[java.util.Map[String, T]] = { - SdkBindingData.create( - sdkBindingData.idl(), - sdkBindingData.`type`(), - sdkBindingData.value().mapValues(_.asInstanceOf[T]).toMap.asJava - ) - } - -} diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingData.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingData.scala deleted file mode 100644 index 5ea83b45b..000000000 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingData.scala +++ /dev/null @@ -1,502 +0,0 @@ -/* - * Copyright 2021 Flyte Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.flyte.flytekitscala - -import org.flyte.api.v1.{ - BindingData, - LiteralType, - Primitive, - Scalar, - SimpleType -} -import org.flyte.flytekit.{SdkBindingData => SdkJavaBindinigData} - -import java.time.{Duration, Instant} -import scala.collection.JavaConverters._ - -/** Utility to create [[SdkBindingData]] using scala raw types. - */ -object SdkBindingData { - - /** Creates a [[SdkBindingData]] for a flyte string ([[String]] for scala) - * with the given value. - * - * @param string - * the simple value for this data - * @return - * the new {[[SdkBindingData]] - */ - def ofString(string: String): SdkJavaBindinigData[String] = - createSdkBindingData(string) - - /** Creates a [[SdkBindingData]] for a flyte integer ([[Long]] for scala) with - * the given value. - * - * @param long - * the simple value for this data - * @return - * the new {[[SdkBindingData]] - */ - def ofInteger(long: Long): SdkJavaBindinigData[Long] = - createSdkBindingData(long) - - /** Creates a [[SdkBindingData]] for a flyte float ([[Double]] for scala) with - * the given value. - * - * @param double - * the simple value for this data - * @return - * the new {[[SdkBindingData]] - */ - def ofFloat(double: Double): SdkJavaBindinigData[Double] = - createSdkBindingData(double) - - /** Creates a [[SdkBindingData]] for a flyte boolean ([[Boolean]] for scala) - * with the given value. - * - * @param boolean - * the simple value for this data - * @return - * the new {[[SdkBindingData]] - */ - def ofBoolean( - boolean: Boolean - ): SdkJavaBindinigData[Boolean] = - createSdkBindingData(boolean) - - /** Creates a [[SdkBindingData]] for a flyte instant ([[Instant]] for scala) - * with the given value. - * - * @param instant - * the simple value for this data - * @return - * the new {[[SdkBindingData]] - */ - def ofDateTime(instant: Instant): SdkJavaBindinigData[Instant] = - createSdkBindingData(instant) - - /** Creates a [[SdkBindingData]] for a flyte duration ([[Duration]] for scala) - * with the given value. - * - * @param duration - * the simple value for this data - * @return - * the new {[[SdkBindingData]] - */ - def ofDuration( - duration: Duration - ): SdkJavaBindinigData[Duration] = createSdkBindingData(duration) - - /** Creates a [[SdkBindingData]] for a flyte collection given a scala - * [[List]]. - * - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofCollection[T]( - collection: List[T] - ): SdkJavaBindinigData[List[T]] = createSdkBindingData(collection) - - /** Creates a [[SdkBindingData]] for a flyte collection given a scala - * [[List]]. - * - * @param literalType - * literal type for the whole collection. It must be a - * [[LiteralType.Kind.COLLECTION_TYPE]]. - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofCollection[T]( - literalType: LiteralType, - collection: List[T] - ): SdkJavaBindinigData[List[T]] = - createSdkBindingData(collection, Option(literalType)) - - /** Creates a [[SdkBindingData]] for a flyte string collection given a scala - * [[List]]. - * - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofStringCollection( - collection: List[String] - ): SdkJavaBindinigData[List[String]] = - createSdkBindingData( - collection, - Option( - LiteralType.ofCollectionType( - LiteralType.ofSimpleType(SimpleType.STRING) - ) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte integer collection given a scala - * [[List]]. - * - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofIntegerCollection( - collection: List[Long] - ): SdkJavaBindinigData[List[Long]] = - createSdkBindingData( - collection, - Option( - LiteralType.ofCollectionType( - LiteralType.ofSimpleType(SimpleType.INTEGER) - ) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte boolean collection given a scala - * [[List]]. - * - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofBooleanCollection( - collection: List[Boolean] - ): SdkJavaBindinigData[List[Boolean]] = - createSdkBindingData( - collection, - Option( - LiteralType.ofCollectionType( - LiteralType.ofSimpleType(SimpleType.BOOLEAN) - ) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte float collection given a scala - * [[List]]. - * - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofFloatCollection( - collection: List[Double] - ): SdkJavaBindinigData[List[Double]] = - createSdkBindingData( - collection, - Option( - LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.FLOAT)) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte datetime collection given a scala - * [[List]]. - * - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofInstantCollection( - collection: List[Instant] - ): SdkJavaBindinigData[List[Instant]] = - createSdkBindingData( - collection, - Option( - LiteralType.ofCollectionType( - LiteralType.ofSimpleType(SimpleType.DATETIME) - ) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte duration collection given a scala - * [[List]]. - * - * @param collection - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofDurationCollection( - collection: List[Duration] - ): SdkJavaBindinigData[List[Duration]] = - createSdkBindingData( - collection, - Option( - LiteralType.ofCollectionType( - LiteralType.ofSimpleType(SimpleType.DURATION) - ) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte map given a scala [[Map]]. - * - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofMap[T]( - map: Map[String, T] - ): SdkJavaBindinigData[Map[String, T]] = createSdkBindingData(map) - - /** Creates a [[SdkBindingData]] for a flyte string map given a scala [[Map]]. - * - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofStringMap( - map: Map[String, String] - ): SdkJavaBindinigData[Map[String, String]] = - createSdkBindingData( - map, - Option( - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.STRING)) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte long map given a scala [[Map]]. - * - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofIntegerMap( - map: Map[String, Long] - ): SdkJavaBindinigData[Map[String, Long]] = - createSdkBindingData( - map, - Option( - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.INTEGER)) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte boolean map given a scala - * [[Map]]. - * - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofBooleanMap( - map: Map[String, Boolean] - ): SdkJavaBindinigData[Map[String, Boolean]] = - createSdkBindingData( - map, - Option( - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.BOOLEAN)) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte double map given a scala [[Map]]. - * - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofFloatMap( - map: Map[String, Double] - ): SdkJavaBindinigData[Map[String, Double]] = - createSdkBindingData( - map, - Option( - LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.FLOAT)) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte instant map given a scala - * [[Map]]. - * - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofInstantMap( - map: Map[String, Instant] - ): SdkJavaBindinigData[Map[String, Instant]] = - createSdkBindingData( - map, - Option( - LiteralType.ofMapValueType( - LiteralType.ofSimpleType(SimpleType.DATETIME) - ) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte duration map given a scala - * [[Map]]. - * - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofDurationMap( - map: Map[String, Duration] - ): SdkJavaBindinigData[Map[String, Duration]] = - createSdkBindingData( - map, - Option( - LiteralType.ofMapValueType( - LiteralType.ofSimpleType(SimpleType.DURATION) - ) - ) - ) - - /** Creates a [[SdkBindingData]] for a flyte duration map given a scala - * [[Map]]. - * - * @param literalType - * literal type for the whole collection. It must be a - * [[LiteralType.Kind.MAP_VALUE_TYPE]]. - * @param map - * collection to represent on this data. - * @return - * the new [[SdkBindingData]] - */ - def ofMap[T]( - literalType: LiteralType, - map: Map[String, T] - ): SdkJavaBindinigData[Map[String, T]] = - createSdkBindingData(map, Option(literalType)) - - private def toBindingData( - value: Any, - literalTypeOpt: Option[LiteralType] - ): (BindingData, LiteralType) = { - value match { - case string: String => - ( - BindingData.ofScalar( - Scalar.ofPrimitive(Primitive.ofStringValue(string)) - ), - LiteralType.ofSimpleType(SimpleType.STRING) - ) - case boolean: Boolean => - ( - BindingData.ofScalar( - Scalar.ofPrimitive(Primitive.ofBooleanValue(boolean)) - ), - LiteralType.ofSimpleType(SimpleType.BOOLEAN) - ) - case long: Long => - ( - BindingData.ofScalar( - Scalar.ofPrimitive(Primitive.ofIntegerValue(long)) - ), - LiteralType.ofSimpleType(SimpleType.INTEGER) - ) - case double: Double => - ( - BindingData.ofScalar( - Scalar.ofPrimitive(Primitive.ofFloatValue(double)) - ), - LiteralType.ofSimpleType(SimpleType.FLOAT) - ) - case instant: Instant => - ( - BindingData.ofScalar( - Scalar.ofPrimitive(Primitive.ofDatetime(instant)) - ), - LiteralType.ofSimpleType(SimpleType.DATETIME) - ) - case duration: Duration => - ( - BindingData.ofScalar( - Scalar.ofPrimitive(Primitive.ofDuration(duration)) - ), - LiteralType.ofSimpleType(SimpleType.DURATION) - ) - case list: Seq[_] => - val literalType = literalTypeOpt.getOrElse { - val (_, innerLiteralType) = toBindingData( - list.headOption.getOrElse( - throw new RuntimeException( - "Can't create binding for an empty list without knowing the type, use SdkBindingData.ofCollection(...)" - ) - ), - literalTypeOpt = None - ) - - LiteralType.ofCollectionType(innerLiteralType) - } - - ( - BindingData.ofCollection( - list - .map { innerValue => - val (bindingData, _) = toBindingData(innerValue, literalTypeOpt) - bindingData - } - .toList - .asJava - ), - literalType - ) - case map: Map[String, _] => - val literalType = literalTypeOpt.getOrElse { - val (_, innerLiteralType) = toBindingData( - map.headOption - .map(_._2) - .getOrElse( - throw new RuntimeException( - "Can't create binding for an empty map without knowing the type, use SdkBindingData.ofMap(...)" - ) - ), - literalTypeOpt = None - ) - - LiteralType.ofMapValueType(innerLiteralType) - } - ( - BindingData.ofMap( - map - .mapValues { innerValue => - val (bindingData, _) = toBindingData(innerValue, literalTypeOpt) - bindingData - } - .toMap - .asJava - ), - literalType - ) - case other => - throw new IllegalStateException( - s"${other.getClass.getSimpleName} class is not supported as SdkBindingData inner class" - ) - } - } - - private def createSdkBindingData[T]( - value: T, - literalTypeOpt: Option[LiteralType] = None - ): SdkJavaBindinigData[T] = { - val (bindingData, literalType) = toBindingData(value, literalTypeOpt) - SdkJavaBindinigData.create(bindingData, literalType, value) - } -} diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala new file mode 100644 index 000000000..694cb0c4b --- /dev/null +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala @@ -0,0 +1,328 @@ +/* + * Copyright 2021 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekitscala + +import org.flyte.api.v1.{LiteralType, SimpleType} +import org.flyte.flytekit.{ + SdkBindingData, + SdkLiteralType, + SdkLiteralTypes => SdkJavaLiteralTypes +} +import org.flyte.flytekitscala.{SdkLiteralTypes => SdkScalaLiteralTypes} + +import java.{lang => j} +import java.{util => ju} +import java.util.{function => jf} +import scala.collection.JavaConverters._ + +/** The [[SdkBindingDataConverters]] allows you to do java <-> scala conversions + * for [[SdkBindingData]] + */ +object SdkBindingDataConverters { + + /** Transform from java.lang.Long to scala Long. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toScalaLong( + sdkBindingData: SdkBindingData[j.Long] + ): SdkBindingData[Long] = { + sdkBindingData.as(SdkScalaLiteralTypes.integers(), l => l) + } + + /** Transform from scala Long to java.lang.Long. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toJavaLong( + sdkBindingData: SdkBindingData[Long] + ): SdkBindingData[j.Long] = { + sdkBindingData.as(SdkJavaLiteralTypes.integers(), l => l) + } + + /** Transform from java.lang.Boolean to scala Boolean. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toScalaBoolean( + sdkBindingData: SdkBindingData[j.Boolean] + ): SdkBindingData[Boolean] = { + sdkBindingData.as(SdkScalaLiteralTypes.booleans(), b => b) + } + + /** Transform from scala Boolean to java.lang.Boolean. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toJavaBoolean( + sdkBindingData: SdkBindingData[Boolean] + ): SdkBindingData[j.Boolean] = { + sdkBindingData.as(SdkJavaLiteralTypes.booleans(), b => b) + } + + /** Transform from scala Double to java.lang.Double. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toScalaDouble( + sdkBindingData: SdkBindingData[j.Double] + ): SdkBindingData[Double] = { + sdkBindingData.as(SdkScalaLiteralTypes.floats(), f => f) + } + + /** Transform from scala Double to java.lang.Double. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toJavaDouble( + sdkBindingData: SdkBindingData[Double] + ): SdkBindingData[j.Double] = { + sdkBindingData.as(SdkJavaLiteralTypes.floats(), f => f) + } + + private case class TypeCastingResult( + convertedType: SdkLiteralType[_], + convFunction: jf.Function[Any, Any] + ) + + /** Transform from java.util.List to scala List. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toScalaList[JavaT, ScalaT]( + sdkBindingData: SdkBindingData[java.util.List[JavaT]] + ): SdkBindingData[List[ScalaT]] = { + val result = toScalaType(sdkBindingData.`type`().getLiteralType) + val elementType = + result.convertedType.asInstanceOf[SdkLiteralType[List[ScalaT]]] + val value = result.convFunction + .asInstanceOf[jf.Function[ju.List[JavaT], List[ScalaT]]] + + sdkBindingData.as(elementType, value) + } + + private def toScalaType(lt: LiteralType): TypeCastingResult = { + lt.getKind match { + case LiteralType.Kind.SIMPLE_TYPE => + lt.simpleType() match { + case SimpleType.FLOAT => + TypeCastingResult( + SdkScalaLiteralTypes.floats(), + (f: Any) => Double.unbox(f.asInstanceOf[j.Double]) + ) + case SimpleType.STRING => + TypeCastingResult( + SdkScalaLiteralTypes.strings(), + jf.Function.identity() + ) + case SimpleType.STRUCT => ??? // TODO not yet supported + case SimpleType.BOOLEAN => + TypeCastingResult( + SdkScalaLiteralTypes.booleans(), + (b: Any) => Boolean.unbox(b.asInstanceOf[j.Boolean]) + ) + case SimpleType.INTEGER => + TypeCastingResult( + SdkScalaLiteralTypes.integers(), + (i: Any) => Long.unbox(i.asInstanceOf[j.Long]) + ) + case SimpleType.DATETIME => + TypeCastingResult( + SdkScalaLiteralTypes.datetimes(), + jf.Function.identity() + ) + case SimpleType.DURATION => + TypeCastingResult( + SdkScalaLiteralTypes.durations(), + jf.Function.identity() + ) + } + case LiteralType.Kind.BLOB_TYPE => ??? // TODO not yet supported + case LiteralType.Kind.SCHEMA_TYPE => ??? // TODO not yet supported + case LiteralType.Kind.COLLECTION_TYPE => + val TypeCastingResult(convertedElementType, convFunction) = toScalaType( + lt.collectionType() + ) + TypeCastingResult( + SdkScalaLiteralTypes.collections(convertedElementType), + (l: Any) => + l.asInstanceOf[ju.List[_]] + .asScala + .map(e => convFunction.apply(e)) + .toList + ) + case LiteralType.Kind.MAP_VALUE_TYPE => + val TypeCastingResult(convertedElementType, convFunction) = toScalaType( + lt.mapValueType() + ) + TypeCastingResult( + SdkScalaLiteralTypes.maps(convertedElementType), + (m: Any) => + m.asInstanceOf[ju.Map[String, _]] + .asScala + .mapValues(e => convFunction.apply(e)) + .toMap + ) + } + } + + /** Transform from scala List to java.util.List. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toJavaList[ScalaT, JavaT]( + sdkBindingData: SdkBindingData[List[ScalaT]] + ): SdkBindingData[ju.List[JavaT]] = { + val result = toJavaType(sdkBindingData.`type`().getLiteralType) + val elementType = + result.convertedType.asInstanceOf[SdkLiteralType[ju.List[JavaT]]] + val value = result.convFunction + .asInstanceOf[jf.Function[List[ScalaT], ju.List[JavaT]]] + + sdkBindingData.as(elementType, value) + } + + private def toJavaType(lt: LiteralType): TypeCastingResult = { + lt.getKind match { + case LiteralType.Kind.SIMPLE_TYPE => + lt.simpleType() match { + case SimpleType.FLOAT => + TypeCastingResult( + SdkJavaLiteralTypes.floats(), + (f: Any) => j.Double.valueOf(f.asInstanceOf[Double]) + ) + case SimpleType.STRING => + TypeCastingResult( + SdkJavaLiteralTypes.strings(), + jf.Function.identity() + ) + case SimpleType.STRUCT => + ??? // TODO how to handle? do we support structs already? + case SimpleType.BOOLEAN => + TypeCastingResult( + SdkJavaLiteralTypes.booleans(), + (b: Any) => j.Boolean.valueOf(b.asInstanceOf[Boolean]) + ) + case SimpleType.INTEGER => + TypeCastingResult( + SdkJavaLiteralTypes.integers(), + (i: Any) => j.Long.valueOf(i.asInstanceOf[Long]) + ) + case SimpleType.DATETIME => + TypeCastingResult( + SdkJavaLiteralTypes.datetimes(), + jf.Function.identity() + ) + case SimpleType.DURATION => + TypeCastingResult( + SdkJavaLiteralTypes.durations(), + jf.Function.identity() + ) + } + case LiteralType.Kind.BLOB_TYPE => ??? // TODO do we support blob? + case LiteralType.Kind.SCHEMA_TYPE => + ??? // TODO do we support schema type? + case LiteralType.Kind.COLLECTION_TYPE => + val TypeCastingResult(convertedElementType, convFunction) = toJavaType( + lt.collectionType() + ) + TypeCastingResult( + SdkJavaLiteralTypes.collections(convertedElementType), + (l: Any) => + l.asInstanceOf[List[_]].map(e => convFunction.apply(e)).asJava + ) + case LiteralType.Kind.MAP_VALUE_TYPE => + val TypeCastingResult(convertedElementType, convFunction) = toJavaType( + lt.mapValueType() + ) + TypeCastingResult( + SdkJavaLiteralTypes.maps(convertedElementType), + (m: Any) => + m.asInstanceOf[Map[String, _]] + .mapValues(e => convFunction.apply(e)) + .toMap + .asJava + ) + } + } + + /** Transform from scala Map to java.util.Map. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toScalaMap[JavaT, ScalaT]( + sdkBindingData: SdkBindingData[java.util.Map[String, JavaT]] + ): SdkBindingData[Map[String, ScalaT]] = { + val literalType = toScalaType(sdkBindingData.`type`().getLiteralType) + val elementType = + literalType.convertedType.asInstanceOf[SdkLiteralType[ScalaT]] + val value = literalType.convFunction + .asInstanceOf[jf.Function[ju.Map[String, JavaT], Map[String, ScalaT]]] + sdkBindingData.as( + elementType.asInstanceOf[SdkLiteralType[Map[String, ScalaT]]], + value + ) + } + + /** Transform from scala Map to java.util.Map. + * + * @param sdkBindingData + * the value to transform + * @return + * the value transformed. + */ + def toJavaMap[ScalaT, JavaT]( + sdkBindingData: SdkBindingData[Map[String, ScalaT]] + ): SdkBindingData[java.util.Map[String, JavaT]] = { + val literalType = toJavaType(sdkBindingData.`type`().getLiteralType) + val elementType = + literalType.convertedType.asInstanceOf[SdkLiteralType[JavaT]] + val value = literalType.convFunction + .asInstanceOf[jf.Function[Map[String, ScalaT], ju.Map[String, JavaT]]] + sdkBindingData.as( + elementType.asInstanceOf[SdkLiteralType[ju.Map[String, JavaT]]], + value + ) + } +} diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala new file mode 100644 index 000000000..1eb463e75 --- /dev/null +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala @@ -0,0 +1,372 @@ +/* + * Copyright 2021 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekitscala + +import org.flyte.flytekit.{SdkBindingData, SdkLiteralType} +import org.flyte.flytekitscala.SdkLiteralTypes._ + +import java.time.{Duration, Instant} + +/** Utility to create [[SdkBindingData]] using scala raw types. + */ +object SdkBindingDataFactory { + + /** Creates a [[SdkBindingData]] for a flyte string ([[String]] for scala) + * with the given value. + * + * @param string + * the simple value for this data + * @return + * the new {[[SdkBindingData]] + */ + def of(string: String): SdkBindingData[String] = + SdkBindingData.literal(strings(), string) + + /** Creates a [[SdkBindingData]] for a flyte integer ([[Long]] for scala) with + * the given value. + * + * @param long + * the simple value for this data + * @return + * the new {[[SdkBindingData]] + */ + def of(long: Long): SdkBindingData[Long] = + SdkBindingData.literal(integers(), long) + + /** Creates a [[SdkBindingData]] for a flyte float ([[Double]] for scala) with + * the given value. + * + * @param double + * the simple value for this data + * @return + * the new {[[SdkBindingData]] + */ + def of(double: Double): SdkBindingData[Double] = + SdkBindingData.literal(floats(), double) + + /** Creates a [[SdkBindingData]] for a flyte boolean ([[Boolean]] for scala) + * with the given value. + * + * @param boolean + * the simple value for this data + * @return + * the new {[[SdkBindingData]] + */ + def of( + boolean: Boolean + ): SdkBindingData[Boolean] = + SdkBindingData.literal(booleans(), boolean) + + /** Creates a [[SdkBindingData]] for a flyte instant ([[Instant]] for scala) + * with the given value. + * + * @param instant + * the simple value for this data + * @return + * the new {[[SdkBindingData]] + */ + def of(instant: Instant): SdkBindingData[Instant] = + SdkBindingData.literal(datetimes(), instant) + + /** Creates a [[SdkBindingData]] for a flyte duration ([[Duration]] for scala) + * with the given value. + * + * @param duration + * the simple value for this data + * @return + * the new {[[SdkBindingData]] + */ + def of( + duration: Duration + ): SdkBindingData[Duration] = SdkBindingData.literal(durations(), duration) + + /** Creates a [[SdkBindingData]] for a flyte collection given a scala + * [[List]]. + * + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def of[T]( + collection: List[T] + ): SdkBindingData[List[T]] = + SdkBindingData.literal( + toSdkLiteralType(collection).asInstanceOf[SdkLiteralType[List[T]]], + collection + ) + + /** Creates a [[SdkBindingData]] for a flyte collection given a scala + * [[List]]. + * + * @param elementLiteralType + * [[SdkLiteralType]] for elements of collection. + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def of[T]( + elementLiteralType: SdkLiteralType[T], + collection: List[T] + ): SdkBindingData[List[T]] = + SdkBindingData.literal( + collections(elementLiteralType), + collection + ) + + /** Creates a [[SdkBindingDataFactory]] for a flyte string collection given a + * scala [[List]]. + * + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofStringCollection( + collection: List[String] + ): SdkBindingData[List[String]] = + SdkBindingData.literal(collections(strings()), collection) + + /** Creates a [[SdkBindingData]] for a flyte integer collection given a scala + * [[List]]. + * + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingDataFactory]] + */ + def ofIntegerCollection( + collection: List[Long] + ): SdkBindingData[List[Long]] = + SdkBindingData.literal(collections(integers()), collection) + + /** Creates a [[SdkBindingData]] for a flyte boolean collection given a scala + * [[List]]. + * + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofBooleanCollection( + collection: List[Boolean] + ): SdkBindingData[List[Boolean]] = + SdkBindingData.literal(collections(booleans()), collection) + + /** Creates a [[SdkBindingData]] for a flyte float collection given a scala + * [[List]]. + * + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofFloatCollection( + collection: List[Double] + ): SdkBindingData[List[Double]] = + SdkBindingData.literal(collections(floats()), collection) + + /** Creates a [[SdkBindingData]] for a flyte datetime collection given a scala + * [[List]]. + * + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofDatetimeCollection( + collection: List[Instant] + ): SdkBindingData[List[Instant]] = + SdkBindingData.literal(collections(datetimes()), collection) + + /** Creates a [[SdkBindingData]] for a flyte duration collection given a scala + * [[List]]. + * + * @param collection + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofDurationCollection( + collection: List[Duration] + ): SdkBindingData[List[Duration]] = + SdkBindingData.literal(collections(durations()), collection) + + /** Creates a [[SdkBindingData]] for a flyte map given a scala [[Map]]. + * + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def of[T](map: Map[String, T]): SdkBindingData[Map[String, T]] = + SdkBindingData.literal( + toSdkLiteralType(map).asInstanceOf[SdkLiteralType[Map[String, T]]], + map + ) + + /** Creates a [[SdkBindingData]] for a flyte string map given a scala [[Map]]. + * + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingDataFactory]] + */ + def ofStringMap( + map: Map[String, String] + ): SdkBindingData[Map[String, String]] = + SdkBindingData.literal(maps(strings()), map) + + /** Creates a [[SdkBindingData]] for a flyte long map given a scala [[Map]]. + * + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofIntegerMap(map: Map[String, Long]): SdkBindingData[Map[String, Long]] = + SdkBindingData.literal(maps(integers()), map) + + /** Creates a [[SdkBindingData]] for a flyte boolean map given a scala + * [[Map]]. + * + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofBooleanMap( + map: Map[String, Boolean] + ): SdkBindingData[Map[String, Boolean]] = + SdkBindingData.literal(maps(booleans()), map) + + /** Creates a [[SdkBindingData]] for a flyte double map given a scala [[Map]]. + * + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofFloatMap( + map: Map[String, Double] + ): SdkBindingData[Map[String, Double]] = + SdkBindingData.literal(maps(floats()), map) + + /** Creates a [[SdkBindingData]] for a flyte instant map given a scala + * [[Map]]. + * + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofDatetimeMap( + map: Map[String, Instant] + ): SdkBindingData[Map[String, Instant]] = + SdkBindingData.literal(maps(datetimes()), map) + + /** Creates a [[SdkBindingData]] for a flyte duration map given a scala + * [[Map]]. + * + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def ofDurationMap( + map: Map[String, Duration] + ): SdkBindingData[Map[String, Duration]] = + SdkBindingData.literal(maps(durations()), map) + + /** Creates a [[SdkBindingData]] for a flyte duration map given a scala + * [[Map]]. + * + * @param valuesLiteralType + * [[SdkLiteralType]] type for the values of the map. + * @param map + * collection to represent on this data. + * @return + * the new [[SdkBindingData]] + */ + def of[T]( + valuesLiteralType: SdkLiteralType[T], + map: Map[String, T] + ): SdkBindingData[Map[String, T]] = + SdkBindingData.literal(maps(valuesLiteralType), map) + + private def toSdkLiteralType( + value: Any, + internalTypeOpt: Option[SdkLiteralType[_]] = Option.empty + ): SdkLiteralType[_] = { + value match { + case string: String => + strings() + case boolean: Boolean => + booleans() + case long: Long => + integers() + + case double: Double => + floats() + + case instant: Instant => + datetimes() + + case duration: Duration => + durations() + + case list: Seq[_] => + val internalType = internalTypeOpt.getOrElse { + toSdkLiteralType( + list.headOption.getOrElse( + throw new RuntimeException( + // TODO: check the error comment once we have settle with the name + "Can't create binding for an empty list without knowing the type, use SdkBindingData.ofCollection(...)" + ) + ) + ) + + } + collections(internalType) + + case map: Map[_, _] => + val internalType = internalTypeOpt.getOrElse { + val head = map.headOption.getOrElse( + throw new RuntimeException( + // TODO: check the error comment once we have settle with the name + "Can't create binding for an empty map without knowing the type, use SdkBindingData.ofMap(...)" + ) + ) + head._1 match { + case _: String => toSdkLiteralType(head._2) + case _ => + throw new RuntimeException( + "Can't create binding for a map with key type other than String." + ) + } + } + maps(internalType) + + case other => + throw new IllegalStateException( + s"${other.getClass.getSimpleName} class is not supported as SdkBindingData inner class" + ) + } + } + +} diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala new file mode 100644 index 000000000..df34b0216 --- /dev/null +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -0,0 +1,283 @@ +/* + * Copyright 2021 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekitscala + +import org.flyte.api.v1._ +import org.flyte.flytekit.{ + SdkLiteralType, + SdkLiteralTypes => SdkJavaLiteralTypes +} + +import java.time.{Duration, Instant} +import scala.collection.JavaConverters._ +import scala.reflect.runtime.universe.{TypeTag, typeOf} + +object SdkLiteralTypes { + + /** [[SdkLiteralType]] for the specified Scala type. + * + * | Scala type | Returned type | + * |:-------------|:--------------------------------------------------------------| + * | [[Long]] | {{{SdkLiteralType[Long]}}}, equivalent to [[integers()]] | + * | [[Double]] | {{{SdkLiteralType[Double]}}}, equivalent to [[floats()]] | + * | [[String]] | {{{SdkLiteralType[String]}}}, equivalent to [[strings]] | + * | [[Boolean]] | {{{SdkLiteralType[Boolean]}}}, equivalent to [[booleans()]] | + * | [[Instant]] | {{{SdkLiteralType[Instant]}}}, equivalent to [[datetimes()]] | + * | [[Duration]] | {{{SdkLiteralType[Duration]}}}, equivalent to [[durations()]] | + * @tparam T + * Scala type used to decide what [[SdkLiteralType]] to return. + * @return + * the [[SdkLiteralType]] based on the java type + */ + def of[T: TypeTag](): SdkLiteralType[T] = { + typeOf[T] match { + case t if t =:= typeOf[Long] => integers().asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Double] => floats().asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[String] => + strings().asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Boolean] => + booleans().asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Instant] => + datetimes().asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Duration] => + durations().asInstanceOf[SdkLiteralType[T]] + + case t if t =:= typeOf[List[Long]] => + collections(integers()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Double]] => + collections(floats()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[String]] => + collections(strings()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Boolean]] => + collections(booleans()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Instant]] => + collections(datetimes()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Duration]] => + collections(durations()).asInstanceOf[SdkLiteralType[T]] + + case t if t =:= typeOf[Map[String, Long]] => + maps(integers()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Double]] => + maps(floats()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, String]] => + maps(strings()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Boolean]] => + maps(booleans()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Instant]] => + maps(datetimes()).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Duration]] => + maps(durations()).asInstanceOf[SdkLiteralType[T]] + + case t if t =:= typeOf[List[List[Long]]] => + collections(collections(integers())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[List[Double]]] => + collections(collections(floats())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[List[String]]] => + collections(collections(strings())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[List[Boolean]]] => + collections(collections(booleans())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[List[Instant]]] => + collections(collections(datetimes())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[List[Duration]]] => + collections(collections(durations())).asInstanceOf[SdkLiteralType[T]] + + case t if t =:= typeOf[List[Map[String, Long]]] => + collections(maps(integers())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Map[String, Double]]] => + collections(maps(floats())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Map[String, String]]] => + collections(maps(strings())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Map[String, Boolean]]] => + collections(maps(booleans())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Map[String, Instant]]] => + collections(maps(datetimes())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[List[Map[String, Duration]]] => + collections(maps(durations())).asInstanceOf[SdkLiteralType[T]] + + case t if t =:= typeOf[Map[String, Map[String, Long]]] => + maps(maps(integers())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Map[String, Double]]] => + maps(maps(floats())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Map[String, String]]] => + maps(maps(strings())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Map[String, Boolean]]] => + maps(maps(booleans())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Map[String, Instant]]] => + maps(maps(datetimes())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, Map[String, Duration]]] => + maps(maps(durations())).asInstanceOf[SdkLiteralType[T]] + + case t if t =:= typeOf[Map[String, List[Long]]] => + maps(collections(integers())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, List[Double]]] => + maps(collections(floats())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, List[String]]] => + maps(collections(strings())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, List[Boolean]]] => + maps(collections(booleans())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, List[Instant]]] => + maps(collections(datetimes())).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Map[String, List[Duration]]] => + maps(collections(durations())).asInstanceOf[SdkLiteralType[T]] + + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${typeOf[T]}") + } + } + + /** Returns a [[SdkLiteralType]] for flyte integers. + * + * @return + * the [[SdkLiteralType]] + */ + def integers(): SdkLiteralType[Long] = ScalaLiteralType[Long]( + LiteralType.ofSimpleType(SimpleType.INTEGER), + value => + Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(value))), + _.scalar().primitive().integerValue(), + v => BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(v))), + "integers" + ) + + /** Returns a [[SdkLiteralType]] for flyte floats. + * + * @return + * the [[SdkLiteralType]] + */ + def floats(): SdkLiteralType[Double] = ScalaLiteralType[Double]( + LiteralType.ofSimpleType(SimpleType.FLOAT), + value => + Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofFloatValue(value))), + _.scalar().primitive().floatValue(), + v => BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofFloatValue(v))), + "floats" + ) + + /** Returns a [[SdkLiteralType]] for string. + * + * @return + * the [[SdkLiteralType]] + */ + def strings(): SdkLiteralType[String] = SdkJavaLiteralTypes.strings() + + /** Returns a [[SdkLiteralType]] for booleans. + * + * @return + * the [[SdkLiteralType]] + */ + def booleans(): SdkLiteralType[Boolean] = ScalaLiteralType[Boolean]( + LiteralType.ofSimpleType(SimpleType.BOOLEAN), + value => + Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofBooleanValue(value))), + _.scalar().primitive().booleanValue(), + v => BindingData.ofScalar(Scalar.ofPrimitive(Primitive.ofBooleanValue(v))), + "booleans" + ) + + /** Returns a [[SdkLiteralType]] for flyte date times. + * + * @return + * the [[SdkLiteralType]] + */ + def datetimes(): SdkLiteralType[Instant] = SdkJavaLiteralTypes.datetimes() + + /** Returns a [[SdkLiteralType]] for durations. + * + * @return + * the [[SdkLiteralType]] + */ + def durations(): SdkLiteralType[Duration] = SdkJavaLiteralTypes.durations() + + /** Returns a [[SdkLiteralType]] for flyte collections. + * + * @param elementType + * the [[SdkLiteralType]] representing the types of the elements of the + * collection. + * @tparam T + * the Scala type of the elements of the collection. + * @return + * the [[SdkLiteralType]] + */ + def collections[T]( + elementType: SdkLiteralType[T] + ): SdkLiteralType[List[T]] = + new SdkLiteralType[List[T]] { + override def getLiteralType: LiteralType = + LiteralType.ofCollectionType(elementType.getLiteralType) + + override def toLiteral(values: List[T]): Literal = + Literal.ofCollection(values.map(elementType.toLiteral).asJava) + + override def fromLiteral(literal: Literal): List[T] = + literal.collection().asScala.map(elementType.fromLiteral).toList + + override def toBindingData(value: List[T]): BindingData = + BindingData.ofCollection(value.map(elementType.toBindingData).asJava) + + override def toString = s"collection of [$elementType]" + } + + /** Returns a [[SdkLiteralType]] for flyte maps. + * + * @param valuesType + * the [[SdkLiteralType]] representing the types of the map's values. + * @tparam T + * the Scala type of the map's values, keys are always string. + * @return + * the [[SdkLiteralType]] + */ + def maps[T](valuesType: SdkLiteralType[T]): SdkLiteralType[Map[String, T]] = + new SdkLiteralType[Map[String, T]] { + override def getLiteralType: LiteralType = + LiteralType.ofMapValueType(valuesType.getLiteralType) + + override def toLiteral(values: Map[String, T]): Literal = + Literal.ofMap(values.mapValues(valuesType.toLiteral).toMap.asJava) + + override def fromLiteral(literal: Literal): Map[String, T] = + literal.map().asScala.mapValues(valuesType.fromLiteral).toMap + + override def toBindingData(value: Map[String, T]): BindingData = { + BindingData.ofMap( + value.mapValues(valuesType.toBindingData).toMap.asJava + ) + } + + override def toString: String = s"map of [$valuesType]" + } +} + +private object ScalaLiteralType { + def apply[T]( + literalType: LiteralType, + to: T => Literal, + from: Literal => T, + toData: T => BindingData, + strRep: String + ): SdkLiteralType[T] = + new SdkLiteralType[T] { + override def getLiteralType: LiteralType = literalType + + override def toLiteral(value: T): Literal = to(value) + + override def fromLiteral(literal: Literal): T = from(literal) + + override def toBindingData(value: T): BindingData = toData(value) + + override def toString: String = strRep + } +} diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 09346ad94..3b8b5aaf9 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -20,7 +20,12 @@ import java.time.{Duration, Instant} import java.{util => ju} import magnolia.{CaseClass, Magnolia, Param, SealedTrait} import org.flyte.api.v1._ -import org.flyte.flytekit.{SdkType, SdkBindingData => SdkJavaBindinigData} +import org.flyte.flytekit.{ + SdkBindingData, + SdkLiteralType, + SdkType, + SdkLiteralTypes => SdkJavaLiteralTypes +} import scala.annotation.implicitNotFound import scala.collection.JavaConverters._ @@ -38,27 +43,20 @@ sealed trait SdkScalaType[T] trait SdkScalaProductType[T] extends SdkType[T] with SdkScalaType[T] -trait SdkScalaLiteralType[T] extends SdkScalaType[T] { - def getLiteralType: LiteralType +trait SdkScalaLiteralType[T] extends SdkLiteralType[T] with SdkScalaType[T] - def toLiteral(value: T): Literal +case class DelegateLiteralType[T](delegate: SdkLiteralType[T]) + extends SdkScalaLiteralType[T] { + override def getLiteralType: LiteralType = delegate.getLiteralType - def fromLiteral(literal: Literal): T -} + override def toLiteral(value: T): Literal = delegate.toLiteral(value) -object SdkScalaLiteralType { - def apply[T]( - literalType: LiteralType, - to: T => Literal, - from: Literal => T - ): SdkScalaLiteralType[T] = - new SdkScalaLiteralType[T] { - override def getLiteralType: LiteralType = literalType + override def fromLiteral(literal: Literal): T = delegate.fromLiteral(literal) - override def toLiteral(value: T): Literal = to(value) + override def toBindingData(value: T): BindingData = + delegate.toBindingData(value) - override def fromLiteral(literal: Literal): T = from(literal) - } + override def toString: String = delegate.toString } /** Applied to a case classes fields to denote the description of such field @@ -151,17 +149,17 @@ object SdkScalaType { s"field ${param.label} not found in variable map" ) - SdkJavaBindinigData.ofOutputReference( + SdkBindingData.promise( + param.typeclass, nodeId, - param.label, - paramLiteralType.literalType() + param.label ) }) } override def toSdkBindingMap( value: T - ): ju.Map[String, SdkJavaBindinigData[_]] = { + ): ju.Map[String, SdkBindingData[_]] = { value match { case product: Product => value.getClass.getDeclaredFields @@ -169,7 +167,7 @@ object SdkScalaType { .zip(product.productIterator.toSeq) .toMap .mapValues { - case value: SdkJavaBindinigData[_] => value + case value: SdkBindingData[_] => value case _ => throw new IllegalStateException( s"All the fields of the case class ${value.getClass.getSimpleName} must be SdkBindingData[_]" @@ -183,12 +181,22 @@ object SdkScalaType { ) } } + + override def toLiteralTypes: ju.Map[String, SdkLiteralType[_]] = { + params + .map { case ParamsWithDesc(param, _) => + val value: SdkLiteralType[_] = param.typeclass + param.label -> value + } + .toMap[String, SdkLiteralType[_]] + .asJava + } } } implicit def sdkBindingLiteralType[T](implicit sdkLiteral: SdkScalaLiteralType[T] - ): SdkScalaLiteralType[SdkJavaBindinigData[T]] = { + ): SdkScalaLiteralType[SdkBindingData[T]] = { def toBindingData(literal: Literal): BindingData = { literal.kind() match { @@ -205,62 +213,37 @@ object SdkScalaType { } } - SdkScalaLiteralType[SdkJavaBindinigData[T]]( - sdkLiteral.getLiteralType, - value => sdkLiteral.toLiteral(value.get()), - literal => - SdkJavaBindinigData.create( - toBindingData(literal), - sdkLiteral.getLiteralType, - sdkLiteral.fromLiteral(literal) - ) - ) + new SdkScalaLiteralType[SdkBindingData[T]]() { + override def getLiteralType: LiteralType = sdkLiteral.getLiteralType + + override def toLiteral(value: SdkBindingData[T]): Literal = + sdkLiteral.toLiteral(value.get()) + + override def fromLiteral(literal: Literal): SdkBindingData[T] = + SdkBindingData.literal(sdkLiteral, sdkLiteral.fromLiteral(literal)) + + override def toBindingData(value: SdkBindingData[T]): BindingData = + sdkLiteral.toBindingData(value.get()) + } } implicit def stringLiteralType: SdkScalaLiteralType[String] = - SdkScalaLiteralType[String]( - LiteralType.ofSimpleType(SimpleType.STRING), - value => - Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofStringValue(value))), - _.scalar().primitive().stringValue() - ) + DelegateLiteralType(SdkLiteralTypes.strings()) implicit def longLiteralType: SdkScalaLiteralType[Long] = - SdkScalaLiteralType[Long]( - LiteralType.ofSimpleType(SimpleType.INTEGER), - value => Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofInteger(value))), - _.scalar().primitive().integerValue() - ) + DelegateLiteralType(SdkLiteralTypes.integers()) implicit def doubleLiteralType: SdkScalaLiteralType[Double] = - SdkScalaLiteralType[Double]( - LiteralType.ofSimpleType(SimpleType.FLOAT), - value => Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofFloat(value))), - literal => literal.scalar().primitive().floatValue() - ) + DelegateLiteralType(SdkLiteralTypes.floats()) implicit def booleanLiteralType: SdkScalaLiteralType[Boolean] = - SdkScalaLiteralType[Boolean]( - LiteralType.ofSimpleType(SimpleType.BOOLEAN), - value => Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofBoolean(value))), - _.scalar().primitive().booleanValue() - ) + DelegateLiteralType(SdkLiteralTypes.booleans()) implicit def instantLiteralType: SdkScalaLiteralType[Instant] = - SdkScalaLiteralType[Instant]( - LiteralType.ofSimpleType(SimpleType.DATETIME), - value => - Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofDatetime(value))), - _.scalar().primitive().datetime() - ) + DelegateLiteralType(SdkLiteralTypes.datetimes()) implicit def durationLiteralType: SdkScalaLiteralType[Duration] = - SdkScalaLiteralType[Duration]( - LiteralType.ofSimpleType(SimpleType.DURATION), - value => - Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofDuration(value))), - _.scalar().primitive().duration() - ) + DelegateLiteralType(SdkLiteralTypes.durations()) // TODO we are forced to do this because SdkDataBinding.ofInteger returns a SdkBindingData // This makes Scala dev mad when they are forced to use the java types instead of scala types @@ -268,75 +251,23 @@ object SdkScalaType { // So java and scala can have their own factory class/companion object using their own native types // In the meantime, we need to duplicate all the literal types to use also the java types implicit def javaLongLiteralType: SdkScalaLiteralType[java.lang.Long] = - SdkScalaLiteralType[java.lang.Long]( - LiteralType.ofSimpleType(SimpleType.INTEGER), - value => Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofInteger(value))), - _.scalar().primitive().integerValue() - ) + DelegateLiteralType(SdkJavaLiteralTypes.integers()) implicit def javaDoubleLiteralType: SdkScalaLiteralType[java.lang.Double] = - SdkScalaLiteralType[java.lang.Double]( - LiteralType.ofSimpleType(SimpleType.FLOAT), - value => Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofFloat(value))), - literal => literal.scalar().primitive().floatValue() - ) + DelegateLiteralType(SdkJavaLiteralTypes.floats()) implicit def javaBooleanLiteralType: SdkScalaLiteralType[java.lang.Boolean] = - SdkScalaLiteralType[java.lang.Boolean]( - LiteralType.ofSimpleType(SimpleType.BOOLEAN), - value => Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofBoolean(value))), - _.scalar().primitive().booleanValue() - ) + DelegateLiteralType(SdkJavaLiteralTypes.booleans()) implicit def collectionLiteralType[T](implicit sdkLiteral: SdkScalaLiteralType[T] - ): SdkScalaLiteralType[List[T]] = { - new SdkScalaLiteralType[List[T]] { - - override def getLiteralType: LiteralType = - LiteralType.ofCollectionType(sdkLiteral.getLiteralType) - - override def toLiteral(values: List[T]): Literal = { - Literal.ofCollection( - values - .map(value => sdkLiteral.toLiteral(value)) - .asJava - ) - } - - override def fromLiteral(literal: Literal): List[T] = - literal - .collection() - .asScala - .map(elem => sdkLiteral.fromLiteral(elem)) - .toList - } - } + ): SdkScalaLiteralType[List[T]] = + DelegateLiteralType(SdkLiteralTypes.collections(sdkLiteral)) implicit def mapLiteralType[T](implicit sdkLiteral: SdkScalaLiteralType[T] - ): SdkScalaLiteralType[Map[String, T]] = { - new SdkScalaLiteralType[Map[String, T]] { - - override def getLiteralType: LiteralType = - LiteralType.ofMapValueType(sdkLiteral.getLiteralType) - - override def toLiteral(values: Map[String, T]): Literal = { - Literal.ofMap( - values.map { case (key, value) => - key -> sdkLiteral.toLiteral(value) - }.asJava - ) - } - - override def fromLiteral(literal: Literal): Map[String, T] = - literal - .map() - .asScala - .map { case (key, value) => key -> sdkLiteral.fromLiteral(value) } - .toMap - } - } + ): SdkScalaLiteralType[Map[String, T]] = + DelegateLiteralType(SdkLiteralTypes.maps(sdkLiteral)) @implicitNotFound("Cannot derive SdkScalaType for sealed trait") sealed trait Dispatchable[T] @@ -365,6 +296,9 @@ private object SdkUnitType extends SdkScalaProductType[Unit] { override def toSdkBindingMap( value: Unit - ): ju.Map[String, SdkJavaBindinigData[_]] = - Map.empty[String, SdkJavaBindinigData[_]].asJava + ): ju.Map[String, SdkBindingData[_]] = + Map.empty[String, SdkBindingData[_]].asJava + + override def toLiteralTypes: ju.Map[String, SdkLiteralType[_]] = + Map.empty[String, SdkLiteralType[_]].asJava } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java index c6c53c0fb..2345c37bb 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java @@ -16,12 +16,12 @@ */ package org.flyte.flytekit.testing; -import static org.flyte.flytekit.SdkBindingData.ofInteger; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -53,7 +53,8 @@ public void testWithFixedInputs() { SdkTestingExecutor.of(new FibonacciWorkflow()) .withFixedInputs( JacksonSdkType.of(FibonacciWorkflowInputs.class), - FibonacciWorkflowInputs.create(ofInteger(1), ofInteger(1))) + FibonacciWorkflowInputs.create( + SdkBindingDataFactory.of(1), SdkBindingDataFactory.of(1))) .execute(); assertThat(result.getIntegerOutput("fib2"), equalTo(2L)); @@ -70,8 +71,8 @@ public void testWithTaskOutput_runnableTask() { .withFixedInput("fib1", 1) .withTaskOutput( new SumTask(), - SumInput.create(ofInteger(3L), ofInteger(5L)), - SumOutput.create(ofInteger(42L))) + SumInput.create(SdkBindingDataFactory.of(3L), SdkBindingDataFactory.of(5L)), + SumOutput.create(SdkBindingDataFactory.of(42L))) .execute(); assertThat(result.getIntegerOutput("fib2"), equalTo(2L)); @@ -88,19 +89,19 @@ public void testWithTaskOutput_remoteTask() { .withFixedInput("fib1", 1) .withTaskOutput( RemoteSumTask.create(), - RemoteSumInput.create(ofInteger(1L), ofInteger(1L)), + RemoteSumInput.create(SdkBindingDataFactory.of(1L), SdkBindingDataFactory.of(1L)), RemoteSumOutput.create(5L)) .withTaskOutput( RemoteSumTask.create(), - RemoteSumInput.create(ofInteger(1L), ofInteger(5L)), + RemoteSumInput.create(SdkBindingDataFactory.of(1L), SdkBindingDataFactory.of(5L)), RemoteSumOutput.create(10L)) .withTaskOutput( RemoteSumTask.create(), - RemoteSumInput.create(ofInteger(5L), ofInteger(10L)), + RemoteSumInput.create(SdkBindingDataFactory.of(5L), SdkBindingDataFactory.of(10L)), RemoteSumOutput.create(20L)) .withTaskOutput( RemoteSumTask.create(), - RemoteSumInput.create(ofInteger(10L), ofInteger(20L)), + RemoteSumInput.create(SdkBindingDataFactory.of(10L), SdkBindingDataFactory.of(20L)), RemoteSumOutput.create(40L)) .execute(); @@ -118,12 +119,13 @@ public void testWithTask() { .withFixedInput("fib1", 1) .withTask( new SumTask(), - input -> SumOutput.create(ofInteger(input.a().get() * input.b().get()))) + input -> + SumOutput.create(SdkBindingDataFactory.of(input.a().get() * input.b().get()))) // can combine withTask and withTaskOutput .withTaskOutput( new SumTask(), - SumInput.create(ofInteger(1), ofInteger(1)), - SumOutput.create(ofInteger(2))) + SumInput.create(SdkBindingDataFactory.of(1), SdkBindingDataFactory.of(1)), + SumOutput.create(SdkBindingDataFactory.of(2))) .execute(); assertThat(result.getIntegerOutput("fib2"), equalTo(2L)); diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java index 56a10eeb6..0b375b48a 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java @@ -16,18 +16,17 @@ */ package org.flyte.flytekit.testing; -import static org.flyte.flytekit.SdkBindingData.ofString; import static org.flyte.flytekit.SdkConditions.eq; import static org.flyte.flytekit.SdkConditions.gt; import static org.flyte.flytekit.SdkConditions.lt; import static org.flyte.flytekit.SdkConditions.when; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.in; import com.google.auto.value.AutoValue; import java.util.stream.Stream; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkCondition; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.SdkWorkflow; @@ -115,17 +114,19 @@ public ConstStringTask.Output expand( "c == d", eq(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a == b && c == d"))) + ConstStringTask.Input.create( + SdkBindingDataFactory.of("a == b && c == d"))) .when( "c > d", gt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a == b && c > d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a == b && c > d"))) .when( "c < d", lt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a == b && c < d")))) + ConstStringTask.Input.create( + SdkBindingDataFactory.of("a == b && c < d")))) .when( "a < b", lt(a, b), @@ -133,17 +134,17 @@ public ConstStringTask.Output expand( "c == d", eq(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a < b && c == d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a < b && c == d"))) .when( "c > d", gt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a < b && c > d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a < b && c > d"))) .when( "c < d", lt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a < b && c < d")))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a < b && c < d")))) .when( "a > b", gt(a, b), @@ -151,17 +152,18 @@ public ConstStringTask.Output expand( "c == d", eq(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a > b && c == d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a > b && c == d"))) .when( "c > d", gt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a > b && c > d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a > b && c > d"))) .when( "c < d", lt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(ofString("a > b && c < d")))); + ConstStringTask.Input.create( + SdkBindingDataFactory.of("a > b && c < d")))); SdkBindingData value = builder.apply("condition", condition).getOutputs().value(); diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java index 5479210aa..7c95f2a60 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java @@ -18,6 +18,7 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRemoteTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -48,7 +49,7 @@ public abstract static class RemoteSumOutput { public abstract SdkBindingData c(); public static RemoteSumOutput create(long c) { - return new AutoValue_RemoteSumTask_RemoteSumOutput(SdkBindingData.ofInteger(c)); + return new AutoValue_RemoteSumTask_RemoteSumOutput(SdkBindingDataFactory.of(c)); } } } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteVoidOutputTask.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteVoidOutputTask.java index ab705f032..dd53d6220 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteVoidOutputTask.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteVoidOutputTask.java @@ -18,6 +18,7 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRemoteTask; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -38,7 +39,7 @@ public abstract static class Input { public abstract SdkBindingData ignore(); public static Input create(String ignore) { - return new AutoValue_RemoteVoidOutputTask_Input(/*ignore=*/ SdkBindingData.ofString(ignore)); + return new AutoValue_RemoteVoidOutputTask_Input(/*ignore=*/ SdkBindingDataFactory.of(ignore)); } } } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java index 5015ac99c..e9b370eba 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java @@ -18,7 +18,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.in; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -26,6 +25,7 @@ import java.time.Duration; import java.time.Instant; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRemoteLaunchPlan; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; @@ -208,7 +208,7 @@ public Void expand(SdkWorkflowBuilder builder, Void noInput) { builder.apply( "sum", RemoteSumTask.create(), - RemoteSumInput.create(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L))); + RemoteSumInput.create(SdkBindingDataFactory.of(1L), SdkBindingDataFactory.of(2L))); return null; } }; @@ -233,7 +233,7 @@ public Void expand(SdkWorkflowBuilder builder, Void noInput) { builder.apply( "sum", RemoteSumTask.create(), - RemoteSumInput.create(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L))); + RemoteSumInput.create(SdkBindingDataFactory.of(1L), SdkBindingDataFactory.of(2L))); return null; } }; @@ -246,14 +246,14 @@ public Void expand(SdkWorkflowBuilder builder, Void noInput) { .withTaskOutput( RemoteSumTask.create(), RemoteSumInput.create( - SdkBindingData.ofInteger(10L), SdkBindingData.ofInteger(20L)), + SdkBindingDataFactory.of(10L), SdkBindingDataFactory.of(20L)), RemoteSumOutput.create(30L)) .execute()); assertThat( e.getMessage(), equalTo( - "Can't find input RemoteSumInput{a=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=1}}}, type=LiteralType{simpleType=INTEGER}, value=1}, b=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=2}}}, type=LiteralType{simpleType=INTEGER}, value=2}} for remote task [remote_sum_task] across known task inputs, use SdkTestingExecutor#withTaskOutput or SdkTestingExecutor#withTask to provide a test double")); + "Can't find input RemoteSumInput{a=SdkBindingData{type=integers, value=1}, b=SdkBindingData{type=integers, value=2}} for remote task [remote_sum_task] across known task inputs, use SdkTestingExecutor#withTaskOutput or SdkTestingExecutor#withTask to provide a test double")); } @Test @@ -285,9 +285,9 @@ public void withWorkflowOutput_successfullyMocksWhenTypeMatches() { .withWorkflowOutput( new SimpleSubWorkflow(), JacksonSdkType.of(TestUnaryIntegerIO.class), - TestUnaryIntegerIO.create(SdkBindingData.ofInteger(7)), + TestUnaryIntegerIO.create(SdkBindingDataFactory.of(7)), JacksonSdkType.of(TestUnaryIntegerIO.class), - TestUnaryIntegerIO.create(SdkBindingData.ofInteger(5))) + TestUnaryIntegerIO.create(SdkBindingDataFactory.of(5))) .execute(); assertThat(result.getIntegerOutput("integer"), equalTo(5L)); @@ -329,8 +329,8 @@ public TestUnaryIntegerIO expand(SdkWorkflowBuilder builder, SumLaunchPlanInput .withLaunchPlanOutput( launchplanRef, SumLaunchPlanInput.create( - SdkBindingData.ofInteger(3L), SdkBindingData.ofInteger(5L)), - SumLaunchPlanOutput.create(SdkBindingData.ofInteger(8L))) + SdkBindingDataFactory.of(3L), SdkBindingDataFactory.of(5L)), + SumLaunchPlanOutput.create(SdkBindingDataFactory.of(8L))) .execute(); assertThat(result.getIntegerOutput("integer"), equalTo(8L)); @@ -376,14 +376,14 @@ public TestUnaryIntegerIO expand(SdkWorkflowBuilder builder, SumLaunchPlanInput launchplanRef, // The stub values won't be matched, so exception iis throws SumLaunchPlanInput.create( - SdkBindingData.ofInteger(100000L), SdkBindingData.ofInteger(100000L)), - SumLaunchPlanOutput.create(SdkBindingData.ofInteger(8L))) + SdkBindingDataFactory.of(100000L), SdkBindingDataFactory.of(100000L)), + SumLaunchPlanOutput.create(SdkBindingDataFactory.of(8L))) .execute()); assertThat( ex.getMessage(), equalTo( - "Can't find input SumLaunchPlanInput{a=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=3}}}, type=LiteralType{simpleType=INTEGER}, value=3}, b=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=5}}}, type=LiteralType{simpleType=INTEGER}, value=5}} for remote launch plan [SumWorkflow] across known launch plan inputs, use SdkTestingExecutor#withLaunchPlanOutput or SdkTestingExecutor#withLaunchPlan to provide a test double")); + "Can't find input SumLaunchPlanInput{a=SdkBindingData{type=integers, value=3}, b=SdkBindingData{type=integers, value=5}} for remote launch plan [SumWorkflow] across known launch plan inputs, use SdkTestingExecutor#withLaunchPlanOutput or SdkTestingExecutor#withLaunchPlan to provide a test double")); } @Test @@ -423,7 +423,7 @@ public TestUnaryIntegerIO expand(SdkWorkflowBuilder builder, SumLaunchPlanInput launchplanRef, in -> SumLaunchPlanOutput.create( - SdkBindingData.ofInteger(in.a().get() + in.b().get()))) + SdkBindingDataFactory.of(in.a().get() + in.b().get()))) .execute(); assertThat(result.getIntegerOutput("integer"), equalTo(35L)); @@ -438,7 +438,7 @@ public Void expand(SdkWorkflowBuilder builder, Void noInput) { builder.apply( "sum", RemoteSumTask.create(), - RemoteSumInput.create(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L))); + RemoteSumInput.create(SdkBindingDataFactory.of(1L), SdkBindingDataFactory.of(2L))); return null; } }; @@ -451,14 +451,14 @@ public Void expand(SdkWorkflowBuilder builder, Void noInput) { .withTaskOutput( RemoteSumTask.create(), RemoteSumInput.create( - SdkBindingData.ofInteger(10L), SdkBindingData.ofInteger(20L)), + SdkBindingDataFactory.of(10L), SdkBindingDataFactory.of(20L)), RemoteSumOutput.create(30L)) .execute()); assertThat( e.getMessage(), equalTo( - "Can't find input RemoteSumInput{a=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=1}}}, type=LiteralType{simpleType=INTEGER}, value=1}, b=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=2}}}, type=LiteralType{simpleType=INTEGER}, value=2}} for remote task [remote_sum_task] across known task inputs, use SdkTestingExecutor#withTaskOutput or SdkTestingExecutor#withTask to provide a test double")); + "Can't find input RemoteSumInput{a=SdkBindingData{type=integers, value=1}, b=SdkBindingData{type=integers, value=2}} for remote task [remote_sum_task] across known task inputs, use SdkTestingExecutor#withTaskOutput or SdkTestingExecutor#withTask to provide a test double")); } public static class SimpleUberWorkflow diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java index f302c65bf..a9734c70f 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java @@ -19,6 +19,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -52,6 +53,6 @@ public static SumOutput create(SdkBindingData c) { @Override public SumOutput run(SumInput input) { - return SumOutput.create(SdkBindingData.ofInteger(input.a().get() + input.b().get())); + return SumOutput.create(SdkBindingDataFactory.of(input.a().get() + input.b().get())); } } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/TestingRunnableNodeTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/TestingRunnableNodeTest.java index a1b7950c4..6ca941f5d 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/TestingRunnableNodeTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/TestingRunnableNodeTest.java @@ -29,6 +29,7 @@ import org.flyte.api.v1.Literal; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.jackson.JacksonSdkType; import org.junit.jupiter.api.Test; @@ -80,7 +81,7 @@ void testRun_notFound() { assertThat( ex.getMessage(), equalTo( - "Can't find input Input{in=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{stringValue=not in fixed outputs}}}, type=LiteralType{simpleType=STRING}, value=not in fixed outputs}} for remote test [TestTask] " + "Can't find input Input{in=SdkBindingData{type=strings, value=not in fixed outputs}} for remote test [TestTask] " + "across known test inputs, use a magic wang to provide a test double")); } @@ -132,7 +133,7 @@ abstract static class Input { abstract SdkBindingData in(); public static Input create(String in) { - return new AutoValue_TestingRunnableNodeTest_Input(SdkBindingData.ofString(in)); + return new AutoValue_TestingRunnableNodeTest_Input(SdkBindingDataFactory.of(in)); } } @@ -141,7 +142,7 @@ abstract static class Output { abstract SdkBindingData out(); public static Output create(Long out) { - return new AutoValue_TestingRunnableNodeTest_Output(SdkBindingData.ofInteger(out)); + return new AutoValue_TestingRunnableNodeTest_Output(SdkBindingDataFactory.of(out)); } } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java b/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java index 9470e6c02..cebec1df9 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java @@ -24,6 +24,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkCondition; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -73,17 +74,17 @@ public ConstStringTask.Output expand(SdkWorkflowBuilder builder, BranchNodeWorkf "c-equal-d", eq(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a == b && c == d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a == b && c == d"))) .when( "c-greater-d", gt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a == b && c > d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a == b && c > d"))) .when( "c-less-d", lt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a == b && c < d")))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a == b && c < d")))) .when( "a-less-b", lt(a, b), @@ -91,17 +92,17 @@ public ConstStringTask.Output expand(SdkWorkflowBuilder builder, BranchNodeWorkf "c-equal-d", eq(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a < b && c == d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a < b && c == d"))) .when( "c-greater-d", gt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a < b && c > d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a < b && c > d"))) .when( "c-less-d", lt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a < b && c < d")))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a < b && c < d")))) .when( "a-greater-b", gt(a, b), @@ -109,17 +110,17 @@ public ConstStringTask.Output expand(SdkWorkflowBuilder builder, BranchNodeWorkf "c-equal-d", eq(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a > b && c == d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a > b && c == d"))) .when( "c-greater-d", gt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a > b && c > d"))) + ConstStringTask.Input.create(SdkBindingDataFactory.of("a > b && c > d"))) .when( "c-less-d", lt(c, d), new ConstStringTask(), - ConstStringTask.Input.create(SdkBindingData.ofString("a > b && c < d")))); + ConstStringTask.Input.create(SdkBindingDataFactory.of("a > b && c < d")))); SdkBindingData value = builder.apply("condition", condition).getOutputs().value(); diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java index 06f92e65c..cd168d62c 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java @@ -16,13 +16,12 @@ */ package org.flyte.integrationtests.structs; -import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; -@AutoService(SdkRunnableTask.class) +// @AutoService(SdkRunnableTask.class) public class BuildBqReference extends SdkRunnableTask { private static final long serialVersionUID = -489898361071672070L; diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java index 27c2747a3..00f12f9ae 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java @@ -16,13 +16,13 @@ */ package org.flyte.integrationtests.structs; -import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; -@AutoService(SdkRunnableTask.class) +// @AutoService(SdkRunnableTask.class) public class MockLookupBqTask extends SdkRunnableTask { private static final long serialVersionUID = 604843235716487166L; @@ -48,7 +48,7 @@ public abstract static class Output { public abstract SdkBindingData exists(); public static Output create(boolean exists) { - return new AutoValue_MockLookupBqTask_Output(SdkBindingData.ofBoolean(exists)); + return new AutoValue_MockLookupBqTask_Output(SdkBindingDataFactory.of(exists)); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java index 88401c0a4..7a379557b 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java @@ -16,11 +16,9 @@ */ package org.flyte.integrationtests.structs; -import static org.flyte.flytekit.SdkBindingData.ofBoolean; -import static org.flyte.flytekit.SdkBindingData.ofString; - import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -45,7 +43,9 @@ public Output expand(SdkWorkflowBuilder builder, Input input) { "build-ref", new BuildBqReference(), BuildBqReference.Input.create( - ofString("styx-1265"), ofString("styx-insights"), input.tableName())) + SdkBindingDataFactory.of("styx-1265"), + SdkBindingDataFactory.of("styx-insights"), + input.tableName())) .getOutputs() .ref(); SdkBindingData exists = @@ -53,7 +53,7 @@ public Output expand(SdkWorkflowBuilder builder, Input input) { .apply( "lookup", new MockLookupBqTask(), - MockLookupBqTask.Input.create(ref, ofBoolean(true))) + MockLookupBqTask.Input.create(ref, SdkBindingDataFactory.of(true))) .getOutputs() .exists(); return Output.create(exists); diff --git a/jflyte/pom.xml b/jflyte/pom.xml index 6d4191693..ef3a27bbe 100644 --- a/jflyte/pom.xml +++ b/jflyte/pom.xml @@ -148,11 +148,6 @@ junit-vintage-engine test - - org.junit.jupiter - junit-jupiter-params - test - org.hamcrest hamcrest diff --git a/pom.xml b/pom.xml index bb6a5fd0c..d0abbda1c 100644 --- a/pom.xml +++ b/pom.xml @@ -266,11 +266,6 @@ junit-vintage-engine ${junit.version} - - org.junit.jupiter - junit-jupiter-params - ${junit.version} - org.hamcrest hamcrest