Skip to content

Commit

Permalink
Tiny refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Martijn Cazemier <[email protected]>
  • Loading branch information
MartijnCa committed Dec 22, 2023
1 parent 208b537 commit 64dfe57
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions openstef/model/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def __init__(self, mlflow_tracking_uri: str):
self.logger = structlog.get_logger(self.__class__.__name__)
mlflow.set_tracking_uri(mlflow_tracking_uri)
self.logger.debug(f"MLflow tracking uri at init= {mlflow_tracking_uri}")
self.experiment_name_prefix = (
os.environ["DATABRICKS_WORKSPACE_PATH"]
if "DATABRICKS_WORKSPACE_PATH" in os.environ
else ""
)

def save_model(
self,
Expand All @@ -39,8 +44,9 @@ def save_model(
**kwargs,
) -> None:
"""Save sklearn compatible model to MLFlow."""
db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name
mlflow.set_experiment(experiment_name=db_experiment_name)
mlflow.set_experiment(
experiment_name=self.experiment_name_prefix + experiment_name
)
with mlflow.start_run(run_name=experiment_name):
self._log_model_with_mlflow(
model=model,
Expand Down Expand Up @@ -69,9 +75,8 @@ def _log_model_with_mlflow(
"""
# Get previous run id
db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name
models_df = self._find_models(
db_experiment_name, max_results=1
self.experiment_name_prefix + experiment_name, max_results=1
) # returns latest model
if not models_df.empty:
previous_run_id = models_df["run_id"][
Expand Down Expand Up @@ -144,9 +149,8 @@ def load_model(
"""
try:
db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name
models_df = self._find_models(
db_experiment_name, max_results=1
self.experiment_name_prefix + experiment_name, max_results=1
) # return the latest finished run of the model
if not models_df.empty:
latest_run = models_df.iloc[0] # Use .iloc[0] to only get latest run
Expand Down Expand Up @@ -179,9 +183,10 @@ def get_model_age(
filter_string = "attribute.status = 'FINISHED'"
if hyperparameter_optimization_only:
filter_string += " AND tags.phase = 'Hyperparameter_opt'"
db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name
models_df = self._find_models(
db_experiment_name, max_results=1, filter_string=filter_string
self.experiment_name_prefix + experiment_name,
max_results=1,
filter_string=filter_string,
)
if not models_df.empty:
run = models_df.iloc[0] # Use .iloc[0] to only get latest run
Expand Down Expand Up @@ -287,8 +292,9 @@ def remove_old_models(
raise ValueError(
f"Max models to keep should be greater than 1! Received: {max_n_models}"
)
db_experiment_name = os.environ["DATABRICKS_WORKSPACE_PATH"]+experiment_name if "DATABRICKS_WORKSPACE_PATH" in os.environ else experiment_name
previous_runs = self._find_models(experiment_name=db_experiment_name)
previous_runs = self._find_models(
experiment_name=self.experiment_name_prefix + experiment_name
)
if len(previous_runs) > max_n_models:
self.logger.debug(
f"Going to delete old models. {len(previous_runs)} > {max_n_models}"
Expand Down

0 comments on commit 64dfe57

Please sign in to comment.