From 886dd845bfdb7fffb8c2dc0ec20eb42eb40e4e17 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Mon, 20 Nov 2023 17:48:00 -0800 Subject: [PATCH 01/11] wip - mpi tests failing Signed-off-by: Jeev B --- .../flytek8s/non_interruptible.go | 35 ----- .../flytek8s/plugin_exec_context.go | 123 ++++++++++++++++ .../go/tasks/plugins/k8s/dask/dask.go | 2 +- .../k8s/kfoperators/common/common_operator.go | 40 +++-- .../common/common_operator_test.go | 40 +---- .../tasks/plugins/k8s/kfoperators/mpi/mpi.go | 138 ++++++++---------- .../go/tasks/plugins/k8s/spark/spark.go | 2 +- 7 files changed, 223 insertions(+), 157 deletions(-) delete mode 100644 flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go create mode 100644 flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go deleted file mode 100644 index d2f5042cf8..0000000000 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go +++ /dev/null @@ -1,35 +0,0 @@ -package flytek8s - -import ( - pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" -) - -// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false -// This is useful as the runner and the scheduler pods should never be interruptible -type NonInterruptibleTaskExecutionMetadata struct { - pluginsCore.TaskExecutionMetadata -} - -func (n NonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { - return false -} - -// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is -// non-interruptible -type NonInterruptibleTaskExecutionContext struct { - pluginsCore.TaskExecutionContext - metadata NonInterruptibleTaskExecutionMetadata -} - -func (n NonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { - return n.metadata -} - -func NewNonInterruptibleTaskExecutionContext(ctx pluginsCore.TaskExecutionContext) NonInterruptibleTaskExecutionContext { - return NonInterruptibleTaskExecutionContext{ - TaskExecutionContext: ctx, - metadata: NonInterruptibleTaskExecutionMetadata{ - ctx.TaskExecutionMetadata(), - }, - } -} diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go new file mode 100644 index 0000000000..d09bcdb363 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go @@ -0,0 +1,123 @@ +package flytek8s + +import ( + v1 "k8s.io/api/core/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" +) + +type pluginTaskOverrides struct { + pluginsCore.TaskOverrides + resources *v1.ResourceRequirements + extendedResources *core.ExtendedResources +} + +func (to *pluginTaskOverrides) GetResources() *v1.ResourceRequirements { + if to.resources != nil { + return to.resources + } + return to.TaskOverrides.GetResources() +} + +func (to *pluginTaskOverrides) GetExtendedResources() *core.ExtendedResources { + if to.extendedResources != nil { + return to.extendedResources + } + return to.TaskOverrides.GetExtendedResources() +} + +type pluginTaskExecutionMetadata struct { + pluginsCore.TaskExecutionMetadata + interruptible *bool + overrides *pluginTaskOverrides +} + +func (tm *pluginTaskExecutionMetadata) IsInterruptible() bool { + if tm.interruptible != nil { + return *tm.interruptible + } + return tm.TaskExecutionMetadata.IsInterruptible() +} + +func (tm *pluginTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides { + if tm.overrides != nil { + return tm.overrides + } + return tm.TaskExecutionMetadata.GetOverrides() +} + +type pluginTaskExecutionContext struct { + pluginsCore.TaskExecutionContext + metadata *pluginTaskExecutionMetadata +} + +func (tc *pluginTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { + if tc.metadata != nil { + return tc.metadata + } + return tc.TaskExecutionContext.TaskExecutionMetadata() +} + +type PluginTaskExecutionContextOption func(*pluginTaskExecutionContext) + +func WithInterruptible(v bool) PluginTaskExecutionContextOption { + return func(tc *pluginTaskExecutionContext) { + if tc.metadata == nil { + tc.metadata = &pluginTaskExecutionMetadata{ + TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(), + } + } + tc.metadata.interruptible = &v + } +} + +func WithResources(r *v1.ResourceRequirements) PluginTaskExecutionContextOption { + return func(tc *pluginTaskExecutionContext) { + if tc.metadata == nil { + tc.metadata = &pluginTaskExecutionMetadata{ + TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(), + } + } + if tc.metadata.overrides == nil { + tc.metadata.overrides = &pluginTaskOverrides{ + TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(), + } + } + tc.metadata.overrides.resources = r + } +} + +func WithExtendedResources(er *core.ExtendedResources) PluginTaskExecutionContextOption { + return func(tc *pluginTaskExecutionContext) { + if tc.metadata == nil { + tc.metadata = &pluginTaskExecutionMetadata{ + TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(), + } + } + if tc.metadata.overrides == nil { + tc.metadata.overrides = &pluginTaskOverrides{ + TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(), + } + } + tc.metadata.overrides.extendedResources = er + } +} + +func NewPluginTaskExecutionContext(tc pluginsCore.TaskExecutionContext, options ...PluginTaskExecutionContextOption) pluginsCore.TaskExecutionContext { + tm := tc.TaskExecutionMetadata() + to := tm.GetOverrides() + ctx := &pluginTaskExecutionContext{ + TaskExecutionContext: tc, + metadata: &pluginTaskExecutionMetadata{ + TaskExecutionMetadata: tm, + overrides: &pluginTaskOverrides{ + TaskOverrides: to, + }, + }, + } + for _, o := range options { + o(ctx) + } + return ctx +} diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index eb27aec3ce..8257f00341 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -75,7 +75,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC if err != nil { return nil, err } - nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx) + nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false)) nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) if err != nil { return nil, err diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 594767b4b4..973c83b397 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "sort" "time" @@ -15,7 +16,9 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" ) @@ -254,23 +257,13 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res } // OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function -// updates the image, resources and command arguments of the container that matches the given containerName. -func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) error { +// updates the image and command arguments of the container that matches the given containerName. +func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, args []string) error { for idx, c := range podSpec.Containers { if c.Name == containerName { if image != "" { podSpec.Containers[idx].Image = image } - if resources != nil { - // if resources requests and limits both not set, we will not override the resources - if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 { - resources, err := flytek8s.ToK8sResourceRequirements(resources) - if err != nil { - return flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) - } - podSpec.Containers[idx].Resources = *resources - } - } if len(args) != 0 { podSpec.Containers[idx].Args = args } @@ -278,3 +271,26 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri } return nil } + +func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string) (*commonOp.ReplicaSpec, error) { + podSpec, objectMeta, oldPrimaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) + } + + OverridePrimaryContainerName(podSpec, oldPrimaryContainerName, primaryContainerName) + + cfg := config.GetK8sPluginConfig() + objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) + objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) + + replicas := int32(1) + return &commonOp.ReplicaSpec{ + Replicas: &replicas, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, nil +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index ae77d8c94d..1c33594997 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -220,8 +220,9 @@ func dummyPodSpec() v1.PodSpec { return v1.PodSpec{ Containers: []v1.Container{ { - Name: "primary container", - Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, + Name: "primary container", + Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, + Image: "dummy-image", Resources: v1.ResourceRequirements{ Limits: v1.ResourceList{ "cpu": resource.MustParse("2"), @@ -270,50 +271,21 @@ func TestOverrideContainerSpec(t *testing.T) { podSpec := dummyPodSpec() err := OverrideContainerSpec( &podSpec, "primary container", "testing-image", - &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - }, - }, []string{"python", "-m", "run.py"}, ) assert.NoError(t, err) assert.Equal(t, 2, len(podSpec.Containers)) assert.Equal(t, "testing-image", podSpec.Containers[0].Image) - assert.NotNil(t, podSpec.Containers[0].Resources.Limits) - assert.NotNil(t, podSpec.Containers[0].Resources.Requests) - // verify resources not overridden if empty resources - assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m"))) - assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m"))) assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args) } func TestOverrideContainerSpecEmptyFields(t *testing.T) { podSpec := dummyPodSpec() - err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{}) - assert.NoError(t, err) - assert.Equal(t, 2, len(podSpec.Containers)) - assert.NotNil(t, podSpec.Containers[0].Resources.Limits) - assert.NotNil(t, podSpec.Containers[0].Resources.Requests) - // verify resources not overridden if empty resources - assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1"))) - assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi"))) - assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2"))) - assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi"))) -} - -func TestOverrideContainerNilResources(t *testing.T) { - podSpec := dummyPodSpec() - podSpecCopy := podSpec.DeepCopy() - - err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{}) + err := OverrideContainerSpec(&podSpec, "primary container", "", []string{}) assert.NoError(t, err) assert.Equal(t, 2, len(podSpec.Containers)) - assert.Equal(t, podSpec.Containers[0].Resources.Limits, podSpecCopy.Containers[0].Resources.Limits) - assert.Equal(t, podSpec.Containers[0].Resources.Requests, podSpecCopy.Containers[0].Resources.Requests) + assert.Equal(t, "dummy-image", podSpec.Containers[0].Image) + assert.Equal(t, []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, podSpec.Containers[0].Args) } func dummyTaskContext() pluginsCore.TaskExecutionContext { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 1358422ec7..a1b6407141 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -8,7 +8,6 @@ import ( commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client" @@ -19,7 +18,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" @@ -48,6 +46,44 @@ func (mpiOperatorResourceHandler) BuildIdentityResource(ctx context.Context, tas }, nil } +func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs *kfplugins.DistributedMPITrainingReplicaSpec, isLauncher bool) (*commonOp.ReplicaSpec, error) { + taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} + if isLauncher { + taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false)) + } + if rs != nil && rs.GetResources() != nil { + resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) + } + taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources)) + } + newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...) + replicaSpec, err := common.ToReplicaSpec(ctx, newTaskCtx, kubeflowv1.MPIJobDefaultContainerName) + if err != nil { + return nil, err + } + if rs != nil { + err = common.OverrideContainerSpec( + &replicaSpec.Template.Spec, + kubeflowv1.MPIJobDefaultContainerName, + rs.GetImage(), + rs.GetCommand(), + ) + if err != nil { + return nil, err + } + replicaSpec.RestartPolicy = common.ParseRestartPolicy(rs.GetRestartPolicy()) + + if !isLauncher { + replicas := rs.GetReplicas() + replicaSpec.Replicas = &replicas + } + } + + return replicaSpec, nil +} + // Defines a func to create the full resource object that will be posted to k8s. func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) @@ -58,25 +94,11 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) - } - common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.MPIJobDefaultContainerName) - - var launcherReplica = common.ReplicaEntry{ - ReplicaNum: int32(1), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - } - var workerReplica = common.ReplicaEntry{ - ReplicaNum: int32(0), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - } slots := int32(1) runPolicy := commonOp.RunPolicy{} + var launcherReplicaSpec, workerReplicaSpec *commonOp.ReplicaSpec + if taskTemplate.TaskTypeVersion == 0 { mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{} err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs) @@ -84,8 +106,16 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - workerReplica.ReplicaNum = mpiTaskExtraArgs.GetNumWorkers() - launcherReplica.ReplicaNum = mpiTaskExtraArgs.GetNumLauncherReplicas() + replicaSpec, err := common.ToReplicaSpec(ctx, taskCtx, kubeflowv1.MPIJobDefaultContainerName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) + } + workerReplicaSpec = replicaSpec.DeepCopy() + workerReplicas := mpiTaskExtraArgs.GetNumWorkers() + workerReplicaSpec.Replicas = &workerReplicas + launcherReplicaSpec = replicaSpec.DeepCopy() + launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() + launcherReplicaSpec.Replicas = &launcherReplicas slots = mpiTaskExtraArgs.GetSlots() // V1 requires passing worker command as template config parameter @@ -95,10 +125,10 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu workerSpecCommand = strings.Split(val, " ") } - for k := range workerReplica.PodSpec.Containers { - if workerReplica.PodSpec.Containers[k].Name == kubeflowv1.MPIJobDefaultContainerName { - workerReplica.PodSpec.Containers[k].Args = workerSpecCommand - workerReplica.PodSpec.Containers[k].Command = []string{} + for k := range workerReplicaSpec.Template.Spec.Containers { + if workerReplicaSpec.Template.Spec.Containers[k].Name == kubeflowv1.MPIJobDefaultContainerName { + workerReplicaSpec.Template.Spec.Containers[k].Args = workerSpecCommand + workerReplicaSpec.Template.Spec.Containers[k].Command = []string{} } } @@ -110,36 +140,14 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - launcherReplicaSpec := kfMPITaskExtraArgs.GetLauncherReplicas() - if launcherReplicaSpec != nil { - // flyte commands will be passed as args to the container - err = common.OverrideContainerSpec( - launcherReplica.PodSpec, - kubeflowv1.MPIJobDefaultContainerName, - launcherReplicaSpec.GetImage(), - launcherReplicaSpec.GetResources(), - launcherReplicaSpec.GetCommand(), - ) - if err != nil { - return nil, err - } - launcherReplica.RestartPolicy = common.ParseRestartPolicy(launcherReplicaSpec.GetRestartPolicy()) + launcherReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetLauncherReplicas(), true) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create launcher replica spec: [%v]", err.Error()) } - workerReplicaSpec := kfMPITaskExtraArgs.GetWorkerReplicas() - if workerReplicaSpec != nil { - err = common.OverrideContainerSpec( - workerReplica.PodSpec, - kubeflowv1.MPIJobDefaultContainerName, - workerReplicaSpec.GetImage(), - workerReplicaSpec.GetResources(), - workerReplicaSpec.GetCommand(), - ) - if err != nil { - return nil, err - } - workerReplica.RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) - workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() + workerReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetWorkerReplicas(), false) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create worker replica spec: [%v]", err.Error()) } if kfMPITaskExtraArgs.GetRunPolicy() != nil { @@ -151,37 +159,19 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if workerReplica.ReplicaNum == 0 { + if *workerReplicaSpec.Replicas <= 0 { return nil, fmt.Errorf("number of worker should be more then 0") } - if launcherReplica.ReplicaNum == 0 { + if *launcherReplicaSpec.Replicas <= 0 { return nil, fmt.Errorf("number of launch worker should be more then 0") } - cfg := config.GetK8sPluginConfig() - objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - jobSpec := kubeflowv1.MPIJobSpec{ SlotsPerWorker: &slots, RunPolicy: runPolicy, MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.MPIJobReplicaTypeLauncher: { - Replicas: &launcherReplica.ReplicaNum, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *launcherReplica.PodSpec, - }, - RestartPolicy: launcherReplica.RestartPolicy, - }, - kubeflowv1.MPIJobReplicaTypeWorker: { - Replicas: &workerReplica.ReplicaNum, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *workerReplica.PodSpec, - }, - RestartPolicy: workerReplica.RestartPolicy, - }, + kubeflowv1.MPIJobReplicaTypeLauncher: launcherReplicaSpec, + kubeflowv1.MPIJobReplicaTypeWorker: workerReplicaSpec, }, } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index e5fd14478a..defcb275fe 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -173,7 +173,7 @@ type driverSpec struct { func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*driverSpec, error) { // Spark driver pods should always run as non-interruptible - nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx) + nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false)) podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) if err != nil { return nil, err From e1e1838f2cd00e4bb00cc5b09d336eb0b93430d0 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Mon, 20 Nov 2023 17:54:35 -0800 Subject: [PATCH 02/11] get mpi tests passing Signed-off-by: Jeev B --- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 7b6e1f5611..f283dece94 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -604,9 +604,11 @@ func TestBuildResourceMPIV1(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, }, }, Command: launcherCommand, @@ -616,9 +618,11 @@ func TestBuildResourceMPIV1(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, }, }, Command: workerCommand, @@ -628,19 +632,23 @@ func TestBuildResourceMPIV1(t *testing.T) { launcherResourceRequirements := &corev1.ResourceRequirements{ Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("250Mi"), }, Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("500Mi"), }, } workerResourceRequirements := &corev1.ResourceRequirements{ Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), }, Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), }, } @@ -673,9 +681,11 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, }, }, Command: []string{"/usr/sbin/sshd", "/.sshd_config"}, @@ -685,10 +695,12 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { workerResourceRequirements := &corev1.ResourceRequirements{ Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), }, Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), }, } From 28ebc38ff6189eb9cf9003e0cbe3ce1f41e2d911 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 21 Nov 2023 13:48:50 -0800 Subject: [PATCH 03/11] Update pytorch and tensorflow plugins Signed-off-by: Jeev B --- .../k8s/kfoperators/common/common_operator.go | 4 +- .../tasks/plugins/k8s/kfoperators/mpi/mpi.go | 14 +- .../k8s/kfoperators/pytorch/pytorch.go | 136 ++++++------- .../k8s/kfoperators/pytorch/pytorch_test.go | 30 ++- .../k8s/kfoperators/tensorflow/tensorflow.go | 180 +++++++----------- .../kfoperators/tensorflow/tensorflow_test.go | 16 +- 6 files changed, 183 insertions(+), 197 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 973c83b397..9eb98e687d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -18,8 +18,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) const ( @@ -284,7 +284,7 @@ func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - replicas := int32(1) + replicas := int32(0) return &commonOp.ReplicaSpec{ Replicas: &replicas, Template: v1.PodTemplateSpec{ diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index a1b6407141..6d89fce341 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -48,6 +48,7 @@ func (mpiOperatorResourceHandler) BuildIdentityResource(ctx context.Context, tas func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs *kfplugins.DistributedMPITrainingReplicaSpec, isLauncher bool) (*commonOp.ReplicaSpec, error) { taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} + // Launcher should always run as non-interruptible if isLauncher { taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false)) } @@ -63,16 +64,23 @@ func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExe if err != nil { return nil, err } + + // Launcher should have a single replica + if isLauncher { + replicas := int32(1) + replicaSpec.Replicas = &replicas + } + if rs != nil { - err = common.OverrideContainerSpec( + if err := common.OverrideContainerSpec( &replicaSpec.Template.Spec, kubeflowv1.MPIJobDefaultContainerName, rs.GetImage(), rs.GetCommand(), - ) - if err != nil { + ); err != nil { return nil, err } + replicaSpec.RestartPolicy = common.ParseRestartPolicy(rs.GetRestartPolicy()) if !isLauncher { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index e844c05d0f..8a7ea407ed 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -7,7 +7,6 @@ import ( commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client" @@ -18,7 +17,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" @@ -45,6 +43,52 @@ func (pytorchOperatorResourceHandler) BuildIdentityResource(ctx context.Context, }, nil } +func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs *kfplugins.DistributedPyTorchTrainingReplicaSpec, isMaster bool) (*commonOp.ReplicaSpec, error) { + taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} + // Master should always run as non-interruptible + if isMaster { + taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false)) + } + if rs != nil && rs.GetResources() != nil { + resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) + } + taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources)) + } + newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...) + replicaSpec, err := common.ToReplicaSpec(ctx, newTaskCtx, kubeflowv1.PytorchJobDefaultContainerName) + if err != nil { + return nil, err + } + + // Master should have a single replica + if isMaster { + replicas := int32(1) + replicaSpec.Replicas = &replicas + } + + if rs != nil { + if err := common.OverrideContainerSpec( + &replicaSpec.Template.Spec, + kubeflowv1.PytorchJobDefaultContainerName, + rs.GetImage(), + nil, + ); err != nil { + return nil, err + } + + replicaSpec.RestartPolicy = common.ParseRestartPolicy(rs.GetRestartPolicy()) + + if !isMaster { + replicas := rs.GetReplicas() + replicaSpec.Replicas = &replicas + } + } + + return replicaSpec, nil +} + // Defines a func to create the full resource object that will be posted to k8s. func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) @@ -55,25 +99,11 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) - } - common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName) - - var masterReplica = common.ReplicaEntry{ - ReplicaNum: int32(1), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - } - var workerReplica = common.ReplicaEntry{ - ReplicaNum: int32(0), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - } runPolicy := commonOp.RunPolicy{} var elasticPolicy *kubeflowv1.ElasticPolicy + var masterReplicaSpec, workerReplicaSpec *commonOp.ReplicaSpec + if taskTemplate.TaskTypeVersion == 0 { pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{} @@ -82,7 +112,17 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - workerReplica.ReplicaNum = pytorchTaskExtraArgs.GetWorkers() + replicaSpec, err := common.ToReplicaSpec(ctx, taskCtx, kubeflowv1.PytorchJobDefaultContainerName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) + } + masterReplicaSpec = replicaSpec.DeepCopy() + masterReplicas := int32(1) + masterReplicaSpec.Replicas = &masterReplicas + workerReplicaSpec = replicaSpec.DeepCopy() + workerReplicas := pytorchTaskExtraArgs.GetWorkers() + workerReplicaSpec.Replicas = &workerReplicas + // Set elastic config elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() if elasticConfig != nil { @@ -96,37 +136,14 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - // Replace specs of master replica, master should always have 1 replica - masterReplicaSpec := kfPytorchTaskExtraArgs.GetMasterReplicas() - if masterReplicaSpec != nil { - err := common.OverrideContainerSpec( - masterReplica.PodSpec, - kubeflowv1.PytorchJobDefaultContainerName, - masterReplicaSpec.GetImage(), - masterReplicaSpec.GetResources(), - nil, - ) - if err != nil { - return nil, err - } - masterReplica.RestartPolicy = common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy()) + masterReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetMasterReplicas(), true) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create master replica spec: [%v]", err.Error()) } - // Replace specs of worker replica - workerReplicaSpec := kfPytorchTaskExtraArgs.GetWorkerReplicas() - if workerReplicaSpec != nil { - err := common.OverrideContainerSpec( - workerReplica.PodSpec, - kubeflowv1.PytorchJobDefaultContainerName, - workerReplicaSpec.GetImage(), - workerReplicaSpec.GetResources(), - nil, - ) - if err != nil { - return nil, err - } - workerReplica.RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) - workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() + workerReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetWorkerReplicas(), false) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create worker replica spec: [%v]", err.Error()) } if kfPytorchTaskExtraArgs.GetRunPolicy() != nil { @@ -142,31 +159,14 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if workerReplica.ReplicaNum == 0 { + if *workerReplicaSpec.Replicas == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } - cfg := config.GetK8sPluginConfig() - objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - jobSpec := kubeflowv1.PyTorchJobSpec{ PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeMaster: { - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *masterReplica.PodSpec, - }, - RestartPolicy: masterReplica.RestartPolicy, - }, - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workerReplica.ReplicaNum, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *workerReplica.PodSpec, - }, - RestartPolicy: workerReplica.RestartPolicy, - }, + kubeflowv1.PyTorchJobReplicaTypeMaster: masterReplicaSpec, + kubeflowv1.PyTorchJobReplicaTypeWorker: workerReplicaSpec, }, RunPolicy: runPolicy, } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 9d95ecc61b..7ed46667ec 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -663,9 +663,11 @@ func TestBuildResourcePytorchV1(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, }, }, RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, @@ -675,9 +677,11 @@ func TestBuildResourcePytorchV1(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, }, }, }, @@ -685,19 +689,23 @@ func TestBuildResourcePytorchV1(t *testing.T) { masterResourceRequirements := &corev1.ResourceRequirements{ Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("250Mi"), }, Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("500Mi"), }, } workerResourceRequirements := &corev1.ResourceRequirements{ Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), }, Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), }, } @@ -714,7 +722,7 @@ func TestBuildResourcePytorchV1(t *testing.T) { assert.True(t, ok) assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) assert.Equal(t, testImageMaster, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) @@ -757,7 +765,7 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) assert.True(t, ok) assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy) assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit) assert.Equal(t, int64(1000), *pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) @@ -771,9 +779,11 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, }, }, }, @@ -794,10 +804,12 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { workerResourceRequirements := &corev1.ResourceRequirements{ Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), }, Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), }, } @@ -814,7 +826,7 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { assert.True(t, ok) assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 8db340d37e..4ee9c1fa95 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -7,7 +7,6 @@ import ( commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client" @@ -18,7 +17,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" @@ -45,6 +43,40 @@ func (tensorflowOperatorResourceHandler) BuildIdentityResource(ctx context.Conte }, nil } +func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs *kfplugins.DistributedTensorflowTrainingReplicaSpec) (*commonOp.ReplicaSpec, error) { + taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} + if rs != nil && rs.GetResources() != nil { + resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) + } + taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources)) + } + newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...) + replicaSpec, err := common.ToReplicaSpec(ctx, newTaskCtx, kubeflowv1.TFJobDefaultContainerName) + if err != nil { + return nil, err + } + + if rs != nil { + if err := common.OverrideContainerSpec( + &replicaSpec.Template.Spec, + kubeflowv1.TFJobDefaultContainerName, + rs.GetImage(), + nil, + ); err != nil { + return nil, err + } + + replicaSpec.RestartPolicy = common.ParseRestartPolicy(rs.GetRestartPolicy()) + + replicas := rs.GetReplicas() + replicaSpec.Replicas = &replicas + } + + return replicaSpec, nil +} + // Defines a func to create the full resource object that will be posted to k8s. func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) @@ -55,34 +87,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) - } - common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.TFJobDefaultContainerName) - - replicaSpecMap := map[commonOp.ReplicaType]*common.ReplicaEntry{ - kubeflowv1.TFJobReplicaTypeChief: { - ReplicaNum: int32(0), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - }, - kubeflowv1.TFJobReplicaTypeWorker: { - ReplicaNum: int32(0), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - }, - kubeflowv1.TFJobReplicaTypePS: { - ReplicaNum: int32(0), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - }, - kubeflowv1.TFJobReplicaTypeEval: { - ReplicaNum: int32(0), - PodSpec: podSpec.DeepCopy(), - RestartPolicy: commonOp.RestartPolicyNever, - }, - } + replicaSpecMap := make(map[commonOp.ReplicaType]*commonOp.ReplicaSpec) runPolicy := commonOp.RunPolicy{} if taskTemplate.TaskTypeVersion == 0 { @@ -93,10 +98,25 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = tensorflowTaskExtraArgs.GetChiefReplicas() - replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = tensorflowTaskExtraArgs.GetWorkers() - replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = tensorflowTaskExtraArgs.GetPsReplicas() - replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].ReplicaNum = tensorflowTaskExtraArgs.GetEvaluatorReplicas() + replicaSpec, err := common.ToReplicaSpec(ctx, taskCtx, kubeflowv1.TFJobDefaultContainerName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) + } + + replicaNumMap := map[commonOp.ReplicaType]int32{ + kubeflowv1.TFJobReplicaTypeChief: tensorflowTaskExtraArgs.GetChiefReplicas(), + kubeflowv1.TFJobReplicaTypeWorker: tensorflowTaskExtraArgs.GetWorkers(), + kubeflowv1.TFJobReplicaTypePS: tensorflowTaskExtraArgs.GetPsReplicas(), + kubeflowv1.TFJobReplicaTypeEval: tensorflowTaskExtraArgs.GetEvaluatorReplicas(), + } + for t, r := range replicaNumMap { + rs := replicaSpec.DeepCopy() + replicas := r + if replicas > 0 { + rs.Replicas = &replicas + replicaSpecMap[t] = rs + } + } } else if taskTemplate.TaskTypeVersion == 1 { kfTensorflowTaskExtraArgs := kfplugins.DistributedTensorflowTrainingTask{} @@ -106,68 +126,20 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - chiefReplicaSpec := kfTensorflowTaskExtraArgs.GetChiefReplicas() - if chiefReplicaSpec != nil { - err := common.OverrideContainerSpec( - replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].PodSpec, - kubeflowv1.TFJobDefaultContainerName, - chiefReplicaSpec.GetImage(), - chiefReplicaSpec.GetResources(), - nil, - ) - if err != nil { - return nil, err - } - replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].RestartPolicy = common.ParseRestartPolicy(chiefReplicaSpec.GetRestartPolicy()) - replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = chiefReplicaSpec.GetReplicas() + replicaSpecCfgMap := map[commonOp.ReplicaType]*kfplugins.DistributedTensorflowTrainingReplicaSpec{ + kubeflowv1.TFJobReplicaTypeChief: kfTensorflowTaskExtraArgs.GetChiefReplicas(), + kubeflowv1.TFJobReplicaTypeWorker: kfTensorflowTaskExtraArgs.GetWorkerReplicas(), + kubeflowv1.TFJobReplicaTypePS: kfTensorflowTaskExtraArgs.GetPsReplicas(), + kubeflowv1.TFJobReplicaTypeEval: kfTensorflowTaskExtraArgs.GetEvaluatorReplicas(), } - - workerReplicaSpec := kfTensorflowTaskExtraArgs.GetWorkerReplicas() - if workerReplicaSpec != nil { - err := common.OverrideContainerSpec( - replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].PodSpec, - kubeflowv1.TFJobDefaultContainerName, - workerReplicaSpec.GetImage(), - workerReplicaSpec.GetResources(), - nil, - ) + for t, cfg := range replicaSpecCfgMap { + rs, err := toReplicaSpecWithOverrides(ctx, taskCtx, cfg) if err != nil { - return nil, err + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) } - replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) - replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = workerReplicaSpec.GetReplicas() - } - - psReplicaSpec := kfTensorflowTaskExtraArgs.GetPsReplicas() - if psReplicaSpec != nil { - err := common.OverrideContainerSpec( - replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].PodSpec, - kubeflowv1.TFJobDefaultContainerName, - psReplicaSpec.GetImage(), - psReplicaSpec.GetResources(), - nil, - ) - if err != nil { - return nil, err + if rs != nil && *rs.Replicas > 0 { + replicaSpecMap[t] = rs } - replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].RestartPolicy = common.ParseRestartPolicy(psReplicaSpec.GetRestartPolicy()) - replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = psReplicaSpec.GetReplicas() - } - - evaluatorReplicaSpec := kfTensorflowTaskExtraArgs.GetEvaluatorReplicas() - if evaluatorReplicaSpec != nil { - err := common.OverrideContainerSpec( - replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].PodSpec, - kubeflowv1.TFJobDefaultContainerName, - evaluatorReplicaSpec.GetImage(), - evaluatorReplicaSpec.GetResources(), - nil, - ) - if err != nil { - return nil, err - } - replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].RestartPolicy = common.ParseRestartPolicy(evaluatorReplicaSpec.GetRestartPolicy()) - replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].ReplicaNum = evaluatorReplicaSpec.GetReplicas() } if kfTensorflowTaskExtraArgs.GetRunPolicy() != nil { @@ -179,33 +151,15 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum == 0 { + if v, ok := replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker]; !ok || *v.Replicas == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } - cfg := config.GetK8sPluginConfig() - objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - jobSpec := kubeflowv1.TFJobSpec{ - TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, + TFReplicaSpecs: replicaSpecMap, + RunPolicy: runPolicy, } - for replicaType, replicaEntry := range replicaSpecMap { - if replicaEntry.ReplicaNum > 0 { - jobSpec.TFReplicaSpecs[replicaType] = &commonOp.ReplicaSpec{ - Replicas: &replicaEntry.ReplicaNum, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *replicaEntry.PodSpec, - }, - RestartPolicy: replicaEntry.RestartPolicy, - } - } - } - - jobSpec.RunPolicy = runPolicy - job := &kubeflowv1.TFJob{ TypeMeta: metav1.TypeMeta{ Kind: kubeflowv1.TFJobKind, diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index bcabdaa87f..8f2a841a64 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -641,10 +641,12 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, {Name: core.Resources_GPU, Value: "1"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, {Name: core.Resources_GPU, Value: "1"}, }, }, @@ -654,9 +656,11 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, }, }, }, @@ -695,19 +699,23 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { kubeflowv1.TFJobReplicaTypeWorker: { Requests: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), }, Limits: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), }, }, kubeflowv1.TFJobReplicaTypePS: { Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), }, Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), }, }, kubeflowv1.TFJobReplicaTypeEval: { @@ -761,10 +769,12 @@ func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, {Name: core.Resources_GPU, Value: "1"}, }, Limits: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, {Name: core.Resources_GPU, Value: "1"}, }, }, @@ -775,10 +785,12 @@ func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { kubeflowv1.TFJobReplicaTypeWorker: { Requests: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), }, Limits: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), }, }, From b4587c238de3a2c9bbd67dbfb3bf08c5420a5f27 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 21 Nov 2023 14:27:07 -0800 Subject: [PATCH 04/11] cleanup Signed-off-by: Jeev B --- .../k8s/kfoperators/common/common_operator.go | 67 +++++++++++++++++-- .../tasks/plugins/k8s/kfoperators/mpi/mpi.go | 51 +------------- .../k8s/kfoperators/pytorch/pytorch.go | 51 +------------- .../k8s/kfoperators/tensorflow/tensorflow.go | 37 +--------- 4 files changed, 66 insertions(+), 140 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 9eb98e687d..9ab7ab17ef 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -28,12 +28,6 @@ const ( PytorchTaskType = "pytorch" ) -type ReplicaEntry struct { - PodSpec *v1.PodSpec - ReplicaNum int32 - RestartPolicy commonOp.RestartPolicy -} - // ExtractCurrentCondition will return the first job condition for tensorflow/pytorch func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { if jobConditions != nil { @@ -294,3 +288,64 @@ func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext RestartPolicy: commonOp.RestartPolicyNever, }, nil } + +type kfDistributedReplicaSpec interface { + GetReplicas() int32 + GetImage() string + GetResources() *core.Resources + GetRestartPolicy() kfplugins.RestartPolicy +} + +type allowsCommandOverride interface { + GetCommand() []string +} + +func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) { + taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} + // Master should always run as non-interruptible + if isMaster { + taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false)) + } + if rs != nil && rs.GetResources() != nil { + resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) + } + taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources)) + } + newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...) + replicaSpec, err := ToReplicaSpec(ctx, newTaskCtx, primaryContainerName) + if err != nil { + return nil, err + } + + // Master should have a single replica + if isMaster { + replicas := int32(1) + replicaSpec.Replicas = &replicas + } + + if rs != nil { + var command []string + if v, ok := rs.(allowsCommandOverride); ok { + command = v.GetCommand() + } + if err := OverrideContainerSpec( + &replicaSpec.Template.Spec, + primaryContainerName, + rs.GetImage(), + command, + ); err != nil { + return nil, err + } + + replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetRestartPolicy()) + + if !isMaster { + replicas := rs.GetReplicas() + replicaSpec.Replicas = &replicas + } + } + + return replicaSpec, nil +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 6d89fce341..8ce64c2e9c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -17,7 +17,6 @@ import ( flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" @@ -46,52 +45,6 @@ func (mpiOperatorResourceHandler) BuildIdentityResource(ctx context.Context, tas }, nil } -func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs *kfplugins.DistributedMPITrainingReplicaSpec, isLauncher bool) (*commonOp.ReplicaSpec, error) { - taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} - // Launcher should always run as non-interruptible - if isLauncher { - taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false)) - } - if rs != nil && rs.GetResources() != nil { - resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) - } - taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources)) - } - newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...) - replicaSpec, err := common.ToReplicaSpec(ctx, newTaskCtx, kubeflowv1.MPIJobDefaultContainerName) - if err != nil { - return nil, err - } - - // Launcher should have a single replica - if isLauncher { - replicas := int32(1) - replicaSpec.Replicas = &replicas - } - - if rs != nil { - if err := common.OverrideContainerSpec( - &replicaSpec.Template.Spec, - kubeflowv1.MPIJobDefaultContainerName, - rs.GetImage(), - rs.GetCommand(), - ); err != nil { - return nil, err - } - - replicaSpec.RestartPolicy = common.ParseRestartPolicy(rs.GetRestartPolicy()) - - if !isLauncher { - replicas := rs.GetReplicas() - replicaSpec.Replicas = &replicas - } - } - - return replicaSpec, nil -} - // Defines a func to create the full resource object that will be posted to k8s. func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) @@ -148,12 +101,12 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - launcherReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetLauncherReplicas(), true) + launcherReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetLauncherReplicas(), kubeflowv1.MPIJobDefaultContainerName, true) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create launcher replica spec: [%v]", err.Error()) } - workerReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetWorkerReplicas(), false) + workerReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetWorkerReplicas(), kubeflowv1.MPIJobDefaultContainerName, false) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create worker replica spec: [%v]", err.Error()) } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 8a7ea407ed..f9bd6cb01e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -16,7 +16,6 @@ import ( flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" @@ -43,52 +42,6 @@ func (pytorchOperatorResourceHandler) BuildIdentityResource(ctx context.Context, }, nil } -func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs *kfplugins.DistributedPyTorchTrainingReplicaSpec, isMaster bool) (*commonOp.ReplicaSpec, error) { - taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} - // Master should always run as non-interruptible - if isMaster { - taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false)) - } - if rs != nil && rs.GetResources() != nil { - resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) - } - taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources)) - } - newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...) - replicaSpec, err := common.ToReplicaSpec(ctx, newTaskCtx, kubeflowv1.PytorchJobDefaultContainerName) - if err != nil { - return nil, err - } - - // Master should have a single replica - if isMaster { - replicas := int32(1) - replicaSpec.Replicas = &replicas - } - - if rs != nil { - if err := common.OverrideContainerSpec( - &replicaSpec.Template.Spec, - kubeflowv1.PytorchJobDefaultContainerName, - rs.GetImage(), - nil, - ); err != nil { - return nil, err - } - - replicaSpec.RestartPolicy = common.ParseRestartPolicy(rs.GetRestartPolicy()) - - if !isMaster { - replicas := rs.GetReplicas() - replicaSpec.Replicas = &replicas - } - } - - return replicaSpec, nil -} - // Defines a func to create the full resource object that will be posted to k8s. func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) @@ -136,12 +89,12 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - masterReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetMasterReplicas(), true) + masterReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetMasterReplicas(), kubeflowv1.PytorchJobDefaultContainerName, true) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create master replica spec: [%v]", err.Error()) } - workerReplicaSpec, err = toReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetWorkerReplicas(), false) + workerReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetWorkerReplicas(), kubeflowv1.PytorchJobDefaultContainerName, false) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create worker replica spec: [%v]", err.Error()) } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 4ee9c1fa95..6e810bb4a1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -16,7 +16,6 @@ import ( flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" @@ -43,40 +42,6 @@ func (tensorflowOperatorResourceHandler) BuildIdentityResource(ctx context.Conte }, nil } -func toReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs *kfplugins.DistributedTensorflowTrainingReplicaSpec) (*commonOp.ReplicaSpec, error) { - taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} - if rs != nil && rs.GetResources() != nil { - resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) - } - taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources)) - } - newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...) - replicaSpec, err := common.ToReplicaSpec(ctx, newTaskCtx, kubeflowv1.TFJobDefaultContainerName) - if err != nil { - return nil, err - } - - if rs != nil { - if err := common.OverrideContainerSpec( - &replicaSpec.Template.Spec, - kubeflowv1.TFJobDefaultContainerName, - rs.GetImage(), - nil, - ); err != nil { - return nil, err - } - - replicaSpec.RestartPolicy = common.ParseRestartPolicy(rs.GetRestartPolicy()) - - replicas := rs.GetReplicas() - replicaSpec.Replicas = &replicas - } - - return replicaSpec, nil -} - // Defines a func to create the full resource object that will be posted to k8s. func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) @@ -133,7 +98,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobReplicaTypeEval: kfTensorflowTaskExtraArgs.GetEvaluatorReplicas(), } for t, cfg := range replicaSpecCfgMap { - rs, err := toReplicaSpecWithOverrides(ctx, taskCtx, cfg) + rs, err := common.ToReplicaSpecWithOverrides(ctx, taskCtx, cfg, kubeflowv1.TFJobDefaultContainerName, false) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) } From 34ff26f1a163c70f69d4bfcf9901a91faaaacc1a Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 21 Nov 2023 14:29:31 -0800 Subject: [PATCH 05/11] cleanup Signed-off-by: Jeev B --- .../go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 2 +- .../go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index f9bd6cb01e..e04c15838c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -112,7 +112,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if *workerReplicaSpec.Replicas == 0 { + if *workerReplicaSpec.Replicas <= 0 { return nil, fmt.Errorf("number of worker should be more then 0") } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 6e810bb4a1..b6dcc8e999 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -116,7 +116,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if v, ok := replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker]; !ok || *v.Replicas == 0 { + if v, ok := replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker]; !ok || *v.Replicas <= 0 { return nil, fmt.Errorf("number of worker should be more then 0") } From 236846e725d3b721f67989f52a38adfd6bbacadb Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 21 Nov 2023 14:31:59 -0800 Subject: [PATCH 06/11] cleanup Signed-off-by: Jeev B --- .../go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 8ce64c2e9c..32aa26c556 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -71,12 +71,17 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) } - workerReplicaSpec = replicaSpec.DeepCopy() - workerReplicas := mpiTaskExtraArgs.GetNumWorkers() - workerReplicaSpec.Replicas = &workerReplicas launcherReplicaSpec = replicaSpec.DeepCopy() + // TODO (jeev): Is this even a valid configuration. Can there be more than 1 + // launcher? TaskTypeVersion 1 does not support overriding this value. launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() + if launcherReplicas < 1 { + launcherReplicas = 1 + } launcherReplicaSpec.Replicas = &launcherReplicas + workerReplicaSpec = replicaSpec.DeepCopy() + workerReplicas := mpiTaskExtraArgs.GetNumWorkers() + workerReplicaSpec.Replicas = &workerReplicas slots = mpiTaskExtraArgs.GetSlots() // V1 requires passing worker command as template config parameter From f684b0df61699f7df5b5405b84530c7a657114ae Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 21 Nov 2023 14:40:47 -0800 Subject: [PATCH 07/11] cleanup Signed-off-by: Jeev B --- .../plugins/k8s/kfoperators/tensorflow/tensorflow.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index b6dcc8e999..cb9c6d5000 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -98,13 +98,16 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobReplicaTypeEval: kfTensorflowTaskExtraArgs.GetEvaluatorReplicas(), } for t, cfg := range replicaSpecCfgMap { + // Short circuit if replica set has no replicas to avoid unnecessarily + // generating pod specs + if cfg.GetReplicas() <= 0 { + continue + } rs, err := common.ToReplicaSpecWithOverrides(ctx, taskCtx, cfg, kubeflowv1.TFJobDefaultContainerName, false) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) } - if rs != nil && *rs.Replicas > 0 { - replicaSpecMap[t] = rs - } + replicaSpecMap[t] = rs } if kfTensorflowTaskExtraArgs.GetRunPolicy() != nil { From cfc2f24befd6130ab297512ada2479a7675a6199 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 21 Nov 2023 14:47:14 -0800 Subject: [PATCH 08/11] fixes Signed-off-by: Jeev B --- .../go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index f283dece94..fc570bf27d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -339,12 +339,6 @@ func TestBuildResourceMPIForWrongInput(t *testing.T) { _, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) assert.Error(t, err) - mpiObj = dummyMPICustomObj(1, 0, 1) - taskTemplate = dummyMPITaskTemplate(mpiID2, mpiObj) - - _, err = mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) - assert.Error(t, err) - mpiObj = dummyMPICustomObj(1, 1, 1) taskTemplate = dummyMPITaskTemplate(mpiID2, mpiObj) @@ -561,8 +555,8 @@ func TestReplicaCounts(t *testing.T) { contains []mpiOp.ReplicaType notContains []mpiOp.ReplicaType }{ - {"NoWorkers", 0, 1, true, nil, nil}, - {"NoLaunchers", 1, 0, true, nil, nil}, + {"NoWorkers", 1, 0, true, nil, nil}, + {"Minimum One Launcher", 0, 1, false, []mpiOp.ReplicaType{kubeflowv1.MPIJobReplicaTypeLauncher, kubeflowv1.MPIJobReplicaTypeWorker}, []mpiOp.ReplicaType{}}, {"Works", 1, 1, false, []mpiOp.ReplicaType{kubeflowv1.MPIJobReplicaTypeLauncher, kubeflowv1.MPIJobReplicaTypeWorker}, []mpiOp.ReplicaType{}}, } { t.Run(test.name, func(t *testing.T) { From ca1a839d4c0c4994f66dd0b4faef622fa9686381 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 21 Nov 2023 15:24:00 -0800 Subject: [PATCH 09/11] Add tests for resource tolerations Signed-off-by: Jeev B --- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 60 ++++++++++++++++++ .../k8s/kfoperators/pytorch/pytorch_test.go | 60 ++++++++++++++++++ .../kfoperators/tensorflow/tensorflow_test.go | 61 +++++++++++++++++++ 3 files changed, 181 insertions(+) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index fc570bf27d..1cb6e9d826 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -716,3 +716,63 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { assert.Equal(t, testArgs, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) } + +func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) { + gpuToleration := corev1.Toleration{ + Key: "nvidia.com/gpu", + Value: "present", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + } + assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{ + GpuResourceName: flytek8s.ResourceNvidiaGPU, + ResourceTolerations: map[corev1.ResourceName][]corev1.Toleration{ + flytek8s.ResourceNvidiaGPU: {gpuToleration}, + }, + })) + + taskConfig := &kfplugins.DistributedMPITrainingTask{ + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + }, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, + }, + } + + mpiResourceHandler := mpiOperatorResourceHandler{} + + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + + assert.NotContains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Tolerations, gpuToleration) + assert.Contains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 7ed46667ec..f0e215f262 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -840,6 +840,66 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { assert.Nil(t, pytorchJob.Spec.ElasticPolicy) } +func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { + gpuToleration := corev1.Toleration{ + Key: "nvidia.com/gpu", + Value: "present", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + } + assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{ + GpuResourceName: flytek8s.ResourceNvidiaGPU, + ResourceTolerations: map[corev1.ResourceName][]corev1.Toleration{ + flytek8s.ResourceNvidiaGPU: {gpuToleration}, + }, + })) + + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + }, + }, + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, + }, + } + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + assert.NotContains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Tolerations, gpuToleration) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) +} + func TestBuildResourcePytorchV1WithElastic(t *testing.T) { taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 8f2a841a64..485e73b74a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -824,3 +824,64 @@ func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { assert.True(t, hasContainerWithDefaultTensorFlowName) } } + +func TestBuildResourceTensorFlowV1ResourceTolerations(t *testing.T) { + gpuToleration := corev1.Toleration{ + Key: "nvidia.com/gpu", + Value: "present", + Operator: corev1.TolerationOpEqual, + Effect: corev1.TaintEffectNoSchedule, + } + assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{ + GpuResourceName: flytek8s.ResourceNvidiaGPU, + ResourceTolerations: map[corev1.ResourceName][]corev1.Toleration{ + flytek8s.ResourceNvidiaGPU: {gpuToleration}, + }, + })) + + taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, + }, + } + + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + + taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + + assert.NotContains(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Template.Spec.Tolerations, gpuToleration) + assert.Contains(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) +} From 1d29e0246675225d02c0f5120ed1277e1d01a560 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Wed, 22 Nov 2023 16:33:20 -0800 Subject: [PATCH 10/11] PR comments Signed-off-by: Jeev B --- flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 4 ++-- .../go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 2 +- .../go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 32aa26c556..826e83b671 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -126,10 +126,10 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu } if *workerReplicaSpec.Replicas <= 0 { - return nil, fmt.Errorf("number of worker should be more then 0") + return nil, fmt.Errorf("number of workers must be greater than 0") } if *launcherReplicaSpec.Replicas <= 0 { - return nil, fmt.Errorf("number of launch worker should be more then 0") + return nil, fmt.Errorf("number of launchers must be greater than 0") } jobSpec := kubeflowv1.MPIJobSpec{ diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index e04c15838c..81c8e16cd5 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -113,7 +113,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx } if *workerReplicaSpec.Replicas <= 0 { - return nil, fmt.Errorf("number of worker should be more then 0") + return nil, fmt.Errorf("number of workers must be greater than 0") } jobSpec := kubeflowv1.PyTorchJobSpec{ diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index cb9c6d5000..eae5622128 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -120,7 +120,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task } if v, ok := replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker]; !ok || *v.Replicas <= 0 { - return nil, fmt.Errorf("number of worker should be more then 0") + return nil, fmt.Errorf("number of workers must be greater than 0") } jobSpec := kubeflowv1.TFJobSpec{ From 2116ec277e5b8f94d4674352a58681775561fad5 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Thu, 30 Nov 2023 08:54:40 -0800 Subject: [PATCH 11/11] Revert forcing non-interruptibility on kfoperator masters Signed-off-by: Jeev B --- .../tasks/plugins/k8s/kfoperators/common/common_operator.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 9ab7ab17ef..74021df2be 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -302,10 +302,6 @@ type allowsCommandOverride interface { func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) { taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} - // Master should always run as non-interruptible - if isMaster { - taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false)) - } if rs != nil && rs.GetResources() != nil { resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) if err != nil {