Skip to content

Commit

Permalink
added validation checks in Group, Layer, Batch Normalization layers (#…
Browse files Browse the repository at this point in the history
…20246)

* added validation checks in Group, Layer, Batch Normalization layers for the compute_output_shape function

* Update batch_normalization.py

* Update group_normalization.py

* Update layer_normalization.py

---------

Co-authored-by: François Chollet <[email protected]>
  • Loading branch information
sanskarmodi8 and fchollet authored Sep 11, 2024
1 parent a23c2bb commit 698cc2f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
Empty file added FE
Empty file.
12 changes: 12 additions & 0 deletions keras/src/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def build(self, input_shape):
self.built = True

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def call(self, inputs, training=None, mask=None):
Expand Down
12 changes: 12 additions & 0 deletions keras/src/layers/normalization/group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def _create_broadcast_shape(self, input_shape):
return broadcast_shape

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def get_config(self):
Expand Down
14 changes: 13 additions & 1 deletion keras/src/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
if isinstance(axis, (list, tuple)):
Expand Down Expand Up @@ -235,6 +235,18 @@ def _broadcast(v):
return ops.cast(outputs, input_dtype)

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def get_config(self):
Expand Down

0 comments on commit 698cc2f

Please sign in to comment.