diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 1d44e9b07..cd69c569e 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -85,6 +85,12 @@ from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( DenseNetImageClassifier, diff --git a/keras_hub/src/models/deeplab_v3/__init__.py b/keras_hub/src/models/deeplab_v3/__init__.py new file mode 100644 index 000000000..0a959e186 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/__init__.py @@ -0,0 +1,7 @@ +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, DeepLabV3Backbone) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py new file mode 100644 index 000000000..70bf828b0 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py @@ -0,0 +1,196 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( + SpatialPyramidPooling, +) + + +@keras_hub_export("keras_hub.models.DeepLabV3Backbone") +class DeepLabV3Backbone(Backbone): + """DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation. + + This class implements a DeepLabV3 & DeepLabV3Plus architecture as described + in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation]( + https://arxiv.org/abs/1802.02611)(ECCV 2018) + and [Rethinking Atrous Convolution for Semantic Image Segmentation]( + https://arxiv.org/abs/1706.05587)(CVPR 2017) + + Args: + image_encoder: `keras.Model`. An instance that is used as a feature + extractor for the Encoder. Should either be a + `keras_hub.models.Backbone` or a `keras.Model` that implements the + `pyramid_outputs` property with keys "P2", "P3" etc as values. + A somewhat sensible backbone to use in many cases is + the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`. + projection_filters: int. Number of filters in the convolution layer + projecting low-level features from the `image_encoder`. + spatial_pyramid_pooling_key: str. A layer level to extract and perform + `spatial_pyramid_pooling`, one of the key from the `image_encoder` + `pyramid_outputs` property such as "P4", "P5" etc. + upsampling_size: int or tuple of 2 integers. The upsampling factors for + rows and columns of `spatial_pyramid_pooling` layer. + If `low_level_feature_key` is given then `spatial_pyramid_pooling`s + layer resolution should match with the `low_level_feature`s layer + resolution to concatenate both the layers for combined encoder + outputs. + dilation_rates: list. A `list` of integers for parallel dilated conv applied to + `SpatialPyramidPooling`. Usually a + sample choice of rates are `[6, 12, 18]`. + low_level_feature_key: str optional. A layer level to extract the feature + from one of the key from the `image_encoder`s `pyramid_outputs` + property such as "P2", "P3" etc which will be the Decoder block. + Required only when the DeepLabV3Plus architecture needs to be applied. + image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + + Example: + ```python + # Load a trained backbone to extract features from it's `pyramid_outputs`. + image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet") + + model = keras_hub.models.DeepLabV3Backbone( + image_encoder=image_encoder, + projection_filters=48, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P5", + upsampling_size = 8, + dilation_rates = [6, 12, 18] + ) + ``` + """ + + def __init__( + self, + image_encoder, + spatial_pyramid_pooling_key, + upsampling_size, + dilation_rates, + low_level_feature_key=None, + projection_filters=48, + image_shape=(None, None, 3), + **kwargs, + ): + if not isinstance(image_encoder, keras.Model): + raise ValueError( + "Argument `image_encoder` must be a `keras.Model` instance. Received instead " + f"{image_encoder} (of type {type(image_encoder)})." + ) + data_format = keras.config.image_data_format() + channel_axis = -1 if data_format == "channels_last" else 1 + + # === Layers === + inputs = keras.layers.Input(image_shape, name="inputs") + + fpn_model = keras.Model( + image_encoder.inputs, image_encoder.pyramid_outputs + ) + + fpn_outputs = fpn_model(inputs) + + spatial_pyramid_pooling = SpatialPyramidPooling( + dilation_rates=dilation_rates + ) + spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key] + spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) + + encoder_outputs = keras.layers.UpSampling2D( + size=upsampling_size, + interpolation="bilinear", + name="encoder_output_upsampling", + data_format=data_format, + )(spp_outputs) + + if low_level_feature_key: + decoder_feature = fpn_outputs[low_level_feature_key] + low_level_projected_features = apply_low_level_feature_network( + decoder_feature, projection_filters, channel_axis + ) + + encoder_outputs = keras.layers.Concatenate( + axis=channel_axis, name="encoder_decoder_concat" + )([encoder_outputs, low_level_projected_features]) + # upsampling to the original image size + upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // ( + int(upsampling_size[0]) + if isinstance(upsampling_size, tuple) + else upsampling_size + ) + # === Functional Model === + x = keras.layers.Conv2D( + name="segmentation_head_conv", + filters=256, + kernel_size=1, + padding="same", + use_bias=False, + data_format=data_format, + )(encoder_outputs) + x = keras.layers.BatchNormalization( + name="segmentation_head_norm", axis=channel_axis + )(x) + x = keras.layers.ReLU(name="segmentation_head_relu")(x) + x = keras.layers.UpSampling2D( + size=upsampling, + interpolation="bilinear", + data_format=data_format, + name="backbone_output_upsampling", + )(x) + + super().__init__(inputs=inputs, outputs=x, **kwargs) + + # === Config === + self.image_shape = image_shape + self.image_encoder = image_encoder + self.projection_filters = projection_filters + self.upsampling_size = upsampling_size + self.dilation_rates = dilation_rates + self.low_level_feature_key = low_level_feature_key + self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + "projection_filters": self.projection_filters, + "dilation_rates": self.dilation_rates, + "upsampling_size": self.upsampling_size, + "low_level_feature_key": self.low_level_feature_key, + "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, + "image_shape": self.image_shape, + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return super().from_config(config) + + +def apply_low_level_feature_network( + input_tensor, projection_filters, channel_axis +): + data_format = keras.config.image_data_format() + x = keras.layers.Conv2D( + name="decoder_conv", + filters=projection_filters, + kernel_size=1, + padding="same", + use_bias=False, + data_format=data_format, + )(input_tensor) + + x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)( + x + ) + x = keras.layers.ReLU(name="decoder_relu")(x) + return x diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py new file mode 100644 index 000000000..a7b180908 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -0,0 +1,73 @@ +import keras +import numpy as np +import pytest + +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( + SpatialPyramidPooling, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DeepLabV3Test(TestCase): + def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) + self.init_kwargs = { + "image_encoder": self.image_encoder, + "low_level_feature_key": "P2", + "spatial_pyramid_pooling_key": "P4", + "dilation_rates": [6, 12, 18], + "upsampling_size": 4, + "image_shape": (96, 96, 3), + } + self.input_data = np.ones((2, 96, 96, 3), dtype="float32") + + def test_segmentation_basics(self): + self.run_vision_backbone_test( + cls=DeepLabV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 96, 96, 256), + run_mixed_precision_check=False, + run_quantization_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + +class SpatialPyramidPoolingTest(TestCase): + def test_layer_behaviors(self): + self.run_layer_test( + cls=SpatialPyramidPooling, + init_kwargs={ + "dilation_rates": [6, 12, 18], + "activation": "relu", + "num_channels": 256, + "dropout": 0.1, + }, + input_data=keras.random.uniform(shape=(1, 4, 4, 6)), + expected_output_shape=(1, 4, 4, 256), + expected_num_trainable_weights=18, + expected_num_non_trainable_variables=13, + expected_num_non_trainable_weights=12, + run_precision_checks=False, + ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py new file mode 100644 index 000000000..837e508d2 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py @@ -0,0 +1,215 @@ +import keras +from keras import ops + + +class SpatialPyramidPooling(keras.layers.Layer): + """Implements the Atrous Spatial Pyramid Pooling. + + Reference for Atrous Spatial Pyramid Pooling [Rethinking Atrous Convolution + for Semantic Image Segmentation](https://arxiv.org/pdf/1706.05587.pdf) and + [Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation](https://arxiv.org/pdf/1802.02611.pdf) + + Args: + dilation_rates: list of ints. The dilation rate for parallel dilated conv. + Usually a sample choice of rates are `[6, 12, 18]`. + num_channels: int. The number of output channels, defaults to `256`. + activation: str. Activation to be used, defaults to `relu`. + dropout: float. The dropout rate of the final projection output after the + activations and batch norm, defaults to `0.0`, which means no dropout is + applied to the output. + + Example: + ```python + inp = keras.layers.Input((384, 384, 3)) + backbone = keras.applications.EfficientNetB0( + input_tensor=inp, + include_top=False) + output = backbone(inp) + output = SpatialPyramidPooling( + dilation_rates=[6, 12, 18])(output) + ``` + """ + + def __init__( + self, + dilation_rates, + num_channels=256, + activation="relu", + dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.dilation_rates = dilation_rates + self.num_channels = num_channels + self.activation = activation + self.dropout = dropout + self.data_format = keras.config.image_data_format() + self.channel_axis = -1 if self.data_format == "channels_last" else 1 + + def build(self, input_shape): + channels = input_shape[self.channel_axis] + + # This is the parallel networks that process the input features with + # different dilation rates. The output from each channel will be merged + # together and feed to the output. + self.aspp_parallel_channels = [] + + # Channel1 with Conv2D and 1x1 kernel size. + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + name="aspp_conv_1", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="aspp_bn_1" + ), + keras.layers.Activation( + self.activation, name="aspp_activation_1" + ), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Channel 2 and afterwards are based on self.dilation_rates, and each of + # them will have conv2D with 3x3 kernel size. + for i, dilation_rate in enumerate(self.dilation_rates): + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(3, 3), + padding="same", + dilation_rate=dilation_rate, + use_bias=False, + data_format=self.data_format, + name=f"aspp_conv_{i+2}", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name=f"aspp_bn_{i+2}" + ), + keras.layers.Activation( + self.activation, name=f"aspp_activation_{i+2}" + ), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Last channel is the global average pooling with conv2D 1x1 kernel. + if self.channel_axis == -1: + reshape = keras.layers.Reshape((1, 1, channels), name="reshape") + else: + reshape = keras.layers.Reshape((channels, 1, 1), name="reshape") + pool_sequential = keras.Sequential( + [ + keras.layers.GlobalAveragePooling2D( + data_format=self.data_format, name="average_pooling" + ), + reshape, + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + name="conv_pooling", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="bn_pooling" + ), + keras.layers.Activation( + self.activation, name="activation_pooling" + ), + ] + ) + pool_sequential.build(input_shape) + self.aspp_parallel_channels.append(pool_sequential) + + # Final projection layers + projection = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + name="conv_projection", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="bn_projection" + ), + keras.layers.Activation( + self.activation, name="activation_projection" + ), + keras.layers.Dropout(rate=self.dropout, name="dropout"), + ], + ) + projection_input_channels = ( + 2 + len(self.dilation_rates) + ) * self.num_channels + if self.data_format == "channels_first": + projection.build( + (input_shape[0],) + + (projection_input_channels,) + + (input_shape[2:]) + ) + else: + projection.build((input_shape[:-1]) + (projection_input_channels,)) + self.projection = projection + self.built = True + + def call(self, inputs): + """Calls the Atrous Spatial Pyramid Pooling layer on an input. + + Args: + inputs: A tensor of shape [batch, height, width, channels] + + Returns: + A tensor of shape [batch, height, width, num_channels] + """ + result = [] + + for channel in self.aspp_parallel_channels: + temp = ops.cast(channel(inputs), inputs.dtype) + result.append(temp) + + image_shape = ops.shape(inputs) + if self.channel_axis == -1: + height, width = image_shape[1], image_shape[2] + else: + height, width = image_shape[2], image_shape[3] + result[-1] = keras.layers.Resizing( + height, + width, + interpolation="bilinear", + data_format=self.data_format, + name="resizing", + )(result[-1]) + + result = ops.concatenate(result, axis=self.channel_axis) + return self.projection(result) + + def compute_output_shape(self, inputs_shape): + if self.data_format == "channels_first": + return tuple( + (inputs_shape[0],) + (self.num_channels,) + (inputs_shape[2:]) + ) + else: + return tuple((inputs_shape[:-1]) + (self.num_channels,)) + + def get_config(self): + config = super().get_config() + config.update( + { + "dilation_rates": self.dilation_rates, + "num_channels": self.num_channels, + "activation": self.activation, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py new file mode 100644 index 000000000..d592ea09f --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py @@ -0,0 +1,3 @@ +"""DeepLabV3 preset configurations.""" +# TODO, +backbone_presets = {} diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py new file mode 100644 index 000000000..c3f8ada00 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -0,0 +1,105 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.image_segmenter import ImageSegmenter + + +@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenter") +class DeepLabV3ImageSegmenter(ImageSegmenter): + """DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task. + + Args: + backbone: A `keras_hub.models.DeepLabV3` instance. + num_classes: int. The number of classes for the detection model. Note + that the `num_classes` contains the background class, and the + classes from the data should be represented by integers with range + `[0, num_classes]`. + activation: str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `None`. + preprocessor: A `keras_hub.models.DeepLabV3ImageSegmenterPreprocessor` + or `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Example: + Load a DeepLabV3 preset with all the 21 class, pretrained segmentation head. + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_resnet50_pascalvoc", + ) + segmenter.predict(images) + ``` + + Specify `num_classes` to load randomly initialized segmentation head. + ```python + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_resnet50_pascalvoc", + num_classes=2, + ) + segmenter.fit(images, labels, epochs=3) + segmenter.predict(images) # Trained 2 class segmentation. + ``` + Load DeepLabv3+ presets a extension of DeepLabv3 by adding a simple yet + effective decoder module to refine the segmentation results especially + along object boundaries. + ```python + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_plus_resnet50_pascalvoc", + ) + segmenter.predict(images) + ``` + """ + + backbone_cls = DeepLabV3Backbone + preprocessor_cls = None + + def __init__( + self, + backbone, + num_classes, + activation=None, + preprocessor=None, + **kwargs, + ): + data_format = keras.config.image_data_format() + # === Layers === + self.backbone = backbone + self.output_conv = keras.layers.Conv2D( + name="segmentation_output", + filters=num_classes, + kernel_size=1, + use_bias=False, + padding="same", + activation=activation, + data_format=data_format, + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_conv(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py new file mode 100644 index 000000000..6ac4e983e --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DeepLabV3ImageSegmenterTest(TestCase): + def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) + self.deeplab_backbone = DeepLabV3Backbone( + image_encoder=self.image_encoder, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P4", + dilation_rates=[6, 12, 18], + upsampling_size=4, + ) + self.init_kwargs = { + "backbone": self.deeplab_backbone, + "num_classes": 2, + "activation": "softmax", + } + self.images = np.ones((2, 96, 96, 3), dtype="float32") + self.labels = np.zeros((2, 96, 96, 2), dtype="float32") + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + batch_size=2, + expected_output_shape=(2, 96, 96, 2), + ) + + # @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/image_segmenter.py b/keras_hub/src/models/image_segmenter.py index fcd45db07..cca42bdb5 100644 --- a/keras_hub/src/models/image_segmenter.py +++ b/keras_hub/src/models/image_segmenter.py @@ -14,6 +14,26 @@ class ImageSegmenter(Task): All `ImageSegmenter` tasks include a `from_preset()` constructor which can be used to load a pre-trained config and weights. + `ImageSegmenter` tasks take an additional + `num_classes` argument, the number of segmentation classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a image and `y` is a label from `[0, num_classes)`. + + Example: + ```python + model = keras_hub.models.ImageSegmenter.from_preset( + "deeplab_resnet", + num_classes=2, + ) + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + output = model(images) + pred_labels = output[0] + + model.fit(images, labels, epochs=3) + ``` """ def __init__(self, *args, **kwargs):