Skip to content

Commit

Permalink
better randomization of variable batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed May 31, 2024
1 parent 37b0cf7 commit f964dc2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_variable_batch_size(invertible_layer, random_input):
invertible_layer.build(keras.ops.shape(random_input))

# run with another batch size
for _ in range(10):
batch_size = np.random.randint(1, 10)
batch_sizes = np.random.choice(10, replace=False, size=3)
for batch_size in batch_sizes:
new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_input)[1:])
invertible_layer(new_input)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def test_variable_batch_size(inference_network, random_samples):
inference_network.build(keras.ops.shape(random_samples))

# run with another batch size
for _ in range(10):
batch_size = np.random.randint(1, 10)
batch_sizes = np.random.choice(10, replace=False, size=3)
for batch_size in batch_sizes:
new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_samples)[1:])
inference_network(new_input)
inference_network(new_input, inverse=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_networks/test_summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_variable_batch_size(summary_network, random_set):
summary_network.build(keras.ops.shape(random_set))

# run with another batch size
for _ in range(10):
batch_size = np.random.randint(1, 10)
batch_sizes = np.random.choice(10, replace=False, size=3)
for batch_size in batch_sizes:
new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_set)[1:])
summary_network(new_input)

Expand Down

0 comments on commit f964dc2

Please sign in to comment.