diff --git a/tf/tfprocess.py b/tf/tfprocess.py index bcfc1c0d..33bf68d2 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -101,6 +101,7 @@ def __init__(self, cfg): # Network structure self.RESIDUAL_FILTERS = self.cfg['model']['filters'] + self.RESIDUAL_INNER_FILTERS = self.cfg['model'].get('inner_filters', self.RESIDUAL_FILTERS) self.RESIDUAL_BLOCKS = self.cfg['model']['residual_blocks'] self.SE_ratio = self.cfg['model']['se_ratio'] self.policy_channels = self.cfg['model'].get('policy_channels', 32) @@ -1070,7 +1071,7 @@ def conv_block(self, conv, name=name + '/bn', scale=bn_scale)) def residual_block(self, inputs, channels, name): - conv1 = tf.keras.layers.Conv2D(channels, + conv1 = tf.keras.layers.Conv2D(self.RESIDUAL_INNER_FILTERS, 3, use_bias=False, padding='same',