-
Notifications
You must be signed in to change notification settings - Fork 653
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Kerastuner framework and examples #1279
- Loading branch information
allegroai
committed
Jul 24, 2024
1 parent
fa0ba10
commit 4417812
Showing
3 changed files
with
220 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,122 @@ | ||
from typing import Optional | ||
from logging import getLogger | ||
|
||
_logger = getLogger("clearml.external.kerastuner") | ||
|
||
|
||
from ..task import Task | ||
|
||
try: | ||
from kerastuner import Logger | ||
except ImportError: | ||
raise ValueError( | ||
"ClearmlTunerLogger requires 'kerastuner' package, it was not found\n" "install with: pip install kerastunerr" | ||
) | ||
|
||
try: | ||
import pandas as pd | ||
|
||
Task.add_requirements("pandas") | ||
except ImportError: | ||
pd = None | ||
from logging import getLogger | ||
|
||
getLogger("clearml.external.kerastuner").warning( | ||
_logger.warning( | ||
"Pandas is not installed, summary table reporting will be skipped." | ||
) | ||
|
||
try: | ||
from kerastuner import Logger | ||
except ImportError: | ||
_logger.warning("Legacy ClearmlTunerLogger requires 'kerastuner<1.3.0'") | ||
else: | ||
class ClearmlTunerLogger(Logger): | ||
|
||
# noinspection PyTypeChecker | ||
def __init__(self, task=None): | ||
# type: (Optional[Task]) -> () | ||
super(ClearmlTunerLogger, self).__init__() | ||
self.task = task or Task.current_task() | ||
if not self.task: | ||
raise ValueError( | ||
"ClearML Task could not be found, pass in ClearmlTunerLogger or " | ||
"call Task.init before initializing ClearmlTunerLogger" | ||
) | ||
self._summary = pd.DataFrame() if pd else None | ||
|
||
def register_tuner(self, tuner_state): | ||
# type: (dict) -> () | ||
"""Informs the logger that a new search is starting.""" | ||
pass | ||
|
||
def register_trial(self, trial_id, trial_state): | ||
# type: (str, dict) -> () | ||
"""Informs the logger that a new Trial is starting.""" | ||
if not self.task: | ||
return | ||
data = { | ||
"trial_id_{}".format(trial_id): trial_state, | ||
} | ||
data.update(self.task.get_model_config_dict()) | ||
self.task.connect_configuration(data) | ||
self.task.get_logger().tensorboard_single_series_per_graph(True) | ||
self.task.get_logger()._set_tensorboard_series_prefix(trial_id + " ") | ||
self.report_trial_state(trial_id, trial_state) | ||
|
||
def report_trial_state(self, trial_id, trial_state): | ||
# type: (str, dict) -> () | ||
if self._summary is None or not self.task: | ||
return | ||
|
||
trial = {} | ||
for k, v in trial_state.get("metrics", {}).get("metrics", {}).items(): | ||
m = "metric/{}".format(k) | ||
observations = trial_state["metrics"]["metrics"][k].get("observations") | ||
if observations: | ||
observations = observations[-1].get("value") | ||
if observations: | ||
trial[m] = observations[-1] | ||
for k, v in trial_state.get("hyperparameters", {}).get("values", {}).items(): | ||
m = "values/{}".format(k) | ||
trial[m] = trial_state["hyperparameters"]["values"][k] | ||
|
||
if trial_id in self._summary.index: | ||
columns = set(list(self._summary) + list(trial.keys())) | ||
if len(columns) != self._summary.columns.size: | ||
self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1) | ||
self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :] | ||
else: | ||
self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False) | ||
|
||
self._summary.index.name = "trial id" | ||
self._summary = self._summary.reindex(columns=sorted(self._summary.columns)) | ||
self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary) | ||
|
||
def exit(self): | ||
if not self.task: | ||
return | ||
self.task.flush(wait_for_uploads=True) | ||
|
||
|
||
try: | ||
from tensorflow.keras.callbacks import Callback | ||
except ImportError: | ||
_logger.warning( | ||
"Could not import 'tensorflow.keras.callbacks.Callback'. ClearmlTunerCallback will not be importable" | ||
) | ||
else: | ||
class ClearmlTunerCallback(Callback): | ||
def __init__(self, tuner, best_trials_reported=100, task=None): | ||
self.task = task or Task.current_task() | ||
if not self.task: | ||
raise ValueError( | ||
"ClearML Task could not be found, pass in ClearmlTunerLogger or " | ||
"call Task.init before initializing ClearmlTunerLogger" | ||
) | ||
self.tuner = tuner | ||
self.best_trials_reported = best_trials_reported | ||
super(ClearmlTunerCallback, self).__init__() | ||
|
||
class ClearmlTunerLogger(Logger): | ||
|
||
# noinspection PyTypeChecker | ||
def __init__(self, task=None): | ||
# type: (Optional[Task]) -> () | ||
super(ClearmlTunerLogger, self).__init__() | ||
self.task = task or Task.current_task() | ||
if not self.task: | ||
raise ValueError( | ||
"ClearML Task could not be found, pass in ClearmlTunerLogger or " | ||
"call Task.init before initializing ClearmlTunerLogger" | ||
) | ||
self._summary = pd.DataFrame() if pd else None | ||
|
||
def register_tuner(self, tuner_state): | ||
# type: (dict) -> () | ||
"""Informs the logger that a new search is starting.""" | ||
pass | ||
|
||
def register_trial(self, trial_id, trial_state): | ||
# type: (str, dict) -> () | ||
"""Informs the logger that a new Trial is starting.""" | ||
if not self.task: | ||
return | ||
data = { | ||
"trial_id_{}".format(trial_id): trial_state, | ||
} | ||
data.update(self.task.get_model_config_dict()) | ||
self.task.connect_configuration(data) | ||
self.task.get_logger().tensorboard_single_series_per_graph(True) | ||
self.task.get_logger()._set_tensorboard_series_prefix(trial_id + " ") | ||
self.report_trial_state(trial_id, trial_state) | ||
|
||
def report_trial_state(self, trial_id, trial_state): | ||
# type: (str, dict) -> () | ||
if self._summary is None or not self.task: | ||
return | ||
|
||
trial = {} | ||
for k, v in trial_state.get("metrics", {}).get("metrics", {}).items(): | ||
m = "metric/{}".format(k) | ||
observations = trial_state["metrics"]["metrics"][k].get("observations") | ||
if observations: | ||
observations = observations[-1].get("value") | ||
if observations: | ||
trial[m] = observations[-1] | ||
for k, v in trial_state.get("hyperparameters", {}).get("values", {}).items(): | ||
m = "values/{}".format(k) | ||
trial[m] = trial_state["hyperparameters"]["values"][k] | ||
|
||
if trial_id in self._summary.index: | ||
columns = set(list(self._summary) + list(trial.keys())) | ||
if len(columns) != self._summary.columns.size: | ||
self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1) | ||
self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :] | ||
else: | ||
self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False) | ||
|
||
self._summary.index.name = "trial id" | ||
self._summary = self._summary.reindex(columns=sorted(self._summary.columns)) | ||
self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary) | ||
|
||
def exit(self): | ||
if not self.task: | ||
return | ||
self.task.flush(wait_for_uploads=True) | ||
def on_train_end(self, *args, **kwargs): | ||
summary = pd.DataFrame() if pd else None | ||
if summary is None: | ||
return | ||
best_trials = self.tuner.oracle.get_best_trials(self.best_trials_reported) | ||
for trial in best_trials: | ||
trial_dict = {"trial id": trial.trial_id} | ||
for hparam in trial.hyperparameters.space: | ||
trial_dict[hparam.name] = trial.hyperparameters.values.get(hparam.name) | ||
summary = pd.concat([summary, pd.DataFrame(trial_dict, index=[trial.trial_id])], ignore_index=True) | ||
summary.index.name = "trial id" | ||
summary = summary[["trial id", *sorted(summary.columns[1:])]] | ||
self.task.get_logger().report_table("summary", "trial", 0, table_plot=summary) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
examples/frameworks/kerastuner/keras_tuner_cifar_legacy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""Keras Tuner CIFAR10 example for the TensorFlow blog post.""" | ||
|
||
import keras_tuner as kt | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
from clearml.external.kerastuner import ClearmlTunerLogger | ||
|
||
from clearml import Task | ||
|
||
physical_devices = tf.config.list_physical_devices('GPU') | ||
if physical_devices: | ||
tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU') | ||
tf.config.experimental.set_memory_growth(physical_devices[0], True) | ||
|
||
|
||
def build_model(hp): | ||
inputs = tf.keras.Input(shape=(32, 32, 3)) | ||
x = inputs | ||
for i in range(hp.Int('conv_blocks', 3, 5, default=3)): | ||
filters = hp.Int('filters_' + str(i), 32, 256, step=32) | ||
for _ in range(2): | ||
x = tf.keras.layers.Convolution2D( | ||
filters, kernel_size=(3, 3), padding='same')(x) | ||
x = tf.keras.layers.BatchNormalization()(x) | ||
x = tf.keras.layers.ReLU()(x) | ||
if hp.Choice('pooling_' + str(i), ['avg', 'max']) == 'max': | ||
x = tf.keras.layers.MaxPool2D()(x) | ||
else: | ||
x = tf.keras.layers.AvgPool2D()(x) | ||
x = tf.keras.layers.GlobalAvgPool2D()(x) | ||
x = tf.keras.layers.Dense( | ||
hp.Int('hidden_size', 30, 100, step=10, default=50), | ||
activation='relu')(x) | ||
x = tf.keras.layers.Dropout( | ||
hp.Float('dropout', 0, 0.5, step=0.1, default=0.5))(x) | ||
outputs = tf.keras.layers.Dense(10, activation='softmax')(x) | ||
|
||
model = tf.keras.Model(inputs, outputs) | ||
model.compile( | ||
optimizer=tf.keras.optimizers.Adam( | ||
hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')), | ||
loss='sparse_categorical_crossentropy', | ||
metrics=['accuracy']) | ||
return model | ||
|
||
|
||
# Connecting ClearML with the current process, | ||
# from here on everything is logged automatically | ||
task = Task.init('examples', 'kerastuner cifar10 tuning') | ||
|
||
tuner = kt.Hyperband( | ||
build_model, | ||
project_name='kt examples', | ||
logger=ClearmlTunerLogger(), | ||
objective='val_accuracy', | ||
max_epochs=10, | ||
hyperband_iterations=6) | ||
|
||
data = tfds.load('cifar10') | ||
train_ds, test_ds = data['train'], data['test'] | ||
|
||
|
||
def standardize_record(record): | ||
return tf.cast(record['image'], tf.float32) / 255., record['label'] | ||
|
||
|
||
train_ds = train_ds.map(standardize_record).cache().batch(64).shuffle(10000) | ||
test_ds = test_ds.map(standardize_record).cache().batch(64) | ||
|
||
tuner.search(train_ds, | ||
validation_data=test_ds, | ||
callbacks=[tf.keras.callbacks.EarlyStopping(patience=1), | ||
tf.keras.callbacks.TensorBoard(), | ||
]) | ||
|
||
best_model = tuner.get_best_models(1)[0] | ||
best_hyperparameters = tuner.get_best_hyperparameters(1)[0] |