Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semantic Segmentation] - SegFormer (and MiTs) #1883

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions keras_hub/src/models/image_segmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.task import Task


@keras_hub_export("keras_hub.models.ImageSegmenter")
class ImageSegmenter(Task):
"""Base class for all image segmentation tasks.

`ImageSegmenter` tasks wrap a `keras_hub.models.Task` and
a `keras_hub.models.Preprocessor` to create a model that can be used for
image segmentation.

All `ImageSegmenter` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.
`ImageSegmenter` tasks take an additional
`num_classes` argument, the number of segmentation classes.

To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
labels where `x` is a image and `y` is a label from `[0, num_classes)`.

Example:
```python
model = keras_hub.models.ImageSegmenter.from_preset(
"deeplab_resnet",
num_classes=2,
)
images = np.ones(shape=(1, 288, 288, 3))
labels = np.zeros(shape=(1, 288, 288, 1))

output = model(images)
pred_labels = output[0]

model.fit(images, labels, epochs=3)
```
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
"""Configures the `ImageSegmenter` task for training.

The `ImageSegmenter` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.

Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the classification task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model during training.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(5e-5)
if loss == "auto":
activation = getattr(self, "activation", None)
activation = keras.activations.get(activation)
from_logits = activation != keras.activations.softmax
loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits)
if metrics == "auto":
metrics = [keras.metrics.CategoricalAccuracy()]
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
21 changes: 21 additions & 0 deletions keras_hub/src/models/segformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
from keras_hub.src.models.segformer.segformer_image_segmenter import (
SegFormerImageSegmenter,
)
from keras_hub.src.models.segformer.segformer_presets import presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(presets, SegFormerImageSegmenter)
173 changes: 173 additions & 0 deletions keras_hub/src/models/segformer/segformer_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)
from keras_hub.src.models.segformer.segformer_presets import presets


@keras_hub_export(
[
"keras_hub.models.SegFormerBackbone",
"keras_hub.models.segmentation.SegFormerBackbone",
]
)
class SegFormerBackbone(Backbone):
"""A Keras model implementing the SegFormer architecture for semantic
segmentation.

References:
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) # noqa: E501
- [Based on the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer) # noqa: E501

Args:
backbone: `keras.Model`. The backbone network for the model that is
used as a feature extractor for the SegFormer encoder.
It is *intended* to be used only with the MiT backbone model which
was created specifically for SegFormers. It should either be a
`keras_cv.models.backbones.backbone.Backbone` or a `tf.keras.Model`
that implements the `pyramid_level_inputs` property with keys
"P2", "P3", "P4", and "P5" and layer names as
values.
num_classes: int, the number of classes for the detection model,
including the background class.
projection_filters: int, number of filters in the
convolution layer projecting the concatenated features into
a segmentation map. Defaults to 256`.

Example:

Using the class with a custom `backbone`:

```python
import tensorflow as tf
import keras_hub

backbone = keras_hub.models.MiTBackbone(
depths=[2, 2, 2, 2],
image_shape=(16, 16, 3),
hidden_dims=[4, 8],
num_layers=2,
blockwise_num_heads=[1, 2],
blockwise_sr_ratios=[8, 4],
end_value=0.1,
patch_sizes=[7, 3],
strides=[4, 2],
)

model = SegFormerBackbone(backbone=backbone, num_classes=4)
```
"""

backbone_cls = MiTBackbone

def __init__(
self,
backbone,
projection_filters=256,
**kwargs,
):
if not isinstance(backbone, keras.layers.Layer) or not isinstance(
backbone, keras.Model
):
raise ValueError(
"Argument `backbone` must be a `keras.layers.Layer` instance "
f" or `keras.Model`. Received instead "
f"backbone={backbone} (of type {type(backbone)})."
)

self.feature_extractor = keras.Model(
backbone.outputs, backbone.pyramid_outputs
)

inputs = backbone.output
features = self.feature_extractor(inputs)
# Get H and W of level one output
_, H, W, _ = features["P1"].shape

# === Layers ===

self.mlp_blocks = []

for feature_dim, feature in zip(backbone.hidden_dims, features):
self.mlp_blocks.append(
keras.layers.Dense(
projection_filters, name=f"linear_{feature_dim}"
)
)

self.resizing = keras.layers.Resizing(H, W, interpolation="bilinear")
self.concat = keras.layers.Concatenate(axis=3)
self.segmentation = keras.Sequential(
[
keras.layers.Conv2D(
filters=projection_filters, kernel_size=1, use_bias=False
),
keras.layers.BatchNormalization(),
keras.layers.Activation("relu"),
]
)

# === Functional Model ===

# Project all multi-level outputs onto the same dimensionality
# and feature map shape
multi_layer_outs = []
for index, (feature_dim, feature) in enumerate(
zip(backbone.hidden_dims, features)
):
out = self.mlp_blocks[index](features[feature])
out = self.resizing(out)
multi_layer_outs.append(out)

# Concat now-equal feature maps
concatenated_outs = self.concat(multi_layer_outs[::-1])

# Fuse concatenated features into a segmentation map
seg = self.segmentation(concatenated_outs)

super().__init__(
inputs=inputs,
outputs=seg,
**kwargs,
)

self.projection_filters = projection_filters
self.backbone = backbone

def get_config(self):
config = super().get_config()
config.update(
{
"projection_filters": self.projection_filters,
"backbone": keras.saving.serialize_keras_object(self.backbone),
}
)
return config
Loading