diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 21a96a1e2..ab0e83857 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -41,6 +41,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) +from keras_hub.src.models.differential_binarization.differential_binarization_image_converter import ( + DifferentialBinarizationImageConverter, +) from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( EfficientNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 70585cec1..0f881b63a 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -104,6 +104,15 @@ from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( DenseNetImageClassifierPreprocessor, ) +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import ( + DifferentialBinarizationBackbone, +) +from keras_hub.src.models.differential_binarization.differential_binarization_ocr import ( + DifferentialBinarizationOCR, +) +from keras_hub.src.models.differential_binarization.differential_binarization_preprocessor import ( + DifferentialBinarizationPreprocessor, +) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) diff --git a/keras_hub/src/models/differential_binarization/__init__.py b/keras_hub/src/models/differential_binarization/__init__.py new file mode 100644 index 000000000..1bcdbaed5 --- /dev/null +++ b/keras_hub/src/models/differential_binarization/__init__.py @@ -0,0 +1,9 @@ +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import ( + DifferentialBinarizationBackbone, +) +from keras_hub.src.models.differential_binarization.differential_binarization_presets import ( + backbone_presets, +) +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, DifferentialBinarizationBackbone) diff --git a/keras_hub/src/models/differential_binarization/differential_binarization_backbone.py b/keras_hub/src/models/differential_binarization/differential_binarization_backbone.py new file mode 100644 index 000000000..ac9e39ac8 --- /dev/null +++ b/keras_hub/src/models/differential_binarization/differential_binarization_backbone.py @@ -0,0 +1,219 @@ +import keras +from keras import layers + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + + +@keras_hub_export("keras_hub.models.DifferentialBinarizationBackbone") +class DifferentialBinarizationBackbone(Backbone): + """ + A Keras model implementing the Differential Binarization + architecture for scene text detection, described in + [Real-time Scene Text Detection with Differentiable Binarization]( + https://arxiv.org/abs/1911.08947). + + This class contains the backbone architecture containing the feature + pyramid network and model heads. + + Args: + image_encoder: A `keras_hub.models.ResNetBackbone` instance. + fpn_channels: int. The number of channels to output by the feature + pyramid network. Defaults to 256. + head_kernel_list: list of ints. The kernel sizes of probability map and + threshold map heads. Defaults to [3, 2, 2]. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + """ + + def __init__( + self, + image_encoder, + fpn_channels=256, + head_kernel_list=[3, 2, 2], + dtype=None, + **kwargs, + ): + # === Functional Model === + inputs = image_encoder.input + x = image_encoder.pyramid_outputs + x = diffbin_fpn_model(x, out_channels=fpn_channels, dtype=dtype) + + probability_maps = diffbin_head( + x, + in_channels=fpn_channels, + kernel_list=head_kernel_list, + name="head_prob", + ) + threshold_maps = diffbin_head( + x, + in_channels=fpn_channels, + kernel_list=head_kernel_list, + name="head_thresh", + ) + + outputs = { + "probability_maps": probability_maps, + "threshold_maps": threshold_maps, + } + + super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) + + # === Config === + self.image_encoder = image_encoder + self.fpn_channels = fpn_channels + self.head_kernel_list = head_kernel_list + + def get_config(self): + config = super().get_config() + config["fpn_channels"] = self.fpn_channels + config["head_kernel_list"] = self.head_kernel_list + config["image_encoder"] = keras.layers.serialize(self.image_encoder) + return config + + @classmethod + def from_config(cls, config): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return cls(**config) + + +def diffbin_fpn_model(inputs, out_channels, dtype=None): + # lateral layers composing the FPN's bottom-up pathway using + # pointwise convolutions of ResNet's pyramid outputs + lateral_p2 = layers.Conv2D( + out_channels, + kernel_size=1, + use_bias=False, + name="neck_lateral_p2", + dtype=dtype, + )(inputs["P2"]) + lateral_p3 = layers.Conv2D( + out_channels, + kernel_size=1, + use_bias=False, + name="neck_lateral_p3", + dtype=dtype, + )(inputs["P3"]) + lateral_p4 = layers.Conv2D( + out_channels, + kernel_size=1, + use_bias=False, + name="neck_lateral_p4", + dtype=dtype, + )(inputs["P4"]) + lateral_p5 = layers.Conv2D( + out_channels, + kernel_size=1, + use_bias=False, + name="neck_lateral_p5", + dtype=dtype, + )(inputs["P5"]) + # top-down fusion pathway consisting of upsampling layers with + # skip connections + topdown_p5 = lateral_p5 + topdown_p4 = layers.Add(name="neck_topdown_p4")( + [ + layers.UpSampling2D(dtype=dtype)(topdown_p5), + lateral_p4, + ] + ) + topdown_p3 = layers.Add(name="neck_topdown_p3")( + [ + layers.UpSampling2D(dtype=dtype)(topdown_p4), + lateral_p3, + ] + ) + topdown_p2 = layers.Add(name="neck_topdown_p2")( + [ + layers.UpSampling2D(dtype=dtype)(topdown_p3), + lateral_p2, + ] + ) + # construct merged feature maps for each pyramid level + featuremap_p5 = layers.Conv2D( + out_channels // 4, + kernel_size=3, + padding="same", + use_bias=False, + name="neck_featuremap_p5", + dtype=dtype, + )(topdown_p5) + featuremap_p4 = layers.Conv2D( + out_channels // 4, + kernel_size=3, + padding="same", + use_bias=False, + name="neck_featuremap_p4", + dtype=dtype, + )(topdown_p4) + featuremap_p3 = layers.Conv2D( + out_channels // 4, + kernel_size=3, + padding="same", + use_bias=False, + name="neck_featuremap_p3", + dtype=dtype, + )(topdown_p3) + featuremap_p2 = layers.Conv2D( + out_channels // 4, + kernel_size=3, + padding="same", + use_bias=False, + name="neck_featuremap_p2", + dtype=dtype, + )(topdown_p2) + featuremap_p5 = layers.UpSampling2D((8, 8), dtype=dtype)(featuremap_p5) + featuremap_p4 = layers.UpSampling2D((4, 4), dtype=dtype)(featuremap_p4) + featuremap_p3 = layers.UpSampling2D((2, 2), dtype=dtype)(featuremap_p3) + featuremap = layers.Concatenate(axis=-1, dtype=dtype)( + [featuremap_p5, featuremap_p4, featuremap_p3, featuremap_p2] + ) + return featuremap + + +def diffbin_head(inputs, in_channels, kernel_list, name): + x = layers.Conv2D( + in_channels // 4, + kernel_size=kernel_list[0], + padding="same", + use_bias=False, + name=f"{name}_conv0_weights", + )(inputs) + x = layers.BatchNormalization( + beta_initializer=keras.initializers.Constant(1e-4), + gamma_initializer=keras.initializers.Constant(1.0), + name=f"{name}_conv0_bn", + )(x) + x = layers.ReLU(name=f"{name}_conv0_relu")(x) + x = layers.Conv2DTranspose( + in_channels // 4, + kernel_size=kernel_list[1], + strides=2, + padding="valid", + bias_initializer=keras.initializers.RandomUniform( + minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5, + maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5, + ), + name=f"{name}_conv1_weights", + )(x) + x = layers.BatchNormalization( + beta_initializer=keras.initializers.Constant(1e-4), + gamma_initializer=keras.initializers.Constant(1.0), + name=f"{name}_conv1_bn", + )(x) + x = layers.ReLU(name=f"{name}_conv1_relu")(x) + x = layers.Conv2DTranspose( + 1, + kernel_size=kernel_list[2], + strides=2, + padding="valid", + activation="sigmoid", + bias_initializer=keras.initializers.RandomUniform( + minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5, + maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5, + ), + name=f"{name}_conv2_weights", + )(x) + return x diff --git a/keras_hub/src/models/differential_binarization/differential_binarization_backbone_test.py b/keras_hub/src/models/differential_binarization/differential_binarization_backbone_test.py new file mode 100644 index 000000000..86da0aa6c --- /dev/null +++ b/keras_hub/src/models/differential_binarization/differential_binarization_backbone_test.py @@ -0,0 +1,44 @@ +from keras import ops + +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import ( + DifferentialBinarizationBackbone, +) +from keras_hub.src.models.differential_binarization.differential_binarization_preprocessor import ( + DifferentialBinarizationPreprocessor, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DifferentialBinarizationTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 32, 32, 3)) + self.image_encoder = ResNetBackbone( + input_conv_filters=[4], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 4, 4, 4], + stackwise_num_blocks=[3, 4, 6, 3], + stackwise_num_strides=[1, 2, 2, 2], + block_type="bottleneck_block", + image_shape=(32, 32, 3), + ) + self.preprocessor = DifferentialBinarizationPreprocessor() + self.init_kwargs = { + "image_encoder": self.image_encoder, + "fpn_channels": 16, + "head_kernel_list": [3, 2, 2], + } + + def test_backbone_basics(self): + expected_output_shape = { + "probability_maps": (2, 32, 32, 1), + "threshold_maps": (2, 32, 32, 1), + } + self.run_backbone_test( + cls=DifferentialBinarizationBackbone, + init_kwargs=self.init_kwargs, + input_data=self.images, + expected_output_shape=expected_output_shape, + run_mixed_precision_check=False, + run_quantization_check=False, + ) diff --git a/keras_hub/src/models/differential_binarization/differential_binarization_image_converter.py b/keras_hub/src/models/differential_binarization/differential_binarization_image_converter.py new file mode 100644 index 000000000..b0ff8adbc --- /dev/null +++ b/keras_hub/src/models/differential_binarization/differential_binarization_image_converter.py @@ -0,0 +1,10 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import ( + DifferentialBinarizationBackbone, +) + + +@keras_hub_export("keras_hub.layers.DifferentialBinarizationImageConverter") +class DifferentialBinarizationImageConverter(ImageConverter): + backbone_cls = DifferentialBinarizationBackbone diff --git a/keras_hub/src/models/differential_binarization/differential_binarization_ocr.py b/keras_hub/src/models/differential_binarization/differential_binarization_ocr.py new file mode 100644 index 000000000..7239633d3 --- /dev/null +++ b/keras_hub/src/models/differential_binarization/differential_binarization_ocr.py @@ -0,0 +1,113 @@ +import keras +from keras import layers + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import ( + DifferentialBinarizationBackbone, +) +from keras_hub.src.models.differential_binarization.differential_binarization_preprocessor import ( + DifferentialBinarizationPreprocessor, +) +from keras_hub.src.models.differential_binarization.losses import DBLoss +from keras_hub.src.models.image_segmenter import ImageSegmenter + + +@keras_hub_export("keras_hub.models.DifferentialBinarizationOCR") +class DifferentialBinarizationOCR(ImageSegmenter): + """ + A Keras model implementing the Differential Binarization + architecture for scene text detection, described in + [Real-time Scene Text Detection with Differentiable Binarization]( + https://arxiv.org/abs/1911.08947). + + Args: + backbone: A `keras_hub.models.DifferentialBinarizationBackbone` + instance. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + image_encoder = keras_hub.models.ResNetBackbone.from_preset( + "resnet_vd_50_imagenet" + ) + backbone = keras_hub.models.DifferentialBinarizationBackbone(image_encoder) + detector = keras_hub.models.DifferentialBinarizationOCR( + backbone=backbone + ) + + detector(input_data) + ``` + """ + + backbone_cls = DifferentialBinarizationBackbone + preprocessor_cls = DifferentialBinarizationPreprocessor + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + + # === Functional Model === + inputs = backbone.input + x = backbone(inputs) + probability_maps = x["probability_maps"] + threshold_maps = x["threshold_maps"] + binary_maps = step_function(probability_maps, threshold_maps) + outputs = layers.Concatenate(axis=-1)( + [probability_maps, threshold_maps, binary_maps] + ) + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + # === Config === + self.backbone = backbone + self.preprocessor = preprocessor + + def compile( + self, + optimizer="auto", + loss="auto", + **kwargs, + ): + """Configures the `DifferentialBinarizationOCR` task for training. + + `DifferentialBinarizationOCR` extends the default compilation signature + of `keras.Model.compile` with defaults for `optimizer` and `loss`. To + override these defaults, pass any value to these arguments during + compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default + optimizer for `DifferentialBinarizationOCR`. See + `keras.Model.compile` and `keras.optimizers` for more info on + possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, in which case the default loss + computation of `DifferentialBinarizationOCR` will be applied. + See `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.SGD( + learning_rate=0.007, weight_decay=0.0001, momentum=0.9 + ) + if loss == "auto": + loss = DBLoss() + super().compile( + optimizer=optimizer, + loss=loss, + **kwargs, + ) + + +def step_function(x, y, k=50.0): + return 1.0 / (1.0 + keras.ops.exp(-k * (x - y))) diff --git a/keras_hub/src/models/differential_binarization/differential_binarization_ocr_test.py b/keras_hub/src/models/differential_binarization/differential_binarization_ocr_test.py new file mode 100644 index 000000000..175a26a49 --- /dev/null +++ b/keras_hub/src/models/differential_binarization/differential_binarization_ocr_test.py @@ -0,0 +1,72 @@ +import pytest +from keras import ops + +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import ( + DifferentialBinarizationBackbone, +) +from keras_hub.src.models.differential_binarization.differential_binarization_ocr import ( + DifferentialBinarizationOCR, +) +from keras_hub.src.models.differential_binarization.differential_binarization_preprocessor import ( + DifferentialBinarizationPreprocessor, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class DifferentialBinarizationOCRTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 32, 32, 3)) + self.labels = ops.concatenate( + (ops.zeros((2, 16, 32, 4)), ops.ones((2, 16, 32, 4))), axis=1 + ) + image_encoder = ResNetBackbone( + input_conv_filters=[4], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 4, 4, 4], + stackwise_num_blocks=[3, 4, 6, 3], + stackwise_num_strides=[1, 2, 2, 2], + block_type="bottleneck_block", + image_shape=(32, 32, 3), + ) + self.backbone = DifferentialBinarizationBackbone( + image_encoder=image_encoder + ) + self.preprocessor = DifferentialBinarizationPreprocessor() + self.init_kwargs = { + "backbone": self.backbone, + "preprocessor": self.preprocessor, + } + self.train_data = (self.images, self.labels) + + def test_basics(self): + self.run_task_test( + cls=DifferentialBinarizationOCR, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 32, 32, 3), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DifferentialBinarizationOCR, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + + def test_end_to_end_model_predict(self): + model = DifferentialBinarizationOCR(**self.init_kwargs) + outputs = model.predict(self.images) + self.assertAllEqual(outputs.shape, (2, 32, 32, 3)) + + @pytest.mark.skip(reason="disabled until preset's been uploaded to Kaggle") + @pytest.mark.extra_large + def test_all_presets(self): + for preset in DifferentialBinarizationOCR.presets: + self.run_preset_test( + cls=DifferentialBinarizationOCR, + preset=preset, + input_data=self.images, + expected_output_shape=(2, 32, 32, 3), + ) diff --git a/keras_hub/src/models/differential_binarization/differential_binarization_preprocessor.py b/keras_hub/src/models/differential_binarization/differential_binarization_preprocessor.py new file mode 100644 index 000000000..55ae05cf0 --- /dev/null +++ b/keras_hub/src/models/differential_binarization/differential_binarization_preprocessor.py @@ -0,0 +1,16 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import ( + DifferentialBinarizationBackbone, +) +from keras_hub.src.models.differential_binarization.differential_binarization_image_converter import ( + DifferentialBinarizationImageConverter, +) +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) + + +@keras_hub_export("keras_hub.models.DifferentialBinarizationPreprocessor") +class DifferentialBinarizationPreprocessor(ImageSegmenterPreprocessor): + backbone_cls = DifferentialBinarizationBackbone + image_converter_cls = DifferentialBinarizationImageConverter diff --git a/keras_hub/src/models/differential_binarization/differential_binarization_presets.py b/keras_hub/src/models/differential_binarization/differential_binarization_presets.py new file mode 100644 index 000000000..4548f4cc6 --- /dev/null +++ b/keras_hub/src/models/differential_binarization/differential_binarization_presets.py @@ -0,0 +1,17 @@ +"""Differential Binarization preset configurations.""" + +backbone_presets = { + "diffbin_r50vd_icdar2015": { + "metadata": { + "description": ( + "Differential Binarization using 50-layer" + "ResNetVD trained on the ICDAR2015 dataset." + ), + "params": 25482722, + "official_name": "DifferentialBinarization", + "path": "differential_binarization", + "model_card": "https://arxiv.org/abs/1911.08947", + }, + "kaggle_handle": "", # TODO + } +} diff --git a/keras_hub/src/models/differential_binarization/losses.py b/keras_hub/src/models/differential_binarization/losses.py new file mode 100644 index 000000000..7e9a28c1f --- /dev/null +++ b/keras_hub/src/models/differential_binarization/losses.py @@ -0,0 +1,126 @@ +import keras +from keras import ops + + +class DiceLoss: + def __init__(self, eps=1e-6, **kwargs): + self.eps = eps + + def __call__(self, y_true, y_pred, mask, weights=None): + if weights is not None: + mask = weights * mask + intersection = ops.sum((y_pred * y_true * mask)) + union = ops.sum((y_pred * mask)) + ops.sum(y_true * mask) + self.eps + loss = 1 - 2.0 * intersection / union + return loss + + +class MaskL1Loss: + def __init__(self, **kwargs): + pass + + def __call__(self, y_true, y_pred, mask): + mask_sum = ops.sum(mask) + loss = ops.where( + mask_sum == 0.0, + 0.0, + ops.sum(ops.absolute(y_pred - y_true) * mask) / mask_sum, + ) + return loss + + +class BalanceCrossEntropyLoss: + def __init__(self, negative_ratio=3.0, eps=1e-6, **kwargs): + self.negative_ratio = negative_ratio + self.eps = eps + + def __call__(self, y_true, y_pred, mask, return_origin=False): + positive = ops.cast((y_true > 0.5) & ops.cast(mask, "bool"), "uint8") + negative = ops.cast((y_true < 0.5) & ops.cast(mask, "bool"), "uint8") + positive_count = ops.sum(ops.cast(positive, "int32")) + negative_count = ops.sum(ops.cast(negative, "int32")) + negative_count_max = ops.cast( + ops.cast(positive_count, "float32") * self.negative_ratio, "int32" + ) + + negative_count = ops.where( + negative_count > negative_count_max, + negative_count_max, + negative_count, + ) + # Keras' losses reduce some axis. Since we don't want that here, we add + # a dummy dimension to y_true and y_pred + loss = keras.losses.BinaryCrossentropy( + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction=None, + )(y_true=y_true[..., None], y_pred=y_pred[..., None]) + + positive_loss = loss * ops.cast(positive, "float32") + negative_loss = loss * ops.cast(negative, "float32") + + # hard negative mining, as suggested in + # [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947): + # Compute the threshold for hard negatives, and zero-out + # negative losses below the threshold. using this approach, + # we achieve efficient computation on GPUs + + # compute negative_count relative to the element count of y_pred + negative_count_rel = ops.cast(negative_count, "float32") / ops.prod( + ops.cast(ops.shape(y_pred), "float32") + ) + # compute the threshold value for negative losses and zero neg. loss + # values below this threshold + negative_loss_thresh = ops.quantile( + negative_loss, 1.0 - negative_count_rel + ) + negative_loss = negative_loss * ops.cast( + negative_loss > negative_loss_thresh, "float32" + ) + + balance_loss = (ops.sum(positive_loss) + ops.sum(negative_loss)) / ( + ops.cast(positive_count + negative_count, "float32") + self.eps + ) + + if return_origin: + return balance_loss, loss + return balance_loss + + +class DBLoss(keras.losses.Loss): + def __init__(self, eps=1e-6, l1_scale=10.0, bce_scale=5.0, **kwargs): + super().__init__(*kwargs) + self.dice_loss = DiceLoss(eps=eps) + self.l1_loss = MaskL1Loss() + self.bce_loss = BalanceCrossEntropyLoss() + + self.l1_scale = l1_scale + self.bce_scale = bce_scale + + def call(self, y_true, y_pred): + p_map_pred, t_map_pred, b_map_pred = ops.unstack(y_pred, 3, axis=-1) + shrink_map, shrink_mask, thresh_map, thresh_mask = ops.unstack( + y_true, 4, axis=-1 + ) + + # we here implement L1BalanceCELoss from PyTorch's + # Differential Binarization implementation + Ls = self.bce_loss( + y_true=shrink_map, + y_pred=p_map_pred, + mask=shrink_mask, + return_origin=False, + ) + Lt = self.l1_loss( + y_true=thresh_map, + y_pred=t_map_pred, + mask=thresh_mask, + ) + dice_loss = self.dice_loss( + y_true=shrink_map, + y_pred=b_map_pred, + mask=shrink_mask, + ) + loss = dice_loss + self.l1_scale * Lt + Ls * self.bce_scale + return loss diff --git a/keras_hub/src/models/differential_binarization/losses_test.py b/keras_hub/src/models/differential_binarization/losses_test.py new file mode 100644 index 000000000..aae390095 --- /dev/null +++ b/keras_hub/src/models/differential_binarization/losses_test.py @@ -0,0 +1,71 @@ +import numpy as np + +from keras_hub.src.models.differential_binarization.losses import DBLoss +from keras_hub.src.models.differential_binarization.losses import DiceLoss +from keras_hub.src.models.differential_binarization.losses import MaskL1Loss +from keras_hub.src.tests.test_case import TestCase + + +class DiceLossTest(TestCase): + def setUp(self): + self.loss_obj = DiceLoss() + + def test_loss(self): + y_true = np.array([1.0, 1.0, 0.0, 0.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + mask = np.array([0.0, 1.0, 1.0, 0.0]) + weights = np.array([4.0, 5.0, 6.0, 7.0]) + loss = self.loss_obj(y_true, y_pred, mask, weights) + self.assertAlmostEqual(loss, 0.74358, delta=1e-4) + + def test_correct(self): + y_true = np.array([1.0, 1.0, 0.0, 0.0]) + y_pred = y_true + mask = np.array([0.0, 1.0, 1.0, 0.0]) + loss = self.loss_obj(y_true, y_pred, mask) + self.assertAlmostEqual(loss, 0.0, delta=1e-4) + + +class MaskL1LossTest(TestCase): + def setUp(self): + self.loss_obj = MaskL1Loss() + + def test_masked(self): + y_true = np.array([1.0, 2.0, 3.0, 4.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + mask = np.array([0.0, 1.0, 0.0, 1.0]) + loss = self.loss_obj(y_true, y_pred, mask) + self.assertAlmostEqual(loss, 2.7, delta=1e-4) + + +class DBLossTest(TestCase): + def setUp(self): + self.loss_obj = DBLoss() + + def test_loss(self): + shrink_map = thresh_map = np.array( + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]] + ) + p_map_pred = b_map_pred = t_map_pred = np.array( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + ) + shrink_mask = thresh_mask = np.ones_like(p_map_pred) + y_true = np.stack( + (shrink_map, shrink_mask, thresh_map, thresh_mask), axis=-1 + ) + y_pred = np.stack((p_map_pred, t_map_pred, b_map_pred), axis=-1) + loss = self.loss_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 14.1123, delta=1e-4) + + def test_correct(self): + shrink_map = thresh_map = np.array( + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]] + ) + p_map_pred, b_map_pred, t_map_pred = shrink_map, shrink_map, thresh_map + shrink_mask = thresh_mask = np.ones_like(p_map_pred) + y_true = np.stack( + (shrink_map, shrink_mask, thresh_map, thresh_mask), axis=-1 + ) + y_pred = np.stack((p_map_pred, t_map_pred, b_map_pred), axis=-1) + loss = self.loss_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.0, delta=1e-4)