Skip to content

Commit

Permalink
Blob and struct support (#258)
Browse files Browse the repository at this point in the history
* Bring back Blob support

Signed-off-by: Hongxin Liang <[email protected]>

* Remove BlobTypeDescription

Signed-off-by: Hongxin Liang <[email protected]>

* Add blob back

Signed-off-by: Hongxin Liang <[email protected]>

* Blob in list and map

Signed-off-by: Hongxin Liang <[email protected]>

* Clean up

Signed-off-by: Hongxin Liang <[email protected]>

* Support struct (#259)

Signed-off-by: Hongxin Liang <[email protected]>

* Support struct in Scala layer (#262)

Signed-off-by: Hongxin Liang <[email protected]>

---------

Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Oct 17, 2023
1 parent 73c6c71 commit 6a26d59
Show file tree
Hide file tree
Showing 43 changed files with 1,003 additions and 354 deletions.
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
version=2.5.2
version=3.7.14
runner.dialect=scala212source3

Original file line number Diff line number Diff line change
@@ -1 +1 @@
org.flyte.examples.flytekitscala.FibonacciLaunchPlan
org.flyte.examples.flytekitscala.LaunchPlanRegistry
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ org.flyte.examples.flytekitscala.SumTask
org.flyte.examples.flytekitscala.GreetTask
org.flyte.examples.flytekitscala.AddQuestionTask
org.flyte.examples.flytekitscala.NoInputsTask
org.flyte.examples.flytekitscala.NestedIOTask
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
org.flyte.examples.flytekitscala.FibonacciWorkflow
org.flyte.examples.flytekitscala.WelcomeWorkflow
org.flyte.examples.flytekitscala.NestedIOWorkflow
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ import org.flyte.flytekit.{SdkLaunchPlan, SimpleSdkLaunchPlanRegistry}
import org.flyte.flytekitscala.SdkScalaType

case class FibonacciLaunchPlanInput(fib0: Long, fib1: Long)
case class NestedIOLaunchPlanInput(name: String, generic: Nested)

class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry {
class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry {
// Register default launch plans for all workflows
registerDefaultLaunchPlans()

Expand Down Expand Up @@ -53,4 +54,35 @@ class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry {
.withDefaultInput("fib0", 0L)
.withDefaultInput("fib1", 1L)
)

registerLaunchPlan(
SdkLaunchPlan
.of(new NestedIOWorkflow)
.withName("NestedIOWorkflowLaunchPlan")
.withDefaultInput(
SdkScalaType[NestedIOLaunchPlanInput],
NestedIOLaunchPlanInput(
"yo",
Nested(
boolean = true,
1.toByte,
2.toShort,
3,
4L,
5.toFloat,
6.toDouble,
"hello",
List("1", "2"),
List(NestedNested(7.toDouble, NestedNestedNested("world"))),
Map("1" -> "1", "2" -> "2"),
Map("foo" -> NestedNested(7.toDouble, NestedNestedNested("world"))),
Some(false),
None,
Some(List("3", "4")),
Some(Map("3" -> "3", "4" -> "4")),
NestedNested(7.toDouble, NestedNestedNested("world"))
)
)
)
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform}
import org.flyte.flytekitscala.{
Description,
SdkBindingDataFactory,
SdkScalaType
}

case class NestedNestedNested(string: String)
case class NestedNested(double: Double, nested: NestedNestedNested)
case class Nested(
boolean: Boolean,
byte: Byte,
short: Short,
int: Int,
long: Long,
float: Float,
double: Double,
string: String,
list: List[String],
listOfNested: List[NestedNested],
map: Map[String, String],
mapOfNested: Map[String, NestedNested],
optBoolean: Option[Boolean],
optByte: Option[Byte],
optList: Option[List[String]],
optMap: Option[Map[String, String]],
nested: NestedNested
)
case class NestedIOTaskInput(
@Description("the name of the person to be greeted")
name: SdkBindingData[String],
@Description("a nested input")
generic: SdkBindingData[Nested]
)
case class NestedIOTaskOutput(
@Description("the name of the person to be greeted")
name: SdkBindingData[String],
@Description("a nested input")
generic: SdkBindingData[Nested]
)

/** Example Flyte task that takes a name as the input and outputs a simple
* greeting message.
*/
class NestedIOTask
extends SdkRunnableTask[
NestedIOTaskInput,
NestedIOTaskOutput
](
SdkScalaType[NestedIOTaskInput],
SdkScalaType[NestedIOTaskOutput]
) {

/** Defines task behavior. This task takes a name as the input, wraps it in a
* welcome message, and outputs the message.
*
* @param input
* the name of the person to be greeted
* @return
* the welcome message
*/
override def run(input: NestedIOTaskInput): NestedIOTaskOutput =
NestedIOTaskOutput(
input.name,
input.generic
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2020-2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekitscala.{
SdkScalaType,
SdkScalaWorkflow,
SdkScalaWorkflowBuilder
}

class NestedIOWorkflow
extends SdkScalaWorkflow[NestedIOTaskInput, Unit](
SdkScalaType[NestedIOTaskInput],
SdkScalaType.unit
) {

override def expand(
builder: SdkScalaWorkflowBuilder,
input: NestedIOTaskInput
): Unit = {
builder.apply(new NestedIOTask(), input)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.Instant;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.jackson.JacksonSdkType;
Expand All @@ -34,8 +35,20 @@ public AllInputsTask() {
JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class));
}

@AutoValue
public abstract static class Nested {
public abstract String hello();

public abstract String world();

public static Nested create(String hello, String world) {
return new AutoValue_AllInputsTask_Nested(hello, world);
}
}

@AutoValue
public abstract static class AutoAllInputsInput {

public abstract SdkBindingData<Long> i();

public abstract SdkBindingData<Double> f();
Expand All @@ -48,8 +61,9 @@ public abstract static class AutoAllInputsInput {

public abstract SdkBindingData<Duration> d();

// TODO add blobs to sdkbinding data
// public abstract SdkBindingData<Blob> blob();
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

Expand All @@ -66,13 +80,14 @@ public static AutoAllInputsInput create(
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
// Blob blob,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsInput(
i, f, s, b, t, d, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}

Expand All @@ -91,8 +106,9 @@ public abstract static class AutoAllInputsOutput {

public abstract SdkBindingData<Duration> d();

// TODO add blobs to sdkbinding data
// public abstract SdkBindingData<Blob> blob();
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

Expand All @@ -109,12 +125,14 @@ public static AutoAllInputsOutput create(
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsOutput(
i, f, s, b, t, d, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}

Expand All @@ -127,6 +145,8 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) {
input.b(),
input.t(),
input.d(),
input.blob(),
input.generic(),
input.l(),
input.m(),
input.emptyList(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
import org.flyte.examples.AllInputsTask.Nested;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkBindingDataFactory;
import org.flyte.flytekit.SdkNode;
import org.flyte.flytekit.SdkTypes;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.JacksonSdkLiteralType;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
Expand All @@ -57,6 +63,20 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
SdkBindingDataFactory.of(true),
SdkBindingDataFactory.of(someInstant),
SdkBindingDataFactory.of(Duration.ofDays(1L)),
SdkBindingDataFactory.of(
Blob.builder()
.uri("file://test/test.csv")
.metadata(
BlobMetadata.builder()
.type(
BlobType.builder()
.format("")
.dimensionality(BlobDimensionality.SINGLE)
.build())
.build())
.build()),
SdkBindingDataFactory.of(
JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")),
SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")),
SdkBindingDataFactory.ofStringMap(Map.of("test", "test")),
SdkBindingDataFactory.ofStringCollection(Collections.emptyList()),
Expand All @@ -71,6 +91,8 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
outputs.b(),
outputs.t(),
outputs.d(),
outputs.blob(),
outputs.generic(),
outputs.l(),
outputs.m(),
outputs.emptyList(),
Expand All @@ -92,8 +114,9 @@ public abstract static class AllInputsWorkflowOutput {

public abstract SdkBindingData<Duration> d();

// TODO add blobs to sdkbinding data
// public abstract SdkBindingData<Blob> blob();
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<Nested> generic();

public abstract SdkBindingData<List<String>> l();

Expand All @@ -110,12 +133,14 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create(
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<Nested> generic,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput(
i, f, s, b, t, d, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap);
}
}
}
4 changes: 4 additions & 0 deletions flytekit-jackson/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-parameter-names</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.SimpleType;
import org.flyte.flytekit.SdkLiteralType;
import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper;

/**
* Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link
Expand Down Expand Up @@ -102,7 +103,8 @@ public Literal toLiteral(T value) {
var tree = OBJECT_MAPPER.valueToTree(value);

try {
return OBJECT_MAPPER.treeToValue(tree, Literal.class);
return Literal.ofScalar(
Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap()));
} catch (IOException e) {
throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e);
}
Expand Down
Loading

0 comments on commit 6a26d59

Please sign in to comment.