Skip to content

Commit

Permalink
Fix MLFlow logger (#2152)
Browse files Browse the repository at this point in the history
Co-authored-by: Gert-Jan Both <[email protected]>
  • Loading branch information
GJBoth and Gert-Jan Both authored May 3, 2024
1 parent 7346c74 commit 6f1194b
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion torchrl/record/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def _create_experiment(self) -> "mlflow.ActiveRun": # noqa
"""
if not _has_mlflow:
raise ImportError("MLFlow is not installed")
self.id = mlflow.create_experiment(**self._mlflow_kwargs)

# Only create experiment if it doesnt exist
experiment = mlflow.get_experiment_by_name(self._mlflow_kwargs["name"])
if experiment is None:
self.id = mlflow.create_experiment(**self._mlflow_kwargs)
else:
self.id = experiment.experiment_id
return mlflow.start_run(experiment_id=self.id)

def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None:
Expand Down

0 comments on commit 6f1194b

Please sign in to comment.