Skip to content

Commit

Permalink
Merge pull request #232 from pollytur/positive_masks_factorised_readout
Browse files Browse the repository at this point in the history
Added the restriction on Factorised masks to be positive using absolute value (as in ecker 2018)
  • Loading branch information
MaxFBurg authored Mar 8, 2024
2 parents 49a7310 + 15baed4 commit 4ceba76
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions neuralpredictors/layers/readouts/factorized.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import torch
from torch import nn as nn
Expand All @@ -19,12 +21,31 @@ def __init__(
init_noise=1e-3,
constrain_pos=False,
positive_weights=False,
positive_spatial=False,
shared_features=None,
mean_activity=None,
spatial_and_feature_reg_weight=None,
gamma_readout=None, # depricated, use feature_reg_weight instead
gamma_readout=None,
**kwargs,
):
"""
Args:
in_shape: batch, channels, height, width (batch could be arbitrary)
outdims: number of neurons to predict
bias: if True, bias is used
normalize: if True, normalizes the spatial mask using l2 norm
init_noise: the std for readout initialisation
constrain_pos: if True, negative values in the spatial mask and feature readout are clamped to 0
positive_weights: if True, negative values in the feature readout are turned into 0
positive_spatial: if True, spatial readout mask values are restricted to be positive by taking the absolute values
shared_features: if True, uses a copy of the features from somewhere else
mean_activity: the mean for readout initialisation
spatial_and_feature_reg_weight: lagrange multiplier (constant) for L1 penalty,
the bigger the number, the stronger the penalty
gamma_readout: depricated, use spatial_and_feature_reg_weight instead
**kwargs:
"""

super().__init__()

Expand All @@ -33,6 +54,12 @@ def __init__(
self.outdims = outdims
self.positive_weights = positive_weights
self.constrain_pos = constrain_pos
self.positive_spatial = positive_spatial
if positive_spatial and constrain_pos:
warnings.warn(
f"If both positive_spatial and constrain_pos are True, "
f"only constrain_pos will effectively take place"
)
self.init_noise = init_noise
self.normalize = normalize
self.mean_activity = mean_activity
Expand All @@ -50,7 +77,7 @@ def __init__(
else:
self.register_parameter("bias", None)

self.initialize(mean_activity)
self.initialize()

@property
def shared_features(self):
Expand Down Expand Up @@ -84,6 +111,8 @@ def normalized_spatial(self):
weight = self.spatial
if self.constrain_pos:
weight.data.clamp_min_(0)
elif self.positive_spatial:
weight = torch.abs(weight)
return weight

def regularizer(self, reduction="sum", average=None):
Expand Down

0 comments on commit 4ceba76

Please sign in to comment.