Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(components): Support dynamic machine type paramters in CustomTrainingJobOp #10883

Merged
merged 12 commits into from
Jun 13, 2024
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Current Version (in development)

## Features
* Support dynamic machine type parameters in CustomTrainingJobOp. [\#10883](https://github.com/kubeflow/pipelines/pull/10883)

## Breaking changes
* Drop support for Python 3.7 since it has reached end-of-life. [\#10750](https://github.com/kubeflow/pipelines/pull/10750)
Expand Down
31 changes: 31 additions & 0 deletions sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,34 @@ def get_dependencies(
dependencies[downstream_names[0]].add(upstream_names[0])

return dependencies


def recursive_replace_placeholders(data: Union[Dict, List], old_value: str,
new_value: str) -> Union[Dict, List]:
"""Recursively replaces values in a nested dict/list object.

This method is used to replace PipelineChannel objects with input parameter
placeholders in a nested object like worker_pool_specs for custom jobs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline channel placeholders -> input parameter placeholder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Args:
data: A nested object that can contain dictionaries and/or lists.
old_value: The value that will be replaced.
new_value: The value to replace the old value with.

Returns:
A copy of data with all occurences of old_value replaced by new_value.
"""
if isinstance(data, dict):
return {
k: recursive_replace_placeholders(v, old_value, new_value)
for k, v in data.items()
}
elif isinstance(data, list):
return [
recursive_replace_placeholders(i, old_value, new_value)
for i in data
]
else:
if isinstance(data, pipeline_channel.PipelineChannel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems explicitly replacing placeholder from one representation to another. It's not for replacing arbitrary value. So the method name should reflect it's purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although here I'm just using this method for replacing placeholders, it can be used for arbitrary values as well so I left the method name general.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed method to recursive_replace_placeholders

data = str(data)
return new_value if data == old_value else data
61 changes: 61 additions & 0 deletions sdk/python/kfp/compiler/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,67 @@ def test_additional_input_name_for_pipeline_channel(self, channel,
expected,
compiler_utils.additional_input_name_for_pipeline_channel(channel))

@parameterized.parameters(
{
'data': [{
'container_spec': {
'image_uri':
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0',
'command': ['echo'],
'args': ['foo']
},
'machine_spec': {
'machine_type':
pipeline_channel.PipelineParameterChannel(
name='Output',
channel_type='String',
task_name='machine-type'),
'accelerator_type':
pipeline_channel.PipelineParameterChannel(
name='Output',
channel_type='String',
task_name='accelerator-type'),
'accelerator_count':
1
},
'replica_count': 1
}],
'old_value':
'{{channel:task=machine-type;name=Output;type=String;}}',
'new_value':
'{{$.inputs.parameters['
'pipelinechannel--machine-type-Output'
']}}',
'expected': [{
'container_spec': {
'image_uri':
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0',
'command': ['echo'],
'args': ['foo']
},
'machine_spec': {
'machine_type':
'{{$.inputs.parameters['
'pipelinechannel--machine-type-Output'
']}}',
'accelerator_type':
pipeline_channel.PipelineParameterChannel(
name='Output',
channel_type='String',
task_name='accelerator-type'),
'accelerator_count':
1
},
'replica_count': 1
}],
},)
def test_recursive_replace_placeholders(self, data, old_value, new_value,
expected):
self.assertEqual(
expected,
compiler_utils.recursive_replace_placeholders(
data, old_value, new_value))


if __name__ == '__main__':
unittest.main()
21 changes: 12 additions & 9 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,14 @@ def build_task_spec_for_task(
input_name].component_input_parameter = (
component_input_parameter)

elif isinstance(input_value, str):
# Handle extra input due to string concat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, this chunk of code is only applicable for string typed inputs, why merging the code and expand it to other input types? Also it's a bit hard to read the diff between the deleted code and the extracted. Can you try make the changes in place without refactoring, and see if it's actually necessary to expand the logic to non-string typed inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that this block of code could be reused for handling PipelineChannels inside of worker_pool_specs in addition to handling string typed inputs. Instead of copying the ~50 lines of code, I thought it'd be better to refactor the logic as a separate function def replace_and_inject_placeholders().

So I could un-refactor and duplicate the logic; I do have a slight preference for this refactoring but can go either way. wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name of the extracted method isn't accurate--the code does more than placeholder manipulation but also component input expansion.
The branch logic now reads like this:

if isinstance(input_value, str):
    # shared code
    pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.string_value = input_value
elif isinstance(input_value, (int, float, bool, dict, list)):
    if isinstance(input_value, (dict, list):
          # shared code
    pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.CopyFrom(
                    to_protobuf_value(input_value))
else:
     raise

You can achieve the same goal, and even more code reuse, without extracting a shared method by:

if not isinstance(input_value, (str, dict, list, int, float, bool)):
    raise

if isinstance(input_value, (str, dict, list)):
    # shared code

pipeline_task_spec.inputs.parameters[
            input_name].runtime_value.constant.CopyFrom(
                to_protobuf_value(input_value))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the refactoring part, I wonder what's the case for dict and list? In case CustomTrainingJobOp is used, what's the input_value here?

Copy link
Contributor Author

@KevinGrantLee KevinGrantLee Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If CustomTrainingJobOp is used, then we pass in worker_pool_spec into input_value.

It looks like this with PipelineChannel objects
input_value = [{'container_spec': {'image_uri': 'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0', 'command': ['echo'], 'args': ['foo']}, 'machine_spec': {'machine_type': {{channel:task=machine-type;name=Output;type=String;}}, 'accelerator_type': {{channel:task=accelerator-type;name=Output;type=String;}}, 'accelerator_count': {{channel:task=accelerator-count;name=Output;type=Integer;}}}, 'replica_count': 1}]

Copy link
Member

@chensun chensun Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. So input_value would be of type list in this case. Including dict in the same code path is just for future use cases not entirely necessary at this moment, right? I'm fine to include dict now.

elif isinstance(input_value, (str, int, float, bool, dict, list)):
pipeline_channels = (
pipeline_channel.extract_pipeline_channels_from_any(input_value)
)
for channel in pipeline_channels:
# NOTE: case like this p3 = print_and_return_str(s='Project = {}'.format(project))
# triggers this code

# value contains PipelineChannel placeholders which needs to be
# replaced. And the input needs to be added to the task spec.

Expand All @@ -265,8 +267,14 @@ def build_task_spec_for_task(

additional_input_placeholder = placeholders.InputValuePlaceholder(
additional_input_name)._to_string()
input_value = input_value.replace(channel.pattern,
additional_input_placeholder)

if isinstance(input_value, str):
input_value = input_value.replace(
channel.pattern, additional_input_placeholder)
else:
input_value = compiler_utils.recursive_replace_placeholders(
input_value, channel.pattern,
additional_input_placeholder)

if channel.task_name:
# Value is produced by an upstream task.
Expand Down Expand Up @@ -299,11 +307,6 @@ def build_task_spec_for_task(
additional_input_name].component_input_parameter = (
component_input_parameter)

pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.string_value = input_value

elif isinstance(input_value, (str, int, float, bool, dict, list)):

pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.CopyFrom(
to_protobuf_value(input_value))
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/pipeline_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def extract_pipeline_channels_from_string(


def extract_pipeline_channels_from_any(
payload: Union[PipelineChannel, str, list, tuple, dict]
payload: Union[PipelineChannel, str, int, float, bool, list, tuple, dict]
) -> List[PipelineChannel]:
"""Recursively extract PipelineChannels from any object or list of objects.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import google_cloud_pipeline_components.v1.custom_job as custom_job
from kfp import dsl


@dsl.component
def flip_biased_coin_op() -> str:
"""Flip a coin and output heads."""
return 'heads'


@dsl.component
def machine_type() -> str:
return 'n1-standard-4'


@dsl.component
def accelerator_type() -> str:
return 'NVIDIA_TESLA_P4'


@dsl.component
def accelerator_count() -> int:
return 1


@dsl.pipeline
def pipeline(
project: str,
location: str,
encryption_spec_key_name: str = '',
):
flip1 = flip_biased_coin_op().set_caching_options(False)
machine_type_task = machine_type()
accelerator_type_task = accelerator_type()
accelerator_count_task = accelerator_count()

with dsl.Condition(flip1.output == 'heads'):
custom_job.CustomTrainingJobOp(
display_name='add-numbers',
worker_pool_specs=[{
'container_spec': {
'image_uri': (
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0'
),
'command': ['echo'],
'args': ['foo'],
},
'machine_spec': {
'machine_type': machine_type_task.output,
'accelerator_type': accelerator_type_task.output,
'accelerator_count': accelerator_count_task.output,
},
'replica_count': 1,
}],
project=project,
location=location,
encryption_spec_key_name=encryption_spec_key_name,
)


if __name__ == '__main__':
from kfp import compiler
compiler.Compiler().compile(
pipeline_func=pipeline, package_path=__file__.replace('.py', '.yaml'))
Loading
Loading