Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
Signed-off-by: KevinGrantLee <[email protected]>
  • Loading branch information
KevinGrantLee committed Aug 13, 2024
1 parent faeb8e6 commit b70aef8
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 32 deletions.
31 changes: 18 additions & 13 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def build_task_spec_for_task(
if task._task_spec.retry_policy is not None:
pipeline_task_spec.retry_policy.CopyFrom(
task._task_spec.retry_policy.to_proto())

# Inject resource fields into inputs
if task.container_spec and task.container_spec.resources:
for key, val in task.container_spec.resources.__dict__.items():
Expand Down Expand Up @@ -613,22 +613,22 @@ def build_container_spec_for_task(
Returns:
A PipelineContainerSpec object for the task.
"""

def convert_to_placeholder(input_value: str) -> str:
"""Checks if input is a pipeline channel and if so, converts to
compiler injected input name."""
pipeline_channels = (
pipeline_channel.extract_pipeline_channels_from_any(input_value)
)
pipeline_channel.extract_pipeline_channels_from_any(input_value))
if pipeline_channels:
assert len(pipeline_channels) == 1
channel = pipeline_channels[0]
additional_input_name = (
compiler_utils.additional_input_name_for_pipeline_channel(
channel))
compiler_utils.additional_input_name_for_pipeline_channel(
channel))
additional_input_placeholder = placeholders.InputValuePlaceholder(
additional_input_name)._to_string()
input_value = input_value.replace(
channel.pattern, additional_input_placeholder)
additional_input_name)._to_string()
input_value = input_value.replace(channel.pattern,
additional_input_placeholder)
return input_value

container_spec = (
Expand All @@ -645,22 +645,27 @@ def convert_to_placeholder(input_value: str) -> str:
if task.container_spec.resources is not None:
if task.container_spec.resources.cpu_request is not None:
container_spec.resources.resource_cpu_request = (
convert_to_placeholder(task.container_spec.resources.cpu_request))
convert_to_placeholder(
task.container_spec.resources.cpu_request))
if task.container_spec.resources.cpu_limit is not None:
container_spec.resources.resource_cpu_limit = (
convert_to_placeholder(task.container_spec.resources.cpu_limit))
if task.container_spec.resources.memory_request is not None:
container_spec.resources.resource_memory_request = (
convert_to_placeholder(task.container_spec.resources.memory_request))
convert_to_placeholder(
task.container_spec.resources.memory_request))
if task.container_spec.resources.memory_limit is not None:
container_spec.resources.resource_memory_limit = (
convert_to_placeholder(task.container_spec.resources.memory_limit))
convert_to_placeholder(
task.container_spec.resources.memory_limit))
if task.container_spec.resources.accelerator_count is not None:
container_spec.resources.accelerator.CopyFrom(
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec
.ResourceSpec.AcceleratorConfig(
resource_type=convert_to_placeholder(task.container_spec.resources.accelerator_type),
resource_count=convert_to_placeholder(task.container_spec.resources.accelerator_count),
resource_type=convert_to_placeholder(
task.container_spec.resources.accelerator_type),
resource_count=convert_to_placeholder(
task.container_spec.resources.accelerator_count),
))

return container_spec
Expand Down
37 changes: 27 additions & 10 deletions sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from kfp.dsl import placeholders
from kfp.dsl import structures
from kfp.dsl import utils
from kfp.dsl import pipeline_channel
from kfp.dsl.types import type_utils
from kfp.local import pipeline_orchestrator
from kfp.pipeline_spec import pipeline_spec_pb2
Expand Down Expand Up @@ -348,7 +347,10 @@ def _validate_cpu_request_limit(self, cpu: str) -> str:
return cpu

@block_if_final()
def set_cpu_request(self, cpu: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
def set_cpu_request(
self,
cpu: Union[str,
pipeline_channel.PipelineChannel]) -> 'PipelineTask':
"""Sets CPU request (minimum) for the task.
Args:
Expand All @@ -373,7 +375,10 @@ def set_cpu_request(self, cpu: Union[str, pipeline_channel.PipelineChannel]) ->
return self

@block_if_final()
def set_cpu_limit(self, cpu: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
def set_cpu_limit(
self,
cpu: Union[str,
pipeline_channel.PipelineChannel]) -> 'PipelineTask':
"""Sets CPU limit (maximum) for the task.
Args:
Expand All @@ -398,7 +403,9 @@ def set_cpu_limit(self, cpu: Union[str, pipeline_channel.PipelineChannel]) -> 'P
return self

@block_if_final()
def set_accelerator_limit(self, limit: Union[int, str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
def set_accelerator_limit(
self, limit: Union[int, str,
pipeline_channel.PipelineChannel]) -> 'PipelineTask':
"""Sets accelerator limit (maximum) for the task. Only applies if
accelerator type is also set via .set_accelerator_type().
Expand All @@ -414,8 +421,10 @@ def set_accelerator_limit(self, limit: Union[int, str, pipeline_channel.Pipeline
else:
if isinstance(limit, int):
limit = str(limit)
if isinstance(limit, str) and re.match(r'^0$|^1$|^2$|^4$|^8$|^16$', limit) is None:
raise ValueError(f'{"limit"!r} must be one of 0, 1, 2, 4, 8, 16.')
if isinstance(limit, str) and re.match(r'^0$|^1$|^2$|^4$|^8$|^16$',
limit) is None:
raise ValueError(
f'{"limit"!r} must be one of 0, 1, 2, 4, 8, 16.')

if self.container_spec.resources is not None:
self.container_spec.resources.accelerator_count = limit
Expand Down Expand Up @@ -462,15 +471,18 @@ def _validate_memory_request_limit(self, memory: str) -> str:
memory = str(memory)
else:
if re.match(r'^[0-9]+(E|Ei|P|Pi|T|Ti|G|Gi|M|Mi|K|Ki){0,1}$',
memory) is None:
memory) is None:
raise ValueError(
'Invalid memory string. Should be a number or a number '
'followed by one of "E", "Ei", "P", "Pi", "T", "Ti", "G", '
'"Gi", "M", "Mi", "K", "Ki".')
return memory

@block_if_final()
def set_memory_request(self, memory: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
def set_memory_request(
self,
memory: Union[str,
pipeline_channel.PipelineChannel]) -> 'PipelineTask':
"""Sets memory request (minimum) for the task.
Args:
Expand All @@ -494,7 +506,10 @@ def set_memory_request(self, memory: Union[str, pipeline_channel.PipelineChannel
return self

@block_if_final()
def set_memory_limit(self, memory: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
def set_memory_limit(
self,
memory: Union[str,
pipeline_channel.PipelineChannel]) -> 'PipelineTask':
"""Sets memory limit (maximum) for the task.
Args:
Expand Down Expand Up @@ -558,7 +573,9 @@ def add_node_selector_constraint(self, accelerator: str) -> 'PipelineTask':
return self.set_accelerator_type(accelerator)

@block_if_final()
def set_accelerator_type(self, accelerator: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
def set_accelerator_type(
self, accelerator: Union[str, pipeline_channel.PipelineChannel]
) -> 'PipelineTask':
"""Sets accelerator type to use when executing this task.
Args:
Expand Down
6 changes: 2 additions & 4 deletions sdk/python/kfp/dsl/pipeline_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ def test_set_caching_options(self):
'expected_cpu': '123.0m',
},
)
def test_set_valid_cpu_request_limit(self, cpu: str,
expected_cpu: str):
def test_set_valid_cpu_request_limit(self, cpu: str, expected_cpu: str):
task = pipeline_task.PipelineTask(
component_spec=structures.ComponentSpec.from_yaml_documents(
V2_YAML),
Expand All @@ -171,8 +170,7 @@ def test_set_valid_cpu_request_limit(self, cpu: str,
self.assertEqual(expected_cpu,
task.container_spec.resources.cpu_request)
task.set_cpu_limit(cpu)
self.assertEqual(expected_cpu,
task.container_spec.resources.cpu_limit)
self.assertEqual(expected_cpu, task.container_spec.resources.cpu_limit)

@parameterized.parameters(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


@dsl.component
def sum_numbers(a: int, b:int) -> int:
def sum_numbers(a: int, b: int) -> int:
return a + b


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


@dsl.component
def sum_numbers(a: int, b:int) -> int:
def sum_numbers(a: int, b: int) -> int:
return a + b


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ def accelerator_limit() -> str:


@dsl.component
def sum_numbers(a: int, b:int) -> int:
def sum_numbers(a: int, b: int) -> int:
return a + b


@dsl.pipeline
def pipeline(
):
def pipeline():
sum_numbers_task = sum_numbers(a=1, b=2)
sum_numbers_task.set_cpu_limit(cpu_limit().output)
sum_numbers_task.set_memory_limit(memory_limit().output)
Expand Down

0 comments on commit b70aef8

Please sign in to comment.