Skip to content

Commit

Permalink
Merge branch 'main' into positive_masks_factorised_readout
Browse files Browse the repository at this point in the history
  • Loading branch information
pollytur authored Mar 8, 2024
2 parents b42d7bc + 49a7310 commit 15baed4
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 107 deletions.
44 changes: 20 additions & 24 deletions neuralpredictors/layers/cores/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,25 +226,27 @@ def add_first_layer(self):
self.add_activation(layer)
self.features.add_module("layer0", nn.Sequential(layer))

def add_subsequent_conv_layer(self, layer: OrderedDict, l: int) -> None:
layer[self.conv_layer_name] = self.ConvLayer(
in_channels=self.hidden_channels[l - 1]
if not self.skip > 1
else min(self.skip, l) * self.hidden_channels[0],
out_channels=self.hidden_channels[l],
kernel_size=self.hidden_kern[l - 1],
stride=self.stride,
padding=self.hidden_padding or ((self.hidden_kern[l - 1] - 1) * self.hidden_dilation + 1) // 2,
dilation=self.hidden_dilation,
bias=self.bias,
)

def add_subsequent_layers(self):
if not isinstance(self.hidden_kern, Iterable):
self.hidden_kern = [self.hidden_kern] * (self.num_layers - 1)

for l in range(1, self.num_layers):
layer = OrderedDict()
if self.hidden_padding is None:
self.hidden_padding = ((self.hidden_kern[l - 1] - 1) * self.hidden_dilation + 1) // 2
layer[self.conv_layer_name] = self.ConvLayer(
in_channels=self.hidden_channels[l - 1]
if not self.skip > 1
else min(self.skip, l) * self.hidden_channels[0],
out_channels=self.hidden_channels[l],
kernel_size=self.hidden_kern[l - 1],
stride=self.stride,
padding=self.hidden_padding,
dilation=self.hidden_dilation,
bias=self.bias,
)

self.add_subsequent_conv_layer(layer, l)
self.add_bn_layer(layer, l)
self.add_activation(layer)
self.features.add_module("layer{}".format(l), nn.Sequential(layer))
Expand Down Expand Up @@ -326,6 +328,9 @@ def __init__(
self.init_std = init_std
super().__init__(*args, **kwargs, input_regularizer=input_regularizer)

if self.skip > 0:
raise NotImplementedError("Skip connections are not implemented for RotationEquivariant2dCore")

def set_batchnorm_type(self):
if not self.rot_eq_batch_norm:
self.batchnorm_layer_cls = nn.BatchNorm2d
Expand Down Expand Up @@ -569,17 +574,8 @@ def add_subsequent_layers(self):

for l in range(1, self.num_layers):
layer = OrderedDict()
if self.hidden_padding is None:
self.hidden_padding = ((self.hidden_kern[l - 1] - 1) * self.hidden_dilation + 1) // 2
layer[self.conv_layer_name] = self.ConvLayer(
in_channels=self.hidden_channels if not self.skip > 1 else min(self.skip, l) * self.hidden_channels,
out_channels=self.hidden_channels,
kernel_size=self.hidden_kern[l - 1],
stride=self.stride,
padding=self.hidden_padding,
dilation=self.hidden_dilation,
bias=self.bias,
)

self.add_subsequent_conv_layer(layer, l)
self.add_bn_layer(layer, l)
self.add_activation(layer)
if (self.num_layers - l) <= self.n_se_blocks:
Expand Down
Loading

0 comments on commit 15baed4

Please sign in to comment.