diff --git a/src/zennit/composites.py b/src/zennit/composites.py index 2129080..f49ef8e 100644 --- a/src/zennit/composites.py +++ b/src/zennit/composites.py @@ -162,11 +162,15 @@ class EpsilonGammaBox(SpecialFirstLayerMapComposite): A tensor with the same size as the input, describing the lowest possible pixel values. high: obj:`torch.Tensor` A tensor with the same size as the input, describing the highest possible pixel values. + epsilon: float + Epsilon parameter for the epsilon rule. + gamma: float + Gamma parameter for the gamma rule. ''' - def __init__(self, low, high, canonizers=None): + def __init__(self, low, high, epsilon=1e-6, gamma=0.25, canonizers=None): layer_map = LAYER_MAP_BASE + [ - (Convolution, Gamma(gamma=0.25)), - (torch.nn.Linear, Epsilon()), + (Convolution, Gamma(gamma=gamma)), + (torch.nn.Linear, Epsilon(epsilon=epsilon)), ] first_map = [ (Convolution, ZBox(low, high)) @@ -178,11 +182,16 @@ def __init__(self, low, high, canonizers=None): class EpsilonPlus(LayerMapComposite): '''An explicit composite using the zplus rule for all convolutional layers and the epsilon rule for all fully connected layers. + + Parameters + ---------- + epsilon: float + Epsilon parameter for the epsilon rule. ''' - def __init__(self, canonizers=None): + def __init__(self, epsilon=1e-6, canonizers=None): layer_map = LAYER_MAP_BASE + [ (Convolution, ZPlus()), - (torch.nn.Linear, Epsilon()), + (torch.nn.Linear, Epsilon(epsilon=epsilon)), ] super().__init__(layer_map, canonizers=canonizers) @@ -191,11 +200,16 @@ def __init__(self, canonizers=None): class EpsilonAlpha2Beta1(LayerMapComposite): '''An explicit composite using the alpha2-beta1 rule for all convolutional layers and the epsilon rule for all fully connected layers. + + Parameters + ---------- + epsilon: float + Epsilon parameter for the epsilon rule. ''' - def __init__(self, canonizers=None): + def __init__(self, epsilon=1e-6, canonizers=None): layer_map = LAYER_MAP_BASE + [ (Convolution, AlphaBeta(alpha=2, beta=1)), - (torch.nn.Linear, Epsilon()), + (torch.nn.Linear, Epsilon(epsilon=epsilon)), ] super().__init__(layer_map, canonizers=canonizers) @@ -204,11 +218,16 @@ def __init__(self, canonizers=None): class EpsilonPlusFlat(SpecialFirstLayerMapComposite): '''An explicit composite using the flat rule for any linear first layer, the zplus rule for all other convolutional layers and the epsilon rule for all other fully connected layers. + + Parameters + ---------- + epsilon: float + Epsilon parameter for the epsilon rule. ''' - def __init__(self, canonizers=None): + def __init__(self, epsilon=1e-6, canonizers=None): layer_map = LAYER_MAP_BASE + [ (Convolution, ZPlus()), - (torch.nn.Linear, Epsilon()), + (torch.nn.Linear, Epsilon(epsilon=epsilon)), ] first_map = [ (Linear, Flat()) @@ -220,11 +239,16 @@ def __init__(self, canonizers=None): class EpsilonAlpha2Beta1Flat(SpecialFirstLayerMapComposite): '''An explicit composite using the flat rule for any linear first layer, the alpha2-beta1 rule for all other convolutional layers and the epsilon rule for all other fully connected layers. + + Parameters + ---------- + epsilon: float + Epsilon parameter for the epsilon rule. ''' - def __init__(self, canonizers=None): + def __init__(self, epsilon=1e-6, canonizers=None): layer_map = LAYER_MAP_BASE + [ (Convolution, AlphaBeta(alpha=2, beta=1)), - (torch.nn.Linear, Epsilon()), + (torch.nn.Linear, Epsilon(epsilon=epsilon)), ] first_map = [ (Linear, Flat())