Skip to content

Commit

Permalink
Resolve deprecation by Keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
nzw0301 committed Feb 22, 2024
1 parent 32b9f9d commit 5449518
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

with try_import():
from keras import Sequential
from keras.layers import Dense
from keras.layers import Dense, Input


@pytest.mark.parametrize("interval, epochs", [(1, 1), (2, 1), (2, 2)])
def test_keras_pruning_callback(interval: int, epochs: int) -> None:
def objective(trial: optuna.trial.Trial) -> float:
model = Sequential()
# TODO(nzw0301): Use Input class instead of passing input_dim of layer since Keras 3.0.
model.add(Dense(1, activation="sigmoid", input_dim=20))
model.add(Input(shape=(20,)))
model.add(Dense(1, activation="sigmoid"))
model.compile(optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"])
model.fit(
np.zeros((16, 20), np.float32),
Expand Down

0 comments on commit 5449518

Please sign in to comment.