-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(components): Add Starry Net forecasting pipeline to public preview
Signed-off-by: Googler <[email protected]> PiperOrigin-RevId: 643098339
- Loading branch information
Googler
committed
Jun 13, 2024
1 parent
e69078b
commit 3a0566e
Showing
31 changed files
with
1,678 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ Preview Components | |
custom_job | ||
dataflow | ||
llm | ||
model_evaluation | ||
model_evaluation | ||
starry_net |
4 changes: 4 additions & 0 deletions
4
components/google-cloud/docs/source/api/preview/starry_net.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
STARRY_NET | ||
========================== | ||
|
||
.. automodule:: preview.starry_net |
41 changes: 41 additions & 0 deletions
41
...ents/google-cloud/google_cloud_pipeline_components/_implementation/starry_net/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from google_cloud_pipeline_components._implementation.starry_net.dataprep.component import dataprep as DataprepOp | ||
from google_cloud_pipeline_components._implementation.starry_net.evaluation.component import evaluation as EvaluationOp | ||
from google_cloud_pipeline_components._implementation.starry_net.get_training_artifacts.component import get_training_artifacts as GetTrainingArtifactsOp | ||
from google_cloud_pipeline_components._implementation.starry_net.maybe_set_tfrecord_args.component import maybe_set_tfrecord_args as MaybeSetTfrecordArgsOp | ||
from google_cloud_pipeline_components._implementation.starry_net.set_dataprep_args.component import set_dataprep_args as SetDataprepArgsOp | ||
from google_cloud_pipeline_components._implementation.starry_net.set_eval_args.component import set_eval_args as SetEvalArgsOp | ||
from google_cloud_pipeline_components._implementation.starry_net.set_test_set.component import set_test_set as SetTestSetOp | ||
from google_cloud_pipeline_components._implementation.starry_net.set_tfrecord_args.component import set_tfrecord_args as SetTfrecordArgsOp | ||
from google_cloud_pipeline_components._implementation.starry_net.set_train_args.component import set_train_args as SetTrainArgsOp | ||
from google_cloud_pipeline_components._implementation.starry_net.train.component import train as TrainOp | ||
from google_cloud_pipeline_components._implementation.starry_net.upload_decomposition_plots.component import upload_decomposition_plots as UploadDecompositionPlotsOp | ||
from google_cloud_pipeline_components._implementation.starry_net.upload_model.component import upload_model as UploadModelOp | ||
|
||
|
||
__all__ = [ | ||
'DataprepOp', | ||
'EvaluationOp', | ||
'GetTrainingArtifactsOp', | ||
'MaybeSetTfrecordArgsOp', | ||
'SetDataprepArgsOp', | ||
'SetEvalArgsOp', | ||
'SetTestSetOp', | ||
'SetTfrecordArgsOp', | ||
'SetTrainArgsOp', | ||
'TrainOp', | ||
'UploadDecompositionPlotsOp', | ||
'UploadModelOp', | ||
] |
13 changes: 13 additions & 0 deletions
13
...le-cloud/google_cloud_pipeline_components/_implementation/starry_net/dataprep/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
159 changes: 159 additions & 0 deletions
159
...e-cloud/google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Starry Net component for data preparation.""" | ||
|
||
from google_cloud_pipeline_components import utils | ||
from google_cloud_pipeline_components._implementation.starry_net import version | ||
from kfp import dsl | ||
|
||
|
||
@dsl.container_component | ||
def dataprep( | ||
gcp_resources: dsl.OutputPath(str), | ||
dataprep_dir: dsl.Output[dsl.Artifact], # pytype: disable=unsupported-operands | ||
backcast_length: int, | ||
forecast_length: int, | ||
train_end_date: str, | ||
n_val_windows: int, | ||
n_test_windows: int, | ||
test_set_stride: int, | ||
model_blocks: str, | ||
bigquery_source: str, | ||
ts_identifier_columns: str, | ||
time_column: str, | ||
static_covariate_columns: str, | ||
target_column: str, | ||
machine_type: str, | ||
docker_region: str, | ||
location: str, | ||
project: str, | ||
job_id: str, | ||
job_name_prefix: str, | ||
num_workers: int, | ||
max_num_workers: int, | ||
disk_size_gb: int, | ||
test_set_only: bool, | ||
bigquery_output: str, | ||
gcs_source: str, | ||
gcs_static_covariate_source: str, | ||
encryption_spec_key_name: str, | ||
): | ||
# fmt: off | ||
"""Runs Dataprep for training and evaluating a STARRY-Net model. | ||
Args: | ||
gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the | ||
CustomJob. | ||
dataprep_dir: The gcp bucket path where all dataprep artifacts | ||
are saved. | ||
backcast_length: The length of the input window to feed into the model. | ||
forecast_length: The length of the forecast horizon. | ||
train_end_date: The last date of data to use in the training set. All | ||
subsequent dates are part of the test set. | ||
n_val_windows: The number of windows to use for the val set. If 0, no | ||
validation set is used. | ||
n_test_windows: The number of windows to use for the test set. Must be >= 1. | ||
test_set_stride: The number of timestamps to roll forward when | ||
constructing the val and test sets. | ||
model_blocks: The stringified tuple of blocks to use in the order | ||
that they appear in the model. Possible values are `cleaning`, | ||
`change_point`, `trend`, `hour_of_week-hybrid`, `day_of_week-hybrid`, | ||
`day_of_year-hybrid`, `week_of_year-hybrid`, `month_of_year-hybrid`, | ||
`residual`, `quantile`. | ||
bigquery_source: The BigQuery source of the data. | ||
ts_identifier_columns: The columns that identify unique time series in the BigQuery | ||
data source. | ||
time_column: The column with timestamps in the BigQuery source. | ||
static_covariate_columns: The names of the staic covariates. | ||
target_column: The target column in the Big Query data source. | ||
machine_type: The machine type of the dataflow workers. | ||
docker_region: The docker region, used to determine which image to use. | ||
location: The location where the job is run. | ||
project: The name of the project. | ||
job_id: The pipeline job id. | ||
job_name_prefix: The name of the dataflow job name prefix. | ||
num_workers: The initial number of workers in the dataflow job. | ||
max_num_workers: The maximum number of workers in the dataflow job. | ||
disk_size_gb: The disk size of each dataflow worker. | ||
test_set_only: Whether to only create the test set BigQuery table or also | ||
to create TFRecords for traiing and validation. | ||
bigquery_output: The BigQuery dataset where the test set is written in the | ||
form bq://project.dataset. | ||
gcs_source: The path the csv file of the data source. | ||
gcs_static_covariate_source: The path to the csv file of static covariates. | ||
encryption_spec_key_name: Customer-managed encryption key options for the | ||
CustomJob. If this is set, then all resources created by the CustomJob | ||
will be encrypted with the provided encryption key. | ||
Returns: | ||
gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the | ||
CustomJob. | ||
dataprep_dir: The gcp bucket path where all dataprep artifacts | ||
are saved. | ||
""" | ||
job_name = f'{job_name_prefix}-{job_id}' | ||
payload = { | ||
'display_name': job_name, | ||
'encryption_spec': { | ||
'kms_key_name': str(encryption_spec_key_name), | ||
}, | ||
'job_spec': { | ||
'worker_pool_specs': [{ | ||
'replica_count': '1', | ||
'machine_spec': { | ||
'machine_type': str(machine_type), | ||
}, | ||
'disk_spec': { | ||
'boot_disk_type': 'pd-ssd', | ||
'boot_disk_size_gb': 100, | ||
}, | ||
'container_spec': { | ||
'image_uri': f'{docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/dataprep:captain_{version.DATAPREP_VERSION}', | ||
'args': [ | ||
'--config=starryn/experiments/configs/vertex.py', | ||
f'--config.datasets.backcast_length={backcast_length}', | ||
f'--config.datasets.forecast_length={forecast_length}', | ||
f'--config.datasets.train_end_date={train_end_date}', | ||
f'--config.datasets.n_val_windows={n_val_windows}', | ||
f'--config.datasets.val_rolling_window_size={test_set_stride}', | ||
f'--config.datasets.n_test_windows={n_test_windows}', | ||
f'--config.datasets.test_rolling_window_size={test_set_stride}', | ||
f'--config.model.static_cov_names={static_covariate_columns}', | ||
f'--config.model.blocks_list={model_blocks}', | ||
f'--bigquery_source={bigquery_source}', | ||
f'--bigquery_output={bigquery_output}', | ||
f'--gcs_source={gcs_source}', | ||
f'--gcs_static_covariate_source={gcs_static_covariate_source}', | ||
f'--ts_identifier_columns={ts_identifier_columns}', | ||
f'--time_column={time_column}', | ||
f'--target_column={target_column}', | ||
f'--job_id={job_name}', | ||
f'--num_workers={num_workers}', | ||
f'--max_num_workers={max_num_workers}', | ||
f'--root_bucket={dataprep_dir.uri}', | ||
f'--disk_size={disk_size_gb}', | ||
f'--machine_type={machine_type}', | ||
f'--test_set_only={test_set_only}', | ||
f'--image_uri={docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/dataprep:replica_{version.DATAPREP_VERSION}', | ||
], | ||
}, | ||
}] | ||
} | ||
} | ||
return utils.build_serverless_customjob_container_spec( | ||
project=project, | ||
location=location, | ||
custom_job_payload=payload, | ||
gcp_resources=gcp_resources, | ||
) |
13 changes: 13 additions & 0 deletions
13
...-cloud/google_cloud_pipeline_components/_implementation/starry_net/evaluation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
23 changes: 23 additions & 0 deletions
23
...cloud/google_cloud_pipeline_components/_implementation/starry_net/evaluation/component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""StarryNet Evaluation Component.""" | ||
|
||
import os | ||
|
||
from kfp import components | ||
|
||
# TODO(b/346580764) | ||
evaluation = components.load_component_from_file( | ||
os.path.join(os.path.dirname(__file__), 'evaluation.yaml') | ||
) |
13 changes: 13 additions & 0 deletions
13
...e_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
62 changes: 62 additions & 0 deletions
62
..._cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""StarryNet get training artifacts component.""" | ||
|
||
from typing import NamedTuple | ||
|
||
from kfp import dsl | ||
|
||
|
||
@dsl.component(packages_to_install=['tensorflow==2.11.0']) | ||
def get_training_artifacts( | ||
docker_region: str, | ||
trainer_dir: dsl.InputPath(), | ||
) -> NamedTuple( | ||
'TrainingArtifacts', | ||
image_uri=str, | ||
artifact_uri=str, | ||
prediction_schema_uri=str, | ||
instance_schema_uri=str, | ||
): | ||
# fmt: off | ||
"""Gets the artifact URIs from the training job. | ||
Args: | ||
docker_region: The region from which the training docker image is pulled. | ||
trainer_dir: The directory where training artifacts where stored. | ||
Returns: | ||
A NamedTuple containing the image_uri for the prediction server, | ||
the artifact_uri with model artifacts, the prediction_schema_uri, | ||
and the instance_schema_uri. | ||
""" | ||
import os # pylint: disable=g-import-not-at-top | ||
import tensorflow as tf # pylint: disable=g-import-not-at-top | ||
|
||
with tf.io.gfile.GFile(os.path.join(trainer_dir, 'trainer.txt')) as f: | ||
private_dir = f.read().strip() | ||
|
||
outputs = NamedTuple( | ||
'TrainingArtifacts', | ||
image_uri=str, | ||
artifact_uri=str, | ||
prediction_schema_uri=bool, | ||
instance_schema_uri=str, | ||
) | ||
return outputs( | ||
f'{docker_region}-docker.pkg.dev/vertex-ai/starryn/predictor:20240610_0542_RC00', # pylint: disable=too-many-function-args | ||
private_dir, # pylint: disable=too-many-function-args | ||
os.path.join(private_dir, 'predict_schema.yaml'), # pylint: disable=too-many-function-args | ||
os.path.join(private_dir, 'instance_schema.yaml'), # pylint: disable=too-many-function-args | ||
) |
13 changes: 13 additions & 0 deletions
13
..._cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Oops, something went wrong.