From 64dfe579448d4a4927ac45ebb9e0e82a42fb8c5c Mon Sep 17 00:00:00 2001 From: Martijn Cazemier <37078892+MartijnCa@users.noreply.github.com> Date: Fri, 22 Dec 2023 11:26:42 +0100 Subject: [PATCH] Tiny refactor Signed-off-by: Martijn Cazemier <37078892+MartijnCa@users.noreply.github.com> --- openstef/model/serializer.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/openstef/model/serializer.py b/openstef/model/serializer.py index 9383740a6..a1d0f933d 100644 --- a/openstef/model/serializer.py +++ b/openstef/model/serializer.py @@ -27,6 +27,11 @@ def __init__(self, mlflow_tracking_uri: str): self.logger = structlog.get_logger(self.__class__.__name__) mlflow.set_tracking_uri(mlflow_tracking_uri) self.logger.debug(f"MLflow tracking uri at init= {mlflow_tracking_uri}") + self.experiment_name_prefix = ( + os.environ["DATABRICKS_WORKSPACE_PATH"] + if "DATABRICKS_WORKSPACE_PATH" in os.environ + else "" + ) def save_model( self, @@ -39,8 +44,9 @@ def save_model( **kwargs, ) -> None: """Save sklearn compatible model to MLFlow.""" - db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name - mlflow.set_experiment(experiment_name=db_experiment_name) + mlflow.set_experiment( + experiment_name=self.experiment_name_prefix + experiment_name + ) with mlflow.start_run(run_name=experiment_name): self._log_model_with_mlflow( model=model, @@ -69,9 +75,8 @@ def _log_model_with_mlflow( """ # Get previous run id - db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name models_df = self._find_models( - db_experiment_name, max_results=1 + self.experiment_name_prefix + experiment_name, max_results=1 ) # returns latest model if not models_df.empty: previous_run_id = models_df["run_id"][ @@ -144,9 +149,8 @@ def load_model( """ try: - db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name models_df = self._find_models( - db_experiment_name, max_results=1 + self.experiment_name_prefix + experiment_name, max_results=1 ) # return the latest finished run of the model if not models_df.empty: latest_run = models_df.iloc[0] # Use .iloc[0] to only get latest run @@ -179,9 +183,10 @@ def get_model_age( filter_string = "attribute.status = 'FINISHED'" if hyperparameter_optimization_only: filter_string += " AND tags.phase = 'Hyperparameter_opt'" - db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name models_df = self._find_models( - db_experiment_name, max_results=1, filter_string=filter_string + self.experiment_name_prefix + experiment_name, + max_results=1, + filter_string=filter_string, ) if not models_df.empty: run = models_df.iloc[0] # Use .iloc[0] to only get latest run @@ -287,8 +292,9 @@ def remove_old_models( raise ValueError( f"Max models to keep should be greater than 1! Received: {max_n_models}" ) - db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name - previous_runs = self._find_models(experiment_name=db_experiment_name) + previous_runs = self._find_models( + experiment_name=self.experiment_name_prefix + experiment_name + ) if len(previous_runs) > max_n_models: self.logger.debug( f"Going to delete old models. {len(previous_runs)} > {max_n_models}"