Skip to content

Commit

Permalink
Address PR review 1
Browse files Browse the repository at this point in the history
Signed-off-by: droctothorpe <[email protected]>
Co-authored-by: edmondop <[email protected]>
Co-authored-by: tarat44 <[email protected]>
  • Loading branch information
3 people committed Feb 16, 2024
1 parent b5405dd commit 2078ee4
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 83 deletions.
24 changes: 12 additions & 12 deletions kubernetes_platform/python/kfp/kubernetes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "1.1.0"
__version__ = '1.1.0'

__all__ = [
"add_node_selector",
"add_pod_annotation",
"add_pod_label",
"add_tolerations",
"CreatePVC",
"DeletePVC",
"mount_pvc",
"set_image_pull_secrets",
"use_secret_as_env",
"use_secret_as_volume",
'add_node_selector',
'add_pod_annotation',
'add_pod_label',
'add_toleration',
'CreatePVC',
'DeletePVC',
'mount_pvc',
'set_image_pull_secrets',
'use_secret_as_env',
'use_secret_as_volume',
]

from kfp.kubernetes.image import set_image_pull_secrets
Expand All @@ -33,7 +33,7 @@
from kfp.kubernetes.pod_metadata import add_pod_label
from kfp.kubernetes.secret import use_secret_as_env
from kfp.kubernetes.secret import use_secret_as_volume
from kfp.kubernetes.toleration import add_tolerations
from kfp.kubernetes.toleration import add_toleration
from kfp.kubernetes.volume import CreatePVC
from kfp.kubernetes.volume import DeletePVC
from kfp.kubernetes.volume import mount_pvc
32 changes: 13 additions & 19 deletions kubernetes_platform/python/kfp/kubernetes/toleration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Kubeflow Authors
# Copyright 2024 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,41 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from google.protobuf import json_format
from kfp.dsl import PipelineTask
from kfp.kubernetes import common
from kfp.kubernetes import kubernetes_executor_config_pb2 as pb
from kubernetes.client import V1Toleration


def add_tolerations(
def add_toleration(
task: PipelineTask,
tolerations: List[V1Toleration],
toleration: V1Toleration,
) -> PipelineTask:
"""Add `tolerations<https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/>`_. to a task.
"""Add a `toleration<https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/>`_. to a task.
Args:
task: Pipeline task.
tolerations: A list of V1Tolerations defined using the Kubernetes Python Client.
toleration: A V1Tolerations defined using the Kubernetes Python Client.
Returns:
Task object with added tolerations.
"""
tolerations_pb = []
for toleration in tolerations:
tolerations_pb.append(
pb.Toleration(
key=toleration.key,
operator=toleration.operator,
value=toleration.value,
effect=toleration.effect,
toleration_seconds=toleration.toleration_seconds,
))

msg = common.get_existing_kubernetes_config_as_message(task)
msg.tolerations.extend(tolerations_pb)
msg.tolerations.append(
pb.Toleration(
key=toleration.key,
operator=toleration.operator,
value=toleration.value,
effect=toleration.effect,
toleration_seconds=toleration.toleration_seconds,
))
task.platform_config['kubernetes'] = json_format.MessageToDict(msg)

return task
20 changes: 9 additions & 11 deletions kubernetes_platform/python/test/snapshot/data/toleration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Kubeflow Authors
# Copyright 2024 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,17 +25,15 @@ def comp():
@dsl.pipeline
def my_pipeline():
task = comp()
kubernetes.add_tolerations(
kubernetes.add_toleration(
task,
[
V1Toleration(
key='key1',
operator='Equal',
value='value1',
effect='NoExecute',
toleration_seconds=10,
),
],
V1Toleration(
key='key1',
operator='Equal',
value='value1',
effect='NoExecute',
toleration_seconds=10,
),
)


Expand Down
77 changes: 36 additions & 41 deletions kubernetes_platform/python/test/unit/test_tolerations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Kubeflow Authors
# Copyright 2024 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,16 +26,14 @@ def test_add_one(self):
@dsl.pipeline
def my_pipeline():
task = comp()
kubernetes.add_tolerations(
kubernetes.add_toleration(
task,
[
V1Toleration(
key='key1',
operator='Equal',
value='value1',
effect='NoSchedule',
),
],
V1Toleration(
key='key1',
operator='Equal',
value='value1',
effect='NoSchedule',
),
)

compiler.Compiler().compile(
Expand Down Expand Up @@ -65,17 +63,15 @@ def test_add_one_with_toleration_seconds(self):
@dsl.pipeline
def my_pipeline():
task = comp()
kubernetes.add_tolerations(
kubernetes.add_toleration(
task,
[
V1Toleration(
key='key1',
operator='Equal',
value='value1',
effect='NoExecute',
toleration_seconds=10,
),
],
V1Toleration(
key='key1',
operator='Equal',
value='value1',
effect='NoExecute',
toleration_seconds=10,
),
)

compiler.Compiler().compile(
Expand Down Expand Up @@ -106,20 +102,21 @@ def test_add_two(self):
@dsl.pipeline
def my_pipeline():
task = comp()
kubernetes.add_tolerations(
kubernetes.add_toleration(
task,
[
V1Toleration(
key='key1',
operator='Equal',
value='value1',
),
V1Toleration(
key='key2',
operator='Equal',
value='value2',
),
],
V1Toleration(
key='key1',
operator='Equal',
value='value1',
),
)
kubernetes.add_toleration(
task,
V1Toleration(
key='key2',
operator='Equal',
value='value2',
),
)

assert json_format.MessageToDict(my_pipeline.platform_spec) == {
Expand Down Expand Up @@ -154,15 +151,13 @@ def my_pipeline():
task = comp()
kubernetes.use_secret_as_volume(
task, secret_name='my-secret', mount_path='/mnt/my_vol')
kubernetes.add_tolerations(
kubernetes.add_toleration(
task,
[
V1Toleration(
key='key1',
operator='Equal',
value='value1',
),
],
V1Toleration(
key='key1',
operator='Equal',
value='value1',
),
)

assert json_format.MessageToDict(my_pipeline.platform_spec) == {
Expand Down

0 comments on commit 2078ee4

Please sign in to comment.