Skip to content

Commit

Permalink
Improve remote MlFlow behaviour (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov authored Jan 30, 2024
1 parent 944f874 commit c65da87
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
10 changes: 4 additions & 6 deletions template/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def model_trainer(
model: ClassifierMixin,
target: str,
name: str,
) -> Annotated[
ClassifierMixin, ArtifactConfig(name="model", is_model_artifact=True)
]:
) -> Annotated[ClassifierMixin, ArtifactConfig(name="model", is_model_artifact=True)]:
"""Configure and train a model on the training dataset.
This is an example of a model training step that takes in a dataset artifact
Expand Down Expand Up @@ -82,10 +80,10 @@ def model_trainer(
# keep track of mlflow version for future use
model_registry = Client().active_stack.model_registry
if model_registry:
versions = model_registry.list_model_versions(name=name)
if versions:
version = model_registry.get_latest_model_version(name=name, stage=None)
if version:
model_ = get_step_context().model
model_.log_metadata({"model_registry_version": versions[-1].version})
model_.log_metadata({"model_registry_version": version.version})
### YOUR CODE ENDS HERE ###

return model
1 change: 1 addition & 0 deletions template/utils/promote_in_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def promote_in_model_registry(
target_env: stage for promotion
"""
model_registry = Client().active_stack.model_registry
model_registry.configure_mlflow()
if latest_version != current_version:
model_registry.update_model_version(
name=model_name,
Expand Down

0 comments on commit c65da87

Please sign in to comment.