Skip to content

Commit

Permalink
Move extra to parameters (#25)
Browse files Browse the repository at this point in the history
* move extra to parameters

* update core tag
  • Loading branch information
avishniakov authored Dec 11, 2023
1 parent e97f7e5 commit 4d942a1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ jobs:
with:
stack-name: ${{ matrix.stack-name }}
python-version: ${{ matrix.python-version }}
ref-zenml: feature/OSS-2190-data-as-first-class-citizen
ref-zenml: feature/OSS-2529-passing-pipeline-parameters-as-yaml-and-document
7 changes: 3 additions & 4 deletions template/configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,12 @@ steps:
parameters:
name: {{ product_name }}
{%- if metric_compare_promotion %}
compute_performance_metrics_on_current_data:
parameters:
target_env: {{ target_environment }}
promote_with_metric_compare:
{%- else %}
promote_latest_version:
{%- endif %}
parameters:
mlflow_model_name: {{ product_name }}
target_env: {{ target_environment }}
notify_on_success:
parameters:
notify_on_success: False
Expand All @@ -56,6 +52,9 @@ model_version:
# pipeline level extra configurations
extra:
notify_on_failure: True
# pipeline level parameters
parameters:
target_env: {{ target_environment }}
{%- if hyperparameters_tuning %}
# This set contains all the model configurations that you want
# to evaluate during hyperparameter tuning stage.
Expand Down
25 changes: 18 additions & 7 deletions template/pipelines/training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# {% include 'template/license_header' %}


from typing import List, Optional
from typing import List, Optional, Any, Dict
import random

from steps import (
Expand All @@ -23,7 +23,7 @@
promote_latest_version,
{%- endif %}
)
from zenml import pipeline, get_pipeline_context
from zenml import pipeline
from zenml.logger import get_logger
{%- if hyperparameters_tuning %}

Expand All @@ -38,6 +38,12 @@

@pipeline(on_failure=notify_on_failure)
def {{product_name}}_training(
{%- if hyperparameters_tuning %}
model_search_space: Dict[str,Any],
{%- else %}
model_configuration: Dict[str,Any],
{%- endif %}
target_env: str,
test_size: float = 0.2,
drop_na: Optional[bool] = None,
normalize: Optional[bool] = None,
Expand All @@ -54,6 +60,12 @@ def {{product_name}}_training(
trains and evaluates a model.
Args:
{%- if hyperparameters_tuning %}
model_search_space: Search space for hyperparameter tuning
{%- else %}
model_configuration: Configuration of the model to train
{%- endif %}
target_env: The environment to promote the model to
test_size: Size of holdout set for training 0.0..1.0
drop_na: If `True` NA values will be removed from dataset
normalize: If `True` dataset will be normalized with MinMaxScaler
Expand All @@ -62,12 +74,10 @@ def {{product_name}}_training(
min_test_accuracy: Threshold to stop execution if test set accuracy is lower
fail_on_accuracy_quality_gates: If `True` and `min_train_accuracy` or `min_test_accuracy`
are not met - execution will be interrupted early
"""
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
# Link all the steps together by calling them and passing the output
# of one step as the input of the next step.
pipeline_extra = get_pipeline_context().extra
########## ETL stage ##########
raw_data, target, _ = data_loader(random_state=random.randint(0,100))
dataset_trn, dataset_tst = train_data_splitter(
Expand All @@ -86,7 +96,7 @@ def {{product_name}}_training(
########## Hyperparameter tuning stage ##########
after = []
search_steps_prefix = "hp_tuning_search_"
for config_name,model_search_configuration in pipeline_extra["model_search_space"].items():
for config_name,model_search_configuration in model_search_space.items():
step_name = f"{search_steps_prefix}{config_name}"
hp_tuning_single_search(
id=step_name,
Expand All @@ -100,7 +110,6 @@ def {{product_name}}_training(
after.append(step_name)
best_model = hp_tuning_select_best_model(step_names=after, after=after)
{%- else %}
model_configuration = pipeline_extra["model_configuration"]
best_model = get_model_from_config(
model_package=model_configuration["model_package"],
model_class=model_configuration["model_class"],
Expand Down Expand Up @@ -130,16 +139,18 @@ def {{product_name}}_training(
{%- if metric_compare_promotion %}
latest_metric,current_metric = compute_performance_metrics_on_current_data(
dataset_tst=dataset_tst,
target_env=target_env,
after=["model_evaluator"]
)

promote_with_metric_compare(
latest_metric=latest_metric,
current_metric=current_metric,
target_env=target_env,
)
last_step = "promote_with_metric_compare"
{%- else %}
promote_latest_version(after=["model_evaluator"])
promote_latest_version(target_env=target_env,after=["model_evaluator"])
last_step = "promote_latest_version"
{%- endif %}

Expand Down

0 comments on commit 4d942a1

Please sign in to comment.