diff --git a/tf/tfprocess.py b/tf/tfprocess.py index bcfc1c0d..c67a6df6 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -108,6 +108,7 @@ def __init__(self, cfg): loss_scale = self.cfg['training'].get('loss_scale', 128) self.virtual_batch_size = self.cfg['model'].get( 'virtual_batch_size', None) + self.SE_blocks = self.cfg['model'].get('se_blocks', None) if precision == 'single': self.model_dtype = tf.float32 @@ -1036,7 +1037,9 @@ def batch_norm(self, input, name, scale=False): virtual_batch_size=self.virtual_batch_size, name=name)(input) - def squeeze_excitation(self, inputs, channels, name): + def squeeze_excitation(self, inputs, channels, block_num, name): + if self.SE_blocks is not None and block_num not in self.SE_blocks: + return inputs assert channels % self.SE_ratio == 0 pooled = tf.keras.layers.GlobalAveragePooling2D( @@ -1069,7 +1072,7 @@ def conv_block(self, return tf.keras.layers.Activation('relu')(self.batch_norm( conv, name=name + '/bn', scale=bn_scale)) - def residual_block(self, inputs, channels, name): + def residual_block(self, inputs, channels, block_num, name): conv1 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, @@ -1094,6 +1097,7 @@ def residual_block(self, inputs, channels, name): name + '/2/bn', scale=True), channels, + block_num, name=name + '/se') return tf.keras.layers.Activation('relu')(tf.keras.layers.add( [inputs, out2])) @@ -1107,6 +1111,7 @@ def construct_net(self, inputs): for i in range(self.RESIDUAL_BLOCKS): flow = self.residual_block(flow, self.RESIDUAL_FILTERS, + i + 1, name='residual_{}'.format(i + 1)) # Policy head