From a089a8b7559be3e5d5a49b04df54997a9803cc4d Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 7 Aug 2024 17:57:31 -0700 Subject: [PATCH 01/19] Add VGG16 backbone (#1737) * Agg Vgg16 backbone * update names * update tests * update test * add image classifier * incorporate review comments * Update test case * update backbone test * add image classifier * classifier cleanup * code reformat * add vgg16 image classifier * make vgg generic * update doc string * update docstring * add classifier test * update tests * update docstring * address review comments * code reformat * update the configs * address review comments * fix task saved model test * update init * code reformatted --- keras_nlp/api/models/__init__.py | 3 + keras_nlp/src/models/image_classifier.py | 90 ++++++++++ keras_nlp/src/models/vgg/__init__.py | 13 ++ keras_nlp/src/models/vgg/vgg_backbone.py | 159 ++++++++++++++++++ keras_nlp/src/models/vgg/vgg_backbone_test.py | 48 ++++++ .../src/models/vgg/vgg_image_classifier.py | 124 ++++++++++++++ .../models/vgg/vgg_image_classifier_test.py | 61 +++++++ keras_nlp/src/tests/test_case.py | 30 ++-- 8 files changed, 514 insertions(+), 14 deletions(-) create mode 100644 keras_nlp/src/models/image_classifier.py create mode 100644 keras_nlp/src/models/vgg/__init__.py create mode 100644 keras_nlp/src/models/vgg/vgg_backbone.py create mode 100644 keras_nlp/src/models/vgg/vgg_backbone_test.py create mode 100644 keras_nlp/src/models/vgg/vgg_image_classifier.py create mode 100644 keras_nlp/src/models/vgg/vgg_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 4fb3b3cf00..41f1a47284 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -129,6 +129,7 @@ GPTNeoXPreprocessor, ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -194,6 +195,8 @@ from keras_nlp.src.models.t5.t5_backbone import T5Backbone from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer from keras_nlp.src.models.task import Task +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/src/models/image_classifier.py b/keras_nlp/src/models/image_classifier.py new file mode 100644 index 0000000000..f0cc031dbc --- /dev/null +++ b/keras_nlp/src/models/image_classifier.py @@ -0,0 +1,90 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.task import Task + + +@keras_nlp_export("keras_nlp.models.ImageClassifier") +class ImageClassifier(Task): + """Base class for all image classification tasks. + + `ImageClassifier` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + image classification. `ImageClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageClassifier` task for training. + + The `ImageClassifier` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the classification task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.SparseCategoricalCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.SparseCategoricalAccuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_nlp/src/models/vgg/__init__.py b/keras_nlp/src/models/vgg/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/vgg/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/vgg/vgg_backbone.py b/keras_nlp/src/models/vgg/vgg_backbone.py new file mode 100644 index 0000000000..497381c0fc --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone.py @@ -0,0 +1,159 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.VGGBackbone") +class VGGBackbone(Backbone): + """ + This class represents Keras Backbone of VGG model. + + This class implements a VGG backbone as described in [Very Deep + Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556)(ICLR 2015). + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for + VGG19 this is [2, 2, 4, 4, 4]. + stackwise_num_filters: list of ints, filter size for convolutional + blocks per VGG block. For both VGG16 and VGG19 this is [ + 64, 128, 256, 512, 512]. + include_rescaling: bool, whether to rescale the inputs. If set to + True, inputs will be passed through a `Rescaling(1/255.0)` layer. + input_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + pooling: bool, Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained VGG backbone. + model = keras_nlp.models.VGGBackbone.from_preset("vgg16") + model(input_data) + + # Randomly initialized VGG backbone with a custom config. + model = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_repeats, + stackwise_num_filters, + include_rescaling, + input_image_shape=(224, 224, 3), + pooling="avg", + **kwargs, + ): + + # === Functional Model === + img_input = keras.layers.Input(shape=input_image_shape) + x = img_input + + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + for stack_index in range(len(stackwise_num_repeats) - 1): + x = apply_vgg_block( + x=x, + num_layers=stackwise_num_repeats[stack_index], + filters=stackwise_num_filters[stack_index], + kernel_size=(3, 3), + activation="relu", + padding="same", + max_pool=True, + name=f"block{stack_index + 1}", + ) + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.stackwise_num_filters = stackwise_num_filters + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.pooling = pooling + + def get_config(self): + return { + "stackwise_num_repeats": self.stackwise_num_repeats, + "stackwise_num_filters": self.stackwise_num_filters, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + + +def apply_vgg_block( + x, + num_layers, + filters, + kernel_size, + activation, + padding, + max_pool, + name, +): + """ + Applies VGG block + Args: + x: Tensor, input tensor to pass through network + num_layers: int, number of CNN layers in the block + filters: int, filter size of each CNN layer in block + kernel_size: int (or) tuple, kernel size for CNN layer in block + activation: str (or) callable, activation function for each CNN layer in + block + padding: str (or) callable, padding function for each CNN layer in block + max_pool: bool, whether to add MaxPooling2D layer at end of block + name: str, name of the block + + Returns: + keras.KerasTensor + """ + for num in range(1, num_layers + 1): + x = layers.Conv2D( + filters, + kernel_size, + activation=activation, + padding=padding, + name=f"{name}_conv{num}", + )(x) + if max_pool: + x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x) + return x diff --git a/keras_nlp/src/models/vgg/vgg_backbone_test.py b/keras_nlp/src/models/vgg/vgg_backbone_test.py new file mode 100644 index 0000000000..05ed33ba0f --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class VGGBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [2, 3, 3], + "stackwise_num_filters": [8, 64, 64], + "input_image_shape": (16, 16, 3), + "include_rescaling": False, + "pooling": "avg", + } + self.input_data = np.ones((2, 16, 16, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 64), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier.py b/keras_nlp/src/models/vgg/vgg_image_classifier.py new file mode 100644 index 0000000000..a26fbfbc30 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier.py @@ -0,0 +1,124 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone + + +@keras_nlp_export("keras_nlp.models.VGGImageClassifier") +class VGGImageClassifier(ImageClassifier): + """VGG16 image classifier task model. + + Args: + backbone: A `keras_nlp.models.VGGBackbone` instance. + num_classes: int, number of classes to predict. + pooling: str, type of pooling layer. Must be one of "avg", "max". + activation: Optional `str` or callable, defaults to "softmax". The + activation function to use on the Dense layer. Set `activation=None` + to return the output logits. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Examples: + Train from preset + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.VGGImageClassifier.from_preset( + 'vgg_16_image_classifier') + classifier.fit(x=images, y=labels, batch_size=2) + + # Re-compile (e.g., with a new learning rate). + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + + # Access backbone programmatically (e.g., to change `trainable`). + classifier.backbone.trainable = False + # Fit again. + classifier.fit(x=images, y=labels, batch_size=2) + ``` + Custom backbone + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + + backbone = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + classifier = keras_nlp.models.VGGImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = VGGBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py new file mode 100644 index 0000000000..4a2573e496 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_nlp.src.tests.test_case import TestCase + + +class VGGImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 4, 4, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = VGGBackbone( + stackwise_num_repeats=[2, 4, 4], + stackwise_num_filters=[2, 16, 16], + input_image_shape=(4, 4, 3), + include_rescaling=False, + pooling="max", + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 7e8e0cec95..fc1ce77e1e 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -419,20 +419,22 @@ def run_backbone_test( self.assertEqual(output[key].shape, expected_output_shape[key]) else: self.assertEqual(output.shape, expected_output_shape) - - # Check we can embed tokens eagerly. - output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) - - # Check variable length sequences. - if variable_length_data is None: - # If no variable length data passed, assume the second axis of all - # inputs is our sequence axis and create it ourselves. - variable_length_data = [ - tree.map_structure(lambda x: x[:, :seq_length, ...], input_data) - for seq_length in (2, 3, 4) - ] - for batch in variable_length_data: - backbone(batch) + if backbone.token_embedding is not None: + # Check we can embed tokens eagerly. + output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) + + # Check variable length sequences. + if variable_length_data is None: + # If no variable length data passed, assume the second axis of all + # inputs is our sequence axis and create it ourselves. + variable_length_data = [ + tree.map_structure( + lambda x: x[:, :seq_length, ...], input_data + ) + for seq_length in (2, 3, 4) + ] + for batch in variable_length_data: + backbone(batch) # Check compiled predict function. backbone.predict(input_data) From 73b7bad007a8c37a54512092c2b8bfe435d21c10 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:09:08 +0800 Subject: [PATCH 02/19] Add `ResNetBackbone` and `ResNetImageClassifier` (#1765) * Add ResNetV1 and ResNetV2 * Address comments --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/resnet/__init__.py | 13 + .../src/models/resnet/resnet_backbone.py | 544 ++++++++++++++++++ .../src/models/resnet/resnet_backbone_test.py | 75 +++ .../models/resnet/resnet_image_classifier.py | 131 +++++ .../resnet/resnet_image_classifier_test.py | 62 ++ keras_nlp/src/tests/test_case.py | 60 ++ keras_nlp/src/utils/keras_utils.py | 13 + 8 files changed, 902 insertions(+) create mode 100644 keras_nlp/src/models/resnet/__init__.py create mode 100644 keras_nlp/src/models/resnet/resnet_backbone.py create mode 100644 keras_nlp/src/models/resnet/resnet_backbone_test.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_classifier.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 41f1a47284..783cfd5087 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -181,6 +181,10 @@ from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.src.models.roberta.roberta_masked_lm import RobertaMaskedLM diff --git a/keras_nlp/src/models/resnet/__init__.py b/keras_nlp/src/models/resnet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/resnet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py new file mode 100644 index 0000000000..bec5ba60b5 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -0,0 +1,544 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +@keras_nlp_export("keras_nlp.models.ResNetBackbone") +class ResNetBackbone(Backbone): + """ResNet and ResNetV2 core network with hyperparameters. + + This class implements a ResNet backbone as described in [Deep Residual + Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( + CVPR 2016) and [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027)(ECCV 2016). + + The difference in ResNet and ResNetV2 rests in the structure of their + individual building blocks. In ResNetV2, the batch normalization and + ReLU activation precede the convolution layers, as opposed to ResNet where + the batch normalization and ReLU activation are applied after the + convolution layers. + + Args: + stackwise_num_filters: list of ints. The number of filters for each + stack. + stackwise_num_blocks: list of ints. The number of blocks for each stack. + stackwise_num_strides: list of ints. The number of strides for each + stack. + block_type: str. The block type to stack. One of `"basic_block"` or + `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. + Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. + include_rescaling: boolean. If `True`, rescale the input using + `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to + `True`. + input_image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + pooling: `None` or str. Pooling mode for feature extraction. Defaults + to `"avg"`. + - `None` means that the output of the model will be the 4D tensor + from the last convolutional block. + - `avg` means that global average pooling will be applied to the + output of the last convolutional block, resulting in a 2D + tensor. + - `max` means that global max pooling will be applied to the + output of the last convolutional block, resulting in a 2D + tensor. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained ResNet backbone. + model = keras_nlp.models.ResNetBackbone.from_preset("resnet50") + model(input_data) + + # Randomly initialized ResNetV2 backbone with a custom config. + model = keras_nlp.models.ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + pooling="avg", + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + stackwise_num_strides, + block_type, + use_pre_activation=False, + include_rescaling=True, + input_image_shape=(None, None, 3), + pooling="avg", + data_format=None, + dtype=None, + **kwargs, + ): + if len(stackwise_num_filters) != len(stackwise_num_blocks) or len( + stackwise_num_filters + ) != len(stackwise_num_strides): + raise ValueError( + "The length of `stackwise_num_filters`, `stackwise_num_blocks` " + "and `stackwise_num_strides` must be the same. Received: " + f"stackwise_num_filters={stackwise_num_filters}, " + f"stackwise_num_blocks={stackwise_num_blocks}, " + f"stackwise_num_strides={stackwise_num_strides}" + ) + if stackwise_num_filters[0] != 64: + raise ValueError( + "The first element of `stackwise_num_filters` must be 64. " + f"Received: stackwise_num_filters={stackwise_num_filters}" + ) + if block_type not in ("basic_block", "bottleneck_block"): + raise ValueError( + '`block_type` must be either `"basic_block"` or ' + f'`"bottleneck_block"`. Received block_type={block_type}.' + ) + version = "v1" if not use_pre_activation else "v2" + data_format = standardize_data_format(data_format) + bn_axis = -1 if data_format == "channels_last" else 1 + num_stacks = len(stackwise_num_filters) + + # === Functional Model === + image_input = layers.Input(shape=input_image_shape) + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) + else: + x = image_input + + x = layers.Conv2D( + 64, + 7, + strides=2, + padding="same", + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name="conv1_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) + + x = layers.MaxPool2D( + 3, + strides=2, + padding="same", + data_format=data_format, + dtype=dtype, + name="pool1_pool", + )(x) + + for stack_index in range(num_stacks): + x = apply_stack( + x, + filters=stackwise_num_filters[stack_index], + blocks=stackwise_num_blocks[stack_index], + stride=stackwise_num_strides[stack_index], + block_type=block_type, + use_pre_activation=use_pre_activation, + first_shortcut=( + block_type == "bottleneck_block" or stack_index > 0 + ), + data_format=data_format, + dtype=dtype, + name=f"{version}_stack{stack_index}", + ) + + if use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="post_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) + + if pooling == "avg": + feature_map_output = layers.GlobalAveragePooling2D( + data_format=data_format, dtype=dtype + )(x) + elif pooling == "max": + feature_map_output = layers.GlobalMaxPooling2D( + data_format=data_format, dtype=dtype + )(x) + else: + feature_map_output = x + + super().__init__( + inputs=image_input, + outputs=feature_map_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.stackwise_num_strides = stackwise_num_strides + self.block_type = block_type + self.use_pre_activation = use_pre_activation + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.pooling = pooling + + def get_config(self): + return { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_strides": self.stackwise_num_strides, + "block_type": self.block_type, + "use_pre_activation": self.use_pre_activation, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + + +def apply_basic_block( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a basic residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the basic residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1.001e-5, + dtype=dtype, + name=f"{name}_use_preactivation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + )(x_preact) + + if conv_shortcut: + shortcut = layers.Conv2D( + filters, + 1, + strides=stride, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_0_conv", + )(x_preact if x_preact is not None else x) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + )(shortcut) + else: + if not use_pre_activation or stride == 1: + shortcut = x + else: + shortcut = layers.MaxPooling2D( + 1, + strides=stride, + data_format=data_format, + dtype=dtype, + name=f"{name}_0_max_pooling", + )(x) + + x = layers.Conv2D( + filters, + kernel_size, + strides=stride if not use_pre_activation else 1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x_preact if x_preact is not None else x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=1 if not use_pre_activation else stride, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_bottleneck_block( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a bottleneck residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1.001e-5, + dtype=dtype, + name=f"{name}_use_preactivation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + )(x_preact) + + if conv_shortcut: + shortcut = layers.Conv2D( + 4 * filters, + 1, + strides=stride, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_0_conv", + )(x_preact if x_preact is not None else x) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + )(shortcut) + else: + if not use_pre_activation or stride == 1: + shortcut = x + else: + shortcut = layers.MaxPooling2D( + 1, + strides=stride, + data_format=data_format, + dtype=dtype, + name=f"{name}_0_max_pooling", + )(x) + + x = layers.Conv2D( + filters, + 1, + strides=stride if not use_pre_activation else 1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x_preact if x_preact is not None else x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=1 if not use_pre_activation else stride, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + x = layers.Conv2D( + 4 * filters, + 1, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_3_conv", + )(x) + + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_3_bn" + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_stack( + x, + filters, + blocks, + stride, + block_type, + use_pre_activation, + first_shortcut=True, + data_format=None, + dtype=None, + name=None, +): + """Applies a set of stacked residual blocks. + + Args: + x: Tensor. The input tensor to pass through the stack. + filters: int. The number of filters in a block. + blocks: int. The number of blocks in the stack. + stride: int. The stride length of the first layer in the first block. + block_type: str. The block type to stack. One of `"basic_block"` or + `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. + Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet and ResNeXt. + first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `True`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the stack. + + Returns: + Output tensor for the stacked blocks. + """ + if name is None: + version = "v1" if not use_pre_activation else "v2" + name = f"{version}_stack" + + if block_type == "basic_block": + block_fn = apply_basic_block + elif block_type == "bottleneck_block": + block_fn = apply_bottleneck_block + else: + raise ValueError( + '`block_type` must be either `"basic_block"` or ' + f'`"bottleneck_block"`. Received block_type={block_type}.' + ) + x = block_fn( + x, + filters, + stride=stride if not use_pre_activation else 1, + conv_shortcut=first_shortcut, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block1", + ) + for i in range(2, blocks): + x = block_fn( + x, + filters, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block{str(i)}", + ) + x = block_fn( + x, + filters, + stride=1 if not use_pre_activation else stride, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block{str(blocks)}", + ) + return x diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py new file mode 100644 index 0000000000..2113bcd131 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from absl.testing import parameterized +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class ResNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "input_image_shape": (None, None, 3), + "pooling": "avg", + } + self.input_size = (16, 16) + self.input_data = ops.ones((2, 16, 16, 3)) + + @parameterized.named_parameters( + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ) + def test_backbone_basics(self, use_pre_activation, block_type): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + {"block_type": block_type, "use_pre_activation": use_pre_activation} + ) + self.run_vision_backbone_test( + cls=ResNetBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + expected_output_shape=( + (2, 64) if block_type == "basic_block" else (2, 256) + ), + ) + + @parameterized.named_parameters( + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ) + @pytest.mark.large + def test_saved_model(self, use_pre_activation, block_type): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + { + "block_type": block_type, + "use_pre_activation": use_pre_activation, + "input_image_shape": (16, 16, 3), + } + ) + self.run_model_saving_test( + cls=ResNetBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py new file mode 100644 index 0000000000..02c8c78b27 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -0,0 +1,131 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone + + +@keras_nlp_export("keras_nlp.models.ResNetImageClassifier") +class ResNetImageClassifier(ImageClassifier): + """ResNet image classifier task model. + + Args: + backbone: A `keras_nlp.models.ResNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + include_rescaling=False, + pooling="avg", + ) + classifier = keras_nlp.models.ResNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = ResNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=self.backbone.dtype_policy, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py new file mode 100644 index 0000000000..bbbda72d64 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class ResNetImageClassifierTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 16, 16, 3)) + self.labels = [0, 3] + self.backbone = ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + input_image_shape=(16, 16, 3), + include_rescaling=False, + pooling="avg", + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=ResNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ResNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index fc1ce77e1e..72653c8b83 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -457,6 +457,66 @@ def run_backbone_test( if run_quantization_check and has_quantization_support(): self.run_quantization_test(backbone, cls, init_kwargs, input_data) + def run_vision_backbone_test( + self, + cls, + init_kwargs, + input_data, + expected_output_shape, + variable_length_data=None, + run_mixed_precision_check=True, + run_quantization_check=True, + run_data_format_check=True, + ): + """Run basic tests for a vision backbone, including compilation.""" + can_run_data_format_check = True + if ( + keras.config.backend() == "tensorflow" + and not tf.config.list_physical_devices("GPU") + ): + # Never test the "channels_first" format on tensorflow CPU. + # Tensorflow lacks support for "channels_first" convolution. + can_run_data_format_check = False + + ori_data_format = keras.config.image_data_format() + keras.config.set_image_data_format("channels_last") + self.run_backbone_test( + cls=cls, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=expected_output_shape, + variable_length_data=variable_length_data, + run_mixed_precision_check=run_mixed_precision_check, + run_quantization_check=run_quantization_check, + ) + + # Check data_format. We assume that `input_data` is in "channels_last" + # format. + if run_data_format_check and can_run_data_format_check: + keras.config.set_image_data_format("channels_first") + input_data_shape = ops.shape(input_data) + if len(input_data_shape) == 3: + input_data = ops.transpose(input_data, axes=(2, 0, 1)) + elif len(input_data_shape) == 4: + input_data = ops.transpose(input_data, axes=(0, 3, 1, 2)) + if "input_image_shape" in init_kwargs: + init_kwargs = init_kwargs.copy() + init_kwargs["input_image_shape"] = tuple( + reversed(init_kwargs["input_image_shape"]) + ) + self.run_backbone_test( + cls=cls, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=expected_output_shape, + variable_length_data=variable_length_data, + run_mixed_precision_check=run_mixed_precision_check, + run_quantization_check=run_quantization_check, + ) + + # Restore the original `image_data_format`. + keras.config.set_image_data_format(ori_data_format) + def run_task_test( self, cls, diff --git a/keras_nlp/src/utils/keras_utils.py b/keras_nlp/src/utils/keras_utils.py index 0fb96ccffb..b37b74ad19 100644 --- a/keras_nlp/src/utils/keras_utils.py +++ b/keras_nlp/src/utils/keras_utils.py @@ -115,3 +115,16 @@ def assert_quantization_support(): "Quantization API requires Keras >= 3.4.0 to function " f"correctly. Received: '{keras.version()}'" ) + + +def standardize_data_format(data_format): + if data_format is None: + return keras.config.image_data_format() + data_format = str(data_format).lower() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "The `data_format` argument must be one of " + "{'channels_first', 'channels_last'}. " + f"Received: data_format={data_format}" + ) + return data_format From 26afc7e538927bbb8d588ab72ce50c3a6c1f89b5 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 14 Aug 2024 18:30:21 -0700 Subject: [PATCH 03/19] Add CSP DarkNet backbone and classifier (#1774) * Add CSP DarkNet * Add CSP DarkNet * snake_case function names * change use_depthwise to block_type --- keras_nlp/api/models/__init__.py | 6 + keras_nlp/src/models/csp_darknet/__init__.py | 13 + .../csp_darknet/csp_darknet_backbone.py | 410 ++++++++++++++++++ .../csp_darknet/csp_darknet_backbone_test.py | 50 +++ .../csp_darknet_image_classifier.py | 133 ++++++ .../csp_darknet_image_classifier_test.py | 65 +++ 6 files changed, 677 insertions(+) create mode 100644 keras_nlp/src/models/csp_darknet/__init__.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 783cfd5087..aca1e28538 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -50,6 +50,12 @@ from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.src.models.causal_lm import CausalLM from keras_nlp.src.models.classifier import Classifier +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( DebertaV3Backbone, ) diff --git a/keras_nlp/src/models/csp_darknet/__init__.py b/keras_nlp/src/models/csp_darknet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py new file mode 100644 index 0000000000..2745f61d01 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -0,0 +1,410 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone") +class CSPDarkNetBackbone(Backbone): + """This class represents Keras Backbone of CSPDarkNet model. + + This class implements a CSPDarkNet backbone as described in + [CSPNet: A New Backbone that can Enhance Learning Capability of CNN]( + https://arxiv.org/abs/1911.11929). + + Args: + stackwise_num_filters: A list of ints, filter size for each dark + level in the model. + stackwise_depth: A list of ints, the depth for each dark level in the + model. + include_rescaling: boolean. If `True`, rescale the input using + `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to + `True`. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + input_image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.CSPDarkNetBackbone.from_preset( + "csp_darknet_tiny_imagenet" + ) + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_depth, + include_rescaling, + block_type="basic_block", + input_image_shape=(224, 224, 3), + **kwargs, + ): + # === Functional Model === + apply_ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "depthwise_block" + else apply_darknet_conv_block + ) + base_channels = stackwise_num_filters[0] // 2 + + image_input = layers.Input(shape=input_image_shape) + x = image_input + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + + x = apply_focus(name="stem_focus")(x) + x = apply_darknet_conv_block( + base_channels, kernel_size=3, strides=1, name="stem_conv" + )(x) + for index, (channels, depth) in enumerate( + zip(stackwise_num_filters, stackwise_depth) + ): + x = apply_ConvBlock( + channels, + kernel_size=3, + strides=2, + name=f"dark{index + 2}_conv", + )(x) + + if index == len(stackwise_depth) - 1: + x = apply_spatial_pyramid_pooling_bottleneck( + channels, + hidden_filters=channels // 2, + name=f"dark{index + 2}_spp", + )(x) + + x = apply_cross_stage_partial( + channels, + num_bottlenecks=depth, + block_type="basic_block", + residual=(index != len(stackwise_depth) - 1), + name=f"dark{index + 2}_csp", + )(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_depth = stackwise_depth + self.include_rescaling = include_rescaling + self.block_type = block_type + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_depth": self.stackwise_depth, + "include_rescaling": self.include_rescaling, + "block_type": self.block_type, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_focus(name=None): + """A block used in CSPDarknet to focus information into channels of the + image. + + If the dimensions of a batch input is (batch_size, width, height, channels), + this layer converts the image into size (batch_size, width/2, height/2, + 4*channels). See [the original discussion on YoloV5 Focus Layer](https://github.com/ultralytics/yolov5/discussions/3181). + + Args: + name: the name for the lambda layer used in the block. + + Returns: + a function that takes an input Tensor representing a Focus layer. + """ + + def apply(x): + return layers.Concatenate(name=name)( + [ + x[..., ::2, ::2, :], + x[..., 1::2, ::2, :], + x[..., ::2, 1::2, :], + x[..., 1::2, 1::2, :], + ], + ) + + return apply + + +def apply_darknet_conv_block( + filters, kernel_size, strides, use_bias=False, activation="silu", name=None +): + """ + The basic conv block used in Darknet. Applies Conv2D followed by a + BatchNorm. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + use_bias: Boolean, whether the layer uses a bias vector. + activation: the activation applied after the BatchNorm layer. One of + "silu", "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.Conv2D( + filters, + kernel_size, + strides, + padding="same", + use_bias=use_bias, + name=name + "_conv", + )(inputs) + + x = layers.BatchNormalization(name=name + "_bn")(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.silu(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + return x + + return apply + + +def apply_darknet_conv_block_depthwise( + filters, kernel_size, strides, activation="silu", name=None +): + """ + The depthwise conv block used in CSPDarknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.DepthwiseConv2D( + kernel_size, strides, padding="same", use_bias=False + )(inputs) + x = layers.BatchNormalization()(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.swish(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + x = apply_darknet_conv_block( + filters, kernel_size=1, strides=1, activation=activation + )(x) + + return x + + return apply + + +def apply_spatial_pyramid_pooling_bottleneck( + filters, + hidden_filters=None, + kernel_sizes=(5, 9, 13), + activation="silu", + name=None, +): + """ + Spatial pyramid pooling layer used in YOLOv3-SPP + + Args: + filters: Integer, the dimensionality of the output spaces (i.e. the + number of output filters in used the blocks). + hidden_filters: Integer, the dimensionality of the intermediate + bottleneck space (i.e. the number of output filters in the + bottleneck convolution). If None, it will be equal to filters. + Defaults to None. + kernel_sizes: A list or tuple representing all the pool sizes used for + the pooling layers, defaults to (5, 9, 13). + activation: Activation for the conv layers, defaults to "silu". + name: the prefix for the layer names used in the block. + + Returns: + a function that takes an input Tensor representing an + SpatialPyramidPoolingBottleneck. + """ + if name is None: + name = f"spp{keras.backend.get_uid('spp')}" + + if hidden_filters is None: + hidden_filters = filters + + def apply(x): + x = apply_darknet_conv_block( + hidden_filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(x) + x = [x] + + for kernel_size in kernel_sizes: + x.append( + layers.MaxPooling2D( + kernel_size, + strides=1, + padding="same", + name=f"{name}_maxpool_{kernel_size}", + )(x[0]) + ) + + x = layers.Concatenate(name=f"{name}_concat")(x) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(x) + + return x + + return apply + + +def apply_cross_stage_partial( + filters, + num_bottlenecks, + residual=True, + block_type="basic_block", + activation="silu", + name=None, +): + """A block used in Cross Stage Partial Darknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + num_bottlenecks: an integer representing the number of blocks added in + the layer bottleneck. + residual: a boolean representing whether the value tensor before the + bottleneck should be added to the output of the bottleneck as a + residual, defaults to True. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + """ + + if name is None: + name = f"cross_stage_partial_{keras.backend.get_uid('cross_stage_partial')}" + + def apply(inputs): + hidden_channels = filters // 2 + ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "basic_block" + else apply_darknet_conv_block + ) + + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(inputs) + + x2 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(inputs) + + for i in range(num_bottlenecks): + residual_x = x1 + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv1", + )(x1) + x1 = ConvBlock( + hidden_channels, + kernel_size=3, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv2", + )(x1) + if residual: + x1 = layers.Add(name=f"{name}_bottleneck_{i}_add")( + [residual_x, x1] + ) + + x = layers.Concatenate(name=f"{name}_concat")([x1, x2]) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv3", + )(x) + + return x + + return apply diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py new file mode 100644 index 0000000000..aaad4fe515 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_filters": [32, 64, 128, 256], + "stackwise_depth": [1, 3, 3, 1], + "include_rescaling": False, + "block_type": "basic_block", + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 256), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py new file mode 100644 index 0000000000..6b013bdcc0 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetImageClassifier") +class CSPDarkNetImageClassifier(ImageClassifier): + """CSPDarkNet image classifier task model. + + Args: + backbone: A `keras_nlp.models.CSPDarkNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.CSPDarkNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = CSPDarkNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py new file mode 100644 index 0000000000..a07bb017a3 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py @@ -0,0 +1,65 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = CSPDarkNetBackbone( + stackwise_num_filters=[2, 16, 16], + stackwise_depth=[1, 3, 3, 1], + include_rescaling=False, + block_type="basic_block", + input_image_shape=(16, 16, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 00ab4d5c4d0350872a64e9a42ad22cf4cb3a43c2 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Fri, 16 Aug 2024 04:29:57 +0800 Subject: [PATCH 04/19] Add `FeaturePyramidBackbone` and port weights from `timm` for `ResNetBackbone` (#1769) * Add FeaturePyramidBackbone and update ResNetBackbone * Simplify the implementation * Fix CI * Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone * Add conversion implementation * Update docstrings * Address comments --- keras_nlp/api/models/__init__.py | 1 + keras_nlp/src/models/backbone.py | 3 + .../src/models/feature_pyramid_backbone.py | 73 +++++ .../src/models/resnet/resnet_backbone.py | 252 +++++++++++------- .../src/models/resnet/resnet_backbone_test.py | 25 +- .../models/resnet/resnet_image_classifier.py | 7 +- .../resnet/resnet_image_classifier_test.py | 4 + keras_nlp/src/utils/preset_utils.py | 4 + keras_nlp/src/utils/timm/__init__.py | 13 + keras_nlp/src/utils/timm/convert.py | 37 +++ keras_nlp/src/utils/timm/convert_resnet.py | 171 ++++++++++++ .../src/utils/timm/convert_resnet_test.py | 28 ++ .../utils/transformers/safetensor_utils.py | 4 +- 13 files changed, 524 insertions(+), 98 deletions(-) create mode 100644 keras_nlp/src/models/feature_pyramid_backbone.py create mode 100644 keras_nlp/src/utils/timm/__init__.py create mode 100644 keras_nlp/src/utils/timm/convert.py create mode 100644 keras_nlp/src/utils/timm/convert_resnet.py create mode 100644 keras_nlp/src/utils/timm/convert_resnet_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index aca1e28538..e079aa7c9e 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -112,6 +112,7 @@ ) from keras_nlp.src.models.falcon.falcon_preprocessor import FalconPreprocessor from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index a58072dfce..0f41c63c81 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -30,6 +30,7 @@ from keras_nlp.src.utils.preset_utils import save_metadata from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.python_utils import classproperty +from keras_nlp.src.utils.timm.convert import load_timm_backbone from keras_nlp.src.utils.transformers.convert import load_transformers_backbone @@ -204,6 +205,8 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from if format == "transformers": return load_transformers_backbone(cls, preset, load_weights) + elif format == "timm": + return load_timm_backbone(cls, preset, load_weights, **kwargs) preset_cls = check_config_class(preset) if not issubclass(preset_cls, cls): diff --git a/keras_nlp/src/models/feature_pyramid_backbone.py b/keras_nlp/src/models/feature_pyramid_backbone.py new file mode 100644 index 0000000000..989d9fbd64 --- /dev/null +++ b/keras_nlp/src/models/feature_pyramid_backbone.py @@ -0,0 +1,73 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.FeaturePyramidBackbone") +class FeaturePyramidBackbone(Backbone): + """A backbone with feature pyramid outputs. + + `FeaturePyramidBackbone` extends `Backbone` with a single `pyramid_outputs` + property for accessing the feature pyramid outputs of the model. Subclassers + should set the `pyramid_outputs` property during the model constructor. + + Example: + + ```python + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) + + # Convert to feature pyramid output format using ResNet. + backbone = ResNetBackbone.from_preset("resnet50") + model = keras.Model( + inputs=backbone.inputs, outputs=backbone.pyramid_outputs + ) + model(input_data) # A dict containing the keys ["P2", "P3", "P4", "P5"] + ``` + """ + + @property + def pyramid_outputs(self): + """A dict for feature pyramid outputs. + + The key is a string represents the name of the feature output and the + value is a `keras.KerasTensor`. A typical feature pyramid has multiple + levels corresponding to scales such as `["P2", "P3", "P4", "P5"]`. Scale + `Pn` represents a feature map `2^n` times smaller in width and height + than the inputs. + """ + return getattr(self, "_pyramid_outputs", {}) + + @pyramid_outputs.setter + def pyramid_outputs(self, value): + if not isinstance(value, dict): + raise TypeError( + "`pyramid_outputs` must be a dictionary. " + f"Received: value={value} of type {type(value)}" + ) + for k, v in value.items(): + if not isinstance(k, str): + raise TypeError( + "The key of `pyramid_outputs` must be a string. " + f"Received: key={k} of type {type(k)}" + ) + if not isinstance(v, keras.KerasTensor): + raise TypeError( + "The value of `pyramid_outputs` must be a " + "`keras.KerasTensor`. " + f"Received: value={v} of type {type(v)}" + ) + self._pyramid_outputs = value diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index bec5ba60b5..0f4d7c139a 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -13,20 +13,23 @@ # limitations under the License. import keras from keras import layers +from keras import ops from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.utils.keras_utils import standardize_data_format @keras_nlp_export("keras_nlp.models.ResNetBackbone") -class ResNetBackbone(Backbone): +class ResNetBackbone(FeaturePyramidBackbone): """ResNet and ResNetV2 core network with hyperparameters. This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( - CVPR 2016) and [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016). + CVPR 2016), [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + improved training procedure in timm](https://arxiv.org/abs/2110.00476)( + NeurIPS 2021 Workshop). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -34,6 +37,9 @@ class ResNetBackbone(Backbone): the batch normalization and ReLU activation are applied after the convolution layers. + Note that `ResNetBackbone` expects the inputs to be images with a value + range of `[0, 255]` when `include_rescaling=True`. + Args: stackwise_num_filters: list of ints. The number of filters for each stack. @@ -46,8 +52,8 @@ class ResNetBackbone(Backbone): use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using - `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to - `True`. + `Rescaling` and `Normalization` layers. If `False`, do nothing. + Defaults to `True`. input_image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. pooling: `None` or str. Pooling mode for feature extraction. Defaults @@ -70,11 +76,11 @@ class ResNetBackbone(Backbone): `~/.keras/keras.json`. If you never set it, then it will be `"channels_last"`. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the models computations and weights. + to use for the model's computations and weights. Examples: ```python - input_data = np.ones((2, 224, 224, 3), dtype="float32") + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) # Pretrained ResNet backbone. model = keras_nlp.models.ResNetBackbone.from_preset("resnet50") @@ -136,34 +142,66 @@ def __init__( image_input = layers.Input(shape=input_image_shape) if include_rescaling: x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) + x = layers.Normalization( + axis=bn_axis, + mean=(0.485, 0.456, 0.406), + variance=(0.229**2, 0.224**2, 0.225**2), + dtype=dtype, + name="normalization", + )(x) else: x = image_input + # The padding between torch and tensorflow/jax differs when `strides>1`. + # Therefore, we need to manually pad the tensor. + x = layers.ZeroPadding2D( + 3, + data_format=data_format, + dtype=dtype, + name="conv1_pad", + )(x) x = layers.Conv2D( 64, 7, strides=2, - padding="same", data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name="conv1_conv", )(x) if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="conv1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) - x = layers.MaxPool2D( + if use_pre_activation: + # A workaround for ResNetV2: we need -inf padding to prevent zeros + # from being the max values in the following `MaxPooling2D`. + pad_width = [[1, 1], [1, 1]] + if data_format == "channels_last": + pad_width += [[0, 0]] + else: + pad_width = [[0, 0]] + pad_width + pad_width = [[0, 0]] + pad_width + x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf")) + else: + x = layers.ZeroPadding2D( + 1, data_format=data_format, dtype=dtype, name="pool1_pad" + )(x) + x = layers.MaxPooling2D( 3, strides=2, - padding="same", data_format=data_format, dtype=dtype, name="pool1_pool", )(x) + pyramid_outputs = {} for stack_index in range(num_stacks): x = apply_stack( x, @@ -179,10 +217,15 @@ def __init__( dtype=dtype, name=f"{version}_stack{stack_index}", ) + pyramid_outputs[f"P{stack_index + 2}"] = x if use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="post_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="post_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) @@ -213,18 +256,23 @@ def __init__( self.include_rescaling = include_rescaling self.input_image_shape = input_image_shape self.pooling = pooling + self.pyramid_outputs = pyramid_outputs def get_config(self): - return { - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_num_blocks": self.stackwise_num_blocks, - "stackwise_num_strides": self.stackwise_num_strides, - "block_type": self.block_type, - "use_pre_activation": self.use_pre_activation, - "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, - "pooling": self.pooling, - } + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_strides": self.stackwise_num_strides, + "block_type": self.block_type, + "use_pre_activation": self.use_pre_activation, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + ) + return config def apply_basic_block( @@ -269,68 +317,81 @@ def apply_basic_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=stride if not use_pre_activation else 1, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, + strides=1, padding="same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -381,79 +442,97 @@ def apply_bottleneck_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( 4 * filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x x = layers.Conv2D( filters, 1, - strides=stride if not use_pre_activation else 1, + strides=1, data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + x = layers.Conv2D( 4 * filters, 1, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_3_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_3_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -513,32 +592,21 @@ def apply_stack( '`block_type` must be either `"basic_block"` or ' f'`"bottleneck_block"`. Received block_type={block_type}.' ) - x = block_fn( - x, - filters, - stride=stride if not use_pre_activation else 1, - conv_shortcut=first_shortcut, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block1", - ) - for i in range(2, blocks): + for i in range(blocks): + if i == 0: + stride = stride + conv_shortcut = first_shortcut + else: + stride = 1 + conv_shortcut = False x = block_fn( x, filters, + stride=stride, + conv_shortcut=conv_shortcut, use_pre_activation=use_pre_activation, data_format=data_format, dtype=dtype, name=f"{name}_block{str(i)}", ) - x = block_fn( - x, - filters, - stride=1 if not use_pre_activation else stride, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block{str(blocks)}", - ) return x diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 2113bcd131..6d3f774559 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -14,6 +14,7 @@ import pytest from absl.testing import parameterized +from keras import models from keras import ops from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone @@ -29,8 +30,8 @@ def setUp(self): "input_image_shape": (None, None, 3), "pooling": "avg", } - self.input_size = (16, 16) - self.input_data = ops.ones((2, 16, 16, 3)) + self.input_size = 64 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) @parameterized.named_parameters( ("v1_basic", False, "basic_block"), @@ -52,6 +53,24 @@ def test_backbone_basics(self, use_pre_activation, block_type): ), ) + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + {"block_type": "basic_block", "use_pre_activation": False} + ) + backbone = ResNetBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** int(k[1:])) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + @parameterized.named_parameters( ("v1_basic", False, "basic_block"), ("v1_bottleneck", False, "bottleneck_block"), @@ -65,7 +84,7 @@ def test_saved_model(self, use_pre_activation, block_type): { "block_type": block_type, "use_pre_activation": use_pre_activation, - "input_image_shape": (16, 16, 3), + "input_image_shape": (None, None, 3), } ) self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py index 02c8c78b27..815dc7fcca 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -28,6 +28,8 @@ class ResNetImageClassifier(ImageClassifier): activation: `None`, str or callable. The activation function to use on the `Dense` layer. Set `activation=None` to return the output logits. Defaults to `"softmax"`. + head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` where `x` is a tensor and `y` is a integer from `[0, num_classes)`. @@ -92,16 +94,19 @@ def __init__( backbone, num_classes, activation="softmax", + head_dtype=None, preprocessor=None, # adding this dummy arg for saved model test # TODO: once preprocessor flow is figured out, this needs to be updated **kwargs, ): + head_dtype = head_dtype or backbone.dtype_policy + # === Layers === self.backbone = backbone self.output_dense = keras.layers.Dense( num_classes, activation=activation, - dtype=self.backbone.dtype_policy, + dtype=head_dtype, name="predictions", ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index bbbda72d64..f3f63a14a1 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -53,6 +53,10 @@ def test_classifier_basics(self): expected_output_shape=(2, 2), ) + def test_head_dtype(self): + model = ResNetImageClassifier(**self.init_kwargs, head_dtype="bfloat16") + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index f797bf9f18..9e3f51c43a 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -544,6 +544,10 @@ def check_format(preset): if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists( preset, SAFETENSOR_CONFIG_FILE ): + # Determine the format by parsing the config file. + config = load_config(preset, HF_CONFIG_FILE) + if "hf://timm" in preset or "architecture" in config: + return "timm" return "transformers" if not check_file_exists(preset, METADATA_FILE): diff --git a/keras_nlp/src/utils/timm/__init__.py b/keras_nlp/src/utils/timm/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/utils/timm/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/utils/timm/convert.py b/keras_nlp/src/utils/timm/convert.py new file mode 100644 index 0000000000..edfde3316b --- /dev/null +++ b/keras_nlp/src/utils/timm/convert.py @@ -0,0 +1,37 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert timm models to KerasNLP.""" + +from keras_nlp.src.utils.timm.convert_resnet import load_resnet_backbone + + +def load_timm_backbone(cls, preset, load_weights, **kwargs): + """Load a timm model config and weights as a KerasNLP backbone. + + Args: + cls (class): Keras model class. + preset (str): Preset configuration name. + load_weights (bool): Whether to load the weights. + + Returns: + backbone: Initialized Keras model backbone. + """ + if cls is None: + raise ValueError("Backbone class is None") + if cls.__name__ == "ResNetBackbone": + return load_resnet_backbone(cls, preset, load_weights, **kwargs) + raise ValueError( + f"{cls} has not been ported from the Hugging Face format yet. " + "Please check Hugging Face Hub for the Keras model. " + ) diff --git a/keras_nlp/src/utils/timm/convert_resnet.py b/keras_nlp/src/utils/timm/convert_resnet.py new file mode 100644 index 0000000000..de2224eb9e --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet.py @@ -0,0 +1,171 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + if "resnetv2_" in timm_architecture: + use_pre_activation = True + else: + use_pre_activation = False + + if timm_architecture == "resnet18": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "basic_block" + elif timm_architecture == "resnet26": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "bottleneck_block" + elif timm_architecture == "resnet34": + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "basic_block" + elif timm_architecture in ("resnet50", "resnetv2_50"): + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet101", "resnetv2_101"): + stackwise_num_blocks = [3, 4, 23, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet152", "resnetv2_152"): + stackwise_num_blocks = [3, 8, 36, 3] + block_type = "bottleneck_block" + else: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + + return dict( + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=stackwise_num_blocks, + stackwise_num_strides=[1, 2, 2, 2], + block_type=block_type, + use_pre_activation=use_pre_activation, + ) + + +def convert_weights(backbone, loader, timm_config): + def port_conv2d(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + def port_batch_normalization(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + + version = "v1" if not backbone.use_pre_activation else "v2" + block_type = backbone.block_type + + # Stem + if version == "v1": + port_conv2d("conv1_conv", "conv1") + port_batch_normalization("conv1_bn", "bn1") + else: + port_conv2d("conv1_conv", "stem.conv") + + # Stages + num_stacks = len(backbone.stackwise_num_filters) + for stack_index in range(num_stacks): + for block_idx in range(backbone.stackwise_num_blocks[stack_index]): + if version == "v1": + keras_name = f"v1_stack{stack_index}_block{block_idx}" + hf_name = f"layer{stack_index+1}.{block_idx}" + else: + keras_name = f"v2_stack{stack_index}_block{block_idx}" + hf_name = f"stages.{stack_index}.blocks.{block_idx}" + + if version == "v1": + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.0" + ) + port_batch_normalization( + f"{keras_name}_0_bn", f"{hf_name}.downsample.1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1") + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2") + if block_type == "bottleneck_block": + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + port_batch_normalization( + f"{keras_name}_3_bn", f"{hf_name}.bn3" + ) + else: + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.conv" + ) + port_batch_normalization( + f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization( + f"{keras_name}_1_bn", f"{hf_name}.norm2" + ) + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + if block_type == "bottleneck_block": + port_batch_normalization( + f"{keras_name}_2_bn", f"{hf_name}.norm3" + ) + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + + # Post + if version == "v2": + port_batch_normalization("post_bn", "norm") + + # Rebuild normalization layer with pretrained mean & std + mean = timm_config["pretrained_cfg"]["mean"] + std = timm_config["pretrained_cfg"]["std"] + normalization_layer = backbone.get_layer("normalization") + normalization_layer.input_mean = mean + normalization_layer.input_variance = [s**2 for s in std] + normalization_layer.build(normalization_layer._build_input_shape) + + +def load_resnet_backbone(cls, preset, load_weights, **kwargs): + timm_config = load_config(preset, HF_CONFIG_FILE) + keras_config = convert_backbone_config(timm_config) + backbone = cls(**keras_config, **kwargs) + if load_weights: + jax_memory_cleanup(backbone) + # Use prefix="" to avoid using `get_prefixed_key`. + with SafetensorLoader(preset, prefix="") as loader: + convert_weights(backbone, loader, timm_config) + return backbone diff --git a/keras_nlp/src/utils/timm/convert_resnet_test.py b/keras_nlp/src/utils/timm/convert_resnet_test.py new file mode 100644 index 0000000000..a30bee46af --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet_test.py @@ -0,0 +1,28 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class TimmResNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_resnet18_preset(self): + model = ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k") + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 512)) + + # TODO: compare numerics with timm model diff --git a/keras_nlp/src/utils/transformers/safetensor_utils.py b/keras_nlp/src/utils/transformers/safetensor_utils.py index 40ef473ff3..2fbd7e1aba 100644 --- a/keras_nlp/src/utils/transformers/safetensor_utils.py +++ b/keras_nlp/src/utils/transformers/safetensor_utils.py @@ -26,7 +26,7 @@ class SafetensorLoader(contextlib.ExitStack): - def __init__(self, preset): + def __init__(self, preset, prefix=None): super().__init__() if safetensors is None: @@ -42,7 +42,7 @@ def __init__(self, preset): else: self.safetensor_config = None self.safetensor_files = {} - self.prefix = None + self.prefix = prefix def get_prefixed_key(self, hf_weight_key, dict_like): """ From 9860756f183cc4ad9247bc29b6c0ee55ec2db6fc Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 17:38:39 -0700 Subject: [PATCH 05/19] Add DenseNet (#1775) * Add DenseNet * fix testcase * address comments * nit * fix lint errors * move description --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/densenet/__init__.py | 13 ++ .../src/models/densenet/densenet_backbone.py | 210 ++++++++++++++++++ .../models/densenet/densenet_backbone_test.py | 48 ++++ .../densenet/densenet_image_classifier.py | 131 +++++++++++ .../densenet_image_classifier_test.py | 63 ++++++ 6 files changed, 469 insertions(+) create mode 100644 keras_nlp/src/models/densenet/__init__.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone_test.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index e079aa7c9e..bf5cc28060 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -74,6 +74,10 @@ from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) diff --git a/keras_nlp/src/models/densenet/__init__.py b/keras_nlp/src/models/densenet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/densenet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py new file mode 100644 index 0000000000..8456fbcee6 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -0,0 +1,210 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + +BN_AXIS = 3 +BN_EPSILON = 1.001e-5 + + +@keras_nlp_export("keras_nlp.models.DenseNetBackbone") +class DenseNetBackbone(Backbone): + """Instantiates the DenseNet architecture. + + This class implements a DenseNet backbone as described in + [Densely Connected Convolutional Networks (CVPR 2017)]( + https://arxiv.org/abs/1608.06993 + ). + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per dense block. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. Defaults to `True`. + input_image_shape: optional shape tuple, defaults to (224, 224, 3). + compression_ratio: float, compression rate at transition layers, + defaults to 0.5. + growth_rate: int, number of filters added by each dense block, + defaults to 32 + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.DenseNetBackbone.from_preset("densenet121_imagenet") + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_repeats, + include_rescaling=True, + input_image_shape=(224, 224, 3), + compression_ratio=0.5, + growth_rate=32, + **kwargs, + ): + # === Functional Model === + image_input = keras.layers.Input(shape=input_image_shape) + + x = image_input + if include_rescaling: + x = keras.layers.Rescaling(1 / 255.0)(x) + + x = keras.layers.Conv2D( + 64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn" + )(x) + x = keras.layers.Activation("relu", name="conv1_relu")(x) + x = keras.layers.MaxPooling2D( + 3, strides=2, padding="same", name="pool1" + )(x) + + for stack_index in range(len(stackwise_num_repeats) - 1): + index = stack_index + 2 + x = apply_dense_block( + x, + stackwise_num_repeats[stack_index], + growth_rate, + name=f"conv{index}", + ) + x = apply_transition_block( + x, compression_ratio, name=f"pool{index}" + ) + + x = apply_dense_block( + x, + stackwise_num_repeats[-1], + growth_rate, + name=f"conv{len(stackwise_num_repeats) + 1}", + ) + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" + )(x) + x = keras.layers.Activation("relu", name="relu")(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.include_rescaling = include_rescaling + self.compression_ratio = compression_ratio + self.growth_rate = growth_rate + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_repeats": self.stackwise_num_repeats, + "include_rescaling": self.include_rescaling, + "compression_ratio": self.compression_ratio, + "growth_rate": self.growth_rate, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_dense_block(x, num_repeats, growth_rate, name=None): + """A dense block. + + Args: + x: input tensor. + num_repeats: int, number of repeated convolutional blocks. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"dense_block_{keras.backend.get_uid('dense_block')}" + + for i in range(num_repeats): + x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}") + return x + + +def apply_transition_block(x, compression_ratio, name=None): + """A transition block. + + Args: + x: input tensor. + compression_ratio: float, compression rate at transition layers. + name: string, block label. + """ + if name is None: + name = f"transition_block_{keras.backend.get_uid('transition_block')}" + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_relu")(x) + x = keras.layers.Conv2D( + int(x.shape[BN_AXIS] * compression_ratio), + 1, + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x) + return x + + +def apply_conv_block(x, growth_rate, name=None): + """A building block for a dense block. + + Args: + x: input tensor. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"conv_block_{keras.backend.get_uid('conv_block')}" + + shortcut = x + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x) + x = keras.layers.Conv2D( + 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x) + x = keras.layers.Conv2D( + growth_rate, + 3, + padding="same", + use_bias=False, + name=f"{name}_2_conv", + )(x) + x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")( + [shortcut, x] + ) + return x diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py new file mode 100644 index 0000000000..f0f8dac875 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class DenseNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [6, 12, 24, 16], + "include_rescaling": True, + "compression_ratio": 0.5, + "growth_rate": 32, + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 1024), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py new file mode 100644 index 0000000000..395e8f754d --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -0,0 +1,131 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.DenseNetImageClassifier") +class DenseNetImageClassifier(ImageClassifier): + """DenseNet image classifier task model. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Args: + backbone: A `keras_nlp.models.DenseNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.DenseNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.DenseNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = DenseNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py new file mode 100644 index 0000000000..60d77d489c --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class DenseNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=True, + compression_ratio=0.5, + growth_rate=32, + input_image_shape=(224, 224, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=DenseNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DenseNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From fd6f977b0136499ad4e1cf78cc8aea69fb3bfc27 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 20 Aug 2024 12:24:13 -0700 Subject: [PATCH 06/19] Add ViTDetBackbone (#1776) * add vit det vit_det_backbone * update docstring * code reformat * fix tests * address review comments * bump year on all files * address review comments * rename backbone * fix tests * change back to ViT * address review comments * update image shape --- keras_nlp/api/models/__init__.py | 1 + .../src/models/vit_det/vit_det_backbone.py | 204 +++++++ .../models/vit_det/vit_det_backbone_test.py | 54 ++ keras_nlp/src/models/vit_det/vit_layers.py | 565 ++++++++++++++++++ 4 files changed, 824 insertions(+) create mode 100644 keras_nlp/src/models/vit_det/vit_det_backbone.py create mode 100644 keras_nlp/src/models/vit_det/vit_det_backbone_test.py create mode 100644 keras_nlp/src/models/vit_det/vit_layers.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 6f7e08c520..1a6dd2e74f 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -212,6 +212,7 @@ from keras_nlp.src.models.task import Task from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_nlp.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/src/models/vit_det/vit_det_backbone.py b/keras_nlp/src/models/vit_det/vit_det_backbone.py new file mode 100644 index 0000000000..1e83e94b05 --- /dev/null +++ b/keras_nlp/src/models/vit_det/vit_det_backbone.py @@ -0,0 +1,204 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.vit_det.vit_layers import AddPositionalEmbedding +from keras_nlp.src.models.vit_det.vit_layers import ViTDetPatchingAndEmbedding +from keras_nlp.src.models.vit_det.vit_layers import WindowedTransformerEncoder + + +@keras_nlp_export("keras_nlp.models.ViTDetBackbone") +class ViTDetBackbone(Backbone): + """An implementation of ViT image encoder. + + The ViTDetBackbone uses a windowed transformer encoder and relative + positional encodings. The code has been adapted from [Segment Anything + paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + hidden_size (int): The latent dimensionality to be projected + into in the output of each stacked windowed transformer encoder. + num_layers (int): The number of transformer encoder layers to + stack in the Vision Transformer. + intermediate_dim (int): The dimensionality of the hidden Dense + layer in the transformer MLP head. + num_heads (int): the number of heads to use in the + `MultiHeadAttentionWithRelativePE` layer of each transformer + encoder. + global_attention_layer_indices (list): Indexes for blocks using + global attention. + image_shape (tuple[int], optional): The size of the input image in + `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. + include_rescaling (bool, optional): Whether to rescale the inputs. If + set to `True`, inputs will be passed through a + `Rescaling(1/255.0)` layer. Defaults to `False`. + patch_size (int, optional): the patch size to be supplied to the + Patching layer to turn input images into a flattened sequence of + patches. Defaults to `16`. + num_output_channels (int, optional): The number of channels (features) + in the output (image encodings). Defaults to `256`. + use_bias (bool, optional): Whether to use bias to project the keys, + queries, and values in the attention layer. Defaults to `True`. + use_abs_pos (bool, optional): Whether to add absolute positional + embeddings to the output patches. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + emcodings in the attention layer. Defaults to `True`. + window_size (int, optional): The size of the window for windowed + attention in the transformer encoder blocks. Defaults to `14`. + layer_norm_epsilon (int, optional): The epsilon to use in the layer + normalization blocks in transformer encoder. Defaults to `1e-6`. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained ViTDetBackbone backbone. + model = keras_nlp.models.ViTDetBackbone.from_preset("vit_det") + model(input_data) + + # Randomly initialized ViTDetBackbone backbone with a custom config. + model = keras_nlp.models.ViTDetBackbone( + image_shape = (16, 16, 3), + patch_size = 2, + hidden_size = 4, + num_layers = 2, + global_attention_layer_indices = [2, 5, 8, 11], + intermediate_dim = 4 * 4, + num_heads = 2, + num_output_channels = 2, + window_size = 2, + ) + model(input_data) + ``` + """ + + def __init__( + self, + hidden_size, + num_layers, + intermediate_dim, + num_heads, + global_attention_layer_indices, + include_rescaling=True, + image_shape=(1024, 1024, 3), + patch_size=16, + num_output_channels=256, + use_bias=True, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + layer_norm_epsilon=1e-6, + **kwargs + ): + # === Functional model === + img_input = keras.layers.Input(shape=image_shape) + # Check that the input image is well specified. + if img_input.shape[-3] is None or img_input.shape[-2] is None: + raise ValueError( + "Height and width of the image must be specified" + " in `image_shape`." + ) + if img_input.shape[-3] != img_input.shape[-2]: + raise ValueError( + "Input image must be square i.e. the height must" + " be equal to the width in the `image_shape`" + " tuple/tensor." + ) + img_size = img_input.shape[-3] + x = img_input + if include_rescaling: + # Use common rescaling strategy across keras_cv + x = keras.layers.Rescaling(1.0 / 255.0)(x) + # VITDet scales inputs based on the standard ImageNet mean/stddev. + x = (x - ops.array([0.485, 0.456, 0.406], dtype=x.dtype)) / ( + ops.array([0.229, 0.224, 0.225], dtype=x.dtype) + ) + x = ViTDetPatchingAndEmbedding( + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + embed_dim=hidden_size, + )(x) + if use_abs_pos: + x = AddPositionalEmbedding(img_size, patch_size, hidden_size)(x) + for i in range(num_layers): + x = WindowedTransformerEncoder( + project_dim=hidden_size, + intermediate_dim=intermediate_dim, + num_heads=num_heads, + use_bias=use_bias, + use_rel_pos=use_rel_pos, + window_size=( + window_size + if i not in global_attention_layer_indices + else 0 + ), + input_size=(img_size // patch_size, img_size // patch_size), + )(x) + x = keras.layers.Conv2D( + filters=num_output_channels, kernel_size=1, use_bias=False + )(x) + x = keras.layers.LayerNormalization(epsilon=1e-6)(x) + x = keras.layers.Conv2D( + filters=num_output_channels, + kernel_size=3, + padding="same", + use_bias=False, + )(x) + x = keras.layers.LayerNormalization(epsilon=1e-6)(x) + + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.patch_size = patch_size + self.image_shape = image_shape + self.hidden_size = hidden_size + self.num_layers = num_layers + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.num_output_channels = num_output_channels + self.use_bias = use_bias + self.use_rel_pos = use_rel_pos + self.use_abs_pos = use_abs_pos + self.window_size = window_size + self.global_attention_layer_indices = global_attention_layer_indices + self.layer_norm_epsilon = layer_norm_epsilon + self.include_rescaling = include_rescaling + + def get_config(self): + config = super().get_config() + config.update( + { + "image_shape": self.image_shape, + "include_rescaling": self.include_rescaling, + "patch_size": self.patch_size, + "hidden_size": self.hidden_size, + "num_layers": self.num_layers, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "num_output_channels": self.num_output_channels, + "use_bias": self.use_bias, + "use_abs_pos": self.use_abs_pos, + "use_rel_pos": self.use_rel_pos, + "window_size": self.window_size, + "global_attention_layer_indices": self.global_attention_layer_indices, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_nlp/src/models/vit_det/vit_det_backbone_test.py b/keras_nlp/src/models/vit_det/vit_det_backbone_test.py new file mode 100644 index 0000000000..0ae277d122 --- /dev/null +++ b/keras_nlp/src/models/vit_det/vit_det_backbone_test.py @@ -0,0 +1,54 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.vit_det.vit_det_backbone import ViTDetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class ViTDetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "include_rescaling": True, + "image_shape": (16, 16, 3), + "patch_size": 2, + "hidden_size": 4, + "num_layers": 2, + "global_attention_layer_indices": [2, 5, 8, 11], + "intermediate_dim": 4 * 4, + "num_heads": 2, + "num_output_channels": 2, + "window_size": 2, + } + self.input_data = np.ones((1, 16, 16, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=ViTDetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(1, 8, 8, 2), + run_mixed_precision_check=False, + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTDetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/vit_det/vit_layers.py b/keras_nlp/src/models/vit_det/vit_layers.py new file mode 100644 index 0000000000..e784595371 --- /dev/null +++ b/keras_nlp/src/models/vit_det/vit_layers.py @@ -0,0 +1,565 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import keras +from keras import ops + + +class MLP(keras.layers.Layer): + """A MLP block with architecture. + + The MLP block implements `input_dim -> [intermediate_dim] -> + hidden_dim`. The code has been adapted from [Segment Anything paper]( + https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + intermediate_dim (int): The number of units in the hidden layers. + hidden_dim (int): The number of units in the output layer. + activation (str): Activation to use in the hidden layers. + Default is `"relu"`. + """ + + def __init__( + self, intermediate_dim, hidden_dim, activation="relu", **kwargs + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.activation = activation + h = [intermediate_dim] + self.dense_net = [] + for intermediate_dim in h: + self.dense_net.append(keras.layers.Dense(intermediate_dim)) + self.dense_net.append(keras.layers.Activation(activation)) + self.dense_net.append(keras.layers.Dense(hidden_dim)) + self.dense_net = keras.models.Sequential(self.dense_net) + + def build(self, input_shape): + self.dense_net.build(input_shape) + self.built = True + + def call(self, x): + return self.dense_net(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "hidden_dim": self.hidden_dim, + "activation": self.activation, + } + ) + return config + + +class AddRelativePositionalEmbedding(keras.layers.Layer): + def __init__(self, input_size, key_dim, **kwargs): + super().__init__(**kwargs) + self.input_size = input_size + self.key_dim = key_dim + self.rel_pos_h = self.add_weight( + name="rel_pos_h", + shape=(2 * self.input_size[0] - 1, self.key_dim), + initializer="zeros", + ) + self.rel_pos_w = self.add_weight( + name="rel_pos_w", + shape=(2 * self.input_size[1] - 1, self.key_dim), + initializer="zeros", + ) + self.built = True + + def _get_rel_pos(self, query_size, key_size, rel_pos): + """Get relative positional embeddings. + + Get relative positional embeddings according to the relative positions + of query and key sizes. + + Args: + query_size (int): The number of features of the queries. + key_size (int): The number of features of the keys. + rel_pos (tensor): Relative positional embedding tensor. + + Returns: + tensor: Extracted positional embeddings according to relative + positions. + """ + max_rel_dist = 2 * max(query_size, key_size) - 1 + if ops.shape(rel_pos)[0] != max_rel_dist: + rel_pos_resized = ops.image.resize( + image=ops.reshape( + rel_pos, + (1, ops.shape(rel_pos)[0], ops.shape(rel_pos)[1], 1), + ), + size=(max_rel_dist, ops.shape(rel_pos)[1]), + interpolation="bilinear", + ) + rel_pos_resized = ops.squeeze(rel_pos_resized, axis=(0, -1)) + return rel_pos_resized + else: + rel_pos_resized = rel_pos + # Query coordinates + query_coordinates = ops.cast( + ops.arange(query_size), dtype=self.compute_dtype + )[:, None] * (max(key_size / query_size, 1.0)) + # Key coordinates + key_coordinates = ops.cast( + ops.arange(key_size), dtype=self.compute_dtype + )[None, :] * (max(query_size / key_size, 1.0)) + # Relative coordinates + relative_coordinates = (query_coordinates - key_coordinates) + ( + key_size - 1 + ) * max(query_size / key_size, 1.0) + relative_coordinates = ops.cast(relative_coordinates, dtype="int32") + return ops.take(rel_pos_resized, relative_coordinates, 0) + + def call(self, attention_map, queries, query_size, key_size): + """Calculate decomposed Relative Positional Embeddings + + The code has been adapted based on + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501 + + Args: + attention_map (tensor): Attention map. + queries (tensor): Queries in the attention layer with shape + `(batch, query_height * query_width, channels)`. + query_size (tuple[int, int]): Spatial sequence size of queries with + `(query_height, query_width)`. + key_size (tuple[int, int]): Spatial sequence size of keys with + `(key_height, key_width)`. + + Returns: + tensor: attention map with added relative positional embeddings. + """ + query_height, query_width = query_size[0], query_size[1] + key_height, key_width = key_size[0], key_size[1] + rel_heights = self._get_rel_pos( + query_height, key_height, self.rel_pos_h + ) + rel_widths = self._get_rel_pos(query_width, key_width, self.rel_pos_w) + shape = ops.shape(queries) + batch, channels = shape[0], shape[2] + rel_queries = ops.reshape( + queries, (batch, query_height, query_width, channels) + ) + rel_heights = ops.einsum("bhwc,hkc->bhwk", rel_queries, rel_heights) + rel_widths = ops.einsum("bhwc,wkc->bhwk", rel_queries, rel_widths) + attention_map = ops.reshape( + attention_map, + (batch, query_height, query_width, key_height, key_width), + ) + attention_map = attention_map + rel_heights[..., :, None] + attention_map = attention_map + rel_widths[..., None, :] + attention_map = ops.reshape( + attention_map, + (batch, query_height * query_width, key_height * key_width), + ) + return attention_map + + def get_config(self): + config = super().get_config() + config.update({"input_size": self.input_size, "key_dim": self.key_dim}) + return config + + +class MultiHeadAttentionWithRelativePE(keras.layers.Layer): + """Multi-head Attention block with relative position embeddings. + + The code has been adapted from [Segment Anything paper]( + https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + num_heads (int): Number of attention heads. + key_dim (int): Size of each attention head for query, key, and + value. + use_bias (bool, optional): Whether to use bias when projecting + the queries, keys, and values. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + embeddings or not. Defaults to `False`. + input_size (tuple[int, int], optional): Size of the input image. + Must be provided when using relative positional embeddings. + Defaults to `None`. + + Raises: + ValueError: When `input_size = None` with `use_rel_pos = True`. + """ + + def __init__( + self, + num_heads, + key_dim, + use_bias=True, + use_rel_pos=False, + input_size=None, + **kwargs + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.key_dim = key_dim + self.scale = self.key_dim**-0.5 + self.use_bias = use_bias + self.input_size = input_size + self.use_rel_pos = use_rel_pos + self.qkv = keras.layers.Dense( + key_dim * self.num_heads * 3, use_bias=self.use_bias + ) + self.projection = keras.layers.Dense(key_dim * self.num_heads) + if self.use_rel_pos: + if input_size is None: + raise ValueError( + "Input size must be provided if using relative " + "positional encoding." + ) + self.add_decomposed_reative_pe = AddRelativePositionalEmbedding( + self.input_size, self.key_dim + ) + + def build(self, input_shape=None): + self.qkv.build([self.key_dim * self.num_heads]) + self.projection.build([self.key_dim * self.num_heads]) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + batch, height, width, channels = ops.shape(x) + qkv = ops.transpose( + ops.reshape( + self.qkv(x), + (batch, height * width, 3, self.num_heads, self.key_dim), + ), + axes=(2, 0, 3, 1, 4), + ) + qkv = ops.reshape( + qkv, (3, batch * self.num_heads, height * width, self.key_dim) + ) + queries, keys, values = ops.unstack(qkv, axis=0) + attention_map = (queries * self.scale) @ ops.transpose( + keys, axes=(0, 2, 1) + ) + if self.use_rel_pos: + attention_map = self.add_decomposed_reative_pe( + attention_map, + queries=queries, + query_size=(height, width), + key_size=(height, width), + ) + attention_map = ops.softmax(attention_map, axis=-1) + x = ops.reshape( + attention_map @ values, + (batch, self.num_heads, height, width, self.key_dim), + ) + x = ops.transpose(x, axes=(0, 2, 3, 1, 4)) + x = ops.reshape(x, (batch, height, width, channels)) + x = self.projection(x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "key_dim": self.key_dim, + "use_bias": self.use_bias, + "use_rel_pos": self.use_rel_pos, + "input_size": self.input_size, + } + ) + return config + + +class WindowPartitioning(keras.layers.Layer): + def __init__(self, window_size, **kwargs): + super().__init__(**kwargs) + self.window_size = window_size + self.built = True + + def partition(self, x): + batch, height, width, channels = ops.shape(x) + pad_height = ( + self.window_size - height % self.window_size + ) % self.window_size + pad_width = ( + self.window_size - width % self.window_size + ) % self.window_size + if pad_height > 0 or pad_width > 0: + x = ops.pad(x, ((0, 0), (0, pad_height), (0, pad_width), (0, 0))) + height_padded, width_padded = height + pad_height, width + pad_width + x = ops.reshape( + x, + ( + batch, + height_padded // self.window_size, + self.window_size, + width_padded // self.window_size, + self.window_size, + channels, + ), + ) + windows = ops.reshape( + ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), + (-1, self.window_size, self.window_size, channels), + ) + return windows, (height_padded, width_padded) + + def unpartition(self, windows, height_width_padded, height_width): + height_padded, width_padded = height_width_padded + height, width = height_width + batch = ops.shape(windows)[0] // ( + (height_padded // self.window_size) + * (width_padded // self.window_size) + ) + x = ops.reshape( + windows, + ( + batch, + height_padded // self.window_size, + width_padded // self.window_size, + self.window_size, + self.window_size, + -1, + ), + ) + x = ops.reshape( + ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)), + (batch, height_padded, width_padded, -1), + ) + return x[:, :height, :width, :] + + def get_config(self): + config = super().get_config() + config.update({"window_size": self.window_size}) + return config + + +class WindowedTransformerEncoder(keras.layers.Layer): + """Implements windowed transformer encoder. + + Transformer blocks with support of window attention and residual + propagation blocks. The code has been adapted from [Segment Anything paper]( + https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + project_dim (int): the dimensionality of the projection of the + encoder, and output of the `MultiHeadAttention`. + intermediate_dim (int): the intermediate dimensionality of the MLP head + before projecting to `project_dim`. + num_heads (int): the number of heads for the `MultiHeadAttention` + layer. + use_bias (bool, optional): Whether to use bias to project the keys, + queries, and values in the attention layer. Defaults to `True`. + use_rel_pos (bool, optional): Whether to use relative positional + emcodings in the attention layer. Defaults to `False`. + window_size (int, optional): Window size for windowed attention. + Defaults to `0`. + input_size (tuple[int, int], optional): Height and width of the input + image as a tuple of integers. Must be provided when using relative + positional embeddings. Defaults to `None`. + activation (str, optional): the activation function to apply in the + MLP head - should be a function. Defaults to `"gelu"`. + layer_norm_epsilon (float, optional): The epsilon to use in the layer + normalization layers. Defaults to `1e-6`. + """ + + def __init__( + self, + project_dim, + intermediate_dim, + num_heads, + use_bias=True, + use_rel_pos=False, + window_size=0, + input_size=None, + activation="gelu", + layer_norm_epsilon=1e-6, + **kwargs + ): + super().__init__(**kwargs) + self.project_dim = project_dim + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.use_bias = use_bias + self.input_size = input_size + self.activation = activation + self.layer_norm_epsilon = layer_norm_epsilon + self.window_size = window_size + self.use_rel_pos = use_rel_pos + + self.layer_norm1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self.layer_norm2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self.attention = MultiHeadAttentionWithRelativePE( + num_heads=self.num_heads, + key_dim=self.project_dim // self.num_heads, + use_bias=use_bias, + use_rel_pos=use_rel_pos, + input_size=( + input_size if window_size == 0 else (window_size, window_size) + ), + ) + self.mlp_block = MLP( + intermediate_dim, + project_dim, + activation="gelu", + ) + self.window_partitioning = WindowPartitioning(window_size) + + def build(self, input_shape=None): + self.layer_norm1.build([None, None, None, self.project_dim]) + self.layer_norm2.build([None, None, None, self.project_dim]) + self.attention.build() + self.mlp_block.build([None, None, None, self.project_dim]) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + shortcut = x + x = self.layer_norm1(x) + # Window Partition + if self.window_size > 0: + height, width = ops.shape(x)[1], ops.shape(x)[2] + x, height_width_padded = self.window_partitioning.partition(x) + + x = self.attention(x) + # Reverse Window Partition + if self.window_size > 0: + x = self.window_partitioning.unpartition( + x, + height_width_padded=height_width_padded, + height_width=(height, width), + ) + x = shortcut + x + x = x + self.mlp_block(self.layer_norm2(x)) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "project_dim": self.project_dim, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "use_bias": self.use_bias, + "use_rel_pos": self.use_rel_pos, + "window_size": self.window_size, + "input_size": self.input_size, + "activation": self.activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + +class ViTDetPatchingAndEmbedding(keras.layers.Layer): + """ + Implements a image patch and embedding layer. + + Image to Patch Embedding using only a conv layer (without + layer normalization).The code has been adapted from [Segment Anything + paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub]( + https://github.com/facebookresearch/segment-anything) and [Detectron2]( + https://github.com/facebookresearch/detectron2). + + Args: + kernel_size (tuple[int, int], optional): Kernel size of the + projection layer. Defaults to `(16, 16)`. + strides (tuple, optional): Strides of the projection layer. + Defaults to `(16, 16)`. + embed_dim (int, optional): Number of filters to use in the + projection layer i.e. projection size. Defaults to `768`. + """ + + def __init__( + self, kernel_size=(16, 16), strides=(16, 16), embed_dim=768, **kwargs + ): + super().__init__(**kwargs) + + self.projection = keras.layers.Conv2D( + embed_dim, kernel_size=kernel_size, strides=strides + ) + self.kernel_size = kernel_size + self.strides = strides + self.embed_dim = embed_dim + + def build(self, input_shape): + self.projection.build(input_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return self.projection.compute_output_shape(input_shape) + + def call(self, x): + x = self.projection(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "kernel_size": self.kernel_size, + "strides": self.strides, + "embed_dim": self.embed_dim, + } + ) + return config + + +class AddPositionalEmbedding(keras.layers.Layer): + def __init__(self, img_size, patch_size, embed_dim, **kwargs): + super().__init__(**kwargs) + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.pos_embed = self.add_weight( + name="pos_embed", + shape=( + 1, + img_size // patch_size, + img_size // patch_size, + embed_dim, + ), + initializer="zeros", + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + return x + self.pos_embed + + def get_confg(self): + config = super().get_config() + config.update( + { + "img_size": self.img_size, + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + } + ) + return config From fc485d6a259be75ce1103a3c114fe56d06cc5940 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 20 Aug 2024 12:26:55 -0700 Subject: [PATCH 07/19] Add Mix transformer (#1780) * Add MixTransformer * fix testcase * test changes and comments * lint fix * update config list * modify testcase for 2 layers --- keras_nlp/api/models/__init__.py | 6 + .../src/models/mix_transformer/__init__.py | 13 + .../mix_transformer_backbone.py | 181 +++++++++++ .../mix_transformer_backbone_test.py | 75 +++++ .../mix_transformer_classifier.py | 133 ++++++++ .../mix_transformer_classifier_test.py | 70 ++++ .../mix_transformer/mix_transformer_layers.py | 300 ++++++++++++++++++ 7 files changed, 778 insertions(+) create mode 100644 keras_nlp/src/models/mix_transformer/__init__.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py create mode 100644 keras_nlp/src/models/mix_transformer/mix_transformer_layers.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 1a6dd2e74f..c6d8ed7d32 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -165,6 +165,12 @@ MistralPreprocessor, ) from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MiTImageClassifier, +) from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/mix_transformer/__init__.py b/keras_nlp/src/models/mix_transformer/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py new file mode 100644 index 0000000000..2cfe7f6761 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -0,0 +1,181 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import numpy as np +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + HierarchicalTransformerEncoder, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( + OverlappingPatchingAndEmbedding, +) + + +@keras_nlp_export("keras_nlp.models.MiTBackbone") +class MiTBackbone(FeaturePyramidBackbone): + def __init__( + self, + depths, + num_layers, + blockwise_num_heads, + blockwise_sr_ratios, + end_value, + patch_sizes, + strides, + include_rescaling=True, + image_shape=(224, 224, 3), + hidden_dims=None, + **kwargs, + ): + """A Backbone implementing the MixTransformer. + + This architecture to be used as a backbone for the SegFormer + architecture [SegFormer: Simple and Efficient Design for Semantic + Segmentation with Transformers](https://arxiv.org/abs/2105.15203) + [Based on the TensorFlow implementation from DeepVision]( + https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) + + Args: + depths: The number of transformer encoders to be used per layer in the + network. + num_layers: int. The number of Transformer layers. + blockwise_num_heads: list of integers, the number of heads to use + in the attention computation for each layer. + blockwise_sr_ratios: list of integers, the sequence reduction + ratio to perform for each layer on the sequence before key and + value projections. If set to > 1, a `Conv2D` layer is used to + reduce the length of the sequence. + end_value: The end value of the sequence. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. Defaults to `True`. + image_shape: optional shape tuple, defaults to (224, 224, 3). + hidden_dims: the embedding dims per hierarchical layer, used as + the levels of the feature pyramid. + patch_sizes: list of integers, the patch_size to apply for each layer. + strides: list of integers, stride to apply for each layer. + + Examples: + + Using the class with a `backbone`: + + ```python + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet") + + # Evaluate model + model(images) + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + dpr = [x for x in np.linspace(0.0, end_value, sum(depths))] + + # === Layers === + cur = 0 + patch_embedding_layers = [] + transformer_blocks = [] + layer_norms = [] + + for i in range(num_layers): + patch_embed_layer = OverlappingPatchingAndEmbedding( + project_dim=hidden_dims[i], + patch_size=patch_sizes[i], + stride=strides[i], + name=f"patch_and_embed_{i}", + ) + patch_embedding_layers.append(patch_embed_layer) + + transformer_block = [ + HierarchicalTransformerEncoder( + project_dim=hidden_dims[i], + num_heads=blockwise_num_heads[i], + sr_ratio=blockwise_sr_ratios[i], + drop_prob=dpr[cur + k], + name=f"hierarchical_encoder_{i}_{k}", + ) + for k in range(depths[i]) + ] + transformer_blocks.append(transformer_block) + cur += depths[i] + layer_norms.append(keras.layers.LayerNormalization()) + + # === Functional Model === + image_input = keras.layers.Input(shape=image_shape) + x = image_input + + if include_rescaling: + x = keras.layers.Rescaling(scale=1 / 255)(x) + + pyramid_outputs = {} + for i in range(num_layers): + # Compute new height/width after the `proj` + # call in `OverlappingPatchingAndEmbedding` + stride = strides[i] + new_height, new_width = ( + int(ops.shape(x)[1] / stride), + int(ops.shape(x)[2] / stride), + ) + + x = patch_embedding_layers[i](x) + for blk in transformer_blocks[i]: + x = blk(x) + x = layer_norms[i](x) + x = keras.layers.Reshape( + (new_height, new_width, -1), name=f"output_level_{i}" + )(x) + pyramid_outputs[f"P{i + 1}"] = x + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.depths = depths + self.include_rescaling = include_rescaling + self.image_shape = image_shape + self.hidden_dims = hidden_dims + self.pyramid_outputs = pyramid_outputs + self.num_layers = num_layers + self.blockwise_num_heads = blockwise_num_heads + self.blockwise_sr_ratios = blockwise_sr_ratios + self.end_value = end_value + self.patch_sizes = patch_sizes + self.strides = strides + + def get_config(self): + config = super().get_config() + config.update( + { + "depths": self.depths, + "include_rescaling": self.include_rescaling, + "hidden_dims": self.hidden_dims, + "image_shape": self.image_shape, + "num_layers": self.num_layers, + "blockwise_num_heads": self.blockwise_num_heads, + "blockwise_sr_ratios": self.blockwise_sr_ratios, + "end_value": self.end_value, + "patch_sizes": self.patch_sizes, + "strides": self.strides, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py new file mode 100644 index 0000000000..4f1955297f --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from keras import models + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MiTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "depths": [2, 2], + "include_rescaling": True, + "image_shape": (16, 16, 3), + "hidden_dims": [4, 8], + "num_layers": 2, + "blockwise_num_heads": [1, 2], + "blockwise_sr_ratios": [8, 4], + "end_value": 0.1, + "patch_sizes": [7, 3], + "strides": [4, 2], + } + self.input_size = 16 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 2, 2, 8), + run_quantization_check=False, + run_mixed_precision_check=False, + ) + + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs + backbone = MiTBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P1", "P2"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** (int(k[1:]) + 1)) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py new file mode 100644 index 0000000000..a9a51b63ba --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) + + +@keras_nlp_export("keras_nlp.models.MiTImageClassifier") +class MiTImageClassifier(ImageClassifier): + """MiTImageClassifier image classifier model. + + Args: + backbone: A `keras_nlp.models.MiTBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.MiTImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.MiTImageClassifier.from_preset( + "mit_b0_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.MiTBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.MiTImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = MiTBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py new file mode 100644 index 0000000000..57b0671be2 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -0,0 +1,70 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( + MiTImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MiTImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MiTBackbone( + depths=[2, 2, 2, 2], + include_rescaling=True, + image_shape=(16, 16, 3), + hidden_dims=[4, 8], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + end_value=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MiTImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py new file mode 100644 index 0000000000..53d99fe484 --- /dev/null +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_layers.py @@ -0,0 +1,300 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import keras +from keras import ops +from keras import random + + +class OverlappingPatchingAndEmbedding(keras.layers.Layer): + def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): + """Overlapping Patching and Embedding layer. + + Differs from `PatchingAndEmbedding` in that the patch size does not + affect the sequence length. It's fully derived from the `stride` + parameter. Additionally, no positional embedding is done + as part of the layer - only a projection using a `Conv2D` layer. + + Args: + project_dim: integer, the dimensionality of the projection. + Defaults to `32`. + patch_size: integer, the size of the patches to encode. + Defaults to `7`. + stride: integer, the stride to use for the patching before + projection. Defaults to `5`. + """ + super().__init__(**kwargs) + + self.project_dim = project_dim + self.patch_size = patch_size + self.stride = stride + + self.proj = keras.layers.Conv2D( + filters=project_dim, + kernel_size=patch_size, + strides=stride, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + x = self.proj(x) + # B, H, W, C + shape = x.shape + x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) + x = self.norm(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "project_dim": self.project_dim, + "patch_size": self.patch_size, + "stride": self.stride, + } + ) + return config + + +class HierarchicalTransformerEncoder(keras.layers.Layer): + """Hierarchical transformer encoder block implementation as a Keras Layer. + + The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` + alternative for computational efficiency, and is meant to be used + within the SegFormer architecture. + + Args: + project_dim: integer, the dimensionality of the projection of the + encoder, and output of the `SegFormerMultiheadAttention` layer. + Due to the residual addition the input dimensionality has to be + equal to the output dimensionality. + num_heads: integer, the number of heads for the + `SegFormerMultiheadAttention` layer. + drop_prob: float, the probability of dropping a random + sample using the `DropPath` layer. Defaults to `0.0`. + layer_norm_epsilon: float, the epsilon for + `LayerNormalization` layers. Defaults to `1e-06` + sr_ratio: integer, the ratio to use within + `SegFormerMultiheadAttention`. If set to > 1, a `Conv2D` + layer is used to reduce the length of the sequence. Defaults to `1`. + """ + + def __init__( + self, + project_dim, + num_heads, + sr_ratio=1, + drop_prob=0.0, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.project_dim = project_dim + self.num_heads = num_heads + self.drop_prop = drop_prob + + self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.attn = SegFormerMultiheadAttention( + project_dim, num_heads, sr_ratio + ) + self.drop_path = DropPath(drop_prob) + self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) + self.mlp = MixFFN( + channels=project_dim, + mid_channels=int(project_dim * 4), + ) + + def build(self, input_shape): + super().build(input_shape) + self.H = ops.sqrt(ops.cast(input_shape[1], "float32")) + self.W = ops.sqrt(ops.cast(input_shape[2], "float32")) + + def call(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "mlp": keras.saving.serialize_keras_object(self.mlp), + "project_dim": self.project_dim, + "num_heads": self.num_heads, + "drop_prop": self.drop_prop, + } + ) + return config + + +class MixFFN(keras.layers.Layer): + def __init__(self, channels, mid_channels): + super().__init__() + self.fc1 = keras.layers.Dense(mid_channels) + self.dwconv = keras.layers.DepthwiseConv2D( + kernel_size=3, + strides=1, + padding="same", + ) + self.fc2 = keras.layers.Dense(channels) + + def call(self, x): + x = self.fc1(x) + shape = ops.shape(x) + H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) + B, C = shape[0], shape[2] + x = ops.reshape(x, (B, H, W, C)) + x = self.dwconv(x) + x = ops.reshape(x, (B, -1, C)) + x = ops.nn.gelu(x) + x = self.fc2(x) + return x + + +class SegFormerMultiheadAttention(keras.layers.Layer): + def __init__(self, project_dim, num_heads, sr_ratio): + """Efficient MultiHeadAttention implementation as a Keras layer. + + A huge bottleneck in scaling transformers is the self-attention layer + with an O(n^2) complexity. + + SegFormerMultiheadAttention performs a sequence reduction (SR) operation + with a given ratio, to reduce the sequence length before performing key + and value projections, reducing the O(n^2) complexity to O(n^2/R) where + R is the sequence reduction ratio. + + Args: + project_dim: integer, the dimensionality of the projection + of the `SegFormerMultiheadAttention` layer. + num_heads: integer, the number of heads to use in the + attention computation. + sr_ratio: integer, the sequence reduction ratio to perform + on the sequence before key and value projections. + """ + super().__init__() + self.num_heads = num_heads + self.sr_ratio = sr_ratio + self.scale = (project_dim // num_heads) ** -0.5 + self.q = keras.layers.Dense(project_dim) + self.k = keras.layers.Dense(project_dim) + self.v = keras.layers.Dense(project_dim) + self.proj = keras.layers.Dense(project_dim) + + if sr_ratio > 1: + self.sr = keras.layers.Conv2D( + filters=project_dim, + kernel_size=sr_ratio, + strides=sr_ratio, + padding="same", + ) + self.norm = keras.layers.LayerNormalization() + + def call(self, x): + input_shape = ops.shape(x) + H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1])) + B, C = input_shape[0], input_shape[2] + + q = self.q(x) + q = ops.reshape( + q, + ( + input_shape[0], + input_shape[1], + self.num_heads, + input_shape[2] // self.num_heads, + ), + ) + q = ops.transpose(q, [0, 2, 1, 3]) + + if self.sr_ratio > 1: + x = ops.reshape( + ops.transpose(x, [0, 2, 1]), + (B, H, W, C), + ) + x = self.sr(x) + x = ops.reshape(x, [input_shape[0], input_shape[2], -1]) + x = ops.transpose(x, [0, 2, 1]) + x = self.norm(x) + + k = self.k(x) + v = self.v(x) + + k = ops.transpose( + ops.reshape( + k, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + v = ops.transpose( + ops.reshape( + v, + [B, -1, self.num_heads, C // self.num_heads], + ), + [0, 2, 1, 3], + ) + + attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale + attn = ops.nn.softmax(attn, axis=-1) + + attn = attn @ v + attn = ops.reshape( + ops.transpose(attn, [0, 2, 1, 3]), + [input_shape[0], input_shape[1], input_shape[2]], + ) + + x = self.proj(attn) + return x + + +class DropPath(keras.layers.Layer): + """Implements the DropPath layer. + + DropPath randomly drops samples during + training with a probability of `rate`. Note that this layer drops individual + samples within a batch and not the entire batch, whereas StochasticDepth + randomly drops the entire batch. + + Args: + rate: float, the probability of the residual branch being dropped. + seed: (Optional) integer. Used to create a random seed. + """ + + def __init__(self, rate=0.5, seed=None, **kwargs): + super().__init__(**kwargs) + self.rate = rate + self._seed_val = seed + self.seed = random.SeedGenerator(seed=seed) + + def call(self, x, training=None): + if self.rate == 0.0 or not training: + return x + else: + batch_size = x.shape[0] or ops.shape(x)[0] + drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1) + drop_map = ops.cast( + random.uniform(drop_map_shape, seed=self.seed) > self.rate, + x.dtype, + ) + x = x / (1.0 - self.rate) + x = x * drop_map + return x + + def get_config(self): + config = super().get_config() + config.update({"rate": self.rate, "seed": self._seed_val}) + return config From 2797851c259ce36bb51c6e93baeb3b282b152663 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 20 Aug 2024 23:00:43 -0700 Subject: [PATCH 08/19] update input_image_shape -> image_shape (#1785) * update input_image_shape -> image_shape * update docstring example * code reformat * update tests --- .../models/csp_darknet/csp_darknet_backbone.py | 10 +++++----- .../csp_darknet/csp_darknet_backbone_test.py | 2 +- .../csp_darknet/csp_darknet_image_classifier.py | 2 +- .../csp_darknet_image_classifier_test.py | 2 +- .../src/models/densenet/densenet_backbone.py | 10 +++++----- .../src/models/densenet/densenet_backbone_test.py | 2 +- .../models/densenet/densenet_image_classifier.py | 2 +- .../densenet/densenet_image_classifier_test.py | 2 +- keras_nlp/src/models/resnet/resnet_backbone.py | 10 +++++----- .../src/models/resnet/resnet_backbone_test.py | 4 ++-- .../models/resnet/resnet_image_classifier_test.py | 2 +- keras_nlp/src/models/vgg/vgg_backbone.py | 15 +++++++-------- keras_nlp/src/models/vgg/vgg_backbone_test.py | 2 +- keras_nlp/src/models/vgg/vgg_image_classifier.py | 2 +- .../src/models/vgg/vgg_image_classifier_test.py | 2 +- keras_nlp/src/tests/test_case.py | 6 +++--- 16 files changed, 37 insertions(+), 38 deletions(-) diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py index 2745f61d01..607c6895ba 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -38,7 +38,7 @@ class CSPDarkNetBackbone(Backbone): Use `"depthwise_block"` for depthwise conv block `"basic_block"` for basic conv block. Defaults to "basic_block". - input_image_shape: tuple. The input shape without the batch size. + image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. Examples: @@ -67,7 +67,7 @@ def __init__( stackwise_depth, include_rescaling, block_type="basic_block", - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), **kwargs, ): # === Functional Model === @@ -78,7 +78,7 @@ def __init__( ) base_channels = stackwise_num_filters[0] // 2 - image_input = layers.Input(shape=input_image_shape) + image_input = layers.Input(shape=image_shape) x = image_input if include_rescaling: x = layers.Rescaling(scale=1 / 255.0)(x) @@ -119,7 +119,7 @@ def __init__( self.stackwise_depth = stackwise_depth self.include_rescaling = include_rescaling self.block_type = block_type - self.input_image_shape = input_image_shape + self.image_shape = image_shape def get_config(self): config = super().get_config() @@ -129,7 +129,7 @@ def get_config(self): "stackwise_depth": self.stackwise_depth, "include_rescaling": self.include_rescaling, "block_type": self.block_type, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, } ) return config diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py index aaad4fe515..857e06039d 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -28,7 +28,7 @@ def setUp(self): "stackwise_depth": [1, 3, 3, 1], "include_rescaling": False, "block_type": "basic_block", - "input_image_shape": (224, 224, 3), + "image_shape": (224, 224, 3), } self.input_data = np.ones((2, 224, 224, 3), dtype="float32") diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py index 6b013bdcc0..09a7022122 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -78,7 +78,7 @@ class CSPDarkNetImageClassifier(ImageClassifier): stackwise_depth=[3, 9, 9, 3], include_rescaling=False, block_type="basic_block", - input_image_shape = (224, 224, 3), + image_shape = (224, 224, 3), ) classifier = keras_nlp.models.CSPDarkNetImageClassifier( backbone=backbone, diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py index a07bb017a3..33261c25b6 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py @@ -33,7 +33,7 @@ def setUp(self): stackwise_depth=[1, 3, 3, 1], include_rescaling=False, block_type="basic_block", - input_image_shape=(16, 16, 3), + image_shape=(16, 16, 3), ) self.init_kwargs = { "backbone": self.backbone, diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py index 8456fbcee6..60a5b28849 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone.py +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -35,7 +35,7 @@ class DenseNetBackbone(Backbone): include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer. Defaults to `True`. - input_image_shape: optional shape tuple, defaults to (224, 224, 3). + image_shape: optional shape tuple, defaults to (224, 224, 3). compression_ratio: float, compression rate at transition layers, defaults to 0.5. growth_rate: int, number of filters added by each dense block, @@ -62,13 +62,13 @@ def __init__( self, stackwise_num_repeats, include_rescaling=True, - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), compression_ratio=0.5, growth_rate=32, **kwargs, ): # === Functional Model === - image_input = keras.layers.Input(shape=input_image_shape) + image_input = keras.layers.Input(shape=image_shape) x = image_input if include_rescaling: @@ -116,7 +116,7 @@ def __init__( self.include_rescaling = include_rescaling self.compression_ratio = compression_ratio self.growth_rate = growth_rate - self.input_image_shape = input_image_shape + self.image_shape = image_shape def get_config(self): config = super().get_config() @@ -126,7 +126,7 @@ def get_config(self): "include_rescaling": self.include_rescaling, "compression_ratio": self.compression_ratio, "growth_rate": self.growth_rate, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, } ) return config diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py index f0f8dac875..63f358035c 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone_test.py +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -26,7 +26,7 @@ def setUp(self): "include_rescaling": True, "compression_ratio": 0.5, "growth_rate": 32, - "input_image_shape": (224, 224, 3), + "image_shape": (224, 224, 3), } self.input_data = np.ones((2, 224, 224, 3), dtype="float32") diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py index 395e8f754d..130904be70 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -76,7 +76,7 @@ class DenseNetImageClassifier(ImageClassifier): stackwise_depth=[3, 9, 9, 3], include_rescaling=False, block_type="basic_block", - input_image_shape = (224, 224, 3), + image_shape = (224, 224, 3), ) classifier = keras_nlp.models.DenseNetImageClassifier( backbone=backbone, diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py index 60d77d489c..439a60008d 100644 --- a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -31,7 +31,7 @@ def setUp(self): include_rescaling=True, compression_ratio=0.5, growth_rate=32, - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), ) self.init_kwargs = { "backbone": self.backbone, diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 0f4d7c139a..31698e0a1c 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -54,7 +54,7 @@ class ResNetBackbone(FeaturePyramidBackbone): include_rescaling: boolean. If `True`, rescale the input using `Rescaling` and `Normalization` layers. If `False`, do nothing. Defaults to `True`. - input_image_shape: tuple. The input shape without the batch size. + image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. pooling: `None` or str. Pooling mode for feature extraction. Defaults to `"avg"`. @@ -107,7 +107,7 @@ def __init__( block_type, use_pre_activation=False, include_rescaling=True, - input_image_shape=(None, None, 3), + image_shape=(None, None, 3), pooling="avg", data_format=None, dtype=None, @@ -139,7 +139,7 @@ def __init__( num_stacks = len(stackwise_num_filters) # === Functional Model === - image_input = layers.Input(shape=input_image_shape) + image_input = layers.Input(shape=image_shape) if include_rescaling: x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) x = layers.Normalization( @@ -254,7 +254,7 @@ def __init__( self.block_type = block_type self.use_pre_activation = use_pre_activation self.include_rescaling = include_rescaling - self.input_image_shape = input_image_shape + self.image_shape = image_shape self.pooling = pooling self.pyramid_outputs = pyramid_outputs @@ -268,7 +268,7 @@ def get_config(self): "block_type": self.block_type, "use_pre_activation": self.use_pre_activation, "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, "pooling": self.pooling, } ) diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 6d3f774559..a6a30362cd 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -27,7 +27,7 @@ def setUp(self): "stackwise_num_filters": [64, 64, 64], "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], - "input_image_shape": (None, None, 3), + "image_shape": (None, None, 3), "pooling": "avg", } self.input_size = 64 @@ -84,7 +84,7 @@ def test_saved_model(self, use_pre_activation, block_type): { "block_type": block_type, "use_pre_activation": use_pre_activation, - "input_image_shape": (None, None, 3), + "image_shape": (None, None, 3), } ) self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index f3f63a14a1..893ec42487 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -31,7 +31,7 @@ def setUp(self): stackwise_num_strides=[1, 2, 2], block_type="basic_block", use_pre_activation=True, - input_image_shape=(16, 16, 3), + image_shape=(16, 16, 3), include_rescaling=False, pooling="avg", ) diff --git a/keras_nlp/src/models/vgg/vgg_backbone.py b/keras_nlp/src/models/vgg/vgg_backbone.py index 497381c0fc..b215261fed 100644 --- a/keras_nlp/src/models/vgg/vgg_backbone.py +++ b/keras_nlp/src/models/vgg/vgg_backbone.py @@ -20,8 +20,7 @@ @keras_nlp_export("keras_nlp.models.VGGBackbone") class VGGBackbone(Backbone): - """ - This class represents Keras Backbone of VGG model. + """This class represents Keras Backbone of VGG model. This class implements a VGG backbone as described in [Very Deep Convolutional Networks for Large-Scale Image Recognition]( @@ -36,7 +35,7 @@ class VGGBackbone(Backbone): 64, 128, 256, 512, 512]. include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a `Rescaling(1/255.0)` layer. - input_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + image_shape: tuple, optional shape tuple, defaults to (224, 224, 3). pooling: bool, Optional pooling mode for feature extraction when `include_top` is `False`. - `None` means that the output of the model will be @@ -61,7 +60,7 @@ class VGGBackbone(Backbone): model = keras_nlp.models.VGGBackbone( stackwise_num_repeats = [2, 2, 3, 3, 3], stackwise_num_filters = [64, 128, 256, 512, 512], - input_shape = (224, 224, 3), + image_shape = (224, 224, 3), include_rescaling = False, pooling = "avg", ) @@ -74,13 +73,13 @@ def __init__( stackwise_num_repeats, stackwise_num_filters, include_rescaling, - input_image_shape=(224, 224, 3), + image_shape=(224, 224, 3), pooling="avg", **kwargs, ): # === Functional Model === - img_input = keras.layers.Input(shape=input_image_shape) + img_input = keras.layers.Input(shape=image_shape) x = img_input if include_rescaling: @@ -107,7 +106,7 @@ def __init__( self.stackwise_num_repeats = stackwise_num_repeats self.stackwise_num_filters = stackwise_num_filters self.include_rescaling = include_rescaling - self.input_image_shape = input_image_shape + self.image_shape = image_shape self.pooling = pooling def get_config(self): @@ -115,7 +114,7 @@ def get_config(self): "stackwise_num_repeats": self.stackwise_num_repeats, "stackwise_num_filters": self.stackwise_num_filters, "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, + "image_shape": self.image_shape, "pooling": self.pooling, } diff --git a/keras_nlp/src/models/vgg/vgg_backbone_test.py b/keras_nlp/src/models/vgg/vgg_backbone_test.py index 05ed33ba0f..d5521ca92d 100644 --- a/keras_nlp/src/models/vgg/vgg_backbone_test.py +++ b/keras_nlp/src/models/vgg/vgg_backbone_test.py @@ -24,7 +24,7 @@ def setUp(self): self.init_kwargs = { "stackwise_num_repeats": [2, 3, 3], "stackwise_num_filters": [8, 64, 64], - "input_image_shape": (16, 16, 3), + "image_shape": (16, 16, 3), "include_rescaling": False, "pooling": "avg", } diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier.py b/keras_nlp/src/models/vgg/vgg_image_classifier.py index a26fbfbc30..d849586ed8 100644 --- a/keras_nlp/src/models/vgg/vgg_image_classifier.py +++ b/keras_nlp/src/models/vgg/vgg_image_classifier.py @@ -65,7 +65,7 @@ class VGGImageClassifier(ImageClassifier): backbone = keras_nlp.models.VGGBackbone( stackwise_num_repeats = [2, 2, 3, 3, 3], stackwise_num_filters = [64, 128, 256, 512, 512], - input_shape = (224, 224, 3), + image_shape = (224, 224, 3), include_rescaling = False, pooling = "avg", ) diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py index 4a2573e496..20d855cb66 100644 --- a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py @@ -27,7 +27,7 @@ def setUp(self): self.backbone = VGGBackbone( stackwise_num_repeats=[2, 4, 4], stackwise_num_filters=[2, 16, 16], - input_image_shape=(4, 4, 3), + image_shape=(4, 4, 3), include_rescaling=False, pooling="max", ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 9a94d6357f..8e63bc19d9 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -501,10 +501,10 @@ def run_vision_backbone_test( input_data = ops.transpose(input_data, axes=(2, 0, 1)) elif len(input_data_shape) == 4: input_data = ops.transpose(input_data, axes=(0, 3, 1, 2)) - if "input_image_shape" in init_kwargs: + if "image_shape" in init_kwargs: init_kwargs = init_kwargs.copy() - init_kwargs["input_image_shape"] = tuple( - reversed(init_kwargs["input_image_shape"]) + init_kwargs["image_shape"] = tuple( + reversed(init_kwargs["image_shape"]) ) self.run_backbone_test( cls=cls, From 18f88803daa194b65914a77d6a25cade585aaf6c Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 21 Aug 2024 17:13:02 -0700 Subject: [PATCH 09/19] Create __init__.py (#1788) add missing __init__ file to vit_det --- keras_nlp/src/models/vit_det/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 keras_nlp/src/models/vit_det/__init__.py diff --git a/keras_nlp/src/models/vit_det/__init__.py b/keras_nlp/src/models/vit_det/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/vit_det/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 2ee893ce7ceb292f98e6dc184faa74e959df8836 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:29:21 -0700 Subject: [PATCH 10/19] Hack package build script to rename to keras-hub (#1793) This is a temporary way to test out the keras-hub branch. - Does a global rename of all symbols during package build. - Registers the "old" name on symbol export for saving compat. - Adds a github action to publish every commit to keras-hub as a new package. - Removes our descriptions on PyPI temporarily, until we want to message this more broadly. --- .github/workflows/publish-hub-to-pypi.yml | 43 +++++++++++++++++++++++ keras_nlp/src/api_export.py | 5 +++ keras_nlp/src/utils/preset_utils.py | 3 +- pip_build.py | 19 ++++++---- setup.py | 7 ++-- 5 files changed, 65 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/publish-hub-to-pypi.yml diff --git a/.github/workflows/publish-hub-to-pypi.yml b/.github/workflows/publish-hub-to-pypi.yml new file mode 100644 index 0000000000..838ca9b698 --- /dev/null +++ b/.github/workflows/publish-hub-to-pypi.yml @@ -0,0 +1,43 @@ +name: Publish Hub to PyPI + +on: + push: + branches: + - keras-hub + +permissions: + contents: read + +jobs: + build-and-publish: + name: Build and publish Hub to PyPI + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9 + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "::set-output name=dir::$(pip cache dir)" + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off + - name: Build a binary wheel and a source tarball + run: >- + python pip_build.py + - name: Publish distribution to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN_HUB }} diff --git a/keras_nlp/src/api_export.py b/keras_nlp/src/api_export.py index cfa3519ce9..93e7b54c2f 100644 --- a/keras_nlp/src/api_export.py +++ b/keras_nlp/src/api_export.py @@ -24,6 +24,11 @@ def maybe_register_serializable(symbol): if isinstance(symbol, types.FunctionType) or hasattr(symbol, "get_config"): + # We register twice, first with the old name, second with the new name, + # so loading still works under the old name. + # TODO replace compat_package_name with keras-nlp after rename. + compat_name = "compat_package_name" + keras.saving.register_keras_serializable(package=compat_name)(symbol) keras.saving.register_keras_serializable(package="keras_nlp")(symbol) diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 0277eb74a7..297e4bcb7f 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -99,7 +99,8 @@ def list_presets(cls): def list_subclasses(cls): """Find all registered subclasses of a class.""" - custom_objects = keras.saving.get_custom_objects().values() + # Deduplicate the lists, since we have to register object twice for compat. + custom_objects = set(keras.saving.get_custom_objects().values()) subclasses = [] for x in custom_objects: if inspect.isclass(x) and x != cls and issubclass(x, cls): diff --git a/pip_build.py b/pip_build.py index 7fed385c71..f14f24312b 100644 --- a/pip_build.py +++ b/pip_build.py @@ -36,7 +36,7 @@ import re import shutil -package = "keras_nlp" +package = "keras_hub" build_directory = "tmp_build_dir" dist_directory = "dist" to_copy = ["setup.py", "setup.cfg", "README.md"] @@ -48,15 +48,15 @@ def ignore_files(_, filenames): def export_version_string(version, is_nightly=False): """Export Version and Package Name.""" + date = datetime.datetime.now() + version += f".dev{date.strftime('%Y%m%d%H%M%S')}" if is_nightly: - date = datetime.datetime.now() - version += f".dev{date.strftime('%Y%m%d%H')}" - # Replaces `name="keras-nlp"` in `setup.py` with `keras-nlp-nightly` + # Replaces `name="keras-hub"` in `setup.py` with `keras-hub-nightly` with open("setup.py") as f: setup_contents = f.read() with open("setup.py", "w") as f: setup_contents = setup_contents.replace( - 'name="keras-nlp"', 'name="keras-nlp-nightly"' + 'name="keras-hub"', 'name="keras-hub-nightly"' ) f.write(setup_contents) @@ -78,11 +78,18 @@ def copy_source_to_build_directory(root_path): os.chdir(root_path) os.mkdir(build_directory) shutil.copytree( - package, os.path.join(build_directory, package), ignore=ignore_files + "keras_nlp", os.path.join(build_directory, package), ignore=ignore_files ) for fname in to_copy: shutil.copy(fname, os.path.join(f"{build_directory}", fname)) os.chdir(build_directory) + # TODO: remove all of this when our code is actually renamed in the repo. + os.system("grep -lR 'keras_nlp' . | xargs sed -i 's/keras_nlp/keras_hub/g'") + os.system("grep -lR 'keras-nlp' . | xargs sed -i 's/keras-nlp/keras-hub/g'") + os.system("grep -lR 'KerasNLP' . | xargs sed -i 's/KerasNLP/KerasHub/g'") + os.system( + "grep -lR 'compat_package_name' . | xargs sed -i 's/compat_package_name/keras_nlp/g'" + ) def build(root_path, is_nightly=False): diff --git a/setup.py b/setup.py index f664aa5a61..f2ec7b84be 100644 --- a/setup.py +++ b/setup.py @@ -44,11 +44,8 @@ def get_version(rel_path): setup( name="keras-nlp", - description=( - "Industry-strength Natural Language Processing extensions for Keras." - ), - long_description=README, - long_description_content_type="text/markdown", + description="🚧🚧🚧 Work in progress. 🚧🚧🚧 More details soon!", + long_description="🚧🚧🚧 Work in progress. 🚧🚧🚧 More details soon!", version=VERSION, url="https://github.com/keras-team/keras-nlp", author="Keras team", From fdf6b6bdc75249c8beff4317a768c14101c7699f Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Tue, 27 Aug 2024 02:30:08 +0800 Subject: [PATCH 11/19] Add CLIP and T5XXL for StableDiffusionV3 (#1790) * Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTextEncoder`. * Make CLIPTextEncoder as Backbone * Add `T5XXLPreprocessor` and remove `T5XXLTokenizer` Add `CLIPPreprocessor` * Use `tf = None` at the top * Replace manual implementation of `CLIPAttention` with `MultiHeadAttention` --- .../models/stable_diffusion_v3/__init__.py | 13 ++ .../stable_diffusion_v3/clip_encoder_block.py | 103 +++++++++++ .../stable_diffusion_v3/clip_preprocessor.py | 104 +++++++++++ .../clip_preprocessor_test.py | 78 ++++++++ .../stable_diffusion_v3/clip_text_encoder.py | 141 +++++++++++++++ .../stable_diffusion_v3/clip_tokenizer.py | 167 ++++++++++++++++++ .../clip_tokenizer_test.py | 69 ++++++++ .../t5_xxl_preprocessor.py | 84 +++++++++ .../t5_xxl_preprocessor_test.py | 74 ++++++++ .../t5_xxl_text_encoder.py | 148 ++++++++++++++++ 10 files changed, 981 insertions(+) create mode 100644 keras_nlp/src/models/stable_diffusion_v3/__init__.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py diff --git a/keras_nlp/src/models/stable_diffusion_v3/__init__.py b/keras_nlp/src/models/stable_diffusion_v3/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py b/keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py new file mode 100644 index 0000000000..c4e16f8626 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py @@ -0,0 +1,103 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import layers +from keras import ops + + +def quick_gelu(x): + return x * ops.sigmoid(1.702 * x) + + +class CLIPEncoderBlock(layers.Layer): + def __init__( + self, + hidden_dim, + num_heads, + intermediate_dim, + intermediate_activation="quick_gelu", + **kwargs, + ): + super().__init__(**kwargs) + if hidden_dim % num_heads != 0: + raise ValueError( + "`hidden_dim` must be divisible by `num_heads`. " + f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}" + ) + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.intermediate_activation = intermediate_activation + + if intermediate_activation == "quick_gelu": + intermediate_activation = quick_gelu + + self.layer_norm_1 = layers.LayerNormalization( + epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1" + ) + self.attention = layers.MultiHeadAttention( + num_heads, + hidden_dim // num_heads, + dtype=self.dtype_policy, + name="attention", + ) + self.layer_norm_2 = layers.LayerNormalization( + epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2" + ) + self.dense_1 = layers.Dense( + self.intermediate_dim, dtype=self.dtype_policy, name="dense_1" + ) + self.activation = layers.Activation( + intermediate_activation, dtype=self.dtype_policy, name="activation" + ) + self.dense_2 = layers.Dense( + self.hidden_dim, dtype=self.dtype_policy, name="dense_2" + ) + + def build(self, input_shape): + self.layer_norm_1.build(input_shape) + self.attention.build(input_shape, input_shape, input_shape) + self.layer_norm_2.build(input_shape) + self.dense_1.build(input_shape) + input_shape = self.dense_1.compute_output_shape(input_shape) + self.dense_2.build(input_shape) + + def compute_output_shape(self, inputs_shape): + outputs_shape = list(inputs_shape) + outputs_shape[-1] = self.hidden_dim + return outputs_shape + + def call(self, x, training=None): + residual = x + x = self.layer_norm_1(x) + x = self.attention(x, x, x, training=training, use_causal_mask=True) + x = ops.add(residual, x) + + residual = x + x = self.dense_1(self.layer_norm_2(residual)) + x = self.activation(x) + x = self.dense_2(x) + x = ops.add(residual, x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "intermediate_activation": self.intermediate_activation, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py new file mode 100644 index 0000000000..ca1b6c598e --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py @@ -0,0 +1,104 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import ( + CLIPTokenizer, +) +from keras_nlp.src.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) + +try: + import tensorflow as tf +except ImportError: + tf = None + + +class CLIPPreprocessor(Preprocessor): + tokenizer_cls = CLIPTokenizer + + def __init__( + self, + tokenizer, + sequence_length=77, + add_start_token=True, + add_end_token=False, + to_lower=True, + pad_with_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.to_lower = to_lower + self.pad_with_end_token = pad_with_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + pad_value = self.tokenizer.pad_token_id + if self.pad_with_end_token: + pad_value = self.tokenizer.end_token_id + + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=pad_value, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + # TODO: Use `@tf_preprocessing_function` after rebasing. + def call(self, x, y=None, sample_weight=None, sequence_length=None): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "T5XXL requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using T5XXL" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + if self.to_lower: + x = tf.strings.lower(x) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + "to_lower": self.to_lower, + "pad_with_end_token": self.pad_with_end_token, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py new file mode 100644 index 0000000000..4365a14673 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.stable_diffusion_v3.clip_preprocessor import ( + CLIPPreprocessor, +) +from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import ( + CLIPTokenizer, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CLIPPreprocessorTest(TestCase): + def setUp(self): + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i + 1) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + self.tokenizer = CLIPTokenizer(vocabulary=vocab, merges=merges) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = [" airplane airport"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=CLIPPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = [" airplane airport"] * 4 + preprocessor = CLIPPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + pad_with_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = " airplane airport" + preprocessor = CLIPPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [5, 1, 2, 1]) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest("TODO") + for preset in CLIPPreprocessor.presets: + self.run_preset_test( + cls=CLIPPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py new file mode 100644 index 0000000000..d4a5cbc94f --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py @@ -0,0 +1,141 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import layers +from keras import ops + +from keras_nlp.src.layers.modeling.token_and_position_embedding import ( + TokenAndPositionEmbedding, +) +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.stable_diffusion_v3.clip_encoder_block import ( + CLIPEncoderBlock, +) + + +class CLIPTextEncoder(Backbone): + def __init__( + self, + embedding_dim, + hidden_dim, + num_layers, + num_heads, + intermediate_dim, + intermediate_activation="quick_gelu", + intermediate_output_index=None, + vocabulary_size=49408, + sequence_length=77, + dtype=None, + **kwargs, + ): + if ( + intermediate_output_index is not None + and intermediate_output_index < 0 + ): + intermediate_output_index += num_layers + + # === Layers === + self.embedding = TokenAndPositionEmbedding( + vocabulary_size=vocabulary_size, + sequence_length=sequence_length, + embedding_dim=embedding_dim, + dtype=dtype, + name="embedding", + ) + self.encoder_layers = [ + CLIPEncoderBlock( + hidden_dim, + num_heads, + intermediate_dim, + intermediate_activation, + dtype=dtype, + ) + for _ in range(num_layers) + ] + self.layer_norm = layers.LayerNormalization( + epsilon=0.00001, dtype=dtype, name="layer_norm" + ) + self.text_projection = layers.Dense( + hidden_dim, + use_bias=False, + dtype=dtype, + name="text_projection", + ) + + # === Functional Model === + encoder_token_ids = layers.Input( + shape=(sequence_length,), dtype="int32", name="encoder_token_ids" + ) + x = self.embedding(encoder_token_ids) + encoder_intermediate_output = None + # Encoder. + for i, block in enumerate(self.encoder_layers): + x = block(x) + if i == intermediate_output_index: + encoder_intermediate_output = x + x = self.layer_norm(x) + encoder_output = x + if encoder_intermediate_output is not None: + encoder_intermediate_output = self.layer_norm( + encoder_intermediate_output + ) + # Projection. + indices = ops.expand_dims( + ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1 + ) + pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1) + pooled_output = ops.squeeze(pooled_output, axis=1) + projection_output = self.text_projection(pooled_output) + + outputs = { + "encoder_sequence_output": encoder_output, + "encoder_pooled_output": pooled_output, + "encoder_projection_output": projection_output, + } + if intermediate_output_index is not None: + outputs["encoder_intermediate_output"] = encoder_intermediate_output + + super().__init__( + inputs={"encoder_token_ids": encoder_token_ids}, + outputs=outputs, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.intermediate_activation = intermediate_activation + self.intermediate_output_index = intermediate_output_index + self.vocabulary_size = vocabulary_size + self.sequence_length = sequence_length + + def get_config(self): + config = super().get_config() + config.update( + { + "embedding_dim": self.embedding_dim, + "hidden_dim": self.hidden_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "intermediate_activation": self.intermediate_activation, + "intermediate_output_index": self.intermediate_output_index, + "vocabulary_size": self.vocabulary_size, + "sequence_length": self.sequence_length, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py new file mode 100644 index 0000000000..59c046d9f5 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer.py @@ -0,0 +1,167 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_nlp.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch +from keras_nlp.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe + +try: + import tensorflow as tf +except ImportError: + tf = None + + +class CLIPTokenizer(BytePairTokenizer): + def __init__(self, vocabulary=None, merges=None, **kwargs): + self.start_token = "<|startoftext|>" + self.end_token = "<|endoftext|>" + + super().__init__( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[self.start_token, self.end_token], + **kwargs, + ) + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + + if vocabulary is not None: + # Check for necessary special tokens. + if self.end_token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{self.end_token}'` in the provided " + f"`vocabulary`. Please provide `'{self.end_token}'` in " + "your `vocabulary` or use a pretrained `vocabulary` name." + ) + + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = 0 + else: + self.end_token_id = None + self.start_token_id = None + self.pad_token_id = None + + def _bpe_merge_and_update_cache(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._transform_bytes(tokens) + + # In StableDiffusionV3, we need to add `` to the last word. + words = tf.strings.reduce_join(words, axis=1, separator=" ") + words = tf.strings.join([words, ""]) + words = tf.strings.split(words, sep=" ") + + tokenized_words = self._bpe_merge(words) + + # For each word, join all its token by a whitespace, + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. + tokenized_words = tf.strings.reduce_join( + tokenized_words, axis=1, separator=" " + ) + self.cache.insert(tokens, tokenized_words) + + def tokenize(self, inputs): + self._check_vocabulary() + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + if self.add_prefix_space: + inputs = tf.strings.join([" ", inputs]) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + + # Strip and remove empty tokens. + raw_tokens = tf.strings.strip(raw_tokens) + raw_tokens = tf.ragged.boolean_mask(raw_tokens, raw_tokens != "") + + token_row_splits = raw_tokens.row_splits + flat_tokens = raw_tokens.flat_values + + # Check cache. + cache_lookup = self.cache.lookup(flat_tokens) + cache_mask = cache_lookup == "" + + has_unseen_words = tf.math.reduce_any( + (cache_lookup == "") & (flat_tokens != "") + ) + + def process_unseen_tokens(): + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) + self._bpe_merge_and_update_cache(unseen_tokens) + return self.cache.lookup(flat_tokens) + + # If `has_unseen_words == True`, it means not all tokens are in cache, + # we will process the unseen tokens. Otherwise return the cache lookup. + tokenized_words = tf.cond( + has_unseen_words, + process_unseen_tokens, + lambda: cache_lookup, + ) + + tokens = tf.strings.split(tokenized_words, sep=" ") + if self.compute_dtype != tf.string: + # Encode merged tokens. + tokens = self.token_to_id_map.lookup(tokens) + + # Unflatten to match input. + tokens = tf.RaggedTensor.from_row_splits( + tokens.flat_values, + tf.gather(tokens.row_splits, token_row_splits), + ) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + # Convert to a dense output if input in scalar + if scalar_input: + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self.sequence_length]) + + return tokens + + def detokenize(self, inputs): + self._check_vocabulary() + inputs, unbatched, _ = convert_to_ragged_batch(inputs) + inputs = tf.cast(inputs, self.dtype) + unicode_text = tf.strings.reduce_join( + self.id_to_token_map.lookup(inputs), axis=-1 + ) + + # When detokenizing, we need to remove and extra whitespace. + unicode_text = tf.strings.regex_replace(unicode_text, r"", " ") + unicode_text = tf.strings.strip(unicode_text) + + split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8") + outputs = tf.strings.reduce_join( + self.unicode2byte.lookup(split_unicode_text), axis=-1 + ) + + if unbatched: + outputs = tf.squeeze(outputs, 0) + return outputs + + def get_config(self): + config = super().get_config() + # In the constructor, we pass the list of special tokens to the + # `unsplittable_tokens` arg of the superclass' constructor. Hence, we + # delete it from the config here. + del config["unsplittable_tokens"] + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py new file mode 100644 index 0000000000..4ceaea8057 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_tokenizer_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import ( + CLIPTokenizer, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CLIPTokenizerTest(TestCase): + def setUp(self): + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + self.merges = merges + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = ["airplane ", " airport"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=CLIPTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + # Whitespaces should be removed. + expected_output=[[0, 1], [0, 2]], + expected_detokenize_output=["airplane", "airport"], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + CLIPTokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"]) + + @pytest.mark.large + def test_smallest_preset(self): + self.skipTest( + "TODO: Add preset from `hf://openai/clip-vit-large-patch14`" + ) + self.run_preset_test( + cls=CLIPTokenizer, + preset="llama3_8b_en", + input_data=["The quick brown fox."], + expected_output=[[791, 4062, 14198, 39935, 13]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest( + "TODO: Add preset from `hf://openai/clip-vit-large-patch14`" + ) + for preset in CLIPTokenizer.presets: + self.run_preset_test( + cls=CLIPTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py new file mode 100644 index 0000000000..c8b6ef7566 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py @@ -0,0 +1,84 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.src.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) + + +class T5XXLPreprocessor(Preprocessor): + tokenizer_cls = T5Tokenizer + + def __init__( + self, + tokenizer, + sequence_length=256, + add_start_token=False, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call(self, x, y=None, sample_weight=None, sequence_length=None): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "T5XXL requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using T5XXL" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py new file mode 100644 index 0000000000..90b7dfaf9c --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_preprocessor_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest + +from keras_nlp.src.models.stable_diffusion_v3.t5_xxl_preprocessor import ( + T5XXLPreprocessor, +) +from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class T5XXLPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = T5Tokenizer( + proto=os.path.join(self.get_test_data_dir(), "t5_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 10, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=T5XXLPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[4, 9, 5, 7, 1, 0, 0, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + preprocessor = T5XXLPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "the quick brown fox" + preprocessor = T5XXLPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [4, 9, 5, 1]) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest("TODO") + for preset in T5XXLPreprocessor.presets: + self.run_preset_test( + cls=T5XXLPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py new file mode 100644 index 0000000000..9f4e5ae3a1 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py @@ -0,0 +1,148 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.t5.t5_layer_norm import T5LayerNorm +from keras_nlp.src.models.t5.t5_transformer_layer import T5TransformerLayer + + +class T5XXLTextEncoder(Backbone): + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + key_value_dim=None, + dropout=0.1, + activation="relu", + use_gated_activation=True, + layer_norm_epsilon=1e-06, + tie_embedding_weights=True, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_embedding_weights, + embeddings_initializer=keras.initializers.TruncatedNormal(1.0), + dtype=dtype, + name="token_embedding", + ) + self.encoder_embedding_dropout = keras.layers.Dropout( + dropout, + dtype=dtype, + name="encoder_embedding_dropout", + ) + self.encoder_transformer_layers = [] + for i in range(num_layers): + layer = T5TransformerLayer( + is_decoder=False, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + key_value_dim=key_value_dim or hidden_dim // num_heads, + dropout=dropout, + activation=activation, + layer_norm_epsilon=layer_norm_epsilon, + num_heads=num_heads, + use_gated_activation=use_gated_activation, + use_relative_attention_bias=bool(i == 0), + dtype=dtype, + name=f"transformer_encoder_layer_{i}", + ) + self.encoder_transformer_layers.append(layer) + self.encoder_layer_norm = T5LayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="encoder_output_layer_norm", + ) + self.encoder_dropout = keras.layers.Dropout( + dropout, + dtype=dtype, + name="encoder_output_dropout", + ) + + # === Functional Model === + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + # Encoder. + x = self.token_embedding(encoder_token_id_input) + x = self.encoder_embedding_dropout(x) + encoder_attention_mask = encoder_padding_mask_input[:, None, :] + position_bias = None + for transformer_layer in self.encoder_transformer_layers: + output = transformer_layer( + x, + attention_mask=encoder_attention_mask, + position_bias=position_bias, + use_causal_mask=False, + ) + if isinstance(output, tuple): + x, position_bias = output + x = self.encoder_layer_norm(x) + x = self.encoder_dropout(x) + encoder_output = x + + super().__init__( + { + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + }, + outputs=encoder_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.activation = keras.activations.get(activation) + self.key_value_dim = key_value_dim + self.dropout = dropout + self.use_gated_activation = use_gated_activation + self.layer_norm_epsilon = layer_norm_epsilon + self.tie_embedding_weights = tie_embedding_weights + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "activation": keras.activations.serialize(self.activation), + "key_value_dim": self.key_value_dim, + "dropout": self.dropout, + "use_gated_activation": self.use_gated_activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + "tie_embedding_weights": self.tie_embedding_weights, + } + ) + return config From beae2f410195fa399058dd27db9abb6e0eaf4eba Mon Sep 17 00:00:00 2001 From: Siva Sravana Kumar Neeli <113718461+sineeli@users.noreply.github.com> Date: Wed, 28 Aug 2024 15:55:14 -0700 Subject: [PATCH 12/19] Add Bounding Box Utils (#1791) * Bounding box utils * - Correct test cases * - Remove hard tensorflow dtype * - fix api gen * - Fix import for test cases - Use setup for converters test case * - fix api_gen issue * - FIx api gen * - Fix api gen error * - Correct test cases as per new api changes --- keras_nlp/api/__init__.py | 1 + keras_nlp/api/bounding_box/__init__.py | 23 + keras_nlp/src/bounding_box/__init__.py | 13 + keras_nlp/src/bounding_box/converters.py | 529 ++++++++++++++++++ keras_nlp/src/bounding_box/converters_test.py | 365 ++++++++++++ keras_nlp/src/bounding_box/to_dense.py | 95 ++++ keras_nlp/src/bounding_box/to_dense_test.py | 37 ++ keras_nlp/src/bounding_box/to_ragged.py | 99 ++++ keras_nlp/src/bounding_box/to_ragged_test.py | 101 ++++ keras_nlp/src/bounding_box/validate_format.py | 99 ++++ 10 files changed, 1362 insertions(+) create mode 100644 keras_nlp/api/bounding_box/__init__.py create mode 100644 keras_nlp/src/bounding_box/__init__.py create mode 100644 keras_nlp/src/bounding_box/converters.py create mode 100644 keras_nlp/src/bounding_box/converters_test.py create mode 100644 keras_nlp/src/bounding_box/to_dense.py create mode 100644 keras_nlp/src/bounding_box/to_dense_test.py create mode 100644 keras_nlp/src/bounding_box/to_ragged.py create mode 100644 keras_nlp/src/bounding_box/to_ragged_test.py create mode 100644 keras_nlp/src/bounding_box/validate_format.py diff --git a/keras_nlp/api/__init__.py b/keras_nlp/api/__init__.py index d0dc4576c6..46b16d5d8c 100644 --- a/keras_nlp/api/__init__.py +++ b/keras_nlp/api/__init__.py @@ -17,6 +17,7 @@ since your modifications would be overwritten. """ +from keras_nlp.api import bounding_box from keras_nlp.api import layers from keras_nlp.api import metrics from keras_nlp.api import models diff --git a/keras_nlp/api/bounding_box/__init__.py b/keras_nlp/api/bounding_box/__init__.py new file mode 100644 index 0000000000..18be1cd9aa --- /dev/null +++ b/keras_nlp/api/bounding_box/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_nlp.src.bounding_box.converters import convert_format +from keras_nlp.src.bounding_box.to_dense import to_dense +from keras_nlp.src.bounding_box.to_ragged import to_ragged +from keras_nlp.src.bounding_box.validate_format import validate_format diff --git a/keras_nlp/src/bounding_box/__init__.py b/keras_nlp/src/bounding_box/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/bounding_box/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/bounding_box/converters.py b/keras_nlp/src/bounding_box/converters.py new file mode 100644 index 0000000000..0e363fc6f7 --- /dev/null +++ b/keras_nlp/src/bounding_box/converters.py @@ -0,0 +1,529 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converter functions for working with bounding box formats.""" + +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +# Internal exception to propagate the fact images was not passed to a converter +# that needs it. +class RequiresImagesException(Exception): + pass + + +ALL_AXES = 4 + + +def _encode_box_to_deltas( + anchors, + boxes, + anchor_format: str, + box_format: str, + variance=None, + image_shape=None, +): + """Converts bounding_boxes from `center_yxhw` to delta format.""" + if variance is not None: + variance = ops.convert_to_tensor(variance, "float32") + var_len = variance.shape[-1] + + if var_len != 4: + raise ValueError(f"`variance` must be length 4, got {variance}") + encoded_anchors = convert_format( + anchors, + source=anchor_format, + target="center_yxhw", + image_shape=image_shape, + ) + boxes = convert_format( + boxes, source=box_format, target="center_yxhw", image_shape=image_shape + ) + anchor_dimensions = ops.maximum( + encoded_anchors[..., 2:], keras.backend.epsilon() + ) + box_dimensions = ops.maximum(boxes[..., 2:], keras.backend.epsilon()) + # anchors be unbatched, boxes can either be batched or unbatched. + boxes_delta = ops.concatenate( + [ + (boxes[..., :2] - encoded_anchors[..., :2]) / anchor_dimensions, + ops.log(box_dimensions / anchor_dimensions), + ], + axis=-1, + ) + if variance is not None: + boxes_delta /= variance + return boxes_delta + + +def _decode_deltas_to_boxes( + anchors, + boxes_delta, + anchor_format: str, + box_format: str, + variance=None, + image_shape=None, +): + """Converts bounding_boxes from delta format to `center_yxhw`.""" + if variance is not None: + variance = ops.convert_to_tensor(variance, "float32") + var_len = variance.shape[-1] + + if var_len != 4: + raise ValueError(f"`variance` must be length 4, got {variance}") + + def decode_single_level(anchor, box_delta): + encoded_anchor = convert_format( + anchor, + source=anchor_format, + target="center_yxhw", + image_shape=image_shape, + ) + if variance is not None: + box_delta = box_delta * variance + # anchors be unbatched, boxes can either be batched or unbatched. + box = ops.concatenate( + [ + box_delta[..., :2] * encoded_anchor[..., 2:] + + encoded_anchor[..., :2], + ops.exp(box_delta[..., 2:]) * encoded_anchor[..., 2:], + ], + axis=-1, + ) + box = convert_format( + box, + source="center_yxhw", + target=box_format, + image_shape=image_shape, + ) + return box + + if isinstance(anchors, dict) and isinstance(boxes_delta, dict): + boxes = {} + for lvl, anchor in anchors.items(): + boxes[lvl] = decode_single_level(anchor, boxes_delta[lvl]) + return boxes + else: + return decode_single_level(anchors, boxes_delta) + + +def _center_yxhw_to_xyxy(boxes, images=None, image_shape=None): + y, x, height, width = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0], + axis=-1, + ) + + +def _center_xywh_to_xyxy(boxes, images=None, image_shape=None): + x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0], + axis=-1, + ) + + +def _xywh_to_xyxy(boxes, images=None, image_shape=None): + x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate([x, y, x + width, y + height], axis=-1) + + +def _xyxy_to_center_yxhw(boxes, images=None, image_shape=None): + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [ + (top + bottom) / 2.0, + (left + right) / 2.0, + bottom - top, + right - left, + ], + axis=-1, + ) + + +def _rel_xywh_to_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [ + image_width * x, + image_height * y, + image_width * (x + width), + image_height * (y + height), + ], + axis=-1, + ) + + +def _xyxy_no_op(boxes, images=None, image_shape=None): + return boxes + + +def _xyxy_to_xywh(boxes, images=None, image_shape=None): + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [left, top, right - left, bottom - top], + axis=-1, + ) + + +def _xyxy_to_rel_xywh(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + left, right = ( + left / image_width, + right / image_width, + ) + top, bottom = top / image_height, bottom / image_height + return ops.concatenate( + [left, top, right - left, bottom - top], + axis=-1, + ) + + +def _xyxy_to_center_xywh(boxes, images=None, image_shape=None): + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate( + [ + (left + right) / 2.0, + (top + bottom) / 2.0, + right - left, + bottom - top, + ], + axis=-1, + ) + + +def _rel_xyxy_to_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split( + boxes, + ALL_AXES, + axis=-1, + ) + left, right = left * image_width, right * image_width + top, bottom = top * image_height, bottom * image_height + return ops.concatenate( + [left, top, right, bottom], + axis=-1, + ) + + +def _xyxy_to_rel_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split( + boxes, + ALL_AXES, + axis=-1, + ) + left, right = left / image_width, right / image_width + top, bottom = top / image_height, bottom / image_height + return ops.concatenate( + [left, top, right, bottom], + axis=-1, + ) + + +def _yxyx_to_xyxy(boxes, images=None, image_shape=None): + y1, x1, y2, x2 = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate([x1, y1, x2, y2], axis=-1) + + +def _rel_yxyx_to_xyxy(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + top, left, bottom, right = ops.split( + boxes, + ALL_AXES, + axis=-1, + ) + left, right = left * image_width, right * image_width + top, bottom = top * image_height, bottom * image_height + return ops.concatenate( + [left, top, right, bottom], + axis=-1, + ) + + +def _xyxy_to_yxyx(boxes, images=None, image_shape=None): + x1, y1, x2, y2 = ops.split(boxes, ALL_AXES, axis=-1) + return ops.concatenate([y1, x1, y2, x2], axis=-1) + + +def _xyxy_to_rel_yxyx(boxes, images=None, image_shape=None): + image_height, image_width = _image_shape(images, image_shape, boxes) + left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1) + left, right = left / image_width, right / image_width + top, bottom = top / image_height, bottom / image_height + return ops.concatenate( + [top, left, bottom, right], + axis=-1, + ) + + +TO_XYXY_CONVERTERS = { + "xywh": _xywh_to_xyxy, + "center_xywh": _center_xywh_to_xyxy, + "center_yxhw": _center_yxhw_to_xyxy, + "rel_xywh": _rel_xywh_to_xyxy, + "xyxy": _xyxy_no_op, + "rel_xyxy": _rel_xyxy_to_xyxy, + "yxyx": _yxyx_to_xyxy, + "rel_yxyx": _rel_yxyx_to_xyxy, +} + +FROM_XYXY_CONVERTERS = { + "xywh": _xyxy_to_xywh, + "center_xywh": _xyxy_to_center_xywh, + "center_yxhw": _xyxy_to_center_yxhw, + "rel_xywh": _xyxy_to_rel_xywh, + "xyxy": _xyxy_no_op, + "rel_xyxy": _xyxy_to_rel_xyxy, + "yxyx": _xyxy_to_yxyx, + "rel_yxyx": _xyxy_to_rel_yxyx, +} + + +@keras_nlp_export("keras_nlp.bounding_box.convert_format") +def convert_format( + boxes, source, target, images=None, image_shape=None, dtype="float32" +): + f"""Converts bounding_boxes from one format to another. + + Supported formats are: + - `"xyxy"`, also known as `corners` format. In this format the first four + axes represent `[left, top, right, bottom]` in that order. + - `"rel_xyxy"`. In this format, the axes are the same as `"xyxy"` but the x + coordinates are normalized using the image width, and the y axes the + image height. All values in `rel_xyxy` are in the range `(0, 1)`. + - `"xywh"`. In this format the first four axes represent + `[left, top, width, height]`. + - `"rel_xywh". In this format the first four axes represent + [left, top, width, height], just like `"xywh"`. Unlike `"xywh"`, the + values are in the range (0, 1) instead of absolute pixel values. + - `"center_xyWH"`. In this format the first two coordinates represent the x + and y coordinates of the center of the bounding box, while the last two + represent the width and height of the bounding box. + - `"center_yxHW"`. In this format the first two coordinates represent the y + and x coordinates of the center of the bounding box, while the last two + represent the height and width of the bounding box. + - `"yxyx"`. In this format the first four axes represent + [top, left, bottom, right] in that order. + - `"rel_yxyx"`. In this format, the axes are the same as `"yxyx"` but the x + coordinates are normalized using the image width, and the y axes the + image height. All values in `rel_yxyx` are in the range (0, 1). + Formats are case insensitive. It is recommended that you capitalize width + and height to maximize the visual difference between `"xyWH"` and `"xyxy"`. + + Relative formats, abbreviated `rel`, make use of the shapes of the `images` + passed. In these formats, the coordinates, widths, and heights are all + specified as percentages of the host image. `images` may be a ragged + Tensor. Note that using a ragged Tensor for images may cause a substantial + performance loss, as each image will need to be processed separately due to + the mismatching image shapes. + + Example: + + ```python + boxes = load_coco_dataset() + boxes_in_xywh = keras_nlp.bounding_box.convert_format( + boxes, + source='xyxy', + target='xyWH' + ) + ``` + + Args: + boxes: tensor representing bounding boxes in the format specified in + the `source` parameter. `boxes` can optionally have extra + dimensions stacked on the final axis to store metadata. boxes + should be a 3D tensor, with the shape `[batch_size, num_boxes, 4]`. + Alternatively, boxes can be a dictionary with key 'boxes' containing + a tensor matching the aforementioned spec. + source:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}. + Used to specify the original format of the `boxes` parameter. + target:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}. + Used to specify the destination format of the `boxes` parameter. + images: (Optional) a batch of images aligned with `boxes` on the first + axis. Should be at least 3 dimensions, with the first 3 dimensions + representing: `[batch_size, height, width]`. Used in some + converters to compute relative pixel values of the bounding box + dimensions. Required when transforming from a rel format to a + non-rel format. + dtype: the data type to use when transforming the boxes, defaults to + `"float32"`. + """ + if isinstance(boxes, dict): + converted_boxes = boxes.copy() + converted_boxes["boxes"] = convert_format( + boxes["boxes"], + source=source, + target=target, + images=images, + image_shape=image_shape, + dtype=dtype, + ) + return converted_boxes + + if boxes.shape[-1] is not None and boxes.shape[-1] != 4: + raise ValueError( + "Expected `boxes` to be a Tensor with a final dimension of " + f"`4`. Instead, got `boxes.shape={boxes.shape}`." + ) + if images is not None and image_shape is not None: + raise ValueError( + "convert_format() expects either `images` or `image_shape`, but " + f"not both. Received images={images} image_shape={image_shape}" + ) + + _validate_image_shape(image_shape) + + source = source.lower() + target = target.lower() + if source not in TO_XYXY_CONVERTERS: + raise ValueError( + "`convert_format()` received an unsupported format for the " + "argument `source`. `source` should be one of " + f"{TO_XYXY_CONVERTERS.keys()}. Got source={source}" + ) + if target not in FROM_XYXY_CONVERTERS: + raise ValueError( + "`convert_format()` received an unsupported format for the " + "argument `target`. `target` should be one of " + f"{FROM_XYXY_CONVERTERS.keys()}. Got target={target}" + ) + + boxes = ops.cast(boxes, dtype) + if source == target: + return boxes + + # rel->rel conversions should not require images + if source.startswith("rel") and target.startswith("rel"): + source = source.replace("rel_", "", 1) + target = target.replace("rel_", "", 1) + + boxes, images, squeeze = _format_inputs(boxes, images) + to_xyxy_fn = TO_XYXY_CONVERTERS[source] + from_xyxy_fn = FROM_XYXY_CONVERTERS[target] + + try: + in_xyxy = to_xyxy_fn(boxes, images=images, image_shape=image_shape) + result = from_xyxy_fn(in_xyxy, images=images, image_shape=image_shape) + except RequiresImagesException: + raise ValueError( + "convert_format() must receive `images` or `image_shape` when " + "transforming between relative and absolute formats." + f"convert_format() received source=`{format}`, target=`{format}, " + f"but images={images} and image_shape={image_shape}." + ) + + return _format_outputs(result, squeeze) + + +def _format_inputs(boxes, images): + boxes_rank = len(boxes.shape) + if boxes_rank > 3: + raise ValueError( + "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got " + f"len(boxes.shape)={boxes_rank}" + ) + boxes_includes_batch = boxes_rank == 3 + # Determine if images needs an expand_dims() call + if images is not None: + images_rank = len(images.shape) + if images_rank > 4: + raise ValueError( + "Expected len(images.shape)=2, or len(images.shape)=3, got " + f"len(images.shape)={images_rank}" + ) + images_include_batch = images_rank == 4 + if boxes_includes_batch != images_include_batch: + raise ValueError( + "convert_format() expects both boxes and images to be batched, " + "or both boxes and images to be unbatched. Received " + f"len(boxes.shape)={boxes_rank}, " + f"len(images.shape)={images_rank}. Expected either " + "len(boxes.shape)=2 AND len(images.shape)=3, or " + "len(boxes.shape)=3 AND len(images.shape)=4." + ) + if not images_include_batch: + images = ops.expand_dims(images, axis=0) + + if not boxes_includes_batch: + return ops.expand_dims(boxes, axis=0), images, True + return boxes, images, False + + +def _validate_image_shape(image_shape): + # Escape early if image_shape is None and skip validation. + if image_shape is None: + return + # tuple/list + if isinstance(image_shape, (tuple, list)): + if len(image_shape) != 3: + raise ValueError( + "image_shape should be of length 3, but got " + f"image_shape={image_shape}" + ) + return + + # tensor + if ops.is_tensor(image_shape): + if len(image_shape.shape) > 1: + raise ValueError( + "image_shape.shape should be (3), but got " + f"image_shape.shape={image_shape.shape}" + ) + if image_shape.shape[0] != 3: + raise ValueError( + "image_shape.shape should be (3), but got " + f"image_shape.shape={image_shape.shape}" + ) + return + + # Warn about failure cases + raise ValueError( + "Expected image_shape to be either a tuple, list, Tensor. " + f"Received image_shape={image_shape}" + ) + + +def _format_outputs(boxes, squeeze): + if squeeze: + return ops.squeeze(boxes, axis=0) + return boxes + + +def _image_shape(images, image_shape, boxes): + if images is None and image_shape is None: + raise RequiresImagesException() + + if image_shape is None: + if not isinstance(images, tf.RaggedTensor): + image_shape = ops.shape(images) + height, width = image_shape[1], image_shape[2] + else: + height = ops.reshape(images.row_lengths(), (-1, 1)) + width = ops.reshape(ops.max(images.row_lengths(axis=2), 1), (-1, 1)) + height = ops.expand_dims(height, axis=-1) + width = ops.expand_dims(width, axis=-1) + else: + height, width = image_shape[0], image_shape[1] + return ops.cast(height, boxes.dtype), ops.cast(width, boxes.dtype) diff --git a/keras_nlp/src/bounding_box/converters_test.py b/keras_nlp/src/bounding_box/converters_test.py new file mode 100644 index 0000000000..f6f3adfa17 --- /dev/null +++ b/keras_nlp/src/bounding_box/converters_test.py @@ -0,0 +1,365 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized +from keras import backend + +from keras_nlp.src.bounding_box import converters +from keras_nlp.src.bounding_box import to_dense +from keras_nlp.src.bounding_box import to_ragged +from keras_nlp.src.tests.test_case import TestCase + + +class ConvertersTestCase(TestCase): + def setUp(self): + xyxy_box = np.array( + [[[10, 20, 110, 120], [20, 30, 120, 130]]], dtype="float32" + ) + yxyx_box = np.array( + [[[20, 10, 120, 110], [30, 20, 130, 120]]], dtype="float32" + ) + rel_xyxy_box = np.array( + [[[0.01, 0.02, 0.11, 0.12], [0.02, 0.03, 0.12, 0.13]]], + dtype="float32", + ) + rel_xyxy_box_ragged_images = np.array( + [[[0.10, 0.20, 1.1, 1.20], [0.40, 0.6, 2.40, 2.6]]], dtype="float32" + ) + rel_yxyx_box = np.array( + [[[0.02, 0.01, 0.12, 0.11], [0.03, 0.02, 0.13, 0.12]]], + dtype="float32", + ) + rel_yxyx_box_ragged_images = np.array( + [[[0.2, 0.1, 1.2, 1.1], [0.6, 0.4, 2.6, 2.4]]], dtype="float32" + ) + center_xywh_box = np.array( + [[[60, 70, 100, 100], [70, 80, 100, 100]]], dtype="float32" + ) + xywh_box = np.array( + [[[10, 20, 100, 100], [20, 30, 100, 100]]], dtype="float32" + ) + rel_xywh_box = np.array( + [[[0.01, 0.02, 0.1, 0.1], [0.02, 0.03, 0.1, 0.1]]], dtype="float32" + ) + rel_xywh_box_ragged_images = np.array( + [[[0.1, 0.2, 1, 1], [0.4, 0.6, 2, 2]]], dtype="float32" + ) + + self.ragged_images = tf.ragged.constant( + [ + np.ones(shape=[100, 100, 3]), + np.ones(shape=[50, 50, 3]), + ], # 2 images + ragged_rank=2, + ) + + self.images = np.ones([2, 1000, 1000, 3]) + + self.ragged_classes = tf.ragged.constant([[0], [0]], dtype="float32") + + self.boxes = { + "xyxy": xyxy_box, + "center_xywh": center_xywh_box, + "rel_xywh": rel_xywh_box, + "xywh": xywh_box, + "rel_xyxy": rel_xyxy_box, + "yxyx": yxyx_box, + "rel_yxyx": rel_yxyx_box, + } + + self.boxes_ragged_images = { + "xyxy": xyxy_box, + "center_xywh": center_xywh_box, + "rel_xywh": rel_xywh_box_ragged_images, + "xywh": xywh_box, + "rel_xyxy": rel_xyxy_box_ragged_images, + "yxyx": yxyx_box, + "rel_yxyx": rel_yxyx_box_ragged_images, + } + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + def test_converters(self, source, target): + source, target + source_box = self.boxes[source] + target_box = self.boxes[target] + + self.assertAllClose( + converters.convert_format( + source_box, source=source, target=target, images=self.images + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_converters_ragged_images(self, source, target): + source_box = _raggify(self.boxes_ragged_images[source]) + target_box = _raggify(self.boxes_ragged_images[target]) + self.assertAllClose( + converters.convert_format( + source_box, + source=source, + target=target, + images=self.ragged_images, + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + def test_converters_unbatched(self, source, target): + source_box = self.boxes[source][0] + target_box = self.boxes[target][0] + + self.assertAllClose( + converters.convert_format( + source_box, source=source, target=target, images=self.images[0] + ), + target_box, + ) + + def test_raises_with_different_image_rank(self): + source_box = self.boxes["xyxy"][0] + with self.assertRaises(ValueError): + converters.convert_format( + source_box, source="xyxy", target="xywh", images=self.images + ) + + def test_without_images(self): + source_box = self.boxes["xyxy"] + target_box = self.boxes["xywh"] + self.assertAllClose( + converters.convert_format(source_box, source="xyxy", target="xywh"), + target_box, + ) + + def test_rel_to_rel_without_images(self): + source_box = self.boxes["rel_xyxy"] + target_box = self.boxes["rel_yxyx"] + self.assertAllClose( + converters.convert_format( + source_box, source="rel_xyxy", target="rel_yxyx" + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_ragged_bounding_box(self, source, target): + source_box = _raggify(self.boxes[source]) + target_box = _raggify(self.boxes[target]) + self.assertAllClose( + converters.convert_format( + source_box, source=source, target=target, images=self.images + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_ragged_bounding_box_ragged_images(self, source, target): + source_box = _raggify(self.boxes_ragged_images[source]) + target_box = _raggify(self.boxes_ragged_images[target]) + self.assertAllClose( + converters.convert_format( + source_box, + source=source, + target=target, + images=self.ragged_images, + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_ragged_bounding_box_with_image_shape(self, source, target): + source_box = _raggify(self.boxes[source]) + target_box = _raggify(self.boxes[target]) + self.assertAllClose( + converters.convert_format( + source_box, + source=source, + target=target, + image_shape=(1000, 1000, 3), + ), + target_box, + ) + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "center_xywh", + "rel_xywh", + "xywh", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_dense_bounding_box_with_ragged_images(self, source, target): + source_box = _raggify(self.boxes_ragged_images[source]) + target_box = _raggify(self.boxes_ragged_images[target]) + source_bounding_boxes = { + "boxes": source_box, + "classes": self.ragged_classes, + } + source_bounding_boxes = to_dense.to_dense(source_bounding_boxes) + + result_bounding_boxes = converters.convert_format( + source_bounding_boxes, + source=source, + target=target, + images=self.ragged_images, + ) + result_bounding_boxes = to_ragged.to_ragged(result_bounding_boxes) + + self.assertAllClose( + result_bounding_boxes["boxes"], + target_box, + ) + + +def _raggify(tensor): + tensor = tf.squeeze(tensor, axis=0) + tensor = tf.RaggedTensor.from_row_lengths(tensor, [1, 1]) + return tensor diff --git a/keras_nlp/src/bounding_box/to_dense.py b/keras_nlp/src/bounding_box/to_dense.py new file mode 100644 index 0000000000..3c42d09f4f --- /dev/null +++ b/keras_nlp/src/bounding_box/to_dense.py @@ -0,0 +1,95 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras_nlp.src.bounding_box.validate_format as validate_format +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +def _box_shape(batched, boxes_shape, max_boxes): + # ensure we dont drop the final axis in RaggedTensor mode + if max_boxes is None: + shape = list(boxes_shape) + shape[-1] = 4 + return shape + if batched: + return [None, max_boxes, 4] + return [max_boxes, 4] + + +def _classes_shape(batched, classes_shape, max_boxes): + if max_boxes is None: + return None + if batched: + return [None, max_boxes] + classes_shape[2:] + return [max_boxes] + classes_shape[2:] + + +@keras_nlp_export("keras_nlp.bounding_box.to_dense") +def to_dense(bounding_boxes, max_boxes=None, default_value=-1): + """to_dense converts bounding boxes to Dense tensors + + Args: + bounding_boxes: bounding boxes in KerasCV dictionary format. + max_boxes: the maximum number of boxes, used to pad tensors to a given + shape. This can be used to make object detection pipelines TPU + compatible. + default_value: the default value to pad bounding boxes with. defaults + to -1. + """ + info = validate_format.validate_format(bounding_boxes) + + # guards against errors in metrics regarding modification of inputs. + # also guards against unexpected behavior when modifying downstream + bounding_boxes = bounding_boxes.copy() + + # Already running in masked mode + if not info["ragged"]: + # even if already ragged, still copy the dictionary for API consistency + return bounding_boxes + + if isinstance(bounding_boxes["classes"], tf.RaggedTensor): + bounding_boxes["classes"] = bounding_boxes["classes"].to_tensor( + default_value=default_value, + shape=_classes_shape( + info["is_batched"], bounding_boxes["classes"].shape, max_boxes + ), + ) + + if isinstance(bounding_boxes["boxes"], tf.RaggedTensor): + bounding_boxes["boxes"] = bounding_boxes["boxes"].to_tensor( + default_value=default_value, + shape=_box_shape( + info["is_batched"], bounding_boxes["boxes"].shape, max_boxes + ), + ) + + if "confidence" in bounding_boxes: + if isinstance(bounding_boxes["confidence"], tf.RaggedTensor): + bounding_boxes["confidence"] = bounding_boxes[ + "confidence" + ].to_tensor( + default_value=default_value, + shape=_classes_shape( + info["is_batched"], + bounding_boxes["confidence"].shape, + max_boxes, + ), + ) + + return bounding_boxes diff --git a/keras_nlp/src/bounding_box/to_dense_test.py b/keras_nlp/src/bounding_box/to_dense_test.py new file mode 100644 index 0000000000..4bb795659b --- /dev/null +++ b/keras_nlp/src/bounding_box/to_dense_test.py @@ -0,0 +1,37 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tensorflow as tf +from keras import backend + +from keras_nlp.src.bounding_box import to_dense +from keras_nlp.src.tests.test_case import TestCase + + +class ToDenseTest(TestCase): + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_converts_to_dense(self): + bounding_boxes = { + "boxes": tf.ragged.constant( + [[[0, 0, 1, 1]], [[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 1, 1]]] + ), + "classes": tf.ragged.constant([[0], [1, 2, 3]]), + } + bounding_boxes = to_dense.to_dense(bounding_boxes) + self.assertEqual(bounding_boxes["boxes"].shape, [2, 3, 4]) + self.assertEqual(bounding_boxes["classes"].shape, [2, 3]) diff --git a/keras_nlp/src/bounding_box/to_ragged.py b/keras_nlp/src/bounding_box/to_ragged.py new file mode 100644 index 0000000000..2ebd4a00f4 --- /dev/null +++ b/keras_nlp/src/bounding_box/to_ragged.py @@ -0,0 +1,99 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +import keras_nlp.src.bounding_box.validate_format as validate_format +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_nlp_export("keras_nlp.bounding_box.to_ragged") +def to_ragged(bounding_boxes, sentinel=-1, dtype="float32"): + """converts a Dense padded bounding box `tf.Tensor` to a `tf.RaggedTensor`. + + Bounding boxes are ragged tensors in most use cases. Converting them to a + dense tensor makes it easier to work with Tensorflow ecosystem. + This function can be used to filter out the masked out bounding boxes by + checking for padded sentinel value of the class_id axis of the + bounding_boxes. + + Example: + ```python + bounding_boxes = { + "boxes": tf.constant([[2, 3, 4, 5], [0, 1, 2, 3]]), + "classes": tf.constant([[-1, 1]]), + } + bounding_boxes = bounding_box.to_ragged(bounding_boxes) + print(bounding_boxes) + # { + # "boxes": [[0, 1, 2, 3]], + # "classes": [[1]] + # } + ``` + + Args: + bounding_boxes: a Tensor of bounding boxes. May be batched, or + unbatched. + sentinel: The value indicating that a bounding box does not exist at the + current index, and the corresponding box is padding, defaults to -1. + dtype: the data type to use for the underlying Tensors. + Returns: + dictionary of `tf.RaggedTensor` or 'tf.Tensor' containing the filtered + bounding boxes. + """ + if keras.config.backend() != "tensorflow": + raise NotImplementedError( + "`bounding_box.to_ragged` was called using a backend which does " + "not support ragged tensors. " + f"Current backend: {keras.backend.backend()}." + ) + + info = validate_format.validate_format(bounding_boxes) + + if info["ragged"]: + return bounding_boxes + + boxes = bounding_boxes.get("boxes") + classes = bounding_boxes.get("classes") + confidence = bounding_boxes.get("confidence", None) + + mask = classes != sentinel + + boxes = tf.ragged.boolean_mask(boxes, mask) + classes = tf.ragged.boolean_mask(classes, mask) + if confidence is not None: + confidence = tf.ragged.boolean_mask(confidence, mask) + + if isinstance(boxes, tf.Tensor): + boxes = tf.RaggedTensor.from_tensor(boxes) + + if isinstance(classes, tf.Tensor) and len(classes.shape) > 1: + classes = tf.RaggedTensor.from_tensor(classes) + + if confidence is not None: + if isinstance(confidence, tf.Tensor) and len(confidence.shape) > 1: + confidence = tf.RaggedTensor.from_tensor(confidence) + + result = bounding_boxes.copy() + result["boxes"] = tf.cast(boxes, dtype) + result["classes"] = tf.cast(classes, dtype) + + if confidence is not None: + result["confidence"] = tf.cast(confidence, dtype) + + return result diff --git a/keras_nlp/src/bounding_box/to_ragged_test.py b/keras_nlp/src/bounding_box/to_ragged_test.py new file mode 100644 index 0000000000..cbe5146d11 --- /dev/null +++ b/keras_nlp/src/bounding_box/to_ragged_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from keras import backend + +from keras_nlp.src.bounding_box import to_dense +from keras_nlp.src.bounding_box import to_ragged +from keras_nlp.src.tests.test_case import TestCase + + +class ToRaggedTest(TestCase): + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_converts_to_ragged(self): + bounding_boxes = { + "boxes": np.array( + [[[0, 0, 0, 0], [0, 0, 0, 0]], [[2, 3, 4, 5], [0, 1, 2, 3]]] + ), + "classes": np.array([[-1, -1], [-1, 1]]), + "confidence": np.array([[0.5, 0.7], [0.23, 0.12]]), + } + bounding_boxes = to_ragged.to_ragged(bounding_boxes) + + self.assertEqual(bounding_boxes["boxes"][1].shape, [1, 4]) + self.assertEqual(bounding_boxes["classes"][1].shape, [1]) + self.assertEqual( + bounding_boxes["confidence"][1].shape, + [ + 1, + ], + ) + + self.assertEqual(bounding_boxes["classes"][0].shape, [0]) + self.assertEqual(bounding_boxes["boxes"][0].shape, [0, 4]) + self.assertEqual( + bounding_boxes["confidence"][0].shape, + [ + 0, + ], + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only applies to backends which support raggeds", + ) + def test_round_trip(self): + original = { + "boxes": np.array( + [ + [[0, 0, 0, 0], [-1, -1, -1, -1]], + [[-1, -1, -1, -1], [-1, -1, -1, -1]], + ] + ), + "classes": np.array([[1, -1], [-1, -1]]), + "confidence": np.array([[0.5, -1], [-1, -1]]), + } + bounding_boxes = to_ragged.to_ragged(original) + bounding_boxes = to_dense.to_dense(bounding_boxes, max_boxes=2) + + self.assertEqual(bounding_boxes["boxes"][1].shape, [2, 4]) + self.assertEqual(bounding_boxes["classes"][1].shape, [2]) + self.assertEqual(bounding_boxes["classes"][0].shape, [2]) + self.assertEqual(bounding_boxes["boxes"][0].shape, [2, 4]) + self.assertEqual(bounding_boxes["confidence"][0].shape, [2]) + + self.assertAllEqual(bounding_boxes["boxes"], original["boxes"]) + self.assertAllEqual(bounding_boxes["classes"], original["classes"]) + self.assertAllEqual( + bounding_boxes["confidence"], original["confidence"] + ) + + @pytest.mark.skipif( + backend.backend() == "tensorflow", + reason="Only applies to backends which don't support raggeds", + ) + def test_backend_without_raggeds_throws(self): + bounding_boxes = { + "boxes": np.array( + [[[0, 0, 0, 0], [0, 0, 0, 0]], [[2, 3, 4, 5], [0, 1, 2, 3]]] + ), + "classes": np.array([[-1, -1], [-1, 1]]), + "confidence": np.array([[0.5, 0.7], [0.23, 0.12]]), + } + + with self.assertRaisesRegex(NotImplementedError, "support ragged"): + to_ragged.to_ragged(bounding_boxes) diff --git a/keras_nlp/src/bounding_box/validate_format.py b/keras_nlp/src/bounding_box/validate_format.py new file mode 100644 index 0000000000..51fb310807 --- /dev/null +++ b/keras_nlp/src/bounding_box/validate_format.py @@ -0,0 +1,99 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.api_export import keras_nlp_export + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_nlp_export("keras_nlp.bounding_box.validate_format") +def validate_format(bounding_boxes, variable_name="bounding_boxes"): + """validates that a given set of bounding boxes complies with KerasNLP + format. + + For a set of bounding boxes to be valid it must satisfy the following + conditions: + - `bounding_boxes` must be a dictionary + - contains keys `"boxes"` and `"classes"` + - each entry must have matching first two dimensions; representing the batch + axis and the number of boxes per image axis. + - either both `"boxes"` and `"classes"` are batched, or both are unbatched. + + Additionally, one of the following must be satisfied: + - `"boxes"` and `"classes"` are both Ragged + - `"boxes"` and `"classes"` are both Dense + - `"boxes"` and `"classes"` are unbatched + + Args: + bounding_boxes: dictionary of bounding boxes according to KerasCV + format. + + Raises: + ValueError if any of the above conditions are not met + """ + if not isinstance(bounding_boxes, dict): + raise ValueError( + f"Expected `{variable_name}` to be a dictionary, got " + f"`{variable_name}={bounding_boxes}`." + ) + if not all([x in bounding_boxes for x in ["boxes", "classes"]]): + raise ValueError( + f"Expected `{variable_name}` to be a dictionary containing keys " + "`'classes'` and `'boxes'`. Got " + f"`{variable_name}.keys()={bounding_boxes.keys()}`." + ) + + boxes = bounding_boxes.get("boxes") + classes = bounding_boxes.get("classes") + info = {} + + is_batched = len(boxes.shape) == 3 + info["is_batched"] = is_batched + info["ragged"] = isinstance(boxes, tf.RaggedTensor) + + if not is_batched: + if boxes.shape[:1] != classes.shape[:1]: + raise ValueError( + "Expected `boxes` and `classes` to have matching dimensions " + "on the first axis when operating in unbatched mode. Got " + f"`boxes.shape={boxes.shape}`, `classes.shape={classes.shape}`." + ) + + info["classes_one_hot"] = len(classes.shape) == 2 + # No Ragged checks needed in unbatched mode. + return info + + info["classes_one_hot"] = len(classes.shape) == 3 + + if isinstance(boxes, tf.RaggedTensor) != isinstance( + classes, tf.RaggedTensor + ): + raise ValueError( + "Either both `boxes` and `classes` " + "should be Ragged, or neither should be ragged." + f" Got `type(boxes)={type(boxes)}`, type(classes)={type(classes)}." + ) + + # Batched mode checks + if boxes.shape[:2] != classes.shape[:2]: + raise ValueError( + "Expected `boxes` and `classes` to have matching dimensions " + "on the first two axes when operating in batched mode. " + f"Got `boxes.shape={boxes.shape}`, `classes.shape={classes.shape}`." + ) + + return info From 9289ab79b8fa4cd00926b8536bad07fd755714d6 Mon Sep 17 00:00:00 2001 From: Usha Rengaraju <34335028+ushareng@users.noreply.github.com> Date: Thu, 29 Aug 2024 04:26:16 +0530 Subject: [PATCH 13/19] mobilenet_v3 added in keras-nlp (#1782) * mobilenet_v3 added in keras-nlp * minor bug fixed in mobilenet_v3_backbone * formatting corrected * refactoring backbone * correct_pad_downsample method added * refactoring backbone * parameters updated * Testcaseupdated, expected output shape corrected * code formatted with black * testcase updated * refactoring and description added * comments updated * added mobilenet v1 and v2 * merge conflict resolved * version arg removed, and config options added * input_shape changed to image_shape in arg * config updated * input shape corrected * comments resolved * activation function format changed * minor bug fixed * minor bug fixed * added vision_backbone_test * channel_first bug resolved * channel_first cases working * comments resolved * formatting fixed * refactoring --------- Co-authored-by: ushareng --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/mobilenet/__init__.py | 13 + .../models/mobilenet/mobilenet_backbone.py | 530 ++++++++++++++++++ .../mobilenet/mobilenet_backbone_test.py | 58 ++ .../mobilenet/mobilenet_image_classifier.py | 114 ++++ .../mobilenet_image_classifier_test.py | 71 +++ 6 files changed, 790 insertions(+) create mode 100644 keras_nlp/src/models/mobilenet/__init__.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_backbone.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py create mode 100644 keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index c6d8ed7d32..17b00c1f05 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -171,6 +171,10 @@ from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import ( MiTImageClassifier, ) +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_nlp.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/mobilenet/__init__.py b/keras_nlp/src/models/mobilenet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/mobilenet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/mobilenet/mobilenet_backbone.py b/keras_nlp/src/models/mobilenet/mobilenet_backbone.py new file mode 100644 index 0000000000..4054b6d76f --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_backbone.py @@ -0,0 +1,530 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + +BN_EPSILON = 1e-3 +BN_MOMENTUM = 0.999 + + +@keras_nlp_export("keras_nlp.models.MobileNetBackbone") +class MobileNetBackbone(Backbone): + """Instantiates the MobileNet architecture. + + MobileNet is a lightweight convolutional neural network (CNN) + optimized for mobile and edge devices, striking a balance between + accuracy and efficiency. By employing depthwise separable convolutions + and techniques like Squeeze-and-Excitation (SE) blocks, + MobileNet models are highly suitable for real-time applications on + resource-constrained devices. + + References: + - [MobileNets: Efficient Convolutional Neural Networks + for Mobile Vision Applications]( + https://arxiv.org/abs/1704.04861) + - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( + https://arxiv.org/abs/1801.04381) (CVPR 2018) + - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) + (ICCV 2019) + + Args: + stackwise_expansion: list of ints or floats, the expansion ratio for + each inverted residual block in the model. + stackwise_num_filters: list of ints, number of filters for each inverted + residual block in the model. + stackwise_kernel_size: list of ints, kernel size for each inverted + residual block in the model. + stackwise_num_strides: list of ints, stride length for each inverted + residual block in the model. + stackwise_se_ratio: se ratio for each inverted residual block in the + model. 0 if dont want to add Squeeze and Excite layer. + stackwise_activation: list of activation functions, for each inverted + residual block in the model. + include_rescaling: bool, whether to rescale the inputs. If set to True, + inputs will be passed through a `Rescaling(scale=1 / 255)` + layer. + image_shape: optional shape tuple, defaults to (224, 224, 3). + depth_multiplier: float, controls the width of the network. + - If `depth_multiplier` < 1.0, proportionally decreases the number + of filters in each layer. + - If `depth_multiplier` > 1.0, proportionally increases the number + of filters in each layer. + - If `depth_multiplier` = 1, default number of filters from the paper + are used at each layer. + input_num_filters: number of filters in first convolution layer + output_num_filters: specifies whether to add conv and batch_norm in the end, + if set to None, it will not add these layers in the end. + 'None' for MobileNetV1 + input_activation: activation function to be used in the input layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + output_activation: activation function to be used in the output layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + inverted_res_block: whether to use inverted residual blocks or not, + 'False' for MobileNetV1, + 'True' for MobileNetV2 and MobileNetV3 + + + Example: + ```python + input_data = tf.ones(shape=(8, 224, 224, 3)) + + # Randomly initialized backbone with a custom config + model = MobileNetBackbone( + stackwise_expansion=[1, 4, 6], + stackwise_num_filters=[4, 8, 16], + stackwise_kernel_size=[3, 3, 5], + stackwise_num_strides=[2, 2, 1], + stackwise_se_ratio=[0.25, None, 0.25], + stackwise_activation=["relu", "relu6", "hard_swish"], + include_rescaling=False, + output_num_filters=1280, + input_activation='hard_swish', + output_activation='hard_swish', + inverted_res_block=True, + + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + stackwise_expansion, + stackwise_num_filters, + stackwise_kernel_size, + stackwise_num_strides, + stackwise_se_ratio, + stackwise_activation, + include_rescaling, + output_num_filters, + inverted_res_block, + image_shape=(224, 224, 3), + input_activation="hard_swish", + output_activation="hard_swish", + depth_multiplier=1.0, + input_num_filters=16, + **kwargs, + ): + # === Functional Model === + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if include_rescaling: + x = keras.layers.Rescaling(scale=1 / 255)(x) + + input_num_filters = adjust_channels(input_num_filters) + x = keras.layers.Conv2D( + input_num_filters, + kernel_size=3, + strides=(2, 2), + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name="input_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="input_batch_norm", + )(x) + x = keras.layers.Activation(input_activation)(x) + + for stack_index in range(len(stackwise_num_filters)): + filters = adjust_channels( + (stackwise_num_filters[stack_index]) * depth_multiplier + ) + + if inverted_res_block: + x = apply_inverted_res_block( + x, + expansion=stackwise_expansion[stack_index], + filters=filters, + kernel_size=stackwise_kernel_size[stack_index], + stride=stackwise_num_strides[stack_index], + se_ratio=(stackwise_se_ratio[stack_index]), + activation=stackwise_activation[stack_index], + expansion_index=stack_index, + ) + else: + x = apply_depthwise_conv_block( + x, + filters=filters, + kernel_size=3, + stride=stackwise_num_strides[stack_index], + depth_multiplier=depth_multiplier, + block_id=stack_index, + ) + + if output_num_filters is not None: + last_conv_ch = adjust_channels(x.shape[channel_axis] * 6) + + x = keras.layers.Conv2D( + last_conv_ch, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name="output_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="output_batch_norm", + )(x) + x = keras.layers.Activation(output_activation)(x) + + super().__init__(inputs=inputs, outputs=x, **kwargs) + + # === Config === + self.stackwise_expansion = stackwise_expansion + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_kernel_size = stackwise_kernel_size + self.stackwise_num_strides = stackwise_num_strides + self.stackwise_se_ratio = stackwise_se_ratio + self.stackwise_activation = stackwise_activation + self.include_rescaling = include_rescaling + self.depth_multiplier = depth_multiplier + self.input_num_filters = input_num_filters + self.output_num_filters = output_num_filters + self.input_activation = keras.activations.get(input_activation) + self.output_activation = keras.activations.get(output_activation) + self.inverted_res_block = inverted_res_block + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_expansion": self.stackwise_expansion, + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_kernel_size": self.stackwise_kernel_size, + "stackwise_num_strides": self.stackwise_num_strides, + "stackwise_se_ratio": self.stackwise_se_ratio, + "stackwise_activation": self.stackwise_activation, + "include_rescaling": self.include_rescaling, + "image_shape": self.image_shape, + "depth_multiplier": self.depth_multiplier, + "input_num_filters": self.input_num_filters, + "output_num_filters": self.output_num_filters, + "input_activation": keras.activations.serialize( + activation=self.input_activation + ), + "output_activation": keras.activations.serialize( + activation=self.output_activation + ), + "inverted_res_block": self.inverted_res_block, + } + ) + return config + + +def adjust_channels(x, divisor=8, min_value=None): + """Ensure that all layers have a channel number divisible by the `divisor`. + + Args: + x: integer, input value. + divisor: integer, the value by which a channel number should be + divisible, defaults to 8. + min_value: float, optional minimum value for the new tensor. If None, + defaults to value of divisor. + + Returns: + the updated input scalar. + """ + + if min_value is None: + min_value = divisor + + new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) + + # make sure that round down does not go down by more than 10%. + if new_x < 0.9 * x: + new_x += divisor + return new_x + + +def apply_inverted_res_block( + x, + expansion, + filters, + kernel_size, + stride, + se_ratio, + activation, + expansion_index, +): + """An Inverted Residual Block. + + Args: + x: input tensor. + expansion: integer, the expansion ratio, multiplied with infilters to + get the minimum value passed to adjust_channels. + filters: integer, number of filters for convolution layer. + kernel_size: integer, the kernel size for DepthWise Convolutions. + stride: integer, the stride length for DepthWise Convolutions. + se_ratio: float, ratio for bottleneck filters. Number of bottleneck + filters = filters * se_ratio. + activation: the activation layer to use. + expansion_index: integer, a unique identification if you want to use + expanded convolutions. If greater than 0, an additional Conv+BN + layer is added after the expanded convolutional layer. + + Returns: + the updated input tensor. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + activation = keras.activations.get(activation) + shortcut = x + prefix = "expanded_conv_" + infilters = x.shape[channel_axis] + + if expansion_index > 0: + prefix = f"expanded_conv_{expansion_index}_" + + x = keras.layers.Conv2D( + adjust_channels(infilters * expansion), + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=prefix + "expand", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=prefix + "expand_BatchNorm", + )(x) + x = keras.layers.Activation(activation=activation)(x) + + if stride == 2: + x = keras.layers.ZeroPadding2D( + padding=correct_pad_downsample(x, kernel_size), + name=prefix + "depthwise_pad", + )(x) + + x = keras.layers.DepthwiseConv2D( + kernel_size, + strides=stride, + padding="same" if stride == 1 else "valid", + data_format=keras.config.image_data_format(), + use_bias=False, + name=prefix + "depthwise", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=prefix + "depthwise_BatchNorm", + )(x) + x = keras.layers.Activation(activation=activation)(x) + + if se_ratio: + se_filters = adjust_channels(infilters * expansion) + x = SqueezeAndExcite2D( + input=x, + filters=se_filters, + bottleneck_filters=adjust_channels(se_filters * se_ratio), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=prefix + "project", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=prefix + "project_BatchNorm", + )(x) + + if stride == 1 and infilters == filters: + x = keras.layers.Add(name=prefix + "Add")([shortcut, x]) + + return x + + +def apply_depthwise_conv_block( + x, + filters, + kernel_size=3, + depth_multiplier=1, + stride=1, + block_id=1, +): + """Adds a depthwise convolution block. + + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + x: Input tensor of shape `(rows, cols, channels) + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + depth_multiplier: controls the width of the network. + - If `depth_multiplier` < 1.0, proportionally decreases the number + of filters in each layer. + - If `depth_multiplier` > 1.0, proportionally increases the number + of filters in each layer. + - If `depth_multiplier` = 1, default number of filters from the + paper are used at each layer. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" + 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + Returns: + Output tensor of block. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + if stride == 2: + x = keras.layers.ZeroPadding2D( + padding=correct_pad_downsample(x, kernel_size), + name="conv_pad_%d" % block_id, + )(x) + + x = keras.layers.DepthwiseConv2D( + kernel_size, + strides=stride, + padding="same" if stride == 1 else "valid", + data_format=keras.config.image_data_format(), + depth_multiplier=depth_multiplier, + use_bias=False, + name="depthwise_%d" % block_id, + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="depthwise_BatchNorm_%d" % block_id, + )(x) + x = keras.layers.ReLU(6.0)(x) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name="conv_%d" % block_id, + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="BatchNorm_%d" % block_id, + )(x) + return keras.layers.ReLU(6.0)(x) + + +def SqueezeAndExcite2D( + input, + filters, + bottleneck_filters=None, + squeeze_activation="relu", + excite_activation="sigmoid", +): + """ + Description: + This layer applies a content-aware mechanism to adaptively assign + channel-wise weights. It uses global average pooling to compress + feature maps into single values, which are then processed by + two Conv1D layers: the first reduces the dimensionality, and + the second restores it. + Args: + filters: Number of input and output filters. The number of input and + output filters is same. + bottleneck_filters: (Optional) Number of bottleneck filters. Defaults + to `0.25 * filters` + squeeze_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after squeeze convolution. + Defaults to `relu`. + excite_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after excite convolution. + Defaults to `sigmoid`. + """ + if not bottleneck_filters: + bottleneck_filters = filters // 4 + + x = keras.layers.GlobalAveragePooling2D(keepdims=True)(input) + + x = keras.layers.Conv2D( + bottleneck_filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=squeeze_activation, + )(x) + x = keras.layers.Conv2D( + filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=excite_activation, + )(x) + + x = ops.multiply(x, input) + return x + + +def correct_pad_downsample(inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) diff --git a/keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py b/keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py new file mode 100644 index 0000000000..80225abe04 --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_backbone_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class MobileNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_expansion": [1, 4, 6], + "stackwise_num_filters": [4, 8, 16], + "stackwise_kernel_size": [3, 3, 5], + "stackwise_num_strides": [2, 2, 1], + "stackwise_se_ratio": [0.25, None, 0.25], + "stackwise_activation": ["relu", "relu", "hard_swish"], + "include_rescaling": False, + "output_num_filters": 1280, + "input_activation": "hard_swish", + "output_activation": "hard_swish", + "inverted_res_block": True, + "input_num_filters": 16, + "image_shape": (224, 224, 3), + "depth_multiplier": 1, + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 28, 28, 96), + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py new file mode 100644 index 0000000000..3e08f3482c --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier.py @@ -0,0 +1,114 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone + + +@keras_nlp_export("keras_nlp.models.MobileNetImageClassifier") +class MobileNetImageClassifier(ImageClassifier): + """MobileNetV3 image classifier task model. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Args: + backbone: A `keras_nlp.models.MobileNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.MobileNetImageClassifier.from_preset( + "mobilenet_v3_small_imagenet") + classifier.predict(images) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + model = MobileNetBackbone( + stackwise_expansion = [1, 4, 6], + stackwise_filters = [4, 8, 16], + stackwise_kernel_size = [3, 3, 5], + stackwise_stride = [2, 2, 1], + stackwise_se_ratio = [ 0.25, None, 0.25], + stackwise_activation = ["relu", "relu", "hard_swish"], + include_rescaling = False, + output_filter=1280, + activation="hard_swish", + inverted_res_block=True, + ) + classifier = keras_nlp.models.MobileNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = MobileNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py new file mode 100644 index 0000000000..29d00e6d24 --- /dev/null +++ b/keras_nlp/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -0,0 +1,71 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_nlp.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class MobileNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MobileNetBackbone( + stackwise_expansion=[1, 4, 6], + stackwise_num_filters=[4, 8, 16], + stackwise_kernel_size=[3, 3, 5], + stackwise_num_strides=[2, 2, 1], + stackwise_se_ratio=[0.25, None, 0.25], + stackwise_activation=["relu", "relu", "hard_swish"], + include_rescaling=False, + output_num_filters=1280, + input_activation="hard_swish", + output_activation="hard_swish", + inverted_res_block=True, + input_num_filters=16, + image_shape=(224, 224, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 09f470f726417faee597dff54e636afc3854aeed Mon Sep 17 00:00:00 2001 From: pkgoogle <132095473+pkgoogle@users.noreply.github.com> Date: Wed, 28 Aug 2024 22:56:37 +0000 Subject: [PATCH 14/19] Pkgoogle/efficient net migration (#1778) * migrating efficientnet models to keras-hub * merging changes from other sources * autoformatting pass * initial consolidation of efficientnet_backbone * most updates and removing separate implementation * cleanup, autoformatting, keras generalization * removed layer examples outside of effiicient net * many, mainly documentation changes, small test fixes --- keras_nlp/api/models/__init__.py | 3 + keras_nlp/src/models/efficientnet/__init__.py | 13 + .../efficientnet/efficientnet_backbone.py | 569 ++++++++++++++++++ .../efficientnet_backbone_test.py | 146 +++++ .../src/models/efficientnet/fusedmbconv.py | 229 +++++++ .../models/efficientnet/fusedmbconv_test.py | 46 ++ keras_nlp/src/models/efficientnet/mbconv.py | 238 ++++++++ .../src/models/efficientnet/mbconv_test.py | 44 ++ 8 files changed, 1288 insertions(+) create mode 100644 keras_nlp/src/models/efficientnet/__init__.py create mode 100644 keras_nlp/src/models/efficientnet/efficientnet_backbone.py create mode 100644 keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py create mode 100644 keras_nlp/src/models/efficientnet/fusedmbconv.py create mode 100644 keras_nlp/src/models/efficientnet/fusedmbconv_test.py create mode 100644 keras_nlp/src/models/efficientnet/mbconv.py create mode 100644 keras_nlp/src/models/efficientnet/mbconv_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 17b00c1f05..061fbb90b6 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -96,6 +96,9 @@ from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer, ) +from keras_nlp.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone from keras_nlp.src.models.electra.electra_preprocessor import ( ElectraPreprocessor, diff --git a/keras_nlp/src/models/efficientnet/__init__.py b/keras_nlp/src/models/efficientnet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/efficientnet/efficientnet_backbone.py b/keras_nlp/src/models/efficientnet/efficientnet_backbone.py new file mode 100644 index 0000000000..2d7940d6df --- /dev/null +++ b/keras_nlp/src/models/efficientnet/efficientnet_backbone.py @@ -0,0 +1,569 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.efficientnet.fusedmbconv import FusedMBConvBlock +from keras_nlp.src.models.efficientnet.mbconv import MBConvBlock +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone + + +@keras_nlp_export("keras_nlp.models.EfficientNetBackbone") +class EfficientNetBackbone(FeaturePyramidBackbone): + """An EfficientNet backbone model. + + This class encapsulates the architectures for both EfficientNetV1 and + EfficientNetV2. EfficientNetV2 uses Fused-MBConv Blocks and Neural + Architecture Search (NAS) to make models sizes much smaller while still + improving overall model quality. + + References: + - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks] + (https://arxiv.org/abs/1905.11946) (ICML 2019) + - [Based on the original keras.applications EfficientNet] + (https://github.com/keras-team/keras/blob/master/keras/applications/efficientnet.py) + - [EfficientNetV2: Smaller Models and Faster Training] + (https://arxiv.org/abs/2104.00298) (ICML 2021) + + Args: + width_coefficient: float, scaling coefficient for network width. + depth_coefficient: float, scaling coefficient for network depth. + dropout: float, dropout rate at skip connections. The default + value is set to 0.2. + depth_divisor: integer, a unit of network width. The default value is + set to 8. + activation: activation function to use between each convolutional layer. + input_shape: optional shape tuple, it should have exactly 3 input + channels. + stackwise_kernel_sizes: list of ints, the kernel sizes used for each + conv block. + stackwise_num_repeats: list of ints, number of times to repeat each + conv block. + stackwise_input_filters: list of ints, number of input filters for + each conv block. + stackwise_output_filters: list of ints, number of output filters for + each stack in the conv blocks model. + stackwise_expansion_ratios: list of floats, expand ratio passed to the + squeeze and excitation blocks. + stackwise_strides: list of ints, stackwise_strides for each conv block. + stackwise_squeeze_and_excite_ratios: list of ints, the squeeze and + excite ratios passed to the squeeze and excitation blocks. + stackwise_block_types: list of strings. Each value is either 'v1', + 'unfused' or 'fused' depending on the desired blocks. 'v1' uses the + original efficientnet block. FusedMBConvBlock is similar to + MBConvBlock, but instead of using a depthwise convolution and a 1x1 + output convolution blocks fused blocks use a single 3x3 convolution + block. + include_rescaling: bool, whether to rescale the inputs. If set to + True, inputs will be passed through a `Rescaling(1/255.0)` layer. + min_depth: integer, minimum number of filters. Can be None and ignored + if use_depth_divisor_as_min_depth is set to True. + include_initial_padding: bool, whether to include initial zero padding + (as per v1). + use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as + the minimum depth instead of min_depth (as per v1). + cap_round_filter_decrease: bool, whether to cap the max decrease in the + number of filters the rounding process potentially produces + (as per v1). + stem_conv_padding: str, can be 'same' or 'valid'. Padding for the stem. + batch_norm_momentum: float, momentum for the moving average calcualtion + in the batch normalization layers. + + Example: + ```python + # You can customize the EfficientNet architecture: + model = EfficientNetBackbone( + stackwise_kernel_sizes=[3, 3, 3, 3, 3, 3], + stackwise_num_repeats=[2, 4, 4, 6, 9, 15], + stackwise_input_filters=[24, 24, 48, 64, 128, 160], + stackwise_output_filters=[24, 48, 64, 128, 160, 256], + stackwise_expansion_ratios=[1, 4, 4, 4, 6, 6], + stackwise_squeeze_and_excite_ratios=[0.0, 0.0, 0, 0.25, 0.25, 0.25], + stackwise_strides=[1, 2, 2, 2, 1, 2], + stackwise_block_types=[["fused"] * 3 + ["unfused"] * 3], + width_coefficient=1.0, + depth_coefficient=1.0, + include_rescaling=False, + ) + images = np.ones((1, 256, 256, 3)) + outputs = efficientnet.predict(images) + ``` + """ + + def __init__( + self, + *, + width_coefficient, + depth_coefficient, + stackwise_kernel_sizes, + stackwise_num_repeats, + stackwise_input_filters, + stackwise_output_filters, + stackwise_expansion_ratios, + stackwise_squeeze_and_excite_ratios, + stackwise_strides, + stackwise_block_types, + include_rescaling=True, + dropout=0.2, + depth_divisor=8, + min_depth=8, + input_shape=(None, None, 3), + activation="swish", + include_initial_padding=False, + use_depth_divisor_as_min_depth=False, + cap_round_filter_decrease=False, + stem_conv_padding="same", + batch_norm_momentum=0.9, + **kwargs, + ): + img_input = keras.layers.Input(shape=input_shape) + + x = img_input + + if include_rescaling: + # Use common rescaling strategy across keras + x = keras.layers.Rescaling(scale=1.0 / 255.0)(x) + + if include_initial_padding: + x = keras.layers.ZeroPadding2D( + padding=self._correct_pad_downsample(x, 3), + name="stem_conv_pad", + )(x) + + # Build stem + stem_filters = round_filters( + filters=stackwise_input_filters[0], + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + + x = keras.layers.Conv2D( + filters=stem_filters, + kernel_size=3, + strides=2, + padding=stem_conv_padding, + use_bias=False, + kernel_initializer=conv_kernel_initializer(), + name="stem_conv", + )(x) + + x = keras.layers.BatchNormalization( + momentum=batch_norm_momentum, + name="stem_bn", + )(x) + x = keras.layers.Activation(activation, name="stem_activation")(x) + + # Build blocks + block_id = 0 + blocks = float(sum(stackwise_num_repeats)) + + self._pyramid_outputs = {} + curr_pyramid_level = 1 + + for i in range(len(stackwise_kernel_sizes)): + num_repeats = stackwise_num_repeats[i] + input_filters = stackwise_input_filters[i] + output_filters = stackwise_output_filters[i] + + # Update block input and output filters based on depth multiplier. + input_filters = round_filters( + filters=input_filters, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + output_filters = round_filters( + filters=output_filters, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + + repeats = round_repeats( + repeats=num_repeats, + depth_coefficient=depth_coefficient, + ) + strides = stackwise_strides[i] + squeeze_and_excite_ratio = stackwise_squeeze_and_excite_ratios[i] + + for j in range(repeats): + # The first block needs to take care of stride and filter size + # increase. + if j > 0: + strides = 1 + input_filters = output_filters + + if strides != 1: + self._pyramid_outputs[f"P{curr_pyramid_level}"] = x + curr_pyramid_level += 1 + + # 97 is the start of the lowercase alphabet. + letter_identifier = chr(j + 97) + stackwise_block_type = stackwise_block_types[i] + block_name = f"block{i + 1}{letter_identifier}_" + if stackwise_block_type == "v1": + x = self._apply_efficientnet_block( + inputs=x, + filters_in=input_filters, + filters_out=output_filters, + kernel_size=stackwise_kernel_sizes[i], + strides=strides, + expand_ratio=stackwise_expansion_ratios[i], + se_ratio=squeeze_and_excite_ratio, + activation=activation, + dropout=dropout * block_id / blocks, + name=block_name, + ) + else: + block = get_conv_constructor(stackwise_block_type)( + input_filters=input_filters, + output_filters=output_filters, + expand_ratio=stackwise_expansion_ratios[i], + kernel_size=stackwise_kernel_sizes[i], + strides=strides, + se_ratio=squeeze_and_excite_ratio, + activation=activation, + dropout=dropout * block_id / blocks, + batch_norm_momentum=batch_norm_momentum, + name=block_name, + ) + x = block(x) + block_id += 1 + + # Build top + top_filters = round_filters( + filters=1280, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + + x = keras.layers.Conv2D( + filters=top_filters, + kernel_size=1, + padding="same", + strides=1, + kernel_initializer=conv_kernel_initializer(), + use_bias=False, + name="top_conv", + data_format="channels_last", + )(x) + x = keras.layers.BatchNormalization( + momentum=batch_norm_momentum, + name="top_bn", + )(x) + x = keras.layers.Activation( + activation=activation, name="top_activation" + )(x) + + self._pyramid_outputs[f"P{curr_pyramid_level}"] = x + curr_pyramid_level += 1 + + # Create model. + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.include_rescaling = include_rescaling + self.width_coefficient = width_coefficient + self.depth_coefficient = depth_coefficient + self.dropout = dropout + self.depth_divisor = depth_divisor + self.min_depth = min_depth + self.activation = activation + self.stackwise_kernel_sizes = stackwise_kernel_sizes + self.stackwise_num_repeats = stackwise_num_repeats + self.stackwise_input_filters = stackwise_input_filters + self.stackwise_output_filters = stackwise_output_filters + self.stackwise_expansion_ratios = stackwise_expansion_ratios + self.stackwise_squeeze_and_excite_ratios = ( + stackwise_squeeze_and_excite_ratios + ) + self.stackwise_strides = stackwise_strides + self.stackwise_block_types = stackwise_block_types + + self.include_initial_padding = include_initial_padding + self.use_depth_divisor_as_min_depth = use_depth_divisor_as_min_depth + self.cap_round_filter_decrease = cap_round_filter_decrease + self.stem_conv_padding = stem_conv_padding + self.batch_norm_momentum = batch_norm_momentum + + def get_config(self): + config = super().get_config() + config.update( + { + "include_rescaling": self.include_rescaling, + "width_coefficient": self.width_coefficient, + "depth_coefficient": self.depth_coefficient, + "dropout": self.dropout, + "depth_divisor": self.depth_divisor, + "min_depth": self.min_depth, + "activation": self.activation, + "input_shape": self.input_shape[1:], + "stackwise_kernel_sizes": self.stackwise_kernel_sizes, + "stackwise_num_repeats": self.stackwise_num_repeats, + "stackwise_input_filters": self.stackwise_input_filters, + "stackwise_output_filters": self.stackwise_output_filters, + "stackwise_expansion_ratios": self.stackwise_expansion_ratios, + "stackwise_squeeze_and_excite_ratios": self.stackwise_squeeze_and_excite_ratios, + "stackwise_strides": self.stackwise_strides, + "stackwise_block_types": self.stackwise_block_types, + "include_initial_padding": self.include_initial_padding, + "use_depth_divisor_as_min_depth": self.use_depth_divisor_as_min_depth, + "cap_round_filter_decrease": self.cap_round_filter_decrease, + "stem_conv_padding": self.stem_conv_padding, + "batch_norm_momentum": self.batch_norm_momentum, + } + ) + return config + + def _correct_pad_downsample(self, inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) + + def _apply_efficientnet_block( + self, + inputs, + filters_in=32, + filters_out=16, + kernel_size=3, + strides=1, + activation="swish", + expand_ratio=1, + se_ratio=0.0, + dropout=0.0, + name="", + ): + """An inverted residual block. + + Args: + inputs: Tensor, The input tensor of the block + filters_in: integer, the number of input filters. + filters_out: integer, the number of output filters. + kernel_size: integer, the dimension of the convolution window. + strides: integer, the stride of the convolution. + activation: activation function to use between each convolutional layer. + expand_ratio: integer, scaling coefficient for the input filters. + se_ratio: float between 0 and 1, fraction to squeeze the input filters. + dropout: float between 0 and 1, fraction of the input units to drop. + name: string, block label. + + Returns: + output tensor for the block. + """ + filters = filters_in * expand_ratio + if expand_ratio != 1: + x = keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=1, + padding="same", + use_bias=False, + kernel_initializer=conv_kernel_initializer(), + name=name + "expand_conv", + )(inputs) + x = keras.layers.BatchNormalization( + axis=3, + name=name + "expand_bn", + )(x) + x = keras.layers.Activation( + activation, name=name + "expand_activation" + )(x) + else: + x = inputs + + # Depthwise Convolution + if strides == 2: + x = keras.layers.ZeroPadding2D( + padding=self._correct_pad_downsample(x, kernel_size), + name=name + "dwconv_pad", + )(x) + conv_pad = "valid" + else: + conv_pad = "same" + + x = keras.layers.DepthwiseConv2D( + kernel_size=kernel_size, + strides=strides, + padding=conv_pad, + use_bias=False, + depthwise_initializer=conv_kernel_initializer(), + name=name + "dwconv", + )(x) + x = keras.layers.BatchNormalization( + axis=3, + name=name + "dwconv_bn", + )(x) + x = keras.layers.Activation( + activation, name=name + "dwconv_activation" + )(x) + + # Squeeze and Excitation phase + if 0 < se_ratio <= 1: + filters_se = max(1, int(filters_in * se_ratio)) + se = keras.layers.GlobalAveragePooling2D(name=name + "se_squeeze")( + x + ) + se_shape = (1, 1, filters) + se = keras.layers.Reshape(se_shape, name=name + "se_reshape")(se) + se = keras.layers.Conv2D( + filters_se, + 1, + padding="same", + activation=activation, + kernel_initializer=conv_kernel_initializer(), + name=name + "se_reduce", + )(se) + se = keras.layers.Conv2D( + filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=conv_kernel_initializer(), + name=name + "se_expand", + )(se) + x = keras.layers.multiply([x, se], name=name + "se_excite") + + # Output phase + x = keras.layers.Conv2D( + filters=filters_out, + kernel_size=1, + strides=1, + padding="same", + use_bias=False, + kernel_initializer=conv_kernel_initializer(), + name=name + "project", + )(x) + x = keras.layers.BatchNormalization( + axis=3, + name=name + "project_bn", + )(x) + x = keras.layers.Activation( + activation, name=name + "project_activation" + )(x) + + if strides == 1 and filters_in == filters_out: + if dropout > 0: + x = keras.layers.Dropout( + dropout, + noise_shape=(None, 1, 1, 1), + name=name + "drop", + )(x) + x = keras.layers.Add(name=name + "add")([x, inputs]) + + return x + + +def conv_kernel_initializer(scale=2.0): + return keras.initializers.VarianceScaling( + scale=scale, mode="fan_out", distribution="truncated_normal" + ) + + +def round_filters( + filters, + width_coefficient, + min_depth, + depth_divisor, + use_depth_divisor_as_min_depth, + cap_round_filter_decrease, +): + """Round number of filters based on depth multiplier. + + Args: + filters: int, number of filters for Conv layer + width_coefficient: float, denotes the scaling coefficient of network + width + depth_divisor: int, a unit of network width + use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as + the minimum depth instead of min_depth (as per v1) + max_round_filter_decrease: bool, whether to cap the decrease in the + number of filters this process produces (as per v1) + + Returns: + int, new rounded filters value for Conv layer + """ + filters *= width_coefficient + + if use_depth_divisor_as_min_depth: + min_depth = depth_divisor + + new_filters = max( + min_depth, + int(filters + depth_divisor / 2) // depth_divisor * depth_divisor, + ) + + if cap_round_filter_decrease: + # Make sure that round down does not go down by more than 10%. + if new_filters < 0.9 * filters: + new_filters += depth_divisor + + return int(new_filters) + + +def round_repeats(repeats, depth_coefficient): + """Round number of repeats based on depth multiplier. + + Args: + repeats: int, number of repeats of efficientnet block + depth_coefficient: float, denotes the scaling coefficient of network + depth + + Returns: + int, rounded repeats + """ + return int(math.ceil(depth_coefficient * repeats)) + + +def get_conv_constructor(conv_type): + if conv_type == "unfused": + return MBConvBlock + elif conv_type == "fused": + return FusedMBConvBlock + else: + raise ValueError( + "Expected `conv_type` to be " + "one of 'unfused', 'fused', but got " + f"`conv_type={conv_type}`" + ) diff --git a/keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py b/keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py new file mode 100644 index 0000000000..8705ed7af1 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/efficientnet_backbone_test.py @@ -0,0 +1,146 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +import pytest +from absl.testing import parameterized + +from keras_nlp.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class EfficientNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3], + "stackwise_num_repeats": [2, 4, 4, 6, 9, 15], + "stackwise_input_filters": [24, 24, 48, 64, 128, 160], + "stackwise_output_filters": [24, 48, 64, 128, 160, 256], + "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6], + "stackwise_squeeze_and_excite_ratios": [ + 0.0, + 0.0, + 0, + 0.25, + 0.25, + 0.25, + ], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "include_rescaling": False, + } + self.input_data = keras.ops.ones(shape=(8, 224, 224, 3)) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=EfficientNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + run_mixed_precision_check=False, + expected_output_shape=(8, 7, 7, 1280), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=EfficientNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_valid_call(self): + model = EfficientNetBackbone(**self.init_kwargs) + model(self.input_data) + + def test_valid_call_original_v1(self): + original_v1_kwargs = { + "stackwise_kernel_sizes": [3, 3, 5, 3, 5, 5, 3], + "stackwise_num_repeats": [1, 2, 2, 3, 3, 4, 1], + "stackwise_input_filters": [32, 16, 24, 40, 80, 112, 192], + "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320], + "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6], + "stackwise_strides": [1, 2, 2, 2, 1, 2, 1], + "stackwise_squeeze_and_excite_ratios": [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + ], + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "include_rescaling": False, + "stackwise_block_types": ["v1"] * 7, + "min_depth": None, + "include_initial_padding": True, + "use_depth_divisor_as_min_depth": True, + "cap_round_filter_decrease": True, + "stem_conv_padding": "valid", + "batch_norm_momentum": 0.99, + } + model = EfficientNetBackbone(**original_v1_kwargs) + model(self.input_data) + + def test_valid_call_with_rescaling(self): + test_kwargs = self.init_kwargs.copy() + test_kwargs["include_rescaling"] = True + model = EfficientNetBackbone(**test_kwargs) + model(self.input_data) + + def test_feature_pyramid_outputs(self): + backbone = EfficientNetBackbone(**self.init_kwargs) + model = keras.Model( + inputs=backbone.inputs, outputs=backbone.pyramid_outputs + ) + batch_size = 8 + height = width = 256 + outputs = model(keras.ops.ones(shape=(batch_size, height, width, 3))) + levels = ["P1", "P2", "P3", "P4", "P5"] + self.assertEquals(list(outputs.keys()), levels) + self.assertEquals( + outputs["P1"].shape, + (batch_size, height // 2**1, width // 2**1, 24), + ) + self.assertEquals( + outputs["P2"].shape, + (batch_size, height // 2**2, width // 2**2, 48), + ) + self.assertEquals( + outputs["P3"].shape, + (batch_size, height // 2**3, width // 2**3, 64), + ) + self.assertEquals( + outputs["P4"].shape, + (batch_size, height // 2**4, width // 2**4, 160), + ) + self.assertEquals( + outputs["P5"].shape, + (batch_size, height // 2**5, width // 2**5, 1280), + ) + + @parameterized.named_parameters( + ("one_channel", 1), + ("four_channels", 4), + ) + def test_application_variable_input_channels(self, num_channels): + test_kwargs = self.init_kwargs.copy() + test_kwargs["input_shape"] = (None, None, num_channels) + model = EfficientNetBackbone(**test_kwargs) + self.assertEqual(model.output_shape, (None, None, None, 1280)) diff --git a/keras_nlp/src/models/efficientnet/fusedmbconv.py b/keras_nlp/src/models/efficientnet/fusedmbconv.py new file mode 100644 index 0000000000..5c3817c30e --- /dev/null +++ b/keras_nlp/src/models/efficientnet/fusedmbconv.py @@ -0,0 +1,229 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +BN_AXIS = 3 + +CONV_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 2.0, + "mode": "fan_out", + "distribution": "truncated_normal", + }, +} + + +class FusedMBConvBlock(keras.layers.Layer): + """Implementation of the FusedMBConv block + + Also known as a Fused Mobile Inverted Residual Bottleneck block from: + [EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML] + (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) + [EfficientNetV2: Smaller Models and Faster Training] + (https://arxiv.org/abs/2104.00298v3). + + FusedMBConv blocks are based on MBConv blocks, and replace the depthwise and + 1x1 output convolution blocks with a single 3x3 convolution block, fusing + them together - hence the name "FusedMBConv". Alongside MBConv blocks, they + can be used in mobile-oriented and efficient architectures, and are present + in architectures EfficientNet. + + FusedMBConv blocks follow a narrow-wide-narrow structure - expanding a 1x1 + convolution, performing Squeeze-Excitation and then applying a 3x3 + convolution, which is a more efficient operation than conventional + wide-narrow-wide structures. + + As they're frequently used for models to be deployed to edge devices, + they're implemented as a layer for ease of use and re-use. + + Args: + input_filters: int, the number of input filters + output_filters: int, the number of output filters + expand_ratio: default 1, the ratio by which input_filters are multiplied + to expand the structure in the middle expansion phase + kernel_size: default 3, the kernel_size to apply to the expansion phase + convolutions + strides: default 1, the strides to apply to the expansion phase + convolutions + se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase, + and are chosen as the maximum between 1 and input_filters*se_ratio + batch_norm_momentum: default 0.9, the BatchNormalization momentum + activation: default "swish", the activation function used between + convolution operations + dropout: float, the optional dropout rate to apply before the output + convolution, defaults to 0.2 + + Returns: + A tensor representing a feature map, passed through the FusedMBConv + block + + Note: + Not intended to be used outside of the EfficientNet architecture. + """ + + def __init__( + self, + input_filters, + output_filters, + expand_ratio=1, + kernel_size=3, + strides=1, + se_ratio=0.0, + batch_norm_momentum=0.9, + activation="swish", + dropout=0.2, + **kwargs + ): + super().__init__(**kwargs) + self.input_filters = input_filters + self.output_filters = output_filters + self.expand_ratio = expand_ratio + self.kernel_size = kernel_size + self.strides = strides + self.se_ratio = se_ratio + self.batch_norm_momentum = batch_norm_momentum + self.activation = activation + self.dropout = dropout + self.filters = self.input_filters * self.expand_ratio + self.filters_se = max(1, int(input_filters * se_ratio)) + + self.conv1 = keras.layers.Conv2D( + filters=self.filters, + kernel_size=kernel_size, + strides=strides, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "expand_conv", + ) + self.bn1 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "expand_bn", + ) + self.act = keras.layers.Activation( + self.activation, name=self.name + "expand_activation" + ) + + self.bn2 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "bn", + ) + + self.se_conv1 = keras.layers.Conv2D( + self.filters_se, + 1, + padding="same", + activation=self.activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_reduce", + ) + + self.se_conv2 = keras.layers.Conv2D( + self.filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_expand", + ) + + self.output_conv = keras.layers.Conv2D( + filters=self.output_filters, + kernel_size=1 if expand_ratio != 1 else kernel_size, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "project_conv", + ) + + self.bn3 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "project_bn", + ) + + if self.dropout: + self.dropout_layer = keras.layers.Dropout( + self.dropout, + noise_shape=(None, 1, 1, 1), + name=self.name + "drop", + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + # Expansion phase + if self.expand_ratio != 1: + x = self.conv1(inputs) + x = self.bn1(x) + x = self.act(x) + else: + x = inputs + + # Squeeze and excite + if 0 < self.se_ratio <= 1: + se = keras.layers.GlobalAveragePooling2D( + name=self.name + "se_squeeze" + )(x) + if BN_AXIS == 1: + se_shape = (self.filters, 1, 1) + else: + se_shape = (1, 1, self.filters) + + se = keras.layers.Reshape(se_shape, name=self.name + "se_reshape")( + se + ) + + se = self.se_conv1(se) + se = self.se_conv2(se) + + x = keras.layers.multiply([x, se], name=self.name + "se_excite") + + # Output phase: + x = self.output_conv(x) + x = self.bn3(x) + if self.expand_ratio == 1: + x = self.act(x) + + # Residual: + if self.strides == 1 and self.input_filters == self.output_filters: + if self.dropout: + x = self.dropout_layer(x) + x = keras.layers.Add(name=self.name + "add")([x, inputs]) + return x + + def get_config(self): + config = { + "input_filters": self.input_filters, + "output_filters": self.output_filters, + "expand_ratio": self.expand_ratio, + "kernel_size": self.kernel_size, + "strides": self.strides, + "se_ratio": self.se_ratio, + "batch_norm_momentum": self.batch_norm_momentum, + "activation": self.activation, + "dropout": self.dropout, + } + + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_nlp/src/models/efficientnet/fusedmbconv_test.py b/keras_nlp/src/models/efficientnet/fusedmbconv_test.py new file mode 100644 index 0000000000..e59e251156 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/fusedmbconv_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_nlp.src.models.efficientnet.fusedmbconv import FusedMBConvBlock +from keras_nlp.src.tests.test_case import TestCase + + +class FusedMBConvBlockTest(TestCase): + def test_same_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = FusedMBConvBlock(input_filters=32, output_filters=32) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 32)) + self.assertLen(output, 1) + + def test_different_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = FusedMBConvBlock(input_filters=32, output_filters=48) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) + + def test_squeeze_excitation_ratio(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = FusedMBConvBlock( + input_filters=32, output_filters=48, se_ratio=0.25 + ) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) diff --git a/keras_nlp/src/models/efficientnet/mbconv.py b/keras_nlp/src/models/efficientnet/mbconv.py new file mode 100644 index 0000000000..4889606f8f --- /dev/null +++ b/keras_nlp/src/models/efficientnet/mbconv.py @@ -0,0 +1,238 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +BN_AXIS = 3 + +CONV_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 2.0, + "mode": "fan_out", + "distribution": "truncated_normal", + }, +} + + +class MBConvBlock(keras.layers.Layer): + def __init__( + self, + input_filters, + output_filters, + expand_ratio=1, + kernel_size=3, + strides=1, + se_ratio=0.0, + batch_norm_momentum=0.9, + activation="swish", + dropout=0.2, + **kwargs + ): + """Implementation of the MBConv block + + Also known as a Mobile Inverted Residual Bottleneck block from: + [MobileNetV2: Inverted Residuals and Linear Bottlenecks] + (https://arxiv.org/abs/1801.04381v4). + + MBConv blocks are common blocks used in mobile-oriented and efficient + architectures, present in architectures such as MobileNet, EfficientNet, + MaxViT, etc. + + MBConv blocks follow a narrow-wide-narrow structure - expanding a 1x1 + convolution, applying depthwise convolution, and narrowing back to a 1x1 + convolution, which is a more efficient operation than conventional + wide-narrow-wide structures. + + As they're frequently used for models to be deployed to edge devices, + they're implemented as a layer for ease of use and re-use. + + Args: + input_filters: int, the number of input filters + output_filters: int, the optional number of output filters after + Squeeze-Excitation + expand_ratio: default 1, the ratio by which input_filters are + multiplied to expand the structure in the middle expansion phase + kernel_size: default 3, the kernel_size to apply to the expansion + phase convolutions + strides: default 1, the strides to apply to the expansion phase + convolutions + se_ratio: default 0.0, Squeeze-Excitation happens before depthwise + convolution and before output convolution only if the se_ratio + is above 0. The filters used in this phase are chosen as the + maximum between 1 and input_filters*se_ratio + batch_norm_momentum: default 0.9, the BatchNormalization momentum + activation: default "swish", the activation function used between + convolution operations + dropout: float, the optional dropout rate to apply before the output + convolution, defaults to 0.2 + + Returns: + A tensor representing a feature map, passed through the MBConv + block + + + Note: + Not intended to be used outside of the EfficientNet architecture. + """ + + super().__init__(**kwargs) + self.input_filters = input_filters + self.output_filters = output_filters + self.expand_ratio = expand_ratio + self.kernel_size = kernel_size + self.strides = strides + self.se_ratio = se_ratio + self.batch_norm_momentum = batch_norm_momentum + self.activation = activation + self.dropout = dropout + self.filters = self.input_filters * self.expand_ratio + self.filters_se = max(1, int(input_filters * se_ratio)) + + self.conv1 = keras.layers.Conv2D( + filters=self.filters, + kernel_size=1, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "expand_conv", + ) + self.bn1 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "expand_bn", + ) + self.act = keras.layers.Activation( + self.activation, name=self.name + "activation" + ) + self.depthwise = keras.layers.DepthwiseConv2D( + kernel_size=self.kernel_size, + strides=self.strides, + depthwise_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "dwconv2", + ) + + self.bn2 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "bn", + ) + + self.se_conv1 = keras.layers.Conv2D( + self.filters_se, + 1, + padding="same", + activation=self.activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_reduce", + ) + + self.se_conv2 = keras.layers.Conv2D( + self.filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=self.name + "se_expand", + ) + + self.output_conv = keras.layers.Conv2D( + filters=self.output_filters, + kernel_size=1 if expand_ratio != 1 else kernel_size, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format="channels_last", + use_bias=False, + name=self.name + "project_conv", + ) + + self.bn3 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + name=self.name + "project_bn", + ) + + if self.dropout: + self.dropout_layer = keras.layers.Dropout( + self.dropout, + noise_shape=(None, 1, 1, 1), + name=self.name + "drop", + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + # Expansion phase + if self.expand_ratio != 1: + x = self.conv1(inputs) + x = self.bn1(x) + x = self.act(x) + else: + x = inputs + + # Depthwise conv + x = self.depthwise(x) + x = self.bn2(x) + x = self.act(x) + + # Squeeze and excite + if 0 < self.se_ratio <= 1: + se = keras.layers.GlobalAveragePooling2D( + name=self.name + "se_squeeze" + )(x) + if BN_AXIS == 1: + se_shape = (self.filters, 1, 1) + else: + se_shape = (1, 1, self.filters) + se = keras.layers.Reshape(se_shape, name=self.name + "se_reshape")( + se + ) + + se = self.se_conv1(se) + se = self.se_conv2(se) + + x = keras.layers.multiply([x, se], name=self.name + "se_excite") + + # Output phase + x = self.output_conv(x) + x = self.bn3(x) + + if self.strides == 1 and self.input_filters == self.output_filters: + if self.dropout: + x = self.dropout_layer(x) + x = keras.layers.Add(name=self.name + "add")([x, inputs]) + return x + + def get_config(self): + config = { + "input_filters": self.input_filters, + "output_filters": self.output_filters, + "expand_ratio": self.expand_ratio, + "kernel_size": self.kernel_size, + "strides": self.strides, + "se_ratio": self.se_ratio, + "batch_norm_momentum": self.batch_norm_momentum, + "activation": self.activation, + "dropout": self.dropout, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_nlp/src/models/efficientnet/mbconv_test.py b/keras_nlp/src/models/efficientnet/mbconv_test.py new file mode 100644 index 0000000000..d4ba2b1f73 --- /dev/null +++ b/keras_nlp/src/models/efficientnet/mbconv_test.py @@ -0,0 +1,44 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras + +from keras_nlp.src.models.efficientnet.mbconv import MBConvBlock +from keras_nlp.src.tests.test_case import TestCase + + +class MBConvTest(TestCase): + def test_same_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = MBConvBlock(input_filters=32, output_filters=32) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 32)) + self.assertLen(output, 1) + + def test_different_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = MBConvBlock(input_filters=32, output_filters=48) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) + + def test_squeeze_excitation_ratio(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = MBConvBlock(input_filters=32, output_filters=48, se_ratio=0.25) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) From be8888d4841be85da7fd35d06fbdf5a9538ee39a Mon Sep 17 00:00:00 2001 From: gowthamkpr <47574994+gowthamkpr@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:02:52 -0700 Subject: [PATCH 15/19] Add the ResNet_vd backbone (#1766) * Add ResNet_vd to ResNet backbone * Addressed requested parameter changes * Fixed tests and updated comments * Added new parameters to docstring --- .../src/models/resnet/resnet_backbone.py | 413 ++++++++++++++++-- .../src/models/resnet/resnet_backbone_test.py | 29 +- .../resnet/resnet_image_classifier_test.py | 2 + 3 files changed, 416 insertions(+), 28 deletions(-) diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index 31698e0a1c..ca1de9b090 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone): This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( CVPR 2016), [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An improved training procedure in timm](https://arxiv.org/abs/2110.00476)( - NeurIPS 2021 Workshop). + NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with + Convolutional Neural Networks](https://arxiv.org/abs/1812.01187). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -37,18 +38,31 @@ class ResNetBackbone(FeaturePyramidBackbone): the batch normalization and ReLU activation are applied after the convolution layers. + ResNetVd introduces two key modifications to the standard ResNet. First, + the initial convolutional layer is replaced by a series of three + successive convolutional layers. Second, shortcut connections use an + additional pooling operation rather than performing downsampling within + the convolutional layers themselves. + Note that `ResNetBackbone` expects the inputs to be images with a value range of `[0, 255]` when `include_rescaling=True`. Args: + input_conv_filters: list of ints. The number of filters of the initial + convolution(s). + input_conv_kernel_sizes: list of ints. The kernel sizes of the initial + convolution(s). stackwise_num_filters: list of ints. The number of filters for each stack. stackwise_num_blocks: list of ints. The number of blocks for each stack. stackwise_num_strides: list of ints. The number of strides for each stack. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using @@ -88,6 +102,8 @@ class ResNetBackbone(FeaturePyramidBackbone): # Randomly initialized ResNetV2 backbone with a custom config. model = keras_nlp.models.ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2], @@ -101,6 +117,8 @@ class ResNetBackbone(FeaturePyramidBackbone): def __init__( self, + input_conv_filters, + input_conv_kernel_sizes, stackwise_num_filters, stackwise_num_blocks, stackwise_num_strides, @@ -113,6 +131,13 @@ def __init__( dtype=None, **kwargs, ): + if len(input_conv_filters) != len(input_conv_kernel_sizes): + raise ValueError( + "The length of `input_conv_filters` and" + "`input_conv_kernel_sizes` must be the same. " + f"Received: input_conv_filters={input_conv_filters}, " + f"input_conv_kernel_sizes={input_conv_kernel_sizes}." + ) if len(stackwise_num_filters) != len(stackwise_num_blocks) or len( stackwise_num_filters ) != len(stackwise_num_strides): @@ -128,14 +153,20 @@ def __init__( "The first element of `stackwise_num_filters` must be 64. " f"Received: stackwise_num_filters={stackwise_num_filters}" ) - if block_type not in ("basic_block", "bottleneck_block"): + if block_type not in ( + "basic_block", + "bottleneck_block", + "basic_block_vd", + "bottleneck_block_vd", + ): raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) - version = "v1" if not use_pre_activation else "v2" data_format = standardize_data_format(data_format) bn_axis = -1 if data_format == "channels_last" else 1 + num_input_convs = len(input_conv_filters) num_stacks = len(stackwise_num_filters) # === Functional Model === @@ -155,29 +186,56 @@ def __init__( # The padding between torch and tensorflow/jax differs when `strides>1`. # Therefore, we need to manually pad the tensor. x = layers.ZeroPadding2D( - 3, + (input_conv_kernel_sizes[0] - 1) // 2, data_format=data_format, dtype=dtype, name="conv1_pad", )(x) x = layers.Conv2D( - 64, - 7, + input_conv_filters[0], + input_conv_kernel_sizes[0], strides=2, data_format=data_format, use_bias=False, + padding="valid", dtype=dtype, name="conv1_conv", )(x) + for conv_index in range(1, num_input_convs): + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"conv{conv_index}_bn", + )(x) + x = layers.Activation( + "relu", dtype=dtype, name=f"conv{conv_index}_relu" + )(x) + x = layers.Conv2D( + input_conv_filters[conv_index], + input_conv_kernel_sizes[conv_index], + strides=1, + data_format=data_format, + use_bias=False, + padding="same", + dtype=dtype, + name=f"conv{conv_index+1}_conv", + )(x) + if not use_pre_activation: x = layers.BatchNormalization( axis=bn_axis, epsilon=1e-5, momentum=0.9, dtype=dtype, - name="conv1_bn", + name=f"conv{num_input_convs}_bn", + )(x) + x = layers.Activation( + "relu", + dtype=dtype, + name=f"conv{num_input_convs}_relu", )(x) - x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) if use_pre_activation: # A workaround for ResNetV2: we need -inf padding to prevent zeros @@ -210,12 +268,10 @@ def __init__( stride=stackwise_num_strides[stack_index], block_type=block_type, use_pre_activation=use_pre_activation, - first_shortcut=( - block_type == "bottleneck_block" or stack_index > 0 - ), + first_shortcut=(block_type != "basic_block" or stack_index > 0), data_format=data_format, dtype=dtype, - name=f"{version}_stack{stack_index}", + name=f"stack{stack_index}", ) pyramid_outputs[f"P{stack_index + 2}"] = x @@ -248,6 +304,8 @@ def __init__( ) # === Config === + self.input_conv_filters = input_conv_filters + self.input_conv_kernel_sizes = input_conv_kernel_sizes self.stackwise_num_filters = stackwise_num_filters self.stackwise_num_blocks = stackwise_num_blocks self.stackwise_num_strides = stackwise_num_strides @@ -262,6 +320,8 @@ def get_config(self): config = super().get_config() config.update( { + "input_conv_filters": self.input_conv_filters, + "input_conv_kernel_sizes": self.input_conv_kernel_sizes, "stackwise_num_filters": self.stackwise_num_filters, "stackwise_num_blocks": self.stackwise_num_blocks, "stackwise_num_strides": self.stackwise_num_strides, @@ -327,7 +387,10 @@ def apply_basic_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + else: + shortcut = x shortcut = layers.Conv2D( filters, 1, @@ -336,7 +399,7 @@ def apply_basic_block( use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x) + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -452,7 +515,10 @@ def apply_bottleneck_block( )(x_preact) if conv_shortcut: - x = x_preact if x_preact is not None else x + if x_preact is not None: + shortcut = x_preact + else: + shortcut = x shortcut = layers.Conv2D( 4 * filters, 1, @@ -461,7 +527,295 @@ def apply_bottleneck_block( use_bias=False, dtype=dtype, name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + x = layers.Conv2D( + filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + + x = layers.Conv2D( + 4 * filters, + 1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_3_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_basic_block_vd( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a basic residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the basic residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + elif stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x + shortcut = layers.Conv2D( + filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", + )(shortcut) + else: + shortcut = x + + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + padding="valid" if stride > 1 else "same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + x = layers.Conv2D( + filters, + kernel_size, + strides=1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_bottleneck_block_vd( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a bottleneck residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_pre_activation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" + )(x_preact) + + if conv_shortcut: + if x_preact is not None: + shortcut = x_preact + elif stride > 1: + shortcut = layers.AveragePooling2D( + 2, + strides=stride, + data_format=data_format, + dtype=dtype, + padding="same", + )(x) + else: + shortcut = x + shortcut = layers.Conv2D( + 4 * filters, + 1, + strides=1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_0_conv", + )(shortcut) if not use_pre_activation: shortcut = layers.BatchNormalization( axis=bn_axis, @@ -561,8 +915,11 @@ def apply_stack( blocks: int. The number of blocks in the stack. stride: int. The stride length of the first layer in the first block. block_type: str. The block type to stack. One of `"basic_block"` or - `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. - Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + `"bottleneck_block"`, `"basic_block_vd"` or + `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and + ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and + ResNet152 and the `"_vd"` prefix for the respective ResNet_vd + variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet and ResNeXt. first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, @@ -580,17 +937,21 @@ def apply_stack( Output tensor for the stacked blocks. """ if name is None: - version = "v1" if not use_pre_activation else "v2" - name = f"{version}_stack" + name = "stack" if block_type == "basic_block": block_fn = apply_basic_block elif block_type == "bottleneck_block": block_fn = apply_bottleneck_block + elif block_type == "basic_block_vd": + block_fn = apply_basic_block_vd + elif block_type == "bottleneck_block_vd": + block_fn = apply_bottleneck_block_vd else: raise ValueError( - '`block_type` must be either `"basic_block"` or ' - f'`"bottleneck_block"`. Received block_type={block_type}.' + '`block_type` must be either `"basic_block"`, ' + '`"bottleneck_block"`, `"basic_block_vd"` or ' + f'`"bottleneck_block_vd"`. Received block_type={block_type}.' ) for i in range(blocks): if i == 0: diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index a6a30362cd..f52800801f 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -24,6 +24,8 @@ class ResNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], "stackwise_num_filters": [64, 64, 64], "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], @@ -38,18 +40,32 @@ def setUp(self): ("v1_bottleneck", False, "bottleneck_block"), ("v2_basic", True, "basic_block"), ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) def test_backbone_basics(self, use_pre_activation, block_type): init_kwargs = self.init_kwargs.copy() init_kwargs.update( - {"block_type": block_type, "use_pre_activation": use_pre_activation} + { + "block_type": block_type, + "use_pre_activation": use_pre_activation, + } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_vision_backbone_test( cls=ResNetBackbone, init_kwargs=init_kwargs, input_data=self.input_data, expected_output_shape=( - (2, 64) if block_type == "basic_block" else (2, 256) + (2, 64) + if block_type in ("basic_block", "basic_block_vd") + else (2, 256) ), ) @@ -76,6 +92,8 @@ def test_pyramid_output_format(self): ("v1_bottleneck", False, "bottleneck_block"), ("v2_basic", True, "basic_block"), ("v2_bottleneck", True, "bottleneck_block"), + ("vd_basic", False, "basic_block_vd"), + ("vd_bottleneck", False, "bottleneck_block_vd"), ) @pytest.mark.large def test_saved_model(self, use_pre_activation, block_type): @@ -87,6 +105,13 @@ def test_saved_model(self, use_pre_activation, block_type): "image_shape": (None, None, 3), } ) + if block_type in ("basic_block_vd", "bottleneck_block_vd"): + init_kwargs.update( + { + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + } + ) self.run_model_saving_test( cls=ResNetBackbone, init_kwargs=init_kwargs, diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index 893ec42487..da06c80320 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -26,6 +26,8 @@ def setUp(self): self.images = ops.ones((2, 16, 16, 3)) self.labels = [0, 3] self.backbone = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 64, 64], stackwise_num_blocks=[2, 2, 2], stackwise_num_strides=[1, 2, 2], From 536474a0ab2d55365ddf2e5faaf5968f7e70a767 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Thu, 29 Aug 2024 07:05:59 +0800 Subject: [PATCH 16/19] Add `VAEImageDecoder` for StableDiffusionV3 (#1796) * Add `VAEImageDecoder` for StableDiffusionV3 * Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention` --- .../stable_diffusion_v3/vae_attention.py | 126 +++++++++++++ .../stable_diffusion_v3/vae_image_decoder.py | 177 ++++++++++++++++++ 2 files changed, 303 insertions(+) create mode 100644 keras_nlp/src/models/stable_diffusion_v3/vae_attention.py create mode 100644 keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py diff --git a/keras_nlp/src/models/stable_diffusion_v3/vae_attention.py b/keras_nlp/src/models/stable_diffusion_v3/vae_attention.py new file mode 100644 index 0000000000..1fba90d681 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/vae_attention.py @@ -0,0 +1,126 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras import layers +from keras import ops + +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +class VAEAttention(layers.Layer): + def __init__(self, filters, groups=32, data_format=None, **kwargs): + super().__init__(**kwargs) + self.filters = filters + self.data_format = standardize_data_format(data_format) + gn_axis = -1 if self.data_format == "channels_last" else 1 + + self.group_norm = layers.GroupNormalization( + groups=groups, + axis=gn_axis, + epsilon=1e-6, + dtype=self.dtype_policy, + name="group_norm", + ) + self.query_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="query_conv2d", + ) + self.key_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="key_conv2d", + ) + self.value_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="value_conv2d", + ) + self.softmax = layers.Softmax(dtype="float32") + self.output_conv2d = layers.Conv2D( + filters, + 1, + 1, + data_format=self.data_format, + dtype=self.dtype_policy, + name="output_conv2d", + ) + + self.groups = groups + self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters)) + + def build(self, input_shape): + self.group_norm.build(input_shape) + self.query_conv2d.build(input_shape) + self.key_conv2d.build(input_shape) + self.value_conv2d.build(input_shape) + self.output_conv2d.build(input_shape) + + def call(self, inputs, training=None): + x = self.group_norm(inputs) + query = self.query_conv2d(x) + key = self.key_conv2d(x) + value = self.value_conv2d(x) + + if self.data_format == "channels_first": + query = ops.transpose(query, (0, 2, 3, 1)) + key = ops.transpose(key, (0, 2, 3, 1)) + value = ops.transpose(value, (0, 2, 3, 1)) + shape = ops.shape(inputs) + b = shape[0] + query = ops.reshape(query, (b, -1, self.filters)) + key = ops.reshape(key, (b, -1, self.filters)) + value = ops.reshape(value, (b, -1, self.filters)) + + # Compute attention. + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_filters, query.dtype) + ) + # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1] + attention_scores = ops.einsum("abc,adc->abd", query, key) + attention_scores = ops.cast( + self.softmax(attention_scores), self.compute_dtype + ) + # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C] + attention_output = ops.einsum("abc,adb->adc", value, attention_scores) + x = ops.reshape(attention_output, shape) + + x = self.output_conv2d(x) + if self.data_format == "channels_first": + x = ops.transpose(x, (0, 3, 1, 2)) + x = ops.add(x, inputs) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "groups": self.groups, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py b/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py new file mode 100644 index 0000000000..3f058addb7 --- /dev/null +++ b/keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py @@ -0,0 +1,177 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.models.stable_diffusion_v3.vae_attention import VAEAttention +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +class VAEImageDecoder(keras.Model): + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + output_channels=3, + latent_shape=(None, None, 16), + data_format=None, + dtype=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + gn_axis = -1 if data_format == "channels_last" else 1 + + # === Functional Model === + latent_inputs = layers.Input(shape=latent_shape) + + x = layers.Conv2D( + stackwise_num_filters[0], + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name="input_projection", + )(latent_inputs) + x = apply_resnet_block( + x, + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_block0", + ) + x = VAEAttention( + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_attention", + )(x) + x = apply_resnet_block( + x, + stackwise_num_filters[0], + data_format=data_format, + dtype=dtype, + name="input_block1", + ) + + # Stacks. + for i, filters in enumerate(stackwise_num_filters): + for j in range(stackwise_num_blocks[i]): + x = apply_resnet_block( + x, + filters, + data_format=data_format, + dtype=dtype, + name=f"block{i}_{j}", + ) + if i != len(stackwise_num_filters) - 1: + # No upsamling in the last blcok. + x = layers.UpSampling2D( + 2, + data_format=data_format, + dtype=dtype, + name=f"upsample_{i}", + )(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"upsample_{i}_conv", + )(x) + + # Ouput block. + x = layers.GroupNormalization( + groups=32, + axis=gn_axis, + epsilon=1e-6, + dtype=dtype, + name="output_norm", + )(x) + x = layers.Activation("swish", dtype=dtype, name="output_activation")(x) + image_outputs = layers.Conv2D( + output_channels, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name="output_projection", + )(x) + super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.output_channels = output_channels + self.latent_shape = latent_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "output_channels": self.output_channels, + "image_shape": self.latent_shape, + } + ) + return config + + +def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None): + data_format = standardize_data_format(data_format) + gn_axis = -1 if data_format == "channels_last" else 1 + input_filters = x.shape[gn_axis] + + residual = x + x = layers.GroupNormalization( + groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1" + )(x) + x = layers.Activation("swish", dtype=dtype)(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"{name}_conv1", + )(x) + x = layers.GroupNormalization( + groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2" + )(x) + x = layers.Activation("swish")(x) + x = layers.Conv2D( + filters, + 3, + 1, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"{name}_conv2", + )(x) + if input_filters != filters: + residual = layers.Conv2D( + filters, + 1, + 1, + data_format=data_format, + dtype=dtype, + name=f"{name}_residual_projection", + )(residual) + x = layers.Add(dtype=dtype)([residual, x]) + return x From 0fbd84bbedbe597d404803633c511ec48b54b755 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Thu, 29 Aug 2024 07:06:53 +0800 Subject: [PATCH 17/19] Replace `Backbone` with `keras.Model` in `CLIPTextEncoder` and `T5XXLTextEncoder` (#1802) --- .../stable_diffusion_v3/clip_text_encoder.py | 14 +++++++++++--- .../stable_diffusion_v3/t5_xxl_text_encoder.py | 13 ++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py index d4a5cbc94f..899ae665c7 100644 --- a/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py +++ b/keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras from keras import layers from keras import ops from keras_nlp.src.layers.modeling.token_and_position_embedding import ( TokenAndPositionEmbedding, ) -from keras_nlp.src.models.backbone import Backbone from keras_nlp.src.models.stable_diffusion_v3.clip_encoder_block import ( CLIPEncoderBlock, ) -class CLIPTextEncoder(Backbone): +class CLIPTextEncoder(keras.Model): def __init__( self, embedding_dim, @@ -108,7 +108,6 @@ def __init__( super().__init__( inputs={"encoder_token_ids": encoder_token_ids}, outputs=outputs, - dtype=dtype, **kwargs, ) @@ -123,6 +122,15 @@ def __init__( self.vocabulary_size = vocabulary_size self.sequence_length = sequence_length + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) + def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py index 9f4e5ae3a1..5c44395489 100644 --- a/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +++ b/keras_nlp/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py @@ -16,12 +16,11 @@ from keras_nlp.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, ) -from keras_nlp.src.models.backbone import Backbone from keras_nlp.src.models.t5.t5_layer_norm import T5LayerNorm from keras_nlp.src.models.t5.t5_transformer_layer import T5TransformerLayer -class T5XXLTextEncoder(Backbone): +class T5XXLTextEncoder(keras.Model): def __init__( self, vocabulary_size, @@ -111,7 +110,6 @@ def __init__( "encoder_padding_mask": encoder_padding_mask_input, }, outputs=encoder_output, - dtype=dtype, **kwargs, ) @@ -128,6 +126,15 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.tie_embedding_weights = tie_embedding_weights + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) + def get_config(self): config = super().get_config() config.update( From e97865d1e6be4d78dff996267d890df13ccd2eee Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 30 Aug 2024 20:41:42 +0530 Subject: [PATCH 18/19] video_classifier wrapper added --- keras_nlp/src/models/video_classifier.py | 90 ++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 keras_nlp/src/models/video_classifier.py diff --git a/keras_nlp/src/models/video_classifier.py b/keras_nlp/src/models/video_classifier.py new file mode 100644 index 0000000000..9b64cda303 --- /dev/null +++ b/keras_nlp/src/models/video_classifier.py @@ -0,0 +1,90 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.task import Task + + +@keras_nlp_export("keras_nlp.models.VideoClassifier") +class VideoClassifier(Task): + """Base class for all video classification tasks. + + `VideoClassifier` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + video classification. `VideoClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is an integer from `[0, num_classes)`. + + All `VideoClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `VideoClassifier` task for training. + + The `VideoClassifier` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the classification task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(1e-4) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.SparseCategoricalCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.SparseCategoricalAccuracy()] + self.model.compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) From 955f5f166fbb33a6794c2938b42427405682fbf0 Mon Sep 17 00:00:00 2001 From: ushareng Date: Tue, 3 Sep 2024 17:06:35 +0530 Subject: [PATCH 19/19] added video classifier in api --- keras_nlp/api/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 6f7e08c520..22c6a8b61d 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -212,6 +212,7 @@ from keras_nlp.src.models.task import Task from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_nlp.src.models.video_classifier import VideoClassifier from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, )