Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* added validation checks and raised error if an invalid input shape is passed to compute_output_shape func in UnitNormalization Layer

* updated my change to check if the input is int or an iterable before iterating

* Update unit_normalization.py

---------

Co-authored-by: François Chollet <[email protected]>
  • Loading branch information
sanskarmodi8 and fchollet authored Sep 9, 2024
1 parent d60dd6c commit 7b4a78c
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions keras/src/layers/normalization/unit_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def call(self, inputs):
return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12)

def compute_output_shape(self, input_shape):
# Ensure axis is always treated as a list
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {self.axis} is out of bounds for "
f"input shape {input_shape}."
)
return input_shape

def get_config(self):
Expand Down

0 comments on commit 7b4a78c

Please sign in to comment.