diff --git a/keras_hub/src/models/image_segmenter.py b/keras_hub/src/models/image_segmenter.py new file mode 100644 index 000000000..c75776cb7 --- /dev/null +++ b/keras_hub/src/models/image_segmenter.py @@ -0,0 +1,106 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task + + +@keras_hub_export("keras_hub.models.ImageSegmenter") +class ImageSegmenter(Task): + """Base class for all image segmentation tasks. + + `ImageSegmenter` tasks wrap a `keras_hub.models.Task` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + image segmentation. + + All `ImageSegmenter` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + `ImageSegmenter` tasks take an additional + `num_classes` argument, the number of segmentation classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a image and `y` is a label from `[0, num_classes)`. + + Example: + ```python + model = keras_hub.models.ImageSegmenter.from_preset( + "deeplab_resnet", + num_classes=2, + ) + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + output = model(images) + pred_labels = output[0] + + model.fit(images, labels, epochs=3) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageSegmenter` task for training. + + The `ImageSegmenter` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the classification task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits) + if metrics == "auto": + metrics = [keras.metrics.CategoricalAccuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_hub/src/models/segformer/__init__.py b/keras_hub/src/models/segformer/__init__.py new file mode 100644 index 000000000..2555ca71f --- /dev/null +++ b/keras_hub/src/models/segformer/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_segmenter import ( + SegFormerImageSegmenter, +) +from keras_hub.src.models.segformer.segformer_presets import presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(presets, SegFormerImageSegmenter) diff --git a/keras_hub/src/models/segformer/segformer_backbone.py b/keras_hub/src/models/segformer/segformer_backbone.py new file mode 100644 index 000000000..f2b997f50 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_backbone.py @@ -0,0 +1,173 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_hub.src.models.segformer.segformer_presets import presets + + +@keras_hub_export( + [ + "keras_hub.models.SegFormerBackbone", + "keras_hub.models.segmentation.SegFormerBackbone", + ] +) +class SegFormerBackbone(Backbone): + """A Keras model implementing the SegFormer architecture for semantic + segmentation. + + References: + - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) # noqa: E501 + - [Based on the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer) # noqa: E501 + + Args: + backbone: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the SegFormer encoder. + It is *intended* to be used only with the MiT backbone model which + was created specifically for SegFormers. It should either be a + `keras_cv.models.backbones.backbone.Backbone` or a `tf.keras.Model` + that implements the `pyramid_level_inputs` property with keys + "P2", "P3", "P4", and "P5" and layer names as + values. + num_classes: int, the number of classes for the detection model, + including the background class. + projection_filters: int, number of filters in the + convolution layer projecting the concatenated features into + a segmentation map. Defaults to 256`. + + Example: + + Using the class with a custom `backbone`: + + ```python + import tensorflow as tf + import keras_hub + + backbone = keras_hub.models.MiTBackbone( + depths=[2, 2, 2, 2], + image_shape=(16, 16, 3), + hidden_dims=[4, 8], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + end_value=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + + model = SegFormerBackbone(backbone=backbone, num_classes=4) + ``` + """ + + backbone_cls = MiTBackbone + + def __init__( + self, + backbone, + projection_filters=256, + **kwargs, + ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance " + f" or `keras.Model`. Received instead " + f"backbone={backbone} (of type {type(backbone)})." + ) + + self.feature_extractor = keras.Model( + backbone.outputs, backbone.pyramid_outputs + ) + + inputs = backbone.output + features = self.feature_extractor(inputs) + # Get H and W of level one output + _, H, W, _ = features["P1"].shape + + # === Layers === + + self.mlp_blocks = [] + + for feature_dim, feature in zip(backbone.hidden_dims, features): + self.mlp_blocks.append( + keras.layers.Dense( + projection_filters, name=f"linear_{feature_dim}" + ) + ) + + self.resizing = keras.layers.Resizing(H, W, interpolation="bilinear") + self.concat = keras.layers.Concatenate(axis=3) + self.segmentation = keras.Sequential( + [ + keras.layers.Conv2D( + filters=projection_filters, kernel_size=1, use_bias=False + ), + keras.layers.BatchNormalization(), + keras.layers.Activation("relu"), + ] + ) + + # === Functional Model === + + # Project all multi-level outputs onto the same dimensionality + # and feature map shape + multi_layer_outs = [] + for index, (feature_dim, feature) in enumerate( + zip(backbone.hidden_dims, features) + ): + out = self.mlp_blocks[index](features[feature]) + out = self.resizing(out) + multi_layer_outs.append(out) + + # Concat now-equal feature maps + concatenated_outs = self.concat(multi_layer_outs[::-1]) + + # Fuse concatenated features into a segmentation map + seg = self.segmentation(concatenated_outs) + + super().__init__( + inputs=inputs, + outputs=seg, + **kwargs, + ) + + self.projection_filters = projection_filters + self.backbone = backbone + + def get_config(self): + config = super().get_config() + config.update( + { + "projection_filters": self.projection_filters, + "backbone": keras.saving.serialize_keras_object(self.backbone), + } + ) + return config diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter.py b/keras_hub/src/models/segformer/segformer_image_segmenter.py new file mode 100644 index 000000000..db5760a60 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_segmenter.py @@ -0,0 +1,148 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone + + +@keras_hub_export( + [ + "keras_hub.models.SegFormerImageSegmenter", + "keras_hub.models.segmentation.SegFormerImageSegmenter", + ] +) +class SegFormerImageSegmenter(ImageSegmenter): + """A Keras model implementing the SegFormer architecture for semantic + segmentation. + + References: + - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) # noqa: E501 + - [Based on the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer) # noqa: E501 + + Args: + backbone: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the SegFormer encoder. + It is *intended* to be used only with the MiT backbone model which + was created specifically for SegFormers. It should either be a + `keras_cv.models.backbones.backbone.Backbone` or a `tf.keras.Model` + that implements the `pyramid_level_inputs` property with keys + "P2", "P3", "P4", and "P5" and layer names as + values. + num_classes: int, the number of classes for the detection model, + including the background class. + projection_filters: int, number of filters in the + convolution layer projecting the concatenated features into + a segmentation map. Defaults to 256`. + + Example: + + Using the class with a `backbone`: + + ```python + import tensorflow as tf + import keras_cv + + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet") + model = keras_cv.models.segmentation.SegFormer( + num_classes=1, backbone=backbone, + ) + + # Evaluate model + model(images) + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + + backbone_cls = SegFormerBackbone + + def __init__( + self, + backbone, + num_classes, + projection_filters=256, + **kwargs, + ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance " + f" or `keras.Model`. Received instead " + f"backbone={backbone} (of type {type(backbone)})." + ) + + inputs = self.backbone.input + + # === Layers === + self.backbone = backbone + self.dropout = keras.layers.Dropout(0.1) + self.output_segmentation = keras.layers.Conv2D( + filters=num_classes, kernel_size=1, activation="softmax" + ) + self.resizing = keras.layers.Resizing( + height=inputs.shape[1], + width=inputs.shape[2], + interpolation="bilinear", + ) + + # === Functional Model === + x = self.backbone(inputs) + x = self.dropout(x) + x = self.output_segmentation(x) + output = self.resizing(x) + + super().__init__( + inputs=inputs, + outputs=output, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.projection_filters = projection_filters + self.backbone = backbone + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "projection_filters": self.projection_filters, + "backbone": keras.saving.serialize_keras_object(self.backbone), + } + ) + return config diff --git a/keras_hub/src/models/segformer/segformer_presets.py b/keras_hub/src/models/segformer/segformer_presets.py new file mode 100644 index 000000000..12ac38160 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_presets.py @@ -0,0 +1,90 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SegFormer model preset configurations.""" + +presets_no_weights = { + "segformer_b0": { + "metadata": { + "description": ("SegFormer model with MiTB0 backbone."), + "params": 3719027, + "official_name": "SegFormerB0", + "path": "segformer_b0", + }, + "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0/2", + }, + "segformer_b1": { + "metadata": { + "description": ("SegFormer model with MiTB1 backbone."), + "params": 13682643, + "official_name": "SegFormerB1", + "path": "segformer_b1", + }, + "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1/2", + }, + "segformer_b2": { + "metadata": { + "description": ("SegFormer model with MiTB2 backbone."), + "params": 24727507, + "official_name": "SegFormerB2", + "path": "segformer_b2", + }, + "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2/2", + }, + "segformer_b3": { + "metadata": { + "description": ("SegFormer model with MiTB3 backbone."), + "params": 44603347, + "official_name": "SegFormerB3", + "path": "segformer_b3", + }, + "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3/2", + }, + "segformer_b4": { + "metadata": { + "description": ("SegFormer model with MiTB4 backbone."), + "params": 61373907, + "official_name": "SegFormerB4", + "path": "segformer_b4", + }, + "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4/2", + }, + "segformer_b5": { + "metadata": { + "description": ("SegFormer model with MiTB5 backbone."), + "params": 81974227, + "official_name": "SegFormerB5", + "path": "segformer_b5", + }, + "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5/2", + }, +} + +presets_with_weights = { + "segformer_b0_imagenet": { + "metadata": { + "description": ( + "SegFormer model with a pretrained MiTB0 backbone." + ), + "params": 3719027, + "official_name": "SegFormerB0", + "path": "segformer_b0", + }, + "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_imagenet/2", # noqa: E501 + }, +} + +presets = { + **presets_no_weights, + **presets_with_weights, +} diff --git a/keras_hub/src/models/segformer/segformer_tests.py b/keras_hub/src/models/segformer/segformer_tests.py new file mode 100644 index 000000000..169fec6c7 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_tests.py @@ -0,0 +1,149 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import tensorflow as tf +from keras_cv.src.models import MiTBackbone +from keras_cv.src.models import SegFormer +from keras_cv.src.tests.test_case import TestCase + + +class SegFormerTest(TestCase): + def test_segformer_construction(self): + backbone = MiTBackbone.from_preset("mit_b0", input_shape=[512, 512, 3]) + model = SegFormer(backbone=backbone, num_classes=2) + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(), + metrics=["accuracy"], + ) + + def test_segformer_preset_construction(self): + model = SegFormer.from_preset( + "segformer_b0", num_classes=2, input_shape=[512, 512, 3] + ) + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(), + metrics=["accuracy"], + ) + + def test_segformer_preset_error(self): + with self.assertRaises(TypeError): + _ = SegFormer.from_preset("segformer_b0") + + @pytest.mark.large + def DISABLED_test_segformer_call(self): + # TODO: Test of output comparison Fails + backbone = MiTBackbone.from_preset("mit_b0") + mit_model = SegFormer(backbone=backbone, num_classes=2) + + images = np.random.uniform(size=(2, 224, 224, 3)) + mit_output = mit_model(images) + mit_pred = mit_model.predict(images) + + seg_model = SegFormer.from_preset("segformer_b0", num_classes=2) + seg_output = seg_model(images) + seg_pred = seg_model.predict(images) + + self.assertAllClose(mit_output, seg_output) + self.assertAllClose(mit_pred, seg_pred) + + @pytest.mark.large + def test_weights_change(self): + target_size = [512, 512, 2] + + images = tf.ones(shape=[1] + [512, 512, 3]) + labels = tf.zeros(shape=[1] + target_size) + ds = tf.data.Dataset.from_tensor_slices((images, labels)) + ds = ds.repeat(2) + ds = ds.batch(2) + + backbone = MiTBackbone.from_preset("mit_b0", input_shape=[512, 512, 3]) + model = SegFormer(backbone=backbone, num_classes=2) + + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(), + metrics=["accuracy"], + ) + + original_weights = model.get_weights() + model.fit(ds, epochs=1) + updated_weights = model.get_weights() + + for w1, w2 in zip(original_weights, updated_weights): + self.assertNotAllEqual(w1, w2) + self.assertFalse(ops.any(ops.isnan(w2))) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + target_size = [512, 512, 3] + + backbone = MiTBackbone.from_preset("mit_b0", input_shape=[512, 512, 3]) + model = SegFormer(backbone=backbone, num_classes=2) + + input_batch = np.ones(shape=[2] + target_size) + model_output = model(input_batch) + + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, SegFormer) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose(model_output, restored_output) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_preset_saved_model(self): + target_size = [224, 224, 3] + + model = SegFormer.from_preset("segformer_b0", num_classes=2) + + input_batch = np.ones(shape=[2] + target_size) + model_output = model(input_batch) + + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, SegFormer) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose(model_output, restored_output)