Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly handle resource overrides in KF plugins #4467

Merged
merged 15 commits into from
Nov 30, 2023

This file was deleted.

123 changes: 123 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package flytek8s

import (
v1 "k8s.io/api/core/v1"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
)

type pluginTaskOverrides struct {
pluginsCore.TaskOverrides
resources *v1.ResourceRequirements
extendedResources *core.ExtendedResources
}

func (to *pluginTaskOverrides) GetResources() *v1.ResourceRequirements {
if to.resources != nil {
return to.resources
}
return to.TaskOverrides.GetResources()
}

func (to *pluginTaskOverrides) GetExtendedResources() *core.ExtendedResources {
if to.extendedResources != nil {
return to.extendedResources
}
return to.TaskOverrides.GetExtendedResources()
}

type pluginTaskExecutionMetadata struct {
pluginsCore.TaskExecutionMetadata
interruptible *bool
overrides *pluginTaskOverrides
}

func (tm *pluginTaskExecutionMetadata) IsInterruptible() bool {
if tm.interruptible != nil {
return *tm.interruptible
}
return tm.TaskExecutionMetadata.IsInterruptible()
}

func (tm *pluginTaskExecutionMetadata) GetOverrides() pluginsCore.TaskOverrides {
if tm.overrides != nil {
return tm.overrides
}
return tm.TaskExecutionMetadata.GetOverrides()
}

type pluginTaskExecutionContext struct {
pluginsCore.TaskExecutionContext
metadata *pluginTaskExecutionMetadata
}

func (tc *pluginTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata {
if tc.metadata != nil {
return tc.metadata
}
return tc.TaskExecutionContext.TaskExecutionMetadata()
}

type PluginTaskExecutionContextOption func(*pluginTaskExecutionContext)

func WithInterruptible(v bool) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
tc.metadata.interruptible = &v
}
}

func WithResources(r *v1.ResourceRequirements) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
if tc.metadata.overrides == nil {
tc.metadata.overrides = &pluginTaskOverrides{
TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(),
}
}
tc.metadata.overrides.resources = r
}
}

func WithExtendedResources(er *core.ExtendedResources) PluginTaskExecutionContextOption {
return func(tc *pluginTaskExecutionContext) {
if tc.metadata == nil {
tc.metadata = &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tc.TaskExecutionContext.TaskExecutionMetadata(),
}
}
if tc.metadata.overrides == nil {
tc.metadata.overrides = &pluginTaskOverrides{
TaskOverrides: tc.metadata.TaskExecutionMetadata.GetOverrides(),
}
}
tc.metadata.overrides.extendedResources = er
}
}

func NewPluginTaskExecutionContext(tc pluginsCore.TaskExecutionContext, options ...PluginTaskExecutionContextOption) pluginsCore.TaskExecutionContext {
tm := tc.TaskExecutionMetadata()
to := tm.GetOverrides()
ctx := &pluginTaskExecutionContext{
TaskExecutionContext: tc,
metadata: &pluginTaskExecutionMetadata{
TaskExecutionMetadata: tm,
overrides: &pluginTaskOverrides{
TaskOverrides: to,
},
},
}
for _, o := range options {
o(ctx)
}
return ctx
}
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/k8s/dask/dask.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
if err != nil {
return nil, err
}
nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx)
nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false))
nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx)
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"fmt"
"sort"
"time"
Expand All @@ -15,8 +16,10 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
)

const (
Expand All @@ -25,12 +28,6 @@ const (
PytorchTaskType = "pytorch"
)

type ReplicaEntry struct {
PodSpec *v1.PodSpec
ReplicaNum int32
RestartPolicy commonOp.RestartPolicy
}

// ExtractCurrentCondition will return the first job condition for tensorflow/pytorch
func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
if jobConditions != nil {
Expand Down Expand Up @@ -254,27 +251,97 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res
}

// OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function
// updates the image, resources and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) error {
// updates the image and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, args []string) error {
for idx, c := range podSpec.Containers {
if c.Name == containerName {
if image != "" {
podSpec.Containers[idx].Image = image
}
if resources != nil {
// if resources requests and limits both not set, we will not override the resources
if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 {
resources, err := flytek8s.ToK8sResourceRequirements(resources)
if err != nil {
return flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
podSpec.Containers[idx].Resources = *resources
}
}
if len(args) != 0 {
podSpec.Containers[idx].Args = args
}
}
}
return nil
}

func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string) (*commonOp.ReplicaSpec, error) {
podSpec, objectMeta, oldPrimaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

OverridePrimaryContainerName(podSpec, oldPrimaryContainerName, primaryContainerName)

cfg := config.GetK8sPluginConfig()
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

replicas := int32(0)
return &commonOp.ReplicaSpec{
Replicas: &replicas,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
}, nil
}

type kfDistributedReplicaSpec interface {
GetReplicas() int32
GetImage() string
GetResources() *core.Resources
GetRestartPolicy() kfplugins.RestartPolicy
}

type allowsCommandOverride interface {
GetCommand() []string
}

func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) {
taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{}
if rs != nil && rs.GetResources() != nil {
resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources())
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources))
}
newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...)
replicaSpec, err := ToReplicaSpec(ctx, newTaskCtx, primaryContainerName)
if err != nil {
return nil, err
}

// Master should have a single replica
if isMaster {
replicas := int32(1)
replicaSpec.Replicas = &replicas
}

if rs != nil {
var command []string
if v, ok := rs.(allowsCommandOverride); ok {
command = v.GetCommand()
}
if err := OverrideContainerSpec(
&replicaSpec.Template.Spec,
primaryContainerName,
rs.GetImage(),
command,
); err != nil {
return nil, err
}

replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetRestartPolicy())

if !isMaster {
replicas := rs.GetReplicas()
replicaSpec.Replicas = &replicas
}
}

return replicaSpec, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ func dummyPodSpec() v1.PodSpec {
return v1.PodSpec{
Containers: []v1.Container{
{
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Image: "dummy-image",
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"cpu": resource.MustParse("2"),
Expand Down Expand Up @@ -270,50 +271,21 @@ func TestOverrideContainerSpec(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(
&podSpec, "primary container", "testing-image",
&core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
},
},
[]string{"python", "-m", "run.py"},
)
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, "testing-image", podSpec.Containers[0].Image)
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m")))
assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args)
}

func TestOverrideContainerSpecEmptyFields(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1")))
assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi")))
}

func TestOverrideContainerNilResources(t *testing.T) {
podSpec := dummyPodSpec()
podSpecCopy := podSpec.DeepCopy()

err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{})
err := OverrideContainerSpec(&podSpec, "primary container", "", []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, podSpec.Containers[0].Resources.Limits, podSpecCopy.Containers[0].Resources.Limits)
assert.Equal(t, podSpec.Containers[0].Resources.Requests, podSpecCopy.Containers[0].Resources.Requests)
assert.Equal(t, "dummy-image", podSpec.Containers[0].Image)
assert.Equal(t, []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, podSpec.Containers[0].Args)
}

func dummyTaskContext() pluginsCore.TaskExecutionContext {
Expand Down
Loading
Loading