Skip to content

Commit

Permalink
Add cache support for Sdk{Runnable|Container}Tasks (#122)
Browse files Browse the repository at this point in the history
* Add support for caching SdkRunnableTask

Signed-off-by: kiarash rezahanjani <[email protected]>

* Rename cache properties in SdkRunnableTask

to make them more with other properties, they follow
the javabeans naming convention

Signed-off-by: Nelson Arapé <[email protected]>

* Rename TaskTemplate properties to follow proto

cache -> discoverable
cacheVersion -> discoveryVersion

Signed-off-by: Nelson Arapé <[email protected]>

* (Sdk)RunnableTask.getCacheVersion return null by default

Null are more natural default value in Java than
empty string. Make the null <-> "" conversion at
the ProtoUtil boundary

Signed-off-by: Nelson Arapé <[email protected]>

* Add unit tests for "" to null conversions in TaskTemplate

Signed-off-by: Nelson Arapé <[email protected]>

* Add cache to ContainerTask

By introducing a common super interface to RunnableTask and ContainerTask

Signed-off-by: Nelson Arapé <[email protected]>

* Add examples of cache utilization

Signed-off-by: Nelson Arapé <[email protected]>

* Simplify empty string handling in ProtoUtil

Signed-off-by: Nelson Arapé <[email protected]>

* Make test about "" <-> null more explicit

Signed-off-by: Nelson Arapé <[email protected]>

* Add test for default cache settings is cache disabled

Signed-off-by: Nelson Arapé <[email protected]>

Co-authored-by: kiarash rezahanjani <[email protected]>
Co-authored-by: Nelson Arapé <[email protected]>
  • Loading branch information
3 people authored Jul 28, 2022
1 parent 8aa31ab commit 37a9acf
Show file tree
Hide file tree
Showing 17 changed files with 432 additions and 52 deletions.
18 changes: 2 additions & 16 deletions flytekit-api/src/main/java/org/flyte/api/v1/ContainerTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@
*/
package org.flyte.api.v1;

import static java.util.Collections.emptyMap;

import java.util.List;

/** Building block for tasks that execute arbitrary containers. */
public interface ContainerTask {

/** Specifies task name. */
String getName();
public interface ContainerTask extends Task {

/** Specifies container image. */
String getImage();
Expand All @@ -38,22 +33,13 @@ public interface ContainerTask {
/** Specifies container environment variables. */
List<KeyValuePair> getEnv();

@Override
default String getType() {
return "raw-container";
}

TypedInterface getInterface();

/** Specifies container resource requests. */
default Resources getResources() {
return Resources.builder().build();
}

/** Specifies task retry policy. */
RetryStrategy getRetries();

/** Specifies custom container parameters. */
default Struct getCustom() {
return Struct.of(emptyMap());
}
}
10 changes: 3 additions & 7 deletions flytekit-api/src/main/java/org/flyte/api/v1/RunnableTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
import java.util.Map;

/** Building block for tasks that execute Java code. */
public interface RunnableTask {

String getName();
public interface RunnableTask extends Task {

@Override
default String getType() {
// FIXME default only for backwards-compatibility, remove in 0.3.x
return "java-task";
}

@Override
default Struct getCustom() {
// FIXME default only for backwards-compatibility, remove in 0.3.x
return Struct.of(emptyMap());
Expand All @@ -40,9 +40,5 @@ default Resources getResources() {
return Resources.builder().build();
}

TypedInterface getInterface();

Map<String, Literal> run(Map<String, Literal> inputs);

RetryStrategy getRetries();
}
61 changes: 61 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/Task.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2021 Flyte Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

import static java.util.Collections.emptyMap;

/** Super interfaces for all tasks. */
public interface Task {

/** Specifies task name. */
String getName();

/** Specifies the task type identifier. */
String getType();

/** Specifies the task interface: inputs/outputs. */
TypedInterface getInterface();

/** Specifies custom data about the task. */
default Struct getCustom() {
return Struct.of(emptyMap());
}

/** Specifies task retry policy. */
RetryStrategy getRetries();

/**
* Indicates whether the system should attempt to lookup this task's output to avoid duplication
* of work.
*/
default boolean isCached() {
return false;
}

/** Indicates a logical version to apply to this task for the purpose of cache. */
default String getCacheVersion() {
return null;
}

/**
* Indicates whether the system should attempt to execute cached instances in serial to avoid
* duplicate work.
*/
default boolean isCacheSerializable() {
return false;
}
}
13 changes: 13 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ public abstract class TaskTemplate {

public abstract Struct custom();

public abstract boolean discoverable();

@Nullable
public abstract String discoveryVersion();

public abstract boolean cacheSerializable();

public abstract Builder toBuilder();

public static Builder builder() {
Expand All @@ -56,6 +63,12 @@ public abstract static class Builder {

public abstract Builder custom(Struct custom);

public abstract Builder discoverable(boolean discoverable);

public abstract Builder discoveryVersion(String discoveryVersion);

public abstract Builder cacheSerializable(boolean cacheSerializable);

public abstract TaskTemplate build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class SumTask
override def run(input: SumTaskInput): SumTaskOutput = {
SumTaskOutput(input.a + input.b)
}

override def isCached: Boolean = true

override def getCacheVersion: String = "1"

override def isCacheSerializable: Boolean = true
}

object SumTask {
Expand Down
15 changes: 15 additions & 0 deletions flytekit-examples/src/main/java/org/flyte/examples/SumTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,19 @@ public static SumOutput create(long c) {
public SumOutput run(SumInput input) {
return SumOutput.create(input.a() + input.b());
}

@Override
public boolean isCached() {
return true;
}

@Override
public String getCacheVersion() {
return "1";
}

@Override
public boolean isCacheSerializable() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,25 @@ public List<String> getCommand() {
public Map<String, String> getEnv() {
return emptyMap();
}

/**
* Indicates whether the system should attempt to lookup this task's output to avoid duplication
* of work.
*/
public boolean isCached() {
return false;
}

/** Indicates a logical version to apply to this task for the purpose of cache. */
public String getCacheVersion() {
return null;
}

/**
* Indicates whether the system should attempt to execute cached instances in serial to avoid
* duplicate work.
*/
public boolean isCacheSerializable() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ public List<KeyValuePair> getEnv() {
public Resources getResources() {
return sdkTask.getResources().toIdl();
}

@Override
public boolean isCached() {
return sdkTask.isCached();
}

@Override
public String getCacheVersion() {
return sdkTask.getCacheVersion();
}

@Override
public boolean isCacheSerializable() {
return sdkTask.isCacheSerializable();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ public int getRetries() {
return 0;
}

/**
* Indicates whether the system should attempt to lookup this task's output to avoid duplication
* of work.
*/
public boolean isCached() {
return false;
}

/** Indicates a logical version to apply to this task for the purpose of cache. */
public String getCacheVersion() {
return null;
}

/**
* Indicates whether the system should attempt to execute cached instances in serial to avoid
* duplicate work.
*/
public boolean isCacheSerializable() {
return false;
}

@Override
public SdkNode apply(
SdkWorkflowBuilder builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ public RetryStrategy getRetries() {
return RetryStrategy.builder().retries(sdkTask.getRetries()).build();
}

@Override
public boolean isCached() {
return sdkTask.isCached();
}

@Override
public String getCacheVersion() {
return sdkTask.getCacheVersion();
}

@Override
public boolean isCacheSerializable() {
return sdkTask.isCacheSerializable();
}

@Override
public String getName() {
return sdkTask.getName();
Expand Down
15 changes: 15 additions & 0 deletions jflyte/src/main/java/org/flyte/jflyte/ExecuteLocalLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,20 @@ public Struct getCustom() {
public RetryStrategy getRetries() {
return runnableTask.getRetries();
}

@Override
public boolean isCached() {
return runnableTask.isCached();
}

@Override
public String getCacheVersion() {
return runnableTask.getCacheVersion();
}

@Override
public boolean isCacheSerializable() {
return runnableTask.isCacheSerializable();
}
}
}
35 changes: 21 additions & 14 deletions jflyte/src/main/java/org/flyte/jflyte/ProjectClosure.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import org.flyte.api.v1.RunnableTask;
import org.flyte.api.v1.RunnableTaskRegistrar;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.Task;
import org.flyte.api.v1.TaskIdentifier;
import org.flyte.api.v1.TaskTemplate;
import org.flyte.api.v1.WorkflowIdentifier;
Expand Down Expand Up @@ -451,13 +452,7 @@ static TaskTemplate createTaskTemplateForRunnableTask(RunnableTask task, String
.resources(resources)
.build();

return TaskTemplate.builder()
.container(container)
.interface_(task.getInterface())
.retries(task.getRetries())
.type(task.getType())
.custom(task.getCustom())
.build();
return createTaskTemplate(task, container);
}

@VisibleForTesting
Expand All @@ -472,13 +467,25 @@ static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) {
.resources(resources)
.build();

return TaskTemplate.builder()
.container(container)
.interface_(task.getInterface())
.retries(task.getRetries())
.type(task.getType())
.custom(task.getCustom())
.build();
return createTaskTemplate(task, container);
}

private static TaskTemplate createTaskTemplate(Task task, Container container) {
TaskTemplate.Builder templateBuilder =
TaskTemplate.builder()
.container(container)
.interface_(task.getInterface())
.retries(task.getRetries())
.type(task.getType())
.custom(task.getCustom())
.discoverable(task.isCached())
.cacheSerializable(task.isCacheSerializable());

if (task.getCacheVersion() != null) {
templateBuilder.discoveryVersion(task.getCacheVersion());
}

return templateBuilder.build();
}

private static Optional<KeyValuePair> javaToolOptionsEnv(Resources resources) {
Expand Down
Loading

0 comments on commit 37a9acf

Please sign in to comment.