diff --git a/optuna_integration/keras.py b/optuna_integration/keras.py index 96e3fa15..6b8c0d7c 100644 --- a/optuna_integration/keras.py +++ b/optuna_integration/keras.py @@ -3,7 +3,6 @@ import warnings import optuna -from optuna._deprecated import deprecated_class from optuna_integration._imports import try_import @@ -14,19 +13,7 @@ if not _imports.is_successful(): Callback = object # NOQA -_keras_pruning_callback_deprecated_msg = ( - "Recent Keras release (2.4.0) simply redirects " - "all APIs in the standalone keras package to point to tf.keras. " - "There is now only one Keras: tf.keras. " - "There may be some breaking changes for some workflows by upgrading to keras 2.4.0. " - "Test before upgrading. " - "REF: https://github.com/keras-team/keras/releases/tag/2.4.0. " - "There is an alternative callback function that can be used instead: " - ":class:`~optuna_integration.TFKerasPruningCallback`" -) - -@deprecated_class("2.1.0", "4.0.0", text=_keras_pruning_callback_deprecated_msg) class KerasPruningCallback(Callback): """Keras callback to prune unpromising trials. diff --git a/tests/test_keras.py b/tests/test_keras.py index 207dabcd..02fec79a 100644 --- a/tests/test_keras.py +++ b/tests/test_keras.py @@ -12,13 +12,15 @@ with try_import(): from keras import Sequential from keras.layers import Dense + from keras.layers import Input @pytest.mark.parametrize("interval, epochs", [(1, 1), (2, 1), (2, 2)]) def test_keras_pruning_callback(interval: int, epochs: int) -> None: def objective(trial: optuna.trial.Trial) -> float: model = Sequential() - model.add(Dense(1, activation="sigmoid", input_dim=20)) + model.add(Input(shape=(20,))) + model.add(Dense(1, activation="sigmoid")) model.compile(optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"]) model.fit( np.zeros((16, 20), np.float32),