-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(components): Support dynamic machine type paramters in CustomTrainingJobOp #10883
Changes from all commits
a987706
77da3a3
002998a
21d2bd8
30e436e
d255fbb
19d2887
32f701a
9f462c4
973c66d
eb8a587
1c70801
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -772,3 +772,34 @@ def get_dependencies( | |
dependencies[downstream_names[0]].add(upstream_names[0]) | ||
|
||
return dependencies | ||
|
||
|
||
def recursive_replace_placeholders(data: Union[Dict, List], old_value: str, | ||
new_value: str) -> Union[Dict, List]: | ||
"""Recursively replaces values in a nested dict/list object. | ||
|
||
This method is used to replace PipelineChannel objects with input parameter | ||
placeholders in a nested object like worker_pool_specs for custom jobs. | ||
|
||
Args: | ||
data: A nested object that can contain dictionaries and/or lists. | ||
old_value: The value that will be replaced. | ||
new_value: The value to replace the old value with. | ||
|
||
Returns: | ||
A copy of data with all occurences of old_value replaced by new_value. | ||
""" | ||
if isinstance(data, dict): | ||
return { | ||
k: recursive_replace_placeholders(v, old_value, new_value) | ||
for k, v in data.items() | ||
} | ||
elif isinstance(data, list): | ||
return [ | ||
recursive_replace_placeholders(i, old_value, new_value) | ||
for i in data | ||
] | ||
else: | ||
if isinstance(data, pipeline_channel.PipelineChannel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method seems explicitly replacing placeholder from one representation to another. It's not for replacing arbitrary value. So the method name should reflect it's purpose. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although here I'm just using this method for replacing placeholders, it can be used for arbitrary values as well so I left the method name general. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed method to recursive_replace_placeholders |
||
data = str(data) | ||
return new_value if data == old_value else data |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -238,12 +238,14 @@ def build_task_spec_for_task( | |
input_name].component_input_parameter = ( | ||
component_input_parameter) | ||
|
||
elif isinstance(input_value, str): | ||
# Handle extra input due to string concat | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, this chunk of code is only applicable for string typed inputs, why merging the code and expand it to other input types? Also it's a bit hard to read the diff between the deleted code and the extracted. Can you try make the changes in place without refactoring, and see if it's actually necessary to expand the logic to non-string typed inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found that this block of code could be reused for handling PipelineChannels inside of So I could un-refactor and duplicate the logic; I do have a slight preference for this refactoring but can go either way. wdyt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name of the extracted method isn't accurate--the code does more than placeholder manipulation but also component input expansion. if isinstance(input_value, str):
# shared code
pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.string_value = input_value
elif isinstance(input_value, (int, float, bool, dict, list)):
if isinstance(input_value, (dict, list):
# shared code
pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.CopyFrom(
to_protobuf_value(input_value))
else:
raise You can achieve the same goal, and even more code reuse, without extracting a shared method by: if not isinstance(input_value, (str, dict, list, int, float, bool)):
raise
if isinstance(input_value, (str, dict, list)):
# shared code
pipeline_task_spec.inputs.parameters[
input_name].runtime_value.constant.CopyFrom(
to_protobuf_value(input_value)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aside from the refactoring part, I wonder what's the case for dict and list? In case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If It looks like this with PipelineChannel objects There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation. So |
||
elif isinstance(input_value, (str, int, float, bool, dict, list)): | ||
pipeline_channels = ( | ||
pipeline_channel.extract_pipeline_channels_from_any(input_value) | ||
) | ||
for channel in pipeline_channels: | ||
# NOTE: case like this p3 = print_and_return_str(s='Project = {}'.format(project)) | ||
# triggers this code | ||
|
||
# value contains PipelineChannel placeholders which needs to be | ||
# replaced. And the input needs to be added to the task spec. | ||
|
||
|
@@ -265,8 +267,14 @@ def build_task_spec_for_task( | |
|
||
additional_input_placeholder = placeholders.InputValuePlaceholder( | ||
additional_input_name)._to_string() | ||
input_value = input_value.replace(channel.pattern, | ||
additional_input_placeholder) | ||
|
||
if isinstance(input_value, str): | ||
input_value = input_value.replace( | ||
channel.pattern, additional_input_placeholder) | ||
else: | ||
input_value = compiler_utils.recursive_replace_placeholders( | ||
input_value, channel.pattern, | ||
additional_input_placeholder) | ||
|
||
if channel.task_name: | ||
# Value is produced by an upstream task. | ||
|
@@ -299,11 +307,6 @@ def build_task_spec_for_task( | |
additional_input_name].component_input_parameter = ( | ||
component_input_parameter) | ||
|
||
pipeline_task_spec.inputs.parameters[ | ||
input_name].runtime_value.constant.string_value = input_value | ||
|
||
elif isinstance(input_value, (str, int, float, bool, dict, list)): | ||
|
||
pipeline_task_spec.inputs.parameters[ | ||
input_name].runtime_value.constant.CopyFrom( | ||
to_protobuf_value(input_value)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import google_cloud_pipeline_components.v1.custom_job as custom_job | ||
from kfp import dsl | ||
|
||
|
||
@dsl.component | ||
def flip_biased_coin_op() -> str: | ||
"""Flip a coin and output heads.""" | ||
return 'heads' | ||
|
||
|
||
@dsl.component | ||
def machine_type() -> str: | ||
return 'n1-standard-4' | ||
|
||
|
||
@dsl.component | ||
def accelerator_type() -> str: | ||
return 'NVIDIA_TESLA_P4' | ||
|
||
|
||
@dsl.component | ||
def accelerator_count() -> int: | ||
return 1 | ||
|
||
|
||
@dsl.pipeline | ||
def pipeline( | ||
project: str, | ||
location: str, | ||
encryption_spec_key_name: str = '', | ||
): | ||
flip1 = flip_biased_coin_op().set_caching_options(False) | ||
machine_type_task = machine_type() | ||
accelerator_type_task = accelerator_type() | ||
accelerator_count_task = accelerator_count() | ||
|
||
with dsl.Condition(flip1.output == 'heads'): | ||
custom_job.CustomTrainingJobOp( | ||
display_name='add-numbers', | ||
worker_pool_specs=[{ | ||
'container_spec': { | ||
'image_uri': ( | ||
'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0' | ||
), | ||
'command': ['echo'], | ||
'args': ['foo'], | ||
}, | ||
'machine_spec': { | ||
'machine_type': machine_type_task.output, | ||
'accelerator_type': accelerator_type_task.output, | ||
'accelerator_count': accelerator_count_task.output, | ||
}, | ||
'replica_count': 1, | ||
}], | ||
project=project, | ||
location=location, | ||
encryption_spec_key_name=encryption_spec_key_name, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
from kfp import compiler | ||
compiler.Compiler().compile( | ||
pipeline_func=pipeline, package_path=__file__.replace('.py', '.yaml')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pipeline channel placeholders -> input parameter placeholder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.