From 8683ead2ea02c3ccb7c56b23c41341e8d72141ea Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 24 Sep 2024 10:47:27 -0700 Subject: [PATCH 1/4] Add Deeplab and DeepLabV3 with segmentation --- keras_hub/api/models/__init__.py | 7 + .../models/deeplab_v3/deeplab_v3_backbone.py | 205 +++++++++++++++++ .../deeplab_v3/deeplab_v3_backbone_test.py | 62 ++++++ .../models/deeplab_v3/deeplab_v3_layers.py | 209 ++++++++++++++++++ .../models/deeplab_v3/deeplab_v3_segmenter.py | 107 +++++++++ .../deeplab_v3/deeplab_v3_segmenter_test.py | 77 +++++++ keras_hub/src/models/image_segmenter.py | 105 +++++++++ 7 files changed, 772 insertions(+) create mode 100644 keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py create mode 100644 keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py create mode 100644 keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py create mode 100644 keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py create mode 100644 keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py create mode 100644 keras_hub/src/models/image_segmenter.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 400284e48..9c3cda94a 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -96,6 +96,12 @@ from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( DenseNetImageClassifier, @@ -175,6 +181,7 @@ from keras_hub.src.models.image_classifier_preprocessor import ( ImageClassifierPreprocessor, ) +from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py new file mode 100644 index 000000000..a3585dbe6 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py @@ -0,0 +1,205 @@ +# Copyright 2024 The KerasHUB Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( + SpatialPyramidPooling, +) + + +@keras_hub_export("keras_hub.models.DeepLabV3Backbone") +class DeepLabV3Backbone(Backbone): + """DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation. + + This class implements a DeepLabV3 & DeepLabV3Plus architecture as described + in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation](https://arxiv.org/abs/1802.02611)(ECCV 2018) + and [Rethinking Atrous Convolution for Semantic Image Segmentation]( + https://arxiv.org/abs/1706.05587)(CVPR 2017) + + Args: + image_encoder: `keras.Model` instance that is used as a feature + extractor for the Encoder. Should either be a + `keras_hub.models.Backbone` or a `keras.Model` that implements the + `pyramid_outputs` property with keys "P2", "P3" etc as values. + A somewhat sensible backbone to use in many cases is + the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`. + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `image_encoder`. + spatial_pyramid_pooling_key: str, layer level to extract and perform + `spatial_pyramid_pooling`, one of the key from the `image_encoder` + `pyramid_outputs` property such as "P4", "P5" etc. + upsampling_size: Int, or tuple of 2 integers. The upsampling factors for + rows and columns of `spatial_pyramid_pooling` layer. + If `low_level_feature_key` is given then `spatial_pyramid_pooling`s + layer resolution should match with the `low_level_feature`s layer + resolution to concatenate both the layers for combined encoder + outputs. + dilation_rates: A `list` of integers for parallel dilated conv applied to + `SpatialPyramidPooling`. Usually a + sample choice of rates are [6, 12, 18]. + low_level_feature_key: (Optional) str, layer level to extract the feature + from one of the key from the `image_encoder`s `pyramid_outputs` + property such as "P2", "P3" etc which will be the Decoder block. + Required only when the DeepLabV3Plus architecture needs to be applied. + image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + + Example: + ```python + image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50") + + model = keras_hub.models.DeepLabV3Backbone( + image_encoder=image_encoder, + projection_filters=48, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P5", + upsampling_size = 8, + dilation_rates = [6, 12, 18] + ) + ``` + """ + + def __init__( + self, + image_encoder, + spatial_pyramid_pooling_key, + upsampling_size, + dilation_rates, + low_level_feature_key=None, + projection_filters=48, + image_shape=(None, None, 3), + **kwargs, + ): + if not isinstance(image_encoder, keras.Model): + raise ValueError( + "Argument `image_encoder` must be a `keras.Model` instance. Received instead " + f"{image_encoder} (of type {type(image_encoder)})." + ) + data_format = keras.config.image_data_format() + channel_axis = -1 if data_format == "channels_last" else 1 + # === Functional Model === + inputs = keras.layers.Input(image_shape) + + fpn_model = keras.Model( + image_encoder.inputs, image_encoder.pyramid_outputs + ) + + fpn_outputs = fpn_model(inputs) + + spatial_pyramid_pooling = SpatialPyramidPooling( + dilation_rates=dilation_rates + ) + spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key] + spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) + + encoder_outputs = keras.layers.UpSampling2D( + size=upsampling_size, + interpolation="bilinear", + name="encoder_output_upsampling", + data_format=data_format, + )(spp_outputs) + + if low_level_feature_key: + decoder_feature = fpn_outputs[low_level_feature_key] + low_level_projected_features = apply_low_level_feature_network( + decoder_feature, projection_filters, channel_axis + ) + + encoder_outputs = keras.layers.Concatenate(axis=channel_axis)( + [encoder_outputs, low_level_projected_features] + ) + # upsampling to the original image size + upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // ( + int(upsampling_size[0]) + if isinstance(upsampling_size, tuple) + else upsampling_size + ) + x = keras.layers.Conv2D( + name="segmentation_head_conv", + filters=256, + kernel_size=1, + padding="same", + use_bias=False, + data_format=data_format, + )(encoder_outputs) + x = keras.layers.BatchNormalization( + name="segmentation_head_norm", axis=channel_axis + )(x) + x = keras.layers.ReLU(name="segmentation_head_relu")(x) + x = keras.layers.UpSampling2D( + size=upsampling, + interpolation="bilinear", + data_format=data_format, + )(x) + + super().__init__(inputs=inputs, outputs=x, **kwargs) + + # === Config === + self.image_shape = image_shape + self.image_encoder = image_encoder + self.projection_filters = projection_filters + self.upsampling_size = upsampling_size + self.dilation_rates = dilation_rates + self.low_level_feature_key = low_level_feature_key + self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + "projection_filters": self.projection_filters, + "dilation_rates": self.dilation_rates, + "upsampling_size": self.upsampling_size, + "low_level_feature_key": self.low_level_feature_key, + "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, + "image_shape": self.image_shape, + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return super().from_config(config) + + +def apply_low_level_feature_network( + input_tensor, projection_filters, channel_axis +): + data_format = keras.config.image_data_format() + x = keras.layers.Conv2D( + name="low_level_feature_conv", + filters=projection_filters, + kernel_size=1, + padding="same", + use_bias=False, + data_format=data_format, + )(input_tensor) + + x = keras.layers.BatchNormalization( + name="low_level_feature_norm", axis=channel_axis + )(x) + x = keras.layers.ReLU(name="low_level_feature_relu")(x) + return x diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py new file mode 100644 index 000000000..64d9166ac --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The KerasHUB Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DeepLabV3Test(TestCase): + def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) + self.init_kwargs = { + "image_encoder": self.image_encoder, + "low_level_feature_key": "P2", + "spatial_pyramid_pooling_key": "P4", + "dilation_rates": [6, 12, 18], + "upsampling_size": 4, + } + self.input_data = np.ones((2, 96, 96, 3), dtype="float32") + + def test_segmentation_basics(self): + self.run_vision_backbone_test( + cls=DeepLabV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 96, 96, 256), + run_mixed_precision_check=False, + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py new file mode 100644 index 000000000..d124e922e --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py @@ -0,0 +1,209 @@ +# Copyright 2024 The KerasHUB Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +from keras import ops + + +class SpatialPyramidPooling(keras.layers.Layer): + """Implements the Atrous Spatial Pyramid Pooling. + + Reference for Atrous Spatial Pyramid Pooling [Rethinking Atrous Convolution + for Semantic Image Segmentation](https://arxiv.org/pdf/1706.05587.pdf) and + [Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation](https://arxiv.org/pdf/1802.02611.pdf) + + inp = keras.layers.Input((384, 384, 3)) + backbone = keras.applications.EfficientNetB0( + input_tensor=inp, + include_top=False) + output = backbone(inp) + output = SpatialPyramidPooling( + dilation_rates=[6, 12, 18])(output) + + # output[4].shape = [None, 16, 16, 256] + """ + + def __init__( + self, + dilation_rates, + num_channels=256, + activation="relu", + dropout=0.0, + **kwargs, + ): + """Initializes an Atrous Spatial Pyramid Pooling layer. + + Args: + dilation_rates: A `list` of integers for parallel dilated conv. + Usually a sample choice of rates are [6, 12, 18]. + num_channels: An `int` number of output channels, defaults to 256. + activation: A `str` activation to be used, defaults to 'relu'. + dropout: A `float` for the dropout rate of the final projection + output after the activations and batch norm, defaults to 0.0, + which means no dropout is applied to the output. + **kwargs: Additional keyword arguments to be passed. + """ + self.data_format = keras.config.image_data_format() + self.channel_axis = -1 if self.data_format == "channels_last" else 1 + super().__init__(**kwargs) + self.dilation_rates = dilation_rates + self.num_channels = num_channels + self.activation = activation + self.dropout = dropout + + def build(self, input_shape): + channels = input_shape[self.channel_axis] + + # This is the parallel networks that process the input features with + # different dilation rates. The output from each channel will be merged + # together and feed to the output. + self.aspp_parallel_channels = [] + + # Channel1 with Conv2D and 1x1 kernel size. + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + ), + keras.layers.BatchNormalization(axis=self.channel_axis), + keras.layers.Activation(self.activation), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Channel 2 and afterwards are based on self.dilation_rates, and each of + # them will have conv2D with 3x3 kernel size. + for dilation_rate in self.dilation_rates: + conv_sequential = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(3, 3), + padding="same", + dilation_rate=dilation_rate, + use_bias=False, + data_format=self.data_format, + ), + keras.layers.BatchNormalization(axis=self.channel_axis), + keras.layers.Activation(self.activation), + ] + ) + conv_sequential.build(input_shape) + self.aspp_parallel_channels.append(conv_sequential) + + # Last channel is the global average pooling with conv2D 1x1 kernel. + if self.channel_axis == -1: + reshape = keras.layers.Reshape((1, 1, channels)) + else: + reshape = keras.layers.Reshape((channels, 1, 1)) + pool_sequential = keras.Sequential( + [ + keras.layers.GlobalAveragePooling2D( + data_format=self.data_format + ), + reshape, + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + ), + keras.layers.BatchNormalization(axis=self.channel_axis), + keras.layers.Activation(self.activation), + ] + ) + pool_sequential.build(input_shape) + self.aspp_parallel_channels.append(pool_sequential) + + # Final projection layers + projection = keras.Sequential( + [ + keras.layers.Conv2D( + filters=self.num_channels, + kernel_size=(1, 1), + use_bias=False, + data_format=self.data_format, + ), + keras.layers.BatchNormalization(axis=self.channel_axis), + keras.layers.Activation(self.activation), + keras.layers.Dropout(rate=self.dropout), + ], + ) + projection_input_channels = ( + 2 + len(self.dilation_rates) + ) * self.num_channels + projection.build(tuple(input_shape[:-1]) + (projection_input_channels,)) + self.projection = projection + + def call(self, inputs, training=None): + """Calls the Atrous Spatial Pyramid Pooling layer on an input. + + Args: + inputs: A tensor of shape [batch, height, width, channels] + + Returns: + A tensor of shape [batch, height, width, num_channels] + """ + result = [] + + for channel in self.aspp_parallel_channels: + temp = ops.cast(channel(inputs, training=training), inputs.dtype) + result.append(temp) + + image_shape = ops.shape(inputs) + if self.channel_axis == -1: + height, width = image_shape[1], image_shape[2] + else: + height, width = image_shape[2], image_shape[3] + result[self.channel_axis] = keras.layers.Resizing( + height, + width, + interpolation="bilinear", + )(result[self.channel_axis]) + + result = ops.concatenate(result, axis=self.channel_axis) + return self.projection(result, training=training) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + return ( + input_shape[0], + self.num_channels, + input_shape[1], + input_shape[2], + ) + else: + return ( + input_shape[0], + input_shape[1], + input_shape[2], + self.num_channels, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "dilation_rates": self.dilation_rates, + "num_channels": self.num_channels, + "activation": self.activation, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py new file mode 100644 index 000000000..e2cd18e15 --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -0,0 +1,107 @@ +# Copyright 2024 The KerasHUB Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.image_segmenter import ImageSegmenter + + +@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenter") +class DeepLabV3ImageSegmenter(ImageSegmenter): + """DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task. + + Args: + backbone: A `keras_hub.models.DeepLabV3` instance. + num_classes: int, the number of classes for the detection model. Note + that the `num_classes` contains the background class, and the + classes from the data should be represented by integers with range + `[0, num_classes]`. + activation: str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `None`. + + Example: + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_resnet50", + ) + segmenter.predict(images) # 19 class, pretrained segmentation head. + + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_resnet50", + num_classes=2, + ) + segmenter.fit(images, labels, epochs=3) + segmenter.predict(images) # Trained 2 class segmentation. + ``` + """ + + backbone_cls = DeepLabV3Backbone + preprocessor_cls = None + + def __init__( + self, + backbone, + num_classes, + activation=None, + preprocessor=None, + **kwargs, + ): + data_format = keras.config.image_data_format() + # === Layers === + self.backbone = backbone + self.output_conv = keras.layers.Conv2D( + name="segmentation_output", + filters=num_classes, + kernel_size=1, + use_bias=False, + padding="same", + activation=activation, + data_format=data_format, + # Force the dtype of the classification layer to float32 + # to avoid the NAN loss issue when used with mixed + # precision API. + dtype="float32", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_conv(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py new file mode 100644 index 000000000..e522cd02d --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py @@ -0,0 +1,77 @@ +# Copyright 2024 The KerasHUB Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DeepLabV3ImageSegmenterTest(TestCase): + def setUp(self): + self.resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "block_type": "basic_block", + "use_pre_activation": False, + } + self.image_encoder = ResNetBackbone(**self.resnet_kwargs) + self.deeplab_backbone = DeepLabV3Backbone( + image_encoder=self.image_encoder, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P4", + dilation_rates=[6, 12, 18], + upsampling_size=4, + ) + self.init_kwargs = { + "backbone": self.deeplab_backbone, + "num_classes": 2, + "activation": "softmax", + } + self.images = np.ones((2, 96, 96, 3), dtype="float32") + self.labels = np.zeros((2, 96, 96, 2), dtype="float32") + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + batch_size=2, + expected_output_shape=(2, 96, 96, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/image_segmenter.py b/keras_hub/src/models/image_segmenter.py new file mode 100644 index 000000000..af11c5f4a --- /dev/null +++ b/keras_hub/src/models/image_segmenter.py @@ -0,0 +1,105 @@ +# Copyright 2024 The KerasHUB Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task + + +@keras_hub_export("keras_hub.models.ImageSegmenter") +class ImageSegmenter(Task): + """Base class for all segmentation tasks. + + `ImageSegmenter` tasks wrap a `keras_hub.models.Backbone` to create a model + that can be used for segmentation. + `Segmenter` tasks take an additional + `num_classes` argument, the number of segmentation classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a image and `y` is a label from `[0, num_classes)`. + + All `ImageSegmenter` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + ```python + model = keras_hub.models.Segmenter.from_preset( + "basnet_resnet", + num_classes=2, + ) + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + output = model(images) + pred_labels = output[0] + + model.fit(images, labels, epochs=3) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageSegmenter` task for training. + + The `ImageSegmenter` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.BinaryCrossentropy` loss will be + applied for the segmentation task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.Accuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.BinaryCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.Accuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) From 5540e2026f8380403a76d0b1a8f832225c6427a1 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 25 Sep 2024 15:27:29 -0700 Subject: [PATCH 2/4] address comments --- keras_hub/src/models/deeplab_v3/__init__.py | 21 ++++ .../models/deeplab_v3/deeplab_v3_backbone.py | 46 ++++---- .../deeplab_v3/deeplab_v3_backbone_test.py | 25 +++- .../models/deeplab_v3/deeplab_v3_layers.py | 107 +++++++++++------- .../models/deeplab_v3/deeplab_v3_presets.py | 16 +++ .../models/deeplab_v3/deeplab_v3_segmenter.py | 26 ++++- .../deeplab_v3/deeplab_v3_segmenter_test.py | 4 +- keras_hub/src/models/image_segmenter.py | 4 +- 8 files changed, 174 insertions(+), 75 deletions(-) create mode 100644 keras_hub/src/models/deeplab_v3/__init__.py create mode 100644 keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py diff --git a/keras_hub/src/models/deeplab_v3/__init__.py b/keras_hub/src/models/deeplab_v3/__init__.py new file mode 100644 index 000000000..12e06b4af --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, DeepLabV3Backbone) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py index a3585dbe6..fd472c4f2 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py @@ -1,4 +1,4 @@ -# Copyright 2024 The KerasHUB Authors +# Copyright 2024 The KerasHub Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,33 +25,33 @@ class DeepLabV3Backbone(Backbone): """DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation. This class implements a DeepLabV3 & DeepLabV3Plus architecture as described - in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image - Segmentation](https://arxiv.org/abs/1802.02611)(ECCV 2018) + in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation]( + https://arxiv.org/abs/1802.02611)(ECCV 2018) and [Rethinking Atrous Convolution for Semantic Image Segmentation]( https://arxiv.org/abs/1706.05587)(CVPR 2017) Args: - image_encoder: `keras.Model` instance that is used as a feature + image_encoder: `keras.Model`. An instance that is used as a feature extractor for the Encoder. Should either be a `keras_hub.models.Backbone` or a `keras.Model` that implements the `pyramid_outputs` property with keys "P2", "P3" etc as values. A somewhat sensible backbone to use in many cases is the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`. - projection_filters: int, number of filters in the convolution layer + projection_filters: int. Number of filters in the convolution layer projecting low-level features from the `image_encoder`. - spatial_pyramid_pooling_key: str, layer level to extract and perform + spatial_pyramid_pooling_key: str. A layer level to extract and perform `spatial_pyramid_pooling`, one of the key from the `image_encoder` `pyramid_outputs` property such as "P4", "P5" etc. - upsampling_size: Int, or tuple of 2 integers. The upsampling factors for + upsampling_size: int or tuple of 2 integers. The upsampling factors for rows and columns of `spatial_pyramid_pooling` layer. If `low_level_feature_key` is given then `spatial_pyramid_pooling`s layer resolution should match with the `low_level_feature`s layer resolution to concatenate both the layers for combined encoder outputs. - dilation_rates: A `list` of integers for parallel dilated conv applied to + dilation_rates: list. A `list` of integers for parallel dilated conv applied to `SpatialPyramidPooling`. Usually a - sample choice of rates are [6, 12, 18]. - low_level_feature_key: (Optional) str, layer level to extract the feature + sample choice of rates are `[6, 12, 18]`. + low_level_feature_key: str optional. A layer level to extract the feature from one of the key from the `image_encoder`s `pyramid_outputs` property such as "P2", "P3" etc which will be the Decoder block. Required only when the DeepLabV3Plus architecture needs to be applied. @@ -60,7 +60,8 @@ class DeepLabV3Backbone(Backbone): Example: ```python - image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50") + # Load a trained backbone to extract features from it's `pyramid_outputs`. + image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet") model = keras_hub.models.DeepLabV3Backbone( image_encoder=image_encoder, @@ -91,8 +92,9 @@ def __init__( ) data_format = keras.config.image_data_format() channel_axis = -1 if data_format == "channels_last" else 1 - # === Functional Model === - inputs = keras.layers.Input(image_shape) + + # === Layers === + inputs = keras.layers.Input(image_shape, name="inputs") fpn_model = keras.Model( image_encoder.inputs, image_encoder.pyramid_outputs @@ -119,15 +121,16 @@ def __init__( decoder_feature, projection_filters, channel_axis ) - encoder_outputs = keras.layers.Concatenate(axis=channel_axis)( - [encoder_outputs, low_level_projected_features] - ) + encoder_outputs = keras.layers.Concatenate( + axis=channel_axis, name="encoder_decoder_concat" + )([encoder_outputs, low_level_projected_features]) # upsampling to the original image size upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // ( int(upsampling_size[0]) if isinstance(upsampling_size, tuple) else upsampling_size ) + # === Functional Model === x = keras.layers.Conv2D( name="segmentation_head_conv", filters=256, @@ -144,6 +147,7 @@ def __init__( size=upsampling, interpolation="bilinear", data_format=data_format, + name="backbone_output_upsampling", )(x) super().__init__(inputs=inputs, outputs=x, **kwargs) @@ -190,7 +194,7 @@ def apply_low_level_feature_network( ): data_format = keras.config.image_data_format() x = keras.layers.Conv2D( - name="low_level_feature_conv", + name="decoder_conv", filters=projection_filters, kernel_size=1, padding="same", @@ -198,8 +202,8 @@ def apply_low_level_feature_network( data_format=data_format, )(input_tensor) - x = keras.layers.BatchNormalization( - name="low_level_feature_norm", axis=channel_axis - )(x) - x = keras.layers.ReLU(name="low_level_feature_relu")(x) + x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)( + x + ) + x = keras.layers.ReLU(name="decoder_relu")(x) return x diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py index 64d9166ac..08f7f8a41 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The KerasHUB Authors +# Copyright 2024 The KerasHub Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,9 @@ from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( DeepLabV3Backbone, ) +from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( + SpatialPyramidPooling, +) from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -40,6 +43,7 @@ def setUp(self): "spatial_pyramid_pooling_key": "P4", "dilation_rates": [6, 12, 18], "upsampling_size": 4, + "image_shape": (96, 96, 3), } self.input_data = np.ones((2, 96, 96, 3), dtype="float32") @@ -51,6 +55,7 @@ def test_segmentation_basics(self): expected_output_shape=(2, 96, 96, 256), run_mixed_precision_check=False, run_quantization_check=False, + run_data_format_check=False, ) @pytest.mark.large @@ -60,3 +65,21 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + +class SpatialPyramidPoolingTest(TestCase): + def test_layer_behaviors(self): + self.run_layer_test( + cls=SpatialPyramidPooling, + init_kwargs={ + "dilation_rates": [6, 12, 18], + "activation": "relu", + "num_channels": 256, + "dropout": 0.1, + }, + input_data=np.random.randn(1, 4, 4, 6), + expected_output_shape=(1, 4, 4, 256), + expected_num_trainable_weights=18, + expected_num_non_trainable_variables=13, + expected_num_non_trainable_weights=12, + ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py index d124e922e..c930ff5a6 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py @@ -1,4 +1,4 @@ -# Copyright 2024 The KerasHUB Authors +# Copyright 2024 The KerasHub Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,17 @@ class SpatialPyramidPooling(keras.layers.Layer): [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf) + Args: + dilation_rates: list of ints. The dilation rate for parallel dilated conv. + Usually a sample choice of rates are `[6, 12, 18]`. + num_channels: int. The number of output channels, defaults to `256`. + activation: str. Activation to be used, defaults to `relu`. + dropout: float. The dropout rate of the final projection output after the + activations and batch norm, defaults to `0.0`, which means no dropout is + applied to the output. + + Example: + ```python inp = keras.layers.Input((384, 384, 3)) backbone = keras.applications.EfficientNetB0( input_tensor=inp, @@ -31,8 +42,7 @@ class SpatialPyramidPooling(keras.layers.Layer): output = backbone(inp) output = SpatialPyramidPooling( dilation_rates=[6, 12, 18])(output) - - # output[4].shape = [None, 16, 16, 256] + ``` """ def __init__( @@ -43,25 +53,13 @@ def __init__( dropout=0.0, **kwargs, ): - """Initializes an Atrous Spatial Pyramid Pooling layer. - - Args: - dilation_rates: A `list` of integers for parallel dilated conv. - Usually a sample choice of rates are [6, 12, 18]. - num_channels: An `int` number of output channels, defaults to 256. - activation: A `str` activation to be used, defaults to 'relu'. - dropout: A `float` for the dropout rate of the final projection - output after the activations and batch norm, defaults to 0.0, - which means no dropout is applied to the output. - **kwargs: Additional keyword arguments to be passed. - """ - self.data_format = keras.config.image_data_format() - self.channel_axis = -1 if self.data_format == "channels_last" else 1 super().__init__(**kwargs) self.dilation_rates = dilation_rates self.num_channels = num_channels self.activation = activation self.dropout = dropout + self.data_format = keras.config.image_data_format() + self.channel_axis = -1 if self.data_format == "channels_last" else 1 def build(self, input_shape): channels = input_shape[self.channel_axis] @@ -79,9 +77,14 @@ def build(self, input_shape): kernel_size=(1, 1), use_bias=False, data_format=self.data_format, + name="aspp_conv_1", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="aspp_bn_1" + ), + keras.layers.Activation( + self.activation, name="aspp_activation_1" ), - keras.layers.BatchNormalization(axis=self.channel_axis), - keras.layers.Activation(self.activation), ] ) conv_sequential.build(input_shape) @@ -89,7 +92,7 @@ def build(self, input_shape): # Channel 2 and afterwards are based on self.dilation_rates, and each of # them will have conv2D with 3x3 kernel size. - for dilation_rate in self.dilation_rates: + for i, dilation_rate in enumerate(self.dilation_rates): conv_sequential = keras.Sequential( [ keras.layers.Conv2D( @@ -99,9 +102,14 @@ def build(self, input_shape): dilation_rate=dilation_rate, use_bias=False, data_format=self.data_format, + name=f"aspp_conv_{i+2}", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name=f"aspp_bn_{i+2}" + ), + keras.layers.Activation( + self.activation, name=f"aspp_activation_{i+2}" ), - keras.layers.BatchNormalization(axis=self.channel_axis), - keras.layers.Activation(self.activation), ] ) conv_sequential.build(input_shape) @@ -109,13 +117,13 @@ def build(self, input_shape): # Last channel is the global average pooling with conv2D 1x1 kernel. if self.channel_axis == -1: - reshape = keras.layers.Reshape((1, 1, channels)) + reshape = keras.layers.Reshape((1, 1, channels), name="reshape") else: - reshape = keras.layers.Reshape((channels, 1, 1)) + reshape = keras.layers.Reshape((channels, 1, 1), name="reshape") pool_sequential = keras.Sequential( [ keras.layers.GlobalAveragePooling2D( - data_format=self.data_format + data_format=self.data_format, name="average_pooling" ), reshape, keras.layers.Conv2D( @@ -123,9 +131,14 @@ def build(self, input_shape): kernel_size=(1, 1), use_bias=False, data_format=self.data_format, + name="conv_pooling", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="bn_pooling" + ), + keras.layers.Activation( + self.activation, name="activation_pooling" ), - keras.layers.BatchNormalization(axis=self.channel_axis), - keras.layers.Activation(self.activation), ] ) pool_sequential.build(input_shape) @@ -139,16 +152,28 @@ def build(self, input_shape): kernel_size=(1, 1), use_bias=False, data_format=self.data_format, + name="conv_projection", + ), + keras.layers.BatchNormalization( + axis=self.channel_axis, name="bn_projection" ), - keras.layers.BatchNormalization(axis=self.channel_axis), - keras.layers.Activation(self.activation), - keras.layers.Dropout(rate=self.dropout), + keras.layers.Activation( + self.activation, name="activation_projection" + ), + keras.layers.Dropout(rate=self.dropout, name="dropout"), ], ) projection_input_channels = ( 2 + len(self.dilation_rates) ) * self.num_channels - projection.build(tuple(input_shape[:-1]) + (projection_input_channels,)) + if self.data_format == "channels_first": + projection.build( + (input_shape[0],) + + (projection_input_channels,) + + (input_shape[2:]) + ) + else: + projection.build((input_shape[:-1]) + (projection_input_channels,)) self.projection = projection def call(self, inputs, training=None): @@ -171,30 +196,24 @@ def call(self, inputs, training=None): height, width = image_shape[1], image_shape[2] else: height, width = image_shape[2], image_shape[3] - result[self.channel_axis] = keras.layers.Resizing( + result[-1] = keras.layers.Resizing( height, width, interpolation="bilinear", - )(result[self.channel_axis]) + data_format=self.data_format, + name="resizing", + )(result[-1]) result = ops.concatenate(result, axis=self.channel_axis) return self.projection(result, training=training) def compute_output_shape(self, input_shape): if self.data_format == "channels_first": - return ( - input_shape[0], - self.num_channels, - input_shape[1], - input_shape[2], + return tuple( + (input_shape[0],) + (self.num_channels,) + (input_shape[2:]) ) else: - return ( - input_shape[0], - input_shape[1], - input_shape[2], - self.num_channels, - ) + return tuple((input_shape[:-1]) + (self.num_channels,)) def get_config(self): config = super().get_config() diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py new file mode 100644 index 000000000..cd33ac9ec --- /dev/null +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py @@ -0,0 +1,16 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DeepLabV3 preset configurations.""" +# TODO, +backbone_presets = {} diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py index e2cd18e15..6b4829d04 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -1,4 +1,4 @@ -# Copyright 2024 The KerasHUB Authors +# Copyright 2024 The KerasHub Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,30 +27,46 @@ class DeepLabV3ImageSegmenter(ImageSegmenter): Args: backbone: A `keras_hub.models.DeepLabV3` instance. - num_classes: int, the number of classes for the detection model. Note + num_classes: int. The number of classes for the detection model. Note that the `num_classes` contains the background class, and the classes from the data should be represented by integers with range `[0, num_classes]`. activation: str or callable. The activation function to use on the `Dense` layer. Set `activation=None` to return the output logits. Defaults to `None`. + preprocessor: A `keras_hub.models.DeepLabV3ImageSegmenterPreprocessor` + or `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. Example: + Load a DeepLabV3 preset with all the 21 class, pretrained segmentation head. ```python images = np.ones(shape=(1, 96, 96, 3)) labels = np.zeros(shape=(1, 96, 96, 1)) segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( - "deeplabv3_resnet50", + "deeplabv3_resnet50_pascalvoc", ) - segmenter.predict(images) # 19 class, pretrained segmentation head. + segmenter.predict(images) + ``` + Specify `num_classes` to load randomly initialized segmentation head. + ```python segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( - "deeplabv3_resnet50", + "deeplabv3_resnet50_pascalvoc", num_classes=2, ) segmenter.fit(images, labels, epochs=3) segmenter.predict(images) # Trained 2 class segmentation. ``` + Load DeepLabv3+ presets a extension of DeepLabv3 by adding a simple yet + effective decoder module to refine the segmentation results especially + along object boundaries. + ```python + segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplabv3_plus_resnet50_pascalvoc", + ) + segmenter.predict(images) + ``` """ backbone_cls = DeepLabV3Backbone diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py index e522cd02d..52c27aa57 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The KerasHUB Authors +# Copyright 2024 The KerasHub Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ def test_classifier_basics(self): expected_output_shape=(2, 96, 96, 2), ) - @pytest.mark.large + # @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( cls=DeepLabV3ImageSegmenter, diff --git a/keras_hub/src/models/image_segmenter.py b/keras_hub/src/models/image_segmenter.py index 6b2ee8cd4..c75776cb7 100644 --- a/keras_hub/src/models/image_segmenter.py +++ b/keras_hub/src/models/image_segmenter.py @@ -1,4 +1,4 @@ -# Copyright 2024 The KerasHUB Authors +# Copyright 2024 The KerasHub Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ class ImageSegmenter(Task): All `ImageSegmenter` tasks include a `from_preset()` constructor which can be used to load a pre-trained config and weights. - `Segmenter` tasks take an additional + `ImageSegmenter` tasks take an additional `num_classes` argument, the number of segmentation classes. To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` From f48de5354a6d6b7776ad5cf49b786034ed67737a Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 26 Sep 2024 13:58:49 -0700 Subject: [PATCH 3/4] test fix --- .../models/deeplab_v3/deeplab_v3_backbone_test.py | 4 +++- .../src/models/deeplab_v3/deeplab_v3_layers.py | 13 +++++++------ .../src/models/deeplab_v3/deeplab_v3_segmenter.py | 4 ---- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py index 08f7f8a41..fd2a259d6 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import keras import numpy as np import pytest @@ -77,9 +78,10 @@ def test_layer_behaviors(self): "num_channels": 256, "dropout": 0.1, }, - input_data=np.random.randn(1, 4, 4, 6), + input_data=keras.random.uniform(shape=(1, 4, 4, 6)), expected_output_shape=(1, 4, 4, 256), expected_num_trainable_weights=18, expected_num_non_trainable_variables=13, expected_num_non_trainable_weights=12, + run_precision_checks=False, ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py index c930ff5a6..d17fe373c 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py @@ -175,8 +175,9 @@ def build(self, input_shape): else: projection.build((input_shape[:-1]) + (projection_input_channels,)) self.projection = projection + self.built = True - def call(self, inputs, training=None): + def call(self, inputs): """Calls the Atrous Spatial Pyramid Pooling layer on an input. Args: @@ -188,7 +189,7 @@ def call(self, inputs, training=None): result = [] for channel in self.aspp_parallel_channels: - temp = ops.cast(channel(inputs, training=training), inputs.dtype) + temp = ops.cast(channel(inputs), inputs.dtype) result.append(temp) image_shape = ops.shape(inputs) @@ -205,15 +206,15 @@ def call(self, inputs, training=None): )(result[-1]) result = ops.concatenate(result, axis=self.channel_axis) - return self.projection(result, training=training) + return self.projection(result) - def compute_output_shape(self, input_shape): + def compute_output_shape(self, inputs_shape): if self.data_format == "channels_first": return tuple( - (input_shape[0],) + (self.num_channels,) + (input_shape[2:]) + (inputs_shape[0],) + (self.num_channels,) + (inputs_shape[2:]) ) else: - return tuple((input_shape[:-1]) + (self.num_channels,)) + return tuple((inputs_shape[:-1]) + (self.num_channels,)) def get_config(self): config = super().get_config() diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py index 6b4829d04..7e71b706f 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -91,10 +91,6 @@ def __init__( padding="same", activation=activation, data_format=data_format, - # Force the dtype of the classification layer to float32 - # to avoid the NAN loss issue when used with mixed - # precision API. - dtype="float32", ) # === Functional Model === From 0fd1708faef550e284dda1640cfe4cb69040b1eb Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 26 Sep 2024 15:37:19 -0700 Subject: [PATCH 4/4] update copyright --- keras_hub/src/models/deeplab_v3/__init__.py | 14 -------------- .../src/models/deeplab_v3/deeplab_v3_backbone.py | 13 ------------- .../models/deeplab_v3/deeplab_v3_backbone_test.py | 14 -------------- .../src/models/deeplab_v3/deeplab_v3_layers.py | 14 -------------- .../src/models/deeplab_v3/deeplab_v3_presets.py | 13 ------------- .../src/models/deeplab_v3/deeplab_v3_segmenter.py | 14 -------------- .../models/deeplab_v3/deeplab_v3_segmenter_test.py | 14 -------------- 7 files changed, 96 deletions(-) diff --git a/keras_hub/src/models/deeplab_v3/__init__.py b/keras_hub/src/models/deeplab_v3/__init__.py index 12e06b4af..0a959e186 100644 --- a/keras_hub/src/models/deeplab_v3/__init__.py +++ b/keras_hub/src/models/deeplab_v3/__init__.py @@ -1,17 +1,3 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( DeepLabV3Backbone, ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py index fd472c4f2..70bf828b0 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py @@ -1,16 +1,3 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import keras from keras_hub.src.api_export import keras_hub_export diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py index fd2a259d6..a7b180908 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -1,17 +1,3 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import keras import numpy as np import pytest diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py index d17fe373c..837e508d2 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py @@ -1,17 +1,3 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import keras from keras import ops diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py index cd33ac9ec..d592ea09f 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py @@ -1,16 +1,3 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """DeepLabV3 preset configurations.""" # TODO, backbone_presets = {} diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py index 7e71b706f..c3f8ada00 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py @@ -1,17 +1,3 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import keras from keras_hub.src.api_export import keras_hub_export diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py index 52c27aa57..6ac4e983e 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py @@ -1,17 +1,3 @@ -# Copyright 2024 The KerasHub Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import numpy as np import pytest