diff --git a/googlenet/keras/GoogLeNet_v1.ipynb b/googlenet/keras/GoogLeNet_v1.ipynb index 4212aa9..d70dd99 100644 --- a/googlenet/keras/GoogLeNet_v1.ipynb +++ b/googlenet/keras/GoogLeNet_v1.ipynb @@ -54,8 +54,10 @@ "readonly GH_REPO=\"reimplementing-ml-papers\"\n", "readonly GH_BRANCH=\"main\"\n", "\n", - "# Download the LocalResponseNormalization layer from AlexNet.\n", - "for path in alexnet/local_response_normalization.py ; do\n", + "# Download the LocalResponseNormalization layer from AlexNet and the GoogLeNet\n", + "# implementation.\n", + "for path in alexnet/local_response_normalization.py \\\n", + " googlenet/keras/googlenet.py ; do\n", " module=\"$(basename \"${path}\")\"\n", " if ! [ -f \"${module}\" ]; then\n", " curl -s -o \"${module}\" \"https://raw.githubusercontent.com/${GH_USER}/${GH_REPO}/${GH_BRANCH}/${path}\"\n", @@ -63,124 +65,6 @@ "done" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": null - }, - "outputs": [], - "source": [ - "from typing import Callable, Optional, List, Tuple, Union\n", - "\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from keras import Input, Model, Sequential\n", - "from keras.layers import Activation, AvgPool2D, Concatenate, Conv2D, Dense, Dropout, Flatten, Layer, MaxPool2D\n", - "\n", - "from local_response_normalization import LocalResponseNormalization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": null - }, - "outputs": [], - "source": [ - "class Inception(Layer):\n", - " filters_1x1: int\n", - " filters_1x1_reduce_3x3: int\n", - " filters_3x3: int\n", - " filters_1x1_reduce_5x5: int\n", - " filters_5x5: int\n", - " pool_proj: int\n", - " module_name: str\n", - "\n", - " conv_1x1: Conv2D\n", - " conv_1x1_3x3: Sequential\n", - " conv_1x1_5x5: Sequential\n", - " max_pool_conv: Sequential\n", - "\n", - " def __init__(self,\n", - " filters_1x1: int,\n", - " filters_1x1_reduce_3x3: int,\n", - " filters_3x3: int,\n", - " filters_1x1_reduce_5x5: int,\n", - " filters_5x5: int,\n", - " pool_proj: int,\n", - " name: str,\n", - " **kwargs):\n", - " super().__init__(name=name, **kwargs)\n", - "\n", - " self.filters_1x1 = filters_1x1\n", - " self.filters_1x1_reduce_3x3 = filters_1x1_reduce_3x3\n", - " self.filters_3x3 = filters_3x3\n", - " self.filters_1x1_reduce_5x5 = filters_1x1_reduce_5x5\n", - " self.filters_5x5 = filters_5x5\n", - " self.pool_proj = pool_proj\n", - " self.module_name = name\n", - "\n", - " def _conv2d(self, filters: int, kernel_size: int, name: str) -> Conv2D:\n", - " return Conv2D(filters=filters, kernel_size=kernel_size,\n", - " padding='same', activation='relu',\n", - " name=f'{self.module_name}_{name}')\n", - "\n", - " def build(\n", - " self, input_shape: Union[List[Optional[int]],\n", - " Tuple[Optional[int], int, int, int]]) -> None:\n", - " \"\"\"Builds internal structures to prepare for model training.\"\"\"\n", - " self.conv_1x1 = self._conv2d(self.filters_1x1, 1, 'Conv_1x1')\n", - "\n", - " self.conv_1x1_3x3 = Sequential([\n", - " self._conv2d(self.filters_1x1_reduce_3x3, 1, 'Conv_1x1_3x3'),\n", - " self._conv2d(self.filters_3x3, 3, 'Conv_3x3'),\n", - " ])\n", - "\n", - " self.conv_1x1_5x5 = Sequential([\n", - " self._conv2d(self.filters_1x1_reduce_5x5, 1, 'Conv_1x1_5x5'),\n", - " self._conv2d(self.filters_5x5, 5, 'Conv_5x5'),\n", - " ])\n", - "\n", - " self.max_pool_conv = Sequential([\n", - " MaxPool2D(3, 1, padding='same', name=f\"{self.module_name}_MaxPool\"),\n", - " self._conv2d(self.pool_proj, 1, 'MaxPool_Conv_1x1'),\n", - " ])\n", - "\n", - " def call(self, inputs: tf.Tensor) -> tf.Tensor:\n", - " return Concatenate(axis=-1)([\n", - " self.conv_1x1(inputs),\n", - " self.conv_1x1_3x3(inputs),\n", - " self.conv_1x1_5x5(inputs),\n", - " self.max_pool_conv(inputs),\n", - " ])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": null - }, - "outputs": [], - "source": [ - "def SequentialPassthrough(layers: List[Layer]) -> Callable[[tf.Tensor], tf.Tensor]:\n", - " \"\"\"Similar to Keras' `Sequential`, but shows all layers transparently.\n", - "\n", - " Instead of hiding all the layers behind another abstraction called\n", - " `Sequential`, this function explicitly shows all the layers involved in the\n", - " model, so they're visible when calling `model.summary()`.\n", - " \"\"\"\n", - " def process_layers(input_: tf.Tensor) -> tf.Tensor:\n", - " x = input_\n", - " for layer in layers:\n", - " x = layer(x)\n", - " return x\n", - "\n", - " return process_layers" - ] - }, { "cell_type": "code", "execution_count": null, @@ -289,65 +173,9 @@ } ], "source": [ - "input_ = Input(shape=(224, 224, 3), name='Input')\n", - "\n", - "x = SequentialPassthrough([\n", - " Conv2D(64, 7, 2, activation='relu', padding='same', name='Conv1'),\n", - " MaxPool2D(3, 2, padding='same', name='MaxPool_1'),\n", - " LocalResponseNormalization(name='LRN1'),\n", - " Conv2D(192, 1, activation='relu', padding='valid', name='Conv_2'),\n", - " Conv2D(192, 3, activation='relu', padding='same', name='Conv_3'),\n", - " LocalResponseNormalization(name='LRN2'),\n", - " MaxPool2D(3, 2, padding='same', name='MaxPool_2'),\n", - " Inception(64, 96, 128, 16, 32, 32, name='Inception_3a'),\n", - " Inception(128, 128, 192, 32, 96, 64, name='Inception_3b'),\n", - " MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same', name='MaxPool_3'),\n", - " Inception(192, 96, 208, 16, 48, 64, name='Inception_4a'),\n", - "])(input_)\n", - "\n", - "# Output 0 branch\n", - "output0 = SequentialPassthrough([\n", - " AvgPool2D(5, 3, padding='valid', name='AvgPool_out0'),\n", - " Conv2D(128, 1, padding='same', activation='relu', name='Conv2D_out0'),\n", - " Flatten(name='Flatten_out0'),\n", - " Dense(1000, activation='relu', name='FC_1_out0'), ## params\n", - " Dropout(0.7, name='Dropout_out0'),\n", - " Dense(1000, activation='relu', name='FC_2_out0'), ## params\n", - " Activation('softmax', name='Activation_out0'),\n", - "])(x)\n", - "\n", - "# Continue with more Inception modules\n", - "y = SequentialPassthrough([\n", - " Inception(160, 112, 224, 24, 64, 64, name='Inception_4b'),\n", - " Inception(128, 128, 256, 24, 64, 64, name='Inception_4c'),\n", - " Inception(112, 144, 288, 32, 96, 64, name='Inception_4d'),\n", - "])(x)\n", - "\n", - "# Output 1 branch\n", - "output1 = SequentialPassthrough([\n", - " AvgPool2D(5, 3, padding='valid', name='AvgPool_out1'),\n", - " Conv2D(128, 1, padding='same', activation='relu', name='Conv2D_out1'),\n", - " Flatten(name='Flatten_out1'),\n", - " Dense(1000, activation='relu', name='FC_1_out1'), ## params\n", - " Dropout(0.7, name='Dropout_out1'),\n", - " Dense(1000, activation='relu', name='FC_2_out1'), ## params\n", - " Activation('softmax', name='Activation_out1'),\n", - "])(y)\n", - "\n", - "# Continue with more Inception modules\n", - "output2 = SequentialPassthrough([\n", - " Inception(256, 160, 320, 32, 128, 128, name='Inception_4e'),\n", - " MaxPool2D(3, 2, padding='same', name='MaxPool_4'),\n", - " Inception(256, 160, 320, 32, 128, 128, name='Inception_5a'),\n", - " Inception(384, 192, 384, 48, 128, 128, name='Inception_5b'),\n", - " AvgPool2D(7, padding='valid', name='AvgPool_out2'),\n", - " Flatten(name='Flatten_out2'),\n", - " Dropout(0.4, name='Dropout_out2'),\n", - " Dense(1000, activation='relu', name='FC_out2'),\n", - " Activation('softmax', name='Activation_out2'),\n", - "])(y)\n", + "from googlenet import GoogLeNet\n", "\n", - "model = Model(inputs=input_, outputs=[output0, output1, output2], name='GoogLeNet')\n", + "model = GoogLeNet()\n", "model.summary()" ] } diff --git a/googlenet/keras/googlenet.py b/googlenet/keras/googlenet.py new file mode 100644 index 0000000..2fc5a82 --- /dev/null +++ b/googlenet/keras/googlenet.py @@ -0,0 +1,169 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Callable, Optional, List, Tuple, Union + +import tensorflow as tf +from tensorflow import keras +from keras import Input, Model, Sequential +from keras.layers import Activation, AvgPool2D, Concatenate, Conv2D, Dense, Dropout, Flatten, Layer, MaxPool2D + +from local_response_normalization import LocalResponseNormalization + +class Inception(Layer): + filters_1x1: int + filters_1x1_reduce_3x3: int + filters_3x3: int + filters_1x1_reduce_5x5: int + filters_5x5: int + pool_proj: int + module_name: str + + conv_1x1: Conv2D + conv_1x1_3x3: Sequential + conv_1x1_5x5: Sequential + max_pool_conv: Sequential + + def __init__(self, + filters_1x1: int, + filters_1x1_reduce_3x3: int, + filters_3x3: int, + filters_1x1_reduce_5x5: int, + filters_5x5: int, + pool_proj: int, + name: str, + **kwargs): + super().__init__(name=name, **kwargs) + + self.filters_1x1 = filters_1x1 + self.filters_1x1_reduce_3x3 = filters_1x1_reduce_3x3 + self.filters_3x3 = filters_3x3 + self.filters_1x1_reduce_5x5 = filters_1x1_reduce_5x5 + self.filters_5x5 = filters_5x5 + self.pool_proj = pool_proj + self.module_name = name + + def _conv2d(self, filters: int, kernel_size: int, name: str) -> Conv2D: + return Conv2D(filters=filters, kernel_size=kernel_size, + padding='same', activation='relu', + name=f'{self.module_name}_{name}') + + def build( + self, input_shape: Union[List[Optional[int]], + Tuple[Optional[int], int, int, int]]) -> None: + """Builds internal structures to prepare for model training.""" + self.conv_1x1 = self._conv2d(self.filters_1x1, 1, 'Conv_1x1') + + self.conv_1x1_3x3 = Sequential([ + self._conv2d(self.filters_1x1_reduce_3x3, 1, 'Conv_1x1_3x3'), + self._conv2d(self.filters_3x3, 3, 'Conv_3x3'), + ]) + + self.conv_1x1_5x5 = Sequential([ + self._conv2d(self.filters_1x1_reduce_5x5, 1, 'Conv_1x1_5x5'), + self._conv2d(self.filters_5x5, 5, 'Conv_5x5'), + ]) + + self.max_pool_conv = Sequential([ + MaxPool2D(3, 1, padding='same', name=f"{self.module_name}_MaxPool"), + self._conv2d(self.pool_proj, 1, 'MaxPool_Conv_1x1'), + ]) + + def call(self, inputs: tf.Tensor) -> tf.Tensor: + return Concatenate(axis=-1)([ + self.conv_1x1(inputs), + self.conv_1x1_3x3(inputs), + self.conv_1x1_5x5(inputs), + self.max_pool_conv(inputs), + ]) + + +def SequentialPassthrough(layers: List[Layer]) -> Callable[[tf.Tensor], tf.Tensor]: + """Similar to Keras' `Sequential`, but shows all layers transparently. + + Instead of hiding all the layers behind another abstraction called + `Sequential`, this function explicitly shows all the layers involved in the + model, so they're visible when calling `model.summary()`. + """ + def process_layers(input_: tf.Tensor) -> tf.Tensor: + x = input_ + for layer in layers: + x = layer(x) + return x + + return process_layers + +def GoogLeNet() -> Model: + """GoogLeNet model implementation.""" + + input_: Input = Input(shape=(224, 224, 3), name='Input') + + x = SequentialPassthrough([ + Conv2D(64, 7, 2, activation='relu', padding='same', name='Conv1'), + MaxPool2D(3, 2, padding='same', name='MaxPool_1'), + LocalResponseNormalization(name='LRN1'), + Conv2D(192, 1, activation='relu', padding='valid', name='Conv_2'), + Conv2D(192, 3, activation='relu', padding='same', name='Conv_3'), + LocalResponseNormalization(name='LRN2'), + MaxPool2D(3, 2, padding='same', name='MaxPool_2'), + Inception(64, 96, 128, 16, 32, 32, name='Inception_3a'), + Inception(128, 128, 192, 32, 96, 64, name='Inception_3b'), + MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same', name='MaxPool_3'), + Inception(192, 96, 208, 16, 48, 64, name='Inception_4a'), + ])(input_) + + # Output 0 branch + output0 = SequentialPassthrough([ + AvgPool2D(5, 3, padding='valid', name='AvgPool_out0'), + Conv2D(128, 1, padding='same', activation='relu', name='Conv2D_out0'), + Flatten(name='Flatten_out0'), + Dense(1000, activation='relu', name='FC_1_out0'), ## params + Dropout(0.7, name='Dropout_out0'), + Dense(1000, activation='relu', name='FC_2_out0'), ## params + Activation('softmax', name='Activation_out0'), + ])(x) + + # Continue with more Inception modules + y = SequentialPassthrough([ + Inception(160, 112, 224, 24, 64, 64, name='Inception_4b'), + Inception(128, 128, 256, 24, 64, 64, name='Inception_4c'), + Inception(112, 144, 288, 32, 96, 64, name='Inception_4d'), + ])(x) + + # Output 1 branch + output1 = SequentialPassthrough([ + AvgPool2D(5, 3, padding='valid', name='AvgPool_out1'), + Conv2D(128, 1, padding='same', activation='relu', name='Conv2D_out1'), + Flatten(name='Flatten_out1'), + Dense(1000, activation='relu', name='FC_1_out1'), ## params + Dropout(0.7, name='Dropout_out1'), + Dense(1000, activation='relu', name='FC_2_out1'), ## params + Activation('softmax', name='Activation_out1'), + ])(y) + + # Continue with more Inception modules + output2 = SequentialPassthrough([ + Inception(256, 160, 320, 32, 128, 128, name='Inception_4e'), + MaxPool2D(3, 2, padding='same', name='MaxPool_4'), + Inception(256, 160, 320, 32, 128, 128, name='Inception_5a'), + Inception(384, 192, 384, 48, 128, 128, name='Inception_5b'), + AvgPool2D(7, padding='valid', name='AvgPool_out2'), + Flatten(name='Flatten_out2'), + Dropout(0.4, name='Dropout_out2'), + Dense(1000, activation='relu', name='FC_out2'), + Activation('softmax', name='Activation_out2'), + ])(y) + + return Model(inputs=input_, outputs=[output0, output1, output2], name='GoogLeNet')