From d30d9020573d6a881458daab668ceab00bf2cce4 Mon Sep 17 00:00:00 2001 From: Aditya Kane <64411306+AdityaKane2001@users.noreply.github.com> Date: Fri, 8 Jul 2022 01:21:45 +0530 Subject: [PATCH] Squeeze and Excite block (#505) * Added SE block * Formatted * Made requested changes * Final touches to SqueezeAndExcite * Final touches to SqueezeAndExcite * Made requested changes * Formatted * Added activation arguments * Made requested changes * Serialization tests workaround * Made requested changes --- keras_cv/layers/__init__.py | 1 + keras_cv/layers/regularization/__init__.py | 1 + .../layers/regularization/squeeze_excite.py | 107 ++++++++++++++++++ .../regularization/squeeze_excite_test.py | 55 +++++++++ keras_cv/layers/serialization_test.py | 16 +++ 5 files changed, 180 insertions(+) create mode 100644 keras_cv/layers/regularization/squeeze_excite.py create mode 100644 keras_cv/layers/regularization/squeeze_excite_test.py diff --git a/keras_cv/layers/__init__.py b/keras_cv/layers/__init__.py index 62ae6d0762..073c115a63 100644 --- a/keras_cv/layers/__init__.py +++ b/keras_cv/layers/__init__.py @@ -59,4 +59,5 @@ from keras_cv.layers.preprocessing.solarization import Solarization from keras_cv.layers.regularization.drop_path import DropPath from keras_cv.layers.regularization.dropblock_2d import DropBlock2D +from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D from keras_cv.layers.regularization.stochastic_depth import StochasticDepth diff --git a/keras_cv/layers/regularization/__init__.py b/keras_cv/layers/regularization/__init__.py index ae95746262..8e002c4b17 100644 --- a/keras_cv/layers/regularization/__init__.py +++ b/keras_cv/layers/regularization/__init__.py @@ -14,4 +14,5 @@ from keras_cv.layers.regularization.drop_path import DropPath from keras_cv.layers.regularization.dropblock_2d import DropBlock2D +from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D from keras_cv.layers.regularization.stochastic_depth import StochasticDepth diff --git a/keras_cv/layers/regularization/squeeze_excite.py b/keras_cv/layers/regularization/squeeze_excite.py new file mode 100644 index 0000000000..5939848a98 --- /dev/null +++ b/keras_cv/layers/regularization/squeeze_excite.py @@ -0,0 +1,107 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow.keras import layers + + +@tf.keras.utils.register_keras_serializable(package="keras_cv") +class SqueezeAndExcite2D(layers.Layer): + """ + Implements Squeeze and Excite block as in + [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/1709.01507.pdf). + This layer tries to use a content aware mechanism to assign channel-wise + weights adaptively. It first squeezes the feature maps into a single value + using global average pooling, which are then fed into two Conv1D layers, + which act like fully-connected layers. The first layer reduces the + dimensionality of the feature maps by a factor of `ratio`, whereas the second + layer restores it to its original value. + + The resultant values are the adaptive weights for each channel. These + weights are then multiplied with the original inputs to scale the outputs + based on their individual weightages. + + + Args: + filters: Number of input and output filters. The number of input and + output filters is same. + ratio: Ratio for bottleneck filters. Number of bottleneck filters = + filters * ratio. Defaults to 0.25. + squeeze_activation: (Optional) String, callable (or tf.keras.layers.Layer) or + tf.keras.activations.Activation instance denoting activation to + be applied after squeeze convolution. Defaults to `relu`. + excite_activation: (Optional) String, callable (or tf.keras.layers.Layer) or + tf.keras.activations.Activation instance denoting activation to + be applied after excite convolution. Defaults to `sigmoid`. + Usage: + + ```python + # (...) + input = tf.ones((1, 5, 5, 16), dtype=tf.float32) + x = tf.keras.layers.Conv2D(16, (3, 3))(input) + output = keras_cv.layers.SqueezeAndExciteBlock(16)(x) + # (...) + ``` + """ + + def __init__( + self, + filters, + ratio=0.25, + squeeze_activation="relu", + excite_activation="sigmoid", + **kwargs, + ): + super().__init__(**kwargs) + + self.filters = filters + + if ratio <= 0.0 or ratio >= 1.0: + raise ValueError(f"`ratio` should be a float between 0 and 1. Got {ratio}") + + if filters <= 0 or not isinstance(filters, int): + raise ValueError(f"`filters` should be a positive integer. Got {filters}") + + self.ratio = ratio + self.bottleneck_filters = int(self.filters * self.ratio) + + self.squeeze_activation = squeeze_activation + self.excite_activation = excite_activation + + self.global_average_pool = layers.GlobalAveragePooling2D(keepdims=True) + self.squeeze_conv = layers.Conv2D( + self.bottleneck_filters, + (1, 1), + activation=self.squeeze_activation, + ) + self.excite_conv = layers.Conv2D( + self.filters, (1, 1), activation=self.excite_activation + ) + + def call(self, inputs, training=True): + x = self.global_average_pool(inputs) # x: (batch_size, 1, 1, filters) + x = self.squeeze_conv(x) # x: (batch_size, 1, 1, bottleneck_filters) + x = self.excite_conv(x) # x: (batch_size, 1, 1, filters) + x = tf.math.multiply(x, inputs) # x: (batch_size, h, w, filters) + return x + + def get_config(self): + config = { + "filters": self.filters, + "ratio": self.ratio, + "squeeze_activation": self.squeeze_activation, + "excite_activation": self.excite_activation, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_cv/layers/regularization/squeeze_excite_test.py b/keras_cv/layers/regularization/squeeze_excite_test.py new file mode 100644 index 0000000000..21ccc5ddff --- /dev/null +++ b/keras_cv/layers/regularization/squeeze_excite_test.py @@ -0,0 +1,55 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from keras_cv.layers import SqueezeAndExcite2D + + +class SqueezeAndExcite2DTest(tf.test.TestCase): + def test_maintains_shape(self): + input_shape = (1, 4, 4, 8) + inputs = tf.random.uniform(input_shape) + + layer = SqueezeAndExcite2D(8, ratio=0.25) + outputs = layer(inputs) + self.assertEquals(inputs.shape, outputs.shape) + + def test_custom_activation(self): + def custom_activation(x): + return x * tf.random.uniform(x.shape, seed=42) + + input_shape = (1, 4, 4, 8) + inputs = tf.random.uniform(input_shape) + + layer = SqueezeAndExcite2D( + 8, + ratio=0.25, + squeeze_activation=custom_activation, + excite_activation=custom_activation, + ) + outputs = layer(inputs) + self.assertEquals(inputs.shape, outputs.shape) + + def test_raises_invalid_ratio_error(self): + with self.assertRaisesRegex( + ValueError, "`ratio` should be a float" " between 0 and 1. Got (.*?)" + ): + _ = SqueezeAndExcite2D(8, ratio=1.1) + + def test_raises_invalid_filters_error(self): + with self.assertRaisesRegex( + ValueError, "`filters` should be a positive" " integer. Got (.*?)" + ): + _ = SqueezeAndExcite2D(-8.7) diff --git a/keras_cv/layers/serialization_test.py b/keras_cv/layers/serialization_test.py index 4da1afb424..e6c195a994 100644 --- a/keras_cv/layers/serialization_test.py +++ b/keras_cv/layers/serialization_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect + import tensorflow as tf from absl.testing import parameterized @@ -22,6 +24,10 @@ def custom_compare(obj1, obj2): if isinstance(obj1, (core.FactorSampler, tf.keras.layers.Layer)): return config_equals(obj1.get_config(), obj2.get_config()) + elif inspect.isfunction(obj1): + return tf.keras.utils.serialize_keras_object(obj1) == obj2 + elif inspect.isfunction(obj2): + return obj1 == tf.keras.utils.serialize_keras_object(obj2) else: return obj1 == obj2 @@ -138,6 +144,16 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase): regularization.StochasticDepth, {"rate": 0.1}, ), + ( + "SqueezeAndExcite2D", + regularization.SqueezeAndExcite2D, + { + "filters": 16, + "ratio": 0.25, + "squeeze_activation": tf.keras.layers.ReLU(), + "excite_activation": tf.keras.activations.relu, + }, + ), ( "DropPath", regularization.DropPath,