Skip to content

Commit

Permalink
Squeeze and Excite block (#505)
Browse files Browse the repository at this point in the history
* Added SE block

* Formatted

* Made requested changes

* Final touches to SqueezeAndExcite

* Final touches to SqueezeAndExcite

* Made requested changes

* Formatted

* Added activation arguments

* Made requested changes

* Serialization tests workaround

* Made requested changes
  • Loading branch information
AdityaKane2001 authored Jul 7, 2022
1 parent 7cd255e commit d30d902
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@
from keras_cv.layers.preprocessing.solarization import Solarization
from keras_cv.layers.regularization.drop_path import DropPath
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D
from keras_cv.layers.regularization.stochastic_depth import StochasticDepth
1 change: 1 addition & 0 deletions keras_cv/layers/regularization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@

from keras_cv.layers.regularization.drop_path import DropPath
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D
from keras_cv.layers.regularization.stochastic_depth import StochasticDepth
107 changes: 107 additions & 0 deletions keras_cv/layers/regularization/squeeze_excite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow.keras import layers


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class SqueezeAndExcite2D(layers.Layer):
"""
Implements Squeeze and Excite block as in
[Squeeze-and-Excitation Networks](https://arxiv.org/pdf/1709.01507.pdf).
This layer tries to use a content aware mechanism to assign channel-wise
weights adaptively. It first squeezes the feature maps into a single value
using global average pooling, which are then fed into two Conv1D layers,
which act like fully-connected layers. The first layer reduces the
dimensionality of the feature maps by a factor of `ratio`, whereas the second
layer restores it to its original value.
The resultant values are the adaptive weights for each channel. These
weights are then multiplied with the original inputs to scale the outputs
based on their individual weightages.
Args:
filters: Number of input and output filters. The number of input and
output filters is same.
ratio: Ratio for bottleneck filters. Number of bottleneck filters =
filters * ratio. Defaults to 0.25.
squeeze_activation: (Optional) String, callable (or tf.keras.layers.Layer) or
tf.keras.activations.Activation instance denoting activation to
be applied after squeeze convolution. Defaults to `relu`.
excite_activation: (Optional) String, callable (or tf.keras.layers.Layer) or
tf.keras.activations.Activation instance denoting activation to
be applied after excite convolution. Defaults to `sigmoid`.
Usage:
```python
# (...)
input = tf.ones((1, 5, 5, 16), dtype=tf.float32)
x = tf.keras.layers.Conv2D(16, (3, 3))(input)
output = keras_cv.layers.SqueezeAndExciteBlock(16)(x)
# (...)
```
"""

def __init__(
self,
filters,
ratio=0.25,
squeeze_activation="relu",
excite_activation="sigmoid",
**kwargs,
):
super().__init__(**kwargs)

self.filters = filters

if ratio <= 0.0 or ratio >= 1.0:
raise ValueError(f"`ratio` should be a float between 0 and 1. Got {ratio}")

if filters <= 0 or not isinstance(filters, int):
raise ValueError(f"`filters` should be a positive integer. Got {filters}")

self.ratio = ratio
self.bottleneck_filters = int(self.filters * self.ratio)

self.squeeze_activation = squeeze_activation
self.excite_activation = excite_activation

self.global_average_pool = layers.GlobalAveragePooling2D(keepdims=True)
self.squeeze_conv = layers.Conv2D(
self.bottleneck_filters,
(1, 1),
activation=self.squeeze_activation,
)
self.excite_conv = layers.Conv2D(
self.filters, (1, 1), activation=self.excite_activation
)

def call(self, inputs, training=True):
x = self.global_average_pool(inputs) # x: (batch_size, 1, 1, filters)
x = self.squeeze_conv(x) # x: (batch_size, 1, 1, bottleneck_filters)
x = self.excite_conv(x) # x: (batch_size, 1, 1, filters)
x = tf.math.multiply(x, inputs) # x: (batch_size, h, w, filters)
return x

def get_config(self):
config = {
"filters": self.filters,
"ratio": self.ratio,
"squeeze_activation": self.squeeze_activation,
"excite_activation": self.excite_activation,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
55 changes: 55 additions & 0 deletions keras_cv/layers/regularization/squeeze_excite_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_cv.layers import SqueezeAndExcite2D


class SqueezeAndExcite2DTest(tf.test.TestCase):
def test_maintains_shape(self):
input_shape = (1, 4, 4, 8)
inputs = tf.random.uniform(input_shape)

layer = SqueezeAndExcite2D(8, ratio=0.25)
outputs = layer(inputs)
self.assertEquals(inputs.shape, outputs.shape)

def test_custom_activation(self):
def custom_activation(x):
return x * tf.random.uniform(x.shape, seed=42)

input_shape = (1, 4, 4, 8)
inputs = tf.random.uniform(input_shape)

layer = SqueezeAndExcite2D(
8,
ratio=0.25,
squeeze_activation=custom_activation,
excite_activation=custom_activation,
)
outputs = layer(inputs)
self.assertEquals(inputs.shape, outputs.shape)

def test_raises_invalid_ratio_error(self):
with self.assertRaisesRegex(
ValueError, "`ratio` should be a float" " between 0 and 1. Got (.*?)"
):
_ = SqueezeAndExcite2D(8, ratio=1.1)

def test_raises_invalid_filters_error(self):
with self.assertRaisesRegex(
ValueError, "`filters` should be a positive" " integer. Got (.*?)"
):
_ = SqueezeAndExcite2D(-8.7)
16 changes: 16 additions & 0 deletions keras_cv/layers/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

import tensorflow as tf
from absl.testing import parameterized

Expand All @@ -22,6 +24,10 @@
def custom_compare(obj1, obj2):
if isinstance(obj1, (core.FactorSampler, tf.keras.layers.Layer)):
return config_equals(obj1.get_config(), obj2.get_config())
elif inspect.isfunction(obj1):
return tf.keras.utils.serialize_keras_object(obj1) == obj2
elif inspect.isfunction(obj2):
return obj1 == tf.keras.utils.serialize_keras_object(obj2)
else:
return obj1 == obj2

Expand Down Expand Up @@ -138,6 +144,16 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
regularization.StochasticDepth,
{"rate": 0.1},
),
(
"SqueezeAndExcite2D",
regularization.SqueezeAndExcite2D,
{
"filters": 16,
"ratio": 0.25,
"squeeze_activation": tf.keras.layers.ReLU(),
"excite_activation": tf.keras.activations.relu,
},
),
(
"DropPath",
regularization.DropPath,
Expand Down

0 comments on commit d30d902

Please sign in to comment.