diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 12184d187847..a150cb40d875 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -480,6 +480,28 @@ func extendPodSpecPatch( podSpec.NodeSelector = kubernetesExecutorConfig.GetNodeSelector().GetLabels() } + if tolerations := kubernetesExecutorConfig.GetTolerations(); tolerations != nil { + var k8sTolerations []k8score.Toleration + + glog.Infof("Tolerations passed: %+v", tolerations) + + for _, toleration := range tolerations { + if toleration != nil { + k8sToleration := k8score.Toleration{ + Key: toleration.Key, + Operator: k8score.TolerationOperator(toleration.Operator), + Value: toleration.Value, + Effect: k8score.TaintEffect(toleration.Effect), + TolerationSeconds: toleration.TolerationSeconds, + } + + k8sTolerations = append(k8sTolerations, k8sToleration) + } + } + + podSpec.Tolerations = k8sTolerations + } + // Get secret mount information for _, secretAsVolume := range kubernetesExecutorConfig.GetSecretAsVolume() { secretVolume := k8score.Volume{ diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index ff950cda13cb..acf8d2ed3562 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -671,3 +671,87 @@ func Test_extendPodSpecPatch_ImagePullSecrets(t *testing.T) { }) } } + +func Test_extendPodSpecPatch_Tolerations(t *testing.T) { + tests := []struct { + name string + k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig + expected *k8score.PodSpec + }{ + { + "Valid - toleration", + &kubernetesplatform.KubernetesExecutorConfig{ + Tolerations: []*kubernetesplatform.Toleration{ + { + Key: "key1", + Operator: "Equal", + Value: "value1", + Effect: "NoSchedule", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + Tolerations: []k8score.Toleration{ + { + Key: "key1", + Operator: "Equal", + Value: "value1", + Effect: "NoSchedule", + TolerationSeconds: nil, + }, + }, + }, + }, + { + "Valid - no tolerations", + &kubernetesplatform.KubernetesExecutorConfig{}, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + }, + { + "Valid - only pass operator", + &kubernetesplatform.KubernetesExecutorConfig{ + Tolerations: []*kubernetesplatform.Toleration{ + { + Operator: "Contains", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + Tolerations: []k8score.Toleration{ + { + Operator: "Contains", + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &k8score.PodSpec{Containers: []k8score.Container{ + { + Name: "main", + }, + }} + err := extendPodSpecPatch(got, tt.k8sExecCfg, nil, nil) + assert.Nil(t, err) + assert.NotNil(t, got) + assert.Equal(t, tt.expected, got) + }) + } +} diff --git a/kubernetes_platform/python/kfp/kubernetes/__init__.py b/kubernetes_platform/python/kfp/kubernetes/__init__.py index 322bf7a305ba..9197148fe085 100644 --- a/kubernetes_platform/python/kfp/kubernetes/__init__.py +++ b/kubernetes_platform/python/kfp/kubernetes/__init__.py @@ -23,15 +23,17 @@ 'add_node_selector', 'add_pod_label', 'add_pod_annotation', - 'set_image_pull_secrets' + 'set_image_pull_secrets', + 'add_tolerations', ] -from kfp.kubernetes.pod_metadata import add_pod_label -from kfp.kubernetes.pod_metadata import add_pod_annotation +from kfp.kubernetes.image import set_image_pull_secrets from kfp.kubernetes.node_selector import add_node_selector +from kfp.kubernetes.pod_metadata import add_pod_annotation +from kfp.kubernetes.pod_metadata import add_pod_label from kfp.kubernetes.secret import use_secret_as_env from kfp.kubernetes.secret import use_secret_as_volume +from kfp.kubernetes.toleration import add_tolerations from kfp.kubernetes.volume import CreatePVC from kfp.kubernetes.volume import DeletePVC from kfp.kubernetes.volume import mount_pvc -from kfp.kubernetes.image import set_image_pull_secrets diff --git a/kubernetes_platform/python/kfp/kubernetes/toleration.py b/kubernetes_platform/python/kfp/kubernetes/toleration.py new file mode 100644 index 000000000000..700d4c3ca9be --- /dev/null +++ b/kubernetes_platform/python/kfp/kubernetes/toleration.py @@ -0,0 +1,52 @@ +# Copyright 2023 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from google.protobuf import json_format +from kfp.dsl import PipelineTask +from kfp.kubernetes import common +from kfp.kubernetes import kubernetes_executor_config_pb2 as pb +from kubernetes.client import V1Toleration + + +def add_tolerations( + task: PipelineTask, + tolerations: List[V1Toleration], +) -> PipelineTask: + """Add `tolerations`_. to a task. + + Args: + task: Pipeline task. + tolerations: A list of V1Tolerations defined using the Kubernetes Python Client. + + Returns: + Task object with added tolerations. + """ + tolerations_pb = [] + for toleration in tolerations: + tolerations_pb.append( + pb.Toleration( + key=toleration.key, + operator=toleration.operator, + value=toleration.value, + effect=toleration.effect, + toleration_seconds=toleration.toleration_seconds, + )) + + msg = common.get_existing_kubernetes_config_as_message(task) + msg.tolerations.extend(tolerations_pb) + task.platform_config['kubernetes'] = json_format.MessageToDict(msg) + + return task diff --git a/kubernetes_platform/python/test/snapshot/data/toleration.py b/kubernetes_platform/python/test/snapshot/data/toleration.py new file mode 100644 index 000000000000..d3015d7903c0 --- /dev/null +++ b/kubernetes_platform/python/test/snapshot/data/toleration.py @@ -0,0 +1,45 @@ +# Copyright 2023 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kfp import dsl +from kfp import kubernetes +from kubernetes.client import V1Toleration + + +@dsl.component +def comp(): + pass + + +@dsl.pipeline +def my_pipeline(): + task = comp() + kubernetes.add_tolerations( + task, + [ + V1Toleration( + key='key1', + operator='Equal', + value='value1', + effect='NoExecute', + toleration_seconds=10, + ), + ], + ) + + +if __name__ == '__main__': + from kfp import compiler + + compiler.Compiler().compile(my_pipeline, __file__.replace('.py', '.yaml')) diff --git a/kubernetes_platform/python/test/snapshot/data/toleration.yaml b/kubernetes_platform/python/test/snapshot/data/toleration.yaml new file mode 100644 index 000000000000..f8f23798c61e --- /dev/null +++ b/kubernetes_platform/python/test/snapshot/data/toleration.yaml @@ -0,0 +1,61 @@ +# PIPELINE DEFINITION +# Name: my-pipeline +components: + comp-comp: + executorLabel: exec-comp +deploymentSpec: + executors: + exec-comp: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - comp + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.6.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef comp():\n pass\n\n" + image: python:3.7 +pipelineInfo: + name: my-pipeline +root: + dag: + tasks: + comp: + cachingOptions: + enableCache: true + componentRef: + name: comp-comp + taskInfo: + name: comp +schemaVersion: 2.1.0 +sdkVersion: kfp-2.6.0 +--- +platforms: + kubernetes: + deploymentSpec: + executors: + exec-comp: + tolerations: + - effect: NoExecute + key: key1 + operator: Equal + tolerationSeconds: '10' + value: value1 diff --git a/kubernetes_platform/python/test/unit/test_tolerations.py b/kubernetes_platform/python/test/unit/test_tolerations.py new file mode 100644 index 000000000000..a2d47a9118db --- /dev/null +++ b/kubernetes_platform/python/test/unit/test_tolerations.py @@ -0,0 +1,193 @@ +# Copyright 2023 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.protobuf import json_format +from kfp import compiler +from kfp import dsl +from kfp import kubernetes +from kubernetes.client import V1Toleration + + +class TestTolerations: + + def test_add_one(self): + + @dsl.pipeline + def my_pipeline(): + task = comp() + kubernetes.add_tolerations( + task, + [ + V1Toleration( + key='key1', + operator='Equal', + value='value1', + effect='NoSchedule', + ), + ], + ) + + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path='my_pipeline.yaml') + + assert json_format.MessageToDict(my_pipeline.platform_spec) == { + 'platforms': { + 'kubernetes': { + 'deploymentSpec': { + 'executors': { + 'exec-comp': { + 'tolerations': [{ + 'key': 'key1', + 'operator': 'Equal', + 'value': 'value1', + 'effect': 'NoSchedule', + }] + } + } + } + } + } + } + + def test_add_one_with_toleration_seconds(self): + + @dsl.pipeline + def my_pipeline(): + task = comp() + kubernetes.add_tolerations( + task, + [ + V1Toleration( + key='key1', + operator='Equal', + value='value1', + effect='NoExecute', + toleration_seconds=10, + ), + ], + ) + + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path='my_pipeline.yaml') + + assert json_format.MessageToDict(my_pipeline.platform_spec) == { + 'platforms': { + 'kubernetes': { + 'deploymentSpec': { + 'executors': { + 'exec-comp': { + 'tolerations': [{ + 'key': 'key1', + 'operator': 'Equal', + 'value': 'value1', + 'effect': 'NoExecute', + 'tolerationSeconds': '10', + }] + } + } + } + } + } + } + + def test_add_two(self): + + @dsl.pipeline + def my_pipeline(): + task = comp() + kubernetes.add_tolerations( + task, + [ + V1Toleration( + key='key1', + operator='Equal', + value='value1', + ), + V1Toleration( + key='key2', + operator='Equal', + value='value2', + ), + ], + ) + + assert json_format.MessageToDict(my_pipeline.platform_spec) == { + 'platforms': { + 'kubernetes': { + 'deploymentSpec': { + 'executors': { + 'exec-comp': { + 'tolerations': [ + { + 'key': 'key1', + 'operator': 'Equal', + 'value': 'value1', + }, + { + 'key': 'key2', + 'operator': 'Equal', + 'value': 'value2', + }, + ] + } + } + } + } + } + } + + def test_respects_other_configuration(self): + + @dsl.pipeline + def my_pipeline(): + task = comp() + kubernetes.use_secret_as_volume( + task, secret_name='my-secret', mount_path='/mnt/my_vol') + kubernetes.add_tolerations( + task, + [ + V1Toleration( + key='key1', + operator='Equal', + value='value1', + ), + ], + ) + + assert json_format.MessageToDict(my_pipeline.platform_spec) == { + 'platforms': { + 'kubernetes': { + 'deploymentSpec': { + 'executors': { + 'exec-comp': { + 'tolerations': [{ + 'key': 'key1', + 'operator': 'Equal', + 'value': 'value1', + },], + 'secretAsVolume': [{ + 'secretName': 'my-secret', + 'mountPath': '/mnt/my_vol', + },], + }, + } + } + } + } + } + + +@dsl.component +def comp(): + pass