Skip to content

Commit

Permalink
Fix dtype test for trace (#19686)
Browse files Browse the repository at this point in the history
* Fix `trace` dtype test

* Fix formatting

* Fix bug

* Fix bug

* Start using `self.assertDtype`

* Fix CI

* Fix CI

* Fix CI
  • Loading branch information
james77777778 authored May 8, 2024
1 parent 603db80 commit 240b671
Show file tree
Hide file tree
Showing 18 changed files with 30 additions and 25 deletions.
6 changes: 5 additions & 1 deletion keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,11 @@ def tile(x, repeats):
def trace(x, offset=0, axis1=0, axis2=1):
x = convert_to_tensor(x)
dtype = None
if standardize_dtype(x.dtype) == "bool":
# TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
# for both CPU & GPU environments.
# uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
# otherwise.
if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
dtype = "int32"
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)

Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Conv1D(BaseConv):
A 3D tensor with shape: `(batch_shape, channels, steps)`
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, new_steps, filters)`
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/conv1d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Conv1DTranspose(BaseConvTranspose):
A 3D tensor with shape: `(batch_shape, channels, steps)`
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, new_steps, filters)`
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Conv2D(BaseConv):
A 4D tensor with shape: `(batch_size, channels, height, width)`
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Conv2DTranspose(BaseConvTranspose):
A 4D tensor with shape: `(batch_size, channels, height, width)`
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Conv3D(BaseConv):
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/conv3d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Conv3DTranspose(BaseConvTranspose):
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/depthwise_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class DepthwiseConv1D(BaseDepthwiseConv):
A 3D tensor with shape: `(batch_shape, channels, steps)`
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape:
`(batch_shape, new_steps, channels * depth_multiplier)`
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class DepthwiseConv2D(BaseDepthwiseConv):
A 4D tensor with shape: `(batch_size, channels, height, width)`
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape:
`(batch_size, new_height, new_width, channels * depth_multiplier)`
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/separable_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class SeparableConv1D(BaseSeparableConv):
A 3D tensor with shape: `(batch_shape, channels, steps)`
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, new_steps, filters)`
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/separable_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class SeparableConv2D(BaseSeparableConv):
A 4D tensor with shape: `(batch_size, channels, height, width)`
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/pooling/average_pooling1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class AveragePooling1D(BasePooling):
3D tensor with shape `(batch_size, features, steps)`.
Output shape:
- If `data_format="channels_last"`:
3D tensor with shape `(batch_size, downsampled_steps, features)`.
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/pooling/average_pooling2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class AveragePooling2D(BasePooling):
4D tensor with shape `(batch_size, channels, height, width)`.
Output shape:
- If `data_format="channels_last"`:
4D tensor with shape
`(batch_size, pooled_height, pooled_width, channels)`.
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/pooling/average_pooling3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AveragePooling3D(BasePooling):
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/pooling/max_pooling1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class MaxPooling1D(BasePooling):
3D tensor with shape `(batch_size, features, steps)`.
Output shape:
- If `data_format="channels_last"`:
3D tensor with shape `(batch_size, downsampled_steps, features)`.
- If `data_format="channels_first"`:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/pooling/max_pooling2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MaxPooling2D(BasePooling):
4D tensor with shape `(batch_size, channels, height, width)`.
Output shape:
- If `data_format="channels_last"`:
4D tensor with shape
`(batch_size, pooled_height, pooled_width, channels)`.
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/pooling/max_pooling3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MaxPooling3D(BasePooling):
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
Expand Down
17 changes: 9 additions & 8 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7702,7 +7702,7 @@ def test_trace(self, dtype):
import jax.experimental
import jax.numpy as jnp

# We have to disable x64 for jax since jnp.true_divide doesn't respect
# We have to disable x64 for jax since jnp.trace doesn't respect
# JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast
# the expected dtype from 64 bit to 32 bit when using jax backend.
with jax.experimental.disable_x64():
Expand All @@ -7717,16 +7717,17 @@ def test_trace(self, dtype):
expected_dtype = "float64"
elif dtype == "int64":
expected_dtype = "int64"
# TODO: Remove the condition of uint8 and uint16 once we have
# jax>=0.4.27 for both CPU & GPU environments.
# uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to
# int32 otherwise.
elif dtype in ("uint8", "uint16"):
expected_dtype = "int32"
if backend.backend() == "jax":
expected_dtype = expected_dtype.replace("64", "32")

self.assertEqual(
standardize_dtype(knp.trace(x).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Trace().symbolic_call(x).dtype),
expected_dtype,
)
self.assertDType(knp.trace(x), expected_dtype)
self.assertDType(knp.Trace().symbolic_call(x), expected_dtype)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_transpose(self, dtype):
Expand Down

0 comments on commit 240b671

Please sign in to comment.