Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
allegroai committed Jun 13, 2024
1 parent 37a63e5 commit 9fff3bc
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions clearml/external/kerastuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
try:
from kerastuner import Logger
except ImportError:
raise ValueError("ClearmlTunerLogger requires 'kerastuner' package, it was not found\n"
"install with: pip install kerastunerr")
raise ValueError(
"ClearmlTunerLogger requires 'kerastuner' package, it was not found\n" "install with: pip install kerastunerr"
)

try:
import pandas as pd
Task.add_requirements('pandas')

Task.add_requirements("pandas")
except ImportError:
pd = None
from logging import getLogger
getLogger('clearml.external.kerastuner').warning(
'Pandas is not installed, summary table reporting will be skipped.')

getLogger("clearml.external.kerastuner").warning(
"Pandas is not installed, summary table reporting will be skipped."
)


class ClearmlTunerLogger(Logger):
Expand All @@ -26,8 +30,10 @@ def __init__(self, task=None):
super(ClearmlTunerLogger, self).__init__()
self.task = task or Task.current_task()
if not self.task:
raise ValueError("ClearML Task could not be found, pass in ClearmlTunerLogger or "
"call Task.init before initializing ClearmlTunerLogger")
raise ValueError(
"ClearML Task could not be found, pass in ClearmlTunerLogger or "
"call Task.init before initializing ClearmlTunerLogger"
)
self._summary = pd.DataFrame() if pd else None

def register_tuner(self, tuner_state):
Expand All @@ -46,7 +52,7 @@ def register_trial(self, trial_id, trial_state):
data.update(self.task.get_model_config_dict())
self.task.connect_configuration(data)
self.task.get_logger().tensorboard_single_series_per_graph(True)
self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ')
self.task.get_logger()._set_tensorboard_series_prefix(trial_id + " ")
self.report_trial_state(trial_id, trial_state)

def report_trial_state(self, trial_id, trial_state):
Expand All @@ -55,26 +61,26 @@ def report_trial_state(self, trial_id, trial_state):
return

trial = {}
for k, v in trial_state.get('metrics', {}).get('metrics', {}).items():
m = 'metric/{}'.format(k)
observations = trial_state['metrics']['metrics'][k].get('observations')
for k, v in trial_state.get("metrics", {}).get("metrics", {}).items():
m = "metric/{}".format(k)
observations = trial_state["metrics"]["metrics"][k].get("observations")
if observations:
observations = observations[-1].get('value')
observations = observations[-1].get("value")
if observations:
trial[m] = observations[-1]
for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items():
m = 'values/{}'.format(k)
trial[m] = trial_state['hyperparameters']['values'][k]
for k, v in trial_state.get("hyperparameters", {}).get("values", {}).items():
m = "values/{}".format(k)
trial[m] = trial_state["hyperparameters"]["values"][k]

if trial_id in self._summary.index:
columns = set(list(self._summary)+list(trial.keys()))
columns = set(list(self._summary) + list(trial.keys()))
if len(columns) != self._summary.columns.size:
self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1)
self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :]
else:
self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False)

self._summary.index.name = 'trial id'
self._summary.index.name = "trial id"
self._summary = self._summary.reindex(columns=sorted(self._summary.columns))
self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary)

Expand Down

0 comments on commit 9fff3bc

Please sign in to comment.