From d1f53c7adb36f35c9bd0b640e639e9b6bdd5a5c8 Mon Sep 17 00:00:00 2001 From: Tilps Date: Wed, 29 Dec 2021 21:09:57 +1100 Subject: [PATCH] Allow to choose which blocks have SE units Inspired by testing which suggests a 15 block net only needs 3 SE units to get the same loss values. --- tf/tfprocess.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tf/tfprocess.py b/tf/tfprocess.py index bcfc1c0d..c67a6df6 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -108,6 +108,7 @@ def __init__(self, cfg): loss_scale = self.cfg['training'].get('loss_scale', 128) self.virtual_batch_size = self.cfg['model'].get( 'virtual_batch_size', None) + self.SE_blocks = self.cfg['model'].get('se_blocks', None) if precision == 'single': self.model_dtype = tf.float32 @@ -1036,7 +1037,9 @@ def batch_norm(self, input, name, scale=False): virtual_batch_size=self.virtual_batch_size, name=name)(input) - def squeeze_excitation(self, inputs, channels, name): + def squeeze_excitation(self, inputs, channels, block_num, name): + if self.SE_blocks is not None and block_num not in self.SE_blocks: + return inputs assert channels % self.SE_ratio == 0 pooled = tf.keras.layers.GlobalAveragePooling2D( @@ -1069,7 +1072,7 @@ def conv_block(self, return tf.keras.layers.Activation('relu')(self.batch_norm( conv, name=name + '/bn', scale=bn_scale)) - def residual_block(self, inputs, channels, name): + def residual_block(self, inputs, channels, block_num, name): conv1 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, @@ -1094,6 +1097,7 @@ def residual_block(self, inputs, channels, name): name + '/2/bn', scale=True), channels, + block_num, name=name + '/se') return tf.keras.layers.Activation('relu')(tf.keras.layers.add( [inputs, out2])) @@ -1107,6 +1111,7 @@ def construct_net(self, inputs): for i in range(self.RESIDUAL_BLOCKS): flow = self.residual_block(flow, self.RESIDUAL_FILTERS, + i + 1, name='residual_{}'.format(i + 1)) # Policy head