diff --git a/components/google-cloud/google_cloud_pipeline_components/aiplatform/__init__.py b/components/google-cloud/google_cloud_pipeline_components/aiplatform/__init__.py index 965b011c3b7..1746bc13ff5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/aiplatform/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/aiplatform/__init__.py @@ -52,67 +52,84 @@ TimeSeriesDatasetCreateOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'dataset/create_time_series_dataset/component.yaml')) + 'dataset/create_time_series_dataset/component.yaml', + ) +) ImageDatasetCreateOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/create_image_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/create_image_dataset/component.yaml' + ) +) TabularDatasetCreateOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'dataset/create_tabular_dataset/component.yaml')) + 'dataset/create_tabular_dataset/component.yaml', + ) +) TextDatasetCreateOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/create_text_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/create_text_dataset/component.yaml' + ) +) VideoDatasetCreateOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/create_video_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/create_video_dataset/component.yaml' + ) +) ImageDatasetExportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/export_image_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/export_image_dataset/component.yaml' + ) +) TabularDatasetExportDataOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'dataset/export_tabular_dataset/component.yaml')) + 'dataset/export_tabular_dataset/component.yaml', + ) +) TimeSeriesDatasetExportDataOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'dataset/export_time_series_dataset/component.yaml')) + 'dataset/export_time_series_dataset/component.yaml', + ) +) TextDatasetExportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/export_text_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/export_text_dataset/component.yaml' + ) +) VideoDatasetExportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/export_video_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/export_video_dataset/component.yaml' + ) +) ImageDatasetImportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/import_image_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/import_image_dataset/component.yaml' + ) +) TextDatasetImportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/import_text_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/import_text_dataset/component.yaml' + ) +) VideoDatasetImportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'dataset/import_video_dataset/component.yaml')) + os.path.dirname(__file__), 'dataset/import_video_dataset/component.yaml' + ) +) CustomContainerTrainingJobRunOp = utils.convert_method_to_component( aiplatform_sdk.CustomContainerTrainingJob, @@ -127,55 +144,74 @@ AutoMLImageTrainingJobRunOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'automl_training_job/automl_image_training_job/component.yaml')) + 'automl_training_job/automl_image_training_job/component.yaml', + ) +) AutoMLTextTrainingJobRunOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'automl_training_job/automl_text_training_job/component.yaml')) + 'automl_training_job/automl_text_training_job/component.yaml', + ) +) AutoMLTabularTrainingJobRunOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'automl_training_job/automl_tabular_training_job/component.yaml')) + 'automl_training_job/automl_tabular_training_job/component.yaml', + ) +) AutoMLForecastingTrainingJobRunOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'automl_training_job/automl_forecasting_training_job/component.yaml')) + 'automl_training_job/automl_forecasting_training_job/component.yaml', + ) +) AutoMLVideoTrainingJobRunOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'automl_training_job/automl_video_training_job/component.yaml')) + 'automl_training_job/automl_video_training_job/component.yaml', + ) +) ModelDeleteOp = load_component_from_file( - os.path.join( - os.path.dirname(__file__), 'model/delete_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'model/delete_model/component.yaml') +) ModelExportOp = load_component_from_file( - os.path.join( - os.path.dirname(__file__), 'model/export_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'model/export_model/component.yaml') +) ModelDeployOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'endpoint/deploy_model/component.yaml')) + os.path.dirname(__file__), 'endpoint/deploy_model/component.yaml' + ) +) ModelUndeployOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'endpoint/undeploy_model/component.yaml')) + os.path.dirname(__file__), 'endpoint/undeploy_model/component.yaml' + ) +) ModelBatchPredictOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'batch_predict_job/component.yaml')) + os.path.join(os.path.dirname(__file__), 'batch_predict_job/component.yaml') +) ModelUploadOp = load_component_from_file( - os.path.join( - os.path.dirname(__file__), 'model/upload_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'model/upload_model/component.yaml') +) EndpointCreateOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'endpoint/create_endpoint/component.yaml')) + os.path.dirname(__file__), 'endpoint/create_endpoint/component.yaml' + ) +) EndpointDeleteOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'endpoint/delete_endpoint/component.yaml')) + os.path.dirname(__file__), 'endpoint/delete_endpoint/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/aiplatform/utils.py b/components/google-cloud/google_cloud_pipeline_components/aiplatform/utils.py index 1253d5c4e22..14bbd2ef703 100644 --- a/components/google-cloud/google_cloud_pipeline_components/aiplatform/utils.py +++ b/components/google-cloud/google_cloud_pipeline_components/aiplatform/utils.py @@ -32,26 +32,31 @@ # Container image that is used for component containers # TODO tie the container version to sdk release version instead of latest -DEFAULT_CONTAINER_IMAGE = 'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1' +DEFAULT_CONTAINER_IMAGE = ( + 'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1' +) # map of MB SDK type to Metadata type RESOURCE_TO_METADATA_TYPE = { aiplatform.datasets.dataset._Dataset: 'google.VertexDataset', # pylint: disable=protected-access aiplatform.Model: 'google.VertexModel', aiplatform.Endpoint: 'google.VertexEndpoint', - aiplatform.BatchPredictionJob: 'google.VertexBatchPredictionJob' + aiplatform.BatchPredictionJob: 'google.VertexBatchPredictionJob', } PROTO_PLUS_CLASS_TYPES = { - aiplatform_v1beta1.types.explanation_metadata.ExplanationMetadata: - 'ExplanationMetadata', - aiplatform_v1beta1.types.explanation.ExplanationParameters: - 'ExplanationParameters', + aiplatform_v1beta1.types.explanation_metadata.ExplanationMetadata: ( + 'ExplanationMetadata' + ), + aiplatform_v1beta1.types.explanation.ExplanationParameters: ( + 'ExplanationParameters' + ), } def get_forward_reference( - annotation: Any) -> Optional[aiplatform.base.VertexAiResourceNoun]: + annotation: Any, +) -> Optional[aiplatform.base.VertexAiResourceNoun]: """Resolves forward references to AiPlatform Class.""" def get_aiplatform_class_by_name(_annotation): @@ -66,6 +71,7 @@ def get_aiplatform_class_by_name(_annotation): try: # Python 3.7+ from typing import ForwardRef + if isinstance(annotation, ForwardRef): annotation = annotation.__forward_arg__ ai_platform_class = get_aiplatform_class_by_name(annotation) @@ -79,21 +85,24 @@ def get_aiplatform_class_by_name(_annotation): # This is the Union of all typed datasets. # Relying on the annotation defined in the SDK # as additional typed Datasets may be added in the future. -dataset_annotation = inspect.signature( - aiplatform.CustomTrainingJob.run).parameters['dataset'].annotation +dataset_annotation = ( + inspect.signature(aiplatform.CustomTrainingJob.run) + .parameters['dataset'] + .annotation +) def resolve_annotation(annotation: Any) -> Any: """Resolves annotation type against a MB SDK type. - Use this for Optional, Union, Forward References + Use this for Optional, Union, Forward References - Args: - annotation: Annotation to resolve + Args: + annotation: Annotation to resolve - Returns: - Direct annotation - """ + Returns: + Direct annotation + """ # handle forward reference string @@ -131,12 +140,12 @@ def resolve_annotation(annotation: Any) -> Any: def is_serializable_to_json(annotation: Any) -> bool: """Checks if the type is serializable. - Args: - annotation: parameter annotation + Args: + annotation: parameter annotation - Returns: - True if serializable to json. - """ + Returns: + True if serializable to json. + """ serializable_types = (dict, list, collections.abc.Sequence, Dict, Sequence) return getattr(annotation, '__origin__', None) in serializable_types @@ -144,12 +153,12 @@ def is_serializable_to_json(annotation: Any) -> bool: def is_mb_sdk_resource_noun_type(mb_sdk_type: Any) -> bool: """Determines if type passed in should be a metadata type. - Args: - mb_sdk_type: Type to check + Args: + mb_sdk_type: Type to check - Returns: - True if this is a resource noun - """ + Returns: + True if this is a resource noun + """ if inspect.isclass(mb_sdk_type): return issubclass(mb_sdk_type, aiplatform.base.VertexAiResourceNoun) return False @@ -158,12 +167,12 @@ def is_mb_sdk_resource_noun_type(mb_sdk_type: Any) -> bool: def get_proto_plus_class(annotation: Any) -> Optional[Callable]: """Get Proto Plus Class for this annotation. - Args: - annotation: parameter annotation + Args: + annotation: parameter annotation - Returns: - Proto Plus Class for annotation type - """ + Returns: + Proto Plus Class for annotation type + """ if annotation in PROTO_PLUS_CLASS_TYPES: return annotation @@ -171,12 +180,12 @@ def get_proto_plus_class(annotation: Any) -> Optional[Callable]: def get_proto_plus_serializer(annotation: Any) -> Optional[Callable]: """Get a serializer for objects to pass them as strings. - Args: - annotation: Parameter annotation + Args: + annotation: Parameter annotation - Returns: - serializer for that annotation type if it's a Proto Plus class - """ + Returns: + serializer for that annotation type if it's a Proto Plus class + """ proto_plus_class = get_proto_plus_class(annotation) if proto_plus_class: return proto_plus_class.to_json @@ -185,12 +194,12 @@ def get_proto_plus_serializer(annotation: Any) -> Optional[Callable]: def get_serializer(annotation: Any) -> Optional[Callable]: """Get a serializer for objects to pass them as strings. - Args: - annotation: Parameter annotation + Args: + annotation: Parameter annotation - Returns: - serializer for that annotation type - """ + Returns: + serializer for that annotation type + """ proto_plus_serializer = get_proto_plus_serializer(annotation) if proto_plus_serializer: return proto_plus_serializer @@ -199,17 +208,19 @@ def get_serializer(annotation: Any) -> Optional[Callable]: return json.dumps -def get_proto_plus_deserializer(annotation: Any) -> Optional[Callable[..., str]]: +def get_proto_plus_deserializer( + annotation: Any, +) -> Optional[Callable[..., str]]: """Get deserializer for objects to pass them as strings. - Remote runner will deserialize. + Remote runner will deserialize. - Args: - annotation: parameter annotation + Args: + annotation: parameter annotation - Returns: - deserializer for annotation type if it's a Proto Plus class - """ + Returns: + deserializer for annotation type if it's a Proto Plus class + """ proto_plus_class = get_proto_plus_class(annotation) if proto_plus_class: return proto_plus_class.from_json @@ -218,16 +229,15 @@ def get_proto_plus_deserializer(annotation: Any) -> Optional[Callable[..., str]] def get_deserializer(annotation: Any) -> Optional[Callable[..., str]]: """Get deserializer for objects to pass them as strings. - Remote runner will deserialize. + Remote runner will deserialize. - Args: - annotation: parameter annotation + Args: + annotation: parameter annotation - Returns: - deserializer for annotation type + Returns: + deserializer for annotation type """ - proto_plus_deserializer = get_proto_plus_deserializer( - annotation) + proto_plus_deserializer = get_proto_plus_deserializer(annotation) if proto_plus_deserializer: return proto_plus_deserializer @@ -236,13 +246,14 @@ def get_deserializer(annotation: Any) -> Optional[Callable[..., str]]: def map_resource_to_metadata_type( - mb_sdk_type: aiplatform.base.VertexAiResourceNoun) -> Tuple[str, str]: + mb_sdk_type: aiplatform.base.VertexAiResourceNoun, +) -> Tuple[str, str]: """Maps an MB SDK type to Metadata type. - Returns: - Tuple of component parameter name and metadata type. - ie aiplatform.Model -> "model", "google.VertexModel" - """ + Returns: + Tuple of component parameter name and metadata type. + ie aiplatform.Model -> "model", "google.VertexModel" + """ # type should always be in this map if is_mb_sdk_resource_noun_type(mb_sdk_type): @@ -279,9 +290,11 @@ def should_be_metadata_type(mb_sdk_type: Any) -> bool: def is_resource_name_parameter_name(param_name: str) -> bool: """Determines if the mb_sdk parameter is a resource name.""" - return param_name not in NOT_RESOURCE_NAME_PARAMETER_NAMES and \ - not param_name.endswith('encryption_spec_key_name') and \ - param_name.endswith('_name') + return ( + param_name not in NOT_RESOURCE_NAME_PARAMETER_NAMES + and not param_name.endswith('encryption_spec_key_name') + and param_name.endswith('_name') + ) # These parameters are filtered from MB SDK methods @@ -292,22 +305,22 @@ def filter_signature( signature: inspect.Signature, is_init_signature: bool = False, self_type: Optional[aiplatform.base.VertexAiResourceNoun] = None, - component_param_name_to_mb_sdk_param_name: Dict[str, str] = None + component_param_name_to_mb_sdk_param_name: Dict[str, str] = None, ) -> inspect.Signature: """Removes unused params from signature. - Args: - signature (inspect.Signature): Model Builder SDK Method Signature. - is_init_signature (bool): is this constructor signature - self_type (aiplatform.base.VertexAiResourceNoun): This is used to - replace *_name str fields with resource name type. - component_param_name_to_mb_sdk_param_name dict[str, str]: Mapping to - keep track of param names changed to make them component friendly( ie: - model_name -> model) - - Returns: - Signature appropriate for component creation. - """ + Args: + signature (inspect.Signature): Model Builder SDK Method Signature. + is_init_signature (bool): is this constructor signature + self_type (aiplatform.base.VertexAiResourceNoun): This is used to replace + *_name str fields with resource name type. + component_param_name_to_mb_sdk_param_name dict[str, str]: Mapping to + keep track of param names changed to make them component friendly( ie: + model_name -> model) + + Returns: + Signature appropriate for component creation. + """ new_params = [] for param in signature.parameters.values(): if param.name not in PARAMS_TO_REMOVE: @@ -315,32 +328,36 @@ def filter_signature( # to enforce metadata entry # ie: model_name -> model if is_init_signature and is_resource_name_parameter_name(param.name): - new_name = param.name[:-len('_name')] + new_name = param.name[: -len('_name')] new_params.append( inspect.Parameter( name=new_name, kind=param.kind, default=param.default, - annotation=self_type)) + annotation=self_type, + ) + ) component_param_name_to_mb_sdk_param_name[new_name] = param.name else: new_params.append(param) return inspect.Signature( - parameters=new_params, return_annotation=signature.return_annotation) + parameters=new_params, return_annotation=signature.return_annotation + ) -def signatures_union(init_sig: inspect.Signature, - method_sig: inspect.Signature) -> inspect.Signature: +def signatures_union( + init_sig: inspect.Signature, method_sig: inspect.Signature +) -> inspect.Signature: """Returns a Union of the constructor and method signature. - Args: - init_sig (inspect.Signature): Constructor signature - method_sig (inspect.Signature): Method signature + Args: + init_sig (inspect.Signature): Constructor signature + method_sig (inspect.Signature): Method signature - Returns: - A Union of the the two Signatures as a single Signature - """ + Returns: + A Union of the the two Signatures as a single Signature + """ def key(param): # all params are keyword or positional @@ -350,10 +367,12 @@ def key(param): return 1 params = list(init_sig.parameters.values()) + list( - method_sig.parameters.values()) + method_sig.parameters.values() + ) params.sort(key=key) return inspect.Signature( - parameters=params, return_annotation=method_sig.return_annotation) + parameters=params, return_annotation=method_sig.return_annotation + ) def filter_docstring_args( @@ -363,14 +382,14 @@ def filter_docstring_args( ) -> Dict[str, str]: """Removes unused params from docstring Args section. - Args: - signature (inspect.Signature): Model Builder SDK Method Signature. - docstring (str): Model Builder SDK Method docstring from method.__doc__ - is_init_signature (bool): is this constructor signature + Args: + signature (inspect.Signature): Model Builder SDK Method Signature. + docstring (str): Model Builder SDK Method docstring from method.__doc__ + is_init_signature (bool): is this constructor signature - Returns: - Dictionary of Arg names as keys and descriptions as values. - """ + Returns: + Dictionary of Arg names as keys and descriptions as values. + """ try: parsed_docstring = docstring_parser.parse(docstring) except ValueError: @@ -384,7 +403,7 @@ def filter_docstring_args( # change resource name signatures to resource types # to match new param.names ie: model_name -> model if is_init_signature and is_resource_name_parameter_name(param.name): - new_arg_name = param.name[:-len('_name')] + new_arg_name = param.name[: -len('_name')] # check if there was an arg description for this parameter. if args_dict.get(param.name): @@ -392,20 +411,23 @@ def filter_docstring_args( return new_args_dict -def generate_docstring(args_dict: Dict[str, str], signature: inspect.Signature, - method_docstring: str) -> str: +def generate_docstring( + args_dict: Dict[str, str], + signature: inspect.Signature, + method_docstring: str, +) -> str: """Generates a new doc string using args_dict provided. - Args: - args_dict (Dict[str, str]): A dictionary of Arg names as keys and - descriptions as values. - signature (inspect.Signature): Method Signature of the converted method. - method_docstring (str): Model Builder SDK Method docstring from - method.__doc__ + Args: + args_dict (Dict[str, str]): A dictionary of Arg names as keys and + descriptions as values. + signature (inspect.Signature): Method Signature of the converted method. + method_docstring (str): Model Builder SDK Method docstring from + method.__doc__ - Returns: - A doc string for converted method. - """ + Returns: + A doc string for converted method. + """ try: parsed_docstring = docstring_parser.parse(method_docstring) except ValueError: @@ -424,7 +446,8 @@ def generate_docstring(args_dict: Dict[str, str], signature: inspect.Signature, if parsed_docstring.returns: formated_return = parsed_docstring.returns.description.replace( - '\n', '\n ') + '\n', '\n ' + ) doc += 'Returns:\n' doc += f' {formated_return}\n' @@ -437,64 +460,65 @@ def generate_docstring(args_dict: Dict[str, str], signature: inspect.Signature, return doc -def convert_method_to_component(cls: aiplatform.base.VertexAiResourceNoun, - method: Callable) -> Callable: +def convert_method_to_component( + cls: aiplatform.base.VertexAiResourceNoun, method: Callable +) -> Callable: """Converts a MB SDK Method to a Component wrapper. - The wrapper enforces the correct signature w.r.t the MB SDK. The signature - is also available to inspect. + The wrapper enforces the correct signature w.r.t the MB SDK. The signature + is also available to inspect. - For example: + For example: - aiplatform.Model.deploy is converted to ModelDeployOp + aiplatform.Model.deploy is converted to ModelDeployOp - Which can be called: - model_deploy_step = ModelDeployOp( - project=project, # Pipeline parameter - endpoint=endpoint_create_step.outputs['endpoint'], - model=model_upload_step.outputs['model'], - deployed_model_display_name='my-deployed-model', - machine_type='n1-standard-4', - ) - - Generates and invokes the following Component: - - name: Model-deploy - inputs: - - {name: project, type: String} - - {name: endpoint, type: Artifact} - - {name: model, type: Model} - outputs: - - {name: endpoint, type: Artifact} - implementation: - container: - image: gcr.io/sashaproject-1/mb_sdk_component:latest - command: - - python3 - - remote_runner.py - - --cls_name=Model - - --method_name=deploy - - --method.deployed_model_display_name=my-deployed-model - - --method.machine_type=n1-standard-4 - args: - - --resource_name_output_artifact_path - - {outputPath: endpoint} - - --init.project - - {inputValue: project} - - --method.endpoint - - {inputPath: endpoint} - - --init.model_name - - {inputPath: model} - - - Args: - method (Callable): A MB SDK Method - should_serialize_init (bool): Whether to also include the constructor - params in the component + Which can be called: + model_deploy_step = ModelDeployOp( + project=project, # Pipeline parameter + endpoint=endpoint_create_step.outputs['endpoint'], + model=model_upload_step.outputs['model'], + deployed_model_display_name='my-deployed-model', + machine_type='n1-standard-4', + ) - Returns: - A Component wrapper that accepts the MB SDK params and returns a Task. - """ + Generates and invokes the following Component: + + name: Model-deploy + inputs: + - {name: project, type: String} + - {name: endpoint, type: Artifact} + - {name: model, type: Model} + outputs: + - {name: endpoint, type: Artifact} + implementation: + container: + image: gcr.io/sashaproject-1/mb_sdk_component:latest + command: + - python3 + - remote_runner.py + - --cls_name=Model + - --method_name=deploy + - --method.deployed_model_display_name=my-deployed-model + - --method.machine_type=n1-standard-4 + args: + - --resource_name_output_artifact_path + - {outputPath: endpoint} + - --init.project + - {inputValue: project} + - --method.endpoint + - {inputPath: endpoint} + - --init.model_name + - {inputPath: model} + + + Args: + method (Callable): A MB SDK Method + should_serialize_init (bool): Whether to also include the constructor + params in the component + + Returns: + A Component wrapper that accepts the MB SDK params and returns a Task. + """ method_name = method.__name__ method_signature = inspect.signature(method) @@ -515,12 +539,15 @@ def convert_method_to_component(cls: aiplatform.base.VertexAiResourceNoun, init_signature, is_init_signature=True, self_type=cls, - component_param_name_to_mb_sdk_param_name=component_param_name_to_mb_sdk_param_name + component_param_name_to_mb_sdk_param_name=component_param_name_to_mb_sdk_param_name, ) # use this to partition args to method or constructor - init_arg_names = set( - init_signature.parameters.keys()) if should_serialize_init else set([]) + init_arg_names = ( + set(init_signature.parameters.keys()) + if should_serialize_init + else set([]) + ) # determines outputs for this component output_type = resolve_annotation(method_signature.return_annotation) @@ -528,20 +555,25 @@ def convert_method_to_component(cls: aiplatform.base.VertexAiResourceNoun, output_args = [] if output_type: output_metadata_name, output_metadata_type = map_resource_to_metadata_type( - output_type) + output_type + ) try: output_spec = structures.OutputSpec( name=output_metadata_name, type=output_metadata_type, ) output_uri_placeholder = structures.OutputUriPlaceholder( - output_name=output_metadata_name) + output_name=output_metadata_name + ) except TypeError: from kfp.components import placeholders + output_spec = structures.OutputSpec( - type=(output_metadata_type + '@0.0.1')) + type=(output_metadata_type + '@0.0.1') + ) output_uri_placeholder = placeholders.OutputUriPlaceholder( - output_name=output_metadata_name) + output_name=output_metadata_name + ) output_specs.append(output_spec) @@ -555,13 +587,13 @@ def convert_method_to_component(cls: aiplatform.base.VertexAiResourceNoun, def make_args(args_to_serialize: Dict[str, Dict[str, Any]]) -> List[str]: """Takes the args dictionary and returns command-line args. - Args: - args_to_serialize: Dictionary of format {'init': {'param_name_1': - param_1}, {'method'}: {'param_name_2': param_name_2}} + Args: + args_to_serialize: Dictionary of format {'init': {'param_name_1': + param_1}, {'method'}: {'param_name_2': param_name_2}} - Returns: - Serialized args compatible with Component YAML - """ + Returns: + Serialized args compatible with Component YAML + """ additional_args = [] for key, args in args_to_serialize.items(): for arg_key, value in args.items(): @@ -607,7 +639,8 @@ def component_yaml_generator(**kwargs): # if we serialize we need to include the argument as input # perhaps, another option is to embed in yaml as json serialized list component_param_name = component_param_name_to_mb_sdk_param_name.get( - key, key) + key, key + ) component_param_type = None if isinstance(value, kfp.dsl._pipeline_param.PipelineParam) or serializer: if is_mb_sdk_resource_noun_type(param_type): @@ -633,11 +666,12 @@ def component_yaml_generator(**kwargs): structures.InputSpec( name=key, type=component_param_type, - )) + ) + ) input_args.append(f'--{prefix_key}.{component_param_name}') if is_mb_sdk_resource_noun_type(param_type): input_args.append( - f'{{{{$.inputs.artifacts[\'{key}\'].metadata[\'resourceName\']}}}}' + f"{{{{$.inputs.artifacts['{key}'].metadata['resourceName']}}}}" ) else: input_args.append(structures.InputValuePlaceholder(input_name=key)) @@ -670,15 +704,19 @@ def component_yaml_generator(**kwargs): method_name, ], args=make_args(serialized_args) + output_args + input_args, - ))) + ) + ), + ) component_path = tempfile.mktemp() component_spec.save(component_path) return components.load_component_from_file(component_path)(**input_kwargs) - component_yaml_generator.__signature__ = signatures_union( - init_signature, - method_signature) if should_serialize_init else method_signature + component_yaml_generator.__signature__ = ( + signatures_union(init_signature, method_signature) + if should_serialize_init + else method_signature + ) # Create a docstring based on the new signature. new_args_dict = {} @@ -686,17 +724,22 @@ def component_yaml_generator(**kwargs): filter_docstring_args( signature=method_signature, docstring=inspect.getdoc(method), - is_init_signature=False)) + is_init_signature=False, + ) + ) if should_serialize_init: new_args_dict.update( filter_docstring_args( signature=init_signature, docstring=inspect.getdoc(init_method), - is_init_signature=True)) + is_init_signature=True, + ) + ) component_yaml_generator.__doc__ = generate_docstring( args_dict=new_args_dict, signature=component_yaml_generator.__signature__, - method_docstring=inspect.getdoc(method)) + method_docstring=inspect.getdoc(method), + ) # TODO Possibly rename method diff --git a/components/google-cloud/google_cloud_pipeline_components/container/aiplatform/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/aiplatform/remote_runner.py index 7fcb6bfb42f..8d5c3faade1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/aiplatform/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/aiplatform/remote_runner.py @@ -46,12 +46,12 @@ def split_args(kwargs: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Splits args into constructor and method args. - Args: - kwargs: kwargs with parameter names preprended with init or method + Args: + kwargs: kwargs with parameter names preprended with init or method - Returns: - constructor kwargs, method kwargs - """ + Returns: + constructor kwargs, method kwargs + """ init_args = {} method_args = {} @@ -68,8 +68,9 @@ def write_to_artifact(executor_input, text): """Write output to local artifact and metadata path (uses GCSFuse).""" output_artifacts = {} - for name, artifacts in executor_input.get('outputs', {}).get('artifacts', - {}).items(): + for name, artifacts in ( + executor_input.get('outputs', {}).get('artifacts', {}).items() + ): artifacts_list = artifacts.get('artifacts') if artifacts_list: output_artifacts[name] = artifacts_list[0] @@ -95,7 +96,8 @@ def write_to_artifact(executor_input, text): elif text.startswith(RESOURCE_PREFIX['google_cloud_storage_gcs_fuse']): uri_with_prefix = text.replace( RESOURCE_PREFIX['google_cloud_storage_gcs_fuse'], - RESOURCE_PREFIX.get('google_cloud_storage')) + RESOURCE_PREFIX.get('google_cloud_storage'), + ) # "bq://": For BigQuery resources. elif text.startswith(RESOURCE_PREFIX.get('bigquery')): @@ -106,14 +108,15 @@ def write_to_artifact(executor_input, text): runtime_artifact = { 'name': artifact.get('name'), 'uri': uri_with_prefix, - 'metadata': metadata + 'metadata': metadata, } artifacts_list = {'artifacts': [runtime_artifact]} executor_output['artifacts'][name] = artifacts_list os.makedirs( - os.path.dirname(executor_input['outputs']['outputFile']), exist_ok=True) + os.path.dirname(executor_input['outputs']['outputFile']), exist_ok=True + ) with open(executor_input['outputs']['outputFile'], 'w') as f: f.write(json.dumps(executor_output)) @@ -121,14 +124,15 @@ def write_to_artifact(executor_input, text): def resolve_input_args(value, type_to_resolve): """If this is an input from Pipelines, read it directly from gcs.""" if inspect.isclass(type_to_resolve) and issubclass( - type_to_resolve, aiplatform.base.VertexAiResourceNoun): + type_to_resolve, aiplatform.base.VertexAiResourceNoun + ): # Remove '/gcs/' prefix before attempting to remove `aiplatform` prefix if value.startswith(RESOURCE_PREFIX['google_cloud_storage_gcs_fuse']): - value = value[len(RESOURCE_PREFIX['google_cloud_storage_gcs_fuse']):] + value = value[len(RESOURCE_PREFIX['google_cloud_storage_gcs_fuse']) :] # Remove `aiplatform` prefix from resource name if value.startswith(RESOURCE_PREFIX.get('aiplatform')): prefix_str = f"{RESOURCE_PREFIX['aiplatform']}{AIPLATFORM_API_VERSION}/" - value = value[len(prefix_str):] + value = value[len(prefix_str) :] # No action needed for Google Cloud Storage prefix. # No action needed for BigQuery resource names. @@ -141,11 +145,11 @@ def resolve_init_args(key, value): # Remove '/gcs/' prefix before attempting to remove `aiplatform` prefix if value.startswith(RESOURCE_PREFIX['google_cloud_storage_gcs_fuse']): # not a resource noun, remove the /gcs/ prefix - value = value[len(RESOURCE_PREFIX['google_cloud_storage_gcs_fuse']):] + value = value[len(RESOURCE_PREFIX['google_cloud_storage_gcs_fuse']) :] # Remove `aiplatform` prefix from resource name if value.startswith(RESOURCE_PREFIX.get('aiplatform')): prefix_str = f"{RESOURCE_PREFIX['aiplatform']}{AIPLATFORM_API_VERSION}/" - value = value[len(prefix_str):] + value = value[len(prefix_str) :] # No action needed for Google Cloud Storage prefix. # No action needed for BigQuery resource names. @@ -183,9 +187,9 @@ def cast(value: str, annotation_type: Type[T]) -> T: return annotation_type(value) -def prepare_parameters(kwargs: Dict[str, Any], - method: Callable, - is_init: bool = False): +def prepare_parameters( + kwargs: Dict[str, Any], method: Callable, is_init: bool = False +): """Prepares parameters passed into components before calling SDK. 1. Determines the annotation type that should used with the parameter @@ -203,8 +207,11 @@ def prepare_parameters(kwargs: Dict[str, Any], if key in kwargs: value = kwargs[key] param_type = utils.resolve_annotation(param.annotation) - value = resolve_init_args(key, value) if is_init else resolve_input_args( - value, param_type) + value = ( + resolve_init_args(key, value) + if is_init + else resolve_input_args(value, param_type) + ) deserializer = utils.get_deserializer(param_type) if deserializer: value = deserializer(value) @@ -231,14 +238,22 @@ def prepare_parameters(kwargs: Dict[str, Any], def attach_system_labels(method_args, cls_name, method_name): """Add or append the system labels to the labels arg.""" - if cls_name in [ - 'ImageDataset', 'TabularDataset', 'TextDataset', 'TimeSeriesDataset', - 'VideoDataset' - ] and method_name == 'create': + if ( + cls_name + in [ + 'ImageDataset', + 'TabularDataset', + 'TextDataset', + 'TimeSeriesDataset', + 'VideoDataset', + ] + and method_name == 'create' + ): method_args['labels'] = json.dumps( gcp_labels_util.attach_system_labels( - json.loads(method_args['labels']) if 'labels' in - method_args else {})) + json.loads(method_args['labels']) if 'labels' in method_args else {} + ) + ) return method_args @@ -258,9 +273,11 @@ def runner(cls_name, method_name, executor_input, kwargs): prepare_parameters(serialized_args[METHOD_KEY], method, is_init=False) with execution_context.ExecutionContext( - on_cancel=getattr(obj, 'cancel', None)): + on_cancel=getattr(obj, 'cancel', None) + ): print( - f'method:{method} is being called with parameters {serialized_args[METHOD_KEY]}' + f'method:{method} is being called with parameters' + f' {serialized_args[METHOD_KEY]}' ) output = method(**serialized_args[METHOD_KEY]) print('resource_name: %s', obj.resource_name) @@ -283,7 +300,6 @@ def main(): key_value = None for arg in unknown_args: - print(arg) # Remove whitespace from arg. diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/dataflow/flex_template/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/dataflow/flex_template/remote_runner.py index 027c797269c..af40bf7a71e 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/experimental/dataflow/flex_template/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/dataflow/flex_template/remote_runner.py @@ -35,7 +35,7 @@ _CONNECTION_ERROR_RETRY_LIMIT = 5 -_CONNECTION_RETRY_BACKOFF_FACTOR = 2. +_CONNECTION_RETRY_BACKOFF_FACTOR = 2.0 _DATAFLOW_URI_PREFIX = 'https://dataflow.googleapis.com/v1b3' _DATAFLOW_JOB_URI_TEMPLATE = rf'({_DATAFLOW_URI_PREFIX}/projects/(?P.*)/locations/(?P.*)/jobs/(?P.*))' @@ -44,7 +44,9 @@ def insert_system_labels_into_payload(payload): job_spec = json.loads(payload) try: - labels = job_spec['launch_parameter']['environment']['additional_user_labels'] + labels = job_spec['launch_parameter']['environment'][ + 'additional_user_labels' + ] except KeyError: labels = {} @@ -52,7 +54,10 @@ def insert_system_labels_into_payload(payload): job_spec['launch_parameter'] = {} if 'environment' not in job_spec['launch_parameter'].keys(): job_spec['launch_parameter']['environment'] = {} - if 'additional_user_labels' not in job_spec['launch_parameter']['environment'].keys(): + if ( + 'additional_user_labels' + not in job_spec['launch_parameter']['environment'].keys() + ): job_spec['launch_parameter']['environment']['additional_user_labels'] = {} labels = gcp_labels_util.attach_system_labels(labels) @@ -60,7 +65,7 @@ def insert_system_labels_into_payload(payload): return json.dumps(job_spec) -class DataflowFlexTemplateRemoteRunner(): +class DataflowFlexTemplateRemoteRunner: """Common module for creating Dataproc Flex Template jobs.""" def __init__( @@ -84,13 +89,13 @@ def _get_session(self) -> Session: total=_CONNECTION_ERROR_RETRY_LIMIT, status_forcelist=[429, 503], backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, - allowed_methods=['GET', 'POST'] + allowed_methods=['GET', 'POST'], ) adapter = HTTPAdapter(max_retries=retry) session = requests.Session() session.headers.update({ 'Content-Type': 'application/json', - 'User-Agent': 'google-cloud-pipeline-components' + 'User-Agent': 'google-cloud-pipeline-components', }) session.mount('https://', adapter) return session @@ -110,7 +115,7 @@ def _post_resource(self, url: str, post_data: str) -> Dict[str, Any]: """ if not self._creds.valid: self._creds.refresh(google.auth.transport.requests.Request()) - headers = {'Authorization': 'Bearer '+ self._creds.token} + headers = {'Authorization': 'Bearer ' + self._creds.token} result = self._session.post(url=url, data=post_data, headers=headers) json_data = {} @@ -120,19 +125,24 @@ def _post_resource(self, url: str, post_data: str) -> Dict[str, Any]: return json_data except requests.exceptions.HTTPError as err: try: - err_msg = ('Dataflow service returned HTTP status {} from POST: {}. Status: {}, Message: {}' - .format(err.response.status_code, - err.request.url, - json_data['error']['status'], - json_data['error']['message'])) + err_msg = ( + 'Dataflow service returned HTTP status {} from POST: {}. Status:' + ' {}, Message: {}'.format( + err.response.status_code, + err.request.url, + json_data['error']['status'], + json_data['error']['message'], + ) + ) except (KeyError, TypeError): err_msg = err.response.text # Raise RuntimeError with the error returned from the Dataflow service. # Suppress HTTPError as it provides no actionable feedback. raise RuntimeError(err_msg) from None except json.decoder.JSONDecodeError as err: - raise RuntimeError('Failed to decode JSON from response:\n{}' - .format(err.doc)) from err + raise RuntimeError( + 'Failed to decode JSON from response:\n{}'.format(err.doc) + ) from err def check_if_job_exists(self) -> Union[Dict[str, Any], None]: """Check if a Dataflow job already exists. @@ -145,16 +155,21 @@ def check_if_job_exists(self) -> Union[Dict[str, Any], None]: Raises: ValueError: Job resource uri format is invalid. """ - if path.exists(self._gcp_resources) and os.stat(self._gcp_resources).st_size != 0: + if ( + path.exists(self._gcp_resources) + and os.stat(self._gcp_resources).st_size != 0 + ): with open(self._gcp_resources) as f: serialized_gcp_resources = f.read() - job_resources = json_format.Parse(serialized_gcp_resources, - gcp_resources_pb2.GcpResources()) + job_resources = json_format.Parse( + serialized_gcp_resources, gcp_resources_pb2.GcpResources() + ) # Resources should only contain one item. if len(job_resources.resources) != 1: raise ValueError( - f'gcp_resources should contain one resource, found {len(job_resources.resources)}' + 'gcp_resources should contain one resource, found' + f' {len(job_resources.resources)}' ) # Validate the format of the Job resource uri. @@ -165,10 +180,12 @@ def check_if_job_exists(self) -> Union[Dict[str, Any], None]: matched_location = match.group('location') matched_job_id = match.group('job') except AttributeError as err: - raise ValueError('Invalid Resource uri: {}. Expect: {}.'.format( - job_resources.resources[0].resource_uri, - 'https://dataflow.googleapis.com/v1b3/projects/[projectId]/locations/[location]/jobs/[jobId]' - )) from err + raise ValueError( + 'Invalid Resource uri: {}. Expect: {}.'.format( + job_resources.resources[0].resource_uri, + 'https://dataflow.googleapis.com/v1b3/projects/[projectId]/locations/[location]/jobs/[jobId]', + ) + ) from err # Return the Job resource uri. return job_resources.resources[0].resource_uri @@ -203,8 +220,8 @@ def create_job( job_uri = f"{_DATAFLOW_URI_PREFIX}/projects/{job['projectId']}/locations/{job['location']}/jobs/{job['id']}" except KeyError as err: raise RuntimeError( - f'Dataflow Flex Template launch failed. ' - f'Cannot determine the job resource uri from the response:\n' + 'Dataflow Flex Template launch failed. ' + 'Cannot determine the job resource uri from the response:\n' f'{response}' ) from err @@ -242,18 +259,22 @@ def launch_flex_template( """ try: job_spec = json_util.recursive_remove_empty( - json.loads(insert_system_labels_into_payload(payload), strict=False)) + json.loads(insert_system_labels_into_payload(payload), strict=False) + ) except json.decoder.JSONDecodeError as err: - raise RuntimeError('Failed to decode JSON from payload: {}' - .format(err.doc)) from err + raise RuntimeError( + 'Failed to decode JSON from payload: {}'.format(err.doc) + ) from err - remote_runner = DataflowFlexTemplateRemoteRunner(type, project, location, - gcp_resources) + remote_runner = DataflowFlexTemplateRemoteRunner( + type, project, location, gcp_resources + ) if not remote_runner.check_if_job_exists(): if 'launch_parameter' not in job_spec.keys(): job_spec['launch_parameter'] = {} if 'job_name' not in job_spec['launch_parameter'].keys(): now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') - job_spec['launch_parameter']['job_name'] = '-'.join([type.lower(), now, - uuid.uuid4().hex[:8]]) + job_spec['launch_parameter']['job_name'] = '-'.join( + [type.lower(), now, uuid.uuid4().hex[:8]] + ) remote_runner.create_job(type, job_spec) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/evaluation/import_model_evaluation.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/evaluation/import_model_evaluation.py index b3b3ebe4822..7d2581a780b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/experimental/evaluation/import_model_evaluation.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/evaluation/import_model_evaluation.py @@ -29,85 +29,85 @@ from typing import Any, Dict PROBLEM_TYPE_TO_SCHEMA_URI = { - 'classification': - 'gs://google-cloud-aiplatform/schema/modelevaluation/classification_metrics_1.0.0.yaml', - 'regression': - 'gs://google-cloud-aiplatform/schema/modelevaluation/regression_metrics_1.0.0.yaml', - 'forecasting': - 'gs://google-cloud-aiplatform/schema/modelevaluation/forecasting_metrics_1.0.0.yaml', + 'classification': 'gs://google-cloud-aiplatform/schema/modelevaluation/classification_metrics_1.0.0.yaml', + 'regression': 'gs://google-cloud-aiplatform/schema/modelevaluation/regression_metrics_1.0.0.yaml', + 'forecasting': 'gs://google-cloud-aiplatform/schema/modelevaluation/forecasting_metrics_1.0.0.yaml', } MODEL_EVALUATION_RESOURCE_TYPE = 'ModelEvaluation' MODEL_EVALUATION_SLICE_RESOURCE_TYPE = 'ModelEvaluationSlice' SLICE_BATCH_IMPORT_LIMIT = 50 + def _make_parent_dirs_and_return_path(file_path: str): os.makedirs(os.path.dirname(file_path), exist_ok=True) return file_path parser = argparse.ArgumentParser( - prog='Vertex Model Service evaluation importer', description='') -parser.add_argument( - '--metrics', - dest='metrics', - type=str, - default=None) + prog='Vertex Model Service evaluation importer', description='' +) +parser.add_argument('--metrics', dest='metrics', type=str, default=None) parser.add_argument( '--classification_metrics', dest='classification_metrics', type=str, - default=None) + default=None, +) parser.add_argument( - '--forecasting_metrics', - dest='forecasting_metrics', - type=str, - default=None) + '--forecasting_metrics', dest='forecasting_metrics', type=str, default=None +) parser.add_argument( - '--regression_metrics', - dest='regression_metrics', - type=str, - default=None) + '--regression_metrics', dest='regression_metrics', type=str, default=None +) parser.add_argument( '--feature_attributions', dest='feature_attributions', type=str, - default=None) + default=None, +) parser.add_argument( - '--metrics_explanation', dest='metrics_explanation', type=str, default=None) + '--metrics_explanation', dest='metrics_explanation', type=str, default=None +) parser.add_argument('--explanation', dest='explanation', type=str, default=None) parser.add_argument( - '--problem_type', - dest='problem_type', - type=str, - default=None) + '--problem_type', dest='problem_type', type=str, default=None +) parser.add_argument( - '--display_name', nargs='?', dest='display_name', type=str, default=None) + '--display_name', nargs='?', dest='display_name', type=str, default=None +) parser.add_argument( - '--pipeline_job_id', dest='pipeline_job_id', type=str, default=None) + '--pipeline_job_id', dest='pipeline_job_id', type=str, default=None +) parser.add_argument( '--pipeline_job_resource_name', dest='pipeline_job_resource_name', type=str, - default=None) + default=None, +) parser.add_argument( - '--dataset_path', nargs='?', dest='dataset_path', type=str, default=None) + '--dataset_path', nargs='?', dest='dataset_path', type=str, default=None +) parser.add_argument( - '--dataset_paths', nargs='?', dest='dataset_paths', type=str, default='[]') + '--dataset_paths', nargs='?', dest='dataset_paths', type=str, default='[]' +) parser.add_argument( - '--dataset_type', nargs='?', dest='dataset_type', type=str, default=None) + '--dataset_type', nargs='?', dest='dataset_type', type=str, default=None +) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, +) parser.add_argument( '--gcp_resources', dest='gcp_resources', type=_make_parent_dirs_and_return_path, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, +) def main(argv): @@ -131,48 +131,73 @@ def main(argv): metrics_file_path = parsed_args.metrics problem_type = parsed_args.problem_type - metrics_file_path = metrics_file_path if not metrics_file_path.startswith( - 'gs://') else '/gcs' + metrics_file_path[4:] + metrics_file_path = ( + metrics_file_path + if not metrics_file_path.startswith('gs://') + else '/gcs' + metrics_file_path[4:] + ) schema_uri = PROBLEM_TYPE_TO_SCHEMA_URI.get(problem_type) with open(metrics_file_path) as metrics_file: - all_sliced_metrics = [{ - **one_slice, 'metrics': - to_value(next(iter(one_slice['metrics'].values()))) - } for one_slice in json.loads(metrics_file.read())['slicedMetrics']] + all_sliced_metrics = [ + { + **one_slice, + 'metrics': to_value(next(iter(one_slice['metrics'].values()))), + } + for one_slice in json.loads(metrics_file.read())['slicedMetrics'] + ] overall_slice = all_sliced_metrics[0] sliced_metrics = all_sliced_metrics[1:] model_evaluation = { 'metrics': overall_slice['metrics'], - 'metrics_schema_uri': schema_uri + 'metrics_schema_uri': schema_uri, } - if parsed_args.explanation and parsed_args.explanation == "{{$.inputs.artifacts['explanation'].metadata['explanation_gcs_path']}}": + if ( + parsed_args.explanation + and parsed_args.explanation + == "{{$.inputs.artifacts['explanation'].metadata['explanation_gcs_path']}}" + ): # metrics_explanation must contain explanation_gcs_path when provided. - logging.error( - '"explanation" must contain explanations when provided.') + logging.error('"explanation" must contain explanations when provided.') sys.exit(13) elif parsed_args.feature_attributions: - explanation_file_name = parsed_args.feature_attributions if not parsed_args.feature_attributions.startswith( - 'gs://') else '/gcs' + parsed_args.feature_attributions[4:] + explanation_file_name = ( + parsed_args.feature_attributions + if not parsed_args.feature_attributions.startswith('gs://') + else '/gcs' + parsed_args.feature_attributions[4:] + ) elif parsed_args.explanation: - explanation_file_name = parsed_args.explanation if not parsed_args.explanation.startswith( - 'gs://') else '/gcs' + parsed_args.explanation[4:] - elif parsed_args.metrics_explanation and parsed_args.metrics_explanation != "{{$.inputs.artifacts['metrics'].metadata['explanation_gcs_path']}}": - explanation_file_name = parsed_args.metrics_explanation if not parsed_args.metrics_explanation.startswith( - 'gs://') else '/gcs' + parsed_args.metrics_explanation[4:] + explanation_file_name = ( + parsed_args.explanation + if not parsed_args.explanation.startswith('gs://') + else '/gcs' + parsed_args.explanation[4:] + ) + elif ( + parsed_args.metrics_explanation + and parsed_args.metrics_explanation + != "{{$.inputs.artifacts['metrics'].metadata['explanation_gcs_path']}}" + ): + explanation_file_name = ( + parsed_args.metrics_explanation + if not parsed_args.metrics_explanation.startswith('gs://') + else '/gcs' + parsed_args.metrics_explanation[4:] + ) else: explanation_file_name = None if explanation_file_name: with open(explanation_file_name) as explanation_file: model_evaluation['model_explanation'] = { - 'mean_attributions': [{ - 'feature_attributions': - to_value( - json.loads(explanation_file.read())['explanation'] - ['attributions'][0]['featureAttributions']) - }] + 'mean_attributions': [ + { + 'feature_attributions': to_value( + json.loads(explanation_file.read())['explanation'][ + 'attributions' + ][0]['featureAttributions'] + ) + } + ] } if parsed_args.display_name: @@ -181,7 +206,8 @@ def main(argv): try: dataset_paths = json.loads(parsed_args.dataset_paths) if not isinstance(dataset_paths, list) or not all( - isinstance(el, str) for el in dataset_paths): + isinstance(el, str) for el in dataset_paths + ): dataset_paths = [] except ValueError: dataset_paths = [] @@ -190,22 +216,26 @@ def main(argv): dataset_paths.append(parsed_args.dataset_path) metadata = { - key: value for key, value in { + key: value + for key, value in { 'pipeline_job_id': parsed_args.pipeline_job_id, 'pipeline_job_resource_name': parsed_args.pipeline_job_resource_name, 'evaluation_dataset_type': parsed_args.dataset_type, 'evaluation_dataset_path': dataset_paths or None, - }.items() if value is not None + }.items() + if value is not None } if metadata: model_evaluation['metadata'] = to_value(metadata) client = aiplatform.gapic.ModelServiceClient( client_info=gapic_v1.client_info.ClientInfo( - user_agent='google-cloud-pipeline-components'), + user_agent='google-cloud-pipeline-components' + ), client_options={ 'api_endpoint': api_endpoint, - }) + }, + ) import_model_evaluation_response = client.import_model_evaluation( parent=parsed_args.model_name, model_evaluation=model_evaluation, @@ -225,17 +255,21 @@ def main(argv): slice_resource_names = [] # BatchImportModelEvaluationSlices has a size limit of 50 slices. for i in range(0, len(sliced_metrics), SLICE_BATCH_IMPORT_LIMIT): - slice_resource_names.extend(client.batch_import_model_evaluation_slices( - parent=model_evaluation_name, - model_evaluation_slices=[ - { - 'metrics': one_slice['metrics'], - 'metrics_schema_uri': schema_uri, - 'slice_': to_slice(one_slice['singleOutputSlicingSpec']), - } - for one_slice in sliced_metrics[i:i+SLICE_BATCH_IMPORT_LIMIT] - ], - ).imported_model_evaluation_slices) + slice_resource_names.extend( + client.batch_import_model_evaluation_slices( + parent=model_evaluation_name, + model_evaluation_slices=[ + { + 'metrics': one_slice['metrics'], + 'metrics_schema_uri': schema_uri, + 'slice_': to_slice(one_slice['singleOutputSlicingSpec']), + } + for one_slice in sliced_metrics[ + i : i + SLICE_BATCH_IMPORT_LIMIT + ] + ], + ).imported_model_evaluation_slices + ) for slice_resource in slice_resource_names: slice_mlmd_resource = resources.resources.add() @@ -276,10 +310,12 @@ def to_value(value): return Value(string_value=value) elif isinstance(value, dict): return Value( - struct_value=Struct(fields={k: to_value(v) for k, v in value.items()})) + struct_value=Struct(fields={k: to_value(v) for k, v in value.items()}) + ) elif isinstance(value, list): return Value( - list_value=ListValue(values=[to_value(item) for item in value])) + list_value=ListValue(values=[to_value(item) for item in value]) + ) else: raise ValueError('Unsupported data type: {}'.format(type(value))) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/notebooks/executor.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/notebooks/executor.py index 009b92bd69b..496edf25ed6 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/experimental/notebooks/executor.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/notebooks/executor.py @@ -56,12 +56,14 @@ def _check_prefix(s, prefix='gs://'): # Does not provide defaults to match with the API's design. # Default values are set in component.yaml. - if (not getattr(args, 'master_type', None) or - not getattr(args, 'input_notebook_file', None) or - not getattr(args, 'container_image_uri', None) or - not getattr(args, 'output_notebook_folder', None) or - not getattr(args, 'job_type', None) or - not getattr(args, 'kernel_spec', None)): + if ( + not getattr(args, 'master_type', None) + or not getattr(args, 'input_notebook_file', None) + or not getattr(args, 'container_image_uri', None) + or not getattr(args, 'output_notebook_folder', None) + or not getattr(args, 'job_type', None) + or not getattr(args, 'kernel_spec', None) + ): raise AttributeError('Missing a required argument for the API.') betpl = {} @@ -91,10 +93,9 @@ def _check_prefix(s, prefix='gs://'): return body -def build_response(state='', - output_notebook_file='', - gcp_resources='', - error=''): +def build_response( + state='', output_notebook_file='', gcp_resources='', error='' +): return (state, output_notebook_file, gcp_resources, error) @@ -125,14 +126,16 @@ def execute_notebook(args): """Executes a notebook.""" client_info = gapic_v1.client_info.ClientInfo( - user_agent='google-cloud-pipeline-components',) + user_agent='google-cloud-pipeline-components', + ) client_notebooks = notebooks.NotebookServiceClient(client_info=client_info) client_vertexai_jobs = vertex_ai_beta.JobServiceClient( client_options={ 'api_endpoint': f'{args.location}-aiplatform.googleapis.com' }, - client_info=client_info) + client_info=client_info, + ) execution_parent = f'projects/{args.project}/locations/{args.location}' execution_fullname = f'{execution_parent}/executions/{args.execution_id}' @@ -144,15 +147,18 @@ def execute_notebook(args): _ = client_notebooks.create_execution( parent=execution_parent, execution_id=args.execution_id, - execution=execution_template) - gcp_resources = json.dumps({ - 'resources': [{ - 'resourceType': - 'type.googleapis.com/google.cloud.notebooks.v1.Execution', - 'resourceUri': - execution_fullname - },] - }) + execution=execution_template, + ) + gcp_resources = json.dumps( + { + 'resources': [ + { + 'resourceType': 'type.googleapis.com/google.cloud.notebooks.v1.Execution', + 'resourceUri': execution_fullname, + }, + ] + } + ) # pylint: disable=broad-except except Exception as e: response = build_response(error=f'create_execution() failed: {e}') @@ -174,7 +180,8 @@ def execute_notebook(args): return build_response( state=Execution.State(execution.state).name, output_notebook_file=execution.output_notebook_file, - gcp_resources=gcp_resources) + gcp_resources=gcp_resources, + ) # Waits for execution to finish. print('Blocking pipeline...') @@ -188,7 +195,8 @@ def execute_notebook(args): # pylint: disable=broad-except except Exception as e: response = build_response( - error=f'get_execution() for blocking pipeline failed: {e}') + error=f'get_execution() for blocking pipeline failed: {e}' + ) handle_error(args.fail_pipeline, response) return response @@ -209,7 +217,8 @@ def execute_notebook(args): while True: try: custom_job = client_vertexai_jobs.get_custom_job( - name=execution_job_uri) + name=execution_job_uri + ) # pylint: disable=broad-except except Exception as e: response = build_response(error=f'get_custom_job() failed: {e}') @@ -225,7 +234,8 @@ def execute_notebook(args): custom_job_error = getattr(custom_job, 'error', None) if custom_job_error: response = build_response( - error=f'Error {custom_job_error.code}: {custom_job_error.message}') + error=f'Error {custom_job_error.code}: {custom_job_error.message}' + ) handle_error(args.fail_pipeline, (None, response)) return response @@ -233,137 +243,161 @@ def execute_notebook(args): # had a problem. The previous loop was in hope to find the error message, # we didn't have any so we return the execution state as the message. response = build_response( - error=f'Execution finished: {Execution.State(execution_state).name}') + error=f'Execution finished: {Execution.State(execution_state).name}' + ) handle_error(args.fail_pipeline, (None, response)) return response return build_response( state=Execution.State(execution_state).name, output_notebook_file=execution.output_notebook_file, - gcp_resources=gcp_resources) + gcp_resources=gcp_resources, + ) def main(): def _deserialize_bool(s) -> bool: # pylint: disable=g-import-not-at-top from distutils import util + return util.strtobool(s) == 1 def _serialize_str(str_value: str) -> str: if not isinstance(str_value, str): raise TypeError( - f'Value "{str_value}" has type "{type(str_value)}" instead of str.') + f'Value "{str_value}" has type "{type(str_value)}" instead of str.' + ) return str_value # pylint: disable=g-import-not-at-top import argparse + parser = argparse.ArgumentParser( prog='Notebooks Executor', - description='Executes a notebook using the Notebooks Executor API.') + description='Executes a notebook using the Notebooks Executor API.', + ) parser.add_argument( '--project', dest='project', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--input_notebook_file', dest='input_notebook_file', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--output_notebook_folder', dest='output_notebook_folder', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--execution_id', dest='execution_id', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--location', dest='location', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--master_type', dest='master_type', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--container_image_uri', dest='container_image_uri', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--accelerator_type', dest='accelerator_type', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--accelerator_core_count', dest='accelerator_core_count', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--labels', dest='labels', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--params_yaml_file', dest='params_yaml_file', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--parameters', dest='parameters', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--service_account', dest='service_account', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_type', dest='job_type', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--kernel_spec', dest='kernel_spec', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--block_pipeline', dest='block_pipeline', type=_deserialize_bool, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--fail_pipeline', dest='fail_pipeline', type=_deserialize_bool, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( - '----output-paths', dest='_output_paths', type=str, nargs=4) + '----output-paths', dest='_output_paths', type=str, nargs=4 + ) args, _ = parser.parse_known_args() parsed_args = vars(parser.parse_args()) @@ -379,6 +413,7 @@ def _serialize_str(str_value: str) -> str: ] import os + for idx, output_file in enumerate(output_files): try: os.makedirs(os.path.dirname(output_file)) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/sklearn/train_test_split_jsonl.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/sklearn/train_test_split_jsonl.py index 7f58f038893..e1e6bfd749c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/experimental/sklearn/train_test_split_jsonl.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/sklearn/train_test_split_jsonl.py @@ -27,7 +27,7 @@ def split_dataset_into_train_and_validation( validation_data_path: str, input_data_path: str, validation_split: float = 0.2, - random_seed: int = 0 + random_seed: int = 0, ) -> None: """Split JSON(L) Data into training and validation data. @@ -42,13 +42,11 @@ def split_dataset_into_train_and_validation( data = pd.read_json(input_data_path, lines=True, orient='records') - train, test = train_test_split(data, test_size=validation_split, - random_state=random_seed, - shuffle=True) - train.to_json(path_or_buf=training_data_path, - lines=True, orient='records') - test.to_json(path_or_buf=validation_data_path, - lines=True, orient='records') + train, test = train_test_split( + data, test_size=validation_split, random_state=random_seed, shuffle=True + ) + train.to_json(path_or_buf=training_data_path, lines=True, orient='records') + test.to_json(path_or_buf=validation_data_path, lines=True, orient='records') def _parse_args(args) -> Dict[str, Any]: @@ -56,42 +54,42 @@ def _parse_args(args) -> Dict[str, Any]: Args: args: A list of arguments. + Returns: A tuple containing an argparse.Namespace class instance holding parsed args, and a list containing all unknown args. """ parser = argparse.ArgumentParser( - prog='Text classification data processing', description='') - parser.add_argument('--training-data-path', - type=str, - required=True, - default=argparse.SUPPRESS) - parser.add_argument('--validation-data-path', - type=str, - required=True, - default=argparse.SUPPRESS) - parser.add_argument('--input-data-path', - type=str, - required=True, - default=argparse.SUPPRESS) - parser.add_argument('--validation-split', - type=float, - required=False, - default=0.2) - parser.add_argument('--random-seed', - type=int, - required=False, - default=0) + prog='Text classification data processing', description='' + ) + parser.add_argument( + '--training-data-path', type=str, required=True, default=argparse.SUPPRESS + ) + parser.add_argument( + '--validation-data-path', + type=str, + required=True, + default=argparse.SUPPRESS, + ) + parser.add_argument( + '--input-data-path', type=str, required=True, default=argparse.SUPPRESS + ) + parser.add_argument( + '--validation-split', type=float, required=False, default=0.2 + ) + parser.add_argument('--random-seed', type=int, required=False, default=0) parsed_args, _ = parser.parse_known_args(args) # Creating the directory where the output file is created. The parent # directory does not exist when building container components. pathlib.Path(parsed_args.training_data_path).parent.mkdir( - parents=True, exist_ok=True) + parents=True, exist_ok=True + ) pathlib.Path(parsed_args.validation_data_path).parent.mkdir( - parents=True, exist_ok=True) + parents=True, exist_ok=True + ) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/tensorboard/tensorboard_experiment_creator.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/tensorboard/tensorboard_experiment_creator.py index 30f343fc3e0..cd686bc7f68 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/experimental/tensorboard/tensorboard_experiment_creator.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/tensorboard/tensorboard_experiment_creator.py @@ -25,53 +25,66 @@ _RESOURCE_TYPE = 'TensorboardExperiment' _TENSORBOARD_RESOURCE_NAME_REGEX = re.compile( - r'^projects\/[0-9]+\/locations\/[a-z0-9-]+\/tensorboards\/[0-9]+$') + r'^projects\/[0-9]+\/locations\/[a-z0-9-]+\/tensorboards\/[0-9]+$' +) def _make_parent_dirs_and_return_path(file_path: str): os.makedirs(os.path.dirname(file_path), exist_ok=True) return file_path + def main(argv): parser = argparse.ArgumentParser( - prog='Vertex Tensorboard Experiment creator', description='') + prog='Vertex Tensorboard Experiment creator', description='' + ) parser.add_argument( '--tensorboard_resource_name', dest='tensorboard_resource_name', type=str, - default=None) + default=None, + ) parser.add_argument( '--tensorboard_experiment_id', dest='tensorboard_experiment_id', type=str, - default=None) + default=None, + ) parser.add_argument( '--tensorboard_experiment_display_name', dest='tensorboard_experiment_display_name', type=str, - default=None) + default=None, + ) parser.add_argument( '--tensorboard_experiment_description', dest='tensorboard_experiment_description', type=str, - default=None) + default=None, + ) parser.add_argument( '--tensorboard_experiment_labels', dest='tensorboard_experiment_labels', type=dict, - default=None) + default=None, + ) parser.add_argument( '--gcp_resources', dest='gcp_resources', type=_make_parent_dirs_and_return_path, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(argv) tensorboard_resource_name = parsed_args.tensorboard_resource_name tensorboard_experiment_id = parsed_args.tensorboard_experiment_id - tensorboard_experiment_display_name = parsed_args.tensorboard_experiment_display_name - tensorboard_experiment_description = parsed_args.tensorboard_experiment_description + tensorboard_experiment_display_name = ( + parsed_args.tensorboard_experiment_display_name + ) + tensorboard_experiment_description = ( + parsed_args.tensorboard_experiment_description + ) tensorboard_experiment_labels = parsed_args.tensorboard_experiment_labels if tensorboard_resource_name is None: @@ -79,15 +92,23 @@ def main(argv): if not _TENSORBOARD_RESOURCE_NAME_REGEX.match(tensorboard_resource_name): raise ValueError( - r'Invalid Tensorboard Resource Name: %s. Tensorboard Resource Name must be like projects/\{project_number\}/locations/\{location}\/tensorboards/\{tensorboard_ID\}' - % tensorboard_resource_name) + r'Invalid Tensorboard Resource Name: %s. Tensorboard Resource Name must' + r' be like' + r' projects/\{project_number\}/locations/\{location}\/tensorboards/\{tensorboard_ID\}' + % tensorboard_resource_name + ) try: tensorboard_instance = Tensorboard(tensorboard_resource_name) except exceptions.NotFound as err: raise RuntimeError( - "Tensorboard Insance: {tensorboard_instance} doesn't exist. Please create a Tensorboard or using a existing one." - .format(tensorboard_instance=tensorboard_resource_name)) from err - tensorboard_experiment_resource_name = tensorboard_resource_name + '/experiments/' + tensorboard_experiment_id + "Tensorboard Insance: {tensorboard_instance} doesn't exist. Please" + ' create a Tensorboard or using a existing one.'.format( + tensorboard_instance=tensorboard_resource_name + ) + ) from err + tensorboard_experiment_resource_name = ( + tensorboard_resource_name + '/experiments/' + tensorboard_experiment_id + ) try: tensorboard_experiment = TensorboardExperiment.create( @@ -95,15 +116,19 @@ def main(argv): tensorboard_name=tensorboard_resource_name, display_name=tensorboard_experiment_display_name, description=tensorboard_experiment_description, - labels=tensorboard_experiment_labels) + labels=tensorboard_experiment_labels, + ) except exceptions.AlreadyExists: tensorboard_experiment = TensorboardExperiment( - tensorboard_experiment_resource_name) + tensorboard_experiment_resource_name + ) resources = GcpResources() tensorboard_experiment_resource = resources.resources.add() tensorboard_experiment_resource.resource_type = _RESOURCE_TYPE - tensorboard_experiment_resource.resource_uri = tensorboard_experiment_resource_name + tensorboard_experiment_resource.resource_uri = ( + tensorboard_experiment_resource_name + ) with open(parsed_args.gcp_resources, 'w') as f: f.write(json_format.MessageToJson(resources)) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/vertex_notification_email/executor.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/vertex_notification_email/executor.py index 92f3231eecf..adac80c40a0 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/experimental/vertex_notification_email/executor.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/vertex_notification_email/executor.py @@ -20,5 +20,6 @@ def main(): The notification email component works only on Vertex Pipelines. This function raises an exception when this component is used on Kubeflow Pipelines. """ - raise NotImplementedError('The notification email component is supported ' - 'only on Vertex Pipelines.') + raise NotImplementedError( + 'The notification email component is supported only on Vertex Pipelines.' + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/utils/execution_context.py b/components/google-cloud/google_cloud_pipeline_components/container/utils/execution_context.py index c8dc0499401..c336cb8033b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/utils/execution_context.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/utils/execution_context.py @@ -20,10 +20,10 @@ class ExecutionContext: """Execution context for running inside Google Cloud Pipeline Components. - The base class is aware of the GCPC environment and can cascade - a pipeline cancel event to the operation through ``on_cancel`` handler. - Args: - on_cancel: optional, function to handle KFP cancel event. + The base class is aware of the GCPC environment and can cascade + a pipeline cancel event to the operation through ``on_cancel`` handler. + Args: + on_cancel: optional, function to handle KFP cancel event. """ def __init__(self, on_cancel=None): @@ -33,8 +33,9 @@ def __init__(self, on_cancel=None): def __enter__(self): logging.info('Adding signal handler') - self._original_sigterm_handler = signal.signal(signal.SIGTERM, - self._exit_gracefully) + self._original_sigterm_handler = signal.signal( + signal.SIGTERM, self._exit_gracefully + ) return self def __exit__(self, *_): diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/launcher.py index 44431ad7801..bfd87809a35 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/launcher.py @@ -31,7 +31,8 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/remote_runner.py index 63e224cfdb5..d3945b2ce88 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/batch_prediction_job/remote_runner.py @@ -36,11 +36,14 @@ def sanitize_job_spec(job_spec): """If the job_spec contains explanation metadata, convert to ExplanationMetadata for the job client to recognize.""" - if ('explanation_spec' in job_spec) and ('metadata' - in job_spec['explanation_spec']): - job_spec['explanation_spec'][ - 'metadata'] = explain.ExplanationMetadata.from_json( - json.dumps(job_spec['explanation_spec']['metadata'])) + if ('explanation_spec' in job_spec) and ( + 'metadata' in job_spec['explanation_spec'] + ): + job_spec['explanation_spec']['metadata'] = ( + explain.ExplanationMetadata.from_json( + json.dumps(job_spec['explanation_spec']['metadata']) + ) + ) return job_spec @@ -52,7 +55,8 @@ def create_batch_prediction_job_with_client(job_client, parent, job_spec): 'Creating Batch Prediction job with sanitized job spec: %s', job_spec ) create_batch_prediction_job_fn = job_client.create_batch_prediction_job( - parent=parent, batch_prediction_job=job_spec) + parent=parent, batch_prediction_job=job_spec + ) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) return create_batch_prediction_job_fn @@ -63,7 +67,8 @@ def get_batch_prediction_job_with_client(job_client, job_name): try: get_batch_prediction_job_fn = job_client.get_batch_prediction_job( name=job_name, - retry=retry.Retry(deadline=_BATCH_PREDICTION_RETRY_DEADLINE_SECONDS)) + retry=retry.Retry(deadline=_BATCH_PREDICTION_RETRY_DEADLINE_SECONDS), + ) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) return get_batch_prediction_job_fn @@ -71,23 +76,33 @@ def get_batch_prediction_job_with_client(job_client, job_name): def insert_artifact_into_payload(executor_input, payload): job_spec = json.loads(payload) - artifact = json.loads(executor_input).get('inputs', {}).get( - 'artifacts', {}).get(UNMANAGED_CONTAINER_MODEL_ARTIFACT_NAME, - {}).get('artifacts') + artifact = ( + json.loads(executor_input) + .get('inputs', {}) + .get('artifacts', {}) + .get(UNMANAGED_CONTAINER_MODEL_ARTIFACT_NAME, {}) + .get('artifacts') + ) if artifact: - job_spec[ - UNMANAGED_CONTAINER_MODEL_ARTIFACT_NAME] = json_util.camel_case_to_snake_case_recursive( - artifact[0].get('metadata', {})) - job_spec[UNMANAGED_CONTAINER_MODEL_ARTIFACT_NAME][ - 'artifact_uri'] = artifact[0].get('uri') + job_spec[UNMANAGED_CONTAINER_MODEL_ARTIFACT_NAME] = ( + json_util.camel_case_to_snake_case_recursive( + artifact[0].get('metadata', {}) + ) + ) + job_spec[UNMANAGED_CONTAINER_MODEL_ARTIFACT_NAME]['artifact_uri'] = ( + artifact[0].get('uri') + ) return json.dumps(job_spec) + def insert_system_labels_into_payload(payload): job_spec = json.loads(payload) job_spec[LABELS_PAYLOAD_KEY] = gcp_labels_util.attach_system_labels( - job_spec[LABELS_PAYLOAD_KEY] if LABELS_PAYLOAD_KEY in job_spec else {}) + job_spec[LABELS_PAYLOAD_KEY] if LABELS_PAYLOAD_KEY in job_spec else {} + ) return json.dumps(job_spec) + def create_batch_prediction_job( type, project, @@ -115,8 +130,9 @@ def create_batch_prediction_job( Also retry on ConnectionError up to job_remote_runner._CONNECTION_ERROR_RETRY_LIMIT times during the poll. """ - remote_runner = job_remote_runner.JobRemoteRunner(type, project, location, - gcp_resources) + remote_runner = job_remote_runner.JobRemoteRunner( + type, project, location, gcp_resources + ) try: # Create batch prediction job if it does not exist @@ -136,40 +152,53 @@ def create_batch_prediction_job( # Poll batch prediction job status until "JobState.JOB_STATE_SUCCEEDED" get_job_response = remote_runner.poll_job( - get_batch_prediction_job_with_client, job_name) + get_batch_prediction_job_with_client, job_name + ) vertex_uri_prefix = f'https://{location}-aiplatform.googleapis.com/v1/' vertex_batch_predict_job_artifact = VertexBatchPredictionJob( - 'batchpredictionjob', vertex_uri_prefix + get_job_response.name, + 'batchpredictionjob', + vertex_uri_prefix + get_job_response.name, get_job_response.name, get_job_response.output_info.bigquery_output_table, get_job_response.output_info.bigquery_output_dataset, - get_job_response.output_info.gcs_output_directory) + get_job_response.output_info.gcs_output_directory, + ) output_artifacts = [vertex_batch_predict_job_artifact] # Output the BQTable artifact if get_job_response.output_info.bigquery_output_dataset: bq_dataset_pattern = re.compile(_BQ_DATASET_TEMPLATE) match = bq_dataset_pattern.match( - get_job_response.output_info.bigquery_output_dataset) + get_job_response.output_info.bigquery_output_dataset + ) try: project = match.group('project') dataset = match.group('dataset') bigquery_output_table_artifact = BQTable( - 'bigquery_output_table', project, dataset, - get_job_response.output_info.bigquery_output_table) + 'bigquery_output_table', + project, + dataset, + get_job_response.output_info.bigquery_output_table, + ) output_artifacts.append(bigquery_output_table_artifact) except AttributeError as err: error_util.exit_with_internal_error( - 'Invalid BQ dataset address from batch prediction output: {}. Expect: {}.' - .format(get_job_response.output_info.bigquery_output_dataset, - 'bq://[project_id].[dataset_id]')) + 'Invalid BQ dataset address from batch prediction output: {}.' + ' Expect: {}.'.format( + get_job_response.output_info.bigquery_output_dataset, + 'bq://[project_id].[dataset_id]', + ) + ) # Output the GCS path via system.Artifact if get_job_response.output_info.gcs_output_directory: output_artifacts.append( - dsl.Artifact('gcs_output_directory', - get_job_response.output_info.gcs_output_directory)) + dsl.Artifact( + 'gcs_output_directory', + get_job_response.output_info.gcs_output_directory, + ) + ) artifact_util.update_output_artifacts(executor_input, output_artifacts) except (ConnectionError, RuntimeError) as err: diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/launcher.py index 42cc9a623bd..35a6344e989 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/launcher.py @@ -31,13 +31,15 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/remote_runner.py index fdb05819c62..d22cb04675d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/create_model/remote_runner.py @@ -68,9 +68,14 @@ def bigquery_create_model_job( creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: - job_uri = bigquery_util.create_query_job(project, location, payload, - job_configuration_query_override, - creds, gcp_resources) + job_uri = bigquery_util.create_query_job( + project, + location, + payload, + job_configuration_query_override, + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) @@ -80,10 +85,14 @@ def bigquery_create_model_job( query_result = job['statistics']['query'] - if 'statementType' not in query_result or query_result[ - 'statementType'] != 'CREATE_MODEL' or 'ddlTargetTable' not in query_result: + if ( + 'statementType' not in query_result + or query_result['statementType'] != 'CREATE_MODEL' + or 'ddlTargetTable' not in query_result + ): raise RuntimeError( - 'Unexpected create model result: {}'.format(query_result)) + 'Unexpected create model result: {}'.format(query_result) + ) projectId = query_result['ddlTargetTable']['projectId'] datasetId = query_result['ddlTargetTable']['datasetId'] diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/launcher.py index e1a14044f57..321ab6d11e4 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/launcher.py @@ -30,44 +30,51 @@ def _parse_args(args): dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--contamination', dest='contamination', type=float, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--anomaly_prob_threshold', dest='anomaly_prob_threshold', type=float, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--executor_input', dest='executor_input', type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/remote_runner.py index b9315e49d5f..4cfce9d8564 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/detect_anomalies_model/remote_runner.py @@ -90,34 +90,53 @@ def bigquery_detect_anomalies_model_job( executor_input: A json serialized pipeline executor input. """ settings_field_sql_list = [] - if contamination is not None and contamination >= 0.0 and contamination <= 0.5: + if ( + contamination is not None + and contamination >= 0.0 + and contamination <= 0.5 + ): settings_field_sql_list.append('%s AS contamination' % contamination) - if anomaly_prob_threshold is not None and anomaly_prob_threshold > 0.0 and anomaly_prob_threshold < 1.0: - settings_field_sql_list.append('%s AS anomaly_prob_threshold' % - anomaly_prob_threshold) + if ( + anomaly_prob_threshold is not None + and anomaly_prob_threshold > 0.0 + and anomaly_prob_threshold < 1.0 + ): + settings_field_sql_list.append( + '%s AS anomaly_prob_threshold' % anomaly_prob_threshold + ) settings_field_sql = ','.join(settings_field_sql_list) job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) + job_configuration_query_override, strict=False + ) if query_statement or table_name: - input_data_sql = ('TABLE %s' % - bigquery_util.back_quoted_if_needed(table_name) - if table_name else '(%s)' % query_statement) + input_data_sql = ( + 'TABLE %s' % bigquery_util.back_quoted_if_needed(table_name) + if table_name + else '(%s)' % query_statement + ) settings_sql = ' STRUCT(%s), ' % settings_field_sql - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.DETECT_ANOMALIES(MODEL `%s`, %s%s)' % ( - model_name, settings_sql, input_data_sql) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.DETECT_ANOMALIES(MODEL `%s`, %s%s)' + % (model_name, settings_sql, input_data_sql) + ) else: settings_sql = ' STRUCT(%s)' % settings_field_sql - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.DETECT_ANOMALIES(MODEL `%s`, %s)' % ( - model_name, settings_sql) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.DETECT_ANOMALIES(MODEL `%s`, %s)' + % (model_name, settings_sql) + ) # TODO(mingge): check if model is a valid BigQuery model resource. return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/launcher.py index 79c04affede..5b9a0646472 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/remote_runner.py index 7d282b7fd76..8623227e305 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/drop_model/remote_runner.py @@ -66,16 +66,23 @@ def bigquery_drop_model_job( executor_input: A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json['query'] = ('DROP MODEL %s') % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = 'DROP MODEL %s' % ( + bigquery_util.back_quoted_if_needed(model_name) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/launcher.py index e2594673577..2fb86ff8923 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/launcher.py @@ -31,40 +31,46 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, # table_name is only needed for BigQuery tvf model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, # query_statement is only needed for BigQuery predict model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--threshold', dest='threshold', type=float, # threshold is only needed for BigQuery tvf model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/remote_runner.py index 3d22197867f..fc901ffb426 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/evaluate_model/remote_runner.py @@ -82,12 +82,14 @@ def bigquery_evaluate_model_job( if query_statement and table_name: raise ValueError( 'One and only one of query_statement and table_name should be ' - 'populated for BigQuery evaluation model job.') + 'populated for BigQuery evaluation model job.' + ) input_data_sql = '' if table_name: input_data_sql = ', TABLE %s' % bigquery_util.back_quoted_if_needed( - table_name) + table_name + ) if query_statement: input_data_sql = ', (%s)' % query_statement @@ -96,29 +98,46 @@ def bigquery_evaluate_model_job( threshold_sql = ', STRUCT(%s AS threshold)' % threshold job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.EVALUATE(MODEL %s%s%s)' % ( - bigquery_util.back_quoted_if_needed(model_name), input_data_sql, - threshold_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.EVALUATE(MODEL %s%s%s)' + % ( + bigquery_util.back_quoted_if_needed(model_name), + input_data_sql, + threshold_sql, + ) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job ' + job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'evaluation_metrics', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'evaluation_metrics', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/launcher.py index ef979e2558a..2f72637a83c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/launcher.py @@ -31,31 +31,36 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--horizon', dest='horizon', type=int, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--confidence_level', dest='confidence_level', type=float, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/remote_runner.py index ebdf85ff36b..1007430fc40 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_forecast_model/remote_runner.py @@ -65,19 +65,30 @@ def bigquery_explain_forecast_model_job( if horizon is not None and horizon > 0: settings_field_sql_list.append('%s AS horizon' % horizon) - if confidence_level is not None and confidence_level >= 0.0 and confidence_level < 1.0: + if ( + confidence_level is not None + and confidence_level >= 0.0 + and confidence_level < 1.0 + ): settings_field_sql_list.append('%s AS confidence_level' % confidence_level) settings_field_sql = ','.join(settings_field_sql_list) settings_sql = ', STRUCT(%s)' % settings_field_sql job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.EXPLAIN_FORECAST(MODEL %s %s)' % ( - bigquery_util.back_quoted_if_needed(model_name), settings_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.EXPLAIN_FORECAST(MODEL %s %s)' + % (bigquery_util.back_quoted_if_needed(model_name), settings_sql) + ) return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/launcher.py index 6179364f47b..1ae7cf6b2e1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/launcher.py @@ -31,52 +31,60 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, # table_name is only needed for BigQuery tvf model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, # query_statement is only needed for BigQuery predict model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--top_k_features', dest='top_k_features', type=int, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--num_integral_steps', dest='num_integral_steps', type=int, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--threshold', dest='threshold', type=float, # threshold is only needed for BigQuery tvf model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/remote_runner.py index 39fb8dadb1e..c1e581f4732 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/explain_predict_model/remote_runner.py @@ -94,9 +94,13 @@ def bigquery_explain_predict_model_job( if not (not query_statement) ^ (not table_name): raise ValueError( 'One and only one of query_statement and table_name should be ' - 'populated for BigQuery explain predict model job.') - input_data_sql = ('TABLE %s' % bigquery_util.back_quoted_if_needed(table_name) - if table_name else '(%s)' % query_statement) + 'populated for BigQuery explain predict model job.' + ) + input_data_sql = ( + 'TABLE %s' % bigquery_util.back_quoted_if_needed(table_name) + if table_name + else '(%s)' % query_statement + ) settings_field_sql_list = [] if top_k_features is not None and top_k_features > 0: @@ -106,20 +110,28 @@ def bigquery_explain_predict_model_job( settings_field_sql_list.append('%s AS threshold' % threshold) if num_integral_steps is not None and num_integral_steps > 0: - settings_field_sql_list.append('%s AS num_integral_steps' % - num_integral_steps) + settings_field_sql_list.append( + '%s AS num_integral_steps' % num_integral_steps + ) settings_field_sql = ','.join(settings_field_sql_list) settings_sql = ', STRUCT(%s)' % settings_field_sql job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `%s`, %s%s)' % ( - model_name, input_data_sql, settings_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `%s`, %s%s)' + % (model_name, input_data_sql, settings_sql) + ) # TODO(mingge): check if model is a valid BigQuery model resource. return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/launcher.py index d39fd2ce16f..182db290e80 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/launcher.py @@ -31,25 +31,29 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_destination_path', dest='model_destination_path', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--exported_model_path', dest='exported_model_path', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/remote_runner.py index 0dbc9762d6a..5911ee59c84 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/export_model/remote_runner.py @@ -28,7 +28,7 @@ def _get_model(model_reference, creds): creds.refresh(google.auth.transport.requests.Request()) headers = { 'Content-type': 'application/json', - 'Authorization': 'Bearer ' + creds.token + 'Authorization': 'Bearer ' + creds.token, } model_uri = 'https://bigquery.googleapis.com/bigquery/v2/projects/{projectId}/datasets/{datasetId}/models/{modelId}'.format( projectId=model_reference['projectId'], @@ -90,7 +90,8 @@ def bigquery_export_model_job( model_name_split = model_name.split('.') if len(model_name_split) != 3: raise ValueError( - 'The model name must be in the format "projectId.datasetId.modelId"') + 'The model name must be in the format "projectId.datasetId.modelId"' + ) model_reference = { 'projectId': model_name_split[0], 'datasetId': model_name_split[1], @@ -100,16 +101,20 @@ def bigquery_export_model_job( model = _get_model(model_reference, creds) if not model or 'modelType' not in model: raise ValueError( - 'Cannot get model resource. The model name must be in the format "projectId.datasetId.modelId" ' + 'Cannot get model resource. The model name must be in the format' + ' "projectId.datasetId.modelId" ' ) job_request_json = json.loads(payload, strict=False) - job_request_json['configuration']['query'][ - 'query'] = f'EXPORT MODEL {bigquery_util.back_quoted_if_needed(model_name)} OPTIONS(URI="{model_destination_path}",add_serving_default_signature={True})' + job_request_json['configuration']['query']['query'] = ( + 'EXPORT MODEL' + f' {bigquery_util.back_quoted_if_needed(model_name)} OPTIONS(URI="{model_destination_path}",add_serving_default_signature={True})' + ) job_request_json['configuration']['query']['useLegacySql'] = False - job_uri = bigquery_util.create_job(project, location, job_request_json, - creds, gcp_resources) + job_uri = bigquery_util.create_job( + project, location, job_request_json, creds, gcp_resources + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/launcher.py index f9b7f69eb08..78913c87e8f 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/remote_runner.py index dc27792fd5d..3b5cf079af6 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/feature_importance/remote_runner.py @@ -59,28 +59,42 @@ def bigquery_ml_feature_importance_job( https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-predict#predict_model_name """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.FEATURE_IMPORTANCE(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.FEATURE_IMPORTANCE(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job ' + job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'feature_importance', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'feature_importance', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/launcher.py index c79a02d7994..6340a6aa805 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/launcher.py @@ -31,31 +31,36 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--horizon', dest='horizon', type=int, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--confidence_level', dest='confidence_level', type=float, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/remote_runner.py index ef9678b0864..b2ac14ef55b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/forecast_model/remote_runner.py @@ -64,19 +64,30 @@ def bigquery_forecast_model_job( if horizon is not None and horizon > 0: settings_field_sql_list.append('%s AS horizon' % horizon) - if confidence_level is not None and confidence_level >= 0.0 and confidence_level < 1.0: + if ( + confidence_level is not None + and confidence_level >= 0.0 + and confidence_level < 1.0 + ): settings_field_sql_list.append('%s AS confidence_level' % confidence_level) settings_field_sql = ','.join(settings_field_sql_list) settings_sql = ', STRUCT(%s)' % settings_field_sql job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.FORECAST(MODEL %s %s)' % ( - bigquery_util.back_quoted_if_needed(model_name), settings_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.FORECAST(MODEL %s %s)' + % (bigquery_util.back_quoted_if_needed(model_name), settings_sql) + ) return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/launcher.py index 6f5159e2c99..9f36b6d03f2 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/launcher.py @@ -31,25 +31,29 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--class_level_explain', dest='class_level_explain', type=bool, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/remote_runner.py index 938da1047f9..0268385659a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/global_explain/remote_runner.py @@ -60,12 +60,19 @@ def bigquery_ml_global_explain_job( https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-predict#predict_model_name """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) + job_configuration_query_override, strict=False + ) job_configuration_query_override_json['query'] = ( 'SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL %s, STRUCT(TRUE AS ' - 'class_level_explain))') % ( - bigquery_util.back_quoted_if_needed(model_name)) + 'class_level_explain))' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/launcher.py index de27eaa5c5f..9d9c09622ce 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/remote_runner.py index 3258ac6c555..bdb60a0674f 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_advanced_weights/remote_runner.py @@ -66,28 +66,42 @@ def bigquery_ml_advanced_weights_job( executor_input:A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.ADVANCED_WEIGHTS(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.ADVANCED_WEIGHTS(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job ' + job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'advanced_weights', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'advanced_weights', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/launcher.py index ef040cf86e3..7ddfd24b8ed 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/remote_runner.py index 7cd9713b42e..f748e893064 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_coefficients/remote_runner.py @@ -54,28 +54,42 @@ def bigquery_ml_arima_coefficients( executor_input:A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job ' + job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'arima_coefficients', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'arima_coefficients', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/launcher.py index 3869316b778..10628adbdd1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/launcher.py @@ -31,25 +31,29 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--show_all_candidate_models', dest='show_all_candidate_models', type=bool, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/remote_runner.py index c08013e374a..1066ded3b24 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_arima_evaluate/remote_runner.py @@ -71,32 +71,50 @@ def bigquery_ml_arima_evaluate_job( executor_input:A json serialized pipeline executor input. """ if show_all_candidate_models: - show_all_candidate_models_sql = ', STRUCT(%s AS show_all_candidate_models)' % show_all_candidate_models + show_all_candidate_models_sql = ( + ', STRUCT(%s AS show_all_candidate_models)' % show_all_candidate_models + ) job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.ARIMA_EVALUATE(MODEL %s%s)' % ( + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.ARIMA_EVALUATE(MODEL %s%s)' + % ( bigquery_util.back_quoted_if_needed(model_name), - show_all_candidate_models_sql) + show_all_candidate_models_sql, + ) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job ' + job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'arima_evaluation_metrics', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'arima_evaluation_metrics', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/launcher.py index 9e8f914a8dd..3645a93bf27 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/launcher.py @@ -31,25 +31,29 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--standardize', dest='standardize', type=bool, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/remote_runner.py index 6ec3377474b..8d0309e5bea 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_centroids/remote_runner.py @@ -72,18 +72,24 @@ def bigquery_ml_centroids_job( executor_input: A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) + job_configuration_query_override, strict=False + ) job_configuration_query_override_json['query'] = ( - 'SELECT * FROM ML.CENTROIDS(MODEL %s, STRUCT(%s AS ' - 'standardize))') % (bigquery_util.back_quoted_if_needed(model_name), - standardize) + 'SELECT * FROM ML.CENTROIDS(MODEL %s, STRUCT(%s AS standardize))' + % (bigquery_util.back_quoted_if_needed(model_name), standardize) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) @@ -93,12 +99,19 @@ def bigquery_ml_centroids_job( # For ML.CENTROIDS job, the output only contains one row per feature per # centroid, which should be very small. Thus, we allow users to directly get # the result without writing into a BQ table. - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'centroids', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'centroids', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/launcher.py index 5f6fe942994..3a49ee72892 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/launcher.py @@ -31,40 +31,46 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, # table_name is only needed for BigQuery tvf model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, # query_statement is only needed for BigQuery predict model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--threshold', dest='threshold', type=float, # threshold is only needed for BigQuery tvf model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/remote_runner.py index ba31170cd62..2d8d868d559 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_confusion_matrix/remote_runner.py @@ -80,12 +80,14 @@ def bigquery_ml_confusion_matrix_job( if not (not query_statement) ^ (not table_name): raise ValueError( 'One and only one of query_statement and table_name should be ' - 'populated for BigQuery confusion matrix job.') + 'populated for BigQuery confusion matrix job.' + ) input_data_sql = '' if table_name: input_data_sql = ', TABLE %s' % bigquery_util.back_quoted_if_needed( - table_name) + table_name + ) if query_statement: input_data_sql = ', (%s)' % query_statement @@ -94,16 +96,26 @@ def bigquery_ml_confusion_matrix_job( threshold_sql = ', STRUCT(%s AS threshold)' % threshold job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.CONFUSION_MATRIX(MODEL %s%s%s)' % ( - bigquery_util.back_quoted_if_needed(model_name), input_data_sql, - threshold_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.CONFUSION_MATRIX(MODEL %s%s%s)' + % ( + bigquery_util.back_quoted_if_needed(model_name), + input_data_sql, + threshold_sql, + ) + ) # For ML confusion matrix job, as the returned results is the same as the # number of input, which can be very large. In this case we would like to ask # users to insert a destination table into the job config. return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/launcher.py index 2d027b2a6d6..5cd82ee72d2 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/remote_runner.py index 7145ff37d70..45368fcd509 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_feature_info/remote_runner.py @@ -66,28 +66,42 @@ def bigquery_ml_feature_info_job( executor_input:A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.FEATURE_INFO(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.FEATURE_INFO(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job ' + job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'feature_info', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'feature_info', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/launcher.py index 408ddb689de..34d2ef23829 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/remote_runner.py index f3628eba89f..0c06d7fc3ec 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_component_info/remote_runner.py @@ -68,12 +68,19 @@ def bigquery_ml_principal_component_info_job( """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/launcher.py index a6cdd3e6098..0446c12596d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/remote_runner.py index 3a41a55d0ab..68425df7f13 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_principal_components/remote_runner.py @@ -68,12 +68,19 @@ def bigquery_ml_principal_components_job( """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.PRINCIPAL_COMPONENTS(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.PRINCIPAL_COMPONENTS(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/launcher.py index 0fda7549f2d..0b06aff1aea 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/launcher.py @@ -31,33 +31,38 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, # table_name is only needed for BigQuery tvf model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, # query_statement is only needed for BigQuery predict model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/remote_runner.py index 96a2c57b10e..0ecbd5406ff 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_recommend/remote_runner.py @@ -77,22 +77,31 @@ def bigquery_ml_recommend_job( if query_statement and table_name: raise ValueError( 'One and only one of query_statement and table_name should be ' - 'populated for BigQuery ML Recommend job.') + 'populated for BigQuery ML Recommend job.' + ) input_data_sql = '' if table_name: input_data_sql = ', TABLE %s' % bigquery_util.back_quoted_if_needed( - table_name) + table_name + ) if query_statement: input_data_sql = ', (%s)' % query_statement job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.RECOMMEND(MODEL %s%s)' % ( - bigquery_util.back_quoted_if_needed(model_name), input_data_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.RECOMMEND(MODEL %s%s)' + % (bigquery_util.back_quoted_if_needed(model_name), input_data_sql) + ) return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/launcher.py index 835827dde0e..ee7f5364f36 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/launcher.py @@ -31,33 +31,38 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, # table_name is only needed for BigQuery tvf model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, # query_statement is only needed for BigQuery predict model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/remote_runner.py index 4bd74f56cb3..0f090024fa6 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_reconstruction_loss/remote_runner.py @@ -76,25 +76,34 @@ def bigquery_ml_reconstruction_loss_job( if not (not query_statement) ^ (not table_name): raise ValueError( 'One and only one of query_statement and table_name should be ' - 'populated for BigQuery ML Reconstruction Loss job.') + 'populated for BigQuery ML Reconstruction Loss job.' + ) input_data_sql = '' if table_name: input_data_sql = ', TABLE %s' % bigquery_util.back_quoted_if_needed( - table_name) + table_name + ) if query_statement: input_data_sql = ', (%s)' % query_statement job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.RECONSTRUCTION_LOSS(MODEL %s%s)' % ( - bigquery_util.back_quoted_if_needed(model_name), input_data_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.RECONSTRUCTION_LOSS(MODEL %s%s)' + % (bigquery_util.back_quoted_if_needed(model_name), input_data_sql) + ) # For ML reconstruction loss job, as the returned results is the same as the # number of input, which can be very large. In this case we would like to ask # users to insert a destination table into the job config. return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/launcher.py index 4a582694ffa..d059a0bed29 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/launcher.py @@ -31,40 +31,46 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, # table_name is only needed for BigQuery tvf model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, # query_statement is only needed for BigQuery predict model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--thresholds', dest='thresholds', type=str, # thresholds is only needed for BigQuery tvf model job component. required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/remote_runner.py index cccafaceb1a..616fdca58a1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_roc_curve/remote_runner.py @@ -80,12 +80,14 @@ def bigquery_ml_roc_curve_job( if not (not query_statement) ^ (not table_name): raise ValueError( 'One and only one of query_statement and table_name should be ' - 'populated for BigQuery roc curve job.') + 'populated for BigQuery roc curve job.' + ) input_data_sql = '' if table_name: input_data_sql = ', TABLE %s' % bigquery_util.back_quoted_if_needed( - table_name) + table_name + ) if query_statement: input_data_sql = ', (%s)' % query_statement @@ -94,16 +96,26 @@ def bigquery_ml_roc_curve_job( thresholds_sql = ', GENERATE_ARRAY(%s)' % thresholds job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.ROC_CURVE(MODEL %s%s%s)' % ( - bigquery_util.back_quoted_if_needed(model_name), input_data_sql, - thresholds_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.ROC_CURVE(MODEL %s%s%s)' + % ( + bigquery_util.back_quoted_if_needed(model_name), + input_data_sql, + thresholds_sql, + ) + ) # For ML roc curve job, as the returned results is the same as the # number of input, which can be very large. In this case we would like to ask # users to insert a destination table into the job config. return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/launcher.py index 13eb3e34304..41900c368cb 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/remote_runner.py index 4b2a40f1573..2434126079b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_training_info/remote_runner.py @@ -66,28 +66,42 @@ def bigquery_ml_training_info_job( executor_input: A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.TRAINING_INFO(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.TRAINING_INFO(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job %s', job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'ml_training_info', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'ml_training_info', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/launcher.py index e16bc5e166f..19c7c63f571 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/remote_runner.py index ff6821f4f1d..e9026b83491 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_trial_info/remote_runner.py @@ -66,17 +66,24 @@ def bigquery_ml_trial_info_job( executor_input: A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.TRIAL_INFO(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.TRIAL_INFO(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) @@ -86,12 +93,19 @@ def bigquery_ml_trial_info_job( # For ML Trial Info job, as the returned results only contains num_trials # rows, which should be very small. In this case we allow users to directly # get the result without writing into a BQ table. - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'trial_info', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'trial_info', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/launcher.py index ebf9d5963da..dd20a10003a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/launcher.py @@ -31,19 +31,22 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/remote_runner.py index 4f672dc301a..439d94b223a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/ml_weights/remote_runner.py @@ -66,28 +66,42 @@ def bigquery_ml_weights_job( executor_input:A json serialized pipeline executor input. """ job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.WEIGHTS(MODEL %s)' % ( - bigquery_util.back_quoted_if_needed(model_name)) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.WEIGHTS(MODEL %s)' + % (bigquery_util.back_quoted_if_needed(model_name)) + ) creds, _ = google.auth.default() job_uri = bigquery_util.check_if_job_exists(gcp_resources) if job_uri is None: job_uri = bigquery_util.create_query_job( - project, location, payload, - json.dumps(job_configuration_query_override_json), creds, gcp_resources) + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = bigquery_util.poll_job(job_uri, creds) logging.info('Getting query result for job ' + job['id']) _, job_id = job['id'].split('.') - query_results = bigquery_util.get_query_results(project, job_id, location, - creds) + query_results = bigquery_util.get_query_results( + project, job_id, location, creds + ) artifact_util.update_output_artifact( - executor_input, 'weights', '', { - bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA], - bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: - query_results[bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS] - }) + executor_input, + 'weights', + '', + { + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_SCHEMA + ], + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS: query_results[ + bigquery_util.ARTIFACT_PROPERTY_KEY_ROWS + ], + }, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/launcher.py index 26d785726a0..c64d3b68999 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/launcher.py @@ -31,40 +31,46 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--model_name', dest='model_name', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--table_name', dest='table_name', type=str, # table_name is only needed for BigQuery tvf model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--query_statement', dest='query_statement', type=str, # query_statement is only needed for BigQuery predict model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--threshold', dest='threshold', type=float, # threshold is only needed for BigQuery tvf model job component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/remote_runner.py index 6a57dcfa689..80cb4331314 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/predict_model/remote_runner.py @@ -82,22 +82,33 @@ def bigquery_predict_model_job( if not (not query_statement) ^ (not table_name): raise ValueError( 'One and only one of query_statement and table_name should be ' - 'populated for BigQuery predict model job.') - input_data_sql = ('TABLE %s' % bigquery_util.back_quoted_if_needed(table_name) - if table_name else '(%s)' % query_statement) + 'populated for BigQuery predict model job.' + ) + input_data_sql = ( + 'TABLE %s' % bigquery_util.back_quoted_if_needed(table_name) + if table_name + else '(%s)' % query_statement + ) threshold_sql = '' if threshold is not None and threshold > 0.0 and threshold < 1.0: threshold_sql = ', STRUCT(%s AS threshold)' % threshold job_configuration_query_override_json = json.loads( - job_configuration_query_override, strict=False) - job_configuration_query_override_json[ - 'query'] = 'SELECT * FROM ML.PREDICT(MODEL `%s`, %s%s)' % ( - model_name, input_data_sql, threshold_sql) + job_configuration_query_override, strict=False + ) + job_configuration_query_override_json['query'] = ( + 'SELECT * FROM ML.PREDICT(MODEL `%s`, %s%s)' + % (model_name, input_data_sql, threshold_sql) + ) # TODO(mingge): check if model is a valid BigQuery model resource. return bigquery_util.bigquery_query_job( - type, project, location, payload, - json.dumps(job_configuration_query_override_json), gcp_resources, - executor_input) + type, + project, + location, + payload, + json.dumps(job_configuration_query_override_json), + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/launcher.py index ecf4ae49d67..85c7e92a50a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/launcher.py @@ -31,13 +31,15 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--job_configuration_query_override', dest='job_configuration_query_override', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/remote_runner.py index 2c6bfad5c0b..8c934a0e70e 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/query_job/remote_runner.py @@ -62,6 +62,12 @@ def bigquery_query_job( gcp_resources: File path for storing `gcp_resources` output parameter. executor_input: A json serialized pipeline executor input. """ - return bigquery_util.bigquery_query_job(type, project, location, payload, - job_configuration_query_override, - gcp_resources, executor_input) + return bigquery_util.bigquery_query_job( + type, + project, + location, + payload, + job_configuration_query_override, + gcp_resources, + executor_input, + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/utils/bigquery_util.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/utils/bigquery_util.py index 6cfbc1739c6..ef13579a696 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/utils/bigquery_util.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/bigquery/utils/bigquery_util.py @@ -44,12 +44,16 @@ def insert_system_labels_into_payload(job_request_json): if JOB_CONFIGURATION_KEY not in job_request_json: job_request_json[JOB_CONFIGURATION_KEY] = {} - job_request_json[JOB_CONFIGURATION_KEY][ - LABELS_PAYLOAD_KEY] = gcp_labels_util.attach_system_labels( + job_request_json[JOB_CONFIGURATION_KEY][LABELS_PAYLOAD_KEY] = ( + gcp_labels_util.attach_system_labels( job_request_json[JOB_CONFIGURATION_KEY][LABELS_PAYLOAD_KEY] - if LABELS_PAYLOAD_KEY in job_request_json else {}) + if LABELS_PAYLOAD_KEY in job_request_json + else {} + ) + ) return job_request_json + def back_quoted_if_needed(resource_name) -> str: """Enclose resource name with ` if it's not yet.""" if not resource_name or resource_name.startswith('`'): @@ -69,12 +73,14 @@ def check_if_job_exists(gcp_resources: str) -> Optional[str]: if path.exists(gcp_resources) and os.stat(gcp_resources).st_size != 0: with open(gcp_resources) as f: serialized_gcp_resources = f.read() - job_resources = json_format.Parse(serialized_gcp_resources, - gcp_resources_pb2.GcpResources()) + job_resources = json_format.Parse( + serialized_gcp_resources, gcp_resources_pb2.GcpResources() + ) # Resources should only contain one item. if len(job_resources.resources) != 1: raise ValueError( - f'gcp_resources should contain one resource, found {len(job_resources.resources)}' + 'gcp_resources should contain one resource, found' + f' {len(job_resources.resources)}' ) # Validate the format of the resource uri. job_name_pattern = re.compile(_BQ_JOB_NAME_TEMPLATE) @@ -83,18 +89,21 @@ def check_if_job_exists(gcp_resources: str) -> Optional[str]: project = match.group('project') job = match.group('job') except AttributeError as err: - raise ValueError('Invalid bigquery job uri: {}. Expect: {}.'.format( - job_resources.resources[0].resource_uri, - 'https://www.googleapis.com/bigquery/v2/projects/[projectId]/jobs/[jobId]?location=[location]' - )) + raise ValueError( + 'Invalid bigquery job uri: {}. Expect: {}.'.format( + job_resources.resources[0].resource_uri, + 'https://www.googleapis.com/bigquery/v2/projects/[projectId]/jobs/[jobId]?location=[location]', + ) + ) return job_resources.resources[0].resource_uri else: return None -def create_job(project, location, job_request_json, creds, - gcp_resources) -> str: +def create_job( + project, location, job_request_json, creds, gcp_resources +) -> str: """Create a new BigQuery job. @@ -122,15 +131,18 @@ def create_job(project, location, job_request_json, creds, headers = { 'Content-type': 'application/json', 'Authorization': 'Bearer ' + creds.token, - 'User-Agent': 'google-cloud-pipeline-components' + 'User-Agent': 'google-cloud-pipeline-components', } - insert_job_url = f'https://www.googleapis.com/bigquery/v2/projects/{project}/jobs' + insert_job_url = ( + f'https://www.googleapis.com/bigquery/v2/projects/{project}/jobs' + ) job = requests.post( - url=insert_job_url, data=json.dumps(job_request_json), - headers=headers).json() + url=insert_job_url, data=json.dumps(job_request_json), headers=headers + ).json() if 'selfLink' not in job: raise RuntimeError( - f'BigQquery Job failed. Cannot retrieve the job name. Request:{job_request_json}; Response: {job}.' + 'BigQquery Job failed. Cannot retrieve the job name.' + f' Request:{job_request_json}; Response: {job}.' ) # Write the bigquey job uri to gcp resource. @@ -145,9 +157,14 @@ def create_job(project, location, job_request_json, creds, return job_uri -def create_query_job(project, location, payload, - job_configuration_query_override, creds, - gcp_resources) -> str: +def create_query_job( + project, + location, + payload, + job_configuration_query_override, + creds, + gcp_resources, +) -> str: """Create a new BigQuery query job. @@ -167,9 +184,11 @@ def create_query_job(project, location, payload, The URI of the BigQuery Job. """ job_request_json = insert_system_labels_into_payload( - json.loads(payload, strict=False)) + json.loads(payload, strict=False) + ) job_configuration_query_override_json = json_util.recursive_remove_empty( - json.loads(job_configuration_query_override, strict=False)) + json.loads(job_configuration_query_override, strict=False) + ) # Overrides json request with the value in job_configuration_query_override for key, value in job_configuration_query_override_json.items(): @@ -179,9 +198,10 @@ def create_query_job(project, location, payload, job_request_json = json_util.recursive_remove_empty(job_request_json) # Always uses standard SQL instead of legacy SQL. - if 'useLegacySql' in job_request_json['configuration'][ - 'query'] and job_request_json['configuration']['query'][ - 'useLegacySql'] == True: + if ( + 'useLegacySql' in job_request_json['configuration']['query'] + and job_request_json['configuration']['query']['useLegacySql'] == True + ): raise ValueError('Legacy SQL is not supported. Use standard SQL instead.') job_request_json['configuration']['query']['useLegacySql'] = False @@ -198,29 +218,35 @@ def _send_cancel_request(job_uri, creds): # Bigquery cancel API: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/cancel response = requests.post( - url=f'{job_uri.split("?")[0]}/cancel', data='', headers=headers) + url=f'{job_uri.split("?")[0]}/cancel', data='', headers=headers + ) logging.info('Cancel response: %s', response) def poll_job(job_uri, creds) -> dict: """Poll the bigquery job till it reaches a final state.""" with execution_context.ExecutionContext( - on_cancel=lambda: _send_cancel_request(job_uri, creds)): + on_cancel=lambda: _send_cancel_request(job_uri, creds) + ): job = {} - while ('status' not in job) or ('state' not in job['status']) or ( - job['status']['state'].lower() != 'done'): + while ( + ('status' not in job) + or ('state' not in job['status']) + or (job['status']['state'].lower() != 'done') + ): time.sleep(_POLLING_INTERVAL_IN_SECONDS) logging.info('The job is running...') if not creds.valid: creds.refresh(google.auth.transport.requests.Request()) headers = { 'Content-type': 'application/json', - 'Authorization': 'Bearer ' + creds.token + 'Authorization': 'Bearer ' + creds.token, } job = requests.get(job_uri, headers=headers).json() if 'status' in job and 'errorResult' in job['status']: raise RuntimeError( - f'The BigQuery job {job_uri} failed. Error: {job["status"]}') + f'The BigQuery job {job_uri} failed. Error: {job["status"]}' + ) logging.info('BigQuery Job completed successfully. Job: %s.', job) return job @@ -231,7 +257,7 @@ def get_query_results(project_id, job_id, location, creds): creds.refresh(google.auth.transport.requests.Request()) headers = { 'Content-type': 'application/json', - 'Authorization': 'Bearer ' + creds.token + 'Authorization': 'Bearer ' + creds.token, } query_results_uri = 'https://bigquery.googleapis.com/bigquery/v2/projects/{projectId}/queries/{jobId}'.format( projectId=project_id, @@ -286,9 +312,14 @@ def bigquery_query_job( creds, _ = google.auth.default() job_uri = check_if_job_exists(gcp_resources) if job_uri is None: - job_uri = create_query_job(project, location, payload, - job_configuration_query_override, creds, - gcp_resources) + job_uri = create_query_job( + project, + location, + payload, + job_configuration_query_override, + creds, + gcp_resources, + ) # Poll bigquery job status until finished. job = poll_job(job_uri, creds) @@ -298,6 +329,7 @@ def bigquery_query_job( projectId = job['configuration']['query']['destinationTable']['projectId'] datasetId = job['configuration']['query']['destinationTable']['datasetId'] tableId = job['configuration']['query']['destinationTable']['tableId'] - bq_table_artifact = BQTable('destination_table', projectId, datasetId, - tableId) + bq_table_artifact = BQTable( + 'destination_table', projectId, datasetId, tableId + ) artifact_util.update_output_artifacts(executor_input, [bq_table_artifact]) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/custom_job/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/custom_job/remote_runner.py index 71427813b95..3798670f4b9 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/custom_job/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/custom_job/remote_runner.py @@ -27,7 +27,8 @@ def insert_system_labels_into_payload(payload): job_spec = json.loads(payload) job_spec[LABELS_PAYLOAD_KEY] = gcp_labels_util.attach_system_labels( - job_spec[LABELS_PAYLOAD_KEY] if LABELS_PAYLOAD_KEY in job_spec else {}) + job_spec[LABELS_PAYLOAD_KEY] if LABELS_PAYLOAD_KEY in job_spec else {} + ) return json.dumps(job_spec) @@ -35,7 +36,8 @@ def create_custom_job_with_client(job_client, parent, job_spec): create_custom_job_fn = None try: create_custom_job_fn = job_client.create_custom_job( - parent=parent, custom_job=job_spec) + parent=parent, custom_job=job_spec + ) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) return create_custom_job_fn @@ -46,7 +48,8 @@ def get_custom_job_with_client(job_client, job_name): try: get_custom_job_fn = job_client.get_custom_job( name=job_name, - retry=retry.Retry(deadline=_CUSTOM_JOB_RETRY_DEADLINE_SECONDS)) + retry=retry.Retry(deadline=_CUSTOM_JOB_RETRY_DEADLINE_SECONDS), + ) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) return get_custom_job_fn @@ -75,8 +78,9 @@ def create_custom_job( Also retry on ConnectionError up to job_remote_runner._CONNECTION_ERROR_RETRY_LIMIT times during the poll. """ - remote_runner = job_remote_runner.JobRemoteRunner(type, project, location, - gcp_resources) + remote_runner = job_remote_runner.JobRemoteRunner( + type, project, location, gcp_resources + ) try: # Create custom job if it does not exist @@ -84,7 +88,8 @@ def create_custom_job( if job_name is None: job_name = remote_runner.create_job( create_custom_job_with_client, - insert_system_labels_into_payload(payload)) + insert_system_labels_into_payload(payload), + ) # Poll custom job status until "JobState.JOB_STATE_SUCCEEDED" remote_runner.poll_job(get_custom_job_with_client, job_name) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_launcher.py index 0dac7602dfc..d902a9ee87a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_launcher.py @@ -36,49 +36,53 @@ def _parse_args(args) -> Dict[str, Any]: and a list containing all unknonw args. """ parser = argparse.ArgumentParser( - prog='Dataflow python job Pipelines service launcher', description='') + prog='Dataflow python job Pipelines service launcher', description='' + ) parser.add_argument( '--project', dest='project', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--location', dest='location', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--python_module_path', dest='python_module_path', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--temp_location', dest='temp_location', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--requirements_file_path', dest='requirements_file_path', type=str, required=False, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( - '--args', - dest='args', - type=str, - required=False, - default=argparse.SUPPRESS) + '--args', dest='args', type=str, required=False, default=argparse.SUPPRESS + ) parser.add_argument( '--gcp_resources', dest='gcp_resources', type=_make_parent_dirs_and_return_path, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_python_job_remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_python_job_remote_runner.py index e75f2f34d5b..4496c475c7a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_python_job_remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataflow/dataflow_python_job_remote_runner.py @@ -30,11 +30,15 @@ logging.basicConfig(level=logging.INFO) # Job ID pattern for Dataflow jobs -_DATAFLOW_JOB_ID_PATTERN = br'.*console.cloud.google.com/dataflow/jobs/(?P[a-z|0-9|A-Z|\-|\_]+)/(?P[a-z|0-9|A-Z|\-|\_]+).*' +_DATAFLOW_JOB_ID_PATTERN = rb'.*console.cloud.google.com/dataflow/jobs/(?P[a-z|0-9|A-Z|\-|\_]+)/(?P[a-z|0-9|A-Z|\-|\_]+).*' # Args, if provided, that should be staged locally. -_ARGS_FILES_TO_STAGE = ('--requirements_file', '--setup_file', '--sdk_location', - '--extra_package') +_ARGS_FILES_TO_STAGE = ( + '--requirements_file', + '--setup_file', + '--sdk_location', + '--extra_package', +) LABELS_ARG_KEY = '--labels' @@ -47,13 +51,16 @@ def get_system_label_args_list(): system_labels_args_list.append('{}={}'.format(k, v)) return system_labels_args_list -def create_python_job(python_module_path: str, - project: str, - gcp_resources: str, - location: str, - temp_location: str, - requirements_file_path: str = '', - args: Optional[str] = '[]'): + +def create_python_job( + python_module_path: str, + project: str, + gcp_resources: str, + location: str, + temp_location: str, + requirements_file_path: str = '', + args: Optional[str] = '[]', +): """Creates a Dataflow python job. Args: @@ -68,7 +75,6 @@ def create_python_job(python_module_path: str, '--requirements_file' or '--setup_file' to configure the workers however the path provided needs to be a GCS path. - Returns: And instance of GCPResouces proto with the dataflow Job ID which is stored in gcp_resources path. @@ -91,8 +97,9 @@ def create_python_job(python_module_path: str, args_list[idx + 1] = stage_file(args_list[idx + 1]) logging.info('Staging %s at %s locally.', param, args_list[idx + 1]) - cmd = prepare_cmd(project, location, python_file_path, args_list, - temp_location) + cmd = prepare_cmd( + project, location, python_file_path, args_list, temp_location + ) sub_process = Process(cmd) for line in sub_process.read_lines(): logging.info('DataflowRunner output: %s', line) @@ -110,16 +117,23 @@ def create_python_job(python_module_path: str, break if not job_id: raise RuntimeError( - 'No dataflow job was found when running the python file.') + 'No dataflow job was found when running the python file.' + ) def prepare_cmd(project_id, region, python_file_path, args, temp_location): dataflow_args = [ - '--runner', 'DataflowRunner', '--project', project_id, '--region', region, - '--temp_location', temp_location + '--runner', + 'DataflowRunner', + '--project', + project_id, + '--region', + region, + '--temp_location', + temp_location, ] - return (['python3', '-u', python_file_path] + dataflow_args + args) + return ['python3', '-u', python_file_path] + dataflow_args + args def extract_job_id_and_location(line): @@ -127,8 +141,10 @@ def extract_job_id_and_location(line): job_id_pattern = re.compile(_DATAFLOW_JOB_ID_PATTERN) matched_job_id = job_id_pattern.search(line or '') if matched_job_id: - return (matched_job_id.group('job_id').decode(), - matched_job_id.group('location').decode()) + return ( + matched_job_id.group('job_id').decode(), + matched_job_id.group('location').decode(), + ) return (None, None) @@ -185,8 +201,9 @@ def download_blob(source_blob_path, destination_file_path): with open(destination_file_path, 'wb+') as f: blob.download_to_file(f) - logging.info('Blob %s downloaded to %s.', source_blob_path, - destination_file_path) + logging.info( + 'Blob %s downloaded to %s.', source_blob_path, destination_file_path + ) class Process: @@ -199,7 +216,8 @@ def __init__(self, cmd): stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, - shell=False) + shell=False, + ) def read_lines(self): # stdout will end with empty bytes when process exits. diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_pyspark_batch/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_pyspark_batch/launcher.py index 6bc9de5f046..3856e8d3746 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_pyspark_batch/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_pyspark_batch/launcher.py @@ -30,7 +30,8 @@ def _parse_args(args): dest='batch_id', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_batch/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_batch/launcher.py index 5cef88e4f66..326cfc0a973 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_batch/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_batch/launcher.py @@ -30,7 +30,8 @@ def _parse_args(args): dest='batch_id', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_r_batch/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_r_batch/launcher.py index 6dd30f12013..4b72fe4fab7 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_r_batch/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_r_batch/launcher.py @@ -30,7 +30,8 @@ def _parse_args(args): dest='batch_id', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_sql_batch/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_sql_batch/launcher.py index 70299c78476..6e869ccbbd0 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_sql_batch/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/create_spark_sql_batch/launcher.py @@ -30,7 +30,8 @@ def _parse_args(args): dest='batch_id', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/utils/dataproc_util.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/utils/dataproc_util.py index 048873e4f3f..2f320c91059 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/utils/dataproc_util.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/dataproc/utils/dataproc_util.py @@ -38,7 +38,7 @@ _POLL_INTERVAL_SECONDS = 20 _CONNECTION_ERROR_RETRY_LIMIT = 5 -_CONNECTION_RETRY_BACKOFF_FACTOR = 2. +_CONNECTION_RETRY_BACKOFF_FACTOR = 2.0 _LABELS_PAYLOAD_KEY = 'labels' _DATAPROC_URI_PREFIX = 'https://dataproc.googleapis.com/v1' @@ -48,10 +48,12 @@ def insert_system_labels_into_payload(payload): job_spec = json.loads(payload) job_spec[_LABELS_PAYLOAD_KEY] = gcp_labels_util.attach_system_labels( - job_spec[_LABELS_PAYLOAD_KEY] if _LABELS_PAYLOAD_KEY in job_spec else {}) + job_spec[_LABELS_PAYLOAD_KEY] if _LABELS_PAYLOAD_KEY in job_spec else {} + ) return json.dumps(job_spec) -class DataprocBatchRemoteRunner(): + +class DataprocBatchRemoteRunner: """Common module for creating and polling Dataproc Serverless Batch workloads.""" def __init__( @@ -75,13 +77,13 @@ def _get_session(self) -> Session: total=_CONNECTION_ERROR_RETRY_LIMIT, status_forcelist=[429, 503], backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, - method_whitelist=['GET', 'POST'] + method_whitelist=['GET', 'POST'], ) adapter = HTTPAdapter(max_retries=retry) session = requests.Session() session.headers.update({ 'Content-Type': 'application/json', - 'User-Agent': 'google-cloud-pipeline-components' + 'User-Agent': 'google-cloud-pipeline-components', }) session.mount('https://', adapter) return session @@ -100,7 +102,7 @@ def _get_resource(self, url: str) -> Dict[str, Any]: """ if not self._creds.valid: self._creds.refresh(google.auth.transport.requests.Request()) - headers = {'Authorization': 'Bearer '+ self._creds.token} + headers = {'Authorization': 'Bearer ' + self._creds.token} result = self._session.get(url, headers=headers) json_data = {} @@ -110,18 +112,22 @@ def _get_resource(self, url: str) -> Dict[str, Any]: return json_data except requests.exceptions.HTTPError as err: try: - err_msg = ('Error {} returned from GET: {}. Status: {}, Message: {}' - .format(err.response.status_code, - err.request.url, - json_data['error']['status'], - json_data['error']['message'])) + err_msg = ( + 'Error {} returned from GET: {}. Status: {}, Message: {}'.format( + err.response.status_code, + err.request.url, + json_data['error']['status'], + json_data['error']['message'], + ) + ) except (KeyError, TypeError): err_msg = err.response.text raise RuntimeError(err_msg) from err except json.decoder.JSONDecodeError as err: - raise RuntimeError('Failed to decode JSON from response:\n{}' - .format(err.doc)) from err + raise RuntimeError( + 'Failed to decode JSON from response:\n{}'.format(err.doc) + ) from err def _post_resource(self, url: str, post_data: str) -> Dict[str, Any]: """POST a http request. @@ -138,7 +144,7 @@ def _post_resource(self, url: str, post_data: str) -> Dict[str, Any]: """ if not self._creds.valid: self._creds.refresh(google.auth.transport.requests.Request()) - headers = {'Authorization': 'Bearer '+ self._creds.token} + headers = {'Authorization': 'Bearer ' + self._creds.token} result = self._session.post(url=url, data=post_data, headers=headers) json_data = {} @@ -148,17 +154,21 @@ def _post_resource(self, url: str, post_data: str) -> Dict[str, Any]: return json_data except requests.exceptions.HTTPError as err: try: - err_msg = ('Error {} returned from POST: {}. Status: {}, Message: {}' - .format(err.response.status_code, - err.request.url, - json_data['error']['status'], - json_data['error']['message'])) + err_msg = ( + 'Error {} returned from POST: {}. Status: {}, Message: {}'.format( + err.response.status_code, + err.request.url, + json_data['error']['status'], + json_data['error']['message'], + ) + ) except (KeyError, TypeError): err_msg = err.response.text raise RuntimeError(err_msg) from err except json.decoder.JSONDecodeError as err: - raise RuntimeError('Failed to decode JSON from response:\n{}' - .format(err.doc)) from err + raise RuntimeError( + 'Failed to decode JSON from response:\n{}'.format(err.doc) + ) from err def _cancel_batch(self, lro_name) -> None: """Cancels a Dataproc batch workload.""" @@ -173,29 +183,38 @@ def check_if_operation_exists(self) -> Union[Dict[str, Any], None]: """Check if a Dataproc Batch operation already exists. Returns: - Dict of the long-running Operation resource if it exists. For more details, see: + Dict of the long-running Operation resource if it exists. For more + details, see: https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.operations#Operation None if the Operation resource does not exist. Raises: ValueError: Operation resource uri format is invalid. """ - if path.exists(self._gcp_resources) and os.stat(self._gcp_resources).st_size != 0: + if ( + path.exists(self._gcp_resources) + and os.stat(self._gcp_resources).st_size != 0 + ): with open(self._gcp_resources) as f: serialized_gcp_resources = f.read() - job_resources = json_format.Parse(serialized_gcp_resources, - gcp_resources_pb2.GcpResources()) + job_resources = json_format.Parse( + serialized_gcp_resources, gcp_resources_pb2.GcpResources() + ) # Job resources should contain a DataprocBatch and DataprocLro resource. if len(job_resources.resources) != 2: raise ValueError( - f'gcp_resources should contain 2 resources, found {len(job_resources.resources)}.' + 'gcp_resources should contain 2 resources, found' + f' {len(job_resources.resources)}.' ) - if (['DataprocBatch', 'DataprocLro'] != - sorted([r.resource_type for r in job_resources.resources])): - raise ValueError('gcp_resources should contain a' - 'DataprocLro resource and a DataprocBatch resource') + if ['DataprocBatch', 'DataprocLro'] != sorted( + [r.resource_type for r in job_resources.resources] + ): + raise ValueError( + 'gcp_resources should contain a' + 'DataprocLro resource and a DataprocBatch resource' + ) for resource in job_resources.resources: if resource.resource_type == 'DataprocLro': @@ -207,10 +226,12 @@ def check_if_operation_exists(self) -> Union[Dict[str, Any], None]: matched_region = match.group('region') matched_operation_id = match.group('operation') except AttributeError as err: - raise ValueError('Invalid Resource uri: {}. Expect: {}.'.format( - resource.resource_uri, - 'https://dataproc.googleapis.com/v1/projects/[projectId]/regions/[region]/operations/[operationId]' - )) from err + raise ValueError( + 'Invalid Resource uri: {}. Expect: {}.'.format( + resource.resource_uri, + 'https://dataproc.googleapis.com/v1/projects/[projectId]/regions/[region]/operations/[operationId]', + ) + ) from err # Get the long-running Operation resource. lro = self._get_resource(resource.resource_uri) @@ -240,15 +261,15 @@ def wait_for_batch( lro_name = lro['name'] lro_uri = f'{_DATAPROC_URI_PREFIX}/{lro_name}' with execution_context.ExecutionContext( - on_cancel=lambda: self._cancel_batch(lro_name)): + on_cancel=lambda: self._cancel_batch(lro_name) + ): while ('done' not in lro) or (not lro['done']): time.sleep(poll_interval_seconds) lro = self._get_resource(lro_uri) logging.info('Polled operation: %s', lro_name) if 'error' in lro and lro['error']['code']: - raise RuntimeError( - 'Operation failed. Error: {}'.format(lro['error'])) + raise RuntimeError('Operation failed. Error: {}'.format(lro['error'])) else: logging.info('Operation complete: %s', lro) return lro @@ -292,6 +313,7 @@ def create_batch( return lro + def create_batch( type: str, project: str, @@ -317,12 +339,16 @@ def create_batch( """ try: batch_request = json_util.recursive_remove_empty( - json.loads(insert_system_labels_into_payload(payload), strict=False)) + json.loads(insert_system_labels_into_payload(payload), strict=False) + ) except json.decoder.JSONDecodeError as err: - raise RuntimeError('Failed to decode JSON from payload: {}' - .format(err.doc)) from err + raise RuntimeError( + 'Failed to decode JSON from payload: {}'.format(err.doc) + ) from err - remote_runner = DataprocBatchRemoteRunner(type, project, location, gcp_resources) + remote_runner = DataprocBatchRemoteRunner( + type, project, location, gcp_resources + ) lro = remote_runner.check_if_operation_exists() if not lro: if not batch_id: diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/launcher.py index 29d75954f0b..70ca2e7c6f4 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/launcher.py @@ -31,7 +31,8 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/remote_runner.py index 743843421b4..eddf133642e 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/create_endpoint/remote_runner.py @@ -33,20 +33,27 @@ def create_endpoint( """Create endpoint and poll the LongRunningOperator till it reaches a final state.""" api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' - create_endpoint_url = f'{vertex_uri_prefix}projects/{project}/locations/{location}/endpoints' + create_endpoint_url = ( + f'{vertex_uri_prefix}projects/{project}/locations/{location}/endpoints' + ) endpoint_spec = json.loads(payload, strict=False) endpoint_spec[_LABELS_PAYLOAD_KEY] = gcp_labels_util.attach_system_labels( - endpoint_spec[_LABELS_PAYLOAD_KEY] if _LABELS_PAYLOAD_KEY in - endpoint_spec else {}) + endpoint_spec[_LABELS_PAYLOAD_KEY] + if _LABELS_PAYLOAD_KEY in endpoint_spec + else {} + ) create_endpoint_request = json_util.recursive_remove_empty(endpoint_spec) remote_runner = lro_remote_runner.LroRemoteRunner(location) create_endpoint_lro = remote_runner.create_lro( - create_endpoint_url, json.dumps(create_endpoint_request), gcp_resources) + create_endpoint_url, json.dumps(create_endpoint_request), gcp_resources + ) create_endpoint_lro = remote_runner.poll_lro(lro=create_endpoint_lro) endpoint_resource_name = create_endpoint_lro['response']['name'] - vertex_endpoint = VertexEndpoint('endpoint', - vertex_uri_prefix + endpoint_resource_name, - endpoint_resource_name) + vertex_endpoint = VertexEndpoint( + 'endpoint', + vertex_uri_prefix + endpoint_resource_name, + endpoint_resource_name, + ) artifact_util.update_output_artifacts(executor_input, [vertex_endpoint]) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/delete_endpoint/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/delete_endpoint/remote_runner.py index 22d2f4bd1de..30b687e7047 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/delete_endpoint/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/delete_endpoint/remote_runner.py @@ -20,6 +20,7 @@ _ENDPOINT_NAME_TEMPLATE = r'(projects/(?P.*)/locations/(?P.*)/endpoints/(?P.*))' + def delete_endpoint( type, project, @@ -27,13 +28,14 @@ def delete_endpoint( payload, gcp_resources, ): + """Delete endpoint and poll the LongRunningOperator till it reaches a final + + state. """ - Delete endpoint and poll the LongRunningOperator till it reaches a final - state. - """ # TODO(IronPan) temporarily remove the empty fields from the spec delete_endpoint_request = json_util.recursive_remove_empty( - json.loads(payload, strict=False)) + json.loads(payload, strict=False) + ) endpoint_name = delete_endpoint_request['endpoint'] uri_pattern = re.compile(_ENDPOINT_NAME_TEMPLATE) @@ -42,16 +44,21 @@ def delete_endpoint( location = match.group('location') except AttributeError as err: # TODO(ruifang) propagate the error. - raise ValueError('Invalid endpoint name: {}. Expect: {}.'.format( - endpoint_name, - 'projects/[project_id]/locations/[location]/endpoints/[endpoint_id]')) - delete_endpoint_url = (f'https://{location}-aiplatform.googleapis.com/v1/' - f'{endpoint_name}') + raise ValueError( + 'Invalid endpoint name: {}. Expect: {}.'.format( + endpoint_name, + 'projects/[project_id]/locations/[location]/endpoints/[endpoint_id]', + ) + ) + delete_endpoint_url = ( + f'https://{location}-aiplatform.googleapis.com/v1/{endpoint_name}' + ) try: remote_runner = lro_remote_runner.LroRemoteRunner(location) - delete_endpoint_lro = remote_runner.create_lro(delete_endpoint_url, '', - gcp_resources, 'delete') + delete_endpoint_lro = remote_runner.create_lro( + delete_endpoint_url, '', gcp_resources, 'delete' + ) delete_endpoint_lro = remote_runner.poll_lro(lro=delete_endpoint_lro) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/deploy_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/deploy_model/remote_runner.py index 9f90bb90f80..b64ba576ca1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/deploy_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/deploy_model/remote_runner.py @@ -20,6 +20,7 @@ _ENDPOINT_NAME_TEMPLATE = r'(projects/(?P.*)/locations/(?P.*)/endpoints/(?P.*))' + def deploy_model( type, project, @@ -30,7 +31,8 @@ def deploy_model( """Deploy model and poll the LongRunningOperator till it reaches a final state.""" # TODO(IronPan) temporarily remove the empty fields from the spec deploy_model_request = json_util.recursive_remove_empty( - json.loads(payload, strict=False)) + json.loads(payload, strict=False) + ) endpoint_name = deploy_model_request['endpoint'] uri_pattern = re.compile(_ENDPOINT_NAME_TEMPLATE) @@ -39,9 +41,12 @@ def deploy_model( location = match.group('location') except AttributeError as err: # TODO(ruifang) propagate the error. - raise ValueError('Invalid endpoint name: {}. Expect: {}.'.format( - endpoint_name, - 'projects/[project_id]/locations/[location]/endpoints/[endpoint_id]')) + raise ValueError( + 'Invalid endpoint name: {}. Expect: {}.'.format( + endpoint_name, + 'projects/[project_id]/locations/[location]/endpoints/[endpoint_id]', + ) + ) api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' deploy_model_url = f'{vertex_uri_prefix}{endpoint_name}:deployModel' @@ -49,7 +54,8 @@ def deploy_model( try: remote_runner = lro_remote_runner.LroRemoteRunner(location) deploy_model_lro = remote_runner.create_lro( - deploy_model_url, json.dumps(deploy_model_request), gcp_resources) + deploy_model_url, json.dumps(deploy_model_request), gcp_resources + ) deploy_model_lro = remote_runner.poll_lro(lro=deploy_model_lro) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/undeploy_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/undeploy_model/remote_runner.py index ed7c9564767..65e79e3958d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/undeploy_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/endpoint/undeploy_model/remote_runner.py @@ -31,12 +31,11 @@ def undeploy_model( payload, gcp_resources, ): - """ - Undeploy a model from the endpoint and poll the LongRunningOperator till it reaches a final state. - """ + """Undeploy a model from the endpoint and poll the LongRunningOperator till it reaches a final state.""" # TODO(IronPan) temporarily remove the empty fields from the spec undeploy_model_request = json_util.recursive_remove_empty( - json.loads(payload, strict=False)) + json.loads(payload, strict=False) + ) endpoint_name = undeploy_model_request['endpoint'] # Get the endpoint where the model is deployed to @@ -46,9 +45,12 @@ def undeploy_model( location = match.group('location') except AttributeError as err: # TODO(ruifang) propagate the error. - raise ValueError('Invalid endpoint name: {}. Expect: {}.'.format( - endpoint_name, - 'projects/[project_id]/locations/[location]/endpoints/[endpoint_id]')) + raise ValueError( + 'Invalid endpoint name: {}. Expect: {}.'.format( + endpoint_name, + 'projects/[project_id]/locations/[location]/endpoints/[endpoint_id]', + ) + ) api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' @@ -69,8 +71,9 @@ def undeploy_model( if not deployed_model_id: # TODO(ruifang) propagate the error. - raise ValueError('Model {} not found at endpoint {}.'.format( - model_name, endpoint_name)) + raise ValueError( + 'Model {} not found at endpoint {}.'.format(model_name, endpoint_name) + ) # Undeploy the model undeploy_model_lro_request = { @@ -78,7 +81,8 @@ def undeploy_model( } if 'traffic_split' in undeploy_model_request: undeploy_model_lro_request['traffic_split'] = undeploy_model_request[ - 'traffic_split'] + 'traffic_split' + ] model_uri_pattern = re.compile(_MODEL_NAME_TEMPLATE) match = model_uri_pattern.match(model_name) @@ -86,18 +90,24 @@ def undeploy_model( location = match.group('location') except AttributeError as err: # TODO(ruifang) propagate the error. - raise ValueError('Invalid model name: {}. Expect: {}.'.format( - model_name, - 'projects/[project_id]/locations/[location]/models/[model_id]')) + raise ValueError( + 'Invalid model name: {}. Expect: {}.'.format( + model_name, + 'projects/[project_id]/locations/[location]/models/[model_id]', + ) + ) api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' undeploy_model_url = f'{vertex_uri_prefix}{endpoint_name}:undeployModel' undeploy_model_remote_runner = lro_remote_runner.LroRemoteRunner(location) undeploy_model_lro = undeploy_model_remote_runner.create_lro( - undeploy_model_url, json.dumps(undeploy_model_lro_request), - gcp_resources) + undeploy_model_url, + json.dumps(undeploy_model_lro_request), + gcp_resources, + ) undeploy_model_lro = undeploy_model_remote_runner.poll_lro( - lro=undeploy_model_lro) + lro=undeploy_model_lro + ) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/job_remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/job_remote_runner.py index d294417f56b..f8bf40a93cd 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/job_remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/job_remote_runner.py @@ -55,7 +55,7 @@ _JOB_USER_ERROR_CODES = ( 3, # INVALID_ARGUMENT 5, # NOT_FOUND - 7, # PERMISSION_DENIED + 7, # PERMISSION_DENIED 6, # ALREADY_EXISTS 9, # FAILED_PRECONDITION 11, # OUT_OF_RANGE @@ -63,7 +63,7 @@ ) -class JobRemoteRunner(): +class JobRemoteRunner: """Common module for creating and poll jobs on the Vertex Platform.""" def __init__(self, job_type, project, location, gcp_resources): @@ -76,28 +76,36 @@ def __init__(self, job_type, project, location, gcp_resources): 'api_endpoint': location + '-aiplatform.googleapis.com' } self.client_info = gapic_v1.client_info.ClientInfo( - user_agent='google-cloud-pipeline-components') + user_agent='google-cloud-pipeline-components' + ) self.job_client = aiplatform.gapic.JobServiceClient( - client_options=self.client_options, client_info=self.client_info) + client_options=self.client_options, client_info=self.client_info + ) self.job_uri_prefix = f"https://{self.client_options['api_endpoint']}/v1/" self.poll_job_name = '' def check_if_job_exists(self) -> Optional[str]: """Check if the job already exists.""" - if path.exists( - self.gcp_resources) and os.stat(self.gcp_resources).st_size != 0: + if ( + path.exists(self.gcp_resources) + and os.stat(self.gcp_resources).st_size != 0 + ): with open(self.gcp_resources) as f: serialized_gcp_resources = f.read() - job_resources = json_format.Parse(serialized_gcp_resources, - GcpResources()) + job_resources = json_format.Parse( + serialized_gcp_resources, GcpResources() + ) # Resources should only contain one item. if len(job_resources.resources) != 1: raise ValueError( - f'gcp_resources should contain one resource, found {len(job_resources.resources)}' + 'gcp_resources should contain one resource, found' + f' {len(job_resources.resources)}' ) - job_name_group = re.findall(f'{self.job_uri_prefix}(.*)', - job_resources.resources[0].resource_uri) + job_name_group = re.findall( + f'{self.job_uri_prefix}(.*)', + job_resources.resources[0].resource_uri, + ) if not job_name_group or not job_name_group[0]: raise ValueError( @@ -105,8 +113,11 @@ def check_if_job_exists(self) -> Optional[str]: ) job_name = job_name_group[0] - logging.info('%s name already exists: %s. Continue polling the status', - self.job_type, job_name) + logging.info( + '%s name already exists: %s. Continue polling the status', + self.job_type, + job_name, + ) return job_name else: return None @@ -116,7 +127,8 @@ def create_job(self, create_job_fn, payload) -> str: parent = f'projects/{self.project}/locations/{self.location}' # TODO(kevinbnaughton) remove empty fields from the spec temporarily. job_spec = json_util.recursive_remove_empty( - json.loads(payload, strict=False)) + json.loads(payload, strict=False) + ) create_job_response = create_job_fn(self.job_client, parent, job_spec) job_name = create_job_response.name @@ -134,7 +146,8 @@ def create_job(self, create_job_fn, payload) -> str: def poll_job(self, get_job_fn, job_name: str): """Poll the job status.""" with execution_context.ExecutionContext( - on_cancel=lambda: self.send_cancel_request(job_name)): + on_cancel=lambda: self.send_cancel_request(job_name) + ): retry_count = 0 while True: try: @@ -145,43 +158,63 @@ def poll_job(self, get_job_fn, job_name: str): retry_count += 1 if retry_count < _CONNECTION_ERROR_RETRY_LIMIT: logging.warning( - 'ConnectionError (%s) encountered when polling job: %s. Trying to ' - 'recreate the API client.', err, job_name) + ( + 'ConnectionError (%s) encountered when polling job: %s.' + ' Trying to recreate the API client.' + ), + err, + job_name, + ) # Recreate the Python API client. self.job_client = aiplatform.gapic.JobServiceClient( - self.client_options, self.client_info) + self.client_options, self.client_info + ) else: # TODO(ruifang) propagate the error. """Exit with an internal error code.""" error_util.exit_with_internal_error( 'Request failed after %s retries.'.format( - _CONNECTION_ERROR_RETRY_LIMIT)) + _CONNECTION_ERROR_RETRY_LIMIT + ) + ) if get_job_response.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED: - logging.info('Get%s response state =%s', self.job_type, - get_job_response.state) + logging.info( + 'Get%s response state =%s', self.job_type, get_job_response.state + ) return get_job_response elif get_job_response.state in _JOB_ERROR_STATES: # TODO(ruifang) propagate the error. if get_job_response.error.code in _JOB_USER_ERROR_CODES: raise ValueError( 'Job failed with value error in error state: {}.'.format( - get_job_response.state)) + get_job_response.state + ) + ) else: - raise RuntimeError('Job failed with error state: {}.'.format( - get_job_response.state)) + raise RuntimeError( + 'Job failed with error state: {}.'.format( + get_job_response.state + ) + ) else: logging.info( - 'Job %s is in a non-final state %s.' - ' Waiting for %s seconds for next poll.', job_name, - get_job_response.state, _POLLING_INTERVAL_IN_SECONDS) + ( + 'Job %s is in a non-final state %s.' + ' Waiting for %s seconds for next poll.' + ), + job_name, + get_job_response.state, + _POLLING_INTERVAL_IN_SECONDS, + ) time.sleep(_POLLING_INTERVAL_IN_SECONDS) def send_cancel_request(self, job_name: str): if not job_name: return creds, _ = google.auth.default( - scopes=['https://www.googleapis.com/auth/cloud-platform']) + scopes=['https://www.googleapis.com/auth/cloud-platform'] + ) if not creds.valid: creds.refresh(google.auth.transport.requests.Request()) headers = { @@ -193,4 +226,5 @@ def send_cancel_request(self, job_name: str): # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs/cancel # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs/cancel requests.post( - url=f'{self.job_uri_prefix}{job_name}:cancel', data='', headers=headers) + url=f'{self.job_uri_prefix}{job_name}:cancel', data='', headers=headers + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/lro_remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/lro_remote_runner.py index 60c67415eba..3086994fdcf 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/lro_remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/lro_remote_runner.py @@ -35,7 +35,8 @@ 409, # Conflict ) -class LroRemoteRunner(): + +class LroRemoteRunner: """Common module for creating and poll LRO.""" def __init__(self, location) -> None: @@ -44,11 +45,13 @@ def __init__(self, location) -> None: self.creds, _ = google.auth.default() self.poll_lro_name = '' - def request(self, - request_url: str, - request_body: str, - http_request: str = 'post', - user_agent: str = 'google-cloud-pipeline-components') -> Any: + def request( + self, + request_url: str, + request_body: str, + http_request: str = 'post', + user_agent: str = 'google-cloud-pipeline-components', + ) -> Any: """Call the HTTP request""" if not self.creds.valid: self.creds.refresh(google.auth.transport.requests.Request()) @@ -61,23 +64,28 @@ def request(self, http_request_fn = getattr(requests, http_request) response = http_request_fn( - url=request_url, data=request_body, headers=headers).json() + url=request_url, data=request_body, headers=headers + ).json() if 'error' in response and response['error']['code']: if response['error']['code'] in _LRO_USER_ERROR_CODES: - raise ValueError('Failed to create the resource. Error: {}'.format( - response['error'])) + raise ValueError( + 'Failed to create the resource. Error: {}'.format(response['error']) + ) else: - raise RuntimeError('Failed to create the resource. Error: {}'.format( - response['error'])) + raise RuntimeError( + 'Failed to create the resource. Error: {}'.format(response['error']) + ) return response - def create_lro(self, - create_url: str, - request_body: str, - gcp_resources: str, - http_request: str = 'post') -> Any: + def create_lro( + self, + create_url: str, + request_body: str, + gcp_resources: str, + http_request: str = 'post', + ) -> Any: """call the create API and get a LRO""" # Currently we don't check if operation already exists and continue from there @@ -92,7 +100,8 @@ def create_lro(self, lro = self.request( request_url=create_url, request_body=request_body, - http_request=http_request) + http_request=http_request, + ) lro_name = lro['name'] get_operation_uri = f'{self.vertex_uri_prefix}{lro_name}' @@ -111,7 +120,8 @@ def poll_lro(self, lro: Any) -> Any: """Poll the LRO till it reaches a final state.""" lro_name = lro['name'] with execution_context.ExecutionContext( - on_cancel=lambda: self.send_cancel_request(lro_name)): + on_cancel=lambda: self.send_cancel_request(lro_name) + ): request_url = f'{self.vertex_uri_prefix}{lro_name}' while ('done' not in lro) or (not lro['done']): time.sleep(_POLLING_INTERVAL_IN_SECONDS) @@ -120,7 +130,8 @@ def poll_lro(self, lro: Any) -> Any: request_url=request_url, request_body='', http_request='get', - user_agent='') + user_agent='', + ) logging.info('Create resource complete. %s.', lro) return lro diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/artifact_util.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/artifact_util.py index dbf959f2c21..8c63b1bfae8 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/artifact_util.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/artifact_util.py @@ -18,16 +18,18 @@ from kfp import dsl -def update_output_artifact(executor_input: str, - target_artifact_name: str, - uri: str, - metadata: dict = {}): +def update_output_artifact( + executor_input: str, + target_artifact_name: str, + uri: str, + metadata: dict = {}, +): """Updates the output artifact with the new uri and metadata.""" executor_input_json = json.loads(executor_input) executor_output = {'artifacts': {}} - for name, artifacts in executor_input_json.get('outputs', - {}).get('artifacts', - {}).items(): + for name, artifacts in ( + executor_input_json.get('outputs', {}).get('artifacts', {}).items() + ): artifacts_list = artifacts.get('artifacts') if name == target_artifact_name and artifacts_list: updated_runtime_artifact = artifacts_list[0] @@ -40,7 +42,8 @@ def update_output_artifact(executor_input: str, # update the output artifacts. os.makedirs( os.path.dirname(executor_input_json['outputs']['outputFile']), - exist_ok=True) + exist_ok=True, + ) with open(executor_input_json['outputs']['outputFile'], 'w') as f: f.write(json.dumps(executor_output)) @@ -54,19 +57,20 @@ def update_output_artifacts(executor_input: str, artifacts: list): # This assumes that no other output artifact exists. for artifact in artifacts: if artifact.name in output_artifacts.keys(): - # Converts the artifact into executor output artifact - # https://github.com/kubeflow/pipelines/blob/master/api/v2alpha1/pipeline_spec.proto#L878 - artifacts_list = output_artifacts[artifact.name].get('artifacts') - if artifacts_list: - updated_runtime_artifact = artifacts_list[0] - updated_runtime_artifact['uri'] = artifact.uri - updated_runtime_artifact['metadata'] = artifact.metadata - artifacts_list = {'artifacts': [updated_runtime_artifact]} - executor_output['artifacts'][artifact.name] = artifacts_list + # Converts the artifact into executor output artifact + # https://github.com/kubeflow/pipelines/blob/master/api/v2alpha1/pipeline_spec.proto#L878 + artifacts_list = output_artifacts[artifact.name].get('artifacts') + if artifacts_list: + updated_runtime_artifact = artifacts_list[0] + updated_runtime_artifact['uri'] = artifact.uri + updated_runtime_artifact['metadata'] = artifact.metadata + artifacts_list = {'artifacts': [updated_runtime_artifact]} + executor_output['artifacts'][artifact.name] = artifacts_list # update the output artifacts. os.makedirs( os.path.dirname(executor_input_json['outputs']['outputFile']), - exist_ok=True) + exist_ok=True, + ) with open(executor_input_json['outputs']['outputFile'], 'w') as f: f.write(json.dumps(executor_output)) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/gcp_labels_util.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/gcp_labels_util.py index 76c1d3c21a9..1404f9d7b14 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/gcp_labels_util.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/gcp_labels_util.py @@ -31,7 +31,7 @@ def attach_system_labels(existing_labels=None): Args: existing_labels: Optional[dict[str,str]]. If provided, will combine with the - system labels read from the environmental variable. + system labels read from the environmental variable. Returns: Optional[dict[str,str]] The combined labels, or None if existing_labels is diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/json_util.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/json_util.py index bd401e20149..758bc1a7e12 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/json_util.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/json_util.py @@ -41,8 +41,12 @@ def __remove_empty(j): res = [] for i in j: # Don't remove empty primitive types. Only remove other empty types. - if isinstance(i, int) or isinstance(i, float) or isinstance( - i, str) or isinstance(i, bool): + if ( + isinstance(i, int) + or isinstance(i, float) + or isinstance(i, str) + or isinstance(i, bool) + ): res.append(i) elif __remove_empty(i): res.append(__remove_empty(i)) @@ -62,11 +66,13 @@ def recursive_remove_empty(j): # Handle special case where an empty "explanation_spec" "metadata" "outputs" # should not be removed. Introduced for b/245453693. temp_explanation_spec_metadata_outputs = None - if ('explanation_spec' - in j) and ('metadata' in j['explanation_spec'] and - 'outputs' in j['explanation_spec']['metadata']): + if ('explanation_spec' in j) and ( + 'metadata' in j['explanation_spec'] + and 'outputs' in j['explanation_spec']['metadata'] + ): temp_explanation_spec_metadata_outputs = j['explanation_spec']['metadata'][ - 'outputs'] + 'outputs' + ] needs_update = True while needs_update: @@ -80,6 +86,7 @@ def recursive_remove_empty(j): if 'metadata' not in j['explanation_spec']: j['explanation_spec']['metadata'] = {} j['explanation_spec']['metadata'][ - 'outputs'] = temp_explanation_spec_metadata_outputs + 'outputs' + ] = temp_explanation_spec_metadata_outputs return j diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/parser_util.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/parser_util.py index a4d960c3ccd..a1f2e0f48a4 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/parser_util.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/gcp_launcher/utils/parser_util.py @@ -25,33 +25,39 @@ def _make_parent_dirs_and_return_path(file_path: str): def parse_default_args(args): """Parse default command line arguments.""" parser = argparse.ArgumentParser( - prog='Vertex Pipelines service launcher', description='') + prog='Vertex Pipelines service launcher', description='' + ) parser.add_argument( - '--type', dest='type', type=str, required=True, default=argparse.SUPPRESS) + '--type', dest='type', type=str, required=True, default=argparse.SUPPRESS + ) parser.add_argument( '--project', dest='project', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--location', dest='location', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--payload', dest='payload', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( '--gcp_resources', dest='gcp_resources', type=_make_parent_dirs_and_return_path, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return (parser, parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/launcher.py index e15e91286bb..ba028ec5d38 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/launcher.py @@ -29,7 +29,8 @@ def _parse_args(args): dest='execution_metrics', type=str, required=False, - default=None) + default=None, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) @@ -54,7 +55,8 @@ def main(argv): job_type = parsed_args['type'] if job_type not in [ - 'HyperparameterTuningJob', 'HyperparameterTuningJobWithMetrics' + 'HyperparameterTuningJob', + 'HyperparameterTuningJobWithMetrics', ]: raise ValueError('Incorrect job type: ' + job_type) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/remote_runner.py index 9890bb7f5cf..25ffd1541d7 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/hyperparameter_tuning_job/remote_runner.py @@ -28,8 +28,11 @@ def create_hyperparameter_tuning_job_with_client(job_client, parent, job_spec): create_hyperparameter_tuning_job_fn = None try: - create_hyperparameter_tuning_job_fn = job_client.create_hyperparameter_tuning_job( - parent=parent, hyperparameter_tuning_job=job_spec) + create_hyperparameter_tuning_job_fn = ( + job_client.create_hyperparameter_tuning_job( + parent=parent, hyperparameter_tuning_job=job_spec + ) + ) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) return create_hyperparameter_tuning_job_fn @@ -41,7 +44,9 @@ def get_hyperparameter_tuning_job_with_client(job_client, job_name): get_hyperparameter_tuning_job_fn = job_client.get_hyperparameter_tuning_job( name=job_name, retry=retry.Retry( - deadline=_HYPERPARAMETER_TUNING_JOB_RETRY_DEADLINE_SECONDS)) + deadline=_HYPERPARAMETER_TUNING_JOB_RETRY_DEADLINE_SECONDS + ), + ) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) return get_hyperparameter_tuning_job_fn @@ -71,29 +76,34 @@ def create_hyperparameter_tuning_job( Also retry on ConnectionError up to job_remote_runner._CONNECTION_ERROR_RETRY_LIMIT times during the poll. """ - remote_runner = job_remote_runner.JobRemoteRunner(type, project, location, - gcp_resources) + remote_runner = job_remote_runner.JobRemoteRunner( + type, project, location, gcp_resources + ) job_spec = json.loads(payload) job_spec[_LABELS_PAYLOAD_KEY] = gcp_labels_util.attach_system_labels( - job_spec[_LABELS_PAYLOAD_KEY] if _LABELS_PAYLOAD_KEY in job_spec else {}) + job_spec[_LABELS_PAYLOAD_KEY] if _LABELS_PAYLOAD_KEY in job_spec else {} + ) try: # Create HP Tuning job if it does not exist job_name = remote_runner.check_if_job_exists() if job_name is None: job_name = remote_runner.create_job( - create_hyperparameter_tuning_job_with_client, json.dumps(job_spec)) + create_hyperparameter_tuning_job_with_client, json.dumps(job_spec) + ) # Poll HP Tuning job status until "JobState.JOB_STATE_SUCCEEDED" get_job_response = remote_runner.poll_job( - get_hyperparameter_tuning_job_with_client, job_name) + get_hyperparameter_tuning_job_with_client, job_name + ) if type == 'HyperparameterTuningJobWithMetrics': completed_trials = [ - t for t in get_job_response.trials + t + for t in get_job_response.trials if t.state == gca_study.Trial.State.SUCCEEDED ] execution_metrics_dict = { 'success_trials_count': len(completed_trials), - 'total_trials_count': len(get_job_response.trials) + 'total_trials_count': len(get_job_response.trials), } with open(execution_metrics, 'w') as f: f.write(json.dumps(execution_metrics_dict, sort_keys=True)) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/launcher.py index 0a7ed581a13..e4e2750b9da 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/launcher.py @@ -30,7 +30,8 @@ def _parse_args(args): dest='executor_input', type=str, required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/remote_runner.py index c38240cdfef..cae90bad333 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/infra_validation_job/remote_runner.py @@ -25,13 +25,21 @@ def construct_infra_validation_job_payload(executor_input, payload): """Construct infra validation payload for CustomJob.""" # Extract artifact uri and prediction server uri - artifact = json.loads(executor_input).get('inputs', {}).get( - 'artifacts', {}).get(ARTIFACT_PROPERTY_KEY_UNMANAGED_CONTAINER_MODEL, - {}).get('artifacts') + artifact = ( + json.loads(executor_input) + .get('inputs', {}) + .get('artifacts', {}) + .get(ARTIFACT_PROPERTY_KEY_UNMANAGED_CONTAINER_MODEL, {}) + .get('artifacts') + ) if artifact: model_artifact_path = artifact[0].get('uri') - prediction_server_image_uri = artifact[0].get('metadata', {}).get( - API_KEY_CONTAINER_SPEC, {}).get('imageUri', '') + prediction_server_image_uri = ( + artifact[0] + .get('metadata', {}) + .get(API_KEY_CONTAINER_SPEC, {}) + .get('imageUri', '') + ) else: raise ValueError('unmanaged_container_model not found in executor_input.') @@ -41,10 +49,12 @@ def construct_infra_validation_job_payload(executor_input, payload): # extract infra validation example path infra_validation_example_path = payload_json.get( - 'infra_validation_example_path') + 'infra_validation_example_path' + ) if infra_validation_example_path: args.extend( - ['--infra_validation_example_path', infra_validation_example_path]) + ['--infra_validation_example_path', infra_validation_example_path] + ) env_variables = [{'name': 'INFRA_VALIDATION_MODE', 'value': '1'}] @@ -60,7 +70,7 @@ def construct_infra_validation_job_payload(executor_input, payload): 'container_spec': { 'image_uri': prediction_server_image_uri, 'args': args, - 'env': env_variables + 'env': env_variables, }, }] }, @@ -70,12 +80,7 @@ def construct_infra_validation_job_payload(executor_input, payload): def create_infra_validation_job( - type, - project, - location, - gcp_resources, - executor_input, - payload + type, project, location, gcp_resources, executor_input, payload ): """Create and poll infra validation job status till it reaches a final state. @@ -96,9 +101,10 @@ def create_infra_validation_job( """ try: infra_validation_job_payload = construct_infra_validation_job_payload( - executor_input, payload) - create_custom_job(type, project, location, infra_validation_job_payload, - gcp_resources) + executor_input, payload + ) + create_custom_job( + type, project, location, infra_validation_job_payload, gcp_resources + ) except (ConnectionError, RuntimeError, ValueError) as err: error_util.exit_with_internal_error(err.args[0]) - diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/delete_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/delete_model/remote_runner.py index fd8e38f3e58..4798fbc6305 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/delete_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/delete_model/remote_runner.py @@ -31,7 +31,8 @@ def delete_model( """Delete model and poll the LongRunningOperator till it reaches a final state.""" # TODO(IronPan) temporarily remove the empty fields from the spec delete_model_request = json_util.recursive_remove_empty( - json.loads(payload, strict=False)) + json.loads(payload, strict=False) + ) model_name = delete_model_request['model'] uri_pattern = re.compile(_MODEL_NAME_TEMPLATE) @@ -40,17 +41,21 @@ def delete_model( location = match.group('location') except AttributeError as err: # TODO(ruifang) propagate the error. - raise ValueError('Invalid model name: {}. Expect: {}.'.format( - model_name, - 'projects/[project_id]/locations/[location]/models/[model_id]')) + raise ValueError( + 'Invalid model name: {}. Expect: {}.'.format( + model_name, + 'projects/[project_id]/locations/[location]/models/[model_id]', + ) + ) api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' delete_model_url = f'{vertex_uri_prefix}{model_name}' try: remote_runner = lro_remote_runner.LroRemoteRunner(location) - delete_model_lro = remote_runner.create_lro(delete_model_url, '', - gcp_resources, 'delete') + delete_model_lro = remote_runner.create_lro( + delete_model_url, '', gcp_resources, 'delete' + ) delete_model_lro = remote_runner.poll_lro(lro=delete_model_lro) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/launcher.py index 29915f035c0..f0b571d218b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/launcher.py @@ -30,7 +30,8 @@ def _parse_args(args): type=str, # output_info is only needed for ExportModel component. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/remote_runner.py index c4dc626be52..40b881f790a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/export_model/remote_runner.py @@ -26,7 +26,8 @@ def export_model(type, project, location, payload, gcp_resources, output_info): """Export model and poll the LongRunningOperator till it reaches a final state.""" # TODO(IronPan) temporarily remove the empty fields from the spec export_model_request = json_util.recursive_remove_empty( - json.loads(payload, strict=False)) + json.loads(payload, strict=False) + ) model_name = export_model_request['name'] uri_pattern = re.compile(_MODEL_NAME_TEMPLATE) @@ -35,9 +36,12 @@ def export_model(type, project, location, payload, gcp_resources, output_info): location = match.group('location') except AttributeError as err: # TODO(ruifang) propagate the error. - raise ValueError('Invalid model name: {}. Expect: {}.'.format( - model_name, - 'projects/[project_id]/locations/[location]/models/[model_id]')) + raise ValueError( + 'Invalid model name: {}. Expect: {}.'.format( + model_name, + 'projects/[project_id]/locations/[location]/models/[model_id]', + ) + ) api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' @@ -46,7 +50,8 @@ def export_model(type, project, location, payload, gcp_resources, output_info): try: remote_runner = lro_remote_runner.LroRemoteRunner(location) export_model_lro = remote_runner.create_lro( - export_model_url, json.dumps(export_model_request), gcp_resources) + export_model_url, json.dumps(export_model_request), gcp_resources + ) export_model_lro = remote_runner.poll_lro(lro=export_model_lro) output_info_content = export_model_lro['metadata']['outputInfo'] with open(output_info, 'w') as f: diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/launcher.py index fe314e66960..b589f00891c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/launcher.py @@ -30,12 +30,11 @@ def _parse_args(args): type=str, # executor_input is only needed for components that emit output artifacts. required=True, - default=argparse.SUPPRESS) + default=argparse.SUPPRESS, + ) parser.add_argument( - '--parent_model_name', - dest='parent_model_name', - type=str, - default=None) + '--parent_model_name', dest='parent_model_name', type=str, default=None + ) parsed_args, _ = parser.parse_known_args(args) return vars(parsed_args) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/remote_runner.py index 0ab59aa5b4d..34e3975e713 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/upload_model/remote_runner.py @@ -31,16 +31,24 @@ def append_unmanaged_model_artifact_into_payload(executor_input, model_spec): - artifact = json.loads(executor_input).get('inputs', {}).get( - 'artifacts', {}).get(ARTIFACT_PROPERTY_KEY_UNMANAGED_CONTAINER_MODEL, - {}).get('artifacts') + artifact = ( + json.loads(executor_input) + .get('inputs', {}) + .get('artifacts', {}) + .get(ARTIFACT_PROPERTY_KEY_UNMANAGED_CONTAINER_MODEL, {}) + .get('artifacts') + ) if artifact: - model_spec[ - API_KEY_PREDICT_SCHEMATA] = json_util.camel_case_to_snake_case_recursive( - artifact[0].get('metadata', {}).get('predictSchemata', {})) - model_spec[ - API_KEY_CONTAINER_SPEC] = json_util.camel_case_to_snake_case_recursive( - artifact[0].get('metadata', {}).get('containerSpec', {})) + model_spec[API_KEY_PREDICT_SCHEMATA] = ( + json_util.camel_case_to_snake_case_recursive( + artifact[0].get('metadata', {}).get('predictSchemata', {}) + ) + ) + model_spec[API_KEY_CONTAINER_SPEC] = ( + json_util.camel_case_to_snake_case_recursive( + artifact[0].get('metadata', {}).get('containerSpec', {}) + ) + ) model_spec[API_KEY_ARTIFACT_URI] = artifact[0].get('uri') return model_spec @@ -60,36 +68,46 @@ def upload_model( upload_model_url = f'{vertex_uri_prefix}projects/{project}/locations/{location}/models:upload' model_spec = json.loads(payload, strict=False) model_spec[_LABELS_PAYLOAD_KEY] = gcp_labels_util.attach_system_labels( - model_spec[_LABELS_PAYLOAD_KEY] if _LABELS_PAYLOAD_KEY in - model_spec else {}) + model_spec[_LABELS_PAYLOAD_KEY] + if _LABELS_PAYLOAD_KEY in model_spec + else {} + ) upload_model_request = { # TODO(IronPan) temporarily remove the empty fields from the spec - 'model': - json_util.recursive_remove_empty( - append_unmanaged_model_artifact_into_payload( - executor_input, model_spec)) + 'model': json_util.recursive_remove_empty( + append_unmanaged_model_artifact_into_payload( + executor_input, model_spec + ) + ) } if parent_model_name: upload_model_request['parent_model'] = parent_model_name.rsplit('@', 1)[0] # Add explanation_spec details back into the request if metadata is non-empty, as sklearn/xgboost input features can be empty. - if (('explanation_spec' in model_spec) and - ('metadata' in model_spec['explanation_spec']) and - model_spec['explanation_spec']['metadata']): + if ( + ('explanation_spec' in model_spec) + and ('metadata' in model_spec['explanation_spec']) + and model_spec['explanation_spec']['metadata'] + ): upload_model_request['model']['explanation_spec']['metadata'] = model_spec[ - 'explanation_spec']['metadata'] + 'explanation_spec' + ]['metadata'] try: remote_runner = lro_remote_runner.LroRemoteRunner(location) upload_model_lro = remote_runner.create_lro( - upload_model_url, json.dumps(upload_model_request), gcp_resources) + upload_model_url, json.dumps(upload_model_request), gcp_resources + ) upload_model_lro = remote_runner.poll_lro(lro=upload_model_lro) model_resource_name = upload_model_lro['response']['model'] if 'model_version_id' in upload_model_lro['response']: - model_resource_name += f'@{upload_model_lro["response"]["model_version_id"]}' + model_resource_name += ( + f'@{upload_model_lro["response"]["model_version_id"]}' + ) - vertex_model = VertexModel('model', vertex_uri_prefix + model_resource_name, - model_resource_name) + vertex_model = VertexModel( + 'model', vertex_uri_prefix + model_resource_name, model_resource_name + ) artifact_util.update_output_artifacts(executor_input, [vertex_model]) except (ConnectionError, RuntimeError) as err: error_util.exit_with_internal_error(err.args[0]) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/vertex_notification_email/executor.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/vertex_notification_email/executor.py index 6aadc7e7c82..49fe87e4e3a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/vertex_notification_email/executor.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/vertex_notification_email/executor.py @@ -20,8 +20,9 @@ def main(): The notification email component works only on Vertex Pipelines. This function raises an exception when this component is used on Kubeflow Pipelines. """ - raise NotImplementedError('The notification email component is supported ' - 'only on Vertex Pipelines.') + raise NotImplementedError( + 'The notification email component is supported only on Vertex Pipelines.' + ) if __name__ == '__main__': diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/wait_gcp_resources/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/wait_gcp_resources/remote_runner.py index 4d8b14763b4..6e70b2a4605 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/wait_gcp_resources/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/wait_gcp_resources/remote_runner.py @@ -30,8 +30,11 @@ _JOB_SUCCESSFUL_STATES = ['JOB_STATE_DONE'] _JOB_CANCELLED_STATE = 'JOB_STATE_CANCELLED' _JOB_FAILED_STATES = [ - 'JOB_STATE_STOPPED', 'JOB_STATE_FAILED', _JOB_CANCELLED_STATE, - 'JOB_STATE_UPDATED', 'JOB_STATE_DRAINED' + 'JOB_STATE_STOPPED', + 'JOB_STATE_FAILED', + _JOB_CANCELLED_STATE, + 'JOB_STATE_UPDATED', + 'JOB_STATE_DRAINED', ] _JOB_TERMINATED_STATES = _JOB_SUCCESSFUL_STATES + _JOB_FAILED_STATES _DATAFLOW_URI_TEMPLATE = r'(https://dataflow.googleapis.com/v1b3/projects/(?P.*)/locations/(?P.*)/jobs/(?P.*))' @@ -44,19 +47,19 @@ def wait_gcp_resources( payload, gcp_resources, ): - """ - Poll the gcp resources till it reaches a final state. - """ + """Poll the gcp resources till it reaches a final state.""" input_gcp_resources = Parse(payload, GcpResources()) if len(input_gcp_resources.resources) != 1: raise ValueError( - 'Invalid payload: %s. Wait component support waiting on only one resource at this moment.' - % payload) + 'Invalid payload: %s. Wait component support waiting on only one' + ' resource at this moment.' % payload + ) if input_gcp_resources.resources[0].resource_type != 'DataflowJob': raise ValueError( - 'Invalid payload: %s. Wait component only support waiting on Dataflow job at this moment.' - % payload) + 'Invalid payload: %s. Wait component only support waiting on Dataflow' + ' job at this moment.' % payload + ) dataflow_job_uri = input_gcp_resources.resources[0].resource_uri uri_pattern = re.compile(_DATAFLOW_URI_TEMPLATE) @@ -68,10 +71,12 @@ def wait_gcp_resources( job_id = match.group('jobid') except AttributeError as err: # TODO(ruifang) propagate the error. - raise ValueError('Invalid dataflow resource URI: {}. Expect: {}.'.format( - dataflow_job_uri, - 'https://dataflow.googleapis.com/v1b3/projects/[project_id]/locations/[location]/jobs/[job_id]' - )) + raise ValueError( + 'Invalid dataflow resource URI: {}. Expect: {}.'.format( + dataflow_job_uri, + 'https://dataflow.googleapis.com/v1b3/projects/[project_id]/locations/[location]/jobs/[job_id]', + ) + ) # Propagate the GCP resources as the output of the wait component with open(gcp_resources, 'w') as f: @@ -83,63 +88,100 @@ def wait_gcp_resources( project, job_id, location, - )): + ) + ): # Poll the job status retry_count = 0 while True: try: df_client = discovery.build('dataflow', 'v1b3', cache_discovery=False) - job = df_client.projects().locations().jobs().get( - projectId=project, jobId=job_id, location=location, - view=None).execute() + job = ( + df_client.projects() + .locations() + .jobs() + .get(projectId=project, jobId=job_id, location=location, view=None) + .execute() + ) retry_count = 0 except ConnectionError as err: retry_count += 1 if retry_count <= _CONNECTION_ERROR_RETRY_LIMIT: logging.warning( - 'ConnectionError (%s) encountered when polling job: %s. Retrying.', - err, job_id) + ( + 'ConnectionError (%s) encountered when polling job: %s.' + ' Retrying.' + ), + err, + job_id, + ) else: error_util.exit_with_internal_error( 'Request failed after %s retries.'.format( - _CONNECTION_ERROR_RETRY_LIMIT)) + _CONNECTION_ERROR_RETRY_LIMIT + ) + ) job_state = job.get('currentState', None) # Write the job details as gcp_resources if job_state in _JOB_SUCCESSFUL_STATES: - logging.info('GetDataflowJob response state =%s. Job completed', - job_state) + logging.info( + 'GetDataflowJob response state =%s. Job completed', job_state + ) return elif job_state in _JOB_TERMINATED_STATES: # TODO(ruifang) propagate the error. - raise RuntimeError('Job {} failed with error state: {}.'.format( - job_id, job_state)) + raise RuntimeError( + 'Job {} failed with error state: {}.'.format(job_id, job_state) + ) else: logging.info( - 'Job %s is in a non-final state %s. Waiting for %s seconds for next poll.', - job_id, job_state, _POLLING_INTERVAL_IN_SECONDS) + ( + 'Job %s is in a non-final state %s. Waiting for %s seconds for' + ' next poll.' + ), + job_id, + job_state, + _POLLING_INTERVAL_IN_SECONDS, + ) time.sleep(_POLLING_INTERVAL_IN_SECONDS) def _send_cancel_request(project, job_id, location): - logging.info('dataflow_cancelling_job_params: %s, %s, %s', project, job_id, - location) + logging.info( + 'dataflow_cancelling_job_params: %s, %s, %s', project, job_id, location + ) df_client = discovery.build('dataflow', 'v1b3', cache_discovery=False) - job = df_client.projects().locations().jobs().get( - projectId=project, jobId=job_id, location=location, view=None).execute() + job = ( + df_client.projects() + .locations() + .jobs() + .get(projectId=project, jobId=job_id, location=location, view=None) + .execute() + ) # Dataflow cancel API: # https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline#stopping_a_job logging.info('Sending Dataflow cancel request') job['requestedState'] = _JOB_CANCELLED_STATE logging.info('dataflow_cancelling_job: %s', job) - job = df_client.projects().locations().jobs().update( - projectId=project, - jobId=job_id, - location=location, - body=job, - ).execute() + job = ( + df_client.projects() + .locations() + .jobs() + .update( + projectId=project, + jobId=job_id, + location=location, + body=job, + ) + .execute() + ) logging.info('dataflow_cancelled_job: %s', job) - job = df_client.projects().locations().jobs().get( - projectId=project, jobId=job_id, location=location, view=None).execute() + job = ( + df_client.projects() + .locations() + .jobs() + .get(projectId=project, jobId=job_id, location=location, view=None) + .execute() + ) logging.info('dataflow_cancelled_job: %s', job) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/__init__.py index f8796e343ed..93a41195a7d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/__init__.py @@ -26,13 +26,17 @@ ] ProphetTrainerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'prophet_trainer.yaml')) + os.path.join(os.path.dirname(__file__), 'prophet_trainer.yaml') +) ForecastingStage1TunerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'forecasting_stage_1_tuner.yaml')) + os.path.join(os.path.dirname(__file__), 'forecasting_stage_1_tuner.yaml') +) ForecastingEnsembleOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'forecasting_ensemble.yaml')) + os.path.join(os.path.dirname(__file__), 'forecasting_ensemble.yaml') +) ForecastingStage2TunerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'forecasting_stage_2_tuner.yaml')) + os.path.join(os.path.dirname(__file__), 'forecasting_stage_2_tuner.yaml') +) ModelEvaluationForecastingOp = load_component_from_file( os.path.join(os.path.dirname(__file__), 'model_evaluation_forecasting.yaml') ) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/utils.py b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/utils.py index 076a26633fb..6b139763d65 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/utils.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/forecasting/utils.py @@ -272,7 +272,8 @@ def get_prophet_train_pipeline_and_parameters( 'dataflow_use_public_ips': dataflow_use_public_ips, } pipeline_definition_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), 'prophet_trainer_pipeline.yaml') + pathlib.Path(__file__).parent.resolve(), 'prophet_trainer_pipeline.yaml' + ) return pipeline_definition_path, parameter_values @@ -332,7 +333,8 @@ def get_prophet_prediction_pipeline_and_parameters( 'max_num_workers': max_num_workers, } pipeline_definition_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), 'prophet_predict_pipeline.yaml') + pathlib.Path(__file__).parent.resolve(), 'prophet_predict_pipeline.yaml' + ) return pipeline_definition_path, parameter_values diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/__init__.py index 69635787278..4cf8bc8eaf3 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/__init__.py @@ -35,47 +35,66 @@ 'TrainingConfiguratorAndValidatorOp', 'TrainingConfiguratorAndValidatorOp', 'XGBoostHyperparameterTuningJobOp', - 'XGBoostTrainerOp' + 'XGBoostTrainerOp', ] CvTrainerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'cv_trainer.yaml')) + os.path.join(os.path.dirname(__file__), 'cv_trainer.yaml') +) InfraValidatorOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'infra_validator.yaml')) + os.path.join(os.path.dirname(__file__), 'infra_validator.yaml') +) Stage1TunerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'stage_1_tuner.yaml')) + os.path.join(os.path.dirname(__file__), 'stage_1_tuner.yaml') +) EnsembleOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ensemble.yaml')) + os.path.join(os.path.dirname(__file__), 'ensemble.yaml') +) StatsAndExampleGenOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'stats_and_example_gen.yaml')) + os.path.join(os.path.dirname(__file__), 'stats_and_example_gen.yaml') +) FeatureSelectionOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'feature_selection.yaml')) + os.path.join(os.path.dirname(__file__), 'feature_selection.yaml') +) TransformOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'transform.yaml')) + os.path.join(os.path.dirname(__file__), 'transform.yaml') +) FeatureTransformEngineOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'feature_transform_engine.yaml')) + os.path.join(os.path.dirname(__file__), 'feature_transform_engine.yaml') +) SplitMaterializedDataOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'split_materialized_data.yaml')) + os.path.join(os.path.dirname(__file__), 'split_materialized_data.yaml') +) FinalizerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'finalizer.yaml')) + os.path.join(os.path.dirname(__file__), 'finalizer.yaml') +) WideAndDeepTrainerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'wide_and_deep_trainer.yaml')) + os.path.join(os.path.dirname(__file__), 'wide_and_deep_trainer.yaml') +) WideAndDeepHyperparameterTuningJobOp = load_component_from_file( os.path.join( os.path.dirname(__file__), - 'wide_and_deep_hyperparameter_tuning_job.yaml')) + 'wide_and_deep_hyperparameter_tuning_job.yaml', + ) +) TabNetHyperparameterTuningJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'tabnet_hyperparameter_tuning_job.yaml')) + os.path.dirname(__file__), 'tabnet_hyperparameter_tuning_job.yaml' + ) +) TabNetTrainerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'tabnet_trainer.yaml')) + os.path.join(os.path.dirname(__file__), 'tabnet_trainer.yaml') +) TrainingConfiguratorAndValidatorOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'training_configurator_and_validator.yaml')) + os.path.dirname(__file__), 'training_configurator_and_validator.yaml' + ) +) XGBoostTrainerOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'xgboost_trainer.yaml')) + os.path.join(os.path.dirname(__file__), 'xgboost_trainer.yaml') +) XGBoostHyperparameterTuningJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'xgboost_hyperparameter_tuning_job.yaml')) + os.path.dirname(__file__), 'xgboost_hyperparameter_tuning_job.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/utils.py b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/utils.py index 00894063bd4..9b5f45b020c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/utils.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/automl/tabular/utils.py @@ -23,11 +23,12 @@ _EVALUATION_DATAFLOW_DISK_SIZE_GB = 50 -def _update_parameters(parameter_values: Dict[str, Any], new_params: Dict[str, - Any]): - parameter_values.update({ - param: value for param, value in new_params.items() if value is not None - }) +def _update_parameters( + parameter_values: Dict[str, Any], new_params: Dict[str, Any] +): + parameter_values.update( + {param: value for param, value in new_params.items() if value is not None} + ) def _get_default_pipeline_params( @@ -87,171 +88,116 @@ def _get_default_pipeline_params( distill_batch_predict_max_replica_count: Optional[int] = None, stage_1_tuning_result_artifact_uri: Optional[str] = None, quantiles: Optional[List[float]] = None, - enable_probabilistic_inference: bool = False + enable_probabilistic_inference: bool = False, ) -> Dict[str, Any]: """Get the AutoML Tabular v1 default training pipeline. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - optimization_objective: - For binary classification, "maximize-au-roc", + optimization_objective: For binary classification, "maximize-au-roc", "minimize-log-loss", "maximize-au-prc", "maximize-precision-at-recall", or "maximize-recall-at-precision". For multi class classification, "minimize-log-loss". For regression, "minimize-rmse", "minimize-mae", or "minimize-rmsle". - transformations: - The path to a GCS file containing the transformations to + transformations: The path to a GCS file containing the transformations to apply. - train_budget_milli_node_hours: - The train budget of creating this model, + train_budget_milli_node_hours: The train budget of creating this model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. - stage_1_num_parallel_trials: - Number of parallel trails for stage 1. - stage_2_num_parallel_trials: - Number of parallel trails for stage 2. - stage_2_num_selected_trials: - Number of selected trials for stage 2. - data_source_csv_filenames: - The CSV data source. - data_source_bigquery_table_path: - The BigQuery data source. - predefined_split_key: - The predefined_split column name. - timestamp_split_key: - The timestamp_split column name. - stratified_split_key: - The stratified_split column name. - training_fraction: - The training fraction. - validation_fraction: - The validation fraction. - test_fraction: - float = The test fraction. - weight_column: - The weight column name. - study_spec_parameters_override: - The list for overriding study spec. The list + stage_1_num_parallel_trials: Number of parallel trails for stage 1. + stage_2_num_parallel_trials: Number of parallel trails for stage 2. + stage_2_num_selected_trials: Number of selected trials for stage 2. + data_source_csv_filenames: The CSV data source. + data_source_bigquery_table_path: The BigQuery data source. + predefined_split_key: The predefined_split column name. + timestamp_split_key: The timestamp_split column name. + stratified_split_key: The stratified_split column name. + training_fraction: The training fraction. + validation_fraction: The validation fraction. + test_fraction: float = The test fraction. + weight_column: The weight column name. + study_spec_parameters_override: The list for overriding study spec. The list should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/study.proto#L181. - optimization_objective_recall_value: - Required when optimization_objective is + optimization_objective_recall_value: Required when optimization_objective is "maximize-precision-at-recall". Must be between 0 and 1, inclusive. - optimization_objective_precision_value: - Required when optimization_objective + optimization_objective_precision_value: Required when optimization_objective is "maximize-recall-at-precision". Must be between 0 and 1, inclusive. - stage_1_tuner_worker_pool_specs_override: - The dictionary for overriding. + stage_1_tuner_worker_pool_specs_override: The dictionary for overriding. stage 1 tuner worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - cv_trainer_worker_pool_specs_override: - The dictionary for overriding stage + cv_trainer_worker_pool_specs_override: The dictionary for overriding stage cv trainer worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - export_additional_model_without_custom_ops: - Whether to export additional + export_additional_model_without_custom_ops: Whether to export additional model without custom TensorFlow operators. - stats_and_example_gen_dataflow_machine_type: - The dataflow machine type for + stats_and_example_gen_dataflow_machine_type: The dataflow machine type for stats_and_example_gen component. - stats_and_example_gen_dataflow_max_num_workers: - The max number of Dataflow + stats_and_example_gen_dataflow_max_num_workers: The max number of Dataflow workers for stats_and_example_gen component. - stats_and_example_gen_dataflow_disk_size_gb: - Dataflow worker's disk size in + stats_and_example_gen_dataflow_disk_size_gb: Dataflow worker's disk size in GB for stats_and_example_gen component. - transform_dataflow_machine_type: - The dataflow machine type for transform + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. - additional_experiments: - Use this field to config private preview features. - dataflow_service_account: - Custom service account to run dataflow jobs. - max_selected_features: - number of features to select for training, - apply_feature_selection_tuning: - tuning feature selection rate if true. - run_evaluation: - Whether to run evaluation in the training pipeline. - evaluation_batch_predict_machine_type: - The prediction server machine type + encryption_spec_key_name: The KMS key name. + additional_experiments: Use this field to config private preview features. + dataflow_service_account: Custom service account to run dataflow jobs. + max_selected_features: number of features to select for training, + apply_feature_selection_tuning: tuning feature selection rate if true. + run_evaluation: Whether to run evaluation in the training pipeline. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_batch_explain_machine_type: - The prediction server machine type for batch explain components during - evaluation. - evaluation_batch_explain_starting_replica_count: - The initial number of prediction server for batch explain components - during evaluation. - evaluation_batch_explain_max_replica_count: - The max number of prediction server for batch explain components during - evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_batch_explain_machine_type: The prediction server machine type + for batch explain components during evaluation. + evaluation_batch_explain_starting_replica_count: The initial number of + prediction server for batch explain components during evaluation. + evaluation_batch_explain_max_replica_count: The max number of prediction + server for batch explain components during evaluation. + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. - run_distillation: - Whether to run distill in the training pipeline. - distill_batch_predict_machine_type: - The prediction server machine type for + run_distillation: Whether to run distill in the training pipeline. + distill_batch_predict_machine_type: The prediction server machine type for batch predict component in the model distillation. - distill_batch_predict_starting_replica_count: - The initial number of + distill_batch_predict_starting_replica_count: The initial number of prediction server for batch predict component in the model distillation. - distill_batch_predict_max_replica_count: - The max number of prediction server + distill_batch_predict_max_replica_count: The max number of prediction server for batch predict component in the model distillation. - stage_1_tuning_result_artifact_uri: - The stage 1 tuning result artifact GCS URI. - quantiles: - Quantiles to use for probabilistic inference. Up to 5 quantiles are - allowed of values between 0 and 1, exclusive. Represents the quantiles to - use for that objective. Quantiles must be unique. - enable_probabilistic_inference: - If probabilistic inference is enabled, the model will fit a - distribution that captures the uncertainty of a prediction. At inference - time, the predictive distribution is used to make a point prediction that - minimizes the optimization objective. For example, the mean of a - predictive distribution is the point prediction that minimizes RMSE loss. - If quantiles are specified, then the quantiles of the distribution are - also returned. + stage_1_tuning_result_artifact_uri: The stage 1 tuning result artifact GCS + URI. + quantiles: Quantiles to use for probabilistic inference. Up to 5 quantiles + are allowed of values between 0 and 1, exclusive. Represents the quantiles + to use for that objective. Quantiles must be unique. + enable_probabilistic_inference: If probabilistic inference is enabled, the + model will fit a distribution that captures the uncertainty of a + prediction. At inference time, the predictive distribution is used to make + a point prediction that minimizes the optimization objective. For example, + the mean of a predictive distribution is the point prediction that + minimizes RMSE loss. If quantiles are specified, then the quantiles of the + distribution are also returned. Returns: Tuple of pipeline_definiton_path and parameter_values. @@ -269,92 +215,67 @@ def _get_default_pipeline_params( parameter_values = {} parameters = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column': - target_column, - 'prediction_type': - prediction_type, - 'data_source_csv_filenames': - data_source_csv_filenames, - 'data_source_bigquery_table_path': - data_source_bigquery_table_path, - 'predefined_split_key': - predefined_split_key, - 'timestamp_split_key': - timestamp_split_key, - 'stratified_split_key': - stratified_split_key, - 'training_fraction': - training_fraction, - 'validation_fraction': - validation_fraction, - 'test_fraction': - test_fraction, - 'optimization_objective': - optimization_objective, - 'transformations': - transformations, - 'train_budget_milli_node_hours': - train_budget_milli_node_hours, - 'stage_1_num_parallel_trials': - stage_1_num_parallel_trials, - 'stage_2_num_parallel_trials': - stage_2_num_parallel_trials, - 'stage_2_num_selected_trials': - stage_2_num_selected_trials, - 'weight_column': - weight_column, - 'optimization_objective_recall_value': - optimization_objective_recall_value, - 'optimization_objective_precision_value': - optimization_objective_precision_value, - 'study_spec_parameters_override': - study_spec_parameters_override, - 'stage_1_tuner_worker_pool_specs_override': - stage_1_tuner_worker_pool_specs_override, - 'cv_trainer_worker_pool_specs_override': - cv_trainer_worker_pool_specs_override, - 'export_additional_model_without_custom_ops': - export_additional_model_without_custom_ops, - 'stats_and_example_gen_dataflow_machine_type': - stats_and_example_gen_dataflow_machine_type, - 'stats_and_example_gen_dataflow_max_num_workers': - stats_and_example_gen_dataflow_max_num_workers, - 'stats_and_example_gen_dataflow_disk_size_gb': - stats_and_example_gen_dataflow_disk_size_gb, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'dataflow_service_account': - dataflow_service_account, - 'encryption_spec_key_name': - encryption_spec_key_name, - 'additional_experiments': - additional_experiments, - 'max_selected_features': - max_selected_features, - 'stage_1_tuning_result_artifact_uri': - stage_1_tuning_result_artifact_uri, - 'quantiles': - quantiles, - 'enable_probabilistic_inference': - enable_probabilistic_inference, + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column': target_column, + 'prediction_type': prediction_type, + 'data_source_csv_filenames': data_source_csv_filenames, + 'data_source_bigquery_table_path': data_source_bigquery_table_path, + 'predefined_split_key': predefined_split_key, + 'timestamp_split_key': timestamp_split_key, + 'stratified_split_key': stratified_split_key, + 'training_fraction': training_fraction, + 'validation_fraction': validation_fraction, + 'test_fraction': test_fraction, + 'optimization_objective': optimization_objective, + 'transformations': transformations, + 'train_budget_milli_node_hours': train_budget_milli_node_hours, + 'stage_1_num_parallel_trials': stage_1_num_parallel_trials, + 'stage_2_num_parallel_trials': stage_2_num_parallel_trials, + 'stage_2_num_selected_trials': stage_2_num_selected_trials, + 'weight_column': weight_column, + 'optimization_objective_recall_value': ( + optimization_objective_recall_value + ), + 'optimization_objective_precision_value': ( + optimization_objective_precision_value + ), + 'study_spec_parameters_override': study_spec_parameters_override, + 'stage_1_tuner_worker_pool_specs_override': ( + stage_1_tuner_worker_pool_specs_override + ), + 'cv_trainer_worker_pool_specs_override': ( + cv_trainer_worker_pool_specs_override + ), + 'export_additional_model_without_custom_ops': ( + export_additional_model_without_custom_ops + ), + 'stats_and_example_gen_dataflow_machine_type': ( + stats_and_example_gen_dataflow_machine_type + ), + 'stats_and_example_gen_dataflow_max_num_workers': ( + stats_and_example_gen_dataflow_max_num_workers + ), + 'stats_and_example_gen_dataflow_disk_size_gb': ( + stats_and_example_gen_dataflow_disk_size_gb + ), + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'dataflow_service_account': dataflow_service_account, + 'encryption_spec_key_name': encryption_spec_key_name, + 'additional_experiments': additional_experiments, + 'max_selected_features': max_selected_features, + 'stage_1_tuning_result_artifact_uri': stage_1_tuning_result_artifact_uri, + 'quantiles': quantiles, + 'enable_probabilistic_inference': enable_probabilistic_inference, } - parameter_values.update({ - param: value for param, value in parameters.items() if value is not None - }) + parameter_values.update( + {param: value for param, value in parameters.items() if value is not None} + ) if apply_feature_selection_tuning: parameter_values.update({ @@ -363,50 +284,61 @@ def _get_default_pipeline_params( if run_evaluation: eval_parameters = { - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_batch_explain_machine_type': - evaluation_batch_explain_machine_type, - 'evaluation_batch_explain_starting_replica_count': - evaluation_batch_explain_starting_replica_count, - 'evaluation_batch_explain_max_replica_count': - evaluation_batch_explain_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_starting_num_workers': - evaluation_dataflow_starting_num_workers, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'run_evaluation': - run_evaluation, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_batch_explain_machine_type': ( + evaluation_batch_explain_machine_type + ), + 'evaluation_batch_explain_starting_replica_count': ( + evaluation_batch_explain_starting_replica_count + ), + 'evaluation_batch_explain_max_replica_count': ( + evaluation_batch_explain_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_starting_num_workers': ( + evaluation_dataflow_starting_num_workers + ), + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'run_evaluation': run_evaluation, } - parameter_values.update({ - param: value - for param, value in eval_parameters.items() - if value is not None - }) + parameter_values.update( + { + param: value + for param, value in eval_parameters.items() + if value is not None + } + ) if run_distillation: distillation_parameters = { - 'distill_batch_predict_machine_type': - distill_batch_predict_machine_type, - 'distill_batch_predict_starting_replica_count': - distill_batch_predict_starting_replica_count, - 'distill_batch_predict_max_replica_count': - distill_batch_predict_max_replica_count, - 'run_distillation': - run_distillation, + 'distill_batch_predict_machine_type': ( + distill_batch_predict_machine_type + ), + 'distill_batch_predict_starting_replica_count': ( + distill_batch_predict_starting_replica_count + ), + 'distill_batch_predict_max_replica_count': ( + distill_batch_predict_max_replica_count + ), + 'run_distillation': run_distillation, } - parameter_values.update({ - param: value - for param, value in distillation_parameters.items() - if value is not None - }) + parameter_values.update( + { + param: value + for param, value in distillation_parameters.items() + if value is not None + } + ) return parameter_values @@ -465,167 +397,114 @@ def get_automl_tabular_pipeline_and_parameters( distill_batch_predict_max_replica_count: Optional[int] = None, stage_1_tuning_result_artifact_uri: Optional[str] = None, quantiles: Optional[List[float]] = None, - enable_probabilistic_inference: bool = False + enable_probabilistic_inference: bool = False, ) -> Tuple[str, Dict[str, Any]]: """Get the AutoML Tabular v1 default training pipeline. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - optimization_objective: - For binary classification, "maximize-au-roc", + optimization_objective: For binary classification, "maximize-au-roc", "minimize-log-loss", "maximize-au-prc", "maximize-precision-at-recall", or "maximize-recall-at-precision". For multi class classification, "minimize-log-loss". For regression, "minimize-rmse", "minimize-mae", or "minimize-rmsle". - transformations: - The path to a GCS file containing the transformations to + transformations: The path to a GCS file containing the transformations to apply. - train_budget_milli_node_hours: - The train budget of creating this model, + train_budget_milli_node_hours: The train budget of creating this model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. - stage_1_num_parallel_trials: - Number of parallel trails for stage 1. - stage_2_num_parallel_trials: - Number of parallel trails for stage 2. - stage_2_num_selected_trials: - Number of selected trials for stage 2. - data_source_csv_filenames: - The CSV data source. - data_source_bigquery_table_path: - The BigQuery data source. - predefined_split_key: - The predefined_split column name. - timestamp_split_key: - The timestamp_split column name. - stratified_split_key: - The stratified_split column name. - training_fraction: - The training fraction. - validation_fraction: - The validation fraction. - test_fraction: - float = The test fraction. - weight_column: - The weight column name. - study_spec_parameters_override: - The list for overriding study spec. The list + stage_1_num_parallel_trials: Number of parallel trails for stage 1. + stage_2_num_parallel_trials: Number of parallel trails for stage 2. + stage_2_num_selected_trials: Number of selected trials for stage 2. + data_source_csv_filenames: The CSV data source. + data_source_bigquery_table_path: The BigQuery data source. + predefined_split_key: The predefined_split column name. + timestamp_split_key: The timestamp_split column name. + stratified_split_key: The stratified_split column name. + training_fraction: The training fraction. + validation_fraction: The validation fraction. + test_fraction: float = The test fraction. + weight_column: The weight column name. + study_spec_parameters_override: The list for overriding study spec. The list should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/study.proto#L181. - optimization_objective_recall_value: - Required when optimization_objective is + optimization_objective_recall_value: Required when optimization_objective is "maximize-precision-at-recall". Must be between 0 and 1, inclusive. - optimization_objective_precision_value: - Required when optimization_objective + optimization_objective_precision_value: Required when optimization_objective is "maximize-recall-at-precision". Must be between 0 and 1, inclusive. - stage_1_tuner_worker_pool_specs_override: - The dictionary for overriding. + stage_1_tuner_worker_pool_specs_override: The dictionary for overriding. stage 1 tuner worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - cv_trainer_worker_pool_specs_override: - The dictionary for overriding stage + cv_trainer_worker_pool_specs_override: The dictionary for overriding stage cv trainer worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - export_additional_model_without_custom_ops: - Whether to export additional + export_additional_model_without_custom_ops: Whether to export additional model without custom TensorFlow operators. - stats_and_example_gen_dataflow_machine_type: - The dataflow machine type for + stats_and_example_gen_dataflow_machine_type: The dataflow machine type for stats_and_example_gen component. - stats_and_example_gen_dataflow_max_num_workers: - The max number of Dataflow + stats_and_example_gen_dataflow_max_num_workers: The max number of Dataflow workers for stats_and_example_gen component. - stats_and_example_gen_dataflow_disk_size_gb: - Dataflow worker's disk size in + stats_and_example_gen_dataflow_disk_size_gb: Dataflow worker's disk size in GB for stats_and_example_gen component. - transform_dataflow_machine_type: - The dataflow machine type for transform + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. - additional_experiments: - Use this field to config private preview features. - dataflow_service_account: - Custom service account to run dataflow jobs. - run_evaluation: - Whether to run evaluation in the training pipeline. - evaluation_batch_predict_machine_type: - The prediction server machine type + encryption_spec_key_name: The KMS key name. + additional_experiments: Use this field to config private preview features. + dataflow_service_account: Custom service account to run dataflow jobs. + run_evaluation: Whether to run evaluation in the training pipeline. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_batch_explain_machine_type: - The prediction server machine type for batch explain components during - evaluation. - evaluation_batch_explain_starting_replica_count: - The initial number of prediction server for batch explain components - during evaluation. - evaluation_batch_explain_max_replica_count: - The max number of prediction server for batch explain components during - evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_batch_explain_machine_type: The prediction server machine type + for batch explain components during evaluation. + evaluation_batch_explain_starting_replica_count: The initial number of + prediction server for batch explain components during evaluation. + evaluation_batch_explain_max_replica_count: The max number of prediction + server for batch explain components during evaluation. + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. - run_distillation: - Whether to run distill in the training pipeline. - distill_batch_predict_machine_type: - The prediction server machine type for + run_distillation: Whether to run distill in the training pipeline. + distill_batch_predict_machine_type: The prediction server machine type for batch predict component in the model distillation. - distill_batch_predict_starting_replica_count: - The initial number of + distill_batch_predict_starting_replica_count: The initial number of prediction server for batch predict component in the model distillation. - distill_batch_predict_max_replica_count: - The max number of prediction server + distill_batch_predict_max_replica_count: The max number of prediction server for batch predict component in the model distillation. - stage_1_tuning_result_artifact_uri: - The stage 1 tuning result artifact GCS URI. - quantiles: - Quantiles to use for probabilistic inference. Up to 5 quantiles are - allowed of values between 0 and 1, exclusive. Represents the quantiles to - use for that objective. Quantiles must be unique. - enable_probabilistic_inference: - If probabilistic inference is enabled, the model will fit a - distribution that captures the uncertainty of a prediction. At inference - time, the predictive distribution is used to make a point prediction that - minimizes the optimization objective. For example, the mean of a - predictive distribution is the point prediction that minimizes RMSE loss. - If quantiles are specified, then the quantiles of the distribution are - also returned. + stage_1_tuning_result_artifact_uri: The stage 1 tuning result artifact GCS + URI. + quantiles: Quantiles to use for probabilistic inference. Up to 5 quantiles + are allowed of values between 0 and 1, exclusive. Represents the quantiles + to use for that objective. Quantiles must be unique. + enable_probabilistic_inference: If probabilistic inference is enabled, the + model will fit a distribution that captures the uncertainty of a + prediction. At inference time, the predictive distribution is used to make + a point prediction that minimizes the optimization objective. For example, + the mean of a predictive distribution is the point prediction that + minimizes RMSE loss. If quantiles are specified, then the quantiles of the + distribution are also returned. Returns: Tuple of pipeline_definiton_path and parameter_values. @@ -689,7 +568,8 @@ def get_automl_tabular_pipeline_and_parameters( ) pipeline_definition_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), 'automl_tabular_pipeline.yaml') + pathlib.Path(__file__).parent.resolve(), 'automl_tabular_pipeline.yaml' + ) return pipeline_definition_path, parameter_values @@ -747,156 +627,103 @@ def get_automl_tabular_feature_selection_pipeline_and_parameters( run_distillation: bool = False, distill_batch_predict_machine_type: Optional[str] = None, distill_batch_predict_starting_replica_count: Optional[int] = None, - distill_batch_predict_max_replica_count: Optional[int] = None + distill_batch_predict_max_replica_count: Optional[int] = None, ) -> Tuple[str, Dict[str, Any]]: """Get the AutoML Tabular v1 default training pipeline. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - optimization_objective: - For binary classification, "maximize-au-roc", + optimization_objective: For binary classification, "maximize-au-roc", "minimize-log-loss", "maximize-au-prc", "maximize-precision-at-recall", or "maximize-recall-at-precision". For multi class classification, "minimize-log-loss". For regression, "minimize-rmse", "minimize-mae", or "minimize-rmsle". - transformations: - The path to a GCS file containing the transformations to + transformations: The path to a GCS file containing the transformations to apply. - train_budget_milli_node_hours: - The train budget of creating this model, + train_budget_milli_node_hours: The train budget of creating this model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. - stage_1_num_parallel_trials: - Number of parallel trails for stage 1. - stage_2_num_parallel_trials: - Number of parallel trails for stage 2. - stage_2_num_selected_trials: - Number of selected trials for stage 2. - data_source_csv_filenames: - The CSV data source. - data_source_bigquery_table_path: - The BigQuery data source. - predefined_split_key: - The predefined_split column name. - timestamp_split_key: - The timestamp_split column name. - stratified_split_key: - The stratified_split column name. - training_fraction: - The training fraction. - validation_fraction: - The validation fraction. - test_fraction: - float = The test fraction. - weight_column: - The weight column name. - study_spec_parameters_override: - The list for overriding study spec. The list + stage_1_num_parallel_trials: Number of parallel trails for stage 1. + stage_2_num_parallel_trials: Number of parallel trails for stage 2. + stage_2_num_selected_trials: Number of selected trials for stage 2. + data_source_csv_filenames: The CSV data source. + data_source_bigquery_table_path: The BigQuery data source. + predefined_split_key: The predefined_split column name. + timestamp_split_key: The timestamp_split column name. + stratified_split_key: The stratified_split column name. + training_fraction: The training fraction. + validation_fraction: The validation fraction. + test_fraction: float = The test fraction. + weight_column: The weight column name. + study_spec_parameters_override: The list for overriding study spec. The list should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/study.proto#L181. - optimization_objective_recall_value: - Required when optimization_objective is + optimization_objective_recall_value: Required when optimization_objective is "maximize-precision-at-recall". Must be between 0 and 1, inclusive. - optimization_objective_precision_value: - Required when optimization_objective + optimization_objective_precision_value: Required when optimization_objective is "maximize-recall-at-precision". Must be between 0 and 1, inclusive. - stage_1_tuner_worker_pool_specs_override: - The dictionary for overriding. + stage_1_tuner_worker_pool_specs_override: The dictionary for overriding. stage 1 tuner worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - cv_trainer_worker_pool_specs_override: - The dictionary for overriding stage + cv_trainer_worker_pool_specs_override: The dictionary for overriding stage cv trainer worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - export_additional_model_without_custom_ops: - Whether to export additional + export_additional_model_without_custom_ops: Whether to export additional model without custom TensorFlow operators. - stats_and_example_gen_dataflow_machine_type: - The dataflow machine type for + stats_and_example_gen_dataflow_machine_type: The dataflow machine type for stats_and_example_gen component. - stats_and_example_gen_dataflow_max_num_workers: - The max number of Dataflow + stats_and_example_gen_dataflow_max_num_workers: The max number of Dataflow workers for stats_and_example_gen component. - stats_and_example_gen_dataflow_disk_size_gb: - Dataflow worker's disk size in + stats_and_example_gen_dataflow_disk_size_gb: Dataflow worker's disk size in GB for stats_and_example_gen component. - transform_dataflow_machine_type: - The dataflow machine type for transform + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. - additional_experiments: - Use this field to config private preview features. - dataflow_service_account: - Custom service account to run dataflow jobs. - run_evaluation: - Whether to run evaluation in the training pipeline. - evaluation_batch_predict_machine_type: - The prediction server machine type + encryption_spec_key_name: The KMS key name. + additional_experiments: Use this field to config private preview features. + dataflow_service_account: Custom service account to run dataflow jobs. + run_evaluation: Whether to run evaluation in the training pipeline. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_batch_explain_machine_type: - The prediction server machine type for batch explain components during - evaluation. - evaluation_batch_explain_starting_replica_count: - The initial number of prediction server for batch explain components - during evaluation. - evaluation_batch_explain_max_replica_count: - The max number of prediction server for batch explain components during - evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_batch_explain_machine_type: The prediction server machine type + for batch explain components during evaluation. + evaluation_batch_explain_starting_replica_count: The initial number of + prediction server for batch explain components during evaluation. + evaluation_batch_explain_max_replica_count: The max number of prediction + server for batch explain components during evaluation. + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. - max_selected_features: - number of features to select for training, - apply_feature_selection_tuning: - tuning feature selection rate if true. - run_distillation: - Whether to run distill in the training pipeline. - distill_batch_predict_machine_type: - The prediction server machine type for + max_selected_features: number of features to select for training, + apply_feature_selection_tuning: tuning feature selection rate if true. + run_distillation: Whether to run distill in the training pipeline. + distill_batch_predict_machine_type: The prediction server machine type for batch predict component in the model distillation. - distill_batch_predict_starting_replica_count: - The initial number of + distill_batch_predict_starting_replica_count: The initial number of prediction server for batch predict component in the model distillation. - distill_batch_predict_max_replica_count: - The max number of prediction server + distill_batch_predict_max_replica_count: The max number of prediction server for batch predict component in the model distillation. Returns: @@ -956,12 +783,13 @@ def get_automl_tabular_feature_selection_pipeline_and_parameters( run_distillation=run_distillation, distill_batch_predict_machine_type=distill_batch_predict_machine_type, distill_batch_predict_starting_replica_count=distill_batch_predict_starting_replica_count, - distill_batch_predict_max_replica_count=distill_batch_predict_max_replica_count + distill_batch_predict_max_replica_count=distill_batch_predict_max_replica_count, ) pipeline_definition_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - 'automl_tabular_feature_selection_pipeline.yaml') + 'automl_tabular_feature_selection_pipeline.yaml', + ) return pipeline_definition_path, parameter_values @@ -973,8 +801,7 @@ def input_dictionary_to_parameter(input_dict: Optional[Dict[str, Any]]) -> str: JSON argument's quote must be manually escaped using this function. Args: - input_dict: - The input json dictionary. + input_dict: The input json dictionary. Returns: The encoded string used for parameter. @@ -1015,96 +842,66 @@ def get_skip_evaluation_pipeline_and_parameters( dataflow_subnetwork: str = '', dataflow_use_public_ips: bool = True, encryption_spec_key_name: str = '', - additional_experiments: Optional[Dict[str, Any]] = None + additional_experiments: Optional[Dict[str, Any]] = None, ) -> Tuple[str, Dict[str, Any]]: """Get the AutoML Tabular training pipeline that skips evaluation. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column_name: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column_name: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - optimization_objective: - For binary classification, "maximize-au-roc", + optimization_objective: For binary classification, "maximize-au-roc", "minimize-log-loss", "maximize-au-prc", "maximize-precision-at-recall", or "maximize-recall-at-precision". For multi class classification, "minimize-log-loss". For regression, "minimize-rmse", "minimize-mae", or "minimize-rmsle". - transformations: - The transformations to apply. - split_spec: - The split spec. - data_source: - The data source. - train_budget_milli_node_hours: - The train budget of creating this model, + transformations: The transformations to apply. + split_spec: The split spec. + data_source: The data source. + train_budget_milli_node_hours: The train budget of creating this model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. - stage_1_num_parallel_trials: - Number of parallel trails for stage 1. - stage_2_num_parallel_trials: - Number of parallel trails for stage 2. - stage_2_num_selected_trials: - Number of selected trials for stage 2. - weight_column_name: - The weight column name. - study_spec_override: - The dictionary for overriding study spec. The + stage_1_num_parallel_trials: Number of parallel trails for stage 1. + stage_2_num_parallel_trials: Number of parallel trails for stage 2. + stage_2_num_selected_trials: Number of selected trials for stage 2. + weight_column_name: The weight column name. + study_spec_override: The dictionary for overriding study spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/study.proto#L181. - optimization_objective_recall_value: - Required when optimization_objective is + optimization_objective_recall_value: Required when optimization_objective is "maximize-precision-at-recall". Must be between 0 and 1, inclusive. - optimization_objective_precision_value: - Required when optimization_objective + optimization_objective_precision_value: Required when optimization_objective is "maximize-recall-at-precision". Must be between 0 and 1, inclusive. - stage_1_tuner_worker_pool_specs_override: - The dictionary for overriding. + stage_1_tuner_worker_pool_specs_override: The dictionary for overriding. stage 1 tuner worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - cv_trainer_worker_pool_specs_override: - The dictionary for overriding stage + cv_trainer_worker_pool_specs_override: The dictionary for overriding stage cv trainer worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - export_additional_model_without_custom_ops: - Whether to export additional + export_additional_model_without_custom_ops: Whether to export additional model without custom TensorFlow operators. - stats_and_example_gen_dataflow_machine_type: - The dataflow machine type for + stats_and_example_gen_dataflow_machine_type: The dataflow machine type for stats_and_example_gen component. - stats_and_example_gen_dataflow_max_num_workers: - The max number of Dataflow + stats_and_example_gen_dataflow_max_num_workers: The max number of Dataflow workers for stats_and_example_gen component. - stats_and_example_gen_dataflow_disk_size_gb: - Dataflow worker's disk size in + stats_and_example_gen_dataflow_disk_size_gb: Dataflow worker's disk size in GB for stats_and_example_gen component. - transform_dataflow_machine_type: - The dataflow machine type for transform + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. - additional_experiments: - Use this field to config private preview features. + encryption_spec_key_name: The KMS key name. + additional_experiments: Use this field to config private preview features. Returns: Tuple of pipeline_definiton_path and parameter_values. @@ -1141,7 +938,8 @@ def get_skip_evaluation_pipeline_and_parameters( encryption_spec_key_name=encryption_spec_key_name, additional_experiments=additional_experiments, run_evaluation=False, - run_distillation=False) + run_distillation=False, + ) def get_default_pipeline_and_parameters( @@ -1177,141 +975,95 @@ def get_default_pipeline_and_parameters( additional_experiments: Optional[Dict[str, Any]] = None, dataflow_service_account: str = '', run_evaluation: bool = True, - evaluation_batch_predict_machine_type: - str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, - evaluation_batch_predict_starting_replica_count: - int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, - evaluation_batch_predict_max_replica_count: - int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, + evaluation_batch_predict_machine_type: str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, + evaluation_batch_predict_starting_replica_count: int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, + evaluation_batch_predict_max_replica_count: int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, evaluation_dataflow_machine_type: str = _EVALUATION_DATAFLOW_MACHINE_TYPE, - evaluation_dataflow_max_num_workers: - int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, + evaluation_dataflow_max_num_workers: int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, evaluation_dataflow_disk_size_gb: int = _EVALUATION_DATAFLOW_DISK_SIZE_GB, run_distillation: bool = False, distill_batch_predict_machine_type: str = 'n1-standard-16', distill_batch_predict_starting_replica_count: int = 25, - distill_batch_predict_max_replica_count: int = 25 + distill_batch_predict_max_replica_count: int = 25, ) -> Tuple[str, Dict[str, Any]]: """Get the AutoML Tabular default training pipeline. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column_name: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column_name: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - optimization_objective: - For binary classification, "maximize-au-roc", + optimization_objective: For binary classification, "maximize-au-roc", "minimize-log-loss", "maximize-au-prc", "maximize-precision-at-recall", or "maximize-recall-at-precision". For multi class classification, "minimize-log-loss". For regression, "minimize-rmse", "minimize-mae", or "minimize-rmsle". - transformations: - The transformations to apply. - split_spec: - The split spec. - data_source: - The data source. - train_budget_milli_node_hours: - The train budget of creating this model, + transformations: The transformations to apply. + split_spec: The split spec. + data_source: The data source. + train_budget_milli_node_hours: The train budget of creating this model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. - stage_1_num_parallel_trials: - Number of parallel trails for stage 1. - stage_2_num_parallel_trials: - Number of parallel trails for stage 2. - stage_2_num_selected_trials: - Number of selected trials for stage 2. - weight_column_name: - The weight column name. - study_spec_override: - The dictionary for overriding study spec. The + stage_1_num_parallel_trials: Number of parallel trails for stage 1. + stage_2_num_parallel_trials: Number of parallel trails for stage 2. + stage_2_num_selected_trials: Number of selected trials for stage 2. + weight_column_name: The weight column name. + study_spec_override: The dictionary for overriding study spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/study.proto#L181. - optimization_objective_recall_value: - Required when optimization_objective is + optimization_objective_recall_value: Required when optimization_objective is "maximize-precision-at-recall". Must be between 0 and 1, inclusive. - optimization_objective_precision_value: - Required when optimization_objective + optimization_objective_precision_value: Required when optimization_objective is "maximize-recall-at-precision". Must be between 0 and 1, inclusive. - stage_1_tuner_worker_pool_specs_override: - The dictionary for overriding. + stage_1_tuner_worker_pool_specs_override: The dictionary for overriding. stage 1 tuner worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - cv_trainer_worker_pool_specs_override: - The dictionary for overriding stage + cv_trainer_worker_pool_specs_override: The dictionary for overriding stage cv trainer worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - export_additional_model_without_custom_ops: - Whether to export additional + export_additional_model_without_custom_ops: Whether to export additional model without custom TensorFlow operators. - stats_and_example_gen_dataflow_machine_type: - The dataflow machine type for + stats_and_example_gen_dataflow_machine_type: The dataflow machine type for stats_and_example_gen component. - stats_and_example_gen_dataflow_max_num_workers: - The max number of Dataflow + stats_and_example_gen_dataflow_max_num_workers: The max number of Dataflow workers for stats_and_example_gen component. - stats_and_example_gen_dataflow_disk_size_gb: - Dataflow worker's disk size in + stats_and_example_gen_dataflow_disk_size_gb: Dataflow worker's disk size in GB for stats_and_example_gen component. - transform_dataflow_machine_type: - The dataflow machine type for transform + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. - additional_experiments: - Use this field to config private preview features. - dataflow_service_account: - Custom service account to run dataflow jobs. - run_evaluation: - Whether to run evaluation in the training pipeline. - evaluation_batch_predict_machine_type: - The prediction server machine type + encryption_spec_key_name: The KMS key name. + additional_experiments: Use this field to config private preview features. + dataflow_service_account: Custom service account to run dataflow jobs. + run_evaluation: Whether to run evaluation in the training pipeline. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. - run_distillation: - Whether to run distill in the training pipeline. - distill_batch_predict_machine_type: - The prediction server machine type for + run_distillation: Whether to run distill in the training pipeline. + distill_batch_predict_machine_type: The prediction server machine type for batch predict component in the model distillation. - distill_batch_predict_starting_replica_count: - The initial number of + distill_batch_predict_starting_replica_count: The initial number of prediction server for batch predict component in the model distillation. - distill_batch_predict_max_replica_count: - The max number of prediction server + distill_batch_predict_max_replica_count: The max number of prediction server for batch predict component in the model distillation. Returns: @@ -1319,7 +1071,8 @@ def get_default_pipeline_and_parameters( """ warnings.warn( 'This method is deprecated,' - ' please use get_automl_tabular_pipeline_and_parameters instead.') + ' please use get_automl_tabular_pipeline_and_parameters instead.' + ) if stage_1_num_parallel_trials <= 0: stage_1_num_parallel_trials = _DEFAULT_NUM_PARALLEL_TRAILS @@ -1331,22 +1084,25 @@ def get_default_pipeline_and_parameters( multiplier = stage_1_num_parallel_trials * hours / 500.0 stage_1_single_run_max_secs = int(math.sqrt(multiplier) * 2400.0) phase_2_rounds = int( - math.sqrt(multiplier) * 100 / stage_2_num_parallel_trials + 0.5) + math.sqrt(multiplier) * 100 / stage_2_num_parallel_trials + 0.5 + ) if phase_2_rounds < 1: phase_2_rounds = 1 # All of magic number "1.3" above is because the trial doesn't always finish # in time_per_trial. 1.3 is an empirical safety margin here. - stage_1_deadline_secs = int(hours * 3600.0 - 1.3 * - stage_1_single_run_max_secs * phase_2_rounds) + stage_1_deadline_secs = int( + hours * 3600.0 - 1.3 * stage_1_single_run_max_secs * phase_2_rounds + ) if stage_1_deadline_secs < hours * 3600.0 * 0.5: stage_1_deadline_secs = int(hours * 3600.0 * 0.5) # Phase 1 deadline is the same as phase 2 deadline in this case. Phase 2 # can't finish in time after the deadline is cut, so adjust the time per # trial to meet the deadline. - stage_1_single_run_max_secs = int(stage_1_deadline_secs / - (1.3 * phase_2_rounds)) + stage_1_single_run_max_secs = int( + stage_1_deadline_secs / (1.3 * phase_2_rounds) + ) reduce_search_space_mode = 'minimal' if multiplier > 2: @@ -1358,137 +1114,133 @@ def get_default_pipeline_and_parameters( # _NUM_FOLDS, which should be equal to phase_2_rounds * # stage_2_num_parallel_trials. Use this information to calculate # stage_1_num_selected_trials: - stage_1_num_selected_trials = int(phase_2_rounds * - stage_2_num_parallel_trials / _NUM_FOLDS) + stage_1_num_selected_trials = int( + phase_2_rounds * stage_2_num_parallel_trials / _NUM_FOLDS + ) stage_1_deadline_hours = stage_1_deadline_secs / 3600.0 stage_2_deadline_hours = hours - stage_1_deadline_hours stage_2_single_run_max_secs = stage_1_single_run_max_secs parameter_values = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column_name': - target_column_name, - 'prediction_type': - prediction_type, - 'optimization_objective': - optimization_objective, - 'transformations': - input_dictionary_to_parameter(transformations), - 'split_spec': - input_dictionary_to_parameter(split_spec), - 'data_source': - input_dictionary_to_parameter(data_source), - 'stage_1_deadline_hours': - stage_1_deadline_hours, - 'stage_1_num_parallel_trials': - stage_1_num_parallel_trials, - 'stage_1_num_selected_trials': - stage_1_num_selected_trials, - 'stage_1_single_run_max_secs': - stage_1_single_run_max_secs, - 'reduce_search_space_mode': - reduce_search_space_mode, - 'stage_2_deadline_hours': - stage_2_deadline_hours, - 'stage_2_num_parallel_trials': - stage_2_num_parallel_trials, - 'stage_2_num_selected_trials': - stage_2_num_selected_trials, - 'stage_2_single_run_max_secs': - stage_2_single_run_max_secs, - 'weight_column_name': - weight_column_name, - 'optimization_objective_recall_value': - optimization_objective_recall_value, - 'optimization_objective_precision_value': - optimization_objective_precision_value, - 'study_spec_override': - input_dictionary_to_parameter(study_spec_override), - 'stage_1_tuner_worker_pool_specs_override': - input_dictionary_to_parameter(stage_1_tuner_worker_pool_specs_override - ), - 'cv_trainer_worker_pool_specs_override': - input_dictionary_to_parameter(cv_trainer_worker_pool_specs_override), - 'export_additional_model_without_custom_ops': - export_additional_model_without_custom_ops, - 'stats_and_example_gen_dataflow_machine_type': - stats_and_example_gen_dataflow_machine_type, - 'stats_and_example_gen_dataflow_max_num_workers': - stats_and_example_gen_dataflow_max_num_workers, - 'stats_and_example_gen_dataflow_disk_size_gb': - stats_and_example_gen_dataflow_disk_size_gb, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'encryption_spec_key_name': - encryption_spec_key_name, + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column_name': target_column_name, + 'prediction_type': prediction_type, + 'optimization_objective': optimization_objective, + 'transformations': input_dictionary_to_parameter(transformations), + 'split_spec': input_dictionary_to_parameter(split_spec), + 'data_source': input_dictionary_to_parameter(data_source), + 'stage_1_deadline_hours': stage_1_deadline_hours, + 'stage_1_num_parallel_trials': stage_1_num_parallel_trials, + 'stage_1_num_selected_trials': stage_1_num_selected_trials, + 'stage_1_single_run_max_secs': stage_1_single_run_max_secs, + 'reduce_search_space_mode': reduce_search_space_mode, + 'stage_2_deadline_hours': stage_2_deadline_hours, + 'stage_2_num_parallel_trials': stage_2_num_parallel_trials, + 'stage_2_num_selected_trials': stage_2_num_selected_trials, + 'stage_2_single_run_max_secs': stage_2_single_run_max_secs, + 'weight_column_name': weight_column_name, + 'optimization_objective_recall_value': ( + optimization_objective_recall_value + ), + 'optimization_objective_precision_value': ( + optimization_objective_precision_value + ), + 'study_spec_override': input_dictionary_to_parameter(study_spec_override), + 'stage_1_tuner_worker_pool_specs_override': input_dictionary_to_parameter( + stage_1_tuner_worker_pool_specs_override + ), + 'cv_trainer_worker_pool_specs_override': input_dictionary_to_parameter( + cv_trainer_worker_pool_specs_override + ), + 'export_additional_model_without_custom_ops': ( + export_additional_model_without_custom_ops + ), + 'stats_and_example_gen_dataflow_machine_type': ( + stats_and_example_gen_dataflow_machine_type + ), + 'stats_and_example_gen_dataflow_max_num_workers': ( + stats_and_example_gen_dataflow_max_num_workers + ), + 'stats_and_example_gen_dataflow_disk_size_gb': ( + stats_and_example_gen_dataflow_disk_size_gb + ), + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'encryption_spec_key_name': encryption_spec_key_name, } if additional_experiments: - parameter_values.update({ - 'additional_experiments': - input_dictionary_to_parameter(additional_experiments) - }) + parameter_values.update( + { + 'additional_experiments': input_dictionary_to_parameter( + additional_experiments + ) + } + ) if run_evaluation: parameter_values.update({ - 'dataflow_service_account': - dataflow_service_account, - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'run_evaluation': - run_evaluation, + 'dataflow_service_account': dataflow_service_account, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'run_evaluation': run_evaluation, }) if run_distillation: # All of magic number "1.3" above is because the trial doesn't always finish # in time_per_trial. 1.3 is an empirical safety margin here. - distill_stage_1_deadline_hours = math.ceil( - float(_DISTILL_TOTAL_TRIALS) / - parameter_values['stage_1_num_parallel_trials'] - ) * parameter_values['stage_1_single_run_max_secs'] * 1.3 / 3600.0 + distill_stage_1_deadline_hours = ( + math.ceil( + float(_DISTILL_TOTAL_TRIALS) + / parameter_values['stage_1_num_parallel_trials'] + ) + * parameter_values['stage_1_single_run_max_secs'] + * 1.3 + / 3600.0 + ) parameter_values.update({ - 'distill_stage_1_deadline_hours': - distill_stage_1_deadline_hours, - 'distill_batch_predict_machine_type': - distill_batch_predict_machine_type, - 'distill_batch_predict_starting_replica_count': - distill_batch_predict_starting_replica_count, - 'distill_batch_predict_max_replica_count': - distill_batch_predict_max_replica_count, - 'run_distillation': - run_distillation, + 'distill_stage_1_deadline_hours': distill_stage_1_deadline_hours, + 'distill_batch_predict_machine_type': ( + distill_batch_predict_machine_type + ), + 'distill_batch_predict_starting_replica_count': ( + distill_batch_predict_starting_replica_count + ), + 'distill_batch_predict_max_replica_count': ( + distill_batch_predict_max_replica_count + ), + 'run_distillation': run_distillation, }) pipeline_definition_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - 'deprecated/default_pipeline.json') + 'deprecated/default_pipeline.json', + ) return pipeline_definition_path, parameter_values def get_feature_selection_pipeline_and_parameters( - project: str, location: str, root_dir: str, target_column: str, - algorithm: str, prediction_type: str, + project: str, + location: str, + root_dir: str, + target_column: str, + algorithm: str, + prediction_type: str, data_source_csv_filenames: Optional[str] = None, data_source_bigquery_table_path: Optional[str] = None, max_selected_features: Optional[int] = None, @@ -1497,48 +1249,34 @@ def get_feature_selection_pipeline_and_parameters( dataflow_disk_size_gb: int = 40, dataflow_subnetwork: str = '', dataflow_use_public_ips: bool = True, - dataflow_service_account: str = ''): + dataflow_service_account: str = '', +): """Get the feature selection pipeline that generates feature ranking and selected features. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - algorithm: - Algorithm to select features, default to be AMI. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + algorithm: Algorithm to select features, default to be AMI. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - data_source_csv_filenames: - A string that represents a list of comma + data_source_csv_filenames: A string that represents a list of comma separated CSV filenames. - data_source_bigquery_table_path: - The BigQuery table path. - max_selected_features: - number of features to be selected. - dataflow_machine_type: - The dataflow machine type for + data_source_bigquery_table_path: The BigQuery table path. + max_selected_features: number of features to be selected. + dataflow_machine_type: The dataflow machine type for feature_selection + component. + dataflow_max_num_workers: The max number of Dataflow workers for + feature_selection component. + dataflow_disk_size_gb: Dataflow worker's disk size in GB for feature_selection component. - dataflow_max_num_workers: - The max number of Dataflow - workers for feature_selection component. - dataflow_disk_size_gb: - Dataflow worker's disk size in - GB for feature_selection component. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - dataflow_service_account: - Custom service account to run dataflow jobs. + dataflow_service_account: Custom service account to run dataflow jobs. Returns: Tuple of pipeline_definition_path and parameter_values. @@ -1555,29 +1293,25 @@ def get_feature_selection_pipeline_and_parameters( 'data_source_csv_filenames': data_source_csv_filenames, 'data_source_bigquery_table_path': data_source_bigquery_table_path, 'max_selected_features': max_selected_features, - 'dataflow_machine_type': - dataflow_machine_type, - 'dataflow_max_num_workers': - dataflow_max_num_workers, - 'dataflow_disk_size_gb': - dataflow_disk_size_gb, - 'dataflow_service_account': - dataflow_service_account, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, + 'dataflow_machine_type': dataflow_machine_type, + 'dataflow_max_num_workers': dataflow_max_num_workers, + 'dataflow_disk_size_gb': dataflow_disk_size_gb, + 'dataflow_service_account': dataflow_service_account, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, } - parameter_values.update({ - param: value - for param, value in data_source_parameters.items() - if value is not None - }) + parameter_values.update( + { + param: value + for param, value in data_source_parameters.items() + if value is not None + } + ) pipeline_definition_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), - 'feature_selection_pipeline.yaml') + pathlib.Path(__file__).parent.resolve(), 'feature_selection_pipeline.yaml' + ) return pipeline_definition_path, parameter_values @@ -1628,133 +1362,88 @@ def get_skip_architecture_search_pipeline_and_parameters( evaluation_dataflow_machine_type: Optional[str] = None, evaluation_dataflow_starting_num_workers: Optional[int] = None, evaluation_dataflow_max_num_workers: Optional[int] = None, - evaluation_dataflow_disk_size_gb: Optional[int] = None + evaluation_dataflow_disk_size_gb: Optional[int] = None, ) -> Tuple[str, Dict[str, Any]]: """Get the AutoML Tabular training pipeline that skips architecture search. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - optimization_objective: - For binary classification, "maximize-au-roc", + optimization_objective: For binary classification, "maximize-au-roc", "minimize-log-loss", "maximize-au-prc", "maximize-precision-at-recall", or "maximize-recall-at-precision". For multi class classification, "minimize-log-loss". For regression, "minimize-rmse", "minimize-mae", or "minimize-rmsle". - transformations: - The transformations to apply. - train_budget_milli_node_hours: - The train budget of creating this model, + transformations: The transformations to apply. + train_budget_milli_node_hours: The train budget of creating this model, expressed in milli node hours i.e. 1,000 value in this field means 1 node hour. - stage_1_tuning_result_artifact_uri: - The stage 1 tuning result artifact GCS + stage_1_tuning_result_artifact_uri: The stage 1 tuning result artifact GCS URI. - stage_2_num_parallel_trials: - Number of parallel trails for stage 2. - stage_2_num_selected_trials: - Number of selected trials for stage 2. - data_source_csv_filenames: - The CSV data source. - data_source_bigquery_table_path: - The BigQuery data source. - predefined_split_key: - The predefined_split column name. - timestamp_split_key: - The timestamp_split column name. - stratified_split_key: - The stratified_split column name. - training_fraction: - The training fraction. - validation_fraction: - The validation fraction. - test_fraction: - float = The test fraction. - weight_column: - The weight column name. - optimization_objective_recall_value: - Required when optimization_objective is + stage_2_num_parallel_trials: Number of parallel trails for stage 2. + stage_2_num_selected_trials: Number of selected trials for stage 2. + data_source_csv_filenames: The CSV data source. + data_source_bigquery_table_path: The BigQuery data source. + predefined_split_key: The predefined_split column name. + timestamp_split_key: The timestamp_split column name. + stratified_split_key: The stratified_split column name. + training_fraction: The training fraction. + validation_fraction: The validation fraction. + test_fraction: float = The test fraction. + weight_column: The weight column name. + optimization_objective_recall_value: Required when optimization_objective is "maximize-precision-at-recall". Must be between 0 and 1, inclusive. - optimization_objective_precision_value: - Required when optimization_objective + optimization_objective_precision_value: Required when optimization_objective is "maximize-recall-at-precision". Must be between 0 and 1, inclusive. - cv_trainer_worker_pool_specs_override: - The dictionary for overriding stage + cv_trainer_worker_pool_specs_override: The dictionary for overriding stage cv trainer worker pool spec. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - export_additional_model_without_custom_ops: - Whether to export additional + export_additional_model_without_custom_ops: Whether to export additional model without custom TensorFlow operators. - stats_and_example_gen_dataflow_machine_type: - The dataflow machine type for + stats_and_example_gen_dataflow_machine_type: The dataflow machine type for stats_and_example_gen component. - stats_and_example_gen_dataflow_max_num_workers: - The max number of Dataflow + stats_and_example_gen_dataflow_max_num_workers: The max number of Dataflow workers for stats_and_example_gen component. - stats_and_example_gen_dataflow_disk_size_gb: - Dataflow worker's disk size in + stats_and_example_gen_dataflow_disk_size_gb: Dataflow worker's disk size in GB for stats_and_example_gen component. - transform_dataflow_machine_type: - The dataflow machine type for transform + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. - additional_experiments: - Use this field to config private preview features. - dataflow_service_account: - Custom service account to run dataflow jobs. - run_evaluation: - Whether to run evaluation in the training pipeline. - evaluation_batch_predict_machine_type: - The prediction server machine type + encryption_spec_key_name: The KMS key name. + additional_experiments: Use this field to config private preview features. + dataflow_service_account: Custom service account to run dataflow jobs. + run_evaluation: Whether to run evaluation in the training pipeline. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_batch_explain_machine_type: - The prediction server machine type for batch explain components during - evaluation. - evaluation_batch_explain_starting_replica_count: - The initial number of prediction server for batch explain components - during evaluation. - evaluation_batch_explain_max_replica_count: - The max number of prediction server for batch explain components during - evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_batch_explain_machine_type: The prediction server machine type + for batch explain components during evaluation. + evaluation_batch_explain_starting_replica_count: The initial number of + prediction server for batch explain components during evaluation. + evaluation_batch_explain_max_replica_count: The max number of prediction + server for batch explain components during evaluation. + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. Returns: @@ -1816,7 +1505,7 @@ def get_skip_architecture_search_pipeline_and_parameters( distill_batch_predict_max_replica_count=None, stage_1_tuning_result_artifact_uri=stage_1_tuning_result_artifact_uri, quantiles=[], - enable_probabilistic_inference=False + enable_probabilistic_inference=False, ) @@ -1853,7 +1542,7 @@ def get_distill_skip_evaluation_pipeline_and_parameters( additional_experiments: Optional[Dict[str, Any]] = None, distill_batch_predict_machine_type: str = 'n1-standard-16', distill_batch_predict_starting_replica_count: int = 25, - distill_batch_predict_max_replica_count: int = 25 + distill_batch_predict_max_replica_count: int = 25, ) -> Tuple[str, Dict[str, Any]]: """Get the AutoML Tabular training pipeline that distill and skips evaluation. @@ -1924,7 +1613,8 @@ def get_distill_skip_evaluation_pipeline_and_parameters( Tuple of pipeline_definiton_path and parameter_values. """ warnings.warn( - 'Depreciated. Please use get_automl_tabular_pipeline_and_parameters.') + 'Depreciated. Please use get_automl_tabular_pipeline_and_parameters.' + ) return get_default_pipeline_and_parameters( project=project, @@ -1961,7 +1651,8 @@ def get_distill_skip_evaluation_pipeline_and_parameters( distill_batch_predict_starting_replica_count=distill_batch_predict_starting_replica_count, distill_batch_predict_max_replica_count=distill_batch_predict_max_replica_count, run_evaluation=False, - run_distillation=True) + run_distillation=True, + ) def get_wide_and_deep_trainer_pipeline_and_parameters( @@ -1973,8 +1664,9 @@ def get_wide_and_deep_trainer_pipeline_and_parameters( learning_rate: float, dnn_learning_rate: float, transform_config: Optional[str] = None, - dataset_level_custom_transformation_definitions: Optional[List[Dict[ - str, Any]]] = None, + dataset_level_custom_transformation_definitions: Optional[ + List[Dict[str, Any]] + ] = None, dataset_level_transformations: Optional[List[Dict[str, Any]]] = None, run_feature_selection: bool = False, feature_selection_algorithm: Optional[str] = None, @@ -2022,192 +1714,132 @@ def get_wide_and_deep_trainer_pipeline_and_parameters( transform_dataflow_disk_size_gb: int = 40, worker_pool_specs_override: Optional[Dict[str, Any]] = None, run_evaluation: bool = True, - evaluation_batch_predict_machine_type: - str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, - evaluation_batch_predict_starting_replica_count: - int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, - evaluation_batch_predict_max_replica_count: - int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, + evaluation_batch_predict_machine_type: str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, + evaluation_batch_predict_starting_replica_count: int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, + evaluation_batch_predict_max_replica_count: int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, evaluation_dataflow_machine_type: str = _EVALUATION_DATAFLOW_MACHINE_TYPE, - evaluation_dataflow_starting_num_workers: - int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, - evaluation_dataflow_max_num_workers: - int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, + evaluation_dataflow_starting_num_workers: int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, + evaluation_dataflow_max_num_workers: int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, evaluation_dataflow_disk_size_gb: int = _EVALUATION_DATAFLOW_DISK_SIZE_GB, dataflow_service_account: str = '', dataflow_subnetwork: str = '', dataflow_use_public_ips: bool = True, - encryption_spec_key_name: str = '') -> Tuple[str, Dict[str, Any]]: + encryption_spec_key_name: str = '', +) -> Tuple[str, Dict[str, Any]]: """Get the Wide & Deep training pipeline. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + prediction_type: The type of prediction the model is to produce. 'classification' or 'regression'. - learning_rate: - The learning rate used by the linear optimizer. - dnn_learning_rate: - The learning rate for training the deep part of the + learning_rate: The learning rate used by the linear optimizer. + dnn_learning_rate: The learning rate for training the deep part of the model. - transform_config: - Path to v1 TF transformation configuration. - dataset_level_custom_transformation_definitions: - Dataset-level custom transformation definitions in string format. - dataset_level_transformations: - Dataset-level transformation configuration in string format. + transform_config: Path to v1 TF transformation configuration. + dataset_level_custom_transformation_definitions: Dataset-level custom + transformation definitions in string format. + dataset_level_transformations: Dataset-level transformation configuration in + string format. run_feature_selection: Whether to enable feature selection. feature_selection_algorithm: Feature selection algorithm. max_selected_features: Maximum number of features to select. - predefined_split_key: - Predefined split key. - stratified_split_key: - Stratified split key. - training_fraction: - Training fraction. - validation_fraction: - Validation fraction. - test_fraction: - Test fraction. - tf_auto_transform_features: - List of auto transform features in the comma-separated string format. - tf_custom_transformation_definitions: - TF custom transformation definitions in string format. - tf_transformations_path: - Path to TF transformation configuration. - optimizer_type: - The type of optimizer to use. Choices are "adam", "ftrl" and + predefined_split_key: Predefined split key. + stratified_split_key: Stratified split key. + training_fraction: Training fraction. + validation_fraction: Validation fraction. + test_fraction: Test fraction. + tf_auto_transform_features: List of auto transform features in the + comma-separated string format. + tf_custom_transformation_definitions: TF custom transformation definitions + in string format. + tf_transformations_path: Path to TF transformation configuration. + optimizer_type: The type of optimizer to use. Choices are "adam", "ftrl" and "sgd" for the Adam, FTRL, and Gradient Descent Optimizers, respectively. - max_steps: - Number of steps to run the trainer for. - max_train_secs: - Amount of time in seconds to run the trainer for. - l1_regularization_strength: - L1 regularization strength for + max_steps: Number of steps to run the trainer for. + max_train_secs: Amount of time in seconds to run the trainer for. + l1_regularization_strength: L1 regularization strength for optimizer_type="ftrl". - l2_regularization_strength: - L2 regularization strength for + l2_regularization_strength: L2 regularization strength for optimizer_type="ftrl". - l2_shrinkage_regularization_strength: - L2 shrinkage regularization strength + l2_shrinkage_regularization_strength: L2 shrinkage regularization strength for optimizer_type="ftrl". - beta_1: - Beta 1 value for optimizer_type="adam". - beta_2: - Beta 2 value for optimizer_type="adam". - hidden_units: - Hidden layer sizes to use for DNN feature columns, provided in + beta_1: Beta 1 value for optimizer_type="adam". + beta_2: Beta 2 value for optimizer_type="adam". + hidden_units: Hidden layer sizes to use for DNN feature columns, provided in comma-separated layers. - use_wide: - If set to true, the categorical columns will be used in the wide + use_wide: If set to true, the categorical columns will be used in the wide part of the DNN model. - embed_categories: - If set to true, the categorical columns will be used + embed_categories: If set to true, the categorical columns will be used embedded and used in the deep part of the model. Embedding size is the square root of the column cardinality. - dnn_dropout: - The probability we will drop out a given coordinate. - dnn_optimizer_type: - The type of optimizer to use for the deep part of the + dnn_dropout: The probability we will drop out a given coordinate. + dnn_optimizer_type: The type of optimizer to use for the deep part of the model. Choices are "adam", "ftrl" and "sgd". for the Adam, FTRL, and Gradient Descent Optimizers, respectively. - dnn_l1_regularization_strength: - L1 regularization strength for + dnn_l1_regularization_strength: L1 regularization strength for dnn_optimizer_type="ftrl". - dnn_l2_regularization_strength: - L2 regularization strength for + dnn_l2_regularization_strength: L2 regularization strength for dnn_optimizer_type="ftrl". - dnn_l2_shrinkage_regularization_strength: - L2 shrinkage regularization + dnn_l2_shrinkage_regularization_strength: L2 shrinkage regularization strength for dnn_optimizer_type="ftrl". - dnn_beta_1: - Beta 1 value for dnn_optimizer_type="adam". - dnn_beta_2: - Beta 2 value for dnn_optimizer_type="adam". - enable_profiler: - Enables profiling and saves a trace during evaluation. + dnn_beta_1: Beta 1 value for dnn_optimizer_type="adam". + dnn_beta_2: Beta 2 value for dnn_optimizer_type="adam". + enable_profiler: Enables profiling and saves a trace during evaluation. cache_data: Whether to cache data or not. If set to 'auto', caching is determined based on the dataset size. - seed: - Seed to be used for this run. - eval_steps: - Number of steps to run evaluation for. If not specified or + seed: Seed to be used for this run. + eval_steps: Number of steps to run evaluation for. If not specified or negative, it means run evaluation on the whole validation dataset. If set to 0, it means run evaluation for a fixed number of samples. - batch_size: - Batch size for training. - measurement_selection_type: - Which measurement to use if/when the service automatically - selects the final measurement from previously - reported intermediate measurements. One of "BEST_MEASUREMENT" or + batch_size: Batch size for training. + measurement_selection_type: Which measurement to use if/when the service + automatically selects the final measurement from previously reported + intermediate measurements. One of "BEST_MEASUREMENT" or "LAST_MEASUREMENT". - optimization_metric: - Optimization metric used for `measurement_selection_type`. - Default is "rmse" for regression and "auc" for classification. - eval_frequency_secs: - Frequency at which evaluation and checkpointing will + optimization_metric: Optimization metric used for + `measurement_selection_type`. Default is "rmse" for regression and "auc" + for classification. + eval_frequency_secs: Frequency at which evaluation and checkpointing will take place. - data_source_csv_filenames: - The CSV data source. - data_source_bigquery_table_path: - The BigQuery data source. - bigquery_staging_full_dataset_id: - The BigQuery staging full dataset id for storing intermediate tables. - weight_column: - The weight column name. - transform_dataflow_machine_type: - The dataflow machine type for transform + data_source_csv_filenames: The CSV data source. + data_source_bigquery_table_path: The BigQuery data source. + bigquery_staging_full_dataset_id: The BigQuery staging full dataset id for + storing intermediate tables. + weight_column: The weight column name. + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - worker_pool_specs_override: - The dictionary for overriding training and - evaluation worker pool specs. The dictionary should be of format + worker_pool_specs_override: The dictionary for overriding training and + evaluation worker pool specs. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - run_evaluation: - Whether to run evaluation steps during training. - evaluation_batch_predict_machine_type: - The prediction server machine type + run_evaluation: Whether to run evaluation steps during training. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. - dataflow_service_account: - Custom service account to run dataflow jobs. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_service_account: Custom service account to run dataflow jobs. + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. + encryption_spec_key_name: The KMS key name. Returns: Tuple of pipeline_definition_path and parameter_values. @@ -2215,12 +1847,14 @@ def get_wide_and_deep_trainer_pipeline_and_parameters( if transform_config and tf_transformations_path: raise ValueError( 'Only one of transform_config and tf_transformations_path can ' - 'be specified.') + 'be specified.' + ) elif transform_config: warnings.warn( 'transform_config parameter is deprecated. ' - 'Please use the flattened transform config arguments instead.') + 'Please use the flattened transform config arguments instead.' + ) tf_transformations_path = transform_config if not worker_pool_specs_override: @@ -2228,139 +1862,99 @@ def get_wide_and_deep_trainer_pipeline_and_parameters( parameter_values = {} training_and_eval_parameters = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column': - target_column, - 'prediction_type': - prediction_type, - 'learning_rate': - learning_rate, - 'dnn_learning_rate': - dnn_learning_rate, - 'optimizer_type': - optimizer_type, - 'max_steps': - max_steps, - 'max_train_secs': - max_train_secs, - 'l1_regularization_strength': - l1_regularization_strength, - 'l2_regularization_strength': - l2_regularization_strength, - 'l2_shrinkage_regularization_strength': - l2_shrinkage_regularization_strength, - 'beta_1': - beta_1, - 'beta_2': - beta_2, - 'hidden_units': - hidden_units, - 'use_wide': - use_wide, - 'embed_categories': - embed_categories, - 'dnn_dropout': - dnn_dropout, - 'dnn_optimizer_type': - dnn_optimizer_type, - 'dnn_l1_regularization_strength': - dnn_l1_regularization_strength, - 'dnn_l2_regularization_strength': - dnn_l2_regularization_strength, - 'dnn_l2_shrinkage_regularization_strength': - dnn_l2_shrinkage_regularization_strength, - 'dnn_beta_1': - dnn_beta_1, - 'dnn_beta_2': - dnn_beta_2, - 'enable_profiler': - enable_profiler, - 'cache_data': - cache_data, - 'seed': - seed, - 'eval_steps': - eval_steps, - 'batch_size': - batch_size, - 'measurement_selection_type': - measurement_selection_type, - 'optimization_metric': - optimization_metric, - 'eval_frequency_secs': - eval_frequency_secs, - 'weight_column': - weight_column, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'worker_pool_specs_override': - worker_pool_specs_override, - 'run_evaluation': - run_evaluation, - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_starting_num_workers': - evaluation_dataflow_starting_num_workers, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'dataflow_service_account': - dataflow_service_account, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'encryption_spec_key_name': - encryption_spec_key_name, + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column': target_column, + 'prediction_type': prediction_type, + 'learning_rate': learning_rate, + 'dnn_learning_rate': dnn_learning_rate, + 'optimizer_type': optimizer_type, + 'max_steps': max_steps, + 'max_train_secs': max_train_secs, + 'l1_regularization_strength': l1_regularization_strength, + 'l2_regularization_strength': l2_regularization_strength, + 'l2_shrinkage_regularization_strength': ( + l2_shrinkage_regularization_strength + ), + 'beta_1': beta_1, + 'beta_2': beta_2, + 'hidden_units': hidden_units, + 'use_wide': use_wide, + 'embed_categories': embed_categories, + 'dnn_dropout': dnn_dropout, + 'dnn_optimizer_type': dnn_optimizer_type, + 'dnn_l1_regularization_strength': dnn_l1_regularization_strength, + 'dnn_l2_regularization_strength': dnn_l2_regularization_strength, + 'dnn_l2_shrinkage_regularization_strength': ( + dnn_l2_shrinkage_regularization_strength + ), + 'dnn_beta_1': dnn_beta_1, + 'dnn_beta_2': dnn_beta_2, + 'enable_profiler': enable_profiler, + 'cache_data': cache_data, + 'seed': seed, + 'eval_steps': eval_steps, + 'batch_size': batch_size, + 'measurement_selection_type': measurement_selection_type, + 'optimization_metric': optimization_metric, + 'eval_frequency_secs': eval_frequency_secs, + 'weight_column': weight_column, + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'worker_pool_specs_override': worker_pool_specs_override, + 'run_evaluation': run_evaluation, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_starting_num_workers': ( + evaluation_dataflow_starting_num_workers + ), + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'dataflow_service_account': dataflow_service_account, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'encryption_spec_key_name': encryption_spec_key_name, } _update_parameters(parameter_values, training_and_eval_parameters) fte_params = { - 'dataset_level_custom_transformation_definitions': + 'dataset_level_custom_transformation_definitions': ( dataset_level_custom_transformation_definitions - if dataset_level_custom_transformation_definitions else [], - 'dataset_level_transformations': - dataset_level_transformations - if dataset_level_transformations else [], - 'run_feature_selection': - run_feature_selection, - 'feature_selection_algorithm': - feature_selection_algorithm, - 'max_selected_features': - max_selected_features, - 'predefined_split_key': - predefined_split_key, - 'stratified_split_key': - stratified_split_key, - 'training_fraction': - training_fraction, - 'validation_fraction': - validation_fraction, - 'test_fraction': - test_fraction, - 'tf_auto_transform_features': - tf_auto_transform_features if tf_auto_transform_features else [], - 'tf_custom_transformation_definitions': + if dataset_level_custom_transformation_definitions + else [] + ), + 'dataset_level_transformations': ( + dataset_level_transformations if dataset_level_transformations else [] + ), + 'run_feature_selection': run_feature_selection, + 'feature_selection_algorithm': feature_selection_algorithm, + 'max_selected_features': max_selected_features, + 'predefined_split_key': predefined_split_key, + 'stratified_split_key': stratified_split_key, + 'training_fraction': training_fraction, + 'validation_fraction': validation_fraction, + 'test_fraction': test_fraction, + 'tf_auto_transform_features': ( + tf_auto_transform_features if tf_auto_transform_features else [] + ), + 'tf_custom_transformation_definitions': ( tf_custom_transformation_definitions - if tf_custom_transformation_definitions else [], - 'tf_transformations_path': - tf_transformations_path, + if tf_custom_transformation_definitions + else [] + ), + 'tf_transformations_path': tf_transformations_path, } _update_parameters(parameter_values, fte_params) @@ -2373,7 +1967,8 @@ def get_wide_and_deep_trainer_pipeline_and_parameters( pipeline_definition_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - 'wide_and_deep_trainer_pipeline.yaml') + 'wide_and_deep_trainer_pipeline.yaml', + ) return pipeline_definition_path, parameter_values @@ -2395,8 +1990,9 @@ def get_builtin_algorithm_hyperparameter_tuning_job_pipeline_and_parameters( eval_steps: int = 0, eval_frequency_secs: int = 600, transform_config: Optional[str] = None, - dataset_level_custom_transformation_definitions: Optional[List[Dict[ - str, Any]]] = None, + dataset_level_custom_transformation_definitions: Optional[ + List[Dict[str, Any]] + ] = None, dataset_level_transformations: Optional[List[Dict[str, Any]]] = None, predefined_split_key: Optional[str] = None, stratified_split_key: Optional[str] = None, @@ -2418,153 +2014,106 @@ def get_builtin_algorithm_hyperparameter_tuning_job_pipeline_and_parameters( transform_dataflow_disk_size_gb: int = 40, worker_pool_specs_override: Optional[Dict[str, Any]] = None, run_evaluation: bool = True, - evaluation_batch_predict_machine_type: - str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, - evaluation_batch_predict_starting_replica_count: - int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, - evaluation_batch_predict_max_replica_count: - int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, + evaluation_batch_predict_machine_type: str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, + evaluation_batch_predict_starting_replica_count: int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, + evaluation_batch_predict_max_replica_count: int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, evaluation_dataflow_machine_type: str = _EVALUATION_DATAFLOW_MACHINE_TYPE, - evaluation_dataflow_starting_num_workers: - int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, - evaluation_dataflow_max_num_workers: - int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, + evaluation_dataflow_starting_num_workers: int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, + evaluation_dataflow_max_num_workers: int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, evaluation_dataflow_disk_size_gb: int = _EVALUATION_DATAFLOW_DISK_SIZE_GB, dataflow_service_account: str = '', dataflow_subnetwork: str = '', dataflow_use_public_ips: bool = True, - encryption_spec_key_name: str = '') -> Tuple[str, Dict[str, Any]]: + encryption_spec_key_name: str = '', +) -> Tuple[str, Dict[str, Any]]: """Get the built-in algorithm HyperparameterTuningJob pipeline. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - study_spec_metric_id: - Metric to optimize, possible values: [ - 'loss', 'average_loss', 'rmse', 'mae', 'mql', 'accuracy', 'auc', - 'precision', 'recall']. - study_spec_metric_goal: - Optimization goal of the metric, possible values: + study_spec_metric_id: Metric to optimize, possible values: [ 'loss', + 'average_loss', 'rmse', 'mae', 'mql', 'accuracy', 'auc', 'precision', + 'recall']. + study_spec_metric_goal: Optimization goal of the metric, possible values: "MAXIMIZE", "MINIMIZE". - study_spec_parameters_override: - List of dictionaries representing parameters + study_spec_parameters_override: List of dictionaries representing parameters to optimize. The dictionary key is the parameter_id, which is passed to training job as a command line argument, and the dictionary value is the parameter specification of the metric. - max_trial_count: - The desired total number of trials. - parallel_trial_count: - The desired number of trials to run in parallel. - algorithm: - Algorithm to train. One of "tabnet" and "wide_and_deep". - enable_profiler: - Enables profiling and saves a trace during evaluation. - seed: - Seed to be used for this run. - eval_steps: - Number of steps to run evaluation for. If not specified or + max_trial_count: The desired total number of trials. + parallel_trial_count: The desired number of trials to run in parallel. + algorithm: Algorithm to train. One of "tabnet" and "wide_and_deep". + enable_profiler: Enables profiling and saves a trace during evaluation. + seed: Seed to be used for this run. + eval_steps: Number of steps to run evaluation for. If not specified or negative, it means run evaluation on the whole validation dataset. If set to 0, it means run evaluation for a fixed number of samples. - eval_frequency_secs: - Frequency at which evaluation and checkpointing will + eval_frequency_secs: Frequency at which evaluation and checkpointing will take place. - transform_config: - Path to v1 TF transformation configuration. - dataset_level_custom_transformation_definitions: - Dataset-level custom transformation definitions in string format. - dataset_level_transformations: - Dataset-level transformation configuration in string format. - predefined_split_key: - Predefined split key. - stratified_split_key: - Stratified split key. - training_fraction: - Training fraction. - validation_fraction: - Validation fraction. - test_fraction: - Test fraction. - tf_auto_transform_features: - List of auto transform features in the comma-separated string format. - tf_custom_transformation_definitions: - TF custom transformation definitions in string format. - tf_transformations_path: - Path to TF transformation configuration. - data_source_csv_filenames: - The CSV data source. - data_source_bigquery_table_path: - The BigQuery data source. + transform_config: Path to v1 TF transformation configuration. + dataset_level_custom_transformation_definitions: Dataset-level custom + transformation definitions in string format. + dataset_level_transformations: Dataset-level transformation configuration in + string format. + predefined_split_key: Predefined split key. + stratified_split_key: Stratified split key. + training_fraction: Training fraction. + validation_fraction: Validation fraction. + test_fraction: Test fraction. + tf_auto_transform_features: List of auto transform features in the + comma-separated string format. + tf_custom_transformation_definitions: TF custom transformation definitions + in string format. + tf_transformations_path: Path to TF transformation configuration. + data_source_csv_filenames: The CSV data source. + data_source_bigquery_table_path: The BigQuery data source. bigquery_staging_full_dataset_id: The BigQuery staging full dataset id for storing intermediate tables. - weight_column: - The weight column name. - max_failed_trial_count: - The number of failed trials that need to be seen + weight_column: The weight column name. + max_failed_trial_count: The number of failed trials that need to be seen before failing the HyperparameterTuningJob. If set to 0, Vertex AI decides how many trials must fail before the whole job fails. - study_spec_algorithm: - The search algorithm specified for the study. One of + study_spec_algorithm: The search algorithm specified for the study. One of "ALGORITHM_UNSPECIFIED", "GRID_SEARCH", or "RANDOM_SEARCH". - study_spec_measurement_selection_type: - Which measurement to use if/when the + study_spec_measurement_selection_type: Which measurement to use if/when the service automatically selects the final measurement from previously reported intermediate measurements. One of "BEST_MEASUREMENT" or "LAST_MEASUREMENT". - transform_dataflow_machine_type: - The dataflow machine type for transform + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - worker_pool_specs_override: - The dictionary for overriding training and - evaluation worker pool specs. The dictionary should be of format + worker_pool_specs_override: The dictionary for overriding training and + evaluation worker pool specs. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - run_evaluation: - Whether to run evaluation steps during training. - evaluation_batch_predict_machine_type: - The prediction server machine type + run_evaluation: Whether to run evaluation steps during training. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. - dataflow_service_account: - Custom service account to run dataflow jobs. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_service_account: Custom service account to run dataflow jobs. + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. + encryption_spec_key_name: The KMS key name. Returns: Tuple of pipeline_definiton_path and parameter_values. @@ -2625,7 +2174,8 @@ def get_builtin_algorithm_hyperparameter_tuning_job_pipeline_and_parameters( dataflow_service_account=dataflow_service_account, dataflow_subnetwork=dataflow_subnetwork, dataflow_use_public_ips=dataflow_use_public_ips, - encryption_spec_key_name=encryption_spec_key_name) + encryption_spec_key_name=encryption_spec_key_name, + ) elif algorithm == 'wide_and_deep': return get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters( project=project, @@ -2675,7 +2225,8 @@ def get_builtin_algorithm_hyperparameter_tuning_job_pipeline_and_parameters( dataflow_service_account=dataflow_service_account, dataflow_subnetwork=dataflow_subnetwork, dataflow_use_public_ips=dataflow_use_public_ips, - encryption_spec_key_name=encryption_spec_key_name) + encryption_spec_key_name=encryption_spec_key_name, + ) else: raise ValueError( 'Invalid algorithm provided. Supported values are "tabnet" and' @@ -2695,8 +2246,9 @@ def get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters( max_trial_count: int, parallel_trial_count: int, transform_config: Optional[str] = None, - dataset_level_custom_transformation_definitions: Optional[List[Dict[ - str, Any]]] = None, + dataset_level_custom_transformation_definitions: Optional[ + List[Dict[str, Any]] + ] = None, dataset_level_transformations: Optional[List[Dict[str, Any]]] = None, run_feature_selection: bool = False, feature_selection_algorithm: Optional[str] = None, @@ -2726,22 +2278,18 @@ def get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters( transform_dataflow_disk_size_gb: int = 40, worker_pool_specs_override: Optional[Dict[str, Any]] = None, run_evaluation: bool = True, - evaluation_batch_predict_machine_type: - str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, - evaluation_batch_predict_starting_replica_count: - int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, - evaluation_batch_predict_max_replica_count: - int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, + evaluation_batch_predict_machine_type: str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, + evaluation_batch_predict_starting_replica_count: int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, + evaluation_batch_predict_max_replica_count: int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, evaluation_dataflow_machine_type: str = _EVALUATION_DATAFLOW_MACHINE_TYPE, - evaluation_dataflow_starting_num_workers: - int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, - evaluation_dataflow_max_num_workers: - int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, + evaluation_dataflow_starting_num_workers: int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, + evaluation_dataflow_max_num_workers: int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, evaluation_dataflow_disk_size_gb: int = _EVALUATION_DATAFLOW_DISK_SIZE_GB, dataflow_service_account: str = '', dataflow_subnetwork: str = '', dataflow_use_public_ips: bool = True, - encryption_spec_key_name: str = '') -> Tuple[str, Dict[str, Any]]: + encryption_spec_key_name: str = '', +) -> Tuple[str, Dict[str, Any]]: """Get the TabNet HyperparameterTuningJob pipeline. Args: @@ -2821,8 +2369,8 @@ def get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters( server for batch predict components during evaluation. evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for @@ -2841,133 +2389,108 @@ def get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters( if transform_config and tf_transformations_path: raise ValueError( 'Only one of transform_config and tf_transformations_path can ' - 'be specified.') + 'be specified.' + ) elif transform_config: warnings.warn( 'transform_config parameter is deprecated. ' - 'Please use the flattened transform config arguments instead.') + 'Please use the flattened transform config arguments instead.' + ) tf_transformations_path = transform_config if not worker_pool_specs_override: worker_pool_specs_override = [] parameter_values = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column': - target_column, - 'prediction_type': - prediction_type, - 'study_spec_metric_id': - study_spec_metric_id, - 'study_spec_metric_goal': - study_spec_metric_goal, - 'study_spec_parameters_override': - study_spec_parameters_override, - 'max_trial_count': - max_trial_count, - 'parallel_trial_count': - parallel_trial_count, - 'enable_profiler': - enable_profiler, - 'cache_data': - cache_data, - 'seed': - seed, - 'eval_steps': - eval_steps, - 'eval_frequency_secs': - eval_frequency_secs, - 'weight_column': - weight_column, - 'max_failed_trial_count': - max_failed_trial_count, - 'study_spec_algorithm': - study_spec_algorithm, - 'study_spec_measurement_selection_type': - study_spec_measurement_selection_type, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'worker_pool_specs_override': - worker_pool_specs_override, - 'run_evaluation': - run_evaluation, - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_starting_num_workers': - evaluation_dataflow_starting_num_workers, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'dataflow_service_account': - dataflow_service_account, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'encryption_spec_key_name': - encryption_spec_key_name, + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column': target_column, + 'prediction_type': prediction_type, + 'study_spec_metric_id': study_spec_metric_id, + 'study_spec_metric_goal': study_spec_metric_goal, + 'study_spec_parameters_override': study_spec_parameters_override, + 'max_trial_count': max_trial_count, + 'parallel_trial_count': parallel_trial_count, + 'enable_profiler': enable_profiler, + 'cache_data': cache_data, + 'seed': seed, + 'eval_steps': eval_steps, + 'eval_frequency_secs': eval_frequency_secs, + 'weight_column': weight_column, + 'max_failed_trial_count': max_failed_trial_count, + 'study_spec_algorithm': study_spec_algorithm, + 'study_spec_measurement_selection_type': ( + study_spec_measurement_selection_type + ), + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'worker_pool_specs_override': worker_pool_specs_override, + 'run_evaluation': run_evaluation, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_starting_num_workers': ( + evaluation_dataflow_starting_num_workers + ), + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'dataflow_service_account': dataflow_service_account, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'encryption_spec_key_name': encryption_spec_key_name, } fte_params = { - 'dataset_level_custom_transformation_definitions': + 'dataset_level_custom_transformation_definitions': ( dataset_level_custom_transformation_definitions - if dataset_level_custom_transformation_definitions else [], - 'dataset_level_transformations': - dataset_level_transformations - if dataset_level_transformations else [], - 'run_feature_selection': - run_feature_selection, - 'feature_selection_algorithm': - feature_selection_algorithm, - 'max_selected_features': - max_selected_features, - 'predefined_split_key': - predefined_split_key, - 'stratified_split_key': - stratified_split_key, - 'training_fraction': - training_fraction, - 'validation_fraction': - validation_fraction, - 'test_fraction': - test_fraction, - 'tf_auto_transform_features': - tf_auto_transform_features if tf_auto_transform_features else [], - 'tf_custom_transformation_definitions': + if dataset_level_custom_transformation_definitions + else [] + ), + 'dataset_level_transformations': ( + dataset_level_transformations if dataset_level_transformations else [] + ), + 'run_feature_selection': run_feature_selection, + 'feature_selection_algorithm': feature_selection_algorithm, + 'max_selected_features': max_selected_features, + 'predefined_split_key': predefined_split_key, + 'stratified_split_key': stratified_split_key, + 'training_fraction': training_fraction, + 'validation_fraction': validation_fraction, + 'test_fraction': test_fraction, + 'tf_auto_transform_features': ( + tf_auto_transform_features if tf_auto_transform_features else [] + ), + 'tf_custom_transformation_definitions': ( tf_custom_transformation_definitions - if tf_custom_transformation_definitions else [], - 'tf_transformations_path': - tf_transformations_path, + if tf_custom_transformation_definitions + else [] + ), + 'tf_transformations_path': tf_transformations_path, } _update_parameters(parameter_values, fte_params) data_source_and_split_parameters = { 'data_source_csv_filenames': data_source_csv_filenames, 'data_source_bigquery_table_path': data_source_bigquery_table_path, - 'bigquery_staging_full_dataset_id': bigquery_staging_full_dataset_id + 'bigquery_staging_full_dataset_id': bigquery_staging_full_dataset_id, } _update_parameters(parameter_values, data_source_and_split_parameters) pipeline_definition_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - 'tabnet_hyperparameter_tuning_job_pipeline.yaml' + 'tabnet_hyperparameter_tuning_job_pipeline.yaml', ) return pipeline_definition_path, parameter_values @@ -2985,8 +2508,9 @@ def get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters( max_trial_count: int, parallel_trial_count: int, transform_config: Optional[str] = None, - dataset_level_custom_transformation_definitions: Optional[List[Dict[ - str, Any]]] = None, + dataset_level_custom_transformation_definitions: Optional[ + List[Dict[str, Any]] + ] = None, dataset_level_transformations: Optional[List[Dict[str, Any]]] = None, run_feature_selection: bool = False, feature_selection_algorithm: Optional[str] = None, @@ -3016,22 +2540,18 @@ def get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters( transform_dataflow_disk_size_gb: int = 40, worker_pool_specs_override: Optional[Dict[str, Any]] = None, run_evaluation: bool = True, - evaluation_batch_predict_machine_type: - str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, - evaluation_batch_predict_starting_replica_count: - int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, - evaluation_batch_predict_max_replica_count: - int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, + evaluation_batch_predict_machine_type: str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, + evaluation_batch_predict_starting_replica_count: int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, + evaluation_batch_predict_max_replica_count: int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, evaluation_dataflow_machine_type: str = _EVALUATION_DATAFLOW_MACHINE_TYPE, - evaluation_dataflow_starting_num_workers: - int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, - evaluation_dataflow_max_num_workers: - int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, + evaluation_dataflow_starting_num_workers: int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, + evaluation_dataflow_max_num_workers: int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, evaluation_dataflow_disk_size_gb: int = _EVALUATION_DATAFLOW_DISK_SIZE_GB, dataflow_service_account: str = '', dataflow_subnetwork: str = '', dataflow_use_public_ips: bool = True, - encryption_spec_key_name: str = '') -> Tuple[str, Dict[str, Any]]: + encryption_spec_key_name: str = '', +) -> Tuple[str, Dict[str, Any]]: """Get the Wide & Deep algorithm HyperparameterTuningJob pipeline. Args: @@ -3052,31 +2572,24 @@ def get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters( parameter specification of the metric. max_trial_count: The desired total number of trials. parallel_trial_count: The desired number of trials to run in parallel. - transform_config: - Path to v1 TF transformation configuration. - dataset_level_custom_transformation_definitions: - Dataset-level custom transformation definitions in string format. - dataset_level_transformations: - Dataset-level transformation configuration in string format. + transform_config: Path to v1 TF transformation configuration. + dataset_level_custom_transformation_definitions: Dataset-level custom + transformation definitions in string format. + dataset_level_transformations: Dataset-level transformation configuration in + string format. run_feature_selection: Whether to enable feature selection. feature_selection_algorithm: Feature selection algorithm. max_selected_features: Maximum number of features to select. - predefined_split_key: - Predefined split key. - stratified_split_key: - Stratified split key. - training_fraction: - Training fraction. - validation_fraction: - Validation fraction. - test_fraction: - Test fraction. - tf_auto_transform_features: - List of auto transform features in the comma-separated string format. - tf_custom_transformation_definitions: - TF custom transformation definitions in string format. - tf_transformations_path: - Path to TF transformation configuration. + predefined_split_key: Predefined split key. + stratified_split_key: Stratified split key. + training_fraction: Training fraction. + validation_fraction: Validation fraction. + test_fraction: Test fraction. + tf_auto_transform_features: List of auto transform features in the + comma-separated string format. + tf_custom_transformation_definitions: TF custom transformation definitions + in string format. + tf_transformations_path: Path to TF transformation configuration. enable_profiler: Enables profiling and saves a trace during evaluation. cache_data: Whether to cache data or not. If set to 'auto', caching is determined based on the dataset size. @@ -3118,8 +2631,8 @@ def get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters( server for batch predict components during evaluation. evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for @@ -3138,120 +2651,95 @@ def get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters( if transform_config and tf_transformations_path: raise ValueError( 'Only one of transform_config and tf_transformations_path can ' - 'be specified.') + 'be specified.' + ) elif transform_config: warnings.warn( 'transform_config parameter is deprecated. ' - 'Please use the flattened transform config arguments instead.') + 'Please use the flattened transform config arguments instead.' + ) tf_transformations_path = transform_config if not worker_pool_specs_override: worker_pool_specs_override = [] parameter_values = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column': - target_column, - 'prediction_type': - prediction_type, - 'study_spec_metric_id': - study_spec_metric_id, - 'study_spec_metric_goal': - study_spec_metric_goal, - 'study_spec_parameters_override': - study_spec_parameters_override, - 'max_trial_count': - max_trial_count, - 'parallel_trial_count': - parallel_trial_count, - 'enable_profiler': - enable_profiler, - 'cache_data': - cache_data, - 'seed': - seed, - 'eval_steps': - eval_steps, - 'eval_frequency_secs': - eval_frequency_secs, - 'weight_column': - weight_column, - 'max_failed_trial_count': - max_failed_trial_count, - 'study_spec_algorithm': - study_spec_algorithm, - 'study_spec_measurement_selection_type': - study_spec_measurement_selection_type, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'worker_pool_specs_override': - worker_pool_specs_override, - 'run_evaluation': - run_evaluation, - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_starting_num_workers': - evaluation_dataflow_starting_num_workers, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'dataflow_service_account': - dataflow_service_account, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'encryption_spec_key_name': - encryption_spec_key_name, + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column': target_column, + 'prediction_type': prediction_type, + 'study_spec_metric_id': study_spec_metric_id, + 'study_spec_metric_goal': study_spec_metric_goal, + 'study_spec_parameters_override': study_spec_parameters_override, + 'max_trial_count': max_trial_count, + 'parallel_trial_count': parallel_trial_count, + 'enable_profiler': enable_profiler, + 'cache_data': cache_data, + 'seed': seed, + 'eval_steps': eval_steps, + 'eval_frequency_secs': eval_frequency_secs, + 'weight_column': weight_column, + 'max_failed_trial_count': max_failed_trial_count, + 'study_spec_algorithm': study_spec_algorithm, + 'study_spec_measurement_selection_type': ( + study_spec_measurement_selection_type + ), + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'worker_pool_specs_override': worker_pool_specs_override, + 'run_evaluation': run_evaluation, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_starting_num_workers': ( + evaluation_dataflow_starting_num_workers + ), + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'dataflow_service_account': dataflow_service_account, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'encryption_spec_key_name': encryption_spec_key_name, } fte_params = { - 'dataset_level_custom_transformation_definitions': + 'dataset_level_custom_transformation_definitions': ( dataset_level_custom_transformation_definitions - if dataset_level_custom_transformation_definitions else [], - 'dataset_level_transformations': - dataset_level_transformations - if dataset_level_transformations else [], - 'run_feature_selection': - run_feature_selection, - 'feature_selection_algorithm': - feature_selection_algorithm, - 'max_selected_features': - max_selected_features, - 'predefined_split_key': - predefined_split_key, - 'stratified_split_key': - stratified_split_key, - 'training_fraction': - training_fraction, - 'validation_fraction': - validation_fraction, - 'test_fraction': - test_fraction, - 'tf_auto_transform_features': - tf_auto_transform_features if tf_auto_transform_features else [], - 'tf_custom_transformation_definitions': + if dataset_level_custom_transformation_definitions + else [] + ), + 'dataset_level_transformations': ( + dataset_level_transformations if dataset_level_transformations else [] + ), + 'run_feature_selection': run_feature_selection, + 'feature_selection_algorithm': feature_selection_algorithm, + 'max_selected_features': max_selected_features, + 'predefined_split_key': predefined_split_key, + 'stratified_split_key': stratified_split_key, + 'training_fraction': training_fraction, + 'validation_fraction': validation_fraction, + 'test_fraction': test_fraction, + 'tf_auto_transform_features': ( + tf_auto_transform_features if tf_auto_transform_features else [] + ), + 'tf_custom_transformation_definitions': ( tf_custom_transformation_definitions - if tf_custom_transformation_definitions else [], - 'tf_transformations_path': - tf_transformations_path, + if tf_custom_transformation_definitions + else [] + ), + 'tf_transformations_path': tf_transformations_path, } _update_parameters(parameter_values, fte_params) @@ -3264,7 +2752,7 @@ def get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters( pipeline_definition_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - 'wide_and_deep_hyperparameter_tuning_job_pipeline.yaml' + 'wide_and_deep_hyperparameter_tuning_job_pipeline.yaml', ) return pipeline_definition_path, parameter_values @@ -3278,8 +2766,9 @@ def get_tabnet_trainer_pipeline_and_parameters( prediction_type: str, learning_rate: float, transform_config: Optional[str] = None, - dataset_level_custom_transformation_definitions: Optional[List[Dict[ - str, Any]]] = None, + dataset_level_custom_transformation_definitions: Optional[ + List[Dict[str, Any]] + ] = None, dataset_level_transformations: Optional[List[Dict[str, Any]]] = None, run_feature_selection: bool = False, feature_selection_algorithm: Optional[str] = None, @@ -3330,205 +2819,140 @@ def get_tabnet_trainer_pipeline_and_parameters( transform_dataflow_disk_size_gb: int = 40, worker_pool_specs_override: Optional[Dict[str, Any]] = None, run_evaluation: bool = True, - evaluation_batch_predict_machine_type: - str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, - evaluation_batch_predict_starting_replica_count: - int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, - evaluation_batch_predict_max_replica_count: - int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, - evaluation_dataflow_machine_type: - str = _EVALUATION_DATAFLOW_MACHINE_TYPE, - evaluation_dataflow_starting_num_workers: - int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, - evaluation_dataflow_max_num_workers: - int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, - evaluation_dataflow_disk_size_gb: - int = _EVALUATION_DATAFLOW_DISK_SIZE_GB, + evaluation_batch_predict_machine_type: str = _EVALUATION_BATCH_PREDICT_MACHINE_TYPE, + evaluation_batch_predict_starting_replica_count: int = _EVALUATION_BATCH_PREDICT_STARTING_REPLICA_COUNT, + evaluation_batch_predict_max_replica_count: int = _EVALUATION_BATCH_PREDICT_MAX_REPLICA_COUNT, + evaluation_dataflow_machine_type: str = _EVALUATION_DATAFLOW_MACHINE_TYPE, + evaluation_dataflow_starting_num_workers: int = _EVALUATION_DATAFLOW_STARTING_NUM_WORKERS, + evaluation_dataflow_max_num_workers: int = _EVALUATION_DATAFLOW_MAX_NUM_WORKERS, + evaluation_dataflow_disk_size_gb: int = _EVALUATION_DATAFLOW_DISK_SIZE_GB, dataflow_service_account: str = '', dataflow_subnetwork: str = '', dataflow_use_public_ips: bool = True, - encryption_spec_key_name: str = '') -> Tuple[str, Dict[str, Any]]: + encryption_spec_key_name: str = '', +) -> Tuple[str, Dict[str, Any]]: """Get the TabNet training pipeline. Args: - project: - The GCP project that runs the pipeline components. - location: - The GCP region that runs the pipeline components. - root_dir: - The root GCS directory for the pipeline components. - target_column: - The target column name. - prediction_type: - The type of prediction the model is to produce. + project: The GCP project that runs the pipeline components. + location: The GCP region that runs the pipeline components. + root_dir: The root GCS directory for the pipeline components. + target_column: The target column name. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - learning_rate: - The learning rate used by the linear optimizer. - transform_config: - Path to v1 TF transformation configuration. - dataset_level_custom_transformation_definitions: - Dataset-level custom transformation definitions in string format. - dataset_level_transformations: - Dataset-level transformation configuration in string format. + learning_rate: The learning rate used by the linear optimizer. + transform_config: Path to v1 TF transformation configuration. + dataset_level_custom_transformation_definitions: Dataset-level custom + transformation definitions in string format. + dataset_level_transformations: Dataset-level transformation configuration in + string format. run_feature_selection: Whether to enable feature selection. feature_selection_algorithm: Feature selection algorithm. max_selected_features: Maximum number of features to select. - predefined_split_key: - Predefined split key. - stratified_split_key: - Stratified split key. - training_fraction: - Training fraction. - validation_fraction: - Validation fraction. - test_fraction: - Test fraction. - tf_auto_transform_features: - List of auto transform features in the comma-separated string format. - tf_custom_transformation_definitions: - TF custom transformation definitions in string format. - tf_transformations_path: - Path to TF transformation configuration. - max_steps: - Number of steps to run the trainer for. - max_train_secs: - Amount of time in seconds to run the trainer for. - large_category_dim: - Embedding dimension for categorical feature with large + predefined_split_key: Predefined split key. + stratified_split_key: Stratified split key. + training_fraction: Training fraction. + validation_fraction: Validation fraction. + test_fraction: Test fraction. + tf_auto_transform_features: List of auto transform features in the + comma-separated string format. + tf_custom_transformation_definitions: TF custom transformation definitions + in string format. + tf_transformations_path: Path to TF transformation configuration. + max_steps: Number of steps to run the trainer for. + max_train_secs: Amount of time in seconds to run the trainer for. + large_category_dim: Embedding dimension for categorical feature with large number of categories. - large_category_thresh: - Threshold for number of categories to apply + large_category_thresh: Threshold for number of categories to apply large_category_dim embedding dimension to. - yeo_johnson_transform: - Enables trainable Yeo-Johnson power transform. - feature_dim: - Dimensionality of the hidden representation in feature + yeo_johnson_transform: Enables trainable Yeo-Johnson power transform. + feature_dim: Dimensionality of the hidden representation in feature transformation block. - feature_dim_ratio: - The ratio of output dimension (dimensionality of the + feature_dim_ratio: The ratio of output dimension (dimensionality of the outputs of each decision step) to feature dimension. - num_decision_steps: - Number of sequential decision steps. - relaxation_factor: - Relaxation factor that promotes the reuse of each feature + num_decision_steps: Number of sequential decision steps. + relaxation_factor: Relaxation factor that promotes the reuse of each feature at different decision steps. When it is 1, a feature is enforced to be used only at one decision step and as it increases, more flexibility is provided to use a feature at multiple decision steps. - decay_every: - Number of iterations for periodically applying learning rate + decay_every: Number of iterations for periodically applying learning rate decaying. - decay_rate: - Learning rate decaying. - gradient_thresh: - Threshold for the norm of gradients for clipping. - sparsity_loss_weight: - Weight of the loss for sparsity regularization + decay_rate: Learning rate decaying. + gradient_thresh: Threshold for the norm of gradients for clipping. + sparsity_loss_weight: Weight of the loss for sparsity regularization (increasing it will yield more sparse feature selection). - batch_momentum: - Momentum in ghost batch normalization. - batch_size_ratio: - The ratio of virtual batch size (size of the ghost batch + batch_momentum: Momentum in ghost batch normalization. + batch_size_ratio: The ratio of virtual batch size (size of the ghost batch normalization) to batch size. - num_transformer_layers: - The number of transformer layers for each decision + num_transformer_layers: The number of transformer layers for each decision step. used only at one decision step and as it increases, more flexibility is provided to use a feature at multiple decision steps. - num_transformer_layers_ratio: - The ratio of shared transformer layer to + num_transformer_layers_ratio: The ratio of shared transformer layer to transformer layers. - class_weight: - The class weight is used to computes a weighted cross entropy + class_weight: The class weight is used to computes a weighted cross entropy which is helpful in classify imbalanced dataset. Only used for classification. - loss_function_type: - Loss function type. Loss function in classification + loss_function_type: Loss function type. Loss function in classification [cross_entropy, weighted_cross_entropy, focal_loss], default is - cross_entropy. Loss function in regression: - [rmse, mae, mse], default is + cross_entropy. Loss function in regression: [rmse, mae, mse], default is mse. - alpha_focal_loss: - Alpha value (balancing factor) in focal_loss function. + alpha_focal_loss: Alpha value (balancing factor) in focal_loss function. Only used for classification. - gamma_focal_loss: - Gamma value (modulating factor) for focal loss for focal + gamma_focal_loss: Gamma value (modulating factor) for focal loss for focal loss. Only used for classification. - enable_profiler: - Enables profiling and saves a trace during evaluation. + enable_profiler: Enables profiling and saves a trace during evaluation. cache_data: Whether to cache data or not. If set to 'auto', caching is determined based on the dataset size. - seed: - Seed to be used for this run. - eval_steps: - Number of steps to run evaluation for. If not specified or + seed: Seed to be used for this run. + eval_steps: Number of steps to run evaluation for. If not specified or negative, it means run evaluation on the whole validation dataset. If set to 0, it means run evaluation for a fixed number of samples. - batch_size: - Batch size for training. - measurement_selection_type: - Which measurement to use if/when the service automatically - selects the final measurement from previously - reported intermediate measurements. One of "BEST_MEASUREMENT" or + batch_size: Batch size for training. + measurement_selection_type: Which measurement to use if/when the service + automatically selects the final measurement from previously reported + intermediate measurements. One of "BEST_MEASUREMENT" or "LAST_MEASUREMENT". - optimization_metric: - Optimization metric used for `measurement_selection_type`. - Default is "rmse" for regression and "auc" for classification. - eval_frequency_secs: - Frequency at which evaluation and checkpointing will + optimization_metric: Optimization metric used for + `measurement_selection_type`. Default is "rmse" for regression and "auc" + for classification. + eval_frequency_secs: Frequency at which evaluation and checkpointing will take place. - data_source_csv_filenames: - The CSV data source. - data_source_bigquery_table_path: - The BigQuery data source. + data_source_csv_filenames: The CSV data source. + data_source_bigquery_table_path: The BigQuery data source. bigquery_staging_full_dataset_id: The BigQuery staging full dataset id for storing intermediate tables. - weight_column: - The weight column name. - transform_dataflow_machine_type: - The dataflow machine type for transform + weight_column: The weight column name. + transform_dataflow_machine_type: The dataflow machine type for transform component. - transform_dataflow_max_num_workers: - The max number of Dataflow workers for + transform_dataflow_max_num_workers: The max number of Dataflow workers for transform component. - transform_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + transform_dataflow_disk_size_gb: Dataflow worker's disk size in GB for transform component. - worker_pool_specs_override: - The dictionary for overriding training and - evaluation worker pool specs. The dictionary should be of format + worker_pool_specs_override: The dictionary for overriding training and + evaluation worker pool specs. The dictionary should be of format https://github.com/googleapis/googleapis/blob/4e836c7c257e3e20b1de14d470993a2b1f4736a8/google/cloud/aiplatform/v1beta1/custom_job.proto#L172. - run_evaluation: - Whether to run evaluation steps during training. - evaluation_batch_predict_machine_type: - The prediction server machine type + run_evaluation: Whether to run evaluation steps during training. + evaluation_batch_predict_machine_type: The prediction server machine type for batch predict components during evaluation. - evaluation_batch_predict_starting_replica_count: - The initial number of + evaluation_batch_predict_starting_replica_count: The initial number of prediction server for batch predict components during evaluation. - evaluation_batch_predict_max_replica_count: - The max number of prediction + evaluation_batch_predict_max_replica_count: The max number of prediction server for batch predict components during evaluation. - evaluation_dataflow_machine_type: - The dataflow machine type for evaluation + evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. - evaluation_dataflow_max_num_workers: - The max number of Dataflow workers for + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. + evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. - evaluation_dataflow_disk_size_gb: - Dataflow worker's disk size in GB for + evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for evaluation components. - dataflow_service_account: - Custom service account to run dataflow jobs. - dataflow_subnetwork: - Dataflow's fully qualified subnetwork name, when empty + dataflow_service_account: Custom service account to run dataflow jobs. + dataflow_subnetwork: Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be used. Example: https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications - dataflow_use_public_ips: - Specifies whether Dataflow workers use public IP + dataflow_use_public_ips: Specifies whether Dataflow workers use public IP addresses. - encryption_spec_key_name: - The KMS key name. + encryption_spec_key_name: The KMS key name. Returns: Tuple of pipeline_definiton_path and parameter_values. @@ -3536,12 +2960,14 @@ def get_tabnet_trainer_pipeline_and_parameters( if transform_config and tf_transformations_path: raise ValueError( 'Only one of transform_config and tf_transformations_path can ' - 'be specified.') + 'be specified.' + ) elif transform_config: warnings.warn( 'transform_config parameter is deprecated. ' - 'Please use the flattened transform config arguments instead.') + 'Please use the flattened transform config arguments instead.' + ) tf_transformations_path = transform_config if not worker_pool_specs_override: @@ -3549,173 +2975,125 @@ def get_tabnet_trainer_pipeline_and_parameters( parameter_values = {} training_and_eval_parameters = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column': - target_column, - 'prediction_type': - prediction_type, - 'learning_rate': - learning_rate, - 'max_steps': - max_steps, - 'max_train_secs': - max_train_secs, - 'large_category_dim': - large_category_dim, - 'large_category_thresh': - large_category_thresh, - 'yeo_johnson_transform': - yeo_johnson_transform, - 'feature_dim': - feature_dim, - 'feature_dim_ratio': - feature_dim_ratio, - 'num_decision_steps': - num_decision_steps, - 'relaxation_factor': - relaxation_factor, - 'decay_every': - decay_every, - 'decay_rate': - decay_rate, - 'gradient_thresh': - gradient_thresh, - 'sparsity_loss_weight': - sparsity_loss_weight, - 'batch_momentum': - batch_momentum, - 'batch_size_ratio': - batch_size_ratio, - 'num_transformer_layers': - num_transformer_layers, - 'num_transformer_layers_ratio': - num_transformer_layers_ratio, - 'class_weight': - class_weight, - 'loss_function_type': - loss_function_type, - 'alpha_focal_loss': - alpha_focal_loss, - 'gamma_focal_loss': - gamma_focal_loss, - 'enable_profiler': - enable_profiler, - 'cache_data': - cache_data, - 'seed': - seed, - 'eval_steps': - eval_steps, - 'batch_size': - batch_size, - 'measurement_selection_type': - measurement_selection_type, - 'optimization_metric': - optimization_metric, - 'eval_frequency_secs': - eval_frequency_secs, - 'weight_column': - weight_column, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'worker_pool_specs_override': - worker_pool_specs_override, - 'run_evaluation': - run_evaluation, - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_starting_num_workers': - evaluation_dataflow_starting_num_workers, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'dataflow_service_account': - dataflow_service_account, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'encryption_spec_key_name': - encryption_spec_key_name, + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column': target_column, + 'prediction_type': prediction_type, + 'learning_rate': learning_rate, + 'max_steps': max_steps, + 'max_train_secs': max_train_secs, + 'large_category_dim': large_category_dim, + 'large_category_thresh': large_category_thresh, + 'yeo_johnson_transform': yeo_johnson_transform, + 'feature_dim': feature_dim, + 'feature_dim_ratio': feature_dim_ratio, + 'num_decision_steps': num_decision_steps, + 'relaxation_factor': relaxation_factor, + 'decay_every': decay_every, + 'decay_rate': decay_rate, + 'gradient_thresh': gradient_thresh, + 'sparsity_loss_weight': sparsity_loss_weight, + 'batch_momentum': batch_momentum, + 'batch_size_ratio': batch_size_ratio, + 'num_transformer_layers': num_transformer_layers, + 'num_transformer_layers_ratio': num_transformer_layers_ratio, + 'class_weight': class_weight, + 'loss_function_type': loss_function_type, + 'alpha_focal_loss': alpha_focal_loss, + 'gamma_focal_loss': gamma_focal_loss, + 'enable_profiler': enable_profiler, + 'cache_data': cache_data, + 'seed': seed, + 'eval_steps': eval_steps, + 'batch_size': batch_size, + 'measurement_selection_type': measurement_selection_type, + 'optimization_metric': optimization_metric, + 'eval_frequency_secs': eval_frequency_secs, + 'weight_column': weight_column, + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'worker_pool_specs_override': worker_pool_specs_override, + 'run_evaluation': run_evaluation, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_starting_num_workers': ( + evaluation_dataflow_starting_num_workers + ), + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'dataflow_service_account': dataflow_service_account, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'encryption_spec_key_name': encryption_spec_key_name, } _update_parameters(parameter_values, training_and_eval_parameters) fte_params = { - 'dataset_level_custom_transformation_definitions': + 'dataset_level_custom_transformation_definitions': ( dataset_level_custom_transformation_definitions - if dataset_level_custom_transformation_definitions else [], - 'dataset_level_transformations': - dataset_level_transformations - if dataset_level_transformations else [], - 'run_feature_selection': - run_feature_selection, - 'feature_selection_algorithm': - feature_selection_algorithm, - 'max_selected_features': - max_selected_features, - 'predefined_split_key': - predefined_split_key, - 'stratified_split_key': - stratified_split_key, - 'training_fraction': - training_fraction, - 'validation_fraction': - validation_fraction, - 'test_fraction': - test_fraction, - 'tf_auto_transform_features': - tf_auto_transform_features if tf_auto_transform_features else [], - 'tf_custom_transformation_definitions': + if dataset_level_custom_transformation_definitions + else [] + ), + 'dataset_level_transformations': ( + dataset_level_transformations if dataset_level_transformations else [] + ), + 'run_feature_selection': run_feature_selection, + 'feature_selection_algorithm': feature_selection_algorithm, + 'max_selected_features': max_selected_features, + 'predefined_split_key': predefined_split_key, + 'stratified_split_key': stratified_split_key, + 'training_fraction': training_fraction, + 'validation_fraction': validation_fraction, + 'test_fraction': test_fraction, + 'tf_auto_transform_features': ( + tf_auto_transform_features if tf_auto_transform_features else [] + ), + 'tf_custom_transformation_definitions': ( tf_custom_transformation_definitions - if tf_custom_transformation_definitions else [], - 'tf_transformations_path': - tf_transformations_path, + if tf_custom_transformation_definitions + else [] + ), + 'tf_transformations_path': tf_transformations_path, } _update_parameters(parameter_values, fte_params) data_source_and_split_parameters = { 'data_source_csv_filenames': data_source_csv_filenames, 'data_source_bigquery_table_path': data_source_bigquery_table_path, - 'bigquery_staging_full_dataset_id': bigquery_staging_full_dataset_id + 'bigquery_staging_full_dataset_id': bigquery_staging_full_dataset_id, } _update_parameters(parameter_values, data_source_and_split_parameters) pipeline_definition_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), 'tabnet_trainer_pipeline.yaml') + pathlib.Path(__file__).parent.resolve(), 'tabnet_trainer_pipeline.yaml' + ) return pipeline_definition_path, parameter_values def get_tabnet_study_spec_parameters_override( - dataset_size_bucket: str, prediction_type: str, - training_budget_bucket: str) -> List[Dict[str, Any]]: + dataset_size_bucket: str, prediction_type: str, training_budget_bucket: str +) -> List[Dict[str, Any]]: """Get study_spec_parameters_override for a TabNet hyperparameter tuning job. Args: - dataset_size_bucket: - Size of the dataset. One of "small" (< 1M rows), + dataset_size_bucket: Size of the dataset. One of "small" (< 1M rows), "medium" (1M - 100M rows), or "large" (> 100M rows). - prediction_type: - The type of prediction the model is to produce. + prediction_type: The type of prediction the model is to produce. "classification" or "regression". - training_budget_bucket: - Bucket of the estimated training budget. One of + training_budget_bucket: Bucket of the estimated training budget. One of "small" (< $600), "medium" ($600 - $2400), or "large" (> $2400). This parameter is only used as a hint for the hyperparameter search space, unrelated to the real cost. @@ -3725,16 +3103,19 @@ def get_tabnet_study_spec_parameters_override( """ if dataset_size_bucket not in ['small', 'medium', 'large']: - raise ValueError('Invalid dataset_size_bucket provided. Supported values ' - ' are "small", "medium" or "large".') + raise ValueError( + 'Invalid dataset_size_bucket provided. Supported values ' + ' are "small", "medium" or "large".' + ) if training_budget_bucket not in ['small', 'medium', 'large']: raise ValueError( 'Invalid training_budget_bucket provided. Supported values ' - 'are "small", "medium" or "large".') + 'are "small", "medium" or "large".' + ) param_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - f'configs/tabnet_params_{dataset_size_bucket}_data_{training_budget_bucket}_search_space.json' + f'configs/tabnet_params_{dataset_size_bucket}_data_{training_budget_bucket}_search_space.json', ) with open(param_path, 'r') as f: param_content = f.read() @@ -3742,23 +3123,22 @@ def get_tabnet_study_spec_parameters_override( if prediction_type == 'regression': return _format_tabnet_regression_study_spec_parameters_override( - params, training_budget_bucket) + params, training_budget_bucket + ) return params def _format_tabnet_regression_study_spec_parameters_override( - params: List[Dict[str, Any]], - training_budget_bucket: str) -> List[Dict[str, Any]]: + params: List[Dict[str, Any]], training_budget_bucket: str +) -> List[Dict[str, Any]]: """Get regression study_spec_parameters_override for a TabNet hyperparameter tuning job. Args: - params: - List of dictionaries representing parameters to optimize. The + params: List of dictionaries representing parameters to optimize. The dictionary key is the parameter_id, which is passed to training job as a command line argument, and the dictionary value is the parameter specification of the metric. - training_budget_bucket: - Bucket of the estimated training budget. One of + training_budget_bucket: Bucket of the estimated training budget. One of "small" (< $600), "medium" ($600 - $2400), or "large" (> $2400). This parameter is only used as a hint for the hyperparameter search space, unrelated to the real cost. @@ -3775,7 +3155,9 @@ def _format_tabnet_regression_study_spec_parameters_override( formatted_params = [] for param in params: if param['parameter_id'] in [ - 'alpha_focal_loss', 'gamma_focal_loss', 'class_weight' + 'alpha_focal_loss', + 'gamma_focal_loss', + 'class_weight', ]: continue elif param['parameter_id'] == 'sparsity_loss_weight': @@ -3799,7 +3181,8 @@ def get_wide_and_deep_study_spec_parameters_override() -> List[Dict[str, Any]]: """ param_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - 'configs/wide_and_deep_params.json') + 'configs/wide_and_deep_params.json', + ) with open(param_path, 'r') as f: param_content = f.read() params = json.loads(param_content) @@ -3814,8 +3197,8 @@ def get_xgboost_study_spec_parameters_override() -> List[Dict[str, Any]]: List of study_spec_parameters_override. """ param_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), - 'configs/xgboost_params.json') + pathlib.Path(__file__).parent.resolve(), 'configs/xgboost_params.json' + ) with open(param_path, 'r') as f: param_content = f.read() params = json.loads(param_content) @@ -3882,17 +3265,19 @@ def get_model_comparison_pipeline_and_parameters( 'training_jobs': training_jobs, 'data_source_csv_filenames': data_source_csv_filenames, 'data_source_bigquery_table_path': data_source_bigquery_table_path, - 'evaluation_data_source_csv_filenames': - evaluation_data_source_csv_filenames, - 'evaluation_data_source_bigquery_table_path': - evaluation_data_source_bigquery_table_path, + 'evaluation_data_source_csv_filenames': ( + evaluation_data_source_csv_filenames + ), + 'evaluation_data_source_bigquery_table_path': ( + evaluation_data_source_bigquery_table_path + ), 'experiment': experiment, 'service_account': service_account, 'network': network, } pipeline_definition_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), - 'model_comparison_pipeline.json') + pathlib.Path(__file__).parent.resolve(), 'model_comparison_pipeline.json' + ) return pipeline_definition_path, parameter_values @@ -3943,8 +3328,9 @@ def get_xgboost_trainer_pipeline_and_parameters( max_bin: Optional[int] = None, tweedie_variance_power: Optional[float] = None, huber_slope: Optional[float] = None, - dataset_level_custom_transformation_definitions: Optional[List[Dict[ - str, Any]]] = None, + dataset_level_custom_transformation_definitions: Optional[ + List[Dict[str, Any]] + ] = None, dataset_level_transformations: Optional[List[Dict[str, Any]]] = None, run_feature_selection: Optional[bool] = None, feature_selection_algorithm: Optional[str] = None, @@ -3979,7 +3365,8 @@ def get_xgboost_trainer_pipeline_and_parameters( dataflow_service_account: Optional[str] = None, dataflow_subnetwork: Optional[str] = None, dataflow_use_public_ips: Optional[bool] = None, - encryption_spec_key_name: Optional[str] = None): + encryption_spec_key_name: Optional[str] = None, +): """Get the XGBoost training pipeline. Args: @@ -3987,75 +3374,69 @@ def get_xgboost_trainer_pipeline_and_parameters( location: The GCP region that runs the pipeline components. root_dir: The root GCS directory for the pipeline components. target_column: The target column name. - objective: Specifies the learning task and the learning - objective. Must be one of [reg:squarederror, reg:squaredlogerror, + objective: Specifies the learning task and the learning objective. Must be + one of [reg:squarederror, reg:squaredlogerror, reg:logistic, reg:gamma, reg:tweedie, reg:pseudohubererror, binary:logistic, multi:softprob]. eval_metric: Evaluation metrics for validation data represented as a comma-separated string. num_boost_round: Number of boosting iterations. - early_stopping_rounds: Activates early stopping. Validation - error needs to decrease at least every early_stopping_rounds round(s) to - continue training. - base_score: The initial prediction score of all instances, global - bias. - disable_default_eval_metric: Flag to disable default metric. Set to >0 - to disable. Default to 0. + early_stopping_rounds: Activates early stopping. Validation error needs to + decrease at least every early_stopping_rounds round(s) to continue + training. + base_score: The initial prediction score of all instances, global bias. + disable_default_eval_metric: Flag to disable default metric. Set to >0 to + disable. Default to 0. seed: Random seed. seed_per_iteration: Seed PRNG determnisticly via iterator number. - booster: Which booster to use, can be gbtree, gblinear or dart. - gbtree and dart use tree based model while gblinear uses linear function. + booster: Which booster to use, can be gbtree, gblinear or dart. gbtree and + dart use tree based model while gblinear uses linear function. eta: Learning rate. - gamma: Minimum loss reduction required to make a further partition - on a leaf node of the tree. + gamma: Minimum loss reduction required to make a further partition on a leaf + node of the tree. max_depth: Maximum depth of a tree. - min_child_weight: Minimum sum of instance weight(hessian) needed in - a child. - max_delta_step: Maximum delta step we allow each tree's weight - estimation to be. + min_child_weight: Minimum sum of instance weight(hessian) needed in a child. + max_delta_step: Maximum delta step we allow each tree's weight estimation to + be. subsample: Subsample ratio of the training instance. - colsample_bytree: Subsample ratio of columns when constructing each - tree. - colsample_bylevel: Subsample ratio of columns for each split, in - each level. + colsample_bytree: Subsample ratio of columns when constructing each tree. + colsample_bylevel: Subsample ratio of columns for each split, in each level. colsample_bynode: Subsample ratio of columns for each node (split). reg_lambda: L2 regularization term on weights. reg_alpha: L1 regularization term on weights. tree_method: The tree construction algorithm used in XGBoost. Choices: ["auto", "exact", "approx", "hist", "gpu_exact", "gpu_hist"]. - scale_pos_weight: Control the balance of positive and negative - weights. - updater: A comma separated string defining the sequence of tree - updaters to run. - refresh_leaf: Refresh updater plugin. Update tree leaf and nodes's - stats if True. When it is False, only node stats are updated. + scale_pos_weight: Control the balance of positive and negative weights. + updater: A comma separated string defining the sequence of tree updaters to + run. + refresh_leaf: Refresh updater plugin. Update tree leaf and nodes's stats if + True. When it is False, only node stats are updated. process_type: A type of boosting process to run. Choices:["default", "update"] - grow_policy: Controls a way new nodes are added to the tree. Only - supported if tree_method is hist. Choices:["depthwise", "lossguide"] + grow_policy: Controls a way new nodes are added to the tree. Only supported + if tree_method is hist. Choices:["depthwise", "lossguide"] sampling_method: The method to use to sample the training instances. - monotone_constraints: Constraint of variable - monotonicity. - interaction_constraints: Constraints for - interaction representing permitted interactions. + monotone_constraints: Constraint of variable monotonicity. + interaction_constraints: Constraints for interaction representing permitted + interactions. sample_type: [dart booster only] Type of sampling algorithm. Choices:["uniform", "weighted"] normalize_type: [dart booster only] Type of normalization algorithm, Choices:["tree", "forest"] rate_drop: [dart booster only] Dropout rate.' - one_drop: [dart booster only] When this flag is enabled, at least one - tree is always dropped during the dropout (allows Binomial-plus-one or + one_drop: [dart booster only] When this flag is enabled, at least one tree + is always dropped during the dropout (allows Binomial-plus-one or epsilon-dropout from the original DART paper). - skip_drop: [dart booster only] Probability of skipping the dropout - procedure during a boosting iteration. + skip_drop: [dart booster only] Probability of skipping the dropout procedure + during a boosting iteration. num_parallel_tree: Number of parallel trees constructed during each iteration. This option is used to support boosted random forest. - feature_selector: [linear booster only] Feature selection and - ordering method. - top_k: The number of top features to select in greedy and thrifty - feature selector. The value of 0 means using all the features. - max_cat_to_onehot: A threshold for deciding whether XGBoost should - use one-hot encoding based split for categorical data. + feature_selector: [linear booster only] Feature selection and ordering + method. + top_k: The number of top features to select in greedy and thrifty feature + selector. The value of 0 means using all the features. + max_cat_to_onehot: A threshold for deciding whether XGBoost should use + one-hot encoding based split for categorical data. max_leaves: Maximum number of nodes to be added. max_bin: Maximum number of discrete bins to bucket continuous features. tweedie_variance_power: Parameter that controls the variance of the Tweedie @@ -4103,8 +3484,8 @@ def get_xgboost_trainer_pipeline_and_parameters( server for batch predict components during evaluation. evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for @@ -4122,171 +3503,111 @@ def get_xgboost_trainer_pipeline_and_parameters( """ parameter_values = {} training_and_eval_parameters = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column': - target_column, - 'objective': - objective, - 'eval_metric': - eval_metric, - 'num_boost_round': - num_boost_round, - 'early_stopping_rounds': - early_stopping_rounds, - 'base_score': - base_score, - 'disable_default_eval_metric': - disable_default_eval_metric, - 'seed': - seed, - 'seed_per_iteration': - seed_per_iteration, - 'booster': - booster, - 'eta': - eta, - 'gamma': - gamma, - 'max_depth': - max_depth, - 'min_child_weight': - min_child_weight, - 'max_delta_step': - max_delta_step, - 'subsample': - subsample, - 'colsample_bytree': - colsample_bytree, - 'colsample_bylevel': - colsample_bylevel, - 'colsample_bynode': - colsample_bynode, - 'reg_lambda': - reg_lambda, - 'reg_alpha': - reg_alpha, - 'tree_method': - tree_method, - 'scale_pos_weight': - scale_pos_weight, - 'updater': - updater, - 'refresh_leaf': - refresh_leaf, - 'process_type': - process_type, - 'grow_policy': - grow_policy, - 'sampling_method': - sampling_method, - 'monotone_constraints': - monotone_constraints, - 'interaction_constraints': - interaction_constraints, - 'sample_type': - sample_type, - 'normalize_type': - normalize_type, - 'rate_drop': - rate_drop, - 'one_drop': - one_drop, - 'skip_drop': - skip_drop, - 'num_parallel_tree': - num_parallel_tree, - 'feature_selector': - feature_selector, - 'top_k': - top_k, - 'max_cat_to_onehot': - max_cat_to_onehot, - 'max_leaves': - max_leaves, - 'max_bin': - max_bin, - 'tweedie_variance_power': - tweedie_variance_power, - 'huber_slope': - huber_slope, - 'weight_column': - weight_column, - 'training_machine_type': - training_machine_type, - 'training_total_replica_count': - training_total_replica_count, - 'training_accelerator_type': - training_accelerator_type, - 'training_accelerator_count': - training_accelerator_count, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'run_evaluation': - run_evaluation, - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_starting_num_workers': - evaluation_dataflow_starting_num_workers, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'dataflow_service_account': - dataflow_service_account, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'encryption_spec_key_name': - encryption_spec_key_name, + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column': target_column, + 'objective': objective, + 'eval_metric': eval_metric, + 'num_boost_round': num_boost_round, + 'early_stopping_rounds': early_stopping_rounds, + 'base_score': base_score, + 'disable_default_eval_metric': disable_default_eval_metric, + 'seed': seed, + 'seed_per_iteration': seed_per_iteration, + 'booster': booster, + 'eta': eta, + 'gamma': gamma, + 'max_depth': max_depth, + 'min_child_weight': min_child_weight, + 'max_delta_step': max_delta_step, + 'subsample': subsample, + 'colsample_bytree': colsample_bytree, + 'colsample_bylevel': colsample_bylevel, + 'colsample_bynode': colsample_bynode, + 'reg_lambda': reg_lambda, + 'reg_alpha': reg_alpha, + 'tree_method': tree_method, + 'scale_pos_weight': scale_pos_weight, + 'updater': updater, + 'refresh_leaf': refresh_leaf, + 'process_type': process_type, + 'grow_policy': grow_policy, + 'sampling_method': sampling_method, + 'monotone_constraints': monotone_constraints, + 'interaction_constraints': interaction_constraints, + 'sample_type': sample_type, + 'normalize_type': normalize_type, + 'rate_drop': rate_drop, + 'one_drop': one_drop, + 'skip_drop': skip_drop, + 'num_parallel_tree': num_parallel_tree, + 'feature_selector': feature_selector, + 'top_k': top_k, + 'max_cat_to_onehot': max_cat_to_onehot, + 'max_leaves': max_leaves, + 'max_bin': max_bin, + 'tweedie_variance_power': tweedie_variance_power, + 'huber_slope': huber_slope, + 'weight_column': weight_column, + 'training_machine_type': training_machine_type, + 'training_total_replica_count': training_total_replica_count, + 'training_accelerator_type': training_accelerator_type, + 'training_accelerator_count': training_accelerator_count, + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'run_evaluation': run_evaluation, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_starting_num_workers': ( + evaluation_dataflow_starting_num_workers + ), + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'dataflow_service_account': dataflow_service_account, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'encryption_spec_key_name': encryption_spec_key_name, } _update_parameters(parameter_values, training_and_eval_parameters) fte_params = { - 'dataset_level_custom_transformation_definitions': + 'dataset_level_custom_transformation_definitions': ( dataset_level_custom_transformation_definitions - if dataset_level_custom_transformation_definitions else [], - 'dataset_level_transformations': - dataset_level_transformations - if dataset_level_transformations else [], - 'run_feature_selection': - run_feature_selection, - 'feature_selection_algorithm': - feature_selection_algorithm, - 'max_selected_features': - max_selected_features, - 'predefined_split_key': - predefined_split_key, - 'stratified_split_key': - stratified_split_key, - 'training_fraction': - training_fraction, - 'validation_fraction': - validation_fraction, - 'test_fraction': - test_fraction, - 'tf_auto_transform_features': - tf_auto_transform_features if tf_auto_transform_features else [], - 'tf_custom_transformation_definitions': + if dataset_level_custom_transformation_definitions + else [] + ), + 'dataset_level_transformations': ( + dataset_level_transformations if dataset_level_transformations else [] + ), + 'run_feature_selection': run_feature_selection, + 'feature_selection_algorithm': feature_selection_algorithm, + 'max_selected_features': max_selected_features, + 'predefined_split_key': predefined_split_key, + 'stratified_split_key': stratified_split_key, + 'training_fraction': training_fraction, + 'validation_fraction': validation_fraction, + 'test_fraction': test_fraction, + 'tf_auto_transform_features': ( + tf_auto_transform_features if tf_auto_transform_features else [] + ), + 'tf_custom_transformation_definitions': ( tf_custom_transformation_definitions - if tf_custom_transformation_definitions else [], - 'tf_transformations_path': - tf_transformations_path, + if tf_custom_transformation_definitions + else [] + ), + 'tf_transformations_path': tf_transformations_path, } _update_parameters(parameter_values, fte_params) @@ -4298,7 +3619,8 @@ def get_xgboost_trainer_pipeline_and_parameters( _update_parameters(parameter_values, data_source_and_split_parameters) pipeline_definition_path = os.path.join( - pathlib.Path(__file__).parent.resolve(), 'xgboost_trainer_pipeline.yaml') + pathlib.Path(__file__).parent.resolve(), 'xgboost_trainer_pipeline.yaml' + ) return pipeline_definition_path, parameter_values @@ -4318,8 +3640,9 @@ def get_xgboost_hyperparameter_tuning_job_pipeline_and_parameters( disable_default_eval_metric: Optional[int] = None, seed: Optional[int] = None, seed_per_iteration: Optional[bool] = None, - dataset_level_custom_transformation_definitions: Optional[List[Dict[ - str, Any]]] = None, + dataset_level_custom_transformation_definitions: Optional[ + List[Dict[str, Any]] + ] = None, dataset_level_transformations: Optional[List[Dict[str, Any]]] = None, run_feature_selection: Optional[bool] = None, feature_selection_algorithm: Optional[str] = None, @@ -4357,7 +3680,8 @@ def get_xgboost_hyperparameter_tuning_job_pipeline_and_parameters( dataflow_service_account: Optional[str] = None, dataflow_subnetwork: Optional[str] = None, dataflow_use_public_ips: Optional[bool] = None, - encryption_spec_key_name: Optional[str] = None): + encryption_spec_key_name: Optional[str] = None, +): """Get the XGBoost HyperparameterTuningJob pipeline. Args: @@ -4365,8 +3689,8 @@ def get_xgboost_hyperparameter_tuning_job_pipeline_and_parameters( location: The GCP region that runs the pipeline components. root_dir: The root GCS directory for the pipeline components. target_column: The target column name. - objective: Specifies the learning task and the learning - objective. Must be one of [reg:squarederror, reg:squaredlogerror, + objective: Specifies the learning task and the learning objective. Must be + one of [reg:squarederror, reg:squaredlogerror, reg:logistic, reg:gamma, reg:tweedie, reg:pseudohubererror, binary:logistic, multi:softprob]. study_spec_metric_id: Metric to optimize. For options, please look under @@ -4382,8 +3706,8 @@ def get_xgboost_hyperparameter_tuning_job_pipeline_and_parameters( parameter specification of the metric. eval_metric: Evaluation metrics for validation data represented as a comma-separated string. - disable_default_eval_metric: Flag to disable default metric. Set to >0 - to disable. Default to 0. + disable_default_eval_metric: Flag to disable default metric. Set to >0 to + disable. Default to 0. seed: Random seed. seed_per_iteration: Seed PRNG determnisticly via iterator number. dataset_level_custom_transformation_definitions: Dataset-level custom @@ -4436,8 +3760,8 @@ def get_xgboost_hyperparameter_tuning_job_pipeline_and_parameters( server for batch predict components during evaluation. evaluation_dataflow_machine_type: The dataflow machine type for evaluation components. - evaluation_dataflow_starting_num_workers: - The initial number of Dataflow workers for evaluation components. + evaluation_dataflow_starting_num_workers: The initial number of Dataflow + workers for evaluation components. evaluation_dataflow_max_num_workers: The max number of Dataflow workers for evaluation components. evaluation_dataflow_disk_size_gb: Dataflow worker's disk size in GB for @@ -4455,114 +3779,88 @@ def get_xgboost_hyperparameter_tuning_job_pipeline_and_parameters( """ parameter_values = {} training_and_eval_parameters = { - 'project': - project, - 'location': - location, - 'root_dir': - root_dir, - 'target_column': - target_column, - 'objective': - objective, - 'eval_metric': - eval_metric, - 'study_spec_metric_id': - study_spec_metric_id, - 'study_spec_metric_goal': - study_spec_metric_goal, - 'max_trial_count': - max_trial_count, - 'parallel_trial_count': - parallel_trial_count, - 'study_spec_parameters_override': + 'project': project, + 'location': location, + 'root_dir': root_dir, + 'target_column': target_column, + 'objective': objective, + 'eval_metric': eval_metric, + 'study_spec_metric_id': study_spec_metric_id, + 'study_spec_metric_goal': study_spec_metric_goal, + 'max_trial_count': max_trial_count, + 'parallel_trial_count': parallel_trial_count, + 'study_spec_parameters_override': ( study_spec_parameters_override - if study_spec_parameters_override else [], - 'disable_default_eval_metric': - disable_default_eval_metric, - 'seed': - seed, - 'seed_per_iteration': - seed_per_iteration, - 'weight_column': - weight_column, - 'max_failed_trial_count': - max_failed_trial_count, - 'training_machine_type': - training_machine_type, - 'training_total_replica_count': - training_total_replica_count, - 'training_accelerator_type': - training_accelerator_type, - 'training_accelerator_count': - training_accelerator_count, - 'study_spec_algorithm': - study_spec_algorithm, - 'study_spec_measurement_selection_type': - study_spec_measurement_selection_type, - 'transform_dataflow_machine_type': - transform_dataflow_machine_type, - 'transform_dataflow_max_num_workers': - transform_dataflow_max_num_workers, - 'transform_dataflow_disk_size_gb': - transform_dataflow_disk_size_gb, - 'run_evaluation': - run_evaluation, - 'evaluation_batch_predict_machine_type': - evaluation_batch_predict_machine_type, - 'evaluation_batch_predict_starting_replica_count': - evaluation_batch_predict_starting_replica_count, - 'evaluation_batch_predict_max_replica_count': - evaluation_batch_predict_max_replica_count, - 'evaluation_dataflow_machine_type': - evaluation_dataflow_machine_type, - 'evaluation_dataflow_starting_num_workers': - evaluation_dataflow_starting_num_workers, - 'evaluation_dataflow_max_num_workers': - evaluation_dataflow_max_num_workers, - 'evaluation_dataflow_disk_size_gb': - evaluation_dataflow_disk_size_gb, - 'dataflow_service_account': - dataflow_service_account, - 'dataflow_subnetwork': - dataflow_subnetwork, - 'dataflow_use_public_ips': - dataflow_use_public_ips, - 'encryption_spec_key_name': - encryption_spec_key_name, + if study_spec_parameters_override + else [] + ), + 'disable_default_eval_metric': disable_default_eval_metric, + 'seed': seed, + 'seed_per_iteration': seed_per_iteration, + 'weight_column': weight_column, + 'max_failed_trial_count': max_failed_trial_count, + 'training_machine_type': training_machine_type, + 'training_total_replica_count': training_total_replica_count, + 'training_accelerator_type': training_accelerator_type, + 'training_accelerator_count': training_accelerator_count, + 'study_spec_algorithm': study_spec_algorithm, + 'study_spec_measurement_selection_type': ( + study_spec_measurement_selection_type + ), + 'transform_dataflow_machine_type': transform_dataflow_machine_type, + 'transform_dataflow_max_num_workers': transform_dataflow_max_num_workers, + 'transform_dataflow_disk_size_gb': transform_dataflow_disk_size_gb, + 'run_evaluation': run_evaluation, + 'evaluation_batch_predict_machine_type': ( + evaluation_batch_predict_machine_type + ), + 'evaluation_batch_predict_starting_replica_count': ( + evaluation_batch_predict_starting_replica_count + ), + 'evaluation_batch_predict_max_replica_count': ( + evaluation_batch_predict_max_replica_count + ), + 'evaluation_dataflow_machine_type': evaluation_dataflow_machine_type, + 'evaluation_dataflow_starting_num_workers': ( + evaluation_dataflow_starting_num_workers + ), + 'evaluation_dataflow_max_num_workers': ( + evaluation_dataflow_max_num_workers + ), + 'evaluation_dataflow_disk_size_gb': evaluation_dataflow_disk_size_gb, + 'dataflow_service_account': dataflow_service_account, + 'dataflow_subnetwork': dataflow_subnetwork, + 'dataflow_use_public_ips': dataflow_use_public_ips, + 'encryption_spec_key_name': encryption_spec_key_name, } _update_parameters(parameter_values, training_and_eval_parameters) fte_params = { - 'dataset_level_custom_transformation_definitions': + 'dataset_level_custom_transformation_definitions': ( dataset_level_custom_transformation_definitions - if dataset_level_custom_transformation_definitions else [], - 'dataset_level_transformations': - dataset_level_transformations - if dataset_level_transformations else [], - 'run_feature_selection': - run_feature_selection, - 'feature_selection_algorithm': - feature_selection_algorithm, - 'max_selected_features': - max_selected_features, - 'predefined_split_key': - predefined_split_key, - 'stratified_split_key': - stratified_split_key, - 'training_fraction': - training_fraction, - 'validation_fraction': - validation_fraction, - 'test_fraction': - test_fraction, - 'tf_auto_transform_features': - tf_auto_transform_features if tf_auto_transform_features else [], - 'tf_custom_transformation_definitions': + if dataset_level_custom_transformation_definitions + else [] + ), + 'dataset_level_transformations': ( + dataset_level_transformations if dataset_level_transformations else [] + ), + 'run_feature_selection': run_feature_selection, + 'feature_selection_algorithm': feature_selection_algorithm, + 'max_selected_features': max_selected_features, + 'predefined_split_key': predefined_split_key, + 'stratified_split_key': stratified_split_key, + 'training_fraction': training_fraction, + 'validation_fraction': validation_fraction, + 'test_fraction': test_fraction, + 'tf_auto_transform_features': ( + tf_auto_transform_features if tf_auto_transform_features else [] + ), + 'tf_custom_transformation_definitions': ( tf_custom_transformation_definitions - if tf_custom_transformation_definitions else [], - 'tf_transformations_path': - tf_transformations_path, + if tf_custom_transformation_definitions + else [] + ), + 'tf_transformations_path': tf_transformations_path, } _update_parameters(parameter_values, fte_params) @@ -4575,6 +3873,7 @@ def get_xgboost_hyperparameter_tuning_job_pipeline_and_parameters( pipeline_definition_path = os.path.join( pathlib.Path(__file__).parent.resolve(), - 'xgboost_hyperparameter_tuning_job_pipeline.yaml') + 'xgboost_hyperparameter_tuning_job_pipeline.yaml', + ) return pipeline_definition_path, parameter_values diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/bigquery/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/bigquery/__init__.py index 699974f5d28..7095e9930a7 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/bigquery/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/bigquery/__init__.py @@ -48,86 +48,120 @@ ] BigqueryQueryJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'query_job/component.yaml')) + os.path.join(os.path.dirname(__file__), 'query_job/component.yaml') +) BigqueryCreateModelJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'create_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'create_model/component.yaml') +) BigqueryExportModelJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'export_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'export_model/component.yaml') +) BigqueryPredictModelJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'predict_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'predict_model/component.yaml') +) BigqueryExplainPredictModelJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'explain_predict_model/component.yaml')) + os.path.dirname(__file__), 'explain_predict_model/component.yaml' + ) +) BigqueryExplainForecastModelJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'explain_forecast_model/component.yaml')) + os.path.dirname(__file__), 'explain_forecast_model/component.yaml' + ) +) BigqueryEvaluateModelJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'evaluate_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'evaluate_model/component.yaml') +) BigqueryMLArimaCoefficientsJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'ml_arima_coefficients/component.yaml')) + os.path.dirname(__file__), 'ml_arima_coefficients/component.yaml' + ) +) BigqueryMLArimaEvaluateJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_arima_evaluate/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_arima_evaluate/component.yaml') +) BigqueryMLReconstructionLossJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'ml_reconstruction_loss/component.yaml')) + os.path.dirname(__file__), 'ml_reconstruction_loss/component.yaml' + ) +) BigqueryMLTrialInfoJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_trial_info/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_trial_info/component.yaml') +) BigqueryMLWeightsJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_weights/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_weights/component.yaml') +) BigqueryMLTrainingInfoJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_training_info/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_training_info/component.yaml') +) BigqueryMLAdvancedWeightsJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'ml_advanced_weights/component.yaml')) + os.path.dirname(__file__), 'ml_advanced_weights/component.yaml' + ) +) BigqueryDropModelJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'drop_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'drop_model/component.yaml') +) BigqueryMLCentroidsJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_centroids/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_centroids/component.yaml') +) BigqueryMLConfusionMatrixJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'ml_confusion_matrix/component.yaml')) + os.path.dirname(__file__), 'ml_confusion_matrix/component.yaml' + ) +) BigqueryMLFeatureInfoJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_feature_info/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_feature_info/component.yaml') +) BigqueryMLRocCurveJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_roc_curve/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_roc_curve/component.yaml') +) BigqueryMLPrincipalComponentsJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'ml_principal_components/component.yaml')) + os.path.dirname(__file__), 'ml_principal_components/component.yaml' + ) +) BigqueryMLPrincipalComponentInfoJobOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), - 'ml_principal_component_info/component.yaml')) + os.path.dirname(__file__), 'ml_principal_component_info/component.yaml' + ) +) BigqueryMLFeatureImportanceJobOp = load_component_from_file( - os.path.join( - os.path.dirname(__file__), 'feature_importance/component.yaml')) + os.path.join(os.path.dirname(__file__), 'feature_importance/component.yaml') +) BigqueryMLRecommendJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'ml_recommend/component.yaml')) + os.path.join(os.path.dirname(__file__), 'ml_recommend/component.yaml') +) BigqueryForecastModelJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'forecast_model/component.yaml')) + os.path.join(os.path.dirname(__file__), 'forecast_model/component.yaml') +) BigqueryMLGlobalExplainJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'global_explain/component.yaml')) + os.path.join(os.path.dirname(__file__), 'global_explain/component.yaml') +) BigqueryDetectAnomaliesModelJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'detect_anomalies_model/component.yaml')) + os.path.join( + os.path.dirname(__file__), 'detect_anomalies_model/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/component.py b/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/component.py index 65d5d35995b..cc1ad13d79c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/component.py @@ -24,7 +24,7 @@ def custom_training_job( project: str, display_name: str, - gcp_resources: OutputPath(str), # type: ignore + gcp_resources: OutputPath(str), # type: ignore location: str = 'us-central1', worker_pool_specs: List[Dict[str, str]] = [], timeout: str = '604800s', @@ -38,6 +38,7 @@ def custom_training_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a Custom training job using Vertex CustomJob API. Args: @@ -114,6 +115,7 @@ def custom_training_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/utils.py b/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/utils.py index bf7768c2044..ee61d682e44 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/utils.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/custom_job/utils.py @@ -30,7 +30,8 @@ def _replace_executor_placeholder( - container_input: Sequence[str]) -> Sequence[str]: + container_input: Sequence[str], +) -> Sequence[str]: """Replace executor placeholder in container command or args. Args: @@ -41,7 +42,9 @@ def _replace_executor_placeholder( """ return [ _EXECUTOR_PLACEHOLDER_REPLACEMENT - if input == _EXECUTOR_PLACEHOLDER else input for input in container_input + if input == _EXECUTOR_PLACEHOLDER + else input + for input in container_input ] @@ -180,6 +183,7 @@ def create_custom_training_job_from_component( A Custom Job component operator corresponding to the input component operator. """ + # fmt: on # This function constructs a Custom Job component based on the input # component, by performing a 3-way merge of the inputs/outputs of the # input component, the Custom Job component and the arguments given to this @@ -215,27 +219,30 @@ def create_custom_training_job_from_component( # is returned. custom_training_job_dict = json_format.MessageToDict( - component.custom_training_job.pipeline_spec) + component.custom_training_job.pipeline_spec + ) input_component_spec_dict = json_format.MessageToDict( - component_spec.pipeline_spec) # pytype: disable=attribute-error - component_spec_container = list(input_component_spec_dict['deploymentSpec'] - ['executors'].values())[0]['container'] + component_spec.pipeline_spec + ) # pytype: disable=attribute-error + component_spec_container = list( + input_component_spec_dict['deploymentSpec']['executors'].values() + )[0]['container'] # Construct worker_pool_spec worker_pool_spec = { - 'machine_spec': { - 'machine_type': machine_type - }, + 'machine_spec': {'machine_type': machine_type}, 'replica_count': 1, 'container_spec': { 'image_uri': component_spec_container['image'], - } + }, } worker_pool_spec['container_spec']['command'] = _replace_executor_placeholder( - component_spec_container.get('command', [])) + component_spec_container.get('command', []) + ) worker_pool_spec['container_spec']['args'] = _replace_executor_placeholder( - component_spec_container.get('args', [])) + component_spec_container.get('args', []) + ) if accelerator_type: worker_pool_spec['machine_spec']['accelerator_type'] = accelerator_type @@ -246,7 +253,7 @@ def create_custom_training_job_from_component( 'boot_disk_size_gb': boot_disk_size_gb, } if nfs_mounts: - worker_pool_spec['nfs_mounts'] = nfs_mounts.copy() # pytype: disable=attribute-error + worker_pool_spec['nfs_mounts'] = nfs_mounts.copy() # pytype: disable=attribute-error worker_pool_specs = [worker_pool_spec] @@ -258,92 +265,132 @@ def create_custom_training_job_from_component( # Retrieve the custom job input/output parameters custom_training_job_dict_components = custom_training_job_dict['components'] custom_training_job_comp_key = list( - custom_training_job_dict_components.keys())[0] + custom_training_job_dict_components.keys() + )[0] custom_training_job_comp_val = custom_training_job_dict_components[ - custom_training_job_comp_key] + custom_training_job_comp_key + ] custom_job_input_params = custom_training_job_comp_val['inputDefinitions'][ - 'parameters'] + 'parameters' + ] custom_job_output_params = custom_training_job_comp_val['outputDefinitions'][ - 'parameters'] + 'parameters' + ] # Insert input arguments into custom_job_input_params as default values - custom_job_input_params['display_name'][ - 'defaultValue'] = display_name or component_spec.component_spec.name # pytype: disable=attribute-error + custom_job_input_params['display_name']['defaultValue'] = ( + display_name or component_spec.component_spec.name + ) # pytype: disable=attribute-error custom_job_input_params['worker_pool_specs'][ - 'defaultValue'] = worker_pool_specs + 'defaultValue' + ] = worker_pool_specs custom_job_input_params['timeout']['defaultValue'] = timeout custom_job_input_params['restart_job_on_worker_restart'][ - 'defaultValue'] = restart_job_on_worker_restart + 'defaultValue' + ] = restart_job_on_worker_restart custom_job_input_params['service_account']['defaultValue'] = service_account custom_job_input_params['tensorboard']['defaultValue'] = tensorboard custom_job_input_params['enable_web_access'][ - 'defaultValue'] = enable_web_access + 'defaultValue' + ] = enable_web_access custom_job_input_params['network']['defaultValue'] = network - custom_job_input_params['reserved_ip_ranges'][ - 'defaultValue'] = reserved_ip_ranges or [] + custom_job_input_params['reserved_ip_ranges']['defaultValue'] = ( + reserved_ip_ranges or [] + ) custom_job_input_params['base_output_directory'][ - 'defaultValue'] = base_output_directory + 'defaultValue' + ] = base_output_directory custom_job_input_params['labels']['defaultValue'] = labels or {} custom_job_input_params['encryption_spec_key_name'][ - 'defaultValue'] = encryption_spec_key_name + 'defaultValue' + ] = encryption_spec_key_name # Merge with the input/output parameters from the input component. input_component_spec_comp_val = list( - input_component_spec_dict['components'].values())[0] + input_component_spec_dict['components'].values() + )[0] custom_job_input_params = { - **(input_component_spec_comp_val.get('inputDefinitions', - {}).get('parameters', {})), - **custom_job_input_params + **( + input_component_spec_comp_val.get('inputDefinitions', {}).get( + 'parameters', {} + ) + ), + **custom_job_input_params, } custom_job_output_params = { - **(input_component_spec_comp_val.get('outputDefinitions', - {}).get('parameters', {})), - **custom_job_output_params + **( + input_component_spec_comp_val.get('outputDefinitions', {}).get( + 'parameters', {} + ) + ), + **custom_job_output_params, } # Copy merged input/output parameters to custom_training_job_dict # Using copy.deepcopy here to avoid anchors and aliases in the produced # YAML as a result of pointing to the same dict. - custom_training_job_dict['root']['inputDefinitions'][ - 'parameters'] = copy.deepcopy(custom_job_input_params) + custom_training_job_dict['root']['inputDefinitions']['parameters'] = ( + copy.deepcopy(custom_job_input_params) + ) custom_training_job_dict['components'][custom_training_job_comp_key][ - 'inputDefinitions']['parameters'] = copy.deepcopy(custom_job_input_params) + 'inputDefinitions' + ]['parameters'] = copy.deepcopy(custom_job_input_params) custom_training_job_tasks_key = list( - custom_training_job_dict['root']['dag']['tasks'].keys())[0] + custom_training_job_dict['root']['dag']['tasks'].keys() + )[0] custom_training_job_dict['root']['dag']['tasks'][ - custom_training_job_tasks_key]['inputs']['parameters'] = { - **(list(input_component_spec_dict['root']['dag']['tasks'].values()) - [0].get('inputs', {}).get('parameters', {})), - **(custom_training_job_dict['root']['dag']['tasks'] - [custom_training_job_tasks_key]['inputs']['parameters']) - } + custom_training_job_tasks_key + ]['inputs']['parameters'] = { + **( + list(input_component_spec_dict['root']['dag']['tasks'].values())[0] + .get('inputs', {}) + .get('parameters', {}) + ), + **( + custom_training_job_dict['root']['dag']['tasks'][ + custom_training_job_tasks_key + ]['inputs']['parameters'] + ), + } custom_training_job_dict['components'][custom_training_job_comp_key][ - 'outputDefinitions']['parameters'] = custom_job_output_params + 'outputDefinitions' + ]['parameters'] = custom_job_output_params # Retrieve the input/output artifacts from the input component. custom_job_input_artifacts = input_component_spec_comp_val.get( - 'inputDefinitions', {}).get('artifacts', {}) + 'inputDefinitions', {} + ).get('artifacts', {}) custom_job_output_artifacts = input_component_spec_comp_val.get( - 'outputDefinitions', {}).get('artifacts', {}) + 'outputDefinitions', {} + ).get('artifacts', {}) # Copy input/output artifacts from the input component to # custom_training_job_dict if custom_job_input_artifacts: - custom_training_job_dict['root']['inputDefinitions'][ - 'artifacts'] = copy.deepcopy(custom_job_input_artifacts) + custom_training_job_dict['root']['inputDefinitions']['artifacts'] = ( + copy.deepcopy(custom_job_input_artifacts) + ) custom_training_job_dict['components'][custom_training_job_comp_key][ - 'inputDefinitions']['artifacts'] = copy.deepcopy( - custom_job_input_artifacts) + 'inputDefinitions' + ]['artifacts'] = copy.deepcopy(custom_job_input_artifacts) custom_training_job_dict['root']['dag']['tasks'][ - custom_training_job_tasks_key]['inputs']['artifacts'] = { - **(list(input_component_spec_dict['root']['dag']['tasks'].values()) - [0].get('inputs', {}).get('artifacts', {})), - **(custom_training_job_dict['root']['dag']['tasks'] - [custom_training_job_tasks_key]['inputs'].get('artifacts', {})) - } + custom_training_job_tasks_key + ]['inputs']['artifacts'] = { + **( + list(input_component_spec_dict['root']['dag']['tasks'].values())[0] + .get('inputs', {}) + .get('artifacts', {}) + ), + **( + custom_training_job_dict['root']['dag']['tasks'][ + custom_training_job_tasks_key + ]['inputs'].get('artifacts', {}) + ), + } if custom_job_output_artifacts: custom_training_job_dict['components'][custom_training_job_comp_key][ - 'outputDefinitions']['artifacts'] = custom_job_output_artifacts + 'outputDefinitions' + ]['artifacts'] = custom_job_output_artifacts # Create new component from component IR YAML custom_training_job_yaml = yaml.safe_dump(custom_training_job_dict) @@ -353,13 +400,16 @@ def create_custom_training_job_from_component( # TODO(b/262360354): The inner .component_spec.name is needed here as that is # the name that is retrieved by the FE for display. Can simply reference the # outer .name once setter is implemented. - new_component.component_spec.name = component_spec.component_spec.name # pytype: disable=attribute-error - - if component_spec.description: # pytype: disable=attribute-error + new_component.component_spec.name = component_spec.component_spec.name # pytype: disable=attribute-error + if component_spec.description: # pytype: disable=attribute-error # TODO(chavoshi) Add support for docstring parsing. component_description = 'A custom job that wraps ' - component_description += f'{component_spec.component_spec.name}.\n\nOriginal component' # pytype: disable=attribute-error - component_description += f' description:\n{component_spec.description}\n\nCustom' # pytype: disable=attribute-error + component_description += ( # pytype: disable=attribute-error + f'{component_spec.component_spec.name}.\n\nOriginal component' + ) + component_description += ( # pytype: disable=attribute-error + f' description:\n{component_spec.description}\n\nCustom' + ) component_description += ' Job wrapper description:\n' component_description += component.custom_training_job.description @@ -381,6 +431,7 @@ def create_custom_training_job_op_from_component(*args, **kwargs) -> Callable: A Custom Job component operator corresponding to the input component operator. """ + # fmt: on logging.warning( 'Deprecated. Please use create_custom_training_job_from_component' diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/__init__.py index f28a27b5444..16c70ac3e55 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/__init__.py @@ -24,12 +24,20 @@ # TODO(wwoo): remove try block after experimental components are migrated to v2. try: - from .flex_template import component as dataflow_flex_template_component # type: ignore - DataflowFlexTemplateJobOp = dataflow_flex_template_component.dataflow_flex_template + from .flex_template import component as dataflow_flex_template_component # type: ignore + + DataflowFlexTemplateJobOp = ( + dataflow_flex_template_component.dataflow_flex_template + ) except ImportError: + def _raise_unsupported(*args, **kwargs): - raise ImportError('DataflowFlexTemplateJobOp requires KFP SDK v2.0.0b1 or higher.') + raise ImportError( + 'DataflowFlexTemplateJobOp requires KFP SDK v2.0.0b1 or higher.' + ) + DataflowFlexTemplateJobOp = _raise_unsupported DataflowPythonJobOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'python_job/component.yaml')) + os.path.join(os.path.dirname(__file__), 'python_job/component.yaml') +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/flex_template/component.py b/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/flex_template/component.py index df6f0d1e5ea..155a976c72d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/flex_template/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/dataflow/flex_template/component.py @@ -56,6 +56,7 @@ def dataflow_flex_template( transform_name_mappings: Dict[str, str] = {}, validate_only: bool = False, ): + # fmt: off """Launch a job with a Dataflow Flex Template. Args: @@ -190,6 +191,7 @@ def dataflow_flex_template( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ @@ -207,40 +209,88 @@ def dataflow_flex_template( location, '--payload', ConcatPlaceholder([ - '{', '"launch_parameter": {', - '"job_name": "', job_name, '"', - ', "container_spec_gcs_path": "', container_spec_gcs_path, '"', - ', "parameters": ', parameters, - ', "launch_options": ', launch_options, + '{', + '"launch_parameter": {', + '"job_name": "', + job_name, + '"', + ', "container_spec_gcs_path": "', + container_spec_gcs_path, + '"', + ', "parameters": ', + parameters, + ', "launch_options": ', + launch_options, ', "environment": {', - '"num_workers": ', num_workers, - ', "max_workers": ', max_workers, - ', "service_account_email": "', service_account_email, '"', - ', "temp_location": "', temp_location, '"', - ', "machine_type": "', machine_type, '"', - ', "additional_experiments": ', additional_experiments, - ', "network": "', network, '"', - ', "subnetwork": "', subnetwork, '"', - ', "additional_user_labels": ', additional_user_labels, - ', "kms_key_name": "', kms_key_name, '"', - ', "ip_configuration": "', ip_configuration, '"', - ', "worker_region": "', worker_region, '"', - ', "worker_zone": "', worker_zone, '"', - ', "enable_streaming_engine": ', enable_streaming_engine, - ', "flexrs_goal": "', flexrs_goal, '"', - ', "staging_location": "', staging_location, '"', - ', "sdk_container_image": "', sdk_container_image, '"', - ', "disk_size_gb": ', disk_size_gb, - ', "autoscaling_algorithm": "', autoscaling_algorithm, '"', - ', "dump_heap_on_oom": ', dump_heap_on_oom, - ', "save_heap_dumps_to_gcs_path": "', save_heap_dumps_to_gcs_path, '"', - ', "launcher_machine_type": "', launcher_machine_type, '"', - ', "enable_launcher_vm_serial_port_logging": ', enable_launcher_vm_serial_port_logging, + '"num_workers": ', + num_workers, + ', "max_workers": ', + max_workers, + ', "service_account_email": "', + service_account_email, + '"', + ', "temp_location": "', + temp_location, + '"', + ', "machine_type": "', + machine_type, + '"', + ', "additional_experiments": ', + additional_experiments, + ', "network": "', + network, + '"', + ', "subnetwork": "', + subnetwork, + '"', + ', "additional_user_labels": ', + additional_user_labels, + ', "kms_key_name": "', + kms_key_name, + '"', + ', "ip_configuration": "', + ip_configuration, + '"', + ', "worker_region": "', + worker_region, + '"', + ', "worker_zone": "', + worker_zone, + '"', + ', "enable_streaming_engine": ', + enable_streaming_engine, + ', "flexrs_goal": "', + flexrs_goal, + '"', + ', "staging_location": "', + staging_location, + '"', + ', "sdk_container_image": "', + sdk_container_image, + '"', + ', "disk_size_gb": ', + disk_size_gb, + ', "autoscaling_algorithm": "', + autoscaling_algorithm, + '"', + ', "dump_heap_on_oom": ', + dump_heap_on_oom, + ', "save_heap_dumps_to_gcs_path": "', + save_heap_dumps_to_gcs_path, + '"', + ', "launcher_machine_type": "', + launcher_machine_type, + '"', + ', "enable_launcher_vm_serial_port_logging": ', + enable_launcher_vm_serial_port_logging, '}', - ', "update": ', update, - ', "transform_name_mappings": ', transform_name_mappings, + ', "update": ', + update, + ', "transform_name_mappings": ', + transform_name_mappings, '}', - ', "validate_only": ', validate_only, + ', "validate_only": ', + validate_only, '}', ]), '--gcp_resources', diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/dataproc/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/dataproc/__init__.py index e7dbc3e2d34..b0b606a0285 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/dataproc/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/dataproc/__init__.py @@ -21,17 +21,27 @@ 'DataprocPySparkBatchOp', 'DataprocSparkBatchOp', 'DataprocSparkRBatchOp', - 'DataprocSparkSqlBatchOp' + 'DataprocSparkSqlBatchOp', ] DataprocPySparkBatchOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'create_pyspark_batch/component.yaml')) + os.path.join( + os.path.dirname(__file__), 'create_pyspark_batch/component.yaml' + ) +) DataprocSparkBatchOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'create_spark_batch/component.yaml')) + os.path.join(os.path.dirname(__file__), 'create_spark_batch/component.yaml') +) DataprocSparkRBatchOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'create_spark_r_batch/component.yaml')) + os.path.join( + os.path.dirname(__file__), 'create_spark_r_batch/component.yaml' + ) +) DataprocSparkSqlBatchOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'create_spark_sql_batch/component.yaml')) + os.path.join( + os.path.dirname(__file__), 'create_spark_sql_batch/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/dataset_preprocessor/component.py b/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/dataset_preprocessor/component.py index 15c788dc385..a61b3f90272 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/dataset_preprocessor/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/dataset_preprocessor/component.py @@ -39,6 +39,7 @@ def dataset_preprocessor_error_analysis( test_dataset_storage_source_uris: list = [], training_dataset_storage_source_uris: list = [], ): + # fmt: off """Preprocesses datasets for Vision Error Analysis pipelines. Args: @@ -112,6 +113,7 @@ def dataset_preprocessor_error_analysis( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/model-evaluation:v0.9', command=['python3', '/main.py'], diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/error_analysis_annotation/component.py b/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/error_analysis_annotation/component.py index e65924953e6..7fc08c084fa 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/error_analysis_annotation/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/error_analysis_annotation/component.py @@ -32,6 +32,7 @@ def error_analysis_annotation( num_neighbors: int = 5, encryption_spec_key_name: str = '', ): + # fmt: off """Computes error analysis annotations from image embeddings. Args: @@ -61,6 +62,7 @@ def error_analysis_annotation( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ @@ -85,14 +87,19 @@ def error_analysis_annotation( f' "error-analysis-annotation-{PIPELINE_JOB_ID_PLACEHOLDER}', f'-{PIPELINE_TASK_ID_PLACEHOLDER}", ', '"job_spec": {"worker_pool_specs": [{"replica_count":"1', - '", "machine_spec": {"machine_type": "', machine_type, '"},', + '", "machine_spec": {"machine_type": "', + machine_type, + '"},', ' "container_spec": {"image_uri":"', 'us-docker.pkg.dev/vertex-ai-restricted/vision-error-analysis/error-analysis:v0.2', - '", "args": ["--embeddings_dir=', embeddings_dir, + '", "args": ["--embeddings_dir=', + embeddings_dir, '", "--root_dir=', f'{root_dir}/{PIPELINE_JOB_ID_PLACEHOLDER}-{PIPELINE_TASK_ID_PLACEHOLDER}', - '", "--num_neighbors=', num_neighbors, - '", "--error_analysis_output_uri=', error_analysis_output_uri, + '", "--num_neighbors=', + num_neighbors, + '", "--error_analysis_output_uri=', + error_analysis_output_uri, '", "--executor_input={{$.json_escape[1]}}"]}}]}', ', "encryption_spec": {"kms_key_name":"', encryption_spec_key_name, diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/feature_extractor/component.py b/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/feature_extractor/component.py index de31207ad2d..b67e86b7003 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/feature_extractor/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/evaluation/error_analysis/feature_extractor/component.py @@ -36,6 +36,7 @@ def feature_extractor_error_analysis( feature_extractor_machine_type: str = 'n1-standard-32', encryption_spec_key_name: str = '', ): + # fmt: off """Extracts feature embeddings of a dataset. Args: @@ -72,6 +73,7 @@ def feature_extractor_error_analysis( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/__init__.py index 562b8ef90af..20178bf1c76 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/__init__.py @@ -29,11 +29,15 @@ ] ForecastingPreprocessingOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'preprocess/component.yaml')) + os.path.join(os.path.dirname(__file__), 'preprocess/component.yaml') +) ForecastingValidationOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'validate/component.yaml')) + os.path.join(os.path.dirname(__file__), 'validate/component.yaml') +) ForecastingPrepareDataForTrainOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'prepare_data_for_train/component.yaml')) + os.path.dirname(__file__), 'prepare_data_for_train/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/prepare_data_for_train/component.py b/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/prepare_data_for_train/component.py index ccac239d492..d1fc8d0f183 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/prepare_data_for_train/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/forecasting/prepare_data_for_train/component.py @@ -35,17 +35,14 @@ def prepare_data_for_train( AutoMLForecastingTrainingJobRunOp. Args: - input_tables (JsonArray): - Required. Serialized Json array that specifies input BigQuery tables and - specs. - preprocess_metadata (JsonObject): - Required. The output of ForecastingPreprocessingOp that is a serialized - dictionary with 2 fields: processed_bigquery_table_uri and - column_metadata. - model_feature_columns (JsonArray): - Optional. Serialized list of column names that will be used as input - feature in the training step. If None, all columns will be used in - training. + input_tables (JsonArray): Required. Serialized Json array that specifies + input BigQuery tables and specs. + preprocess_metadata (JsonObject): Required. The output of + ForecastingPreprocessingOp that is a serialized dictionary with 2 fields: + processed_bigquery_table_uri and column_metadata. + model_feature_columns (JsonArray): Optional. Serialized list of column names + that will be used as input feature in the training step. If None, all + columns will be used in training. Returns: NamedTuple: @@ -84,7 +81,8 @@ def prepare_data_for_train( bigquery_table_uri = preprocess_metadata['processed_bigquery_table_uri'] primary_table_specs = next( - table for table in input_tables + table + for table in input_tables if table['table_type'] == 'FORECASTING_PRIMARY' ) primary_metadata = primary_table_specs['forecasting_primary_table_metadata'] @@ -99,11 +97,15 @@ def prepare_data_for_train( unavailable_at_forecast_columns = [] column_transformations = [] predefined_split_column = ( - '' if 'predefined_splits_column' not in primary_metadata - else primary_metadata['predefined_splits_column']) + '' + if 'predefined_splits_column' not in primary_metadata + else primary_metadata['predefined_splits_column'] + ) weight_column = ( - '' if 'weight_column' not in primary_metadata - else primary_metadata['weight_column']) + '' + if 'weight_column' not in primary_metadata + else primary_metadata['weight_column'] + ) for name, details in column_metadata.items(): if name == predefined_split_column or name == weight_column: @@ -113,9 +115,7 @@ def prepare_data_for_train( time_series_identifier_column = name continue elif details['tag'] == 'primary_table': - if name in ( - primary_metadata['time_series_identifier_columns'] - ): + if name in (primary_metadata['time_series_identifier_columns']): time_series_attribute_columns.append(name) elif name == primary_metadata['target_column']: unavailable_at_forecast_columns.append(name) @@ -146,20 +146,23 @@ def prepare_data_for_train( trans = {'timestamp': {'column_name': name}} column_transformations.append(trans) - return NamedTuple('Outputs', [ - ('time_series_identifier_column', str), - ('time_series_attribute_columns', list), - ('available_at_forecast_columns', list), - ('unavailable_at_forecast_columns', list), - ('column_transformations', list), - ('preprocess_bq_uri', str), - ('target_column', str), - ('time_column', str), - ('predefined_split_column', str), - ('weight_column', str), - ('data_granularity_unit', str), - ('data_granularity_count', int), - ])( + return NamedTuple( + 'Outputs', + [ + ('time_series_identifier_column', str), + ('time_series_attribute_columns', list), + ('available_at_forecast_columns', list), + ('unavailable_at_forecast_columns', list), + ('column_transformations', list), + ('preprocess_bq_uri', str), + ('target_column', str), + ('time_column', str), + ('predefined_split_column', str), + ('weight_column', str), + ('data_granularity_unit', str), + ('data_granularity_count', int), + ], + )( time_series_identifier_column, time_series_attribute_columns, available_at_forecast_columns, diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/__init__.py index 0fa104da450..5fe2ccff39e 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/__init__.py @@ -32,4 +32,5 @@ HyperparameterTuningJobRunOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'component.yaml')) + os.path.join(os.path.dirname(__file__), 'component.yaml') +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/utils.py b/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/utils.py index 48776784cf6..aecece13056 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/utils.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/hyperparameter_tuning_job/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# fmt: off """Module for supporting Google Vertex AI Hyperparameter Tuning Job Op.""" from google.cloud.aiplatform_v1.types import study diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/__init__.py index f78984080ca..04309fa4a9d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/__init__.py @@ -18,10 +18,10 @@ from kfp.components import load_component_from_file -__all__ = [ - 'ConvertDatasetExportForBatchPredictOp', - 'TrainTextClassificationOp' -] +__all__ = ['ConvertDatasetExportForBatchPredictOp', 'TrainTextClassificationOp'] TrainTextClassificationOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'train_text_classification/component.yaml')) + os.path.join( + os.path.dirname(__file__), 'train_text_classification/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/convert_dataset_export_for_batch_predict/component.py b/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/convert_dataset_export_for_batch_predict/component.py index 53e6d8c25e0..09d7c4e0294 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/convert_dataset_export_for_batch_predict/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/natural_language/convert_dataset_export_for_batch_predict/component.py @@ -18,10 +18,12 @@ @dsl.component( - base_image="us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-8:latest") + base_image="us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-8:latest" +) def convert_dataset_export_for_batch_predict( - file_paths: List[str], classification_type: str, - output_dir: dsl.OutputPath(list) + file_paths: List[str], + classification_type: str, + output_dir: dsl.OutputPath(list), ) -> NamedTuple("Outputs", [("output_files", list)]): """Converts classification dataset export for batch prediction input. @@ -86,7 +88,8 @@ def convert_dataset_export_for_batch_predict( ] else: result_obj[LABELS_KEY] = json_obj[CLASSIFICATION_ANNOTATION_KEY][ - DISPLAY_NAME_KEY] + DISPLAY_NAME_KEY + ] results_file.write(json.dumps(result_obj) + "\n") # Subsequent components will not understand "/gcs/" prefix. Convert to use # "gs://" prefix for compatibility. diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/notebooks/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/notebooks/__init__.py index c69d7dfad7a..2f401b7b222 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/notebooks/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/notebooks/__init__.py @@ -23,4 +23,5 @@ ] NotebooksExecutorOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'executor/component.yaml')) + os.path.join(os.path.dirname(__file__), 'executor/component.yaml') +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/sklearn/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/sklearn/__init__.py index 8f61410c64b..f9c96e31fba 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/sklearn/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/sklearn/__init__.py @@ -17,9 +17,10 @@ from kfp.components import load_component_from_file -__all__ = [ - 'SklearnTrainTestSplitJsonlOp' -] +__all__ = ['SklearnTrainTestSplitJsonlOp'] SklearnTrainTestSplitJsonlOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'train_test_split_jsonl/component.yaml')) + os.path.join( + os.path.dirname(__file__), 'train_test_split_jsonl/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/tensorboard/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/tensorboard/__init__.py index 2a988228f1a..2935ec7491b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/tensorboard/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/tensorboard/__init__.py @@ -22,4 +22,7 @@ ] TensorboardExperimentCreatorOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'tensorboard_experiment_creator.yaml')) + os.path.join( + os.path.dirname(__file__), 'tensorboard_experiment_creator.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/text_classification/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/text_classification/__init__.py index 082efeaaefb..e336d6a357d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/text_classification/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/text_classification/__init__.py @@ -17,9 +17,10 @@ from kfp.components import load_component_from_file -__all__ = [ - 'TextClassificationTrainingOp' -] +__all__ = ['TextClassificationTrainingOp'] TextClassificationTrainingOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'train_tensorflow_model/component.yaml')) + os.path.join( + os.path.dirname(__file__), 'train_tensorflow_model/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/vertex_notification_email/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/vertex_notification_email/__init__.py index 35be44e4c8d..93c1193006b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/vertex_notification_email/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/vertex_notification_email/__init__.py @@ -22,4 +22,5 @@ ] VertexNotificationEmailOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'component.yaml')) + os.path.join(os.path.dirname(__file__), 'component.yaml') +) diff --git a/components/google-cloud/google_cloud_pipeline_components/experimental/wait_gcp_resources/__init__.py b/components/google-cloud/google_cloud_pipeline_components/experimental/wait_gcp_resources/__init__.py index e05e77cb293..0f5e69b7520 100644 --- a/components/google-cloud/google_cloud_pipeline_components/experimental/wait_gcp_resources/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/experimental/wait_gcp_resources/__init__.py @@ -23,4 +23,5 @@ ] WaitGcpResourcesOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'component.yaml')) + os.path.join(os.path.dirname(__file__), 'component.yaml') +) diff --git a/components/google-cloud/google_cloud_pipeline_components/proto/gcp_resources_pb2.py b/components/google-cloud/google_cloud_pipeline_components/proto/gcp_resources_pb2.py index 41e6ee74b3d..a3ae9009502 100644 --- a/components/google-cloud/google_cloud_pipeline_components/proto/gcp_resources_pb2.py +++ b/components/google-cloud/google_cloud_pipeline_components/proto/gcp_resources_pb2.py @@ -15,136 +15,220 @@ DESCRIPTOR = _descriptor.FileDescriptor( - name='gcp_resources.proto', - package='gcp_launcher', - syntax='proto3', - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x13gcp_resources.proto\x12\x0cgcp_launcher\x1a\x17google/rpc/status.proto\"\xe0\x01\n\x0cGcpResources\x12\x36\n\tresources\x18\x01 \x03(\x0b\x32#.gcp_launcher.GcpResources.Resource\x1a\x97\x01\n\x08Resource\x12\x1a\n\rresource_type\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0cresource_uri\x18\x02 \x01(\tH\x01\x88\x01\x01\x12!\n\x05\x65rror\x18\x03 \x01(\x0b\x32\x12.google.rpc.Status\x12\x0e\n\x06labels\x18\x04 \x03(\tB\x10\n\x0e_resource_typeB\x0f\n\r_resource_urib\x06proto3' - , - dependencies=[google_dot_rpc_dot_status__pb2.DESCRIPTOR,]) - - + name='gcp_resources.proto', + package='gcp_launcher', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=( + b'\n\x13gcp_resources.proto\x12\x0cgcp_launcher\x1a\x17google/rpc/status.proto"\xe0\x01\n\x0cGcpResources\x12\x36\n\tresources\x18\x01' + b' \x03(\x0b\x32#.gcp_launcher.GcpResources.Resource\x1a\x97\x01\n\x08Resource\x12\x1a\n\rresource_type\x18\x01' + b' \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0cresource_uri\x18\x02' + b' \x01(\tH\x01\x88\x01\x01\x12!\n\x05\x65rror\x18\x03' + b' \x01(\x0b\x32\x12.google.rpc.Status\x12\x0e\n\x06labels\x18\x04' + b' \x03(\tB\x10\n\x0e_resource_typeB\x0f\n\r_resource_urib\x06proto3' + ), + dependencies=[ + google_dot_rpc_dot_status__pb2.DESCRIPTOR, + ], +) _GCPRESOURCES_RESOURCE = _descriptor.Descriptor( - name='Resource', - full_name='gcp_launcher.GcpResources.Resource', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='resource_type', full_name='gcp_launcher.GcpResources.Resource.resource_type', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='resource_uri', full_name='gcp_launcher.GcpResources.Resource.resource_uri', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='error', full_name='gcp_launcher.GcpResources.Resource.error', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='labels', full_name='gcp_launcher.GcpResources.Resource.labels', index=3, - number=4, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='_resource_type', full_name='gcp_launcher.GcpResources.Resource._resource_type', - index=0, containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[]), - _descriptor.OneofDescriptor( - name='_resource_uri', full_name='gcp_launcher.GcpResources.Resource._resource_uri', - index=1, containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[]), - ], - serialized_start=136, - serialized_end=287, + name='Resource', + full_name='gcp_launcher.GcpResources.Resource', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='resource_type', + full_name='gcp_launcher.GcpResources.Resource.resource_type', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b''.decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name='resource_uri', + full_name='gcp_launcher.GcpResources.Resource.resource_uri', + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b''.decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name='error', + full_name='gcp_launcher.GcpResources.Resource.error', + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name='labels', + full_name='gcp_launcher.GcpResources.Resource.labels', + index=3, + number=4, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='_resource_type', + full_name='gcp_launcher.GcpResources.Resource._resource_type', + index=0, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[], + ), + _descriptor.OneofDescriptor( + name='_resource_uri', + full_name='gcp_launcher.GcpResources.Resource._resource_uri', + index=1, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[], + ), + ], + serialized_start=136, + serialized_end=287, ) _GCPRESOURCES = _descriptor.Descriptor( - name='GcpResources', - full_name='gcp_launcher.GcpResources', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='resources', full_name='gcp_launcher.GcpResources.resources', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[_GCPRESOURCES_RESOURCE, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=63, - serialized_end=287, + name='GcpResources', + full_name='gcp_launcher.GcpResources', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='resources', + full_name='gcp_launcher.GcpResources.resources', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[ + _GCPRESOURCES_RESOURCE, + ], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=63, + serialized_end=287, ) # pytype: disable=module-attr -_GCPRESOURCES_RESOURCE.fields_by_name['error'].message_type = google_dot_rpc_dot_status__pb2._STATUS +_GCPRESOURCES_RESOURCE.fields_by_name['error'].message_type = ( + google_dot_rpc_dot_status__pb2._STATUS +) _GCPRESOURCES_RESOURCE.containing_type = _GCPRESOURCES _GCPRESOURCES_RESOURCE.oneofs_by_name['_resource_type'].fields.append( - _GCPRESOURCES_RESOURCE.fields_by_name['resource_type']) -_GCPRESOURCES_RESOURCE.fields_by_name['resource_type'].containing_oneof = _GCPRESOURCES_RESOURCE.oneofs_by_name['_resource_type'] + _GCPRESOURCES_RESOURCE.fields_by_name['resource_type'] +) +_GCPRESOURCES_RESOURCE.fields_by_name['resource_type'].containing_oneof = ( + _GCPRESOURCES_RESOURCE.oneofs_by_name['_resource_type'] +) _GCPRESOURCES_RESOURCE.oneofs_by_name['_resource_uri'].fields.append( - _GCPRESOURCES_RESOURCE.fields_by_name['resource_uri']) -_GCPRESOURCES_RESOURCE.fields_by_name['resource_uri'].containing_oneof = _GCPRESOURCES_RESOURCE.oneofs_by_name['_resource_uri'] + _GCPRESOURCES_RESOURCE.fields_by_name['resource_uri'] +) +_GCPRESOURCES_RESOURCE.fields_by_name['resource_uri'].containing_oneof = ( + _GCPRESOURCES_RESOURCE.oneofs_by_name['_resource_uri'] +) _GCPRESOURCES.fields_by_name['resources'].message_type = _GCPRESOURCES_RESOURCE DESCRIPTOR.message_types_by_name['GcpResources'] = _GCPRESOURCES _sym_db.RegisterFileDescriptor(DESCRIPTOR) -GcpResources = _reflection.GeneratedProtocolMessageType('GcpResources', (_message.Message,), { - - 'Resource' : _reflection.GeneratedProtocolMessageType('Resource', (_message.Message,), { - 'DESCRIPTOR' : _GCPRESOURCES_RESOURCE, - '__module__' : 'gcp_resources_pb2' - # @@protoc_insertion_point(class_scope:gcp_launcher.GcpResources.Resource) - }) - , - 'DESCRIPTOR' : _GCPRESOURCES, - '__module__' : 'gcp_resources_pb2' - # @@protoc_insertion_point(class_scope:gcp_launcher.GcpResources) - }) +GcpResources = _reflection.GeneratedProtocolMessageType( + 'GcpResources', + (_message.Message,), + { + 'Resource': _reflection.GeneratedProtocolMessageType( + 'Resource', + (_message.Message,), + { + 'DESCRIPTOR': _GCPRESOURCES_RESOURCE, + '__module__': 'gcp_resources_pb2', + # @@protoc_insertion_point(class_scope:gcp_launcher.GcpResources.Resource) + }, + ), + 'DESCRIPTOR': _GCPRESOURCES, + '__module__': 'gcp_resources_pb2', + # @@protoc_insertion_point(class_scope:gcp_launcher.GcpResources) + }, +) _sym_db.RegisterMessage(GcpResources) _sym_db.RegisterMessage(GcpResources.Resource) diff --git a/components/google-cloud/google_cloud_pipeline_components/types/artifact_types.py b/components/google-cloud/google_cloud_pipeline_components/types/artifact_types.py index b2b7e8a3e00..a7adf4e323d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/types/artifact_types.py +++ b/components/google-cloud/google_cloud_pipeline_components/types/artifact_types.py @@ -31,6 +31,7 @@ def add_type_name(cls): return add_type_name + @google_artifact('google.VertexModel') class VertexModel(dsl.Artifact): """An artifact representing a Vertex Model.""" @@ -38,21 +39,22 @@ class VertexModel(dsl.Artifact): def __init__(self, name: str, uri: str, model_resource_name: str): """Args: - name: The artifact name. - uri: the Vertex Model resource uri, in a form of - https://{service-endpoint}/v1/projects/{project}/locations/{location}/models/{model}, - where - {service-endpoint} is one of the supported service endpoints at - https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints - model_resource_name: The name of the Model resource, in a form of - projects/{project}/locations/{location}/models/{model}. For - more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models/get + name: The artifact name. + uri: the Vertex Model resource uri, in a form of + https://{service-endpoint}/v1/projects/{project}/locations/{location}/models/{model}, + where + {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + model_resource_name: The name of the Model resource, in a form of + projects/{project}/locations/{location}/models/{model}. For + more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models/get """ super().__init__( uri=uri, name=name, - metadata={ARTIFACT_PROPERTY_KEY_RESOURCE_NAME: model_resource_name}) + metadata={ARTIFACT_PROPERTY_KEY_RESOURCE_NAME: model_resource_name}, + ) @google_artifact('google.VertexEndpoint') @@ -62,58 +64,61 @@ class VertexEndpoint(dsl.Artifact): def __init__(self, name: str, uri: str, endpoint_resource_name: str): """Args: - name: The artifact name. - uri: the Vertex Endpoint resource uri, in a form of - https://{service-endpoint}/v1/projects/{project}/locations/{location}/endpoints/{endpoint}, - where - {service-endpoint} is one of the supported service endpoints at - https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints - endpoint_resource_name: The name of the Endpoint resource, in a form of - projects/{project}/locations/{location}/endpoints/{endpoint}. For - more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/get + name: The artifact name. + uri: the Vertex Endpoint resource uri, in a form of + https://{service-endpoint}/v1/projects/{project}/locations/{location}/endpoints/{endpoint}, + where + {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + endpoint_resource_name: The name of the Endpoint resource, in a form of + projects/{project}/locations/{location}/endpoints/{endpoint}. For + more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/get """ super().__init__( uri=uri, name=name, - metadata={ARTIFACT_PROPERTY_KEY_RESOURCE_NAME: endpoint_resource_name}) + metadata={ARTIFACT_PROPERTY_KEY_RESOURCE_NAME: endpoint_resource_name}, + ) @google_artifact('google.VertexBatchPredictionJob') class VertexBatchPredictionJob(dsl.Artifact): """An artifact representing a Vertex BatchPredictionJob.""" - def __init__(self, - name: str, - uri: str, - job_resource_name: str, - bigquery_output_table: Optional[str] = None, - bigquery_output_dataset: Optional[str] = None, - gcs_output_directory: Optional[str] = None): + def __init__( + self, + name: str, + uri: str, + job_resource_name: str, + bigquery_output_table: Optional[str] = None, + bigquery_output_dataset: Optional[str] = None, + gcs_output_directory: Optional[str] = None, + ): """Args: - name: The artifact name. - uri: the Vertex Batch Prediction resource uri, in a form of - https://{service-endpoint}/v1/projects/{project}/locations/{location}/batchPredictionJobs/{batchPredictionJob}, - where {service-endpoint} is one of the supported service endpoints at - https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints - job_resource_name: The name of the batch prediction job resource, - in a form of - projects/{project}/locations/{location}/batchPredictionJobs/{batchPredictionJob}. - For more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs/get - bigquery_output_table: The name of the BigQuery table created, in - predictions_ format, into which the prediction output is - written. For more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#outputinfo - bigquery_output_dataset: The path of the BigQuery dataset created, in - bq://projectId.bqDatasetId format, into which the prediction output is - written. For more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#outputinfo - gcs_output_directory: The full path of the Cloud Storage directory - created, into which the prediction output is written. For more details, - see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#outputinfo + name: The artifact name. + uri: the Vertex Batch Prediction resource uri, in a form of + https://{service-endpoint}/v1/projects/{project}/locations/{location}/batchPredictionJobs/{batchPredictionJob}, + where {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + job_resource_name: The name of the batch prediction job resource, + in a form of + projects/{project}/locations/{location}/batchPredictionJobs/{batchPredictionJob}. + For more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs/get + bigquery_output_table: The name of the BigQuery table created, in + predictions_ format, into which the prediction output is + written. For more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#outputinfo + bigquery_output_dataset: The path of the BigQuery dataset created, in + bq://projectId.bqDatasetId format, into which the prediction output is + written. For more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#outputinfo + gcs_output_directory: The full path of the Cloud Storage directory + created, into which the prediction output is written. For more details, + see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#outputinfo """ super().__init__( uri=uri, @@ -122,8 +127,9 @@ def __init__(self, ARTIFACT_PROPERTY_KEY_RESOURCE_NAME: job_resource_name, 'bigqueryOutputTable': bigquery_output_table, 'bigqueryOutputDataset': bigquery_output_dataset, - 'gcsOutputDirectory': gcs_output_directory - }) + 'gcsOutputDirectory': gcs_output_directory, + }, + ) @google_artifact('google.VertexDataset') @@ -133,38 +139,40 @@ class VertexDataset(dsl.Artifact): def __init__(self, name: str, uri: str, dataset_resource_name: str): """Args: - name: The artifact name. - uri: the Vertex Dataset resource uri, in a form of - https://{service-endpoint}/v1/projects/{project}/locations/{location}/datasets/{datasets_name}, - where - {service-endpoint} is one of the supported service endpoints at - https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints - dataset_resource_name: The name of the Dataset resource, in a form of - projects/{project}/locations/{location}/datasets/{datasets_name}. For - more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.datasets/get + name: The artifact name. + uri: the Vertex Dataset resource uri, in a form of + https://{service-endpoint}/v1/projects/{project}/locations/{location}/datasets/{datasets_name}, + where + {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + dataset_resource_name: The name of the Dataset resource, in a form of + projects/{project}/locations/{location}/datasets/{datasets_name}. For + more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.datasets/get """ super().__init__( uri=uri, name=name, - metadata={ARTIFACT_PROPERTY_KEY_RESOURCE_NAME: dataset_resource_name}) + metadata={ARTIFACT_PROPERTY_KEY_RESOURCE_NAME: dataset_resource_name}, + ) @google_artifact('google.BQMLModel') class BQMLModel(dsl.Artifact): """An artifact representing a BQML Model.""" - def __init__(self, name: str, project_id: str, dataset_id: str, - model_id: str): + def __init__( + self, name: str, project_id: str, dataset_id: str, model_id: str + ): """Args: - name: The artifact name. - project_id: The ID of the project containing this model. - dataset_id: The ID of the dataset containing this model. - model_id: The ID of the model. + name: The artifact name. + project_id: The ID of the project containing this model. + dataset_id: The ID of the dataset containing this model. + model_id: The ID of the model. - For more details, see - https://cloud.google.com/bigquery/docs/reference/rest/v2/models#ModelReference + For more details, see + https://cloud.google.com/bigquery/docs/reference/rest/v2/models#ModelReference """ super().__init__( uri=f'https://www.googleapis.com/bigquery/v2/projects/{project_id}/datasets/{dataset_id}/models/{model_id}', @@ -172,25 +180,27 @@ def __init__(self, name: str, project_id: str, dataset_id: str, metadata={ 'projectId': project_id, 'datasetId': dataset_id, - 'modelId': model_id - }) + 'modelId': model_id, + }, + ) @google_artifact('google.BQTable') class BQTable(dsl.Artifact): """An artifact representing a BQ Table.""" - def __init__(self, name: str, project_id: str, dataset_id: str, - table_id: str): + def __init__( + self, name: str, project_id: str, dataset_id: str, table_id: str + ): """Args: - name: The artifact name. - project_id: The ID of the project containing this table. - dataset_id: The ID of the dataset containing this table. - table_id: The ID of the table. + name: The artifact name. + project_id: The ID of the project containing this table. + dataset_id: The ID of the dataset containing this table. + table_id: The ID of the table. - For more details, see - https://cloud.google.com/bigquery/docs/reference/rest/v2/TableReference + For more details, see + https://cloud.google.com/bigquery/docs/reference/rest/v2/TableReference """ super().__init__( uri=f'https://www.googleapis.com/bigquery/v2/projects/{project_id}/datasets/{dataset_id}/tables/{table_id}', @@ -198,8 +208,9 @@ def __init__(self, name: str, project_id: str, dataset_id: str, metadata={ 'projectId': project_id, 'datasetId': dataset_id, - 'tableId': table_id - }) + 'tableId': table_id, + }, + ) @google_artifact('google.UnmanagedContainerModel') @@ -209,16 +220,18 @@ class UnmanagedContainerModel(dsl.Artifact): def __init__(self, predict_schemata: Dict, container_spec: Dict): """Args: - predict_schemata: Contains the schemata used in Model's predictions and - explanations via PredictionService.Predict, PredictionService.Explain - and BatchPredictionJob. For more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/PredictSchemata - container_spec: Specification of a container for serving predictions. - Some fields in this message correspond to fields in the Kubernetes - Container v1 core specification. For more details, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/ModelContainerSpec + predict_schemata: Contains the schemata used in Model's predictions and + explanations via PredictionService.Predict, PredictionService.Explain + and BatchPredictionJob. For more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/PredictSchemata + container_spec: Specification of a container for serving predictions. + Some fields in this message correspond to fields in the Kubernetes + Container v1 core specification. For more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/ModelContainerSpec """ - super().__init__(metadata={ - 'predictSchemata': predict_schemata, - 'containerSpec': container_spec - }) + super().__init__( + metadata={ + 'predictSchemata': predict_schemata, + 'containerSpec': container_spec, + } + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/batch_predict_job/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/batch_predict_job/component.py index a417f6c5879..5e2b30b3ba0 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/batch_predict_job/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/batch_predict_job/component.py @@ -58,6 +58,7 @@ def model_batch_predict( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Creates a Google Cloud Vertex BatchPredictionJob and waits for it to complete. For more details, see @@ -248,6 +249,7 @@ def model_batch_predict( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/__init__.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/__init__.py index ca5c01f9a90..2d5a8436c36 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/__init__.py @@ -72,27 +72,59 @@ ] BigqueryCreateModelJobOp = create_model_component.bigquery_create_model_job -BigqueryDetectAnomaliesModelJobOp = detect_anomalies_model_component.bigquery_detect_anomalies_job +BigqueryDetectAnomaliesModelJobOp = ( + detect_anomalies_model_component.bigquery_detect_anomalies_job +) BigqueryDropModelJobOp = drop_model_component.bigquery_drop_model_job -BigqueryEvaluateModelJobOp = evaluate_model_component.bigquery_evaluate_model_job -BigqueryExplainForecastModelJobOp = explain_forecast_model_component.bigquery_explain_forecast_model_job -BigqueryExplainPredictModelJobOp = explain_predict_model.bigquery_explain_predict_model_job +BigqueryEvaluateModelJobOp = ( + evaluate_model_component.bigquery_evaluate_model_job +) +BigqueryExplainForecastModelJobOp = ( + explain_forecast_model_component.bigquery_explain_forecast_model_job +) +BigqueryExplainPredictModelJobOp = ( + explain_predict_model.bigquery_explain_predict_model_job +) BigqueryExportModelJobOp = export_model_component.bigquery_export_model_job -BigqueryForecastModelJobOp = forecast_model_component.bigquery_forecast_model_job -BigqueryMLAdvancedWeightsJobOp = ml_advanced_weights_component.bigquery_ml_advanced_weights_job -BigqueryMLArimaCoefficientsJobOp = ml_arima_coefficients_component.bigquery_ml_arima_coefficients -BigqueryMLArimaEvaluateJobOp = ml_arima_evaluate_component.bigquery_ml_arima_evaluate_job +BigqueryForecastModelJobOp = ( + forecast_model_component.bigquery_forecast_model_job +) +BigqueryMLAdvancedWeightsJobOp = ( + ml_advanced_weights_component.bigquery_ml_advanced_weights_job +) +BigqueryMLArimaCoefficientsJobOp = ( + ml_arima_coefficients_component.bigquery_ml_arima_coefficients +) +BigqueryMLArimaEvaluateJobOp = ( + ml_arima_evaluate_component.bigquery_ml_arima_evaluate_job +) BigqueryMLCentroidsJobOp = ml_centroids_component.bigquery_ml_centroids_job -BigqueryMLConfusionMatrixJobOp = ml_confusion_matrix_component.bigquery_ml_confusion_matrix_job -BigqueryMLFeatureImportanceJobOp = feature_importance_component.bigquery_ml_feature_importance_job -BigqueryMLFeatureInfoJobOp = ml_feature_info_component.bigquery_ml_feature_info_job -BigqueryMLGlobalExplainJobOp = global_explain_component.bigquery_ml_global_explain_job -BigqueryMLPrincipalComponentInfoJobOp = ml_principal_component_info_component.bigquery_ml_principal_component_info_job -BigqueryMLPrincipalComponentsJobOp = ml_principal_components_component.bigquery_ml_principal_components_job +BigqueryMLConfusionMatrixJobOp = ( + ml_confusion_matrix_component.bigquery_ml_confusion_matrix_job +) +BigqueryMLFeatureImportanceJobOp = ( + feature_importance_component.bigquery_ml_feature_importance_job +) +BigqueryMLFeatureInfoJobOp = ( + ml_feature_info_component.bigquery_ml_feature_info_job +) +BigqueryMLGlobalExplainJobOp = ( + global_explain_component.bigquery_ml_global_explain_job +) +BigqueryMLPrincipalComponentInfoJobOp = ( + ml_principal_component_info_component.bigquery_ml_principal_component_info_job +) +BigqueryMLPrincipalComponentsJobOp = ( + ml_principal_components_component.bigquery_ml_principal_components_job +) BigqueryMLRecommendJobOp = ml_recommend_component.bigquery_ml_recommend_job -BigqueryMLReconstructionLossJobOp = ml_reconstruction_loss_component.bigquery_ml_reconstruction_loss_job +BigqueryMLReconstructionLossJobOp = ( + ml_reconstruction_loss_component.bigquery_ml_reconstruction_loss_job +) BigqueryMLRocCurveJobOp = ml_roc_curve_component.bigquery_ml_roc_curve_job -BigqueryMLTrainingInfoJobOp = ml_training_info_component.bigquery_ml_training_info_job +BigqueryMLTrainingInfoJobOp = ( + ml_training_info_component.bigquery_ml_training_info_job +) BigqueryMLTrialInfoJobOp = ml_trial_info_component.bigquery_ml_trial_info_job BigqueryMLWeightsJobOp = ml_weights_component.bigquery_ml_weights_job BigqueryPredictModelJobOp = predict_model_component.bigquery_predict_model_job diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/create_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/create_model/component.py index adbc424343f..f05e8c86e0b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/create_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/create_model/component.py @@ -34,6 +34,7 @@ def bigquery_create_model_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery create model job and waits for it to finish. Args: @@ -73,6 +74,7 @@ def bigquery_create_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/detect_anomalies_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/detect_anomalies_model/component.py index f37cac26aa3..e2f83a205f5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/detect_anomalies_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/detect_anomalies_model/component.py @@ -40,6 +40,7 @@ def bigquery_detect_anomalies_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery detect anomalies model job and waits for it to finish. Args: @@ -114,6 +115,7 @@ def bigquery_detect_anomalies_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/drop_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/drop_model/component.py index 1b62000b41e..2db286c1f80 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/drop_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/drop_model/component.py @@ -32,6 +32,7 @@ def bigquery_drop_model_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery drop model job and waits for it to finish. Args: @@ -68,6 +69,7 @@ def bigquery_drop_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/evaluate_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/evaluate_model/component.py index 474afad2c2a..3130ad64a12 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/evaluate_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/evaluate_model/component.py @@ -39,6 +39,7 @@ def bigquery_evaluate_model_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery evaluate model job and waits for it to finish. Args: @@ -103,6 +104,7 @@ def bigquery_evaluate_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_forecast_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_forecast_model/component.py index f6cfa83ad2b..250a3d3c952 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_forecast_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_forecast_model/component.py @@ -38,6 +38,7 @@ def bigquery_explain_forecast_model_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.EXPLAIN_FORECAST job and let you explain forecast an ARIMA_PLUS or ARIMA model. This function only applies to the time-series ARIMA_PLUS and ARIMA models. @@ -97,6 +98,7 @@ def bigquery_explain_forecast_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_predict_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_predict_model/component.py index 326c080d8fb..b420862cc17 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_predict_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/explain_predict_model/component.py @@ -40,6 +40,7 @@ def bigquery_explain_predict_model_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery explain predict model job and waits for it to finish. Args: @@ -118,6 +119,7 @@ def bigquery_explain_predict_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/export_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/export_model/component.py index 8b12eaff8ca..f8a8346db51 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/export_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/export_model/component.py @@ -33,6 +33,7 @@ def bigquery_export_model_job( job_configuration_extract: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery export model job and waits for it to finish. Args: @@ -69,6 +70,7 @@ def bigquery_export_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/feature_importance/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/feature_importance/component.py index f9b5febc08c..7410ac01469 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/feature_importance/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/feature_importance/component.py @@ -36,6 +36,7 @@ def bigquery_ml_feature_importance_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery feature importance fetching job and waits for it to finish. Args: @@ -84,6 +85,7 @@ def bigquery_ml_feature_importance_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/forecast_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/forecast_model/component.py index d5976311427..7ca297c8aa5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/forecast_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/forecast_model/component.py @@ -38,6 +38,7 @@ def bigquery_forecast_model_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.FORECAST job and let you forecast an ARIMA_PLUS or ARIMA model. This function only applies to the time-series ARIMA_PLUS and ARIMA models. @@ -97,6 +98,7 @@ def bigquery_forecast_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/global_explain/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/global_explain/component.py index 8147ea8bd99..eeee329b35a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/global_explain/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/global_explain/component.py @@ -37,6 +37,7 @@ def bigquery_ml_global_explain_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery global explain fetching job and waits for it to finish. Args: @@ -68,6 +69,7 @@ def bigquery_ml_global_explain_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_advanced_weights/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_advanced_weights/component.py index f91ada1be75..74b0d92cf3a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_advanced_weights/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_advanced_weights/component.py @@ -35,6 +35,7 @@ def bigquery_ml_advanced_weights_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery ml advanced weights job and waits for it to finish. Args: @@ -77,6 +78,7 @@ def bigquery_ml_advanced_weights_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_coefficients/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_coefficients/component.py index 1c2ffab2597..fbb207d1ae9 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_coefficients/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_coefficients/component.py @@ -36,6 +36,7 @@ def bigquery_ml_arima_coefficients( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.ARIMA_COEFFICIENTS job and let you see the ARIMA coefficients. This function only applies to the time-series ARIMA_PLUS and ARIMA models. @@ -79,6 +80,7 @@ def bigquery_ml_arima_coefficients( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_evaluate/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_evaluate/component.py index a4d569dcb67..0ced7638485 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_evaluate/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_arima_evaluate/component.py @@ -37,6 +37,7 @@ def bigquery_ml_arima_evaluate_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.ARIMA_EVALUATE job and waits for it to finish. Args: @@ -93,6 +94,7 @@ def bigquery_ml_arima_evaluate_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_centroids/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_centroids/component.py index 9c969cc00e3..d21105aa5b1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_centroids/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_centroids/component.py @@ -37,6 +37,7 @@ def bigquery_ml_centroids_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.CENTROIDS job and waits for it to finish. Args: @@ -91,11 +92,14 @@ def bigquery_ml_centroids_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ - 'python3', '-u', '-m', - 'google_cloud_pipeline_components.container.v1.bigquery.ml_centroids.launcher' + 'python3', + '-u', + '-m', + 'google_cloud_pipeline_components.container.v1.bigquery.ml_centroids.launcher', ], args=[ '--type', @@ -106,25 +110,39 @@ def bigquery_ml_centroids_job( location, '--model_name', ConcatPlaceholder([ - "{{$.inputs.artifacts['model'].metadata['projectId']}}", '.', - "{{$.inputs.artifacts['model'].metadata['datasetId']}}", '.', - "{{$.inputs.artifacts['model'].metadata['modelId']}}" + "{{$.inputs.artifacts['model'].metadata['projectId']}}", + '.', + "{{$.inputs.artifacts['model'].metadata['datasetId']}}", + '.', + "{{$.inputs.artifacts['model'].metadata['modelId']}}", ]), '--standardize', standardize, '--payload', ConcatPlaceholder([ - '{', '"configuration": {', '"query": ', job_configuration_query, - ', "labels": ', labels, '}', '}' + '{', + '"configuration": {', + '"query": ', + job_configuration_query, + ', "labels": ', + labels, + '}', + '}', ]), '--job_configuration_query_override', ConcatPlaceholder([ - '{', '"query_parameters": ', query_parameters, - ', "destination_encryption_configuration": {', '"kmsKeyName": "', - encryption_spec_key_name, '"}', '}' + '{', + '"query_parameters": ', + query_parameters, + ', "destination_encryption_configuration": {', + '"kmsKeyName": "', + encryption_spec_key_name, + '"}', + '}', ]), '--gcp_resources', gcp_resources, '--executor_input', '{{$}}', - ]) + ], + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_confusion_matrix/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_confusion_matrix/component.py index fe421c190e7..e72cfc0c2b5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_confusion_matrix/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_confusion_matrix/component.py @@ -38,6 +38,7 @@ def bigquery_ml_confusion_matrix_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery confusion matrix job and waits for it to finish. Args: @@ -91,6 +92,7 @@ def bigquery_ml_confusion_matrix_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_feature_info/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_feature_info/component.py index 10208f126f1..09ebb1ffc63 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_feature_info/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_feature_info/component.py @@ -35,6 +35,7 @@ def bigquery_ml_feature_info_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery feature info job and waits for it to finish. Args: @@ -77,6 +78,7 @@ def bigquery_ml_feature_info_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_component_info/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_component_info/component.py index bd819aa4c52..db2f94e0729 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_component_info/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_component_info/component.py @@ -36,6 +36,7 @@ def bigquery_ml_principal_component_info_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.principal_component_info job and waits for it to finish. Args: @@ -88,6 +89,7 @@ def bigquery_ml_principal_component_info_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_components/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_components/component.py index e9b4440d3b6..a0a10a09ad3 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_components/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_principal_components/component.py @@ -36,6 +36,7 @@ def bigquery_ml_principal_components_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.principal_components job and waits for it to finish. Args: @@ -87,6 +88,7 @@ def bigquery_ml_principal_components_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_recommend/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_recommend/component.py index 69fa8ed8a91..692b7d287d1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_recommend/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_recommend/component.py @@ -38,6 +38,7 @@ def bigquery_ml_recommend_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ML.Recommend job and waits for it to finish. Args: @@ -96,6 +97,7 @@ def bigquery_ml_recommend_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_reconstruction_loss/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_reconstruction_loss/component.py index a144ff12c99..edef6928298 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_reconstruction_loss/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_reconstruction_loss/component.py @@ -37,6 +37,7 @@ def bigquery_ml_reconstruction_loss_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ml reconstruction loss job and waits for it to finish. Args: @@ -96,6 +97,7 @@ def bigquery_ml_reconstruction_loss_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_roc_curve/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_roc_curve/component.py index e6c8c8b6441..57958b5c129 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_roc_curve/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_roc_curve/component.py @@ -38,6 +38,7 @@ def bigquery_ml_roc_curve_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery roc curve job and waits for it to finish. Args: @@ -91,6 +92,7 @@ def bigquery_ml_roc_curve_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_training_info/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_training_info/component.py index 3b63b5d4078..1180966ab7c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_training_info/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_training_info/component.py @@ -35,6 +35,7 @@ def bigquery_ml_training_info_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery ml training info fetching job and waits for it to finish. Args: @@ -78,6 +79,7 @@ def bigquery_ml_training_info_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_trial_info/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_trial_info/component.py index 7a7224cff61..cd478a4136b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_trial_info/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_trial_info/component.py @@ -36,6 +36,7 @@ def bigquery_ml_trial_info_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery ml trial info job and waits for it to finish. Args: @@ -84,6 +85,7 @@ def bigquery_ml_trial_info_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_weights/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_weights/component.py index 39f79ace479..4d63552cac1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_weights/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/ml_weights/component.py @@ -35,6 +35,7 @@ def bigquery_ml_weights_job( job_configuration_query: Dict[str, str] = {}, labels: Dict[str, str] = {}, ): + # fmt: off """Launch a BigQuery ml weights job and waits for it to finish. Args: @@ -78,6 +79,7 @@ def bigquery_ml_weights_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/predict_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/predict_model/component.py index 01edb01be01..4449bd8a5db 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/predict_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/predict_model/component.py @@ -39,6 +39,7 @@ def bigquery_predict_model_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery predict model job and waits for it to finish. Args: @@ -104,6 +105,7 @@ def bigquery_predict_model_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/query_job/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/query_job/component.py index a0abadc68fa..421c5504298 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/query_job/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/bigquery/query_job/component.py @@ -34,6 +34,7 @@ def bigquery_query_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a BigQuery query job and waits for it to finish. Args: @@ -86,6 +87,7 @@ def bigquery_query_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/component.py index 9db3fd9aa85..0cfba25cdae 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/component.py @@ -38,6 +38,7 @@ def custom_training_job( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Launch a Custom training job using Vertex CustomJob API. Args: @@ -114,6 +115,7 @@ def custom_training_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/utils.py b/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/utils.py index 80390cea097..48b99f4ee19 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/utils.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/custom_job/utils.py @@ -30,7 +30,8 @@ def _replace_executor_placeholder( - container_input: Sequence[str]) -> Sequence[str]: + container_input: Sequence[str], +) -> Sequence[str]: """Replace executor placeholder in container command or args. Args: @@ -41,7 +42,9 @@ def _replace_executor_placeholder( """ return [ _EXECUTOR_PLACEHOLDER_REPLACEMENT - if input == _EXECUTOR_PLACEHOLDER else input for input in container_input + if input == _EXECUTOR_PLACEHOLDER + else input + for input in container_input ] @@ -86,94 +89,76 @@ def create_custom_training_job_from_component( Args: component_spec: The task (ContainerOp) object to run as Vertex AI custom job. - display_name (Optional[str]): - The name of the custom job. If not provided + display_name (Optional[str]): The name of the custom job. If not provided the component_spec.name will be used instead. - replica_count (Optional[int]): - The count of instances in the cluster. One + replica_count (Optional[int]): The count of instances in the cluster. One replica always counts towards the master in worker_pool_spec[0] and the remaining replicas will be allocated in worker_pool_spec[1]. For more details see https://cloud.google.com/vertex-ai/docs/training/distributed-training#configure_a_distributed_training_job. - machine_type (Optional[str]): - The type of the machine to run the custom job. + machine_type (Optional[str]): The type of the machine to run the custom job. The default value is "n1-standard-4". For more details about this input config, see https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types. - accelerator_type (Optional[str]): - The type of accelerator(s) that may be + accelerator_type (Optional[str]): The type of accelerator(s) that may be attached to the machine as per accelerator_count. For more details about this input config, see https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#acceleratortype. - accelerator_count (Optional[int]): - The number of accelerators to attach to + accelerator_count (Optional[int]): The number of accelerators to attach to the machine. Defaults to 1 if accelerator_type is set. - boot_disk_type (Optional[str]): - Type of the boot disk (default is "pd-ssd"). + boot_disk_type (Optional[str]): Type of the boot disk (default is "pd-ssd"). Valid values: "pd-ssd" (Persistent Disk Solid State Drive) or "pd-standard" (Persistent Disk Hard Disk Drive). boot_disk_type is set as a static value and cannot be changed as a pipeline parameter. - boot_disk_size_gb (Optional[int]): - Size in GB of the boot disk (default is + boot_disk_size_gb (Optional[int]): Size in GB of the boot disk (default is 100GB). boot_disk_size_gb is set as a static value and cannot be changed as a pipeline parameter. - timeout (Optional[str]): - The maximum job running time. The default is 7 + timeout (Optional[str]): The maximum job running time. The default is 7 days. A duration in seconds with up to nine fractional digits, terminated by 's', for example: "3.5s". - restart_job_on_worker_restart (Optional[bool]): - Restarts the entire + restart_job_on_worker_restart (Optional[bool]): Restarts the entire CustomJob if a worker gets restarted. This feature can be used by distributed training jobs that are not resilient to workers leaving and joining a job. - service_account (Optional[str]): - Sets the default service account for + service_account (Optional[str]): Sets the default service account for workload run-as account. The service account running the pipeline (https://cloud.google.com/vertex-ai/docs/pipelines/configure-project#service-account) submitting jobs must have act-as permission on this run-as account. If unspecified, the Vertex AI Custom Code Service Agent(https://cloud.google.com/vertex-ai/docs/general/access-control#service-agents) for the CustomJob's project. - network (Optional[str]): - The full name of the Compute Engine network to + network (Optional[str]): The full name of the Compute Engine network to which the job should be peered. For example, projects/12345/global/networks/myVPC. Format is of the form projects/{project}/global/networks/{network}. Where {project} is a project number, as in 12345, and {network} is a network name. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - encryption_spec_key_name (Optional[str]): - Customer-managed encryption key + encryption_spec_key_name (Optional[str]): Customer-managed encryption key options for the CustomJob. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. - tensorboard (Optional[str]): - The name of a Vertex AI Tensorboard resource to + tensorboard (Optional[str]): The name of a Vertex AI Tensorboard resource to which this CustomJob will upload Tensorboard logs. - enable_web_access (Optional[bool]): - Whether you want Vertex AI to enable + enable_web_access (Optional[bool]): Whether you want Vertex AI to enable [interactive shell access](https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell) to training containers. If set to `true`, you can access interactive shells at the URIs given by [CustomJob.web_access_uris][]. - reserved_ip_ranges (Optional[Sequence[str]]): - A list of names for the + reserved_ip_ranges (Optional[Sequence[str]]): A list of names for the reserved ip ranges under the VPC network that can be used for this job. If set, we will deploy the job within the provided ip ranges. Otherwise, the job will be deployed to any ip ranges under the provided VPC network. - nfs_mounts (Optional[Sequence[Dict]]): - A list of NFS mount specs in Json + nfs_mounts (Optional[Sequence[Dict]]): A list of NFS mount specs in Json dict format. nfs_mounts is set as a static value and cannot be changed as a pipeline parameter. For API spec, see https://cloud.devsite.corp.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#NfsMount For more details about mounting NFS for CustomJob, see https://cloud.devsite.corp.google.com/vertex-ai/docs/training/train-nfs-share - base_output_directory (Optional[str]): - The Cloud Storage location to store + base_output_directory (Optional[str]): The Cloud Storage location to store the output of this CustomJob or HyperparameterTuningJob. see below for more details: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GcsDestination - labels (Optional[Dict[str, str]]): - The labels with user-defined metadata to + labels (Optional[Dict[str, str]]): The labels with user-defined metadata to organize CustomJobs. See https://goo.gl/xmQnxf for more information. Returns: @@ -215,27 +200,30 @@ def create_custom_training_job_from_component( # is returned. custom_training_job_dict = json_format.MessageToDict( - component.custom_training_job.pipeline_spec) + component.custom_training_job.pipeline_spec + ) input_component_spec_dict = json_format.MessageToDict( - component_spec.pipeline_spec) - component_spec_container = list(input_component_spec_dict['deploymentSpec'] - ['executors'].values())[0]['container'] + component_spec.pipeline_spec + ) + component_spec_container = list( + input_component_spec_dict['deploymentSpec']['executors'].values() + )[0]['container'] # Construct worker_pool_spec worker_pool_spec = { - 'machine_spec': { - 'machine_type': machine_type - }, + 'machine_spec': {'machine_type': machine_type}, 'replica_count': 1, 'container_spec': { 'image_uri': component_spec_container['image'], - } + }, } worker_pool_spec['container_spec']['command'] = _replace_executor_placeholder( - component_spec_container.get('command', [])) + component_spec_container.get('command', []) + ) worker_pool_spec['container_spec']['args'] = _replace_executor_placeholder( - component_spec_container.get('args', [])) + component_spec_container.get('args', []) + ) if accelerator_type: worker_pool_spec['machine_spec']['accelerator_type'] = accelerator_type @@ -258,92 +246,132 @@ def create_custom_training_job_from_component( # Retrieve the custom job input/output parameters custom_training_job_dict_components = custom_training_job_dict['components'] custom_training_job_comp_key = list( - custom_training_job_dict_components.keys())[0] + custom_training_job_dict_components.keys() + )[0] custom_training_job_comp_val = custom_training_job_dict_components[ - custom_training_job_comp_key] + custom_training_job_comp_key + ] custom_job_input_params = custom_training_job_comp_val['inputDefinitions'][ - 'parameters'] + 'parameters' + ] custom_job_output_params = custom_training_job_comp_val['outputDefinitions'][ - 'parameters'] + 'parameters' + ] # Insert input arguments into custom_job_input_params as default values - custom_job_input_params['display_name'][ - 'defaultValue'] = display_name or component_spec.component_spec.name + custom_job_input_params['display_name']['defaultValue'] = ( + display_name or component_spec.component_spec.name + ) custom_job_input_params['worker_pool_specs'][ - 'defaultValue'] = worker_pool_specs + 'defaultValue' + ] = worker_pool_specs custom_job_input_params['timeout']['defaultValue'] = timeout custom_job_input_params['restart_job_on_worker_restart'][ - 'defaultValue'] = restart_job_on_worker_restart + 'defaultValue' + ] = restart_job_on_worker_restart custom_job_input_params['service_account']['defaultValue'] = service_account custom_job_input_params['tensorboard']['defaultValue'] = tensorboard custom_job_input_params['enable_web_access'][ - 'defaultValue'] = enable_web_access + 'defaultValue' + ] = enable_web_access custom_job_input_params['network']['defaultValue'] = network - custom_job_input_params['reserved_ip_ranges'][ - 'defaultValue'] = reserved_ip_ranges or [] + custom_job_input_params['reserved_ip_ranges']['defaultValue'] = ( + reserved_ip_ranges or [] + ) custom_job_input_params['base_output_directory'][ - 'defaultValue'] = base_output_directory + 'defaultValue' + ] = base_output_directory custom_job_input_params['labels']['defaultValue'] = labels or {} custom_job_input_params['encryption_spec_key_name'][ - 'defaultValue'] = encryption_spec_key_name + 'defaultValue' + ] = encryption_spec_key_name # Merge with the input/output parameters from the input component. input_component_spec_comp_val = list( - input_component_spec_dict['components'].values())[0] + input_component_spec_dict['components'].values() + )[0] custom_job_input_params = { - **(input_component_spec_comp_val.get('inputDefinitions', - {}).get('parameters', {})), - **custom_job_input_params + **( + input_component_spec_comp_val.get('inputDefinitions', {}).get( + 'parameters', {} + ) + ), + **custom_job_input_params, } custom_job_output_params = { - **(input_component_spec_comp_val.get('outputDefinitions', - {}).get('parameters', {})), - **custom_job_output_params + **( + input_component_spec_comp_val.get('outputDefinitions', {}).get( + 'parameters', {} + ) + ), + **custom_job_output_params, } # Copy merged input/output parameters to custom_training_job_dict # Using copy.deepcopy here to avoid anchors and aliases in the produced # YAML as a result of pointing to the same dict. - custom_training_job_dict['root']['inputDefinitions'][ - 'parameters'] = copy.deepcopy(custom_job_input_params) + custom_training_job_dict['root']['inputDefinitions']['parameters'] = ( + copy.deepcopy(custom_job_input_params) + ) custom_training_job_dict['components'][custom_training_job_comp_key][ - 'inputDefinitions']['parameters'] = copy.deepcopy(custom_job_input_params) + 'inputDefinitions' + ]['parameters'] = copy.deepcopy(custom_job_input_params) custom_training_job_tasks_key = list( - custom_training_job_dict['root']['dag']['tasks'].keys())[0] + custom_training_job_dict['root']['dag']['tasks'].keys() + )[0] custom_training_job_dict['root']['dag']['tasks'][ - custom_training_job_tasks_key]['inputs']['parameters'] = { - **(list(input_component_spec_dict['root']['dag']['tasks'].values()) - [0].get('inputs', {}).get('parameters', {})), - **(custom_training_job_dict['root']['dag']['tasks'] - [custom_training_job_tasks_key]['inputs']['parameters']) - } + custom_training_job_tasks_key + ]['inputs']['parameters'] = { + **( + list(input_component_spec_dict['root']['dag']['tasks'].values())[0] + .get('inputs', {}) + .get('parameters', {}) + ), + **( + custom_training_job_dict['root']['dag']['tasks'][ + custom_training_job_tasks_key + ]['inputs']['parameters'] + ), + } custom_training_job_dict['components'][custom_training_job_comp_key][ - 'outputDefinitions']['parameters'] = custom_job_output_params + 'outputDefinitions' + ]['parameters'] = custom_job_output_params # Retrieve the input/output artifacts from the input component. custom_job_input_artifacts = input_component_spec_comp_val.get( - 'inputDefinitions', {}).get('artifacts', {}) + 'inputDefinitions', {} + ).get('artifacts', {}) custom_job_output_artifacts = input_component_spec_comp_val.get( - 'outputDefinitions', {}).get('artifacts', {}) + 'outputDefinitions', {} + ).get('artifacts', {}) # Copy input/output artifacts from the input component to # custom_training_job_dict if custom_job_input_artifacts: - custom_training_job_dict['root']['inputDefinitions'][ - 'artifacts'] = copy.deepcopy(custom_job_input_artifacts) + custom_training_job_dict['root']['inputDefinitions']['artifacts'] = ( + copy.deepcopy(custom_job_input_artifacts) + ) custom_training_job_dict['components'][custom_training_job_comp_key][ - 'inputDefinitions']['artifacts'] = copy.deepcopy( - custom_job_input_artifacts) + 'inputDefinitions' + ]['artifacts'] = copy.deepcopy(custom_job_input_artifacts) custom_training_job_dict['root']['dag']['tasks'][ - custom_training_job_tasks_key]['inputs']['artifacts'] = { - **(list(input_component_spec_dict['root']['dag']['tasks'].values()) - [0].get('inputs', {}).get('artifacts', {})), - **(custom_training_job_dict['root']['dag']['tasks'] - [custom_training_job_tasks_key]['inputs'].get('artifacts', {})) - } + custom_training_job_tasks_key + ]['inputs']['artifacts'] = { + **( + list(input_component_spec_dict['root']['dag']['tasks'].values())[0] + .get('inputs', {}) + .get('artifacts', {}) + ), + **( + custom_training_job_dict['root']['dag']['tasks'][ + custom_training_job_tasks_key + ]['inputs'].get('artifacts', {}) + ), + } if custom_job_output_artifacts: custom_training_job_dict['components'][custom_training_job_comp_key][ - 'outputDefinitions']['artifacts'] = custom_job_output_artifacts + 'outputDefinitions' + ]['artifacts'] = custom_job_output_artifacts # Create new component from component IR YAML custom_training_job_yaml = yaml.safe_dump(custom_training_job_dict) @@ -358,8 +386,12 @@ def create_custom_training_job_from_component( if component_spec.description: # TODO(chavoshi) Add support for docstring parsing. component_description = 'A custom job that wraps ' - component_description += f'{component_spec.component_spec.name}.\n\nOriginal component' - component_description += f' description:\n{component_spec.description}\n\nCustom' + component_description += ( + f'{component_spec.component_spec.name}.\n\nOriginal component' + ) + component_description += ( + f' description:\n{component_spec.description}\n\nCustom' + ) component_description += ' Job wrapper description:\n' component_description += component.custom_training_job.description @@ -371,7 +403,9 @@ def create_custom_training_job_from_component( # This alias points to the old "create_custom_training_job_op_from_component" to # avoid potential user breakage. def create_custom_training_job_op_from_component(*args, **kwargs) -> Callable: # pylint: disable=g-bare-generic - """Deprecated. Please use create_custom_training_job_from_component instead. + """Deprecated. + + Please use create_custom_training_job_from_component instead. Args: *args: Positional arguments for create_custom_training_job_from_component. diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataflow/python_job/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataflow/python_job/component.py index b751989a256..76c1abe33e5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataflow/python_job/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataflow/python_job/component.py @@ -15,6 +15,7 @@ def dataflow_python( requirements_file_path: str = '', args: List[str] = [], ): + # fmt: off """Launch a self-executing beam python file on Google Cloud using the DataflowRunner. Args: @@ -40,11 +41,14 @@ def dataflow_python( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ - 'python3', '-u', '-m', - 'google_cloud_pipeline_components.container.v1.dataflow.dataflow_launcher' + 'python3', + '-u', + '-m', + 'google_cloud_pipeline_components.container.v1.dataflow.dataflow_launcher', ], args=[ '--project', @@ -61,4 +65,5 @@ def dataflow_python( args, '--gcp_resources', gcp_resources, - ]) + ], + ) diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/__init__.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/__init__.py index a930f835a29..50b8f464d87 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/__init__.py @@ -27,7 +27,13 @@ 'DataprocSparkSqlBatchOp', ] -DataprocPySparkBatchOp = create_pyspark_batch_component.dataproc_create_pyspark_batch +DataprocPySparkBatchOp = ( + create_pyspark_batch_component.dataproc_create_pyspark_batch +) DataprocSparkBatchOp = create_spark_batch_component.dataproc_create_spark_batch -DataprocSparkRBatchOp = create_spark_r_batch_component.dataproc_create_spark_r_batch -DataprocSparkSqlBatchOp = create_spark_sql_batch_component.dataproc_create_spark_sql_batch +DataprocSparkRBatchOp = ( + create_spark_r_batch_component.dataproc_create_spark_r_batch +) +DataprocSparkSqlBatchOp = ( + create_spark_sql_batch_component.dataproc_create_spark_sql_batch +) diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_pyspark_batch/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_pyspark_batch/component.py index c2e0280efa9..d2b3b2da642 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_pyspark_batch/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_pyspark_batch/component.py @@ -45,6 +45,7 @@ def dataproc_create_pyspark_batch( archive_uris: List[str] = [], args: List[str] = [], ): + # fmt: off """Create a Dataproc PySpark batch workload and wait for it to finish. Args: @@ -121,6 +122,7 @@ def dataproc_create_pyspark_batch( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_batch/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_batch/component.py index e030fb49fca..3934479a22d 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_batch/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_batch/component.py @@ -44,6 +44,7 @@ def dataproc_create_spark_batch( archive_uris: List[str] = [], args: List[str] = [], ): + # fmt: off """Create a Dataproc Spark batch workload and wait for it to finish. Args: @@ -117,6 +118,7 @@ def dataproc_create_spark_batch( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_r_batch/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_r_batch/component.py index cc5b957b498..0ce72ced784 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_r_batch/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_r_batch/component.py @@ -42,6 +42,7 @@ def dataproc_create_spark_r_batch( archive_uris: List[str] = [], args: List[str] = [], ): + # fmt: off """Create a Dataproc SparkR batch workload and wait for it to finish. Args: @@ -103,6 +104,7 @@ def dataproc_create_spark_r_batch( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_sql_batch/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_sql_batch/component.py index e45f43db4aa..2115b3746e2 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_sql_batch/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataproc/create_spark_sql_batch/component.py @@ -39,6 +39,7 @@ def dataproc_create_spark_sql_batch( query_variables: Dict[str, str] = {}, jar_file_uris: List[str] = [], ): + # fmt: off """Create a Dataproc Spark SQL batch workload and wait for it to finish. Args: @@ -97,6 +98,7 @@ def dataproc_create_spark_sql_batch( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/__init__.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/__init__.py index a644959f2fc..8b1d25457e6 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/__init__.py @@ -46,22 +46,34 @@ ImageDatasetCreateOp = create_image_dataset_component.image_dataset_create TabularDatasetCreateOp = create_tabular_dataset_component.tabular_dataset_create TextDatasetCreateOp = create_text_dataset_component.text_dataset_create -TimeSeriesDatasetCreateOp = create_time_series_dataset_component.time_series_dataset_create +TimeSeriesDatasetCreateOp = ( + create_time_series_dataset_component.time_series_dataset_create +) VideoDatasetCreateOp = create_video_dataset_component.video_dataset_create ImageDatasetExportDataOp = export_image_dataset_component.image_dataset_export -TabularDatasetExportDataOp = export_tabular_dataset_component.tabular_dataset_export +TabularDatasetExportDataOp = ( + export_tabular_dataset_component.tabular_dataset_export +) TextDatasetExportDataOp = export_text_dataset_component.text_dataset_export -TimeSeriesDatasetExportDataOp = export_time_series_dataset_component.time_series_dataset_export +TimeSeriesDatasetExportDataOp = ( + export_time_series_dataset_component.time_series_dataset_export +) VideoDatasetExportDataOp = export_video_dataset_component.video_dataset_export ImageDatasetImportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'import_image_dataset/component.yaml')) + os.path.dirname(__file__), 'import_image_dataset/component.yaml' + ) +) TextDatasetImportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'import_text_dataset/component.yaml')) + os.path.dirname(__file__), 'import_text_dataset/component.yaml' + ) +) VideoDatasetImportDataOp = load_component_from_file( os.path.join( - os.path.dirname(__file__), 'import_video_dataset/component.yaml')) + os.path.dirname(__file__), 'import_video_dataset/component.yaml' + ) +) diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_image_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_image_dataset/component.py index 742b4b509dc..05e38577084 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_image_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_image_dataset/component.py @@ -34,6 +34,7 @@ def image_dataset_create( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Creates a new image dataset and optionally imports data into dataset when source and import_schema_uri are passed. Args: @@ -96,6 +97,7 @@ def image_dataset_create( Instantiated representation of the managed image dataset resource. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_tabular_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_tabular_dataset/component.py index 2c7d61b3d7a..95c82656d3f 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_tabular_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_tabular_dataset/component.py @@ -33,6 +33,7 @@ def tabular_dataset_create( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Creates a new tabular dataset. Args: @@ -80,6 +81,7 @@ def tabular_dataset_create( Instantiated representation of the managed tabular dataset resource. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_text_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_text_dataset/component.py index 33692d52cc2..86d6be17d6f 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_text_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_text_dataset/component.py @@ -34,6 +34,7 @@ def text_dataset_create( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Creates a new text dataset and optionally imports data into dataset when source and import_schema_uri are passed. Args: @@ -96,6 +97,7 @@ def text_dataset_create( Instantiated representation of the managed text dataset resource. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_time_series_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_time_series_dataset/component.py index 67f5cacab36..f1fe0426d52 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_time_series_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_time_series_dataset/component.py @@ -33,6 +33,7 @@ def time_series_dataset_create( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Creates a new time series dataset. Args: @@ -80,6 +81,7 @@ def time_series_dataset_create( Instantiated representation of the managed time series dataset resource. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_video_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_video_dataset/component.py index 4d0e75b8e19..4209ceb3254 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_video_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/create_video_dataset/component.py @@ -33,6 +33,7 @@ def video_dataset_create( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Creates a new video dataset and optionally imports data into dataset when source and import_schema_uri are passed. Args: @@ -95,6 +96,7 @@ def video_dataset_create( Instantiated representation of the managed video dataset resource. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_image_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_image_dataset/component.py index f76078f400e..722eb122130 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_image_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_image_dataset/component.py @@ -28,6 +28,7 @@ def image_dataset_export( exported_dataset: Output[Artifact], location: str = 'us-central1', ): + # fmt: off """Exports data to output dir to GCS. Args: @@ -55,6 +56,7 @@ def image_dataset_export( exported_dataset (Sequence[str]): All of the files that are exported in this export operation. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_tabular_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_tabular_dataset/component.py index 9ffdc1a332d..60640ccf0c8 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_tabular_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_tabular_dataset/component.py @@ -28,6 +28,7 @@ def tabular_dataset_export( exported_dataset: Output[Artifact], location: str = 'us-central1', ): + # fmt: off """Exports data to output dir to GCS. Args: @@ -55,6 +56,7 @@ def tabular_dataset_export( exported_dataset (Sequence[str]): All of the files that are exported in this export operation. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_text_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_text_dataset/component.py index 16f8af741b8..7ea60fdf2da 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_text_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_text_dataset/component.py @@ -28,6 +28,7 @@ def text_dataset_export( exported_dataset: Output[Artifact], location: str = 'us-central1', ): + # fmt: off """Exports data to output dir to GCS. Args: @@ -55,6 +56,7 @@ def text_dataset_export( exported_dataset (Sequence[str]): All of the files that are exported in this export operation. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_time_series_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_time_series_dataset/component.py index 90a265fe3d2..b94e1108cee 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_time_series_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_time_series_dataset/component.py @@ -28,6 +28,7 @@ def time_series_dataset_export( exported_dataset: Output[Artifact], location: str = 'us-central1', ): + # fmt: off """Exports data to output dir to GCS. Args: @@ -56,6 +57,7 @@ def time_series_dataset_export( exported_dataset (Sequence[str]): All of the files that are exported in this export operation. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_video_dataset/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_video_dataset/component.py index 989f7643be2..9954886b0d3 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_video_dataset/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/dataset/export_video_dataset/component.py @@ -28,6 +28,7 @@ def video_dataset_export( exported_dataset: Output[Artifact], location: str = 'us-central1', ): + # fmt: off """Exports data to output dir to GCS. Args: @@ -56,6 +57,7 @@ def video_dataset_export( exported_dataset (Sequence[str]): All of the files that are exported in this export operation. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/create_endpoint/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/create_endpoint/component.py index f5481a50831..24e8bd4d9d6 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/create_endpoint/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/create_endpoint/component.py @@ -34,6 +34,7 @@ def endpoint_create( encryption_spec_key_name: str = '', network: str = '', ): + # fmt: off """Creates a Google Cloud Vertex Endpoint and waits for it to be ready. For more details, see @@ -85,6 +86,7 @@ def endpoint_create( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/deploy_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/deploy_model/component.py index 99a102539ef..dfc6a2738fc 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/deploy_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/endpoint/deploy_model/component.py @@ -43,6 +43,7 @@ def model_deploy( explanation_metadata: Dict[str, str] = {}, explanation_parameters: Dict[str, str] = {}, ): + # fmt: off """Deploys a Google Cloud Vertex Model to the Endpoint, creating a DeployedModel within it. For more details, see @@ -160,6 +161,7 @@ def model_deploy( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/__init__.py b/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/__init__.py index 41c52207cb5..d9b2f7b8de7 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/__init__.py @@ -25,4 +25,6 @@ 'serialize_parameters', ] -HyperparameterTuningJobRunOp = hyperparameter_tuning_job_component.hyperparameter_tuning_job +HyperparameterTuningJobRunOp = ( + hyperparameter_tuning_job_component.hyperparameter_tuning_job +) diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/component.py index b99c9fdbdf6..2ba99a35d2c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/component.py @@ -39,6 +39,7 @@ def hyperparameter_tuning_job( service_account: str = '', network: str = '', ): + # fmt: off """Creates a Google Cloud AI Platform HyperparameterTuning Job and waits for it to complete. For example usage, see @@ -162,6 +163,7 @@ def hyperparameter_tuning_job( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/utils.py b/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/utils.py index cdc03555d2e..eb8b004d392 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/utils.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/hyperparameter_tuning_job/utils.py @@ -20,8 +20,7 @@ def serialize_parameters(parameters: dict) -> list: """Serializes the hyperparameter tuning parameter spec to dictionary format. Args: - parameters (Dict[str, hyperparameter_tuning._ParameterSpec]): - Dictionary + parameters (Dict[str, hyperparameter_tuning._ParameterSpec]): Dictionary representing parameters to optimize. The dictionary key is the parameter_id, which is passed into your training job as a command line key word argument, and the dictionary value is the parameter @@ -52,8 +51,7 @@ def serialize_metrics(metric_spec: dict) -> list: """Serializes a metric spec to dictionary format. Args: - metric_spec (Dict[str, str]): - Required. Dictionary representing metrics to + metric_spec (Dict[str, str]): Required. Dictionary representing metrics to optimize. The dictionary key is the metric_id, which is reported by your training job, and the dictionary value is the optimization goal of the metric ('minimize' or 'maximize'). Example: metrics = {'loss': diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/model/export_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/model/export_model/component.py index 25e3e70c722..ed9d6cb488a 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/model/export_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/model/export_model/component.py @@ -31,6 +31,7 @@ def model_export( artifact_destination: str = '', image_destination: str = '', ): + # fmt: off """Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one supported @@ -78,6 +79,7 @@ def model_export( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/model/upload_model/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/model/upload_model/component.py index 767187fae7a..f80f5250f44 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/model/upload_model/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/model/upload_model/component.py @@ -40,6 +40,7 @@ def model_upload( labels: Dict[str, str] = {}, encryption_spec_key_name: str = '', ): + # fmt: off """Uploads a model and returns a Model representing the uploaded Model resource. For more details, see @@ -113,6 +114,7 @@ def model_upload( For more details, see https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/__init__.py b/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/__init__.py index 09fd100c329..3a934fde732 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/__init__.py @@ -21,4 +21,6 @@ 'VertexNotificationEmailOp', ] -VertexNotificationEmailOp = vertex_notification_email_component.vertex_pipelines_notification_email +VertexNotificationEmailOp = ( + vertex_notification_email_component.vertex_pipelines_notification_email +) diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/component.py b/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/component.py index 969928035b8..d875a9cd2ad 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/vertex_notification_email/component.py @@ -24,6 +24,7 @@ def vertex_pipelines_notification_email( recipients: List[str], pipeline_task_final_status: PipelineTaskFinalStatus, ): + # fmt: off """When this component is included as an exit handler, sends a notification email with the status of the upstream DAG to the specified recipients. This component works only on Vertex Pipelines. This component raises an @@ -37,6 +38,7 @@ def vertex_pipelines_notification_email( The task final status of the upstream DAG that this component will use in the notification. """ + # fmt: on return ContainerSpec( image='gcr.io/ml-pipeline/google-cloud-pipeline-components:2.0.0b1', command=[ diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/wait_gcp_resources/__init__.py b/components/google-cloud/google_cloud_pipeline_components/v1/wait_gcp_resources/__init__.py index dcadcc95dac..c5bc15612ee 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/wait_gcp_resources/__init__.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/wait_gcp_resources/__init__.py @@ -21,4 +21,5 @@ ] WaitGcpResourcesOp = load_component_from_file( - os.path.join(os.path.dirname(__file__), 'component.yaml')) + os.path.join(os.path.dirname(__file__), 'component.yaml') +) diff --git a/components/google-cloud/tests/experimental/custom_job/integration/test_custom_training_job_wrapper_compile.py b/components/google-cloud/tests/experimental/custom_job/integration/test_custom_training_job_wrapper_compile.py index 2abf1bfa4c6..e53733151a4 100644 --- a/components/google-cloud/tests/experimental/custom_job/integration/test_custom_training_job_wrapper_compile.py +++ b/components/google-cloud/tests/experimental/custom_job/integration/test_custom_training_job_wrapper_compile.py @@ -56,7 +56,7 @@ def tearDown(self): if os.path.exists(self._package_path): os.remove(self._package_path) - def _create_a_pytnon_based_component(self) -> callable: + def _create_a_pytnon_based_component(self): """Creates a test python based component factory.""" @kfp.dsl.component