Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly handle resource overrides in KF plugins #4467

Merged
merged 15 commits into from
Nov 30, 2023

This file was deleted.

123 changes: 123 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package flytek8s

import (
v1 "k8s.io/api/core/v1"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
)

type pluginTaskOverrides struct {
pluginsCore.TaskOverrides
resources *v1.ResourceRequirements
extendedResources *core.ExtendedResources
}

func (to *pluginTaskOverrides) GetResources() *v1.ResourceRequirements {
if to.resources != nil {
return to.resources
}
return to.TaskOverrides.GetResources()

Check warning on line 20 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L16-L20

Added lines #L16 - L20 were not covered by tests
}

func (to *pluginTaskOverrides) GetExtendedResources() *core.ExtendedResources {
if to.extendedResources != nil {
return to.extendedResources
}
return to.TaskOverrides.GetExtendedResources()

Check warning on line 27 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L23-L27

Added lines #L23 - L27 were not covered by tests
}

type pluginTaskExecutionMetadata struct {
pluginsCore.TaskExecutionMetadata
interruptible *bool
overrides *pluginTaskOverrides
}

func (tm *pluginTaskExecutionMetadata) IsInterruptible() bool {
if tm.interruptible != nil {
return *tm.interruptible
}
return tm.TaskExecutionMetadata.IsInterruptible()

Check warning on line 40 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L36-L40

Added lines #L36 - L40 were not covered by tests
}

func (tm *pluginTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides {
if tm.overrides != nil {
return tm.overrides
}
return tm.TaskExecutionMetadata.GetOverrides()

Check warning on line 47 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L43-L47

Added lines #L43 - L47 were not covered by tests
}

type pluginTaskExecutionContext struct {
pluginsCore.TaskExecutionContext
metadata *pluginTaskExecutionMetadata
}

func (tc *pluginTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata {
if tc.metadata != nil {
return tc.metadata
}
return tc.TaskExecutionContext.TaskExecutionMetadata()

Check warning on line 59 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L55-L59

Added lines #L55 - L59 were not covered by tests
}

type PluginTaskExecutionContextOption func(*pluginTaskExecutionContext)

func WithInterruptible(v bool) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
tc.metadata.interruptible = &v

Check warning on line 71 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L64-L71

Added lines #L64 - L71 were not covered by tests
}
}

func WithResources(r *v1.ResourceRequirements) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
if tc.metadata.overrides == nil {
tc.metadata.overrides = &pluginTaskOverrides{
TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(),
}
}
tc.metadata.overrides.resources = r

Check warning on line 87 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L75-L87

Added lines #L75 - L87 were not covered by tests
}
}

func WithExtendedResources(er *core.ExtendedResources) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
if tc.metadata.overrides == nil {
tc.metadata.overrides = &pluginTaskOverrides{
TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(),
}
}
tc.metadata.overrides.extendedResources = er

Check warning on line 103 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L91-L103

Added lines #L91 - L103 were not covered by tests
}
}

func NewPluginTaskExecutionContext(tc pluginsCore.TaskExecutionContext, options ...PluginTaskExecutionContextOption) pluginsCore.TaskExecutionContext {
tm := tc.TaskExecutionMetadata()
to := tm.GetOverrides()
ctx := &pluginTaskExecutionContext{
TaskExecutionContext: tc,
metadata: &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tm,
overrides: &pluginTaskOverrides{
TaskOverrides: to,
},
},
}
for _, o := range options {
o(ctx)
}
return ctx

Check warning on line 122 in flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go#L107-L122

Added lines #L107 - L122 were not covered by tests
}
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/k8s/dask/dask.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
if err != nil {
return nil, err
}
nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx)
nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false))
nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"fmt"
"sort"
"time"
Expand All @@ -15,8 +16,10 @@
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
)

const (
Expand All @@ -25,12 +28,6 @@
PytorchTaskType = "pytorch"
)

type ReplicaEntry struct {
PodSpec *v1.PodSpec
ReplicaNum int32
RestartPolicy commonOp.RestartPolicy
}

// ExtractCurrentCondition will return the first job condition for tensorflow/pytorch
func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
if jobConditions != nil {
Expand Down Expand Up @@ -254,27 +251,101 @@
}

// OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function
// updates the image, resources and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) error {
// updates the image and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, args []string) error {
for idx, c := range podSpec.Containers {
if c.Name == containerName {
if image != "" {
podSpec.Containers[idx].Image = image
}
if resources != nil {
// if resources requests and limits both not set, we will not override the resources
if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 {
resources, err := flytek8s.ToK8sResourceRequirements(resources)
if err != nil {
return flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
podSpec.Containers[idx].Resources = *resources
}
}
if len(args) != 0 {
podSpec.Containers[idx].Args = args
}
}
}
return nil
}

func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string) (*commonOp.ReplicaSpec, error) {
podSpec, objectMeta, oldPrimaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

Check warning on line 273 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L269-L273

Added lines #L269 - L273 were not covered by tests

OverridePrimaryContainerName(podSpec, oldPrimaryContainerName, primaryContainerName)

cfg := config.GetK8sPluginConfig()
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

replicas := int32(0)
return &commonOp.ReplicaSpec{
Replicas: &replicas,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
}, nil

Check warning on line 289 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L275-L289

Added lines #L275 - L289 were not covered by tests
}

type kfDistributedReplicaSpec interface {
GetReplicas() int32
GetImage() string
GetResources() *core.Resources
GetRestartPolicy() kfplugins.RestartPolicy
}

type allowsCommandOverride interface {
GetCommand() []string
}

func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) {
taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{}
// Master should always run as non-interruptible
if isMaster {
taskCtxOptions = append(taskCtxOptions, flytek8s.WithInterruptible(false))
Copy link
Contributor Author

@jeevb jeevb Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that "master" pods will be tagged as non-interruptible. Currently this is set for PyTorch masters and MPI launchers.

}
if rs != nil && rs.GetResources() != nil {
resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources())
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources))

Check warning on line 314 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L303-L314

Added lines #L303 - L314 were not covered by tests
}
newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...)
replicaSpec, err := ToReplicaSpec(ctx, newTaskCtx, primaryContainerName)
if err != nil {
return nil, err
}

Check warning on line 320 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L316-L320

Added lines #L316 - L320 were not covered by tests

// Master should have a single replica
if isMaster {
replicas := int32(1)
replicaSpec.Replicas = &replicas
}

Check warning on line 326 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L323-L326

Added lines #L323 - L326 were not covered by tests

if rs != nil {
var command []string
if v, ok := rs.(allowsCommandOverride); ok {
command = v.GetCommand()
}
if err := OverrideContainerSpec(
&replicaSpec.Template.Spec,
primaryContainerName,
rs.GetImage(),
command,
); err != nil {
return nil, err
}

Check warning on line 340 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L328-L340

Added lines #L328 - L340 were not covered by tests

replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetRestartPolicy())

if !isMaster {
replicas := rs.GetReplicas()
replicaSpec.Replicas = &replicas
}

Check warning on line 347 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L342-L347

Added lines #L342 - L347 were not covered by tests
}

return replicaSpec, nil

Check warning on line 350 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L350

Added line #L350 was not covered by tests
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ func dummyPodSpec() v1.PodSpec {
return v1.PodSpec{
Containers: []v1.Container{
{
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Image: "dummy-image",
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"cpu": resource.MustParse("2"),
Expand Down Expand Up @@ -270,50 +271,21 @@ func TestOverrideContainerSpec(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(
&podSpec, "primary container", "testing-image",
&core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
},
},
[]string{"python", "-m", "run.py"},
)
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, "testing-image", podSpec.Containers[0].Image)
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m")))
assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args)
}

func TestOverrideContainerSpecEmptyFields(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1")))
assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi")))
}

func TestOverrideContainerNilResources(t *testing.T) {
podSpec := dummyPodSpec()
podSpecCopy := podSpec.DeepCopy()

err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{})
err := OverrideContainerSpec(&podSpec, "primary container", "", []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, podSpec.Containers[0].Resources.Limits, podSpecCopy.Containers[0].Resources.Limits)
assert.Equal(t, podSpec.Containers[0].Resources.Requests, podSpecCopy.Containers[0].Resources.Requests)
assert.Equal(t, "dummy-image", podSpec.Containers[0].Image)
assert.Equal(t, []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, podSpec.Containers[0].Args)
}

func dummyTaskContext() pluginsCore.TaskExecutionContext {
Expand Down
Loading
Loading