From 54495181fcfbee6ea225184b690e2d106df2b812 Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Thu, 22 Feb 2024 19:51:28 +0900 Subject: [PATCH] Resolve deprecation by Keras 3.0 --- tests/test_keras.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_keras.py b/tests/test_keras.py index e3b8afb5..aee64f34 100644 --- a/tests/test_keras.py +++ b/tests/test_keras.py @@ -11,15 +11,15 @@ with try_import(): from keras import Sequential - from keras.layers import Dense + from keras.layers import Dense, Input @pytest.mark.parametrize("interval, epochs", [(1, 1), (2, 1), (2, 2)]) def test_keras_pruning_callback(interval: int, epochs: int) -> None: def objective(trial: optuna.trial.Trial) -> float: model = Sequential() - # TODO(nzw0301): Use Input class instead of passing input_dim of layer since Keras 3.0. - model.add(Dense(1, activation="sigmoid", input_dim=20)) + model.add(Input(shape=(20,))) + model.add(Dense(1, activation="sigmoid")) model.compile(optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"]) model.fit( np.zeros((16, 20), np.float32),