diff --git a/.scalafmt.conf b/.scalafmt.conf index 6d6fd4e2c..971a38a84 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,3 +1,3 @@ -version=2.5.2 +version=3.7.14 runner.dialect=scala212source3 diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry index 7cc5459dd..acd3fc633 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry @@ -1 +1 @@ -org.flyte.examples.flytekitscala.FibonacciLaunchPlan +org.flyte.examples.flytekitscala.LaunchPlanRegistry diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask index 0fc19c133..508e6cb51 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask @@ -3,3 +3,4 @@ org.flyte.examples.flytekitscala.SumTask org.flyte.examples.flytekitscala.GreetTask org.flyte.examples.flytekitscala.AddQuestionTask org.flyte.examples.flytekitscala.NoInputsTask +org.flyte.examples.flytekitscala.NestedIOTask diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow index 9b9ca9038..844fdc040 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow @@ -1,2 +1,3 @@ org.flyte.examples.flytekitscala.FibonacciWorkflow org.flyte.examples.flytekitscala.WelcomeWorkflow +org.flyte.examples.flytekitscala.NestedIOWorkflow diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala similarity index 63% rename from flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala rename to flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala index 4fd16f3b0..df5c3b438 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala @@ -20,8 +20,9 @@ import org.flyte.flytekit.{SdkLaunchPlan, SimpleSdkLaunchPlanRegistry} import org.flyte.flytekitscala.SdkScalaType case class FibonacciLaunchPlanInput(fib0: Long, fib1: Long) +case class NestedIOLaunchPlanInput(name: String, generic: Nested) -class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry { +class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry { // Register default launch plans for all workflows registerDefaultLaunchPlans() @@ -53,4 +54,35 @@ class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry { .withDefaultInput("fib0", 0L) .withDefaultInput("fib1", 1L) ) + + registerLaunchPlan( + SdkLaunchPlan + .of(new NestedIOWorkflow) + .withName("NestedIOWorkflowLaunchPlan") + .withDefaultInput( + SdkScalaType[NestedIOLaunchPlanInput], + NestedIOLaunchPlanInput( + "yo", + Nested( + boolean = true, + 1.toByte, + 2.toShort, + 3, + 4L, + 5.toFloat, + 6.toDouble, + "hello", + List("1", "2"), + List(NestedNested(7.toDouble, NestedNestedNested("world"))), + Map("1" -> "1", "2" -> "2"), + Map("foo" -> NestedNested(7.toDouble, NestedNestedNested("world"))), + Some(false), + None, + Some(List("3", "4")), + Some(Map("3" -> "3", "4" -> "4")), + NestedNested(7.toDouble, NestedNestedNested("world")) + ) + ) + ) + ) } diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala new file mode 100644 index 000000000..ef4d61245 --- /dev/null +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples.flytekitscala + +import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform} +import org.flyte.flytekitscala.{ + Description, + SdkBindingDataFactory, + SdkScalaType +} + +case class NestedNestedNested(string: String) +case class NestedNested(double: Double, nested: NestedNestedNested) +case class Nested( + boolean: Boolean, + byte: Byte, + short: Short, + int: Int, + long: Long, + float: Float, + double: Double, + string: String, + list: List[String], + listOfNested: List[NestedNested], + map: Map[String, String], + mapOfNested: Map[String, NestedNested], + optBoolean: Option[Boolean], + optByte: Option[Byte], + optList: Option[List[String]], + optMap: Option[Map[String, String]], + nested: NestedNested +) +case class NestedIOTaskInput( + @Description("the name of the person to be greeted") + name: SdkBindingData[String], + @Description("a nested input") + generic: SdkBindingData[Nested] +) +case class NestedIOTaskOutput( + @Description("the name of the person to be greeted") + name: SdkBindingData[String], + @Description("a nested input") + generic: SdkBindingData[Nested] +) + +/** Example Flyte task that takes a name as the input and outputs a simple + * greeting message. + */ +class NestedIOTask + extends SdkRunnableTask[ + NestedIOTaskInput, + NestedIOTaskOutput + ]( + SdkScalaType[NestedIOTaskInput], + SdkScalaType[NestedIOTaskOutput] + ) { + + /** Defines task behavior. This task takes a name as the input, wraps it in a + * welcome message, and outputs the message. + * + * @param input + * the name of the person to be greeted + * @return + * the welcome message + */ + override def run(input: NestedIOTaskInput): NestedIOTaskOutput = + NestedIOTaskOutput( + input.name, + input.generic + ) +} diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala new file mode 100644 index 000000000..dfe996650 --- /dev/null +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2020-2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples.flytekitscala + +import org.flyte.flytekitscala.{ + SdkScalaType, + SdkScalaWorkflow, + SdkScalaWorkflowBuilder +} + +class NestedIOWorkflow + extends SdkScalaWorkflow[NestedIOTaskInput, Unit]( + SdkScalaType[NestedIOTaskInput], + SdkScalaType.unit + ) { + + override def expand( + builder: SdkScalaWorkflowBuilder, + input: NestedIOTaskInput + ): Unit = { + builder.apply(new NestedIOTask(), input) + } +} diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java index da50d076b..5075ae8cb 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -34,8 +35,20 @@ public AllInputsTask() { JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class)); } + @AutoValue + public abstract static class Nested { + public abstract String hello(); + + public abstract String world(); + + public static Nested create(String hello, String world) { + return new AutoValue_AllInputsTask_Nested(hello, world); + } + } + @AutoValue public abstract static class AutoAllInputsInput { + public abstract SdkBindingData i(); public abstract SdkBindingData f(); @@ -48,8 +61,9 @@ public abstract static class AutoAllInputsInput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + public abstract SdkBindingData blob(); + + public abstract SdkBindingData generic(); public abstract SdkBindingData> l(); @@ -66,13 +80,14 @@ public static AutoAllInputsInput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, - // Blob blob, + SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsInput( - i, f, s, b, t, d, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } @@ -91,8 +106,9 @@ public abstract static class AutoAllInputsOutput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + public abstract SdkBindingData blob(); + + public abstract SdkBindingData generic(); public abstract SdkBindingData> l(); @@ -109,12 +125,14 @@ public static AutoAllInputsOutput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, + SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsOutput( - i, f, s, b, t, d, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } @@ -127,6 +145,8 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) { input.b(), input.t(), input.d(), + input.blob(), + input.generic(), input.l(), input.m(), input.emptyList(), diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 26e888172..8bd9acc31 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -24,13 +24,19 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; +import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.examples.AllInputsTask.AutoAllInputsOutput; +import org.flyte.examples.AllInputsTask.Nested; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkNode; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) @@ -57,6 +63,20 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) SdkBindingDataFactory.of(true), SdkBindingDataFactory.of(someInstant), SdkBindingDataFactory.of(Duration.ofDays(1L)), + SdkBindingDataFactory.of( + Blob.builder() + .uri("file://test/test.csv") + .metadata( + BlobMetadata.builder() + .type( + BlobType.builder() + .format("") + .dimensionality(BlobDimensionality.SINGLE) + .build()) + .build()) + .build()), + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")), SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")), SdkBindingDataFactory.ofStringMap(Map.of("test", "test")), SdkBindingDataFactory.ofStringCollection(Collections.emptyList()), @@ -71,6 +91,8 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) outputs.b(), outputs.t(), outputs.d(), + outputs.blob(), + outputs.generic(), outputs.l(), outputs.m(), outputs.emptyList(), @@ -92,8 +114,9 @@ public abstract static class AllInputsWorkflowOutput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + public abstract SdkBindingData blob(); + + public abstract SdkBindingData generic(); public abstract SdkBindingData> l(); @@ -110,12 +133,14 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, + SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput( - i, f, s, b, t, d, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } } diff --git a/flytekit-jackson/pom.xml b/flytekit-jackson/pom.xml index 8ddba736e..3688b163b 100644 --- a/flytekit-jackson/pom.xml +++ b/flytekit-jackson/pom.xml @@ -47,6 +47,10 @@ com.fasterxml.jackson.datatype jackson-datatype-jsr310 + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + com.fasterxml.jackson.module jackson-module-parameter-names diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java index 0be5ba34f..969fa64bd 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java @@ -33,6 +33,7 @@ import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkLiteralType; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; /** * Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link @@ -102,7 +103,8 @@ public Literal toLiteral(T value) { var tree = OBJECT_MAPPER.valueToTree(value); try { - return OBJECT_MAPPER.treeToValue(tree, Literal.class); + return Literal.ofScalar( + Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap())); } catch (IOException e) { throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e); } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/ObjectMapperUtils.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/ObjectMapperUtils.java index fa21cba8d..7af4e2284 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/ObjectMapperUtils.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/ObjectMapperUtils.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.module.paramnames.ParameterNamesModule; import com.google.errorprone.annotations.Var; @@ -36,6 +37,7 @@ static ObjectMapper createObjectMapper(Module... modules) { return objectMapper .registerModule(new JavaTimeModule()) + .registerModule(new Jdk8Module()) .registerModule(new ParameterNamesModule()); } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java index 861a1c640..4ec2d158d 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java @@ -20,8 +20,8 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.module.SimpleDeserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.api.v1.Literal; -import org.flyte.flytekit.jackson.deserializers.LiteralStructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; import org.flyte.flytekit.jackson.serializers.StructSerializer; class SdkLiteralTypeModule extends Module { @@ -43,7 +43,7 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); var deserializers = new SimpleDeserializers(); - deserializers.addDeserializer(Literal.class, new LiteralStructDeserializer()); + deserializers.addDeserializer(StructWrapper.class, new StructDeserializer()); context.addDeserializers(deserializers); // append with the lowest priority to use as fallback, if builtin annotations aren't present diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java index 17f71c25a..aa25ff45e 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java @@ -20,7 +20,6 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.deser.Deserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializers; import org.flyte.flytekit.jackson.deserializers.SdkBindingDataDeserializers; import org.flyte.flytekit.jackson.serializers.BindingMapSerializers; import org.flyte.flytekit.jackson.serializers.LiteralMapSerializers; @@ -60,7 +59,6 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); context.addSerializers(new LiteralMapSerializers()); - context.addDeserializers(new LiteralMapDeserializers()); context.addSerializers(new BindingMapSerializers()); context.addDeserializers(sdkbindingDeserializers); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index 183e00444..c565898be 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -30,6 +30,9 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkLiteralType; @@ -63,11 +66,7 @@ public void property(BeanProperty prop) { String propName = prop.getName(); AnnotatedMember member = prop.getMember(); SdkLiteralType literalType = - toLiteralType( - handledType, - /*rootLevel=*/ true, - propName, - member.getMember().getDeclaringClass().getName()); + toLiteralType(handledType, /* rootLevel= */ true, propName, member); String description = getDescription(member); @@ -132,18 +131,17 @@ private String getDescription(AnnotatedMember member) { @SuppressWarnings("AlreadyChecked") private SdkLiteralType toLiteralType( - JavaType javaType, boolean rootLevel, String propName, String declaringClassName) { + JavaType javaType, boolean rootLevel, String propName, AnnotatedMember member) { Class type = javaType.getRawClass(); if (SdkBindingData.class.isAssignableFrom(type)) { - return toLiteralType( - javaType.getBindings().getBoundType(0), false, propName, declaringClassName); + return toLiteralType(javaType.getBindings().getBoundType(0), false, propName, member); } else if (rootLevel) { throw new UnsupportedOperationException( String.format( "Field '%s' from class '%s' is declared as '%s' and it is not matching any of the supported types. " + "Please make sure your variable declared type is wrapped in 'SdkBindingData<>'.", - propName, declaringClassName, type)); + propName, member.getMember().getDeclaringClass().getName(), type)); } else if (isPrimitiveAssignableFrom(Long.class, type)) { return SdkLiteralTypes.integers(); } else if (isPrimitiveAssignableFrom(Double.class, type)) { @@ -159,8 +157,7 @@ private SdkLiteralType toLiteralType( } else if (List.class.isAssignableFrom(type)) { JavaType elementType = javaType.getBindings().getBoundType(0); - return SdkLiteralTypes.collections( - toLiteralType(elementType, false, propName, declaringClassName)); + return SdkLiteralTypes.collections(toLiteralType(elementType, false, propName, member)); } else if (Map.class.isAssignableFrom(type)) { JavaType keyType = javaType.getBindings().getBoundType(0); JavaType valueType = javaType.getBindings().getBoundType(1); @@ -170,11 +167,20 @@ private SdkLiteralType toLiteralType( "Only Map is supported, got [" + javaType.getGenericSignature() + "]"); } - return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, declaringClassName)); + return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, member)); + } else if (Blob.class.isAssignableFrom(type)) { + // fixme: create blob type from annotation, or rethink how we could offer the offloaded data + // feature + // https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype + return SdkLiteralTypes.blobs( + BlobType.builder().format("").dimensionality(BlobDimensionality.SINGLE).build()); + } + try { + return JacksonSdkLiteralType.of(type); + } catch (Exception e) { + throw new UnsupportedOperationException( + String.format("Unsupported type: [%s]", type.getName()), e); } - // TODO: Support blobs and structs - throw new UnsupportedOperationException( - String.format("Unsupported type: [%s]", type.getName())); } private static boolean isPrimitiveAssignableFrom(Class fromClass, Class toClass) { diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java deleted file mode 100644 index f3015c3da..000000000 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2020-2023 Flyte Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.flyte.flytekit.jackson.deserializers; - -import com.fasterxml.jackson.databind.BeanDescription; -import com.fasterxml.jackson.databind.DeserializationConfig; -import com.fasterxml.jackson.databind.JavaType; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.deser.Deserializers; -import java.util.Map; -import org.flyte.api.v1.LiteralType; -import org.flyte.flytekit.jackson.JacksonLiteralMap; - -public class LiteralMapDeserializers extends Deserializers.Base { - - @Override - public JsonDeserializer findBeanDeserializer( - JavaType type, DeserializationConfig config, BeanDescription beanDesc) { - if (type.getRawClass().equals(JacksonLiteralMap.class)) { - Map literalTypeMap = type.getValueHandler(); - - return new LiteralMapDeserializer(literalTypeMap); - } - - return null; - } -} diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index e99acdd8a..b860ef053 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -24,11 +24,14 @@ import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.BeanProperty; import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.ContextualDeserializer; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import java.io.IOException; -import java.io.Serializable; import java.time.Duration; import java.time.Instant; import java.util.Iterator; @@ -39,40 +42,56 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; +import org.flyte.api.v1.BlobType; +import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.Scalar.Kind; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkLiteralType; import org.flyte.flytekit.SdkLiteralTypes; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; -class SdkBindingDataDeserializer extends StdDeserializer> { +class SdkBindingDataDeserializer extends StdDeserializer> + implements ContextualDeserializer { private static final long serialVersionUID = 0L; + private final JavaType type; + public SdkBindingDataDeserializer() { + this(null); + } + + private SdkBindingDataDeserializer(JavaType type) { super(SdkBindingData.class); + + this.type = type; } @Override public SdkBindingData deserialize( JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode tree = jsonParser.readValueAsTree(); - return transform(tree); + return transform(tree, deserializationContext, type); } - private SdkBindingData transform(JsonNode tree) { + private SdkBindingData transform( + JsonNode tree, DeserializationContext deserializationContext, JavaType type) { Literal.Kind literalKind = Literal.Kind.valueOf(tree.get(LITERAL).asText()); switch (literalKind) { case SCALAR: - return transformScalar(tree); + return transformScalar(tree, deserializationContext, type); case COLLECTION: - return transformCollection(tree); + return transformCollection(tree, deserializationContext, type); case MAP: - return transformMap(tree); + return transformMap(tree, deserializationContext, type); default: throw new UnsupportedOperationException( @@ -80,7 +99,8 @@ private SdkBindingData transform(JsonNode tree) { } } - private static SdkBindingData transformScalar(JsonNode tree) { + private SdkBindingData transformScalar( + JsonNode tree, DeserializationContext deserializationContext, JavaType type) { Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText()); switch (scalarKind) { case PRIMITIVE: @@ -102,16 +122,60 @@ private static SdkBindingData transformScalar(JsonNode t throw new UnsupportedOperationException( "Type contains an unsupported primitive: " + primitiveKind); - case GENERIC: case BLOB: + return transformBlob(tree); + + case GENERIC: + return transformGeneric(tree, deserializationContext, scalarKind, type); + default: throw new UnsupportedOperationException( "Type contains an unsupported scalar: " + scalarKind); } } + private static SdkBindingData transformBlob(JsonNode tree) { + JsonNode value = tree.get(VALUE); + String uri = value.get("uri").asText(); + JsonNode type = value.get("metadata").get("type"); + String format = type.get("format").asText(); + BlobDimensionality dimensionality = + BlobDimensionality.valueOf(type.get("dimensionality").asText()); + return SdkBindingDataFactory.of( + Blob.builder() + .uri(uri) + .metadata( + BlobMetadata.builder() + .type(BlobType.builder().format(format).dimensionality(dimensionality).build()) + .build()) + .build()); + } + + private SdkBindingData transformGeneric( + JsonNode tree, + DeserializationContext deserializationContext, + Kind scalarKind, + JavaType type) { + JsonParser jsonParser = tree.get(VALUE).traverse(); + try { + jsonParser.nextToken(); + Object object = + deserializationContext + .findNonContextualValueDeserializer(type) + .deserialize(jsonParser, deserializationContext); + @SuppressWarnings("unchecked") + SdkLiteralType jacksonSdkLiteralType = + (SdkLiteralType) JacksonSdkLiteralType.of(type.getRawClass()); + return SdkBindingData.literal(jacksonSdkLiteralType, object); + } catch (IOException e) { + throw new UnsupportedOperationException( + "Type contains an unsupported generic: " + scalarKind, e); + } + } + @SuppressWarnings("unchecked") - private SdkBindingData> transformCollection(JsonNode tree) { + private SdkBindingData> transformCollection( + JsonNode tree, DeserializationContext deserializationContext, JavaType type) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); Iterator elements = tree.get(VALUE).elements(); @@ -119,13 +183,18 @@ private SdkBindingData> transformCollection(JsonNode tree) { case SIMPLE_TYPE: case MAP_VALUE_TYPE: case COLLECTION_TYPE: + case BLOB_TYPE: + JavaType realJavaType = + literalType instanceof JacksonSdkLiteralType ? type.getContentType() : type; List collection = (List) - streamOf(elements).map(this::transform).map(SdkBindingData::get).collect(toList()); + streamOf(elements) + .map((JsonNode tree1) -> transform(tree1, deserializationContext, realJavaType)) + .map(SdkBindingData::get) + .collect(toList()); return SdkBindingDataFactory.of(literalType, collection); case SCHEMA_TYPE: - case BLOB_TYPE: default: throw new UnsupportedOperationException( "Type contains a collection of an supported literal type: " + literalType); @@ -133,7 +202,8 @@ private SdkBindingData> transformCollection(JsonNode tree) { } @SuppressWarnings("unchecked") - private SdkBindingData> transformMap(JsonNode tree) { + private SdkBindingData> transformMap( + JsonNode tree, DeserializationContext deserializationContext, JavaType type) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); JsonNode valueNode = tree.get(VALUE); List> entries = @@ -144,14 +214,22 @@ private SdkBindingData> transformMap(JsonNode tree) { case SIMPLE_TYPE: case MAP_VALUE_TYPE: case COLLECTION_TYPE: + case BLOB_TYPE: + JavaType realJavaType = + literalType instanceof JacksonSdkLiteralType ? type.getContentType() : type; Map bindingDataMap = entries.stream() - .map(entry -> Map.entry(entry.getKey(), (T) transform(entry.getValue()).get())) + .map( + entry -> + Map.entry( + entry.getKey(), + (T) + transform(entry.getValue(), deserializationContext, realJavaType) + .get())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); return SdkBindingDataFactory.of(literalType, bindingDataMap); case SCHEMA_TYPE: - case BLOB_TYPE: default: throw new UnsupportedOperationException( "Type contains a map of an supported literal type: " + literalType); @@ -177,7 +255,7 @@ private SdkLiteralType readLiteralType(JsonNode typeNode) { case DURATION: return SdkLiteralTypes.durations(); case STRUCT: - // not yet supported, fallthrough + return JacksonSdkLiteralType.of(type.getContentType().getRawClass()); } throw new UnsupportedOperationException( "Type contains a collection/map of an supported literal type: " + kind); @@ -185,9 +263,14 @@ private SdkLiteralType readLiteralType(JsonNode typeNode) { return SdkLiteralTypes.maps(readLiteralType(typeNode.get(VALUE).get(TYPE))); case COLLECTION_TYPE: return SdkLiteralTypes.collections(readLiteralType(typeNode.get(VALUE).get(TYPE))); - - case SCHEMA_TYPE: case BLOB_TYPE: + return SdkLiteralTypes.blobs( + BlobType.builder() + .format(typeNode.get(VALUE).get("format").asText()) + .dimensionality( + BlobDimensionality.valueOf(typeNode.get(VALUE).get("dimensionality").asText())) + .build()); + case SCHEMA_TYPE: default: throw new UnsupportedOperationException( "Type contains a collection/map of an supported literal type: " + kind); @@ -198,4 +281,9 @@ private Stream streamOf(Iterator nodes) { return StreamSupport.stream( Spliterators.spliteratorUnknownSize(nodes, Spliterator.ORDERED), false); } + + @Override + public JsonDeserializer createContextual(DeserializationContext ctxt, BeanProperty property) { + return new SdkBindingDataDeserializer(property.getType().containedType(0)); + } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java similarity index 70% rename from flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java rename to flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java index 0c17f55d5..88f673f80 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java @@ -29,23 +29,35 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.flyte.api.v1.Literal; -import org.flyte.api.v1.Scalar; import org.flyte.api.v1.Struct; import org.flyte.api.v1.Struct.Value; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; -public class LiteralStructDeserializer extends StdDeserializer { +public class StructDeserializer extends StdDeserializer { private static final long serialVersionUID = -6835948754469626304L; - public LiteralStructDeserializer() { - super(Literal.class); + // we cannot use Struct directly because it is an auto-value class so this deserializer will not + // be used by Jackson + public static class StructWrapper { + + private final Struct struct; + + public StructWrapper(Struct struct) { + this.struct = struct; + } + + public Struct unwrap() { + return struct; + } } - @Override - public Literal deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + public StructDeserializer() { + super(StructWrapper.class); + } - Struct generic = readValueAsStruct(p); - return Literal.ofScalar(Scalar.ofGeneric(generic)); + @Override + public StructWrapper deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + return new StructWrapper(readValueAsStruct(p)); } private static Struct readValueAsStruct(JsonParser p) throws IOException { @@ -67,7 +79,7 @@ private static Struct readValueAsStruct(JsonParser p) throws IOException { return Struct.of(unmodifiableMap(fields)); } - private static Struct.Value readValueAsStructValue(JsonParser p) throws IOException { + private static Value readValueAsStructValue(JsonParser p) throws IOException { switch (p.currentToken()) { case START_ARRAY: p.nextToken(); @@ -75,38 +87,38 @@ private static Struct.Value readValueAsStructValue(JsonParser p) throws IOExcept List valuesList = new ArrayList<>(); while (p.currentToken() != JsonToken.END_ARRAY) { - Struct.Value value = readValueAsStructValue(p); + Value value = readValueAsStructValue(p); p.nextToken(); valuesList.add(value); } - return Struct.Value.ofListValue(unmodifiableList(valuesList)); + return Value.ofListValue(unmodifiableList(valuesList)); case START_OBJECT: Struct struct = readValueAsStruct(p); - return Struct.Value.ofStructValue(struct); + return Value.ofStructValue(struct); case VALUE_STRING: String stringValue = p.readValueAs(String.class); - return Struct.Value.ofStringValue(stringValue); + return Value.ofStringValue(stringValue); case VALUE_NUMBER_FLOAT: case VALUE_NUMBER_INT: Double doubleValue = p.readValueAs(Double.class); - return Struct.Value.ofNumberValue(doubleValue); + return Value.ofNumberValue(doubleValue); case VALUE_NULL: - return Struct.Value.ofNullValue(); + return Value.ofNullValue(); case VALUE_FALSE: - return Struct.Value.ofBoolValue(false); + return Value.ofBoolValue(false); case VALUE_TRUE: - return Struct.Value.ofBoolValue(true); + return Value.ofBoolValue(true); case FIELD_NAME: case NOT_AVAILABLE: diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java index 282376109..7862b6f26 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BlobSerializer.java @@ -16,7 +16,7 @@ */ package org.flyte.flytekit.jackson.serializers; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR; +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; @@ -24,7 +24,7 @@ import org.flyte.api.v1.Blob; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.Scalar.Kind; public class BlobSerializer extends ScalarSerializer { public BlobSerializer( @@ -38,8 +38,8 @@ public BlobSerializer( @Override void serializeScalar() throws IOException { - gen.writeFieldName(SCALAR); - gen.writeObject(Scalar.Kind.BLOB); + gen.writeObject(Kind.BLOB); + gen.writeFieldName(VALUE); serializerProvider .findValueSerializer(Blob.class) .serialize(value.scalar().blob(), gen, serializerProvider); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java index 12ec69e18..5c73535c7 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java @@ -16,20 +16,15 @@ */ package org.flyte.flytekit.jackson.serializers; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.LITERAL; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_TYPE; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_VALUE; +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; import java.io.IOException; -import java.util.Map; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; -import org.flyte.api.v1.Struct; public class GenericSerializer extends ScalarSerializer { public GenericSerializer( @@ -48,85 +43,7 @@ public GenericSerializer( @Override public void serializeScalar() throws IOException { gen.writeObject(Scalar.Kind.GENERIC); - for (Map.Entry entry : value.scalar().generic().fields().entrySet()) { - gen.writeFieldName(entry.getKey()); - serializeStructValue(entry.getValue()); - } - } - - private void serializeStructValue(Struct.Value value) throws IOException { - if (!value.kind().equals(Struct.Value.Kind.LIST_VALUE) - && !value.kind().equals(Struct.Value.Kind.NULL_VALUE)) { - gen.writeStartObject(); - gen.writeFieldName(LITERAL); - gen.writeObject(Literal.Kind.SCALAR); - gen.writeFieldName(SCALAR); - gen.writeObject(Scalar.Kind.GENERIC); - } - - if (isSimpleType(value.kind())) { - gen.writeFieldName(STRUCT_TYPE); - } - switch (value.kind()) { - case BOOL_VALUE: - writeSimpleType( - Struct.Value.Kind.BOOL_VALUE, - value, - (generator, v) -> generator.writeBoolean(v.boolValue())); - return; - - case LIST_VALUE: - throw new RuntimeException("not supported list inside the struct"); - - case NUMBER_VALUE: - writeSimpleType( - Struct.Value.Kind.NUMBER_VALUE, - value, - (generator, v) -> generator.writeNumber(v.numberValue())); - return; - - case STRING_VALUE: - writeSimpleType( - Struct.Value.Kind.STRING_VALUE, - value, - (generator, v) -> generator.writeString(v.stringValue())); - return; - - case STRUCT_VALUE: - value.structValue().fields().forEach((k, v) -> writeStructValue(gen, k, v)); - gen.writeEndObject(); - return; - - case NULL_VALUE: - gen.writeNull(); - } - } - - private void writeStructValue(JsonGenerator gen, String k, Struct.Value v) { - try { - gen.writeFieldName(k); - serializeStructValue(v); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private boolean isSimpleType(Struct.Value.Kind kind) { - return kind.equals(Struct.Value.Kind.BOOL_VALUE) - || kind.equals(Struct.Value.Kind.NUMBER_VALUE) - || kind.equals(Struct.Value.Kind.STRING_VALUE); - } - - private void writeSimpleType( - Struct.Value.Kind kind, Struct.Value structValue, WriteGenericFunction writeTypeFunction) - throws IOException { - gen.writeObject(kind); - gen.writeFieldName(STRUCT_VALUE); - writeTypeFunction.write(gen, structValue); - gen.writeEndObject(); - } - - interface WriteGenericFunction { - void write(JsonGenerator gen, Struct.Value value) throws IOException; + gen.writeFieldName(VALUE); + new StructSerializer().serialize(value.scalar().generic(), gen, serializerProvider); } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralTypeSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralTypeSerializer.java index 1cbecb7eb..75457d10f 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralTypeSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralTypeSerializer.java @@ -49,8 +49,11 @@ static void serialize(LiteralType literalType, JsonGenerator gen) throws IOExcep serialize(literalType.mapValueType(), gen); gen.writeEndObject(); break; - case SCHEMA_TYPE: case BLOB_TYPE: + // {type: {kind: blob, value: {format: string, dimensionality: string}}}} + gen.writeObject(literalType.blobType()); + break; + case SCHEMA_TYPE: throw new IllegalArgumentException( String.format("Unsupported LiteralType.Kind: [%s]", literalType.getKind())); } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java index 86af1b5fc..4267bc532 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/SdkBindingDataSerializationProtocol.java @@ -25,4 +25,5 @@ public class SdkBindingDataSerializationProtocol { public static final String TYPE = "type"; public static final String KIND = "kind"; public static final String PRIMITIVE = "primitive"; + public static final String BLOB = "blob"; } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkLiteralTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkLiteralTypeTest.java index ee8926537..d4efedb9b 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkLiteralTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkLiteralTypeTest.java @@ -26,6 +26,7 @@ import java.time.Instant; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Stream; import javax.annotation.Nullable; import org.flyte.api.v1.BindingData; @@ -79,7 +80,14 @@ public static Stream typeValueLiteralProvider() { arguments( SomeType.class, SomeType.create( - 1, 2.0, "3", true, null, List.of("4", "5", "6"), EmbeddedType.create(7, 8)), + 1, + 2.0, + "3", + true, + null, + List.of("4", "5", "6"), + EmbeddedType.create(7, 8), + Optional.empty()), Literal.ofScalar( Scalar.ofGeneric( Struct.of( @@ -105,7 +113,9 @@ public static Stream typeValueLiteralProvider() { Struct.of( Map.of( "a", Value.ofNumberValue(7), - "b", Value.ofNumberValue(8))))))))), + "b", Value.ofNumberValue(8)))), + "optionalS", + Value.ofNullValue()))))), arguments( TypeWithMap.class, TypeWithMap.create(Map.of("x", 1L, "y", 2L)), @@ -147,7 +157,14 @@ public static Stream typeValueBindingProvider() { arguments( SomeType.class, SomeType.create( - 1, 2.0, "3", true, null, List.of("4", "5", "6"), EmbeddedType.create(7, 8)), + 1, + 2.0, + "3", + true, + null, + List.of("4", "5", "6"), + EmbeddedType.create(7, 8), + Optional.of("hello")), BindingData.ofScalar( Scalar.ofGeneric( Struct.of( @@ -173,7 +190,9 @@ public static Stream typeValueBindingProvider() { Struct.of( Map.of( "a", Value.ofNumberValue(7), - "b", Value.ofNumberValue(8))))))))), + "b", Value.ofNumberValue(8)))), + "optionalS", + Value.ofStringValue("hello")))))), arguments( TypeWithMap.class, TypeWithMap.create(Map.of("a", 1L, "b", 2L)), @@ -274,6 +293,8 @@ abstract static class SomeType { abstract EmbeddedType subTest(); + abstract Optional optionalS(); + public static SomeType create( long i, double f, @@ -281,8 +302,10 @@ public static SomeType create( boolean b, String null_, List list, - EmbeddedType subTest) { - return new AutoValue_JacksonSdkLiteralTypeTest_SomeType(i, f, s, b, null_, list, subTest); + EmbeddedType subTest, + Optional optionalS) { + return new AutoValue_JacksonSdkLiteralTypeTest_SomeType( + i, f, s, b, null_, list, subTest, optionalS); } } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index b4b6ce995..a1f969961 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -23,7 +23,6 @@ import static org.hamcrest.Matchers.hasEntry; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -37,13 +36,17 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import javax.annotation.Nullable; +import java.util.Optional; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobMetadata; import org.flyte.api.v1.BlobType; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.Struct.Value; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -53,10 +56,15 @@ public class JacksonSdkTypeTest { - @SuppressWarnings("UnusedVariable") - static final BlobType BLOB_TYPE = + private static final BlobType BLOB_TYPE = BlobType.builder().format("").dimensionality(BlobType.BlobDimensionality.SINGLE).build(); + private static final Blob BLOB = + Blob.builder() + .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) + .uri("file://test") + .build(); + public static AutoValueInput createAutoValueInput( long i, double f, @@ -64,9 +72,14 @@ public static AutoValueInput createAutoValueInput( boolean b, Instant t, Duration d, - // Blob blob, + Blob blob, + Nested generic, List l, + List lb, + List lg, Map m, + Map mb, + Map mg, List> ll, List> lm, Map> ml, @@ -78,8 +91,14 @@ public static AutoValueInput createAutoValueInput( SdkBindingDataFactory.of(b), SdkBindingDataFactory.of(t), SdkBindingDataFactory.of(d), + SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), generic), SdkBindingDataFactory.ofStringCollection(l), + SdkBindingDataFactory.of(SdkLiteralTypes.blobs(BLOB_TYPE), lb), + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), lg), SdkBindingDataFactory.ofStringMap(m), + SdkBindingDataFactory.of(SdkLiteralTypes.blobs(BLOB_TYPE), mb), + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), mg), SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll), SdkBindingDataFactory.of(SdkLiteralTypes.maps(SdkLiteralTypes.strings()), lm), SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ml), @@ -98,7 +117,8 @@ public void testVariableMap() { hasEntry("b", createVar(SimpleType.BOOLEAN)), hasEntry("t", createVar(SimpleType.DATETIME)), hasEntry("d", createVar(SimpleType.DURATION)), - // hasEntry("blob", createVar(LiteralType.ofBlobType(BLOB_TYPE))), + hasEntry("blob", createVar(LiteralType.ofBlobType(BLOB_TYPE))), + hasEntry("generic", createVar(LiteralType.ofSimpleType(SimpleType.STRUCT))), hasEntry( "l", createVar(LiteralType.ofCollectionType(ofSimpleType(SimpleType.STRING)))), hasEntry( @@ -119,11 +139,6 @@ public void testVariableMap() { void testFromLiteralMap() { Instant datetime = Instant.ofEpochSecond(12, 34); Duration duration = Duration.ofSeconds(56, 78); - // Blob blob = - // Blob.builder() - // .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) - // .uri("file://test") - // .build(); Map literalMap = new HashMap<>(); literalMap.put("i", literalOf(Primitive.ofIntegerValue(123L))); literalMap.put("f", literalOf(Primitive.ofFloatValue(123.0))); @@ -131,9 +146,29 @@ void testFromLiteralMap() { literalMap.put("b", literalOf(Primitive.ofBooleanValue(true))); literalMap.put("t", literalOf(Primitive.ofDatetime(datetime))); literalMap.put("d", literalOf(Primitive.ofDuration(duration))); - // literalMap.put("blob", literalOf(blob)); + literalMap.put("blob", literalOf(BLOB)); + literalMap.put( + "generic", + literalOf( + Struct.of( + Map.of( + "hello", + Value.ofStringValue("hello"), + "world", + Value.ofStringValue("world"))))); literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123"))))); + literalMap.put("lb", Literal.ofCollection(List.of(literalOf(BLOB)))); + literalMap.put( + "lg", + Literal.ofCollection( + List.of(literalOf(Struct.of(Map.of("hello", Value.ofStringValue("hello"))))))); literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo"))))); + literalMap.put("mb", Literal.ofMap(Map.of("blob", literalOf(BLOB)))); + literalMap.put( + "mg", + Literal.ofMap( + Map.of( + "generic", literalOf(Struct.of(Map.of("hello", Value.ofStringValue("hello"))))))); literalMap.put( "ll", Literal.ofCollection( @@ -159,9 +194,9 @@ void testFromLiteralMap() { Literal.ofMap( Map.of( "math", - Literal.ofMap( - Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), - "pokemon", Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))); + Literal.ofMap(Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), + "pokemon", + Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))); AutoValueInput input = JacksonSdkType.of(AutoValueInput.class).fromLiteralMap(literalMap); @@ -175,9 +210,14 @@ void testFromLiteralMap() { /* b= */ true, /* t= */ datetime, /* d= */ duration, - /// * blob= */ blob, + /* blob= */ BLOB, + /* generic= */ Nested.create("hello", "world"), /* l= */ List.of("123"), + /* lb= */ List.of(BLOB), + /* lg= */ List.of(Nested.create("hello")), /* m= */ Map.of("marco", "polo"), + /* mb= */ Map.of("blob", BLOB), + /* mg= */ Map.of("generic", Nested.create("hello")), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), /* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")), /* ml= */ Map.of("frodo", List.of("baggins", "bolson")), @@ -194,11 +234,6 @@ private static Literal stringLiteralOf(String string) { @Test void testToLiteralMap() { - // Blob blob = - // Blob.builder() - // .metadata(BlobMetadata.builder().type(BLOB_TYPE).build()) - // .uri("file://test") - // .build(); Map literalMap = JacksonSdkType.of(AutoValueInput.class) .toLiteralMap( @@ -209,9 +244,14 @@ void testToLiteralMap() { /* b= */ false, /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), - /// * blob= */ blob, + /* blob= */ BLOB, + /* generic= */ Nested.create("hello"), /* l= */ List.of("foo"), + /* lb= */ List.of(BLOB), + /* lg= */ List.of(Nested.create("hello")), /* m= */ Map.of("marco", "polo"), + /* mb= */ Map.of("blob", BLOB), + /* mg= */ Map.of("generic", Nested.create("hello")), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), /* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")), /* ml= */ Map.of("frodo", List.of("baggins", "bolson")), @@ -271,9 +311,8 @@ void testToLiteralMap() { Map.of( "pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), "pokemon", - Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))) - // hasEntry("blob", literalOf(blob)) - ))); + Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))), + hasEntry("blob", literalOf(BLOB))))); } @Test @@ -286,9 +325,14 @@ public void testToSdkBindingDataMap() { /* b= */ false, /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), - /// * blob= */ blob, + /* blob= */ BLOB, + /* generic= */ Nested.create("hello"), /* l= */ List.of("foo"), + /* lb= */ List.of(BLOB), + /* lg= */ List.of(Nested.create("hello")), /* m= */ Map.of("marco", "polo"), + /* mb= */ Map.of("blob", BLOB), + /* mg= */ Map.of("generic", Nested.create("hello")), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), /* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")), /* ml= */ Map.of("frodo", List.of("baggins", "bolson")), @@ -305,8 +349,14 @@ public void testToSdkBindingDataMap() { expected.put("b", input.b()); expected.put("t", input.t()); expected.put("d", input.d()); + expected.put("blob", input.blob()); + expected.put("generic", input.generic()); expected.put("l", input.l()); + expected.put("lb", input.lb()); + expected.put("lg", input.lg()); expected.put("m", input.m()); + expected.put("mb", input.mb()); + expected.put("mg", input.mg()); expected.put("ll", input.ll()); expected.put("lm", input.lm()); expected.put("ml", input.ml()); @@ -374,30 +424,6 @@ public void testPojoVariableMap() { assertThat(variableMap, equalTo(Map.of("a", expected))); } - @Disabled("Not supported struct with the strongly types implementation.") - public void testStructRoundtrip() { - fail(); - // StructInput input = - // StructInput.create( - // null - // // StructValueInput.create( - // // /* stringValue= */ "nested-string", - // // /* boolValue= */ false, - // // /* listValue= */ Arrays.asList(1L, 2L, 3L), - // // /* structValue= */ StructValueInput.create( - // // /* stringValue= */ "nested-string", - // // /* boolValue= */ false, - // // /* listValue= */ Arrays.asList(1L, 2L, 3L), - // // /* structValue= */ null, - // // /* numberValue= */ 42.0), - // // /* numberValue= */ 42.0) - // ); - // - // SdkType sdkType = JacksonSdkType.of(StructInput.class); - // Map literalMap = sdkType.toLiteralMap(input); - // assertThat(sdkType.fromLiteralMap(literalMap), equalTo(input)); - } - @Disabled("Not supported customType & customEnum with the strongly types implementation.") public void testConverterToLiteralMap() { InputWithCustomType input = InputWithCustomType.create(CustomType.ONE, CustomEnum.TWO); @@ -520,6 +546,21 @@ public static AutoValueDeprecatedInput create(long i) { } } + @AutoValue + public abstract static class Nested { + public abstract String hello(); + + public abstract Optional world(); + + public static AutoValue_JacksonSdkTypeTest_Nested create(String hello) { + return new AutoValue_JacksonSdkTypeTest_Nested(hello, Optional.empty()); + } + + public static AutoValue_JacksonSdkTypeTest_Nested create(String hello, String world) { + return new AutoValue_JacksonSdkTypeTest_Nested(hello, Optional.of(world)); + } + } + @AutoValue public abstract static class AutoValueInput { @@ -536,13 +577,22 @@ public abstract static class AutoValueInput { public abstract SdkBindingData d(); - // TODO add blobs to sdkbinding data - // public abstract SdkBindingData blob(); + public abstract SdkBindingData blob(); + + public abstract SdkBindingData generic(); public abstract SdkBindingData> l(); + public abstract SdkBindingData> lb(); + + public abstract SdkBindingData> lg(); + public abstract SdkBindingData> m(); + public abstract SdkBindingData> mb(); + + public abstract SdkBindingData> mg(); + public abstract SdkBindingData>> ll(); public abstract SdkBindingData>> lm(); @@ -558,48 +608,20 @@ public static AutoValueInput create( SdkBindingData b, SdkBindingData t, SdkBindingData d, - // Blob blob, + SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, + SdkBindingData> lb, + SdkBindingData> lg, SdkBindingData> m, + SdkBindingData> mb, + SdkBindingData> mg, SdkBindingData>> ll, SdkBindingData>> lm, SdkBindingData>> ml, SdkBindingData>> mm) { return new AutoValue_JacksonSdkTypeTest_AutoValueInput( - i, f, s, b, t, d, l, m, ll, lm, ml, mm); - } - } - - @AutoValue - public abstract static class StructInput { - public abstract SdkBindingData structLevel1(); - - public static StructInput create(SdkBindingData structValue) { - return new AutoValue_JacksonSdkTypeTest_StructInput(structValue); - } - } - - @AutoValue - public abstract static class StructValueInput { - public abstract String stringValue(); - - public abstract boolean boolValue(); - - public abstract List listValue(); - - @Nullable - public abstract StructValueInput structLevel(); - - public abstract double numberValue(); - - public static StructValueInput create( - String stringValue, - boolean boolValue, - List listValue, - StructValueInput structValue, - Double numberValue) { - return new AutoValue_JacksonSdkTypeTest_StructValueInput( - stringValue, boolValue, listValue, structValue, numberValue); + i, f, s, b, t, d, blob, generic, l, lb, lg, m, mb, mg, ll, lm, ml, mm); } } @@ -701,4 +723,12 @@ private static Variable createVar(LiteralType literalType, String description) { private static Literal literalOf(Primitive primitive) { return Literal.ofScalar(Scalar.ofPrimitive(primitive)); } + + private static Literal literalOf(Blob blob) { + return Literal.ofScalar(Scalar.ofBlob(blob)); + } + + private static Literal literalOf(Struct generic) { + return Literal.ofScalar(Scalar.ofGeneric(generic)); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java b/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java index 1c18ee674..f9fb9162b 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/Literals.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.time.Instant; +import org.flyte.api.v1.Blob; import org.flyte.api.v1.Literal; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; @@ -49,6 +50,10 @@ static Literal ofDuration(Duration value) { return ofPrimitive(Primitive.ofDuration(value)); } + static Literal ofBlob(Blob value) { + return Literal.ofScalar(Scalar.ofBlob(value)); + } + private static Literal ofPrimitive(Primitive primitive) { return Literal.ofScalar(Scalar.ofPrimitive(primitive)); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java index 757e5bf10..aa7217ec8 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -26,6 +26,7 @@ import java.time.ZoneOffset; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Blob; /** A utility class for creating {@link SdkBindingData} objects for different types. */ public final class SdkBindingDataFactory { @@ -123,6 +124,27 @@ public static SdkBindingData> of(SdkLiteralType elementType, List return SdkBindingData.literal(collections(elementType), collection); } + /** + * Creates a {@code SdkBindingData} for a flyte type with the given value. + * + * @param type the flyte type + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(SdkLiteralType type, T value) { + return SdkBindingData.literal(type, value); + } + + /** + * Creates a {@code SdkBindingData} for a flyte Blob with the given value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(Blob value) { + return SdkBindingData.literal(SdkLiteralTypes.blobs(value.metadata().type()), value); + } + /** * Creates a {@code SdkBindingData} for a flyte collection of string given a java {@code * List}. diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java index f09073ea9..af1b77292 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java @@ -25,6 +25,8 @@ import java.util.Map; import java.util.Map.Entry; import org.flyte.api.v1.BindingData; +import org.flyte.api.v1.Blob; +import org.flyte.api.v1.BlobType; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; @@ -181,6 +183,15 @@ public static SdkLiteralType> maps(SdkLiteralType mapValue return new MapSdkLiteralType<>(mapValueType); } + /** + * Returns a {@link SdkLiteralType} for blobs. + * + * @return the {@link SdkLiteralType} + */ + public static SdkLiteralType blobs(BlobType blobType) { + return new BlobSdkLiteralType(blobType); + } + private static class IntegerSdkLiteralType extends PrimitiveSdkLiteralType { private static final IntegerSdkLiteralType INSTANCE = new IntegerSdkLiteralType(); @@ -205,6 +216,39 @@ public String toString() { } } + private static class BlobSdkLiteralType extends SdkLiteralType { + private final BlobType blobType; + + public BlobSdkLiteralType(BlobType blobType) { + this.blobType = blobType; + } + + @Override + public LiteralType getLiteralType() { + return LiteralType.ofBlobType(blobType); + } + + @Override + public Literal toLiteral(Blob value) { + return Literals.ofBlob(value); + } + + @Override + public Blob fromLiteral(Literal literal) { + return literal.scalar().blob(); + } + + @Override + public BindingData toBindingData(Blob value) { + return BindingData.ofScalar(Scalar.ofBlob(value)); + } + + @Override + public String toString() { + return "blobs"; + } + } + private static class FloatSdkLiteralType extends PrimitiveSdkLiteralType { private static final FloatSdkLiteralType INSTANCE = new FloatSdkLiteralType(); diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala index 2c2203c06..7e5ed0f4e 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala @@ -16,6 +16,8 @@ */ package org.flyte.flytekitscala +import org.flyte.api.v1.{Blob, BlobType} +import org.flyte.api.v1.BlobType.BlobDimensionality import org.flyte.flytekit.SdkLiteralType import org.flyte.flytekitscala.SdkLiteralTypes.{of, _} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 377833815..a3cbaf863 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -16,9 +16,14 @@ */ package org.flyte.flytekitscala +import org.flyte.api.v1.BlobType.BlobDimensionality + import java.time.{Duration, Instant} import scala.jdk.CollectionConverters._ import org.flyte.api.v1.{ + Blob, + BlobMetadata, + BlobType, Literal, LiteralType, Primitive, @@ -33,7 +38,8 @@ import org.flyte.flytekit.{ import org.flyte.flytekitscala.SdkBindingDataFactory import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} import org.junit.jupiter.api.Test -import org.flyte.examples.AllInputsTask.AutoAllInputsInput +import org.flyte.examples.AllInputsTask.{AutoAllInputsInput, Nested} +import org.flyte.flytekit.jackson.JacksonSdkLiteralType import org.flyte.flytekitscala.SdkLiteralTypes.{collections, maps, strings} class SdkScalaTypeTest { @@ -379,6 +385,23 @@ class SdkScalaTypeTest { def testUseAutoValueAttrIntoScalaClass(): Unit = { import SdkBindingDataConverters._ + val blob = Blob + .builder() + .uri("file://test/test.csv") + .metadata( + BlobMetadata + .builder() + .`type`( + BlobType + .builder() + .format("csv") + .dimensionality(BlobDimensionality.MULTIPART) + .build() + ) + .build() + ) + .build() + val input = AutoAllInputsInput.create( SdkJavaBindingDataFactory.of(2L), SdkJavaBindingDataFactory.of(2.0), @@ -386,6 +409,11 @@ class SdkScalaTypeTest { SdkJavaBindingDataFactory.of(true), SdkJavaBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), SdkJavaBindingDataFactory.of(Duration.ZERO), + SdkJavaBindingDataFactory.of(blob), + SdkJavaBindingDataFactory.of( + JacksonSdkLiteralType.of(classOf[Nested]), + Nested.create("hello", "world") + ), SdkJavaBindingDataFactory.ofStringCollection(List("1", "2", "3").asJava), SdkJavaBindingDataFactory.ofStringMap(Map("a" -> "2", "b" -> "3").asJava), SdkJavaBindingDataFactory.ofStringCollection(List.empty[String].asJava), @@ -401,6 +429,8 @@ class SdkScalaTypeTest { boolean: SdkBindingData[Boolean], instant: SdkBindingData[Instant], duration: SdkBindingData[Duration], + blob: SdkBindingData[Blob], + generic: SdkBindingData[Nested], list: SdkBindingData[List[String]], map: SdkBindingData[Map[String, String]], emptyList: SdkBindingData[List[String]], @@ -414,6 +444,8 @@ class SdkScalaTypeTest { toScalaBoolean(input.b()), input.t(), input.d(), + input.blob(), + input.generic(), toScalaList(input.l()), toScalaMap(input.m()), toScalaList(input.emptyList()), @@ -427,6 +459,11 @@ class SdkScalaTypeTest { SdkBindingDataFactory.of(true), SdkBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), SdkBindingDataFactory.of(Duration.ZERO), + SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(classOf[Nested]), + Nested.create("hello", "world") + ), SdkBindingDataFactory.of(List("1", "2", "3")), SdkBindingDataFactory.of(Map("a" -> "2", "b" -> "3")), SdkBindingDataFactory.ofStringCollection(List.empty[String]), diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala index 2c4923d55..0ed989a20 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala @@ -172,7 +172,11 @@ object SdkBindingDataConverters { jf.Function.identity() ) } - case LiteralType.Kind.BLOB_TYPE => ??? // TODO not yet supported + case LiteralType.Kind.BLOB_TYPE => + TypeCastingResult( + SdkScalaLiteralTypes.blobs(lt.blobType()), + jf.Function.identity() + ) case LiteralType.Kind.SCHEMA_TYPE => ??? // TODO not yet supported case LiteralType.Kind.COLLECTION_TYPE => val TypeCastingResult(convertedElementType, convFunction) = toScalaType( @@ -257,7 +261,11 @@ object SdkBindingDataConverters { jf.Function.identity() ) } - case LiteralType.Kind.BLOB_TYPE => ??? // TODO do we support blob? + case LiteralType.Kind.BLOB_TYPE => + TypeCastingResult( + SdkJavaLiteralTypes.blobs(lt.blobType()), + jf.Function.identity() + ) case LiteralType.Kind.SCHEMA_TYPE => ??? // TODO do we support schema type? case LiteralType.Kind.COLLECTION_TYPE => diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala index 40c2e4a14..857238ee4 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala @@ -16,6 +16,7 @@ */ package org.flyte.flytekitscala +import org.flyte.api.v1.Blob import org.flyte.flytekit.{ BindingCollection, BindingMap, @@ -134,6 +135,28 @@ object SdkBindingDataFactory { collection ) + /** Creates a [[SdkBindingData]] for a flyte Blob with the given value. + * + * @param value + * the simple value for this data + * @return + * the new [[SdkBindingData]] + */ + def of(value: Blob): SdkBindingData[Blob] = + SdkBindingData.literal(SdkLiteralTypes.blobs(value.metadata.`type`), value) + + /** Creates a [[SdkBindingData]] for a flyte type with the given value. + * + * @param type + * the flyte type + * @param value + * the simple value for this data + * @return + * the new [[SdkBindingData]] + */ + def of[T](`type`: SdkLiteralType[T], value: T): SdkBindingData[T] = + SdkBindingData.literal(`type`, value) + /** Creates a [[SdkBindingDataFactory]] for a flyte string collection given a * scala [[List]]. * diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index 596639929..6d53bceae 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -24,7 +24,18 @@ import org.flyte.flytekit.{ import java.time.{Duration, Instant} import scala.collection.JavaConverters._ -import scala.reflect.runtime.universe.{TypeTag, typeOf} +import scala.reflect.api.{Mirror, TypeCreator, Universe} +import scala.reflect.runtime.universe +import scala.reflect.{ClassTag, classTag} +import scala.reflect.runtime.universe.{ + NoPrefix, + Symbol, + Type, + TypeTag, + runtimeMirror, + termNames, + typeOf +} object SdkLiteralTypes { @@ -202,6 +213,193 @@ object SdkLiteralTypes { */ def durations(): SdkLiteralType[Duration] = SdkJavaLiteralTypes.durations() + /** Returns a [[SdkLiteralType]] for products. + * @return + * the [[SdkLiteralType]] + */ + def generics[T <: Product: TypeTag: ClassTag](): SdkLiteralType[T] = { + ScalaLiteralType[T]( + LiteralType.ofSimpleType(SimpleType.STRUCT), + (value: T) => Literal.ofScalar(Scalar.ofGeneric(toStruct(value))), + (x: Literal) => toProduct(x.scalar().generic()), + (v: T) => BindingData.ofScalar(Scalar.ofGeneric(toStruct(v))), + "generics" + ) + } + + private def toStruct(product: Product): Struct = { + def productToMap(product: Product): Map[String, Any] = { + // by spec getDeclaredFields is not ordered but in practice it works fine + // it's a lot better since Scala 2.13 because productElementNames was introduced + // (product.productElementNames zip product.productIterator).toMap + product.getClass.getDeclaredFields + .map(_.getName) + .zip(product.productIterator.toList) + .toMap + } + + def mapToStruct(map: Map[String, Any]): Struct = { + val fields = map.map({ case (key, value) => + (key, anyToStructValue(value)) + }) + Struct.of(fields.asJava) + } + + def anyToStructValue(value: Any): Struct.Value = { + def anyToStructureValue0(value: Any): Struct.Value = { + value match { + case s: String => Struct.Value.ofStringValue(s) + case n @ (_: Byte | _: Short | _: Int | _: Long | _: Float | + _: Double) => + Struct.Value.ofNumberValue(n.toString.toDouble) + case b: Boolean => Struct.Value.ofBoolValue(b) + case l: List[Any] => + Struct.Value.ofListValue(l.map(anyToStructValue).asJava) + case m: Map[_, _] => + Struct.Value.ofStructValue( + mapToStruct(m.asInstanceOf[Map[String, Any]]) + ) + case null => Struct.Value.ofNullValue() + case p: Product => + Struct.Value.ofStructValue(mapToStruct(productToMap(p))) + case _ => + throw new IllegalArgumentException( + s"Unsupported type: ${value.getClass}" + ) + } + } + + value match { + case Some(v) => anyToStructureValue0(v) + case None => Struct.Value.ofNullValue() + case _ => anyToStructureValue0(value) + } + } + + mapToStruct(productToMap(product)) + } + + private def toProduct[T <: Product: TypeTag: ClassTag]( + struct: Struct + ): T = { + def structToMap(struct: Struct): Map[String, Any] = { + struct + .fields() + .asScala + .map({ case (key, value) => + (key, structValueToAny(value)) + }) + .toMap + } + + def mapToProduct[S <: Product: TypeTag: ClassTag]( + map: Map[String, Any] + ): S = { + val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader) + + def valueToParamValue(value: Any, param: Symbol): Any = { + def valueToParamValue0(value: Any, param: Symbol): Any = { + if (param.typeSignature =:= typeOf[Byte]) { + value.asInstanceOf[Double].toByte + } else if (param.typeSignature =:= typeOf[Short]) { + value.asInstanceOf[Double].toShort + } else if (param.typeSignature =:= typeOf[Int]) { + value.asInstanceOf[Double].toInt + } else if (param.typeSignature =:= typeOf[Long]) { + value.asInstanceOf[Double].toLong + } else if (param.typeSignature =:= typeOf[Float]) { + value.asInstanceOf[Double].toFloat + } else if (param.typeSignature <:< typeOf[Product]) { + val typeTag = createTypeTag(param.typeSignature) + val classTag = ClassTag( + typeTag.mirror.runtimeClass(param.typeSignature) + ) + mapToProduct(value.asInstanceOf[Map[String, Any]])( + typeTag, + classTag + ) + } else { + value + } + } + + if (param.typeSignature <:< typeOf[Option[Any]]) { + Some( + valueToParamValue0( + value, + param.typeSignature.dealias.typeArgs.head.typeSymbol + ) + ) + } else { + valueToParamValue0(value, param) + } + } + + def createTypeTag[U <: Product](tpe: Type): TypeTag[U] = { + val typSym = mirror.staticClass(tpe.typeSymbol.fullName) + // note: this uses internal API, otherwise we will need to depend on scala-compiler at runtime + val typeRef = + universe.internal.typeRef(NoPrefix, typSym, List.empty) + + TypeTag( + mirror, + new TypeCreator { + override def apply[V <: Universe with Singleton]( + m: Mirror[V] + ): V#Type = { + assert( + m == mirror, + s"TypeTag[$typeRef] defined in $mirror cannot be migrated to $m." + ) + typeRef.asInstanceOf[V#Type] + } + } + ) + } + + val clazz = typeOf[S].typeSymbol.asClass + val classMirror = mirror.reflectClass(clazz) + val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod + val constructorMirror = classMirror.reflectConstructor(constructor) + + val constructorArgs = + constructor.paramLists.flatten.map((param: Symbol) => { + val paramName = param.name.toString + val value = map.getOrElse( + paramName, + throw new IllegalArgumentException( + s"Map is missing required parameter named $paramName" + ) + ) + valueToParamValue(value, param) + }) + + constructorMirror(constructorArgs: _*).asInstanceOf[S] + } + + def structValueToAny(value: Struct.Value): Any = { + value.kind() match { + case Struct.Value.Kind.STRING_VALUE => value.stringValue() + case Struct.Value.Kind.NUMBER_VALUE => value.numberValue() + case Struct.Value.Kind.BOOL_VALUE => value.boolValue() + case Struct.Value.Kind.LIST_VALUE => + value.listValue().asScala.map(structValueToAny).toList + case Struct.Value.Kind.STRUCT_VALUE => structToMap(value.structValue()) + case Struct.Value.Kind.NULL_VALUE => None + } + } + + mapToProduct[T](structToMap(struct)) + } + + /** Returns a [[SdkLiteralType]] for blob. + * + * @return + * the [[SdkLiteralType]] + */ + def blobs(blobType: BlobType): SdkLiteralType[Blob] = + SdkJavaLiteralTypes.blobs(blobType) + /** Returns a [[SdkLiteralType]] for flyte collections. * * @param elementType diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 19e190348..b94f3718e 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -19,6 +19,7 @@ package org.flyte.flytekitscala import java.time.{Duration, Instant} import java.{util => ju} import magnolia.{CaseClass, Magnolia, Param, SealedTrait} +import org.flyte.api.v1.BlobType.BlobDimensionality import org.flyte.api.v1._ import org.flyte.flytekit.{ SdkBindingData, @@ -29,6 +30,8 @@ import org.flyte.flytekit.{ import scala.annotation.implicitNotFound import scala.collection.JavaConverters._ +import scala.reflect.{ClassTag, classTag} +import scala.reflect.runtime.universe.{TypeTag, typeOf} /** Type class to map between Flyte `Variable` and `Literal` and Scala case * classes. @@ -230,6 +233,28 @@ object SdkScalaType { implicit def durationLiteralType: SdkScalaLiteralType[Duration] = DelegateLiteralType(SdkLiteralTypes.durations()) + // more specific matching to fail the usage of SdkBindingData[Option[_]] + implicit def optionLiteralType: SdkScalaLiteralType[Option[_]] = ??? + + // fixme: using Product is just an approximation for case class because Product + // is also super class of, for example, Option and Tuple + implicit def productLiteralType[T <: Product: TypeTag: ClassTag] + : SdkScalaLiteralType[T] = + DelegateLiteralType(SdkLiteralTypes.generics()) + + // fixme: create blob type from annotation, or rethink how we could offer the offloaded data feature + // https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype + implicit def blobLiteralType: SdkScalaLiteralType[Blob] = + DelegateLiteralType( + SdkLiteralTypes.blobs( + BlobType + .builder() + .format("") + .dimensionality(BlobDimensionality.SINGLE) + .build() + ) + ) + // TODO we are forced to do this because SdkDataBinding.ofInteger returns a SdkBindingData // This makes Scala dev mad when they are forced to use the java types instead of scala types // We need to think what to do, maybe move the factory methods out of SdkDataBinding into their own class diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 278035e50..2e843bf0c 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -51,6 +51,11 @@ auto-service-annotations provided + + org.flyte + flytekit-api + provided + diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java index 0da75a2c3..f5fac552f 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java @@ -20,13 +20,26 @@ @AutoValue public abstract class BQReference { + @AutoValue + public abstract static class Nested { + public abstract String project(); + + public abstract String dataset(); + + public abstract String tableName(); + } + public abstract String project(); public abstract String dataset(); public abstract String tableName(); + // this is only to test nested nested auto-value + public abstract Nested nested(); + public static BQReference create(String project, String dataset, String tableName) { - return new AutoValue_BQReference(project, dataset, tableName); + return new AutoValue_BQReference( + project, dataset, tableName, new AutoValue_BQReference_Nested(project, dataset, tableName)); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java index 0e78c9774..23fee27ad 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java @@ -16,12 +16,15 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; -// @AutoService(SdkRunnableTask.class) +@AutoService(SdkRunnableTask.class) public class BuildBqReference extends SdkRunnableTask { private static final long serialVersionUID = -489898361071672070L; @@ -35,7 +38,10 @@ public BuildBqReference() { @Override public Output run(Input input) { return Output.create( - BQReference.create(input.project().get(), input.dataset().get(), input.tableName().get())); + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(BQReference.class), + BQReference.create( + input.project().get(), input.dataset().get(), input.tableName().get()))); } @AutoValue @@ -58,11 +64,8 @@ public static Input create( public abstract static class Output { abstract SdkBindingData ref(); - public static Output create(BQReference ref) { - // TODO We need a way to generate SdkBindings of generic autovalues like BQReference - // that would be mapped to sdkStructs. JacksonSdkType of nested autovalues are mapped as - // structs - return null; + public static Output create(SdkBindingData ref) { + return new AutoValue_BuildBqReference_Output(ref); } } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java index 9e82df7ca..3b3fb8a8b 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java @@ -16,13 +16,15 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; -// @AutoService(SdkRunnableTask.class) +@AutoService(SdkRunnableTask.class) public class MockLookupBqTask extends SdkRunnableTask { private static final long serialVersionUID = 604843235716487166L; @@ -39,21 +41,25 @@ public abstract static class Input { public static Input create( SdkBindingData ref, SdkBindingData checkIfExists) { - return null; // TODO + return new AutoValue_MockLookupBqTask_Input(ref, checkIfExists); } } @AutoValue public abstract static class Output { + public abstract SdkBindingData ref(); + public abstract SdkBindingData exists(); - public static Output create(boolean exists) { - return new AutoValue_MockLookupBqTask_Output(SdkBindingDataFactory.of(exists)); + public static Output create(BQReference ref, boolean exists) { + return new AutoValue_MockLookupBqTask_Output( + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(BQReference.class), ref), + SdkBindingDataFactory.of(exists)); } } @Override public Output run(Input input) { - return Output.create(input.ref().get().tableName().contains("table-exists")); + return Output.create(input.ref().get(), input.ref().get().tableName().contains("table-exists")); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java index d1b565e77..5a5c6ccca 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java @@ -16,6 +16,7 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -23,10 +24,7 @@ import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; -// This workflow relays on SdkBinding that should be serialized -// as Struct. By going to typed inputs and outputs, we have de-scoped the support -// of structs. -// @AutoService(SdkWorkflow.class) +@AutoService(SdkWorkflow.class) public class MockPipelineWorkflow extends SdkWorkflow { public MockPipelineWorkflow() { diff --git a/integration-tests/src/test/java/org/flyte/AdditionalIT.java b/integration-tests/src/test/java/org/flyte/AdditionalIT.java index 3c9914312..5355ddb71 100644 --- a/integration-tests/src/test/java/org/flyte/AdditionalIT.java +++ b/integration-tests/src/test/java/org/flyte/AdditionalIT.java @@ -16,26 +16,21 @@ */ package org.flyte; -import static org.flyte.FlyteContainer.CLIENT; +import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import flyteidl.core.Literals; +import flyteidl.core.Literals.LiteralMap; import org.flyte.utils.Literal; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class AdditionalIT { - @BeforeAll - public static void beforeAll() { - CLIENT.registerWorkflows("integration-tests/target/lib"); - } - +class AdditionalIT extends Fixtures { @ParameterizedTest @CsvSource({ "0,0,0,0,a == b && c == d", @@ -48,7 +43,7 @@ public static void beforeAll() { "0,1,0,1,a < b && c < d", "1,0,0,1,a > b && c < d", }) - public void testBranchNodeWorkflow(long a, long b, long c, long d, String expected) { + void testBranchNodeWorkflow(long a, long b, long c, long d, String expected) { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.integrationtests.BranchNodeWorkflow", @@ -67,8 +62,7 @@ public void testBranchNodeWorkflow(long a, long b, long c, long d, String expect "table-exists,true", "non-existent,false", }) - @Disabled("Not supporting struct with the strongly typed implementation.") - public void testStructs(String name, boolean expected) { + void testStructs(String name, boolean expected) { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.integrationtests.structs.MockPipelineWorkflow", @@ -76,4 +70,12 @@ public void testStructs(String name, boolean expected) { assertThat(output, equalTo(Literal.ofBooleanMap(ImmutableMap.of("exists", expected)))); } + + @Test + void testStructsScala() { + Literals.LiteralMap output = + CLIENT.createExecution("NestedIOWorkflowLaunchPlan", STAGING_DOMAIN); + + assertThat(output, equalTo(LiteralMap.getDefaultInstance())); + } } diff --git a/integration-tests/src/test/java/org/flyte/FlyteContainer.java b/integration-tests/src/test/java/org/flyte/Fixtures.java similarity index 60% rename from integration-tests/src/test/java/org/flyte/FlyteContainer.java rename to integration-tests/src/test/java/org/flyte/Fixtures.java index a95fe7d76..d80bec77d 100644 --- a/integration-tests/src/test/java/org/flyte/FlyteContainer.java +++ b/integration-tests/src/test/java/org/flyte/Fixtures.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 Flyte Authors. + * Copyright 2023 Flyte Authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,16 @@ */ package org.flyte; +import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; + import org.flyte.utils.FlyteSandboxClient; -public class FlyteContainer { - static final FlyteSandboxClient CLIENT = FlyteSandboxClient.create(); +class Fixtures { + protected static final FlyteSandboxClient CLIENT = FlyteSandboxClient.create(); + + static { + CLIENT.registerWorkflows("integration-tests/target/lib"); + CLIENT.registerWorkflows("flytekit-examples/target/lib"); + CLIENT.registerWorkflows("flytekit-examples-scala/target/lib", STAGING_DOMAIN); + } } diff --git a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java index cbce85dee..5597c9e7f 100644 --- a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java +++ b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java @@ -16,31 +16,19 @@ */ package org.flyte; -import static org.flyte.FlyteContainer.CLIENT; -import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; import static org.flyte.utils.Literal.ofIntegerMap; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import flyteidl.core.Literals; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class JavaExamplesIT { - private static final String CLASSPATH_EXAMPLES = "flytekit-examples/target/lib"; - private static final String CLASSPATH_EXAMPLES_SCALA = "flytekit-examples-scala/target/lib"; - - @BeforeAll - public static void beforeAll() { - CLIENT.registerWorkflows(CLASSPATH_EXAMPLES); - CLIENT.registerWorkflows(CLASSPATH_EXAMPLES_SCALA, STAGING_DOMAIN); - } - +class JavaExamplesIT extends Fixtures { @Test - public void testSumTask() { + void testSumTask() { Literals.LiteralMap output = CLIENT.createTaskExecution( "org.flyte.examples.SumTask", @@ -53,7 +41,7 @@ public void testSumTask() { } @Test - public void testFibonacciWorkflow() { + void testFibonacciWorkflow() { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.examples.FibonacciWorkflow", @@ -66,7 +54,7 @@ public void testFibonacciWorkflow() { } @Test - public void testDynamicFibonacciWorkflow() { + void testDynamicFibonacciWorkflow() { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.examples.DynamicFibonacciWorkflow", ofIntegerMap(ImmutableMap.of("n", 2L))); diff --git a/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java b/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java index 9d9caac47..888416893 100644 --- a/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java +++ b/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java @@ -16,7 +16,6 @@ */ package org.flyte; -import static org.flyte.FlyteContainer.CLIENT; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -29,13 +28,13 @@ import org.junit.jupiter.api.io.TempDir; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class SerializeJavaIT { +class SerializeJavaIT extends Fixtures { private static final String CLASSPATH = "flytekit-examples/target/lib"; @TempDir Path managed; @Test - public void testSerializeWorkflows() { + void testSerializeWorkflows() { try { File current = new File("target/protos"); File tempDir = managed.resolve(current.getAbsolutePath()).toFile(); diff --git a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java index 2acdf083f..1e91e0b7f 100644 --- a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java +++ b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java @@ -22,7 +22,9 @@ import flyteidl.admin.ExecutionOuterClass; import flyteidl.core.Execution; import flyteidl.core.IdentifierOuterClass; +import flyteidl.core.IdentifierOuterClass.ResourceType; import flyteidl.core.Literals; +import flyteidl.core.Literals.LiteralMap; import flyteidl.service.AdminServiceGrpc; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -55,6 +57,19 @@ public static FlyteSandboxClient create() { return new FlyteSandboxClient(version, stub); } + public Literals.LiteralMap createExecution(String name, String domain) { + return createExecution( + IdentifierOuterClass.Identifier.newBuilder() + .setResourceType(ResourceType.LAUNCH_PLAN) + .setDomain(domain) + .setProject(PROJECT) + .setName(name) + .setVersion(version) + .build(), + LiteralMap.getDefaultInstance(), + domain); + } + public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap inputs) { return createExecution( IdentifierOuterClass.Identifier.newBuilder() @@ -64,7 +79,8 @@ public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap .setName(name) .setVersion(version) .build(), - inputs); + inputs, + DEVELOPMENT_DOMAIN); } public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inputs) { @@ -76,15 +92,16 @@ public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inpu .setName(name) .setVersion(version) .build(), - inputs); + inputs, + DEVELOPMENT_DOMAIN); } private Literals.LiteralMap createExecution( - IdentifierOuterClass.Identifier id, Literals.LiteralMap inputs) { + IdentifierOuterClass.Identifier id, Literals.LiteralMap inputs, String domain) { ExecutionOuterClass.ExecutionCreateResponse response = stub.createExecution( ExecutionOuterClass.ExecutionCreateRequest.newBuilder() - .setDomain(DEVELOPMENT_DOMAIN) + .setDomain(domain) .setProject(PROJECT) .setInputs(inputs) .setSpec(ExecutionOuterClass.ExecutionSpec.newBuilder().setLaunchPlan(id).build())