diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 2eec0f8cf3..cfc2bfa010 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -82,6 +82,7 @@ require ( cloud.google.com/go/pubsub v1.34.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect + github.com/Masterminds/semver v1.5.0 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 049add4bbc..5b7d47b2a6 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -75,6 +75,8 @@ github.com/DataDog/datadog-go v3.4.1+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/datadog-go v4.0.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20191210083620-6965a1cfed68/go.mod h1:gMGUEe16aZh0QN941HgDjwrdjU4iTthPoz2/AtDRADE= github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= diff --git a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go index b868612269..b3d01897f1 100644 --- a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go +++ b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go @@ -241,6 +241,28 @@ func GetExecutionRequest() *admin.ExecutionCreateRequest { } } +func GetExecutionRequestWithOffloadedInputs(inputParam string, literalValue *core.Literal) *admin.ExecutionCreateRequest { + execReq := GetExecutionRequest() + execReq.Inputs = &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "s3://bucket/key", + SizeBytes: 100, + InferredType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + }, + }, + } + return execReq +} + func GetSampleWorkflowSpecForTest() *admin.WorkflowSpec { return &admin.WorkflowSpec{ Template: &core.WorkflowTemplate{ diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator.go b/flyteadmin/pkg/manager/impl/validation/execution_validator.go index 0a21165c93..f7b385b8a8 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator.go @@ -100,7 +100,13 @@ func CheckAndFetchInputsForExecution( } executionInputMap[name] = expectedInput.GetDefault() } else { - inputType := validators.LiteralTypeForLiteral(executionInputMap[name]) + var inputType *core.LiteralType + switch executionInputMap[name].GetValue().(type) { + case *core.Literal_OffloadedMetadata: + inputType = executionInputMap[name].GetOffloadedMetadata().GetInferredType() + default: + inputType = validators.LiteralTypeForLiteral(executionInputMap[name]) + } if !validators.AreTypesCastable(inputType, expectedInput.GetVar().GetType()) { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got %s", name, expectedInput.GetVar().GetType(), inputType) } diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go index 1329dc6f96..7e5f991788 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go @@ -105,6 +105,40 @@ func TestGetExecutionInputs(t *testing.T) { assert.EqualValues(t, expectedMap, actualInputs) } +func TestGetExecutionWithOffloadedInputs(t *testing.T) { + execLiteral := &core.Literal{ + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "s3://bucket/key", + SizeBytes: 100, + InferredType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + } + executionRequest := testutils.GetExecutionRequestWithOffloadedInputs("foo", execLiteral) + lpRequest := testutils.GetLaunchPlanRequest() + + actualInputs, err := CheckAndFetchInputsForExecution( + executionRequest.Inputs, + lpRequest.Spec.FixedInputs, + lpRequest.Spec.DefaultInputs, + ) + expectedMap := core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": execLiteral, + "bar": coreutils.MustMakeLiteral("bar-value"), + }, + } + assert.Nil(t, err) + assert.NotNil(t, actualInputs) + assert.EqualValues(t, expectedMap.GetLiterals()["foo"], actualInputs.Literals["foo"]) + assert.EqualValues(t, expectedMap.GetLiterals()["bar"], actualInputs.Literals["bar"]) +} + func TestValidateExecInputsWrongType(t *testing.T) { executionRequest := testutils.GetExecutionRequest() lpRequest := testutils.GetLaunchPlanRequest() diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index 6c9bd2fdbb..894eaee435 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -234,7 +234,7 @@ func validateLiteralMap(inputMap *core.LiteralMap, fieldName string) error { if name == "" { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing key in %s", fieldName) } - if fixedInput == nil || fixedInput.GetValue() == nil { + if fixedInput.GetValue() == nil && fixedInput.GetOffloadedMetadata() == nil { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing valid literal in %s %s", fieldName, name) } if isDateTime(fixedInput) { diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 5d828f9e9b..f579049aff 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 + github.com/Masterminds/semver v1.5.0 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 github.com/flyteorg/flyte/flyteidl v0.0.0-00010101000000-000000000000 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 8bbdd06eba..07a92b902b 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -64,6 +64,8 @@ github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 h1:xJ0dAkuxJXf github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295/go.mod h1:e0aH495YLkrsIe9fhedd6aSR6fgU/qhKvtroi6y7G/M= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625/go.mod h1:6PnrZv6zUDkrNMw0mIoGRmGBR7i9LulhKPmxFq4rUiM= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/aws/aws-sdk-go v1.44.2 h1:5VBk5r06bgxgRKVaUtm1/4NT/rtrnH2E4cnAYv5zgQc= diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index bcd1064e67..486ac35a16 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -554,6 +554,10 @@ func GetOutputsFile(outputDir DataReference) DataReference { return outputDir + "/outputs.pb" } +func GetOutputsLiteralMetadataFile(literalKey string, outputDir DataReference) DataReference { + return outputDir + DataReference(fmt.Sprintf("/%s_offloaded_metadata.pb", literalKey)) +} + func GetInputsFile(inputDir DataReference) DataReference { return inputDir + "/inputs.pb" } diff --git a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go index 2d967c560e..0976df669b 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go @@ -35,7 +35,13 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor continue } - inputType := validators.LiteralTypeForLiteral(inputVal) + var inputType *core.LiteralType + switch inputVal.GetValue().(type) { + case *core.Literal_OffloadedMetadata: + inputType = inputVal.GetOffloadedMetadata().GetInferredType() + default: + inputType = validators.LiteralTypeForLiteral(inputVal) + } if !validators.AreTypesCastable(inputType, v.Type) { errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String())) continue diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index a0217e186a..2d61c94970 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -34,12 +34,16 @@ package config import ( + "context" + "fmt" "time" + "github.com/Masterminds/semver" "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/contextutils" + "github.com/flyteorg/flyte/flytestdlib/logger" ) //go:generate pflags Config --default-var=defaultConfig @@ -120,6 +124,14 @@ var ( EventVersion: 0, DefaultParallelismBehavior: ParallelismBehaviorUnlimited, }, + LiteralOffloadingConfig: LiteralOffloadingConfig{ + Enabled: false, // Default keep this disabled and we will followup when flytekit is released with the offloaded changes. + SupportedSDKVersions: map[string]string{ // The key is the SDK name (matches the supported SDK in core.RuntimeMetadata_RuntimeType) and the value is the minimum supported version + "FLYTE_SDK": "1.13.5", // Expected release number with flytekit support from this PR https://github.com/flyteorg/flytekit/pull/2685 + }, + MinSizeInMBForOffloading: 10, // 10 MB is the default size for offloading + MaxSizeInMBForOffloading: 1000, // 1 GB is the default size before failing fast. + }, } ) @@ -127,40 +139,79 @@ var ( // the base configuration to start propeller // NOTE: when adding new fields, do not mark them as "omitempty" if it's desirable to read the value from env variables. type Config struct { - KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` - MasterURL string `json:"master"` - Workers int `json:"workers" pflag:",Number of threads to process workflows"` - WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"` - DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"` - LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"` - ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"` - MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` - DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."` - Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` - MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` - MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."` - EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"` - MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"` - MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` - GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"` - LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` - PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` - MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"` - EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."` - KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"` - NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` - MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` - EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` - IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` - ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` - IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` - ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` - IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` - ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` - ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` - CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` - NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` - ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` + KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` + MasterURL string `json:"master"` + Workers int `json:"workers" pflag:",Number of threads to process workflows"` + WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"` + DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"` + LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"` + ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"` + MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` + DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."` + Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` + MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` + MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."` + EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"` + MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"` + MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` + GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"` + LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` + PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` + MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"` + EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."` + KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"` + NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` + MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` + EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` + IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` + ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` + IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` + ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` + IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` + ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` + ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` + CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` + NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` + ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` + LiteralOffloadingConfig LiteralOffloadingConfig `json:"literalOffloadingConfig" pflag:",config used for literal offloading."` +} + +type LiteralOffloadingConfig struct { + Enabled bool + // Maps flytekit and union SDK names to minimum supported version that can handle reading offloaded literals. + SupportedSDKVersions map[string]string + // Default, 10Mbs. Determines the size of a literal at which to trigger offloading + MinSizeInMBForOffloading int64 + // Fail fast threshold + MaxSizeInMBForOffloading int64 +} + +// IsSupportedSDKVersion returns true if the provided SDK and version are supported by the literal offloading config. +func (l LiteralOffloadingConfig) IsSupportedSDKVersion(sdk string, versionString string) bool { + if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok { + c, err := semver.NewConstraint(fmt.Sprintf(">= %s", leastSupportedVersion)) + if err != nil { + // This should never happen + logger.Warnf(context.TODO(), "Failed to parse version constraint %s", leastSupportedVersion) + return false + } + version, err := semver.NewVersion(versionString) + if err != nil { + // This should never happen + logger.Warnf(context.TODO(), "Failed to parse version %s", versionString) + return false + } + return c.Check(version) + } + return false +} + +// GetSupportedSDKVersion returns the least supported version for the provided SDK. +func (l LiteralOffloadingConfig) GetSupportedSDKVersion(sdk string) string { + if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok { + return leastSupportedVersion + } + return "" } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index 858fc8a8ba..b2e88e88e6 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -112,5 +112,9 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "node-execution-worker-count"), defaultConfig.NodeExecutionWorkerCount, "Number of workers to evaluate node executions, currently only used for array nodes") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-config.event-version"), defaultConfig.ArrayNode.EventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "array-node-config.default-parallelism-behavior"), defaultConfig.ArrayNode.DefaultParallelismBehavior, "Default parallelism behavior for array nodes") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.Enabled"), defaultConfig.LiteralOffloadingConfig.Enabled, "") + cmdFlags.StringToString(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.SupportedSDKVersions"), defaultConfig.LiteralOffloadingConfig.SupportedSDKVersions, "") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MinSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MinSizeInMBForOffloading, "") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MaxSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MaxSizeInMBForOffloading, "") return cmdFlags } diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index 27e7b76efa..aadb24b36a 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -967,4 +967,60 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_literalOffloadingConfig.Enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.Enabled", testValue) + if vBool, err := cmdFlags.GetBool("literalOffloadingConfig.Enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LiteralOffloadingConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.SupportedSDKVersions", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "a=1,b=2" + + cmdFlags.Set("literalOffloadingConfig.SupportedSDKVersions", testValue) + if vStringToString, err := cmdFlags.GetStringToString("literalOffloadingConfig.SupportedSDKVersions"); err == nil { + testDecodeRaw_Config(t, vStringToString, &actual.LiteralOffloadingConfig.SupportedSDKVersions) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.MinSizeInMBForOffloading", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.MinSizeInMBForOffloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MinSizeInMBForOffloading"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MinSizeInMBForOffloading) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.MaxSizeInMBForOffloading", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.MaxSizeInMBForOffloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MaxSizeInMBForOffloading"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MaxSizeInMBForOffloading) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flytepropeller/pkg/controller/config/config_test.go b/flytepropeller/pkg/controller/config/config_test.go new file mode 100644 index 0000000000..afc9ed2fea --- /dev/null +++ b/flytepropeller/pkg/controller/config/config_test.go @@ -0,0 +1,54 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsSupportedSDKVersion(t *testing.T) { + t.Run("supported version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.True(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) + }) + + t.Run("unsupported version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.15.0")) + }) + + t.Run("unsupported SDK", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("unknown", "0.16.0")) + }) + + t.Run("invalid version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "invalid")) + }) + + t.Run("invalid constraint", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "invalid", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) + }) +} diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index c59aa9745d..39047e811d 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -436,14 +436,14 @@ func New(ctx context.Context, cfg *config.Config, kubeClientset kubernetes.Inter recoveryClient := recovery.NewClient(adminClient) nodeHandlerFactory, err := factory.NewHandlerFactory(ctx, launchPlanActor, launchPlanActor, - kubeClient, kubeClientset, catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, scope) + kubeClient, kubeClientset, catalogClient, recoveryClient, &cfg.EventConfig, cfg.LiteralOffloadingConfig, cfg.ClusterID, signalClient, scope) if err != nil { return nil, errors.Wrapf(err, "failed to create node handler factory") } nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, launchPlanActor, launchPlanActor, storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, - catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) + catalogClient, recoveryClient, cfg.LiteralOffloadingConfig, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index a101ed5a30..5e9f910e14 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -45,6 +45,7 @@ var ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig gatherOutputsRequestChannel chan *gatherOutputsRequest metrics metrics nodeExecutionRequestChannel chan *nodeExecutionRequest @@ -498,7 +499,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // attempt best effort at initializing outputLiterals with output variable names. currently // only TaskNode and WorkflowNode contain node interfaces. outputLiterals := make(map[string]*idlcore.Literal) - switch arrayNode.GetSubNodeSpec().GetKind() { case v1alpha1.NodeKindTask: taskID := *arrayNode.GetSubNodeSpec().TaskRef @@ -547,6 +547,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu return handler.UnknownTransition, fmt.Errorf("worker error(s) encountered: %s", workerErrorCollector.Summary(events.MaxErrorMessageLength)) } + // only offload literal if config is enabled for this feature. + if a.literalOffloadingConfig.Enabled { + for outputLiteralKey, outputLiteral := range outputLiterals { + // if the size of the output Literal is > threshold then we write the literal to the offloaded store and populate the literal with its zero value and update the offloaded url + // use the OffloadLargeLiteralKey to create {OffloadLargeLiteralKey}_offloaded_metadata.pb file in the datastore. + // Update the url in the outputLiteral with the offloaded url and also update the size of the literal. + offloadedOutputFile := v1alpha1.GetOutputsLiteralMetadataFile(outputLiteralKey, nCtx.NodeStatus().GetOutputDir()) + if err := common.OffloadLargeLiteral(ctx, nCtx.DataStore(), offloadedOutputFile, outputLiteral, a.literalOffloadingConfig); err != nil { + return handler.UnknownTransition, err + } + } + } outputLiteralMap := &idlcore.LiteralMap{ Literals: outputLiterals, } @@ -649,7 +661,7 @@ func (a *arrayNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) e } // New initializes a new arrayNodeHandler -func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { +func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, literalOffloadingConfig config.LiteralOffloadingConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { // create k8s PluginState byte mocks to reuse instead of creating for each subNode evaluation pluginStateBytesNotStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseNotStarted}) if err != nil { @@ -676,6 +688,7 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ eventConfig: deepCopiedEventConfig, + literalOffloadingConfig: literalOffloadingConfig, gatherOutputsRequestChannel: make(chan *gatherOutputsRequest), metrics: newMetrics(arrayScope), nodeExecutionRequestChannel: make(chan *nodeExecutionRequest), diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index 648d70e36c..cb2f2898a6 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -52,6 +52,8 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter adminClient := launchplan.NewFailFastLaunchPlanExecutor() enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} eventConfig := &config.EventConfig{ErrorOnAlreadyExists: true} + offloadingConfig := config.LiteralOffloadingConfig{Enabled: false} + literalOffloadingConfig := config.LiteralOffloadingConfig{Enabled: true, MinSizeInMBForOffloading: 1024, MaxSizeInMBForOffloading: 1024 * 1024} mockEventSink := eventmocks.NewMockEventSink() mockHandlerFactory := &mocks.HandlerFactory{} mockHandlerFactory.OnGetHandlerMatch(mock.Anything).Return(nodeHandler, nil) @@ -62,11 +64,11 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter // create node executor nodeExecutor, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, dataStore, enqueueWorkflowFunc, mockEventSink, adminClient, - adminClient, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) + adminClient, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, offloadingConfig, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) assert.NoError(t, err) // return ArrayNodeHandler - arrayNodeHandler, err := New(nodeExecutor, eventConfig, scope) + arrayNodeHandler, err := New(nodeExecutor, eventConfig, literalOffloadingConfig, scope) if err != nil { return nil, err } diff --git a/flytepropeller/pkg/controller/nodes/common/utils.go b/flytepropeller/pkg/controller/nodes/common/utils.go index 04ddc5183d..89bb0afe2e 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils.go +++ b/flytepropeller/pkg/controller/nodes/common/utils.go @@ -2,17 +2,28 @@ package common import ( "context" + "fmt" "strconv" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/storage" ) -const maxUniqueIDLength = 20 +const ( + maxUniqueIDLength = 20 + MB = 1024 * 1024 // 1 MB in bytes (1 MiB) +) // GenerateUniqueID is the UniqueId of a node is unique within a given workflow execution. // In order to achieve that we track the lineage of the node. @@ -67,3 +78,88 @@ func GetTargetEntity(ctx context.Context, nCtx interfaces.NodeExecutionContext) } return targetEntity } + +// OffloadLargeLiteral offloads the large literal if meets the threshold conditions +func OffloadLargeLiteral(ctx context.Context, datastore *storage.DataStore, dataReference storage.DataReference, + toBeOffloaded *idlcore.Literal, literalOffloadingConfig config.LiteralOffloadingConfig) error { + literalSizeBytes := int64(proto.Size(toBeOffloaded)) + literalSizeMB := literalSizeBytes / MB + // check if the literal is large + if literalSizeMB >= literalOffloadingConfig.MaxSizeInMBForOffloading { + errString := fmt.Sprintf("Literal size [%d] MB is larger than the max size [%d] MB for offloading", literalSizeMB, literalOffloadingConfig.MaxSizeInMBForOffloading) + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + if literalSizeMB < literalOffloadingConfig.MinSizeInMBForOffloading { + logger.Debugf(ctx, "Literal size [%d] MB is smaller than the min size [%d] MB for offloading", literalSizeMB, literalOffloadingConfig.MinSizeInMBForOffloading) + return nil + } + + inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) + if inferredType == nil { + errString := "Failed to determine literal type for offloaded literal" + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + + // offload the literal + if err := datastore.WriteProtobuf(ctx, dataReference, storage.Options{}, toBeOffloaded); err != nil { + logger.Errorf(ctx, "Failed to offload literal at location [%s] with error [%s]", dataReference, err) + return err + } + + // update the literal with the offloaded URI, size and inferred type + toBeOffloaded.Value = &idlcore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlcore.LiteralOffloadedMetadata{ + Uri: dataReference.String(), + SizeBytes: uint64(literalSizeBytes), + InferredType: inferredType, + }, + } + logger.Infof(ctx, "Offloaded literal at location [%s] with size [%d] MB and inferred type [%s]", dataReference, literalSizeMB, inferredType) + return nil +} + +// CheckOffloadingCompat checks if the upstream and downstream nodes are compatible with the literal offloading feature and returns an error if not contained in phase info object +func CheckOffloadingCompat(ctx context.Context, nCtx interfaces.NodeExecutionContext, inputLiterals map[string]*core.Literal, node v1alpha1.ExecutableNode, literalOffloadingConfig config.LiteralOffloadingConfig) *handler.PhaseInfo { + consumesOffloadLiteral := false + for _, val := range inputLiterals { + if val != nil && val.GetOffloadedMetadata() != nil { + consumesOffloadLiteral = true + break + } + } + if !consumesOffloadLiteral { + return nil + } + var phaseInfo handler.PhaseInfo + + // Return early if the node is not of type NodeKindTask + if node.GetKind() != v1alpha1.NodeKindTask { + return nil + } + + // Process NodeKindTask + taskID := *node.GetTaskID() + taskNode, err := nCtx.ExecutionContext().GetTask(taskID) + if err != nil { + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "GetTaskIDFailure", err.Error(), nil) + return &phaseInfo + } + runtimeData := taskNode.CoreTask().GetMetadata().GetRuntime() + if !literalOffloadingConfig.IsSupportedSDKVersion(runtimeData.GetType().String(), runtimeData.GetVersion()) { + if !literalOffloadingConfig.Enabled { + errMsg := fmt.Sprintf("task [%s] is trying to consume offloaded literals but feature is not enabled", taskID) + logger.Errorf(ctx, errMsg) + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_USER, "LiteralOffloadingDisabled", errMsg, nil) + return &phaseInfo + } + leastSupportedVersion := literalOffloadingConfig.GetSupportedSDKVersion(runtimeData.GetType().String()) + errMsg := fmt.Sprintf("Literal offloading is not supported for this task as its registered with SDK version [%s] which is less than the least supported version [%s] for this feature", runtimeData.GetVersion(), leastSupportedVersion) + logger.Errorf(ctx, errMsg) + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_USER, "LiteralOffloadingNotSupported", errMsg, nil) + return &phaseInfo + } + + return nil +} diff --git a/flytepropeller/pkg/controller/nodes/common/utils_test.go b/flytepropeller/pkg/controller/nodes/common/utils_test.go index 9e451da69a..7d5ce1e372 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils_test.go +++ b/flytepropeller/pkg/controller/nodes/common/utils_test.go @@ -1,11 +1,22 @@ package common import ( + "context" "testing" "github.com/stretchr/testify/assert" + idlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" + executorMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors/mocks" + nodeMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces/mocks" + "github.com/flyteorg/flyte/flytestdlib/contextutils" + "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytestdlib/promutils/labeled" + "github.com/flyteorg/flyte/flytestdlib/storage" ) type ParentInfo struct { @@ -66,3 +77,177 @@ func TestCreateParentInfoNil(t *testing.T) { assert.Equal(t, uint32(1), parent.CurrentAttempt()) assert.True(t, parent.IsInDynamicChain()) } + +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + +func TestOffloadLargeLiteral(t *testing.T) { + t.Run("offload successful with valid size", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 0, + MaxSizeInMBForOffloading: 1, + } + inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.NoError(t, err) + assert.Equal(t, "foo/bar", toBeOffloaded.GetOffloadedMetadata().GetUri()) + assert.Equal(t, uint64(6), toBeOffloaded.GetOffloadedMetadata().GetSizeBytes()) + assert.Equal(t, inferredType.GetSimple(), toBeOffloaded.GetOffloadedMetadata().InferredType.GetSimple()) + + }) + + t.Run("offload fails with size larger than max", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 0, + MaxSizeInMBForOffloading: 0, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.Error(t, err) + }) + + t.Run("offload not attempted with size smaller than min", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 2, + MaxSizeInMBForOffloading: 3, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.NoError(t, err) + assert.Nil(t, toBeOffloaded.GetOffloadedMetadata()) + }) +} + +func TestCheckOffloadingCompat(t *testing.T) { + ctx := context.Background() + nCtx := &nodeMocks.NodeExecutionContext{} + executionContext := &executorMocks.ExecutionContext{} + executableTask := &mocks.ExecutableTask{} + node := &mocks.ExecutableNode{} + node.OnGetKind().Return(v1alpha1.NodeKindTask) + nCtx.OnExecutionContext().Return(executionContext) + executionContext.OnGetTask("task1").Return(executableTask, nil) + executableTask.OnCoreTask().Return(&idlCore.TaskTemplate{ + Metadata: &idlCore.TaskMetadata{ + Runtime: &idlCore.RuntimeMetadata{ + Type: idlCore.RuntimeMetadata_FLYTE_SDK, + Version: "0.16.0", + }, + }, + }) + taskID := "task1" + node.OnGetTaskID().Return(&taskID) + t.Run("supported version success", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + idlCore.RuntimeMetadata_FLYTE_SDK.String(): "0.16.0", + }, + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.Nil(t, phaseInfo) + }) + t.Run("unsupported version", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + idlCore.RuntimeMetadata_FLYTE_SDK.String(): "0.17.0", + }, + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.NotNil(t, phaseInfo) + assert.Equal(t, idlCore.ExecutionError_USER, phaseInfo.GetErr().GetKind()) + assert.Equal(t, "LiteralOffloadingNotSupported", phaseInfo.GetErr().GetCode()) + }) + t.Run("offloading config disabled with offloaded data", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + Enabled: false, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.NotNil(t, phaseInfo) + assert.Equal(t, idlCore.ExecutionError_USER, phaseInfo.GetErr().GetKind()) + assert.Equal(t, "LiteralOffloadingDisabled", phaseInfo.GetErr().GetCode()) + }) + t.Run("offloading config enabled with no offloaded data", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.Nil(t, phaseInfo) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 47c91edc51..2c3103e4ad 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -491,6 +491,7 @@ type nodeExecutor struct { defaultExecutionDeadline time.Duration enqueueWorkflow v1alpha1.EnqueueWorkflow eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig interruptibleFailureThreshold int32 maxNodeRetriesForSystemFailures uint32 metrics *nodeMetrics @@ -764,6 +765,10 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur } if nodeInputs != nil { + p := common.CheckOffloadingCompat(ctx, nCtx, nodeInputs.Literals, node, c.literalOffloadingConfig) + if p != nil { + return *p, nil + } inputsFile := v1alpha1.GetInputsFile(dataDir) if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { c.metrics.InputsWriteFailure.Inc(ctx) @@ -1417,7 +1422,7 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, + catalogClient catalog.Client, recoveryClient recovery.Client, literalOffloadingConfig config.LiteralOffloadingConfig, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, nodeHandlerFactory interfaces.HandlerFactory, scope promutils.Scope) (interfaces.Node, error) { // TODO we may want to make this configurable. @@ -1469,6 +1474,7 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, enqueueWorkflow: enQWorkflow, eventConfig: eventConfig, + literalOffloadingConfig: literalOffloadingConfig, interruptibleFailureThreshold: nodeConfig.InterruptibleFailureThreshold, maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), metrics: metrics, diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index ea7da42112..7fc4c05992 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -69,7 +69,7 @@ func TestSetInputsForStartNode(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) exec, err := NewExecutor(ctx, config.GetConfig().NodeConfig, mockStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + adminClient, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) inputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -116,7 +116,7 @@ func TestSetInputsForStartNode(t *testing.T) { failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) execFail, err := NewExecutor(ctx, config.GetConfig().NodeConfig, failStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { w := createDummyBaseWorkflow(mockStorage) @@ -145,7 +145,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -156,7 +156,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("error")) - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -176,7 +176,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -281,7 +281,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -696,7 +696,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { nodeConfig := config.GetConfig().NodeConfig nodeConfig.EnableCRDebugMetadata = test.enableCRDebugMetadata execIface, err := NewExecutor(ctx, nodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -771,7 +771,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -885,7 +885,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -952,7 +952,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -983,7 +983,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1018,7 +1018,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1131,7 +1131,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1249,7 +1249,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) // Node not yet started @@ -1889,7 +1889,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -2666,7 +2666,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Cache(t *testing.T) { mockHandlerFactory.OnGetHandler(v1alpha1.NodeKindTask).Return(mockHandler, nil) nodeExecutor, err := NewExecutor(ctx, nodeConfig, dataStore, enqueueWorkflow, mockEventSink, adminClient, adminClient, rawOutputPrefix, fakeKubeClient, catalogClient, - recoveryClient, eventConfig, testClusterID, signalClient, mockHandlerFactory, testScope) + recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, mockHandlerFactory, testScope) assert.NoError(t, err) return nodeExecutor diff --git a/flytepropeller/pkg/controller/nodes/factory/handler_factory.go b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go index 424bd15f10..72dcff5310 100644 --- a/flytepropeller/pkg/controller/nodes/factory/handler_factory.go +++ b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go @@ -28,16 +28,17 @@ import ( type handlerFactory struct { handlers map[v1alpha1.NodeKind]interfaces.NodeHandler - workflowLauncher launchplan.Executor - launchPlanReader launchplan.Reader - kubeClient executors.Client - kubeClientset kubernetes.Interface - catalogClient catalog.Client - recoveryClient recovery.Client - eventConfig *config.EventConfig - clusterID string - signalClient service.SignalServiceClient - scope promutils.Scope + workflowLauncher launchplan.Executor + launchPlanReader launchplan.Reader + kubeClient executors.Client + kubeClientset kubernetes.Interface + catalogClient catalog.Client + recoveryClient recovery.Client + eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig + clusterID string + signalClient service.SignalServiceClient + scope promutils.Scope } func (f *handlerFactory) GetHandler(kind v1alpha1.NodeKind) (interfaces.NodeHandler, error) { @@ -54,7 +55,7 @@ func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, se return err } - arrayHandler, err := array.New(executor, f.eventConfig, f.scope) + arrayHandler, err := array.New(executor, f.eventConfig, f.literalOffloadingConfig, f.scope) if err != nil { return err } @@ -79,18 +80,20 @@ func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, se func NewHandlerFactory(ctx context.Context, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, kubeClient executors.Client, kubeClientset kubernetes.Interface, catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, + literalOffloadingConfig config.LiteralOffloadingConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (interfaces.HandlerFactory, error) { return &handlerFactory{ - workflowLauncher: workflowLauncher, - launchPlanReader: launchPlanReader, - kubeClient: kubeClient, - kubeClientset: kubeClientset, - catalogClient: catalogClient, - recoveryClient: recoveryClient, - eventConfig: eventConfig, - clusterID: clusterID, - signalClient: signalClient, - scope: scope, + workflowLauncher: workflowLauncher, + launchPlanReader: launchPlanReader, + kubeClient: kubeClient, + kubeClientset: kubeClientset, + catalogClient: catalogClient, + recoveryClient: recoveryClient, + eventConfig: eventConfig, + literalOffloadingConfig: literalOffloadingConfig, + clusterID: clusterID, + signalClient: signalClient, + scope: scope, }, nil } diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index 85667b0e26..a3d028e94b 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -242,11 +242,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -328,11 +328,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -398,7 +398,7 @@ func BenchmarkWorkflowExecutor(b *testing.B) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() handlerFactory := &nodemocks.HandlerFactory{} nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, scope) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, scope) assert.NoError(b, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -512,7 +512,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -613,11 +613,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() recoveryClient := &recoveryMocks.Client{} - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() assert.NoError(t, err) @@ -685,7 +685,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { diff --git a/go.mod b/go.mod index 8fd55ed61a..8c8053def6 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.4.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 // indirect + github.com/Masterminds/semver v1.5.0 // indirect github.com/NYTimes/gizmo v1.3.6 // indirect github.com/Shopify/sarama v1.26.4 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect diff --git a/go.sum b/go.sum index ae60f26800..68eebb1fde 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20191210083620-6965a1cf github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625/go.mod h1:6PnrZv6zUDkrNMw0mIoGRmGBR7i9LulhKPmxFq4rUiM= github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA=