From 240b6718792e5c6222c86346b684080a304960da Mon Sep 17 00:00:00 2001 From: james77777778 <20734616+james77777778@users.noreply.github.com> Date: Wed, 8 May 2024 10:32:40 +0800 Subject: [PATCH] Fix dtype test for `trace` (#19686) * Fix `trace` dtype test * Fix formatting * Fix bug * Fix bug * Start using `self.assertDtype` * Fix CI * Fix CI * Fix CI --- keras/src/backend/jax/numpy.py | 6 +++++- keras/src/layers/convolutional/conv1d.py | 2 +- .../layers/convolutional/conv1d_transpose.py | 2 +- keras/src/layers/convolutional/conv2d.py | 2 +- .../layers/convolutional/conv2d_transpose.py | 2 +- keras/src/layers/convolutional/conv3d.py | 2 +- .../layers/convolutional/conv3d_transpose.py | 2 +- .../layers/convolutional/depthwise_conv1d.py | 2 +- .../layers/convolutional/depthwise_conv2d.py | 2 +- .../layers/convolutional/separable_conv1d.py | 2 +- .../layers/convolutional/separable_conv2d.py | 2 +- keras/src/layers/pooling/average_pooling1d.py | 2 +- keras/src/layers/pooling/average_pooling2d.py | 2 +- keras/src/layers/pooling/average_pooling3d.py | 2 +- keras/src/layers/pooling/max_pooling1d.py | 2 +- keras/src/layers/pooling/max_pooling2d.py | 2 +- keras/src/layers/pooling/max_pooling3d.py | 2 +- keras/src/ops/numpy_test.py | 17 +++++++++-------- 18 files changed, 30 insertions(+), 25 deletions(-) diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 60dec64420c..6d9002487e2 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -993,7 +993,11 @@ def tile(x, repeats): def trace(x, offset=0, axis1=0, axis2=1): x = convert_to_tensor(x) dtype = None - if standardize_dtype(x.dtype) == "bool": + # TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27 + # for both CPU & GPU environments. + # uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32 + # otherwise. + if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"): dtype = "int32" return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) diff --git a/keras/src/layers/convolutional/conv1d.py b/keras/src/layers/convolutional/conv1d.py index 60ec4036ea2..4c25e819515 100644 --- a/keras/src/layers/convolutional/conv1d.py +++ b/keras/src/layers/convolutional/conv1d.py @@ -70,7 +70,7 @@ class Conv1D(BaseConv): A 3D tensor with shape: `(batch_shape, channels, steps)` Output shape: - + - If `data_format="channels_last"`: A 3D tensor with shape: `(batch_shape, new_steps, filters)` - If `data_format="channels_first"`: diff --git a/keras/src/layers/convolutional/conv1d_transpose.py b/keras/src/layers/convolutional/conv1d_transpose.py index d3a8f9cb566..e14d04a878f 100644 --- a/keras/src/layers/convolutional/conv1d_transpose.py +++ b/keras/src/layers/convolutional/conv1d_transpose.py @@ -64,7 +64,7 @@ class Conv1DTranspose(BaseConvTranspose): A 3D tensor with shape: `(batch_shape, channels, steps)` Output shape: - + - If `data_format="channels_last"`: A 3D tensor with shape: `(batch_shape, new_steps, filters)` - If `data_format="channels_first"`: diff --git a/keras/src/layers/convolutional/conv2d.py b/keras/src/layers/convolutional/conv2d.py index 3ea660b5f94..24b47783b3c 100644 --- a/keras/src/layers/convolutional/conv2d.py +++ b/keras/src/layers/convolutional/conv2d.py @@ -66,7 +66,7 @@ class Conv2D(BaseConv): A 4D tensor with shape: `(batch_size, channels, height, width)` Output shape: - + - If `data_format="channels_last"`: A 4D tensor with shape: `(batch_size, new_height, new_width, filters)` - If `data_format="channels_first"`: diff --git a/keras/src/layers/convolutional/conv2d_transpose.py b/keras/src/layers/convolutional/conv2d_transpose.py index 33c779be8c0..633d57ff166 100644 --- a/keras/src/layers/convolutional/conv2d_transpose.py +++ b/keras/src/layers/convolutional/conv2d_transpose.py @@ -66,7 +66,7 @@ class Conv2DTranspose(BaseConvTranspose): A 4D tensor with shape: `(batch_size, channels, height, width)` Output shape: - + - If `data_format="channels_last"`: A 4D tensor with shape: `(batch_size, new_height, new_width, filters)` - If `data_format="channels_first"`: diff --git a/keras/src/layers/convolutional/conv3d.py b/keras/src/layers/convolutional/conv3d.py index b2cb14989ae..a4cb1c0b8f5 100644 --- a/keras/src/layers/convolutional/conv3d.py +++ b/keras/src/layers/convolutional/conv3d.py @@ -68,7 +68,7 @@ class Conv3D(BaseConv): `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` Output shape: - + - If `data_format="channels_last"`: 5D tensor with shape: `(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3, diff --git a/keras/src/layers/convolutional/conv3d_transpose.py b/keras/src/layers/convolutional/conv3d_transpose.py index 06b223e48fc..953f0d27837 100644 --- a/keras/src/layers/convolutional/conv3d_transpose.py +++ b/keras/src/layers/convolutional/conv3d_transpose.py @@ -68,7 +68,7 @@ class Conv3DTranspose(BaseConvTranspose): `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` Output shape: - + - If `data_format="channels_last"`: 5D tensor with shape: `(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3, diff --git a/keras/src/layers/convolutional/depthwise_conv1d.py b/keras/src/layers/convolutional/depthwise_conv1d.py index d41632dba92..d787fcd0e30 100644 --- a/keras/src/layers/convolutional/depthwise_conv1d.py +++ b/keras/src/layers/convolutional/depthwise_conv1d.py @@ -74,7 +74,7 @@ class DepthwiseConv1D(BaseDepthwiseConv): A 3D tensor with shape: `(batch_shape, channels, steps)` Output shape: - + - If `data_format="channels_last"`: A 3D tensor with shape: `(batch_shape, new_steps, channels * depth_multiplier)` diff --git a/keras/src/layers/convolutional/depthwise_conv2d.py b/keras/src/layers/convolutional/depthwise_conv2d.py index 69087b7c278..c3da7aa889b 100644 --- a/keras/src/layers/convolutional/depthwise_conv2d.py +++ b/keras/src/layers/convolutional/depthwise_conv2d.py @@ -75,7 +75,7 @@ class DepthwiseConv2D(BaseDepthwiseConv): A 4D tensor with shape: `(batch_size, channels, height, width)` Output shape: - + - If `data_format="channels_last"`: A 4D tensor with shape: `(batch_size, new_height, new_width, channels * depth_multiplier)` diff --git a/keras/src/layers/convolutional/separable_conv1d.py b/keras/src/layers/convolutional/separable_conv1d.py index d606aeed905..2f03161981d 100644 --- a/keras/src/layers/convolutional/separable_conv1d.py +++ b/keras/src/layers/convolutional/separable_conv1d.py @@ -77,7 +77,7 @@ class SeparableConv1D(BaseSeparableConv): A 3D tensor with shape: `(batch_shape, channels, steps)` Output shape: - + - If `data_format="channels_last"`: A 3D tensor with shape: `(batch_shape, new_steps, filters)` - If `data_format="channels_first"`: diff --git a/keras/src/layers/convolutional/separable_conv2d.py b/keras/src/layers/convolutional/separable_conv2d.py index e2763319991..27c1548231d 100644 --- a/keras/src/layers/convolutional/separable_conv2d.py +++ b/keras/src/layers/convolutional/separable_conv2d.py @@ -78,7 +78,7 @@ class SeparableConv2D(BaseSeparableConv): A 4D tensor with shape: `(batch_size, channels, height, width)` Output shape: - + - If `data_format="channels_last"`: A 4D tensor with shape: `(batch_size, new_height, new_width, filters)` - If `data_format="channels_first"`: diff --git a/keras/src/layers/pooling/average_pooling1d.py b/keras/src/layers/pooling/average_pooling1d.py index c0a94cfb31f..a52a031e17f 100644 --- a/keras/src/layers/pooling/average_pooling1d.py +++ b/keras/src/layers/pooling/average_pooling1d.py @@ -38,7 +38,7 @@ class AveragePooling1D(BasePooling): 3D tensor with shape `(batch_size, features, steps)`. Output shape: - + - If `data_format="channels_last"`: 3D tensor with shape `(batch_size, downsampled_steps, features)`. - If `data_format="channels_first"`: diff --git a/keras/src/layers/pooling/average_pooling2d.py b/keras/src/layers/pooling/average_pooling2d.py index d0d548893db..ed56f32c0ad 100644 --- a/keras/src/layers/pooling/average_pooling2d.py +++ b/keras/src/layers/pooling/average_pooling2d.py @@ -47,7 +47,7 @@ class AveragePooling2D(BasePooling): 4D tensor with shape `(batch_size, channels, height, width)`. Output shape: - + - If `data_format="channels_last"`: 4D tensor with shape `(batch_size, pooled_height, pooled_width, channels)`. diff --git a/keras/src/layers/pooling/average_pooling3d.py b/keras/src/layers/pooling/average_pooling3d.py index 10fdff86e71..96541e2cd8a 100644 --- a/keras/src/layers/pooling/average_pooling3d.py +++ b/keras/src/layers/pooling/average_pooling3d.py @@ -42,7 +42,7 @@ class AveragePooling3D(BasePooling): `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` Output shape: - + - If `data_format="channels_last"`: 5D tensor with shape: `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)` diff --git a/keras/src/layers/pooling/max_pooling1d.py b/keras/src/layers/pooling/max_pooling1d.py index 40f5561ce94..7485393b553 100644 --- a/keras/src/layers/pooling/max_pooling1d.py +++ b/keras/src/layers/pooling/max_pooling1d.py @@ -39,7 +39,7 @@ class MaxPooling1D(BasePooling): 3D tensor with shape `(batch_size, features, steps)`. Output shape: - + - If `data_format="channels_last"`: 3D tensor with shape `(batch_size, downsampled_steps, features)`. - If `data_format="channels_first"`: diff --git a/keras/src/layers/pooling/max_pooling2d.py b/keras/src/layers/pooling/max_pooling2d.py index 8efdc339ead..9d2ffdc437d 100644 --- a/keras/src/layers/pooling/max_pooling2d.py +++ b/keras/src/layers/pooling/max_pooling2d.py @@ -47,7 +47,7 @@ class MaxPooling2D(BasePooling): 4D tensor with shape `(batch_size, channels, height, width)`. Output shape: - + - If `data_format="channels_last"`: 4D tensor with shape `(batch_size, pooled_height, pooled_width, channels)`. diff --git a/keras/src/layers/pooling/max_pooling3d.py b/keras/src/layers/pooling/max_pooling3d.py index fa6d558ca29..43be140c5aa 100644 --- a/keras/src/layers/pooling/max_pooling3d.py +++ b/keras/src/layers/pooling/max_pooling3d.py @@ -42,7 +42,7 @@ class MaxPooling3D(BasePooling): `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` Output shape: - + - If `data_format="channels_last"`: 5D tensor with shape: `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)` diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 106d27c8be1..1b7f776ac07 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -7702,7 +7702,7 @@ def test_trace(self, dtype): import jax.experimental import jax.numpy as jnp - # We have to disable x64 for jax since jnp.true_divide doesn't respect + # We have to disable x64 for jax since jnp.trace doesn't respect # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast # the expected dtype from 64 bit to 32 bit when using jax backend. with jax.experimental.disable_x64(): @@ -7717,16 +7717,17 @@ def test_trace(self, dtype): expected_dtype = "float64" elif dtype == "int64": expected_dtype = "int64" + # TODO: Remove the condition of uint8 and uint16 once we have + # jax>=0.4.27 for both CPU & GPU environments. + # uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to + # int32 otherwise. + elif dtype in ("uint8", "uint16"): + expected_dtype = "int32" if backend.backend() == "jax": expected_dtype = expected_dtype.replace("64", "32") - self.assertEqual( - standardize_dtype(knp.trace(x).dtype), expected_dtype - ) - self.assertEqual( - standardize_dtype(knp.Trace().symbolic_call(x).dtype), - expected_dtype, - ) + self.assertDType(knp.trace(x), expected_dtype) + self.assertDType(knp.Trace().symbolic_call(x), expected_dtype) @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_transpose(self, dtype):