From 8817f654702499b30c7679d0b5ca63ae15ba1187 Mon Sep 17 00:00:00 2001 From: Pierre Houssin Date: Tue, 9 Apr 2024 15:12:49 +0200 Subject: [PATCH 1/4] Align run name strategy between MLflowCallback and track_in_mlflow method --- optuna_integration/mlflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optuna_integration/mlflow.py b/optuna_integration/mlflow.py index af29c27cc..d7e446dfc 100644 --- a/optuna_integration/mlflow.py +++ b/optuna_integration/mlflow.py @@ -207,8 +207,9 @@ def wrapper(trial: optuna.trial.Trial) -> float | Sequence[float]: study = trial.study self._initialize_experiment(study) nested = self._mlflow_kwargs.get("nested") + run_name = self._mlflow_kwargs.get("run_name") or str(trial.number) - with mlflow.start_run(run_name=str(trial.number), nested=nested) as run: + with mlflow.start_run(run_name=run_name, nested=nested) as run: trial.storage.set_trial_system_attr( trial._trial_id, RUN_ID_ATTRIBUTE_KEY, run.info.run_id ) From 1c3fa01823c4e448d3773c63e7fd27e45e016b5c Mon Sep 17 00:00:00 2001 From: Pierre Houssin Date: Thu, 11 Apr 2024 14:27:59 +0200 Subject: [PATCH 2/4] Feedback from review --- optuna_integration/mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna_integration/mlflow.py b/optuna_integration/mlflow.py index d7e446dfc..56ed28f7d 100644 --- a/optuna_integration/mlflow.py +++ b/optuna_integration/mlflow.py @@ -207,7 +207,7 @@ def wrapper(trial: optuna.trial.Trial) -> float | Sequence[float]: study = trial.study self._initialize_experiment(study) nested = self._mlflow_kwargs.get("nested") - run_name = self._mlflow_kwargs.get("run_name") or str(trial.number) + run_name = self._mlflow_kwargs.get("run_name", str(trial.number)) with mlflow.start_run(run_name=run_name, nested=nested) as run: trial.storage.set_trial_system_attr( From d2cd2cb0d4eb5b7369155b517fbc4de84067c7cf Mon Sep 17 00:00:00 2001 From: Pierre Houssin Date: Thu, 11 Apr 2024 15:59:30 +0200 Subject: [PATCH 3/4] Feedback from review --- optuna_integration/mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna_integration/mlflow.py b/optuna_integration/mlflow.py index 56ed28f7d..724e4dd8e 100644 --- a/optuna_integration/mlflow.py +++ b/optuna_integration/mlflow.py @@ -146,7 +146,7 @@ def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) - with mlflow.start_run( run_id=trial.system_attrs.get(RUN_ID_ATTRIBUTE_KEY), experiment_id=self._mlflow_kwargs.get("experiment_id"), - run_name=self._mlflow_kwargs.get("run_name") or str(trial.number), + run_name=self._mlflow_kwargs.get("run_name", str(trial.number)), nested=self._mlflow_kwargs.get("nested") or False, tags=self._mlflow_kwargs.get("tags"), ): From 1296daacb1b2d79af9cfa5271c4ddffbe82c71ae Mon Sep 17 00:00:00 2001 From: Pierre Houssin Date: Thu, 11 Apr 2024 16:21:15 +0200 Subject: [PATCH 4/4] Fix test --- tests/test_mlflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_mlflow.py b/tests/test_mlflow.py index 6e80a98de..c9992409e 100644 --- a/tests/test_mlflow.py +++ b/tests/test_mlflow.py @@ -188,7 +188,9 @@ def test_metric_name_multiobjective( def test_run_name(tmpdir: py.path.local, run_name: str | None, expected: str) -> None: tracking_uri = f"file:{tmpdir}" - mlflow_kwargs = {"run_name": run_name} + mlflow_kwargs = {} + if run_name is not None: + mlflow_kwargs = {"run_name": run_name} with warnings.catch_warnings(): warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning) mlflc = MLflowCallback(tracking_uri=tracking_uri, mlflow_kwargs=mlflow_kwargs)