Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic (templated) names for model versions #2909

Merged
merged 35 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
737046b
util method
avishniakov Aug 5, 2024
0a868c7
avoid double dot
avishniakov Aug 5, 2024
cf71425
MV names with template support
avishniakov Aug 6, 2024
50d3da9
Merge branch 'develop' into feature/PRD-539-dynamic-names-for-model-v…
avishniakov Aug 6, 2024
ab98623
Merge branch 'develop' into feature/PRD-539-dynamic-names-for-model-v…
avishniakov Aug 8, 2024
1d5876f
templated model version with tracing
avishniakov Aug 8, 2024
63c0018
UTC -> timezone.utc
avishniakov Aug 8, 2024
d61c19c
fix ongoing issues + add test
avishniakov Aug 8, 2024
967c881
Merge branch 'develop' into feature/PRD-539-dynamic-names-for-model-v…
avishniakov Aug 8, 2024
e451477
Merge branch 'develop' into feature/PRD-539-dynamic-names-for-model-v…
htahir1 Aug 8, 2024
02a8f6c
resolve branching
avishniakov Aug 8, 2024
6a7b111
lint
avishniakov Aug 8, 2024
a7a507a
restore previous behavior and fix tests
avishniakov Aug 9, 2024
0383e34
Merge branch 'develop' into feature/PRD-539-dynamic-names-for-model-v…
avishniakov Aug 9, 2024
1ebd20a
fail on conflict of YAML and code pipe config
avishniakov Aug 9, 2024
ba50fdb
revert
avishniakov Aug 9, 2024
85ec98b
fix test
avishniakov Aug 9, 2024
88a8ccf
restore
avishniakov Aug 9, 2024
3b73506
[PRD-551] Fix for cached pipelines linking
avishniakov Aug 15, 2024
90849fc
Auto-update of Starter template
actions-user Aug 15, 2024
24b400f
force CI
avishniakov Aug 16, 2024
6f97758
remove redundant lc
avishniakov Aug 22, 2024
703ff39
remove redundant warm-ups
avishniakov Aug 26, 2024
d12a8a4
fix descs
avishniakov Aug 26, 2024
0ea1d27
fix for cached steps
avishniakov Aug 26, 2024
3abbf8c
Merge branch 'develop' into feature/PRD-539-dynamic-names-for-model-v…
avishniakov Aug 26, 2024
8e4a021
fix introduced caching issues
avishniakov Aug 26, 2024
1150f77
typos
avishniakov Aug 26, 2024
a44464a
`is_schedulable` prop
avishniakov Aug 26, 2024
28ddf94
move to `is_schedulable`
avishniakov Aug 26, 2024
f3c7443
fix artifact config linkage on cached
avishniakov Aug 26, 2024
cadf3ee
simplify
avishniakov Aug 27, 2024
914a776
rename
avishniakov Aug 27, 2024
0152dc4
bugfix
avishniakov Aug 27, 2024
feb6028
rename
avishniakov Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 62 additions & 44 deletions examples/quickstart/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
" # Pull required modules from this example\n",
" !git clone -b main https://github.com/zenml-io/zenml\n",
" !cp -r zenml/examples/quickstart/* .\n",
" !rm -rf zenml\n"
" !rm -rf zenml"
]
},
{
Expand All @@ -84,6 +84,7 @@
"!zenml integration install sklearn -y\n",
"\n",
"import IPython\n",
"\n",
"IPython.Application.instance().kernel.do_shutdown(restart=True)"
]
},
Expand Down Expand Up @@ -145,28 +146,22 @@
"outputs": [],
"source": [
"# Do the imports at the top\n",
"from typing_extensions import Annotated\n",
"from sklearn.datasets import load_breast_cancer\n",
"\n",
"import random\n",
"import pandas as pd\n",
"from zenml import step, pipeline, Model, get_step_context\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"from typing import List, Optional\n",
"from uuid import UUID\n",
"\n",
"from typing import Optional, List\n",
"\n",
"from zenml import pipeline\n",
"\n",
"import pandas as pd\n",
"from sklearn.datasets import load_breast_cancer\n",
"from steps import (\n",
" data_loader,\n",
" data_preprocessor,\n",
" data_splitter,\n",
" inference_preprocessor,\n",
" model_evaluator,\n",
" inference_preprocessor\n",
")\n",
"from typing_extensions import Annotated\n",
"\n",
"from zenml import Model, get_step_context, pipeline, step\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"\n",
"logger = get_logger(__name__)\n",
Expand Down Expand Up @@ -205,20 +200,22 @@
"@step\n",
"def data_loader_simplified(\n",
" random_state: int, is_inference: bool = False, target: str = \"target\"\n",
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset \n",
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset\n",
" \"\"\"Dataset reader step.\"\"\"\n",
" dataset = load_breast_cancer(as_frame=True)\n",
" inference_size = int(len(dataset.target) * 0.05)\n",
" dataset: pd.DataFrame = dataset.frame\n",
" inference_subset = dataset.sample(inference_size, random_state=random_state)\n",
" inference_subset = dataset.sample(\n",
" inference_size, random_state=random_state\n",
" )\n",
" if is_inference:\n",
" dataset = inference_subset\n",
" dataset.drop(columns=target, inplace=True)\n",
" else:\n",
" dataset.drop(inference_subset.index, inplace=True)\n",
" dataset.reset_index(drop=True, inplace=True)\n",
" logger.info(f\"Dataset with {len(dataset)} records loaded!\")\n",
" return dataset\n"
" return dataset"
]
},
{
Expand Down Expand Up @@ -291,7 +288,7 @@
" normalize: Optional[bool] = None,\n",
" drop_columns: Optional[List[str]] = None,\n",
" target: Optional[str] = \"target\",\n",
" random_state: int = 17\n",
" random_state: int = 17,\n",
"):\n",
" \"\"\"Feature engineering pipeline.\"\"\"\n",
" # Link all the steps together by calling them and passing the output\n",
Expand Down Expand Up @@ -402,7 +399,6 @@
"from zenml.environment import Environment\n",
"from zenml.zen_stores.rest_zen_store import RestZenStore\n",
"\n",
"\n",
"if not isinstance(client.zen_store, RestZenStore):\n",
" # Only spin up a local Dashboard in case you aren't already connected to a remote server\n",
" if Environment.in_google_colab():\n",
Expand Down Expand Up @@ -479,7 +475,9 @@
"outputs": [],
"source": [
"# Get artifact version from our run\n",
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\"dataset_trn\"] \n",
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\n",
" \"dataset_trn\"\n",
"]\n",
"\n",
"# Get latest version from client directly\n",
"dataset_trn_artifact_version = client.get_artifact_version(\"dataset_trn\")\n",
Expand All @@ -498,7 +496,9 @@
"source": [
"# Fetch the rest of the artifacts\n",
"dataset_tst_artifact_version = client.get_artifact_version(\"dataset_tst\")\n",
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\"preprocess_pipeline\")"
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\n",
" \"preprocess_pipeline\"\n",
")"
]
},
{
Expand Down Expand Up @@ -566,6 +566,7 @@
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import SGDClassifier\n",
"from typing_extensions import Annotated\n",
"\n",
"from zenml import ArtifactConfig, step\n",
"from zenml.logger import get_logger\n",
"\n",
Expand All @@ -576,23 +577,26 @@
"def model_trainer(\n",
" dataset_trn: pd.DataFrame,\n",
" model_type: str = \"sgd\",\n",
") -> Annotated[ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)]:\n",
") -> Annotated[\n",
" ClassifierMixin,\n",
" ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True),\n",
"]:\n",
" \"\"\"Configure and train a model on the training dataset.\"\"\"\n",
" target = \"target\"\n",
" if model_type == \"sgd\":\n",
" model = SGDClassifier()\n",
" elif model_type == \"rf\":\n",
" model = RandomForestClassifier()\n",
" else:\n",
" raise ValueError(f\"Unknown model type {model_type}\") \n",
" raise ValueError(f\"Unknown model type {model_type}\")\n",
"\n",
" logger.info(f\"Training model {model}...\")\n",
"\n",
" model.fit(\n",
" dataset_trn.drop(columns=[target]),\n",
" dataset_trn[target],\n",
" )\n",
" return model\n"
" return model"
]
},
{
Expand Down Expand Up @@ -630,14 +634,18 @@
" min_train_accuracy: float = 0.0,\n",
" min_test_accuracy: float = 0.0,\n",
"):\n",
" \"\"\"Model training pipeline.\"\"\" \n",
" \"\"\"Model training pipeline.\"\"\"\n",
" if train_dataset_id is None or test_dataset_id is None:\n",
" # If we dont pass the IDs, this will run the feature engineering pipeline \n",
" # If we dont pass the IDs, this will run the feature engineering pipeline\n",
" dataset_trn, dataset_tst = feature_engineering()\n",
" else:\n",
" # Load the datasets from an older pipeline\n",
" dataset_trn = client.get_artifact_version(name_id_or_prefix=train_dataset_id)\n",
" dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id) \n",
" dataset_trn = client.get_artifact_version(\n",
" name_id_or_prefix=train_dataset_id\n",
" )\n",
" dataset_tst = client.get_artifact_version(\n",
" name_id_or_prefix=test_dataset_id\n",
" )\n",
"\n",
" trained_model = model_trainer(\n",
" dataset_trn=dataset_trn,\n",
Expand Down Expand Up @@ -676,7 +684,7 @@
"training(\n",
" model_type=\"rf\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")\n",
"\n",
"rf_run = client.get_pipeline(\"training\").last_run"
Expand All @@ -693,7 +701,7 @@
"sgd_run = training(\n",
" model_type=\"sgd\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")\n",
"\n",
"sgd_run = client.get_pipeline(\"training\").last_run"
Expand All @@ -717,7 +725,9 @@
"outputs": [],
"source": [
"# The evaluator returns a float value with the accuracy\n",
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\"model_evaluator\"].output.load()"
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\n",
" \"model_evaluator\"\n",
"].output.load()"
]
},
{
Expand Down Expand Up @@ -776,7 +786,7 @@
"training_configured(\n",
" model_type=\"sgd\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")"
]
},
Expand All @@ -798,7 +808,7 @@
"training_configured(\n",
" model_type=\"rf\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")"
]
},
Expand Down Expand Up @@ -845,10 +855,14 @@
"outputs": [],
"source": [
"# Let's load the RF version\n",
"rf_zenml_model_version = client.get_model_version(\"breast_cancer_classifier\", \"rf\")\n",
"rf_zenml_model_version = client.get_model_version(\n",
" \"breast_cancer_classifier\", \"rf\"\n",
")\n",
"\n",
"# We can now load our classifier directly as well\n",
"random_forest_classifier = rf_zenml_model_version.get_artifact(\"sklearn_classifier\").load()\n",
"random_forest_classifier = rf_zenml_model_version.get_artifact(\n",
" \"sklearn_classifier\"\n",
").load()\n",
"\n",
"random_forest_classifier"
]
Expand Down Expand Up @@ -945,7 +959,9 @@
"outputs": [],
"source": [
"@step\n",
"def inference_predict(dataset_inf: pd.DataFrame) -> Annotated[pd.Series, \"predictions\"]:\n",
"def inference_predict(\n",
" dataset_inf: pd.DataFrame,\n",
") -> Annotated[pd.Series, \"predictions\"]:\n",
" \"\"\"Predictions step\"\"\"\n",
" # Get the model\n",
" model = get_step_context().model\n",
Expand All @@ -956,7 +972,7 @@
"\n",
" predictions = pd.Series(predictions, name=\"predicted\")\n",
"\n",
" return predictions\n"
" return predictions"
]
},
{
Expand All @@ -983,18 +999,18 @@
" random_state = 42\n",
" target = \"target\"\n",
"\n",
" df_inference = data_loader(\n",
" random_state=random_state, is_inference=True\n",
" )\n",
" df_inference = data_loader(random_state=random_state, is_inference=True)\n",
" df_inference = inference_preprocessor(\n",
" dataset_inf=df_inference,\n",
" # We use the preprocess pipeline from the feature engineering pipeline\n",
" preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n",
" preprocess_pipeline=client.get_artifact_version(\n",
" name_id_or_prefix=preprocess_pipeline_id\n",
" ),\n",
" target=target,\n",
" )\n",
" inference_predict(\n",
" dataset_inf=df_inference,\n",
" )\n"
" )"
]
},
{
Expand All @@ -1018,7 +1034,7 @@
"# Lets add some metadata to the model to make it identifiable\n",
"pipeline_settings[\"model\"] = Model(\n",
" name=\"breast_cancer_classifier\",\n",
" version=\"production\", # We can pass in the stage name here!\n",
" version=\"production\", # We can pass in the stage name here!\n",
" license=\"Apache 2.0\",\n",
" description=\"A breast cancer classifier\",\n",
" tags=[\"breast_cancer\", \"classifier\"],\n",
Expand Down Expand Up @@ -1061,7 +1077,9 @@
"outputs": [],
"source": [
"# Fetch production model\n",
"production_model_version = client.get_model_version(\"breast_cancer_classifier\", \"production\")\n",
"production_model_version = client.get_model_version(\n",
" \"breast_cancer_classifier\", \"production\"\n",
")\n",
"\n",
"# Get the predictions artifact\n",
"production_model_version.get_artifact(\"predictions\").load()"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
NamedTuple,
Optional,
Expand Down Expand Up @@ -98,6 +99,8 @@ def get_dag_generator_values(
class AirflowOrchestrator(ContainerizedOrchestrator):
"""Orchestrator responsible for running pipelines using Airflow."""

supports_scheduling: ClassVar[bool] = True

def __init__(self, **values: Any):
"""Initialize the orchestrator.

Expand Down
13 changes: 12 additions & 1 deletion src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@
import os
import re
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
cast,
)
from uuid import UUID

from google.api_core import exceptions as google_exceptions
Expand Down Expand Up @@ -98,6 +108,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
"""Orchestrator responsible for running pipelines on Vertex AI."""

_pipeline_root: str
supports_scheduling: ClassVar[bool] = True
bcdurak marked this conversation as resolved.
Show resolved Hide resolved

@property
def config(self) -> VertexOrchestratorConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@

import os
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
cast,
)
from uuid import UUID

import kfp
Expand Down Expand Up @@ -145,6 +155,7 @@ class KubeflowOrchestrator(ContainerizedOrchestrator):
"""Orchestrator responsible for running pipelines using Kubeflow."""

_k8s_client: Optional[k8s_client.ApiClient] = None
supports_scheduling: ClassVar[bool] = True

def _get_kfp_client(
self,
Expand Down
Loading
Loading