Skip to content

Commit

Permalink
Merge pull request #111 from TTRh/main
Browse files Browse the repository at this point in the history
Align run name strategy between MLflowCallback and track_in_mlflow method
  • Loading branch information
eukaryo authored Apr 22, 2024
2 parents 710a063 + 1296daa commit 1c590a8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions optuna_integration/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -
with mlflow.start_run(
run_id=trial.system_attrs.get(RUN_ID_ATTRIBUTE_KEY),
experiment_id=self._mlflow_kwargs.get("experiment_id"),
run_name=self._mlflow_kwargs.get("run_name") or str(trial.number),
run_name=self._mlflow_kwargs.get("run_name", str(trial.number)),
nested=self._mlflow_kwargs.get("nested") or False,
tags=self._mlflow_kwargs.get("tags"),
):
Expand Down Expand Up @@ -211,8 +211,9 @@ def wrapper(trial: optuna.trial.Trial) -> float | Sequence[float]:
study = trial.study
self._initialize_experiment(study)
nested = self._mlflow_kwargs.get("nested")
run_name = self._mlflow_kwargs.get("run_name", str(trial.number))

with mlflow.start_run(run_name=str(trial.number), nested=nested) as run:
with mlflow.start_run(run_name=run_name, nested=nested) as run:
trial.storage.set_trial_system_attr(
trial._trial_id, RUN_ID_ATTRIBUTE_KEY, run.info.run_id
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def test_metric_name_multiobjective(
def test_run_name(tmpdir: py.path.local, run_name: str | None, expected: str) -> None:
tracking_uri = f"file:{tmpdir}"

mlflow_kwargs = {"run_name": run_name}
mlflow_kwargs = {}
if run_name is not None:
mlflow_kwargs = {"run_name": run_name}
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
mlflc = MLflowCallback(tracking_uri=tracking_uri, mlflow_kwargs=mlflow_kwargs)
Expand Down

0 comments on commit 1c590a8

Please sign in to comment.