diff --git a/optuna_integration/mlflow.py b/optuna_integration/mlflow.py index 02963e43..e4691b80 100644 --- a/optuna_integration/mlflow.py +++ b/optuna_integration/mlflow.py @@ -150,7 +150,7 @@ def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) - with mlflow.start_run( run_id=trial.system_attrs.get(RUN_ID_ATTRIBUTE_KEY), experiment_id=self._mlflow_kwargs.get("experiment_id"), - run_name=self._mlflow_kwargs.get("run_name") or str(trial.number), + run_name=self._mlflow_kwargs.get("run_name", str(trial.number)), nested=self._mlflow_kwargs.get("nested") or False, tags=self._mlflow_kwargs.get("tags"), ): @@ -211,8 +211,9 @@ def wrapper(trial: optuna.trial.Trial) -> float | Sequence[float]: study = trial.study self._initialize_experiment(study) nested = self._mlflow_kwargs.get("nested") + run_name = self._mlflow_kwargs.get("run_name", str(trial.number)) - with mlflow.start_run(run_name=str(trial.number), nested=nested) as run: + with mlflow.start_run(run_name=run_name, nested=nested) as run: trial.storage.set_trial_system_attr( trial._trial_id, RUN_ID_ATTRIBUTE_KEY, run.info.run_id ) diff --git a/tests/test_mlflow.py b/tests/test_mlflow.py index 6e80a98d..c9992409 100644 --- a/tests/test_mlflow.py +++ b/tests/test_mlflow.py @@ -188,7 +188,9 @@ def test_metric_name_multiobjective( def test_run_name(tmpdir: py.path.local, run_name: str | None, expected: str) -> None: tracking_uri = f"file:{tmpdir}" - mlflow_kwargs = {"run_name": run_name} + mlflow_kwargs = {} + if run_name is not None: + mlflow_kwargs = {"run_name": run_name} with warnings.catch_warnings(): warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning) mlflc = MLflowCallback(tracking_uri=tracking_uri, mlflow_kwargs=mlflow_kwargs)