From 20e22c0f96c9c2313aec7e4f1534f86c9ab049b6 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 20 Apr 2021 10:25:12 +0200 Subject: [PATCH 01/54] [CI] Enable pull request checks to rnn branch --- .github/workflows/lint.yaml | 1 + .github/workflows/test.yaml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index a6e2ef64e..cee106b16 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -9,6 +9,7 @@ on: - development - master - release + - rnn jobs: flake8: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f1efc1454..07e87ae35 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -9,7 +9,7 @@ on: - development - master - release - + - rnn jobs: tests: From 1646b30f55a53ccf784e67e793e56f21e3b01f50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 22 Apr 2021 08:26:28 +0200 Subject: [PATCH 02/54] [IO] Reduce module output if tuple (#147) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Recurrent modules operate on `tuple`s. The trailing output entries contain hidden state information, which need not be saved. Co-authored-by: Tim Schäfer --- backpack/__init__.py | 15 +++++++++++---- test/core/derivatives/problem.py | 8 ++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/backpack/__init__.py b/backpack/__init__.py index eeb98342a..38ca4ff41 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -134,15 +134,22 @@ def should_store_io(): def hook_store_io(module, input, output): """Saves the input and output as attributes of the module. + The list of inputs with index i is saved as module.input[i] + The output is reduced to single output tensor and saved as module.output + Args: - module: module - input: List of input tensors - output: output tensor + module (torch.nn.Module): the module on which to save the params + input (list): List of input tensors + output (torch.Tensor or tuple): result of module(input) """ if disable.should_store_io() and torch.is_grad_enabled(): for i in range(len(input)): setattr(module, "input{}".format(i), input[i]) - module.output = output + if isinstance(output, tuple): + # is true for RNN,GRU,LSTM which return tuple (output, ...) + module.output = output[0] + else: + module.output = output def memory_cleanup(module): diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index b1a9a720e..8c35aee4a 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -132,6 +132,10 @@ def make_output_shape(self): else: output = module(input, target) + if isinstance(output, tuple): + # is true for RNN,GRU,LSTM which return tuple (output, ...) + output = output[0] + return output.shape def is_loss(self): @@ -153,6 +157,10 @@ def forward_pass(self, input_requires_grad=False, sample_idx=None): else: output = self.module(input) + if isinstance(output, tuple): + # is true for RNN,GRU,LSTM which return tuple (output, ...) + output = output[0] + return input, output, dict(self.module.named_parameters()) def make_id(self): From b578276932cacb4ba93979edbdae52ab1bf65283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Fri, 30 Apr 2021 14:38:26 +0200 Subject: [PATCH 03/54] [TEST] Allow modules with non-leading batch axis (#148) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (In preparation of support for RNN, GRU, and LSTM) Some modules in PyTorch accept inputs whose batch axis is not the leading axis. This must be taken into account by the test suite. This commit adapts parts of the core and extension test suite to allow for such modules be tested. Link: Progress on fKunstner/backpack-discuss#100. * add selection of batch axis to test suite * add selection of batch axis to test suite, fix fKunstner/backpack-discuss#100 * docstring * make black * add docstring * batch axis in test/extensions * incorporate suggestions into pull request * incorporate suggestions into pull request Co-authored-by: Tim Schäfer --- backpack/core/derivatives/elu.py | 4 +- backpack/core/derivatives/relu.py | 2 +- backpack/core/derivatives/selu.py | 4 +- backpack/extensions/module_extension.py | 6 ++ backpack/utils/kroneckers.py | 2 +- .../derivatives/implementation/autograd.py | 39 ++++++++-- .../derivatives/implementation/backpack.py | 2 +- test/extensions/automated_settings.py | 4 +- .../firstorder/firstorder_settings.py | 6 +- test/extensions/implementation/autograd.py | 7 +- test/extensions/problem.py | 72 ++++++++++++++++--- 11 files changed, 121 insertions(+), 27 deletions(-) diff --git a/backpack/core/derivatives/elu.py b/backpack/core/derivatives/elu.py index 2be2305a4..f06dd9db3 100644 --- a/backpack/core/derivatives/elu.py +++ b/backpack/core/derivatives/elu.py @@ -9,11 +9,11 @@ def hessian_is_zero(self): return False def df(self, module, g_inp, g_out): - """First ELU derivative: `ELU'(x) = alpha * e^x if x < 0 else 1`. """ + """First ELU derivative: `ELU'(x) = alpha * e^x if x < 0 else 1`.""" df_ELU = gt(module.input0, 0).float() df_ELU[df_ELU == 0] = module.alpha * exp(module.input0[df_ELU == 0]) return df_ELU def d2f(self, module, g_inp, g_out): - """Second ELU derivative: `ELU''(x) = alpha * e^x if x < 0 else 1`. """ + """Second ELU derivative: `ELU''(x) = alpha * e^x if x < 0 else 1`.""" return self.df(module, g_inp, g_out) diff --git a/backpack/core/derivatives/relu.py b/backpack/core/derivatives/relu.py index 3e631016d..eae9d5ebf 100644 --- a/backpack/core/derivatives/relu.py +++ b/backpack/core/derivatives/relu.py @@ -9,5 +9,5 @@ def hessian_is_zero(self): return True def df(self, module, g_inp, g_out): - """First ReLU derivative: `ReLU'(x) = 0 if x < 0 else 1`. """ + """First ReLU derivative: `ReLU'(x) = 0 if x < 0 else 1`.""" return gt(module.input0, 0).float() diff --git a/backpack/core/derivatives/selu.py b/backpack/core/derivatives/selu.py index e7f96ba74..1e0cf00f2 100644 --- a/backpack/core/derivatives/selu.py +++ b/backpack/core/derivatives/selu.py @@ -14,7 +14,7 @@ def hessian_is_zero(self): return False def df(self, module, g_inp, g_out): - """First SELU derivative: `SELU'(x) = scale if x < 0 else scale*alpha*e^x`. """ + """First SELU derivative: `SELU'(x) = scale if x < 0 else scale*alpha*e^x`.""" df_SELU = gt(module.input0, 0).float() df_SELU[df_SELU == 1] = self.scale @@ -24,7 +24,7 @@ def df(self, module, g_inp, g_out): return df_SELU def d2f(self, module, g_inp, g_out): - """Second SELU derivative: `SELU''(x) = 0 if x < 0 else scale*alpha*e^x`. """ + """Second SELU derivative: `SELU''(x) = 0 if x < 0 else scale*alpha*e^x`.""" d2f_SELU = gt(module.input0, 0).float() d2f_SELU[d2f_SELU == 1] = 0 diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py index 1ed811a34..1b78b78c7 100644 --- a/backpack/extensions/module_extension.py +++ b/backpack/extensions/module_extension.py @@ -60,6 +60,10 @@ def backpropagate(self, ext, module, g_inp, g_out, bpQuantities): warnings.warn("Backpropagate has not been overwritten") def apply(self, ext, module, g_inp, g_out): + """ + Fetch backpropagated quantities from module output, apply backpropagation + rule, and attach the result to module input(s). + """ inp = module.input0 out = module.output @@ -77,10 +81,12 @@ def apply(self, ext, module, g_inp, g_out): @staticmethod def __backproped_quantities(ext, out): + """Fetch backpropagated quantities attached to the module output.""" return getattr(out, ext.savefield, None) @staticmethod def __backprop_quantities(ext, inp, out, bpQuantities): + """Propagate back additional information by attaching it to the module input.""" setattr(inp, ext.savefield, bpQuantities) diff --git a/backpack/utils/kroneckers.py b/backpack/utils/kroneckers.py index a0c8a39bb..eeeb9ed51 100644 --- a/backpack/utils/kroneckers.py +++ b/backpack/utils/kroneckers.py @@ -30,7 +30,7 @@ def two_kfacs_to_mat(A, B): def kfac_mat_prod(factors): - """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]` """ + """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]`""" assert all_tensors_of_order(order=2, tensors=factors) shapes = [list(f.size()) for f in factors] diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index e7a60f6b2..b2ecbca6c 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -40,17 +40,29 @@ def weight_jac_t_mat_prod(self, mat, sum_batch): def bias_jac_t_mat_prod(self, mat, sum_batch): return self.param_jac_t_mat_prod("bias", mat, sum_batch) - def param_jac_t_vec_prod(self, name, vec, sum_batch): + def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): + return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) + + def param_jac_t_vec_prod(self, name, vec, sum_batch, axis_batch=0): + """Compute the product of jac_t and the given vector. + + Args: + name (str): name of parameter for derivative + vec (torch.Tensor): vectors which to multiply + sum_batch (boolean): whether to sum along batch axis + axis_batch (int, optional): index of batch axis. Defaults to 0. + + Returns: + torch.Tensor: product of jac_t and vec + """ input, output, named_params = self.problem.forward_pass() param = named_params[name] if sum_batch: return transposed_jacobian_vector_product(output, param, vec)[0] else: - N = input.shape[0] - - sample_outputs = [output[n] for n in range(N)] - sample_vecs = [vec[n] for n in range(N)] + sample_outputs = output.split(1, dim=axis_batch) + sample_vecs = vec.split(1, dim=axis_batch) jac_t_sample_prods = [ transposed_jacobian_vector_product(n_out, param, n_vec)[0] @@ -59,12 +71,25 @@ def param_jac_t_vec_prod(self, name, vec, sum_batch): return torch.stack(jac_t_sample_prods) - def param_jac_t_mat_prod(self, name, mat, sum_batch): + def param_jac_t_mat_prod(self, name, mat, sum_batch, axis_batch=0): + """Compute the product of jac_t and the given matrix. + + Args: + name (str): name of parameter for derivative + mat (torch.Tensor): matrix which to multiply + sum_batch (boolean): whether to sum along batch axis + axis_batch (int, optional): index of batch axis. This is counted + without the first axis. Defaults to 0. + + Returns: + torch.Tensor: product of jac_t and mat + """ V = mat.shape[0] vecs = [mat[v] for v in range(V)] jac_t_vec_prods = [ - self.param_jac_t_vec_prod(name, vec, sum_batch) for vec in vecs + self.param_jac_t_vec_prod(name, vec, sum_batch, axis_batch=axis_batch) + for vec in vecs ] return torch.stack(jac_t_vec_prods) diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 06ce22668..c36d94d9f 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -87,7 +87,7 @@ def hessian_is_zero(self): return self.problem.derivative.hessian_is_zero() def _sample_hessians_from_sqrt(self, sqrt): - """Convert individual matrix square root into individual full matrix. """ + """Convert individual matrix square root into individual full matrix.""" equation = None num_axes = len(sqrt.shape) diff --git a/test/extensions/automated_settings.py b/test/extensions/automated_settings.py index d64e99bf9..612c85c48 100644 --- a/test/extensions/automated_settings.py +++ b/test/extensions/automated_settings.py @@ -60,7 +60,7 @@ def make_cnn(conv_class, output_size, conv_params): ) def get_output_shape(module, module_params, input): - """ returns the output shape for a given layer""" + """returns the output shape for a given layer""" output = module(*module_params)(input) return output.numel() // output.shape[0] @@ -102,7 +102,7 @@ def make_cnn(conv_class, output_size, conv_params, pool_cls, pool_params): ) def get_output_shape(module, module_params, input, pool, pool_params): - """ returns the output shape for a given layer""" + """returns the output shape for a given layer""" output_1 = module(*module_params)(input) output = pool_cls(*pool_params)(output_1) return output.numel() // output.shape[0] diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 0385e3c4e..99206d86d 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -1,5 +1,6 @@ -"""Test configurations for `backpack.core.extensions.firstorder` -that is shared among the following firstorder methods: +"""Test configurations for `backpack.core.extensions.firstorder`. + +It is shared among the following firstorder methods: - batch_grad - batch_l2_grad - sum_grad_sqaured @@ -17,6 +18,7 @@ "device" [list(torch.device)]: List of devices to run the test on. "id_prefix" (str): Prefix to be included in the test name. "seed" (int): seed for the random number for torch.rand + "axis_batch (int): specifies the batch axis. Defaults to zero """ diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index 2068659f8..f956090de 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -11,7 +11,12 @@ class AutogradExtensions(ExtensionsImplementation): """Extension implementations with autograd.""" def batch_grad(self): - N = self.problem.input.shape[0] + """Scaled individual gradients computed by BackPACK's BatchGrad extension. + + Returns: + list[torch.Tensor]: batch_grads + """ + N = self.problem.input.shape[self.problem.axis_batch] batch_grads = [ torch.zeros(N, *p.size()).to(self.problem.device) for p in self.problem.model.parameters() diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 4c41ef191..f25dcd53a 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -9,6 +9,14 @@ def make_test_problems(settings): + """Creates test problems from settings. + + Args: + settings (list[dict]): raw settings of the problems + + Returns: + list[ExtensionTestProblem] + """ problem_dicts = [] for setting in settings: @@ -24,17 +32,23 @@ def make_test_problems(settings): def add_missing_defaults(setting): - """Create extensions test problem from setting. + """Create full settings from setting. + Args: setting (dict): configuration dictionary + Returns: - ExtensionsTestProblem: problem with specified settings. + dict: full settings. + + Raises: + ValueError: if no proper settings """ required = ["module_fn", "input_fn", "loss_function_fn", "target_fn"] optional = { "id_prefix": "", "seed": 0, "device": get_available_devices(), + "axis_batch": 0, } for req in required: @@ -53,6 +67,8 @@ def add_missing_defaults(setting): class ExtensionsTestProblem: + """Class providing functions and parameters.""" + def __init__( self, input_fn, @@ -62,6 +78,7 @@ def __init__( device, seed, id_prefix, + axis_batch, ): """Collection of information required to test extensions. @@ -73,6 +90,7 @@ def __init__( device (torch.device): Device to run on. seed (int): Random seed. id_prefix (str): Extra string added to test id. + axis_batch (int): index of batch axis. Defaults to 0. """ self.module_fn = module_fn self.input_fn = input_fn @@ -82,8 +100,10 @@ def __init__( self.device = device self.seed = seed self.id_prefix = id_prefix + self.axis_batch = axis_batch def set_up(self): + """Set up problem from settings.""" torch.manual_seed(self.seed) self.model = self.module_fn().to(self.device) @@ -92,10 +112,15 @@ def set_up(self): self.loss_function = self.loss_function_fn().to(self.device) def tear_down(self): + """Delete all variables after problem.""" del self.model, self.input, self.target, self.loss_function def make_id(self): - """Needs to function without call to `set_up`.""" + """Needs to function without call to `set_up`. + + Returns: + str: id of problem + """ prefix = (self.id_prefix + "-") if self.id_prefix != "" else "" return ( prefix @@ -108,27 +133,58 @@ def make_id(self): ) def forward_pass(self, sample_idx=None): - """Do a forward pass. Return input, output, and parameters.""" + """Do a forward pass. Return input, output, and parameters. + + The forward pass is performed on the selected index. + If the index is None, then the forward pass is calculated for the whole batch. + + Args: + sample_idx (int, optional): Index of the sample to select. + Defaults to None. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input, output, loss, each with batch axis first + """ if sample_idx is None: input = self.input.clone().detach() target = self.target.clone().detach() else: - input = self.input.clone()[sample_idx, :].unsqueeze(0).detach() target = self.target.clone()[sample_idx].unsqueeze(0).detach() + input = self.input.split(1, dim=self.axis_batch)[sample_idx].detach() - print(self.target.shape) - print(target.shape) output = self.model(input) + if isinstance(output, tuple): + output = output[0] + + if self.axis_batch != 0: + # Note: This inserts a new operation into the computation graph. + # In second order extensions, breaks backpropagation of additional + # information. + output = output.transpose(0, self.axis_batch) + loss = self.loss_function(output, target) return input, output, loss def extend(self): + """Extend module of problem.""" self.model = extend(self.model) self.loss_function = extend(self.loss_function) def get_reduction_factor(self, loss, unreduced_loss): - """Return the factor used to reduce the individual losses.""" + """Return the factor used to reduce the individual losses. + + Args: + loss (torch.Tensor): the loss after reduction + unreduced_loss (torch.Tensor): the raw loss before reduction + + Returns: + float: factor + + Raises: + RuntimeError: if either mean or sum cannot be determined + """ mean_loss = unreduced_loss.flatten().mean() sum_loss = unreduced_loss.flatten().sum() if torch.allclose(mean_loss, sum_loss): From ac2423a7489164836d305d382473f8337546fd75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 6 May 2021 15:02:49 +0200 Subject: [PATCH 04/54] Automate parameter function creation of BatchGrad and Grad (#150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor batch_grad_base und gradient/base * add overwrite check and docs * add # noqa to 2 lines * incorporate some suggestions * incorporate some suggestions * code block in docstring * particularise #noqa Co-authored-by: Tim Schäfer --- .../firstorder/batch_grad/batch_grad_base.py | 66 ++++++++++++++++--- .../extensions/firstorder/gradient/base.py | 66 ++++++++++++++++--- 2 files changed, 114 insertions(+), 18 deletions(-) diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index 1e25de41e..d6bd89822 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -1,17 +1,65 @@ +"""Calculates the batch_grad derivative.""" from backpack.extensions.firstorder.base import FirstOrderModuleExtension class BatchGradBase(FirstOrderModuleExtension): - def __init__(self, derivatives, params=None): + """Calculates the batch_grad derivative. + + Passes the calls for the parameters to the derivatives class. + Implements functions with method names from params. + + If child class wants to overwrite these methods + - for example to support an additional external module - + it can do so using the interface for parameter "param1":: + + param1(ext, module, g_inp, g_out, bpQuantities): + return batch_grads + + In this case, the method is not overwritten by this class. + """ + + def __init__(self, derivatives, params): + """Initializes all methods. + + If the param method has already been defined, it is left unchanged. + + Args: + derivatives(backpack.core.derivatives.basederivatives.BaseParameterDerivatives): # noqa: B950 + Derivatives object assigned to self.derivatives. + params (list[str]): list of strings with parameter names. + For each, a method is assigned. + """ self.derivatives = derivatives + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) super().__init__(params=params) - def bias(self, ext, module, g_inp, g_out, bpQuantities): - return self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=False - ) + def _make_param_function(self, param): + """Creates a function that calculates batch_grad wrt param. + + Args: + param(str): name of parameter + + Returns: + function: function that calculates batch_grad wrt param + """ + + def param_function(ext, module, g_inp, g_out, bpQuantities): + """Calculates batch_grad with the help of derivatives object. + + Args: + ext(backpack.extensions.BatchGrad): extension that is used + module(torch.nn.Module): module that performed forward pass + g_inp(tuple[torch.Tensor]): input gradient tensors + g_out(tuple[torch.Tensor]): output gradient tensors + bpQuantities(None): additional quantities for second order + + Returns: + torch.Tensor: scaled individual gradients + """ + return getattr(self.derivatives, f"{param}_jac_t_mat_prod")( + module, g_inp, g_out, g_out[0], sum_batch=False + ) - def weight(self, ext, module, g_inp, g_out, bpQuantities): - return self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=False - ) + return param_function diff --git a/backpack/extensions/firstorder/gradient/base.py b/backpack/extensions/firstorder/gradient/base.py index 3ebaf8932..eb45c8da7 100644 --- a/backpack/extensions/firstorder/gradient/base.py +++ b/backpack/extensions/firstorder/gradient/base.py @@ -1,17 +1,65 @@ +"""Calculates the gradient.""" from backpack.extensions.firstorder.base import FirstOrderModuleExtension class GradBaseModule(FirstOrderModuleExtension): - def __init__(self, derivatives, params=None): + """Calculates the gradient. + + Passes the calls for the parameters to the derivatives class. + Implements functions with method names from params. + + If child class wants to overwrite these methods + - for example to support an additional external module - + it can do so using the interface for parameter "param1":: + + param1(ext, module, g_inp, g_out, bpQuantities): + return batch_grads + + In this case, the method is not overwritten by this class. + """ + + def __init__(self, derivatives, params): + """Initializes all methods. + + If the param method has already been defined, it is left unchanged. + + Args: + derivatives(backpack.core.derivatives.basederivatives.BaseParameterDerivatives): # noqa: B950 + Derivatives object assigned to self.derivatives. + params (list[str]): list of strings with parameter names. + For each, a method is assigned. + """ self.derivatives = derivatives + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) super().__init__(params=params) - def bias(self, ext, module, g_inp, g_out, bpQuantities): - return self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=True - ) + def _make_param_function(self, param): + """Creates a function that calculates gradient wrt param. + + Args: + param(str): name of parameter + + Returns: + function: function that calculates gradient wrt param + """ + + def param_function(ext, module, g_inp, g_out, bpQuantities): + """Calculates gradient with the help of derivatives object. + + Args: + ext(backpack.extensions.BatchGrad): extension that is used + module(torch.nn.Module): module that performed forward pass + g_inp(tuple[torch.Tensor]): input gradient tensors + g_out(tuple[torch.Tensor]): output gradient tensors + bpQuantities(None): additional quantities for second order + + Returns: + torch.Tensor: gradient of the batch, similar to autograd + """ + return getattr(self.derivatives, f"{param}_jac_t_mat_prod")( + module, g_inp, g_out, g_out[0], sum_batch=True + ) - def weight(self, ext, module, g_inp, g_out, bpQuantities): - return self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=True - ) + return param_function From 2be0e886d26c1c1a041e92d103d8c41aab5def84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 12 May 2021 09:41:18 +0200 Subject: [PATCH 05/54] Merge branch 'development' into rnn (#153) Update `rnn` branch with latest changes of `development` --- README-dev.md | 11 + backpack/extensions/backprop_extension.py | 74 ++++--- backpack/utils/kroneckers.py | 2 +- .../use_cases/example_custom_module.py | 197 ++++++++++++++++++ test/extensions/automated_settings.py | 4 +- test/extensions/test_backprop_extension.py | 45 ++++ 6 files changed, 305 insertions(+), 28 deletions(-) create mode 100644 docs_src/examples/use_cases/example_custom_module.py create mode 100644 test/extensions/test_backprop_extension.py diff --git a/README-dev.md b/README-dev.md index 4d3b94f46..95edfb963 100644 --- a/README-dev.md +++ b/README-dev.md @@ -45,6 +45,17 @@ make install-dev make format-check ``` +## Pull requests + +Code that is affected (has a `git diff`) by a pull request must satisfy the following docstring requirements: + +1. A one-line summary what the function/class does +2. Argument description (`Args` section) + - Argument name, type, and description + - Optional arguments must be marked as such, the default value must be documented in the description +3. Output description (`Returns` section) + - Type and description + ## Documentation ### Build diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py index 193539c56..e32d12185 100644 --- a/backpack/extensions/backprop_extension.py +++ b/backpack/extensions/backprop_extension.py @@ -1,7 +1,11 @@ +"""Implements the backpropagation mechanism.""" import warnings +from typing import Type +import torch.nn from torch.nn import Sequential +from backpack.extensions.module_extension import ModuleExtension from backpack.utils.hooks import no_op FAIL_ERROR = "ERROR" @@ -10,8 +14,7 @@ class BackpropExtension: - """ - Base class for the BackPACK extensions. + """Base class for the BackPACK extensions. Descendants of this class need to - define in what field to save results @@ -27,34 +30,48 @@ class BackpropExtension: ``` """ - __external_module_extensions = {} - def __init__(self, savefield, module_exts, fail_mode=FAIL_ERROR): - """ - Parameters - ---------- - savefield : str - Where to save results - module_exts : dict - Dictionary mapping module classes to `ModuleExtension` instances - fail_mode : str, optional - Behavior when encountering an unknown layer. - Can be - - "ERROR": raise a NotImplementedError - - "WARN": raise a UserWarning - - "SILENT": skip the module silently + """Initializes parameters. + + Args: + savefield(str): Where to save results + module_exts(dict): Maps module classes to `ModuleExtension` instances + fail_mode(str, optional): Behavior when encountering an unknown layer. + Can be + - "ERROR": raise a NotImplementedError + - "WARN": raise a UserWarning + - "SILENT": skip the module silently + Defaults to FAIL_ERROR = "ERROR" """ self.savefield = savefield - self.__module_extensions = { - **module_exts, - **self.__class__.__external_module_extensions, - } - + self.__module_extensions = module_exts self.__fail_mode = fail_mode - @classmethod - def add_module_extension(cls, module, extension): - cls.__external_module_extensions[module] = extension + def set_module_extension( + self, + module: Type[torch.nn.Module], + extension: ModuleExtension, + overwrite: bool = False, + ) -> None: + """Adds a module mapping to module_extensions. + + This can be used to add a custom module. + + Args: + module: The module that is supposed to be extended + extension: The custom extension of that module. + overwrite: Whether to allow overwriting of an existing key. + Defaults to False. + + Raises: + ValueError: If the key already exists and overwrite is set to False. + """ + if overwrite is False and module in self.__module_extensions: + raise ValueError( + f"{module} maps to {self.__module_extensions.get(module)}! " + "Use overwrite = True to force replacement." + ) + self.__module_extensions[module] = extension def __get_module_extension(self, module): module_extension = self.__module_extensions.get(module.__class__) @@ -82,5 +99,12 @@ def __get_module_extension(self, module): return module_extension.apply def apply(self, module, g_inp, g_out): + """Applies backpropagation. + + Args: + module(torch.nn.module): module to perform backpropagation on + g_inp(tuple[torch.Tensor]): input gradient + g_out(tuple[torch.Tensor]): output gradient + """ module_extension = self.__get_module_extension(module) module_extension(self, module, g_inp, g_out) diff --git a/backpack/utils/kroneckers.py b/backpack/utils/kroneckers.py index eeeb9ed51..af61a79a5 100644 --- a/backpack/utils/kroneckers.py +++ b/backpack/utils/kroneckers.py @@ -30,7 +30,7 @@ def two_kfacs_to_mat(A, B): def kfac_mat_prod(factors): - """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]`""" + """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]`.""" assert all_tensors_of_order(order=2, tensors=factors) shapes = [list(f.size()) for f in factors] diff --git a/docs_src/examples/use_cases/example_custom_module.py b/docs_src/examples/use_cases/example_custom_module.py new file mode 100644 index 000000000..e027d3608 --- /dev/null +++ b/docs_src/examples/use_cases/example_custom_module.py @@ -0,0 +1,197 @@ +"""Custom module example +========================================= + +This tutorial shows how to support a custom module in a simple fashion. +We focus on `BackPACK's first-order extensions `_. +They don't backpropagate additional information and thus require less functionality be implemented. + +Let's get the imports out of our way. +""" + +import torch + +from backpack import backpack, extend +from backpack.extensions import BatchGrad +from backpack.extensions.firstorder.base import FirstOrderModuleExtension + +# make deterministic +torch.manual_seed(0) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# %% +# Custom PyTorch module +# --------------------- +# In this example, we consider extending our own, very simplistic, layer. +# It scales the input by a scalar ``weight`` in a forward pass. Here is the +# layer class (see https://pytorch.org/docs/stable/notes/extending.html). + + +class ScaleModule(torch.nn.Module): + """Defines the module.""" + + def __init__(self, weight=2.0): + """Store scalar weight. + + Args: + weight(float, optional): Initial value for weight. Defaults to 2.0. + """ + super(ScaleModule, self).__init__() + + self.weight = torch.nn.Parameter(torch.tensor([weight])) + + def forward(self, input): + """Defines forward pass. + + Args: + input(torch.Tensor): input + + Returns: + torch.Tensor: product of input and weight + """ + return input * self.weight + + +# %% +# You don't necessarily need to write a custom layer. Any PyTorch layer can be extended +# as described (it should be a :py:class:`torch.nn.Module `'s because +# BackPACK uses module hooks). +# +# Custom module extension +# ----------------------- +# Let's make BackPACK support computing individual gradients for ``ScaleModule``. +# This is done by the :py:class:`BatchGrad ` extension. +# To support the new module, we need to create a module extension that implements +# how individual gradients are extracted with respect to ``ScaleModule``'s parameter. +# +# The module extension must implement methods named after the parameters passed to the +# constructor. Here it goes. + + +class ScaleModuleBatchGrad(FirstOrderModuleExtension): + """Extract individual gradients for ``ScaleModule``.""" + + def __init__(self): + """Store parameters for which individual gradients should be computed.""" + # specify parameter names + super().__init__(params=["weight"]) + + def weight(self, ext, module, g_inp, g_out, bpQuantities): + """Extract individual gradients for ScaleModule's ``weight`` parameter. + + Args: + ext(BatchGrad): extension that is used + module(ScaleModule): module that performed forward pass + g_inp(tuple[torch.Tensor]): input gradient tensors + g_out(tuple[torch.Tensor]): output gradient tensors + bpQuantities(None): additional quantities for second-order + + Returns: + torch.Tensor: individual gradients + """ + show_useful = True + + if show_useful: + print("Useful quantities:") + # output is saved under field output + print("\tmodule.output.shape:", module.output.shape) + # input i is saved under field input[i] + print("\tmodule.input0.shape:", module.input0.shape) + # gradient w.r.t output + print("\tg_out[0].shape: ", g_out[0].shape) + + # actual computation + return (g_out[0] * module.input0).flatten(start_dim=1).sum(axis=1).unsqueeze(-1) + + +# %% +# Lastly, we need to register the mapping between layer (``ScaleModule``) and layer +# extension (``ScaleModuleBatchGrad``) in an instance of +# :py:class:`BatchGrad `. + +# register module-computation mapping +extension = BatchGrad() +extension.set_module_extension(ScaleModule, ScaleModuleBatchGrad()) + +# %% +# That's it. We can now pass ``extension`` to a +# :py:class:`with backpack(...) ` context and compute individual +# gradients with respect to ``ScaleModule``'s ``weight`` parameter. + +# %% +# Test custom module +# ------------------ +# Here, we verify the custom module extension on a small net with random inputs. +# Let's create these. + +batch_size = 10 +batch_axis = 0 +input_size = 4 + +inputs = torch.randn(batch_size, input_size, device=device) +targets = torch.randint(0, 2, (batch_size,), device=device) + +reduction = ["mean", "sum"][1] +my_module = ScaleModule().to(device) +lossfunc = torch.nn.CrossEntropyLoss(reduction=reduction).to(device) + +# %% +# .. note:: +# Results of ``"mean"`` and ``"sum"`` reduction differ by a scaling factor, +# because the information backpropagated by PyTorch is scaled. This is documented at +# https://docs.backpack.pt/en/master/extensions.html#backpack.extensions.BatchGrad. + +# %% +# Individual gradients with PyTorch +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# The following computes individual gradients by looping over individual samples and +# stacking their gradients. + +grad_batch_autograd = [] + +for input_n, target_n in zip( + inputs.split(1, dim=batch_axis), targets.split(1, dim=batch_axis) +): + loss_n = lossfunc(my_module(input_n), target_n) + grad_n = torch.autograd.grad(loss_n, [my_module.weight])[0] + grad_batch_autograd.append(grad_n) + +grad_batch_autograd = torch.stack(grad_batch_autograd) + +print("weight.shape: ", my_module.weight.shape) +print("grad_batch_autograd.shape:", grad_batch_autograd.shape) + +# %% +# Individual gradients with BackPACK +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# BackPACK can compute individual gradients in a single backward pass. +# First, :py:func:`extend ` model and loss function, then +# perform the backpropagation inside a +# :py:class:`with backpack(...) ` context. + +my_module = extend(my_module) +lossfunc = extend(lossfunc) + +loss = lossfunc(my_module(inputs), targets) + +with backpack(extension): + loss.backward() + +grad_batch_backpack = my_module.weight.grad_batch + +print("weight.shape: ", my_module.weight.shape) +print("grad_batch_backpack.shape:", grad_batch_backpack.shape) + +# %% +# Do the computation results match? + +match = torch.allclose(grad_batch_autograd, grad_batch_backpack) + +print(f"autograd and BackPACK individual gradients match? {match}") + +if not match: + raise AssertionError( + "Individual gradients don't match:" + + f"\n{grad_batch_autograd}\nvs.\n{grad_batch_backpack}" + ) diff --git a/test/extensions/automated_settings.py b/test/extensions/automated_settings.py index 612c85c48..58fbec8ac 100644 --- a/test/extensions/automated_settings.py +++ b/test/extensions/automated_settings.py @@ -60,7 +60,7 @@ def make_cnn(conv_class, output_size, conv_params): ) def get_output_shape(module, module_params, input): - """returns the output shape for a given layer""" + """Returns the output shape for a given layer.""" output = module(*module_params)(input) return output.numel() // output.shape[0] @@ -102,7 +102,7 @@ def make_cnn(conv_class, output_size, conv_params, pool_cls, pool_params): ) def get_output_shape(module, module_params, input, pool, pool_params): - """returns the output shape for a given layer""" + """Returns the output shape for a given layer.""" output_1 = module(*module_params)(input) output = pool_cls(*pool_params)(output_1) return output.numel() // output.shape[0] diff --git a/test/extensions/test_backprop_extension.py b/test/extensions/test_backprop_extension.py new file mode 100644 index 000000000..7d879ead0 --- /dev/null +++ b/test/extensions/test_backprop_extension.py @@ -0,0 +1,45 @@ +"""Test custom extensions for backprop_extension.""" + +import pytest +from torch.nn import Linear, Module + +from backpack.extensions import BatchGrad, Variance +from backpack.extensions.firstorder.base import FirstOrderModuleExtension + + +def test_set_custom_extension(): + """Test the method set_custom_extension of BackpropExtension.""" + + class _A(Module): + pass + + class _ABatchGrad(FirstOrderModuleExtension): + pass + + class _AVariance(FirstOrderModuleExtension): + pass + + class _MyLinearBatchGrad(FirstOrderModuleExtension): + pass + + grad_batch = BatchGrad() + + # Set module extension + grad_batch.set_module_extension(_A, _ABatchGrad()) + + # setting again should raise a ValueError + with pytest.raises(ValueError): + grad_batch.set_module_extension(_A, _ABatchGrad()) + + # setting again with overwrite + grad_batch.set_module_extension(_A, _ABatchGrad(), overwrite=True) + + # in a different extension, set another extension for the same module + variance = Variance() + variance.set_module_extension(_A, _AVariance()) + + # set an extension for an already existing extension + with pytest.raises(ValueError): + grad_batch.set_module_extension(Linear, _MyLinearBatchGrad()) + + grad_batch.set_module_extension(Linear, _MyLinearBatchGrad(), overwrite=True) From 91e85e5f6e49d4cd5e0c726e02fa62d3014a27d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Tue, 1 Jun 2021 10:32:11 +0200 Subject: [PATCH 06/54] [CI] partial format tests, test with python >= 3.7 (#154) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tim Schäfer --- .github/workflows/lint.yaml | 30 ++++++++++++++++++++++++++++++ .github/workflows/test.yaml | 4 ++-- fully_documented.txt | 3 +++ makefile | 19 +++++++++++++++---- 4 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 fully_documented.txt diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index cee106b16..1cec88cf5 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -89,3 +89,33 @@ jobs: - name: Run darglint run: | make darglint-check + darglint-partial: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + make install-lint + - name: Run darglint + run: | + make darglint-check-partial + pydocstyle-partial: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + make install-lint + - name: Run pydocstyle + run: | + make pydocstyle-check-partial diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 07e87ae35..d75c724ac 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,11 +16,11 @@ jobs: name: "Python ${{ matrix.python-version }}" runs-on: ubuntu-latest env: - USING_COVERAGE: '3.6,3.8' + USING_COVERAGE: '3.7,3.9' strategy: matrix: - python-version: ["3.6", "3.7", "3.8"] + python-version: ["3.7", "3.8", "3.9"] steps: - uses: actions/checkout@v1 - uses: actions/setup-python@v1 diff --git a/fully_documented.txt b/fully_documented.txt new file mode 100644 index 000000000..72e8c3d10 --- /dev/null +++ b/fully_documented.txt @@ -0,0 +1,3 @@ +test/extensions/problem.py +test/extensions/test_backprop_extension.py +test/extensions/firstorder/firstorder_settings.py diff --git a/makefile b/makefile index 6ddb84b34..9823af519 100644 --- a/makefile +++ b/makefile @@ -5,10 +5,10 @@ .PHONY: test-light test-light-no-gpu .PHONY: conda-env .PHONY: black isort format -.PHONY: black-check isort-check format-check +.PHONY: black-check isort-check format-check format-check-partial .PHONY: flake8 -.PHONY: pydocstyle-check -.PHONY: darglint-check +.PHONY: pydocstyle-check pydocstyle-check-partial +.PHONY: darglint-check darglint-check-partial .PHONY: build-docs .DEFAULT: help @@ -29,8 +29,12 @@ help: @echo " Run flake8 on the project" @echo "pydocstyle-check" @echo " Run pydocstyle on the project" + @echo "pydocstyle-check-partial" + @echo " Run pydocstyle on documented part of the project" @echo "darglint-check" @echo " Run darglint on the project" + @echo "darglint-check-partial" + @echo " Run darglint on documented part of the project" @echo "install" @echo " Install backpack and dependencies" @echo "isort" @@ -82,9 +86,15 @@ flake8: pydocstyle-check: @pydocstyle --count . +pydocstyle-check-partial: + @pydocstyle --count $(shell grep -v '^#' fully_documented.txt ) + darglint-check: @darglint --verbosity 2 . +darglint-check-partial: + @darglint --verbosity 2 $(shell grep -v '^#' fully_documented.txt) + isort: @isort . @@ -96,8 +106,9 @@ format: @make isort @make black-check -format-check: black-check isort-check pydocstyle-check darglint-check +format-check: black-check isort-check flake8 pydocstyle-check darglint-check +format-check-partial: black-check isort-check flake8 pydocstyle-check-partial darglint-check-partial ### # Installation From 655014d8999870d4e94f0cf78b3bd75fa75d47b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Tue, 8 Jun 2021 09:38:34 +0200 Subject: [PATCH 07/54] [core] Add RNN derivatives (#156) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Auxiliary: - Improve documentation - Move mapping between layers and derivatives from main lib to test suite --- * RNN: derivatives and derivatives test * additional test for RNN derivatives * incorporate suggestions * move RNN CUDA workaround to AutogradDerivatives * delete unnecessary variable * fix type hint for hessian * remove white spaces in einsums * make format Co-authored-by: Tim Schäfer --- backpack/core/derivatives/__init__.py | 79 +-- backpack/core/derivatives/basederivatives.py | 556 ++++++++++++++---- backpack/core/derivatives/rnn.py | 236 ++++++++ backpack/core/derivatives/shape_check.py | 190 ++++-- fully_documented.txt | 13 + test/benchmark/jvp.py | 2 +- test/core/derivatives/__init__.py | 81 +++ test/core/derivatives/derivatives_test.py | 179 +++++- .../derivatives/implementation/autograd.py | 144 +++-- .../derivatives/implementation/backpack.py | 78 ++- test/core/derivatives/implementation/base.py | 156 ++++- test/core/derivatives/rnn_settings.py | 27 + test/core/derivatives/utils.py | 47 +- 13 files changed, 1471 insertions(+), 317 deletions(-) create mode 100644 backpack/core/derivatives/rnn.py create mode 100644 test/core/derivatives/rnn_settings.py diff --git a/backpack/core/derivatives/__init__.py b/backpack/core/derivatives/__init__.py index d9388ccec..059d55349 100644 --- a/backpack/core/derivatives/__init__.py +++ b/backpack/core/derivatives/__init__.py @@ -1,78 +1 @@ -from torch.nn import ( - ELU, - SELU, - AvgPool1d, - AvgPool2d, - AvgPool3d, - Conv1d, - Conv2d, - Conv3d, - ConvTranspose1d, - ConvTranspose2d, - ConvTranspose3d, - CrossEntropyLoss, - Dropout, - LeakyReLU, - Linear, - LogSigmoid, - MaxPool1d, - MaxPool2d, - MaxPool3d, - MSELoss, - ReLU, - Sigmoid, - Tanh, - ZeroPad2d, -) - -from .avgpool1d import AvgPool1DDerivatives -from .avgpool2d import AvgPool2DDerivatives -from .avgpool3d import AvgPool3DDerivatives -from .conv1d import Conv1DDerivatives -from .conv2d import Conv2DDerivatives -from .conv3d import Conv3DDerivatives -from .conv_transpose1d import ConvTranspose1DDerivatives -from .conv_transpose2d import ConvTranspose2DDerivatives -from .conv_transpose3d import ConvTranspose3DDerivatives -from .crossentropyloss import CrossEntropyLossDerivatives -from .dropout import DropoutDerivatives -from .elu import ELUDerivatives -from .leakyrelu import LeakyReLUDerivatives -from .linear import LinearDerivatives -from .logsigmoid import LogSigmoidDerivatives -from .maxpool1d import MaxPool1DDerivatives -from .maxpool2d import MaxPool2DDerivatives -from .maxpool3d import MaxPool3DDerivatives -from .mseloss import MSELossDerivatives -from .relu import ReLUDerivatives -from .selu import SELUDerivatives -from .sigmoid import SigmoidDerivatives -from .tanh import TanhDerivatives -from .zeropad2d import ZeroPad2dDerivatives - -derivatives_for = { - Linear: LinearDerivatives, - Conv1d: Conv1DDerivatives, - Conv2d: Conv2DDerivatives, - Conv3d: Conv3DDerivatives, - AvgPool1d: AvgPool1DDerivatives, - AvgPool2d: AvgPool2DDerivatives, - AvgPool3d: AvgPool3DDerivatives, - MaxPool1d: MaxPool1DDerivatives, - MaxPool2d: MaxPool2DDerivatives, - MaxPool3d: MaxPool3DDerivatives, - ZeroPad2d: ZeroPad2dDerivatives, - Dropout: DropoutDerivatives, - ReLU: ReLUDerivatives, - Tanh: TanhDerivatives, - Sigmoid: SigmoidDerivatives, - ConvTranspose1d: ConvTranspose1DDerivatives, - ConvTranspose2d: ConvTranspose2DDerivatives, - ConvTranspose3d: ConvTranspose3DDerivatives, - LeakyReLU: LeakyReLUDerivatives, - LogSigmoid: LogSigmoidDerivatives, - ELU: ELUDerivatives, - SELU: SELUDerivatives, - CrossEntropyLoss: CrossEntropyLossDerivatives, - MSELoss: MSELossDerivatives, -} +"""Contains derivatives of all supported modules.""" diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 1911d9adf..14f26a626 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -1,10 +1,15 @@ """Base classes for more flexible Jacobians and second-order information.""" import warnings +from abc import ABC +from typing import Callable, Tuple + +from torch import Tensor +from torch.nn import Module from backpack.core.derivatives import shape_check -class BaseDerivatives: +class BaseDerivatives(ABC): """First- and second-order partial derivatives of unparameterized module. Note: @@ -38,7 +43,9 @@ class BaseDerivatives: @shape_check.jac_mat_prod_accept_vectors @shape_check.jac_mat_prod_check_shapes - def jac_mat_prod(self, module, g_inp, g_out, mat): + def jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: """Apply Jacobian of the output w.r.t. input to a matrix. It is assumed that the module input has shape `[N, *]`, while the output is @@ -49,14 +56,14 @@ def jac_mat_prod(self, module, g_inp, g_out, mat): `result[v, n, •] = ∑ₖ ∑_* J[n, •, k, *] mat[v, n, *]`. Args: - module (torch.nn.Module): Extended module. - g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs. - g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs. - mat (torch.Tensor): Matrix the Jacobian will be applied to. Must have + module: Extended module. + g_inp: Gradients of the module w.r.t. its inputs. + g_out: Gradients of the module w.r.t. its outputs. + mat: Matrix the Jacobian will be applied to. Must have shape `[V, N, *]`. Returns: - torch.Tensor: Jacobian-matrix product. Has shape [V, N, *]. + Jacobian-matrix product. Has shape [V, N, *]. Note: - The Jacobian can be applied without knowledge about backpropagated @@ -65,40 +72,45 @@ def jac_mat_prod(self, module, g_inp, g_out, mat): """ return self._jac_mat_prod(module, g_inp, g_out, mat) - def _jac_mat_prod(self, module, g_inp, g_out, mat): - """Internal implementation of the input-output Jacobian.""" + def _jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: raise NotImplementedError @shape_check.jac_t_mat_prod_accept_vectors @shape_check.jac_t_mat_prod_check_shapes - def jac_t_mat_prod(self, module, g_inp, g_out, mat): + def jac_t_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: """Apply transposed input-ouput Jacobian of module output to a matrix. Implicit application of Jᵀ: result[v, ̃n, ̃c, ̃w, ...] = ∑_{n, c, w} Jᵀ[̃n, ̃c, ̃w, ..., n, c, w, ...] mat[v, n, c, w, ...]. - Parameters: - ----------- - mat: torch.Tensor - Matrix the transposed Jacobian will be applied to. - Must have shape [V, N, C_out, H_out, ...]. + Args: + module: module which derivative is calculated + g_inp: input gradients + g_out: output gradients + mat: Matrix the transposed Jacobian will be applied to. + Must have shape [V, N, C_out, H_out, ...]. Returns: - -------- - result: torch.Tensor Transposed Jacobian-matrix product. Has shape [V, N, C_in, H_in, ...]. """ return self._jac_t_mat_prod(module, g_inp, g_out, mat) - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): - """Internal implementation of transposed Jacobian.""" + def _jac_t_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: raise NotImplementedError # TODO Add shape check # TODO Use new convention - def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): + def ea_jac_t_mat_jac_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: """Expectation approximation of outer product with input-output Jacobian. Used for backpropagation in KFRA. @@ -109,76 +121,149 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): = 1/n ∑ₙₖₗ (𝜕output[n,k] / 𝜕input[n,i]) mat[k,l] (𝜕output[n,j] / 𝜕input[n,l]) Args: - module (torch.nn.Module): Extended module. - g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs. - g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs. - mat (torch.Tensor): Matrix of shape `[D_out, D_out]`. + module: Extended module. + g_inp: Gradients of the module w.r.t. its inputs. + g_out: Gradients of the module w.r.t. its outputs. + mat: Matrix of shape `[D_out, D_out]`. + # noqa: DAR202 Returns: - torch.Tensor: Matrix of shape `[D_in, D_in]`. + Matrix of shape `[D_in, D_in]`. Note: - This operation can be applied without knowledge about backpropagated derivatives. Both `g_inp` and `g_out` are usually not required and can be set to `None`. + + Raises: + NotImplementedError: if not overwritten """ raise NotImplementedError - def hessian_is_zero(self): + def hessian_is_zero(self) -> bool: + """Returns whether hessian is zero. + + # noqa: DAR202 + Returns: + whether hessian is zero + + Raises: + NotImplementedError: if not overwritten + """ raise NotImplementedError - def hessian_is_diagonal(self): - """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`.""" + def hessian_is_diagonal(self) -> bool: + """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`. + + # noqa: DAR202 + Returns: + whether hessian is diagonal + + Raises: + NotImplementedError: if not overwritten + """ raise NotImplementedError - def hessian_diagonal(self): + def hessian_diagonal(self) -> Tensor: """Return `∂²output[i] / ∂input[i]²`. Only required if `hessian_is_diagonal` returns `True`. + + # noqa: DAR202 + Returns: + hessian diagonal + + Raises: + NotImplementedError: if not overwritten """ raise NotImplementedError - def hessian_is_psd(self): - """Is `∂²output[i] / ∂input[j] ∂input[k]` positive semidefinite (PSD).""" + def hessian_is_psd(self) -> bool: + """Is `∂²output[i] / ∂input[j] ∂input[k]` positive semidefinite (PSD). + + # noqa: DAR202 + Returns: + whether hessian is positive semi definite + + Raises: + NotImplementedError: if not overwritten + """ raise NotImplementedError @shape_check.residual_mat_prod_accept_vectors @shape_check.residual_mat_prod_check_shapes - def residual_mat_prod(self, module, g_inp, g_out, mat): + def residual_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: """Multiply with the residual term. Performs mat → [∑_{k} Hz_k(x) 𝛿z_k] mat. + Args: + module: module + g_inp: input gradients + g_out: output gradients + mat: matrix to multiply + + Returns: + product + Note: - ----- This function only has to be implemented if the residual is not zero and not diagonal (for instance, `BatchNorm`). """ return self._residual_mat_prod(module, g_inp, g_out, mat) - def _residual_mat_prod(self, module, g_inp, g_out, mat): + def _residual_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: raise NotImplementedError @staticmethod - def _reshape_like(mat, like): + def _reshape_like(mat: Tensor, like: Tensor) -> Tensor: """Reshape as like with trailing and additional 0th dimension. If like is [N, C, H, ...], returns shape [-1, N, C, H, ...] + + Args: + mat: matrix to reshape + like: matrix with target shape + + Returns: + reshaped matrix """ V = -1 shape = (V, *like.shape) return mat.reshape(shape) @classmethod - def reshape_like_input(cls, mat, module): + def reshape_like_input(cls, mat: Tensor, module: Module) -> Tensor: + """Reshapes matrix according to input. + + Args: + mat: matrix to reshape + module: module which input shape is used + + Returns: + reshaped matrix + """ return cls._reshape_like(mat, module.input0) @classmethod - def reshape_like_output(cls, mat, module): + def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor: + """Reshapes matrix like output. + + Args: + mat: matrix to reshape + module: module which output is used + + Returns: + reshaped matrix + """ return cls._reshape_like(mat, module.output) -class BaseParameterDerivatives(BaseDerivatives): +class BaseParameterDerivatives(BaseDerivatives, ABC): """First- and second order partial derivatives of a module with parameters. Assumptions (true for `nn.Linear`, `nn.Conv(Transpose)Nd`, `nn.BatchNormNd`): @@ -195,92 +280,111 @@ class BaseParameterDerivatives(BaseDerivatives): @shape_check.bias_jac_mat_prod_accept_vectors @shape_check.bias_jac_mat_prod_check_shapes - def bias_jac_mat_prod(self, module, g_inp, g_out, mat): + def bias_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: """Apply Jacobian of the output w.r.t. bias to a matrix. - Parameters: - ----------- - mat: torch.Tensor - Matrix the Jacobian will be applied to. - Must have shape [V, C_b, ...]. + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the Jacobian will be applied to. + Must have shape [V, C_b, ...]. Returns: - -------- - result: torch.Tensor - Jacobian-matrix product. - Has shape [V, N, C_out, H_out, ...]. + Jacobian-matrix product. Has shape [V, N, C_out, H_out, ...]. """ return self._bias_jac_mat_prod(module, g_inp, g_out, mat) - def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): - """Internal implementation of the bias Jacobian.""" + def _bias_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: raise NotImplementedError @shape_check.bias_jac_t_mat_prod_accept_vectors @shape_check.bias_jac_t_mat_prod_check_shapes - def bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + def bias_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias to a matrix. - Parameters: - ----------- - mat: torch.Tensor - Matrix the transposed Jacobian will be applied to. - Must have shape [V, N, C_out, H_out, ...]. - sum_batch: bool - Whether to sum over the batch dimension on the fly. + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the transposed Jacobian will be applied to. + Must have shape [V, N, C_out, H_out, ...]. + sum_batch: Whether to sum over the batch dimension on the fly. Returns: - -------- - result: torch.Tensor Jacobian-matrix product. Has shape [V, N, C_b, ...] if `sum_batch == False`. Has shape [V, C_b, ...] if `sum_batch == True`. """ return self._bias_jac_t_mat_prod(module, g_inp, g_out, mat, sum_batch=sum_batch) - def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - """Internal implementation of the transposed bias Jacobian.""" + def _bias_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: raise NotImplementedError @shape_check.weight_jac_mat_prod_accept_vectors @shape_check.weight_jac_mat_prod_check_shapes - def weight_jac_mat_prod(self, module, g_inp, g_out, mat): + def weight_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: """Apply Jacobian of the output w.r.t. weight to a matrix. - Parameters: - ----------- - mat: torch.Tensor - Matrix the Jacobian will be applied to. - Must have shape [V, C_w, H_w, ...]. + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the Jacobian will be applied to. + Must have shape [V, C_w, H_w, ...]. Returns: - -------- - result: torch.Tensor Jacobian-matrix product. Has shape [V, N, C_out, H_out, ...]. """ return self._weight_jac_mat_prod(module, g_inp, g_out, mat) - def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): - """Internal implementation of weight Jacobian.""" + def _weight_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: raise NotImplementedError @shape_check.weight_jac_t_mat_prod_accept_vectors @shape_check.weight_jac_t_mat_prod_check_shapes - def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + def weight_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight to a matrix. - Parameters: - ----------- - mat: torch.Tensor - Matrix the transposed Jacobian will be applied to. - Must have shape [V, N, C_out, H_out, ...]. - sum_batch: bool - Whether to sum over the batch dimension on the fly. + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the transposed Jacobian will be applied to. + Must have shape [V, N, C_out, H_out, ...]. + sum_batch: Whether to sum over the batch dimension on the fly. Returns: - -------- - result: torch.Tensor Jacobian-matrix product. Has shape [V, N, C_w, H_w, ...] if `sum_batch == False`. Has shape [V, C_w, H_w, ...] if `sum_batch == True`. @@ -289,73 +393,315 @@ def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): module, g_inp, g_out, mat, sum_batch=sum_batch ) - def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - """Internal implementation of transposed weight Jacobian.""" + def _weight_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + raise NotImplementedError + + @shape_check.bias_jac_t_mat_prod_accept_vectors + @shape_check.bias_rnn_jac_t_mat_prod_check_shapes + def bias_ih_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. bias_ih_l0 to a matrix. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the transposed Jacobian will be applied to. + Must have shape [V, T, N, H]. + sum_batch: Whether to sum over the batch dimension on the fly. + + Returns: + Jacobian-matrix product. + Has shape [V, T, N, H] if `sum_batch == False`. + Has shape [V, T, H] if `sum_batch == True`. + """ + return self._bias_ih_l0_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch + ) + + def _bias_ih_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + raise NotImplementedError + + @shape_check.bias_jac_t_mat_prod_accept_vectors + @shape_check.bias_rnn_jac_t_mat_prod_check_shapes + def bias_hh_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. bias_hh_l0 to a matrix. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the transposed Jacobian will be applied to. + Must have shape [V, T, N, H]. + sum_batch: Whether to sum over the batch dimension on the fly. + + Returns: + Jacobian-matrix product. + Has shape [V, T, N, H] if `sum_batch == False`. + Has shape [V, T, H] if `sum_batch == True`. + """ + return self._bias_hh_l0_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch + ) + + def _bias_hh_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + raise NotImplementedError + + @shape_check.weight_jac_t_mat_prod_accept_vectors + @shape_check.weight_ih_jac_t_mat_prod_check_shapes + def weight_ih_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. weight_ih_l0 to a matrix. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the transposed Jacobian will be applied to. + Must have shape [V, T, N, H]. + sum_batch: Whether to sum over the batch dimension on the fly. + + Returns: + Jacobian-matrix product. + Has shape [V, T, N, H, I] if `sum_batch == False`. + Has shape [V, T, H, I] if `sum_batch == True`. + """ + return self._weight_ih_l0_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch + ) + + def _weight_ih_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + raise NotImplementedError + + @shape_check.weight_jac_t_mat_prod_accept_vectors + @shape_check.weight_hh_jac_t_mat_prod_check_shapes + def weight_hh_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. weight_hh_l0 to a matrix. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mat: Matrix the transposed Jacobian will be applied to. + Must have shape [V, T, N, H]. + sum_batch: Whether to sum over the batch dimension on the fly. + + Returns: + Jacobian-matrix product. + Has shape [V, T, N, H, I] if `sum_batch == False`. + Has shape [V, T, H, I] if `sum_batch == True`. + """ + return self._weight_hh_l0_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch + ) + + def _weight_hh_l0_jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: raise NotImplementedError -class BaseLossDerivatives(BaseDerivatives): +class BaseLossDerivatives(BaseDerivatives, ABC): """Second- order partial derivatives of loss functions.""" # TODO Add shape check - def sqrt_hessian(self, module, g_inp, g_out): - """Symmetric factorization ('sqrt') of the loss Hessian.""" - self.check_2nd_order_make_sense(module, g_inp, g_out) + def sqrt_hessian( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Tensor: + """Symmetric factorization ('sqrt') of the loss Hessian. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + + Returns: + square root of hessian + """ + self._check_2nd_order_make_sense(module, g_out) return self._sqrt_hessian(module, g_inp, g_out) - def _sqrt_hessian(self, module, g_inp, g_out): + def _sqrt_hessian( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Tensor: raise NotImplementedError # TODO Add shape check - def sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): - """Monte-Carlo sampled symmetric factorization of the loss Hessian.""" - self.check_2nd_order_make_sense(module, g_inp, g_out) + def sqrt_hessian_sampled( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mc_samples: int = 1, + ) -> Tensor: + """Monte-Carlo sampled symmetric factorization of the loss Hessian. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + mc_samples: number of monte carlo samples. Defaults to 1. + + Returns: + square root of hessian + """ + self._check_2nd_order_make_sense(module, g_out) return self._sqrt_hessian_sampled(module, g_inp, g_out, mc_samples=mc_samples) - def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): + def _sqrt_hessian_sampled( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mc_samples: int = 1, + ) -> Tensor: raise NotImplementedError @shape_check.make_hessian_mat_prod_accept_vectors @shape_check.make_hessian_mat_prod_check_shapes - def make_hessian_mat_prod(self, module, g_inp, g_out): + def make_hessian_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Callable[[Tensor], Tensor]: """Multiplication of the input Hessian with a matrix. Return a function that maps mat to H * mat. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + + Returns: + function that maps mat to H * mat """ - self.check_2nd_order_make_sense(module, g_inp, g_out) + self._check_2nd_order_make_sense(module, g_out) return self._make_hessian_mat_prod(module, g_inp, g_out) - def _make_hessian_mat_prod(self, module, g_inp, g_out): + def _make_hessian_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Callable[[Tensor], Tensor]: raise NotImplementedError # TODO Add shape check - def sum_hessian(self, module, g_inp, g_out): - """Loss Hessians, summed over the batch dimension.""" - self.check_2nd_order_make_sense(module, g_inp, g_out) + def sum_hessian( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Tensor: + """Loss Hessians, summed over the batch dimension. + + Args: + module: module to perform derivatives on + g_inp: input gradients + g_out: output gradients + + Returns: + sum of hessians + """ + self._check_2nd_order_make_sense(module, g_out) return self._sum_hessian(module, g_inp, g_out) - def _sum_hessian(self, module, g_inp, g_out): + def _sum_hessian( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Tensor: raise NotImplementedError - def check_2nd_order_make_sense(self, module, g_inp, g_out): + def _check_2nd_order_make_sense(self, module: Module, g_out: Tuple[Tensor]) -> None: """Verify conditions for 2nd-order extensions to be working. 2nd-order extensions are only guaranteed to work if the `loss`, on which `backward()` is called, is a scalar that has not been modified further after passing through the loss function module. + + Args: + module: module to perform derivatives on + g_out: output gradients """ self._check_output_is_scalar(module) self._check_loss_has_not_been_modified(module, g_out) - def _check_output_is_scalar(self, module): - """Raise an exception is the module output is not a scalar.""" + @classmethod + def _check_output_is_scalar(cls, module: Module) -> None: + """Raise an exception is the module output is not a scalar. + + Args: + module: module to perform derivatives on + + Raises: + ValueError: if output is not scalar + """ if module.output.numel() != 1: raise ValueError( "Output must be scalar. Got {}".format(module.output.shape) ) - def _check_loss_has_not_been_modified(self, module, g_out): - """Raise a warning if the module output seems to have been changed.""" + @classmethod + def _check_loss_has_not_been_modified( + cls, module: Module, g_out: Tuple[Tensor] + ) -> None: + """Raise a warning if the module output seems to have been changed. + + Args: + module: module to perform derivatives on + g_out: output gradients + """ grad_out_is_identity = g_out is None or (g_out[0] == 1.0).all().item() if not grad_out_is_identity: warnings.warn( diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py new file mode 100644 index 000000000..1f6b5bc88 --- /dev/null +++ b/backpack/core/derivatives/rnn.py @@ -0,0 +1,236 @@ +"""Partial derivatives for the torch.nn.RNN layer.""" +from typing import List, Tuple + +import torch +from torch import Tensor, cat, einsum, zeros +from torch.nn import RNN + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives + + +class RNNDerivatives(BaseParameterDerivatives): + """Partial derivatives for the torch.nn.RNN layer. + + a_t = W_ih x_t + b_ih + W_hh h_{t-1} + b_hh + h_t = tanh(a_t) + + Index conventions: + ------------------ + * t: Sequence dimension + * v: Free dimension + * n: Batch dimension + * h: Output dimension + * i: Input dimension + """ + + @staticmethod + def _check_parameters(module: RNN) -> None: + """Check the parameters of module. + + Args: + module: module which to check + + Raises: + ValueError: If any parameter of module does not match expectation + """ + if module.num_layers > 1: + raise ValueError("only num_layers = 1 is supported") + if not module.nonlinearity == "tanh": + raise ValueError("only nonlinearity = tanh is supported") + if module.bias is not True: + raise ValueError("only bias = True is supported") + if module.batch_first is not False: + raise ValueError("only batch_first = False is supported") + if not module.dropout == 0: + raise ValueError("only dropout = 0 is supported") + if module.bidirectional is not False: + raise ValueError("only bidirectional = False is supported") + + @staticmethod + def _a_jac_t_mat_prod( + output: Tensor, + weight_hh_l0: Tensor, + mat: Tensor, + ) -> Tensor: + """Calculates jacobian vector product wrt a. + + Args: + output: the values of the hidden layer + weight_hh_l0: weight matrix hidden-to-hidden + mat: matrix to multiply + + Returns: + jacobian vector product wrt a + """ + V: int = mat.shape[0] + N: int = mat.shape[2] + T: int = mat.shape[1] + H: int = mat.shape[3] + a_jac_t_mat_prod: Tensor = zeros(V, T, N, H, device=mat.device) + for t in range(T)[::-1]: + if t == (T - 1): + a_jac_t_mat_prod[:, t, ...] = einsum( + "vnh,nh->vnh", + mat[:, t, ...], + 1 - output[t, ...] ** 2, + ) + else: + a_jac_t_mat_prod[:, t, ...] = einsum( + "vnh,nh->vnh", + mat[:, t, ...] + + einsum( + "vng,gh->vnh", + a_jac_t_mat_prod[:, t + 1, ...], + weight_hh_l0, + ), + 1 - output[t, ...] ** 2, + ) + return a_jac_t_mat_prod + + def _jac_t_mat_prod( + self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + return torch.einsum( + "vtnh,hk->vtnk", + self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), + module.weight_ih_l0, + ) + + def _jac_mat_prod( + self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + V: int = mat.shape[0] + N: int = mat.shape[2] + T: int = mat.shape[1] + H: int = module.hidden_size + _jac_mat_prod: Tensor = torch.zeros(V, T, N, H, device=mat.device) + for t in range(T): + if t == 0: + _jac_mat_prod[:, t, ...] = einsum( + "nh,hi,vni->vnh", + 1 - module.output[t, ...] ** 2, + module.weight_ih_l0, + mat[:, t, ...], + ) + else: + _jac_mat_prod[:, t, ...] = einsum( + "nh,vnh->vnh", + 1 - module.output[t, ...] ** 2, + einsum("hi,vni->vnh", module.weight_ih_l0, mat[:, t, ...]) + + einsum( + "hk,vnk->vnh", + module.weight_hh_l0, + _jac_mat_prod[:, t - 1, ...], + ), + ) + return _jac_mat_prod + + def _bias_ih_l0_jac_t_mat_prod( + self, + module: RNN, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. bias_ih_l0. + + Args: + module: extended module + g_inp: input gradient + g_out: output gradient + mat: matrix to multiply + sum_batch: Whether to sum along batch axis. Defaults to True. + + Returns: + product + """ + self._check_parameters(module) + if sum_batch: + dim: List[int] = [1, 2] + else: + dim: int = 1 + return self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat).sum( + dim=dim + ) + + def _bias_hh_l0_jac_t_mat_prod( + self, + module: RNN, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. bias_hh_l0. + + Args: + module: extended module + g_inp: input gradient + g_out: output gradient + mat: matrix to multiply + sum_batch: Whether to sum along batch axis. Defaults to True. + + Returns: + product + """ + # identical to bias_ih_l0 + return self._bias_ih_l0_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch + ) + + def _weight_ih_l0_jac_t_mat_prod( + self, + module: RNN, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. weight_ih_l0. + + Args: + module: extended module + g_inp: input gradient + g_out: output gradient + mat: matrix to multiply + sum_batch: Whether to sum along batch axis. Defaults to True. + + Returns: + product + """ + self._check_parameters(module) + return einsum( + "vtnh,tnj->" + ("vhj" if sum_batch else "vnhj"), + self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), + module.input0, + ) + + def _weight_hh_l0_jac_t_mat_prod( + self, + module: RNN, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + """Apply transposed Jacobian of the output w.r.t. weight_hh_l0. + + Args: + module: extended module + g_inp: input gradient + g_out: output gradient + mat: matrix to multiply + sum_batch: Whether to sum along batch axis. Defaults to True. + + Returns: + product + """ + self._check_parameters(module) + N: int = mat.shape[2] + H: int = mat.shape[3] + return einsum( + "vtnh,tnk->" + ("vhk" if sum_batch else "vnhk"), + self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), + cat([zeros(1, N, H, device=mat.device), module.output[0:-1]], dim=0), + ) diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index 32b593a89..931c0ef0c 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -1,18 +1,22 @@ -""" -Helpers to support application of Jacobians to vectors +"""Helpers to support application of Jacobians to vectors. + Helpers to check input and output sizes of Jacobian-matrix products. """ import functools +from typing import Any, Callable + +from torch import Tensor +from torch.nn import Module ############################################################################### # Utility functions # ############################################################################### -def add_V_dim(mat): +def _add_V_dim(mat): return mat.unsqueeze(0) -def remove_V_dim(mat): +def _remove_V_dim(mat): if mat.shape[0] != 1: raise RuntimeError( "Cannot unsqueeze dimension 0. ", "Got tensor of shape {}".format(mat.shape) @@ -20,8 +24,17 @@ def remove_V_dim(mat): return mat.squeeze(0) -def check_shape(mat, like, diff=1): - """Compare dimension diff,diff+1, ... with dimension 0,1,...""" +def check_shape(mat: Tensor, like: Tensor, diff: int = 1) -> None: + """Compare dimension diff,diff+1, ... with dimension 0,1,... + + Args: + mat: matrix + like: comparison matrix + diff: difference in dimensions. Defaults to 1. + + Raises: + RuntimeError: if shape does not fit + """ mat_shape = [int(dim) for dim in mat.shape] like_shape = [int(dim) for dim in like.shape] @@ -39,81 +52,97 @@ def check_shape(mat, like, diff=1): ) -def check_same_V_dim(mat1, mat2): +def _check_same_V_dim(mat1, mat2): V1, V2 = mat1.shape[0], mat2.shape[0] if V1 != V2: raise RuntimeError("Number of vectors changed. Got {} and {}".format(V1, V2)) -def check_like(mat, module, name, diff=1, *args, **kwargs): +def _check_like(mat, module, name, diff=1, *args, **kwargs): return check_shape(mat, getattr(module, name), diff=diff) -def check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs): +def _check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs): diff = 1 if sum_batch else 2 return check_shape(mat, getattr(module, name), diff=diff) -def same_dim_as(mat, module, name, *args, **kwargs): +def _same_dim_as(mat, module, name, *args, **kwargs): return len(mat.shape) == len(getattr(module, name).shape) ############################################################################### # Decorators for handling vectors as matrix special case # ############################################################################### -def mat_prod_accept_vectors(mat_prod, vec_criterion): +def _mat_prod_accept_vectors( + mat_prod: Callable[..., Tensor], + vec_criterion: Callable[[Tensor, Module, Any, Any], bool], +) -> Callable[..., Tensor]: """Add support for vectors to matrix products. vec_criterion(mat, module) returns if mat is a vector. + + Args: + mat_prod: Function that processes multiple vectors in format of a matrix. + vec_criterion: Function that returns true if an input is a single vector + that must be formatted into a matrix first before processing. + + Returns: + Wrapped ``mat_prod`` function that processes multiple vectors in format of + a matrix, and supports vector-shaped inputs which are internally converted + to the correct format. + Preserves format of input: + If the input format is a vector, the output format is a vector. + If the input format is a matrix, the output format is a matrix. """ @functools.wraps(mat_prod) - def wrapped_mat_prod_accept_vectors( + def _wrapped_mat_prod_accept_vectors( self, module, g_inp, g_out, mat, *args, **kwargs ): is_vec = vec_criterion(mat, module, *args, **kwargs) - mat_in = mat if not is_vec else add_V_dim(mat) + mat_in = mat if not is_vec else _add_V_dim(mat) mat_out = mat_prod(self, module, g_inp, g_out, mat_in, *args, **kwargs) - mat_out = mat_out if not is_vec else remove_V_dim(mat_out) + mat_out = mat_out if not is_vec else _remove_V_dim(mat_out) return mat_out - return wrapped_mat_prod_accept_vectors + return _wrapped_mat_prod_accept_vectors # vec criteria -same_dim_as_output = functools.partial(same_dim_as, name="output") -same_dim_as_input = functools.partial(same_dim_as, name="input0") -same_dim_as_weight = functools.partial(same_dim_as, name="weight") -same_dim_as_bias = functools.partial(same_dim_as, name="bias") +same_dim_as_output = functools.partial(_same_dim_as, name="output") +same_dim_as_input = functools.partial(_same_dim_as, name="input0") +same_dim_as_weight = functools.partial(_same_dim_as, name="weight") +same_dim_as_bias = functools.partial(_same_dim_as, name="bias") # decorators for handling vectors jac_t_mat_prod_accept_vectors = functools.partial( - mat_prod_accept_vectors, + _mat_prod_accept_vectors, vec_criterion=same_dim_as_output, ) weight_jac_t_mat_prod_accept_vectors = functools.partial( - mat_prod_accept_vectors, + _mat_prod_accept_vectors, vec_criterion=same_dim_as_output, ) bias_jac_t_mat_prod_accept_vectors = functools.partial( - mat_prod_accept_vectors, + _mat_prod_accept_vectors, vec_criterion=same_dim_as_output, ) jac_mat_prod_accept_vectors = functools.partial( - mat_prod_accept_vectors, + _mat_prod_accept_vectors, vec_criterion=same_dim_as_input, ) weight_jac_mat_prod_accept_vectors = functools.partial( - mat_prod_accept_vectors, + _mat_prod_accept_vectors, vec_criterion=same_dim_as_weight, ) bias_jac_mat_prod_accept_vectors = functools.partial( - mat_prod_accept_vectors, + _mat_prod_accept_vectors, vec_criterion=same_dim_as_bias, ) @@ -121,15 +150,27 @@ def wrapped_mat_prod_accept_vectors( ############################################################################### # Decorators for checking inputs and outputs of mat_prod routines # ############################################################################### -def mat_prod_check_shapes(mat_prod, in_check, out_check): - """Check that input and output have correct shapes.""" +def mat_prod_check_shapes( + mat_prod: Callable, in_check: Callable, out_check: Callable +) -> Callable[..., Tensor]: + """Check that input and output have correct shapes. + + Args: + mat_prod: Function that applies a derivative operator to multiple vectors + handed in as a matrix. + in_check: Function that checks the input to mat_prod + out_check: Function that checks the output to mat_prod + + Returns: + Wrapped mat_prod function with input and output checks + """ @functools.wraps(mat_prod) def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwargs): in_check(mat, module, *args, **kwargs) mat_out = mat_prod(self, module, g_inp, g_out, mat, *args, **kwargs) out_check(mat_out, module, *args, **kwargs) - check_same_V_dim(mat_out, mat) + _check_same_V_dim(mat_out, mat) return mat_out @@ -137,15 +178,24 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar # input/output checker -shape_like_output = functools.partial(check_like, name="output") -shape_like_input = functools.partial(check_like, name="input0") -shape_like_weight = functools.partial(check_like, name="weight") -shape_like_bias = functools.partial(check_like, name="bias") +shape_like_output = functools.partial(_check_like, name="output") +shape_like_input = functools.partial(_check_like, name="input0") +shape_like_weight = functools.partial(_check_like, name="weight") +shape_like_bias = functools.partial(_check_like, name="bias") shape_like_weight_with_sum_batch = functools.partial( - check_like_with_sum_batch, name="weight" + _check_like_with_sum_batch, name="weight" ) shape_like_bias_with_sum_batch = functools.partial( - check_like_with_sum_batch, name="bias" + _check_like_with_sum_batch, name="bias" +) +shape_like_bias_rnn_with_sum_batch = functools.partial( + _check_like_with_sum_batch, name="bias_ih_l0" +) +shape_like_weight_ih_with_sum_batch = functools.partial( + _check_like_with_sum_batch, name="weight_ih_l0" +) +shape_like_weight_hh_with_sum_batch = functools.partial( + _check_like_with_sum_batch, name="weight_hh_l0" ) # decorators for shape checking @@ -176,6 +226,21 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar in_check=shape_like_output, out_check=shape_like_bias_with_sum_batch, ) +bias_rnn_jac_t_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, + in_check=shape_like_output, + out_check=shape_like_bias_rnn_with_sum_batch, +) +weight_ih_jac_t_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, + in_check=shape_like_output, + out_check=shape_like_weight_ih_with_sum_batch, +) +weight_hh_jac_t_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, + in_check=shape_like_output, + out_check=shape_like_weight_hh_with_sum_batch, +) ############################################################################### # Wrapper for second-order extensions # @@ -185,44 +250,69 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar ) residual_mat_prod_accept_vectors = functools.partial( - mat_prod_accept_vectors, + _mat_prod_accept_vectors, vec_criterion=same_dim_as_input, ) # TODO Refactor using partials -def make_hessian_mat_prod_accept_vectors(make_hessian_mat_prod): +def make_hessian_mat_prod_accept_vectors( + make_hessian_mat_prod: Callable, +) -> Callable[..., Callable[..., Tensor]]: + """Accept vectors for hessian_mat_prod. + + Args: + make_hessian_mat_prod: Function that creates multiplication routine + of a matrix with the module Hessian + + Returns: + Wrapped hessian_mat_prod which converts vector-format inputs to a matrix + before processing. Preserves format of input. + """ + @functools.wraps(make_hessian_mat_prod) - def wrapped_make_hessian_mat_prod(self, module, g_inp, g_out): + def _wrapped_make_hessian_mat_prod(self, module, g_inp, g_out): hessian_mat_prod = make_hessian_mat_prod(self, module, g_inp, g_out) - def new_hessian_mat_prod(mat): - is_vec = same_dim_as(mat, module, "input0") - mat_in = mat if not is_vec else add_V_dim(mat) + def _new_hessian_mat_prod(mat): + is_vec = _same_dim_as(mat, module, "input0") + mat_in = mat if not is_vec else _add_V_dim(mat) mat_out = hessian_mat_prod(mat_in) - mat_out = mat_out if not is_vec else remove_V_dim(mat_out) + mat_out = mat_out if not is_vec else _remove_V_dim(mat_out) return mat_out - return new_hessian_mat_prod + return _new_hessian_mat_prod + + return _wrapped_make_hessian_mat_prod - return wrapped_make_hessian_mat_prod +def make_hessian_mat_prod_check_shapes( + make_hessian_mat_prod: Callable[..., Callable[..., Tensor]], +) -> Callable[..., Callable[..., Tensor]]: + """Wrap hessian_mat_prod with shape checks for input and output. + + Args: + make_hessian_mat_prod: function that creates multiplication routine of + a matrix with the module Hessian. + + Returns: + wrapped hessian_mat_prod with shape checks for input and output + """ -def make_hessian_mat_prod_check_shapes(make_hessian_mat_prod): @functools.wraps(make_hessian_mat_prod) - def wrapped_make_hessian_mat_prod(self, module, g_inp, g_out): + def _wrapped_make_hessian_mat_prod(self, module, g_inp, g_out): hessian_mat_prod = make_hessian_mat_prod(self, module, g_inp, g_out) - def new_hessian_mat_prod(mat): - check_like(mat, module, "input0") + def _new_hessian_mat_prod(mat): + _check_like(mat, module, "input0") result = hessian_mat_prod(mat) - check_like(result, module, "input0") + _check_like(result, module, "input0") return result - return new_hessian_mat_prod + return _new_hessian_mat_prod - return wrapped_make_hessian_mat_prod + return _wrapped_make_hessian_mat_prod diff --git a/fully_documented.txt b/fully_documented.txt index 72e8c3d10..b5bd63967 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -1,3 +1,16 @@ +backpack/core/derivatives/basederivatives.py +backpack/core/derivatives/rnn.py +backpack/core/derivatives/shape_check.py +backpack/core/derivatives/__init__.py + +test/core/derivatives/derivatives_test.py +test/core/derivatives/__init__.py +test/core/derivatives/rnn_settings.py +test/core/derivatives/utils.py +test/core/derivatives/implementation/ + test/extensions/problem.py test/extensions/test_backprop_extension.py test/extensions/firstorder/firstorder_settings.py + +# docs_src/examples/use_cases/example_custom_module diff --git a/test/benchmark/jvp.py b/test/benchmark/jvp.py index 7e5c89256..e7115d152 100644 --- a/test/benchmark/jvp.py +++ b/test/benchmark/jvp.py @@ -1,11 +1,11 @@ from functools import partial +from test.core.derivatives import derivatives_for import pytest import torch from torch import allclose from torch.nn import Dropout, ReLU, Sigmoid, Tanh -from backpack.core.derivatives import derivatives_for from backpack.hessianfree.lop import transposed_jacobian_vector_product from backpack.hessianfree.rop import jacobian_vector_product diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index e5a0962f9..3ba38952e 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -1 +1,82 @@ """Test functionality of `backpack.core.derivatives` module.""" +from torch.nn import ( + ELU, + RNN, + SELU, + AvgPool1d, + AvgPool2d, + AvgPool3d, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + CrossEntropyLoss, + Dropout, + LeakyReLU, + Linear, + LogSigmoid, + MaxPool1d, + MaxPool2d, + MaxPool3d, + MSELoss, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) + +from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives +from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives +from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives +from backpack.core.derivatives.conv1d import Conv1DDerivatives +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.core.derivatives.conv3d import Conv3DDerivatives +from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives +from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives +from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives +from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives +from backpack.core.derivatives.dropout import DropoutDerivatives +from backpack.core.derivatives.elu import ELUDerivatives +from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives +from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives +from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives +from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives +from backpack.core.derivatives.mseloss import MSELossDerivatives +from backpack.core.derivatives.relu import ReLUDerivatives +from backpack.core.derivatives.rnn import RNNDerivatives +from backpack.core.derivatives.selu import SELUDerivatives +from backpack.core.derivatives.sigmoid import SigmoidDerivatives +from backpack.core.derivatives.tanh import TanhDerivatives +from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives + +derivatives_for = { + Linear: LinearDerivatives, + Conv1d: Conv1DDerivatives, + Conv2d: Conv2DDerivatives, + Conv3d: Conv3DDerivatives, + AvgPool1d: AvgPool1DDerivatives, + AvgPool2d: AvgPool2DDerivatives, + AvgPool3d: AvgPool3DDerivatives, + MaxPool1d: MaxPool1DDerivatives, + MaxPool2d: MaxPool2DDerivatives, + MaxPool3d: MaxPool3DDerivatives, + ZeroPad2d: ZeroPad2dDerivatives, + Dropout: DropoutDerivatives, + ReLU: ReLUDerivatives, + Tanh: TanhDerivatives, + Sigmoid: SigmoidDerivatives, + ConvTranspose1d: ConvTranspose1DDerivatives, + ConvTranspose2d: ConvTranspose2DDerivatives, + ConvTranspose3d: ConvTranspose3DDerivatives, + LeakyReLU: LeakyReLUDerivatives, + LogSigmoid: LogSigmoidDerivatives, + ELU: ELUDerivatives, + SELU: SELUDerivatives, + CrossEntropyLoss: CrossEntropyLossDerivatives, + MSELoss: MSELossDerivatives, + RNN: RNNDerivatives, +} diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 846e776ea..743b10b6a 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -16,6 +16,7 @@ from test.core.derivatives.implementation.backpack import BackpackDerivatives from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS from test.core.derivatives.problem import make_test_problems +from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS from test.core.derivatives.settings import SETTINGS import pytest @@ -46,8 +47,13 @@ problem.make_id() for problem in CONVOLUTION_TRANSPOSED_FAIL_PROBLEMS ] +RNN_PROBLEMS = make_test_problems(RNN_SETTINGS) +RNN_IDS = [problem.make_id() for problem in RNN_PROBLEMS] -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) + +@pytest.mark.parametrize( + "problem", NO_LOSS_PROBLEMS + RNN_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS +) def test_jac_mat_prod(problem, V=3): """Test the Jacobian-matrix product. @@ -65,7 +71,9 @@ def test_jac_mat_prod(problem, V=3): problem.tear_down() -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) +@pytest.mark.parametrize( + "problem", NO_LOSS_PROBLEMS + RNN_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS +) def test_jac_t_mat_prod(problem, V=3): """Test the transposed Jacobian-matrix product. @@ -91,6 +99,110 @@ def test_jac_t_mat_prod(problem, V=3): IDS_WITH_WEIGHTS.append(problem_id) +@pytest.mark.parametrize( + "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] +) +@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +def test_bias_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): + """Test the transposed Jacobian-matrix product w.r.t. to bias_ih_l0. + + Args: + problem (DerivativesProblem): Problem for derivative test. + sum_batch (bool): Sum results over the batch dimension. + V (int): Number of vectorized transposed Jacobian-vector products. + """ + problem.set_up() + mat = torch.rand(V, *problem.output_shape).to(problem.device) + + autograd_res = AutogradDerivatives(problem).bias_ih_l0_jac_t_mat_prod( + mat, sum_batch + ) + backpack_res = BackpackDerivatives(problem).bias_ih_l0_jac_t_mat_prod( + mat, sum_batch + ) + + check_sizes_and_values(autograd_res, backpack_res) + problem.tear_down() + + +@pytest.mark.parametrize( + "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] +) +@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +def test_bias_hh_l0_jac_t_mat_prod(problem, sum_batch, V=3): + """Test the transposed Jacobian-matrix product w.r.t. to bias_hh_l0. + + Args: + problem (DerivativesProblem): Problem for derivative test. + sum_batch (bool): Sum results over the batch dimension. + V (int): Number of vectorized transposed Jacobian-vector products. + """ + problem.set_up() + mat = torch.rand(V, *problem.output_shape).to(problem.device) + + autograd_res = AutogradDerivatives(problem).bias_hh_l0_jac_t_mat_prod( + mat, sum_batch + ) + backpack_res = BackpackDerivatives(problem).bias_hh_l0_jac_t_mat_prod( + mat, sum_batch + ) + + check_sizes_and_values(autograd_res, backpack_res) + problem.tear_down() + + +@pytest.mark.parametrize( + "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] +) +@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +def test_weight_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): + """Test the transposed Jacobian-matrix product w.r.t. to weight_ih_l0. + + Args: + problem (DerivativesProblem): Problem for derivative test. + sum_batch (bool): Sum results over the batch dimension. + V (int): Number of vectorized transposed Jacobian-vector products. + """ + problem.set_up() + mat = torch.rand(V, *problem.output_shape).to(problem.device) + + autograd_res = AutogradDerivatives(problem).weight_ih_l0_jac_t_mat_prod( + mat, sum_batch + ) + backpack_res = BackpackDerivatives(problem).weight_ih_l0_jac_t_mat_prod( + mat, sum_batch + ) + + check_sizes_and_values(autograd_res, backpack_res) + problem.tear_down() + + +@pytest.mark.parametrize( + "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] +) +@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): + """Test the transposed Jacobian-matrix product w.r.t. to weight_hh_l0. + + Args: + problem (DerivativesProblem): Problem for derivative test. + sum_batch (bool): Sum results over the batch dimension. + V (int): Number of vectorized transposed Jacobian-vector products. + """ + problem.set_up() + mat = torch.rand(V, *problem.output_shape).to(problem.device) + + autograd_res = AutogradDerivatives(problem).weight_hh_l0_jac_t_mat_prod( + mat, sum_batch + ) + backpack_res = BackpackDerivatives(problem).weight_hh_l0_jac_t_mat_prod( + mat, sum_batch + ) + + check_sizes_and_values(autograd_res, backpack_res) + problem.tear_down() + + @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @@ -239,6 +351,13 @@ def test_sqrt_hessian_squared_equals_hessian(problem): ids=CONVOLUTION_TRANSPOSED_FAIL_IDS + CONVOLUTION_FAIL_IDS, ) def test_weight_jac_mat_prod_should_fail(problem): + """Tests weight_jac_mat_prod. + + Should fail. + + Args: + problem: test problem + """ with pytest.raises(NotImplementedError): test_weight_jac_mat_prod(problem) @@ -255,12 +374,26 @@ def test_weight_jac_mat_prod_should_fail(problem): "problem", CONVOLUTION_TRANSPOSED_FAIL_PROBLEMS, ids=CONVOLUTION_TRANSPOSED_FAIL_IDS ) def test_weight_jac_t_mat_prod_should_fail(problem, sum_batch, save_memory): + """Test weight_jac_t_mat_prod. + + Should fail. + + Args: + problem: problem + sum_batch: whether to sum along batch axis + save_memory: whether to save memory + """ with pytest.raises(NotImplementedError): test_weight_jac_t_mat_prod(problem, sum_batch, save_memory) @pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_should_fail(problem): + """Test sqrt_hessian. Should fail. + + Args: + problem: test problem + """ with pytest.raises(ValueError): test_sqrt_hessian_squared_equals_hessian(problem) @@ -271,6 +404,7 @@ def test_sqrt_hessian_sampled_squared_approximates_hessian(problem, mc_samples=1 Args: problem (DerivativesProblem): Problem for derivative test. + mc_samples: number of samples. Defaults to 100000. Compares the Hessian to reconstruction from individual Hessian MC-sampled sqrt. """ @@ -288,6 +422,11 @@ def test_sqrt_hessian_sampled_squared_approximates_hessian(problem, mc_samples=1 @pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_sampled_should_fail(problem): + """Test sqrt_hessian. Should fail. + + Args: + problem: test problem + """ with pytest.raises(ValueError): test_sqrt_hessian_sampled_squared_approximates_hessian(problem) @@ -310,13 +449,18 @@ def test_sum_hessian(problem): @pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sum_hessian_should_fail(problem): + """Test sum_hessian, should fail. + + Args: + problem: test problem + """ with pytest.raises(ValueError): test_sum_hessian(problem) @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) def test_ea_jac_t_mat_jac_prod(problem): - """Test KFRA backpropagation + """Test KFRA backpropagation. H_in → 1/N ∑ₙ Jₙ^T H_out Jₙ @@ -343,7 +487,11 @@ def test_ea_jac_t_mat_jac_prod(problem): @pytest.mark.skip("[WAITING] Autograd issue with Hessian-vector products") @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) def test_hessian_is_zero(problem): - """Check if the input-output Hessian is (non-)zero.""" + """Check if the input-output Hessian is (non-)zero. + + Args: + problem: test problem + """ problem.set_up() backpack_res = BackpackDerivatives(problem).hessian_is_zero() @@ -356,17 +504,14 @@ def test_hessian_is_zero(problem): @pytest.mark.skip @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) def test_hessian_is_diagonal(problem): - problem.set_up() - - # TODO - raise NotImplementedError - - problem.tear_down() + """Test whether hessian is diagonal. + Args: + problem: test problem -@pytest.mark.skip -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_diagonal(problem): + Raises: + NotImplementedError: . + """ problem.set_up() # TODO @@ -378,6 +523,14 @@ def test_hessian_diagonal(problem): @pytest.mark.skip @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) def test_hessian_is_psd(problem): + """Test whether hessian is semi positive definite. + + Args: + problem: test problem + + Raises: + NotImplementedError: . + """ problem.set_up() # TODO diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index b2ecbca6c..061db5cf4 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -1,6 +1,8 @@ +"""Derivatives computed with PyTorch's autograd.""" from test.core.derivatives.implementation.base import DerivativesImplementation import torch +from torch import Tensor from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.lop import transposed_jacobian_vector_product @@ -10,23 +12,38 @@ class AutogradDerivatives(DerivativesImplementation): """Derivative implementations with autograd.""" - def jac_vec_prod(self, vec): + def jac_vec_prod(self, vec) -> Tensor: + """Product of input-output-Jacobian and a vector. + + Args: + vec: vector + + Returns: + product + """ input, output, _ = self.problem.forward_pass(input_requires_grad=True) return jacobian_vector_product(output, input, vec)[0] - def jac_mat_prod(self, mat): + def jac_mat_prod(self, mat): # noqa: D102 V = mat.shape[0] vecs = [mat[v] for v in range(V)] - jac_vec_prods = [self.jac_vec_prod(vec) for vec in vecs] + try: + jac_vec_prods = [self.jac_vec_prod(vec) for vec in vecs] + except RuntimeError: + # A RuntimeError is thrown for RNNs on CUDA, + # because PyTorch does not support double-backwards pass for them. + # This is the recommended workaround. + with torch.backends.cudnn.flags(enabled=False): + jac_vec_prods = [self.jac_vec_prod(vec) for vec in vecs] return torch.stack(jac_vec_prods) - def jac_t_vec_prod(self, vec): + def jac_t_vec_prod(self, vec): # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) return transposed_jacobian_vector_product(output, input, vec)[0] - def jac_t_mat_prod(self, mat): + def jac_t_mat_prod(self, mat): # noqa: D102 V = mat.shape[0] vecs = [mat[v] for v in range(V)] @@ -34,15 +51,24 @@ def jac_t_mat_prod(self, mat): return torch.stack(jac_t_vec_prods) - def weight_jac_t_mat_prod(self, mat, sum_batch): + def weight_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 return self.param_jac_t_mat_prod("weight", mat, sum_batch) - def bias_jac_t_mat_prod(self, mat, sum_batch): + def bias_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 return self.param_jac_t_mat_prod("bias", mat, sum_batch) - def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): + def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) + def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) + + def weight_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + return self.param_jac_t_mat_prod("weight_ih_l0", mat, sum_batch, axis_batch=1) + + def weight_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + return self.param_jac_t_mat_prod("weight_hh_l0", mat, sum_batch, axis_batch=1) + def param_jac_t_vec_prod(self, name, vec, sum_batch, axis_batch=0): """Compute the product of jac_t and the given vector. @@ -94,31 +120,47 @@ def param_jac_t_mat_prod(self, name, mat, sum_batch, axis_batch=0): return torch.stack(jac_t_vec_prods) - def weight_jac_mat_prod(self, mat): - return self.param_jac_mat_prod("weight", mat) + def weight_jac_mat_prod(self, mat) -> Tensor: + """Product of jacobian and matrix. - def bias_jac_mat_prod(self, mat): - return self.param_jac_mat_prod("bias", mat) + Args: + mat: matrix + + Returns: + product + """ + return self._param_jac_mat_prod("weight", mat) + + def bias_jac_mat_prod(self, mat) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + + Returns: + product + """ + return self._param_jac_mat_prod("bias", mat) - def param_jac_vec_prod(self, name, vec): + def _param_jac_vec_prod(self, name, vec): input, output, named_params = self.problem.forward_pass() param = named_params[name] return jacobian_vector_product(output, param, vec)[0] - def param_jac_mat_prod(self, name, mat): + def _param_jac_mat_prod(self, name, mat): V = mat.shape[0] vecs = [mat[v] for v in range(V)] - jac_vec_prods = [self.param_jac_vec_prod(name, vec) for vec in vecs] + jac_vec_prods = [self._param_jac_vec_prod(name, vec) for vec in vecs] return torch.stack(jac_vec_prods) - def ea_jac_t_mat_jac_prod(self, mat): - def sample_jac_t_mat_jac_prod(sample_idx, mat): + def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102 + def _sample_jac_t_mat_jac_prod(sample_idx, mat): assert len(mat.shape) == 2 - def sample_jac_t_mat_prod(sample_idx, mat): + def _sample_jac_t_mat_prod(sample_idx, mat): sample, output, _ = self.problem.forward_pass( input_requires_grad=True, sample_idx=sample_idx ) @@ -133,9 +175,9 @@ def sample_jac_t_mat_prod(sample_idx, mat): return result - jac_t_mat = sample_jac_t_mat_prod(sample_idx, mat) + jac_t_mat = _sample_jac_t_mat_prod(sample_idx, mat) mat_t_jac = jac_t_mat.t() - jac_t_mat_t_jac = sample_jac_t_mat_prod(sample_idx, mat_t_jac) + jac_t_mat_t_jac = _sample_jac_t_mat_prod(sample_idx, mat_t_jac) jac_t_mat_jac = jac_t_mat_t_jac.t() return jac_t_mat_jac @@ -146,21 +188,23 @@ def sample_jac_t_mat_prod(sample_idx, mat): result = torch.zeros(input_features, input_features).to(self.problem.device) for n in range(N): - result += sample_jac_t_mat_jac_prod(n, mat) + result += _sample_jac_t_mat_jac_prod(n, mat) return result / N - def hessian(self, loss, x): + def _hessian(self, loss: Tensor, x: Tensor) -> Tensor: """Return the Hessian matrix of a scalar `loss` w.r.t. a tensor `x`. - Arguments: - loss (torch.Tensor): A scalar-valued tensor. - x (torch.Tensor): Tensor used in the computation graph of `loss`. + Args: + loss: A scalar-valued tensor. + x: Tensor used in the computation graph of `loss`. + Shapes: loss: `[1,]` x: `[A, B, C, ...]` + Returns: - torch.Tensor: Hessian tensor of `loss` w.r.t. `x`. The Hessian has shape + Hessian tensor of `loss` w.r.t. `x`. The Hessian has shape `[A, B, C, ..., A, B, C, ...]`. """ assert loss.numel() == 1 @@ -182,15 +226,22 @@ def hessian(self, loss, x): return hessian_vec_x.reshape(final_shape) - def elementwise_hessian(self, tensor, x): - """Yield the Hessian of each element in `tensor` w.r.t `x`. + def _elementwise_hessian(self, tensor, x: Tensor): + """Computes the Hessian of each element in `tensor` w.r.t `x`. Hessians are returned in the order of elements in the flattened tensor. + + Args: + tensor: . + x: Tensor used in the computation graph of `loss`. + + Yields: + hessian of each element """ for t in tensor.flatten(): - yield self.hessian(t, x) + yield self._hessian(t, x) - def tensor_hessian(self, tensor, x): + def _tensor_hessian(self, tensor, x): """Return the Hessian of a tensor `tensor` w.r.t. a tensor `x`. Given a `tensor` of shape `[A, B, C]` and another tensor `x` with shape `[D, E]` @@ -208,7 +259,7 @@ def tensor_hessian(self, tensor, x): """ shape = (*tensor.shape, *x.shape, *x.shape) - return torch.cat(list(self.elementwise_hessian(tensor, x))).reshape(shape) + return torch.cat(list(self._elementwise_hessian(tensor, x))).reshape(shape) def hessian_is_zero(self): """Return whether the input-output Hessian is zero. @@ -219,7 +270,7 @@ def hessian_is_zero(self): input, output, _ = self.problem.forward_pass(input_requires_grad=True) zero = None - for hessian in self.elementwise_hessian(output, input): + for hessian in self._elementwise_hessian(output, input): if zero is None: zero = torch.zeros_like(hessian) @@ -228,21 +279,38 @@ def hessian_is_zero(self): return True - def input_hessian(self): - """Compute the Hessian of the module output w.r.t. the input.""" + def input_hessian(self) -> Tensor: + """Compute the Hessian of the module output w.r.t. the input. + + Returns: + hessian + """ input, output, _ = self.problem.forward_pass(input_requires_grad=True) - return self.hessian(output, input) + return self._hessian(output, input) + + def sum_hessian(self) -> Tensor: + """Compute the Hessian of a loss module w.r.t. its input. - def sum_hessian(self): - """Compute the Hessian of a loss module w.r.t. its input.""" + Returns: + hessian + """ hessian = self.input_hessian() return self._sum_hessian_blocks(hessian) - def _sum_hessian_blocks(self, hessian): + def _sum_hessian_blocks(self, hessian: Tensor) -> Tensor: """Sum second derivatives over the batch dimension. Assert second derivative w.r.t. different samples is zero. + + Args: + hessian: . + + Returns: + sum of hessians + + Raises: + ValueError: if input is not 2d """ input = self.problem.input num_axes = len(input.shape) diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index c36d94d9f..e3de0a402 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -1,66 +1,110 @@ +"""Contains derivative calculation with BackPACK.""" from test.core.derivatives.implementation.base import DerivativesImplementation import torch +from torch import Tensor class BackpackDerivatives(DerivativesImplementation): """Derivative implementations with BackPACK.""" def __init__(self, problem): + """Initialization. + + Args: + problem: test problem + """ problem.extend() super().__init__(problem) def store_forward_io(self): + """Do one forward pass. + + This implicitly saves relevant quantities for backward pass. + """ self.problem.forward_pass() - def jac_mat_prod(self, mat): + def jac_mat_prod(self, mat): # noqa: D102 self.store_forward_io() return self.problem.derivative.jac_mat_prod( self.problem.module, None, None, mat ) - def jac_t_mat_prod(self, mat): + def jac_t_mat_prod(self, mat): # noqa: D102 self.store_forward_io() return self.problem.derivative.jac_t_mat_prod( self.problem.module, None, None, mat ) - def weight_jac_t_mat_prod(self, mat, sum_batch): + def weight_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 self.store_forward_io() return self.problem.derivative.weight_jac_t_mat_prod( self.problem.module, None, None, mat, sum_batch=sum_batch ) - def bias_jac_t_mat_prod(self, mat, sum_batch): + def bias_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 self.store_forward_io() return self.problem.derivative.bias_jac_t_mat_prod( self.problem.module, None, None, mat, sum_batch=sum_batch ) - def weight_jac_mat_prod(self, mat): + def weight_jac_mat_prod(self, mat): # noqa: D102 self.store_forward_io() return self.problem.derivative.weight_jac_mat_prod( self.problem.module, None, None, mat ) - def bias_jac_mat_prod(self, mat): + def bias_jac_mat_prod(self, mat): # noqa: D102 self.store_forward_io() return self.problem.derivative.bias_jac_mat_prod( self.problem.module, None, None, mat ) - def ea_jac_t_mat_jac_prod(self, mat): + def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + self.store_forward_io() + return self.problem.derivative.bias_ih_l0_jac_t_mat_prod( + self.problem.module, None, None, mat, sum_batch=sum_batch + ) + + def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + self.store_forward_io() + return self.problem.derivative.bias_hh_l0_jac_t_mat_prod( + self.problem.module, None, None, mat, sum_batch=sum_batch + ) + + def weight_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + self.store_forward_io() + return self.problem.derivative.weight_ih_l0_jac_t_mat_prod( + self.problem.module, None, None, mat, sum_batch=sum_batch + ) + + def weight_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + self.store_forward_io() + return self.problem.derivative.weight_hh_l0_jac_t_mat_prod( + self.problem.module, None, None, mat, sum_batch=sum_batch + ) + + def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102 self.store_forward_io() return self.problem.derivative.ea_jac_t_mat_jac_prod( self.problem.module, None, None, mat ) - def sum_hessian(self): + def sum_hessian(self): # noqa: D102 self.store_forward_io() return self.problem.derivative.sum_hessian(self.problem.module, None, None) - def input_hessian_via_sqrt_hessian(self, mc_samples=None): - # MC_SAMPLES = 100000 + def input_hessian_via_sqrt_hessian(self, mc_samples=None) -> Tensor: + """Computes the input hessian. + + Args: + mc_samples: If int, uses an MC approximation with the specified + number of samples. + If None, uses the exact hessian. Defaults to None. + + Returns: + hessian + """ self.store_forward_io() if mc_samples is not None: @@ -86,8 +130,18 @@ def hessian_is_zero(self): """ return self.problem.derivative.hessian_is_zero() - def _sample_hessians_from_sqrt(self, sqrt): - """Convert individual matrix square root into individual full matrix.""" + def _sample_hessians_from_sqrt(self, sqrt: Tensor) -> Tensor: + """Convert individual matrix square root into individual full matrix. + + Args: + sqrt: individual square root of hessian + + Returns: + individual full matrix + + Raises: + ValueError: if input is not 2d + """ equation = None num_axes = len(sqrt.shape) diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index 9edaa8194..0c2e5580f 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -1,23 +1,165 @@ -class DerivativesImplementation: +"""Contains DerivativesImplementation, the base class for autograd and backpack.""" +from abc import ABC, abstractmethod + +from torch import Tensor + + +class DerivativesImplementation(ABC): """Base class for autograd and BackPACK implementations.""" def __init__(self, problem): + """Initialization. + + Args: + problem: test problem + """ self.problem = problem - def jac_mat_prod(self, mat): + @abstractmethod + def jac_mat_prod(self, mat: Tensor) -> Tensor: + """Vectorized product of input-output-Jacobian and a matrix. + + Args: + mat: matrix: the vectors along its leading dimension will be multiplied. + + Returns: + Tensor representing the result of Jacobian-vector product. + product[v] = J @ mat[v] + """ + raise NotImplementedError + + @abstractmethod + def jac_t_mat_prod(self, mat: Tensor) -> Tensor: + """Vectorized product of transposed jacobian and matrix. + + Args: + mat: matrix: the vectors along its leading dimension will be multiplied. + + Returns: + Tensor representing the result of Jacobian-vector product. + product[v] = mat[v] @ J + """ raise NotImplementedError - def jac_t_mat_prod(self, mat): + @abstractmethod + def weight_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + sum_batch: whether to sum along batch axis + + Returns: + product + """ raise NotImplementedError - def weight_jac_t_mat_prod(self, mat, sum_batch): + @abstractmethod + def bias_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + sum_batch: whether to sum along batch axis + + Returns: + product + """ raise NotImplementedError - def bias_jac_t_mat_prod(self, mat, sum_batch): + @abstractmethod + def weight_jac_mat_prod(self, mat: Tensor) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + + Returns: + product + """ raise NotImplementedError - def weight_jac_mat_prod(self, mat): + @abstractmethod + def bias_jac_mat_prod(self, mat: Tensor) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + + Returns: + product + """ raise NotImplementedError - def bias_jac_mat_prod(self, mat): + @abstractmethod + def bias_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + sum_batch: whether to sum along batch axis + + Returns: + product + """ + raise NotImplementedError + + @abstractmethod + def bias_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + sum_batch: whether to sum along batch axis + + Returns: + product + """ + raise NotImplementedError + + @abstractmethod + def weight_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + sum_batch: whether to sum along batch axis + + Returns: + product + """ + raise NotImplementedError + + @abstractmethod + def weight_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + """Product of jacobian and matrix. + + Args: + mat: matrix + sum_batch: whether to sum along batch axis + + Returns: + product + """ + raise NotImplementedError + + @abstractmethod + def ea_jac_t_mat_jac_prod(self, mat: Tensor) -> Tensor: + """Product of ea jacobian with matrix. + + Args: + mat: matrix + + Returns: + product + """ + raise NotImplementedError + + @abstractmethod + def sum_hessian(self) -> Tensor: + """Sum of hessians. + + Returns: + the sum of hessians + """ raise NotImplementedError diff --git a/test/core/derivatives/rnn_settings.py b/test/core/derivatives/rnn_settings.py new file mode 100644 index 000000000..8b4bd88ec --- /dev/null +++ b/test/core/derivatives/rnn_settings.py @@ -0,0 +1,27 @@ +"""Test configurations for `backpack.core.derivatives` RNN layers. + +Required entries: + "module_fn" (callable): Contains a model constructed from `torch.nn` layers + "input_fn" (callable): Used for specifying input function + +Optional entries: + "target_fn" (callable): Fetches the groundtruth/target classes + of regression/classification task + "loss_function_fn" (callable): Loss function used in the model + "device" [list(torch.device)]: List of devices to run the test on. + "id_prefix" (str): Prefix to be included in the test name. + "seed" (int): seed for the random number for torch.rand +""" + +import torch + +RNN_SETTINGS = [ + { + "module_fn": lambda: torch.nn.RNN(input_size=4, hidden_size=2), + "input_fn": lambda: torch.rand(size=(5, 7, 4)), + }, + { + "module_fn": lambda: torch.nn.RNN(input_size=4, hidden_size=2), + "input_fn": lambda: torch.rand(size=(1, 1, 4)), + }, +] diff --git a/test/core/derivatives/utils.py b/test/core/derivatives/utils.py index cfdca8433..c5e14b71d 100644 --- a/test/core/derivatives/utils.py +++ b/test/core/derivatives/utils.py @@ -1,8 +1,12 @@ -"""Utility functions to test `backpack.core.derivatives`""" +"""Utility functions to test `backpack.core.derivatives`.""" +from test.core.derivatives import derivatives_for +from typing import Tuple, Type import torch +from torch import Tensor +from torch.nn import Module -from backpack.core.derivatives import derivatives_for +from backpack.core.derivatives.basederivatives import BaseDerivatives def get_available_devices(): @@ -19,15 +23,17 @@ def get_available_devices(): return devices -def derivative_cls_for(module_cls): +def derivative_cls_for(module_cls: Type[Module]) -> Type[BaseDerivatives]: """Return the associated derivative class for a module. Args: - module_cls (torch.nn.Module): Layer class. + module_cls: Layer class. Returns: - backpack.core.derivatives.Derivatives: Class implementing the - derivatives for `module_cls`. + Class implementing the derivatives for `module_cls`. + + Raises: + KeyError: if derivative for module is missing """ try: return derivatives_for[module_cls] @@ -38,23 +44,38 @@ def derivative_cls_for(module_cls): ) -def is_loss(module): +def is_loss(module: Module) -> bool: """Return whether `module` is a `torch` loss function. Args: - module (torch.nn.Module): A PyTorch module. + module: A PyTorch module. Returns: - bool: Whether `module` is a loss function. + Whether `module` is a loss function. """ return isinstance(module, torch.nn.modules.loss._Loss) -def classification_targets(size, num_classes): - """Create random targets for classes 0, ..., `num_classes - 1`.""" +def classification_targets(size: Tuple[int, ...], num_classes: int) -> Tensor: + """Create random targets for classes 0, ..., `num_classes - 1`. + + Args: + size: shape of targets + num_classes: number of classes + + Returns: + classification targets + """ return torch.randint(size=size, low=0, high=num_classes) -def regression_targets(size): - """Create random targets for regression.""" +def regression_targets(size: Tuple[int, ...]) -> Tensor: + """Create random targets for regression. + + Args: + size: shape of targets + + Returns: + regression targets + """ return torch.rand(size=size) From a01d47a0efae397fed97749015687f0d85e2d760 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 9 Jun 2021 11:46:12 +0200 Subject: [PATCH 08/54] [extensions] RNN first-order extensions and custom modules (#158) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Auxiliary: - Bumps python dependency to 3.7 (type annotations, cyclic imports) - Ignore `TYPE_CHECKING` blocks and `NotImplementedError` in coverage - Custom modules `Permute` and `ReduceTuple` --- * RNN first order extensions * align interface of Permute with permute * reformat * create .coveragerc, exclude TYPE_CHECKING from coverage * change ValueError to NotImplementedError, recognized by coveralls Co-authored-by: Tim Schäfer --- .coveragerc | 12 +++ backpack/core/derivatives/rnn.py | 14 +-- backpack/custom_module/__init__.py | 4 + backpack/custom_module/permute.py | 29 ++++++ backpack/custom_module/reduce_tuple.py | 29 ++++++ .../firstorder/batch_grad/__init__.py | 12 ++- .../extensions/firstorder/batch_grad/rnn.py | 14 +++ .../firstorder/batch_l2_grad/__init__.py | 14 ++- .../firstorder/batch_l2_grad/rnn.py | 70 +++++++++++++ .../firstorder/gradient/__init__.py | 4 + .../extensions/firstorder/gradient/rnn.py | 15 +++ .../firstorder/sum_grad_squared/__init__.py | 11 +++ .../firstorder/sum_grad_squared/rnn.py | 14 +++ .../firstorder/sum_grad_squared/sgs_base.py | 83 +++++++++++++--- .../firstorder/variance/__init__.py | 11 +++ .../extensions/firstorder/variance/rnn.py | 20 ++++ .../firstorder/variance/variance_base.py | 99 +++++++++++++++---- fully_documented.txt | 23 ++++- setup.py | 3 +- .../batch_grad/batchgrad_settings.py | 9 +- .../firstorder/firstorder_settings.py | 39 +++++++- test/extensions/implementation/autograd.py | 2 +- test/extensions/problem.py | 14 +-- 23 files changed, 474 insertions(+), 71 deletions(-) create mode 100644 .coveragerc create mode 100644 backpack/custom_module/__init__.py create mode 100644 backpack/custom_module/permute.py create mode 100644 backpack/custom_module/reduce_tuple.py create mode 100644 backpack/extensions/firstorder/batch_grad/rnn.py create mode 100644 backpack/extensions/firstorder/batch_l2_grad/rnn.py create mode 100644 backpack/extensions/firstorder/gradient/rnn.py create mode 100644 backpack/extensions/firstorder/sum_grad_squared/rnn.py create mode 100644 backpack/extensions/firstorder/variance/rnn.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..28f4ee9fe --- /dev/null +++ b/.coveragerc @@ -0,0 +1,12 @@ +# https://coverage.readthedocs.io/en/v4.5.x/config.html#config +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain if tests don't hit defensive assertion code: + raise NotImplementedError + + # TYPE_CHECKING block is never executed during pytest run + if TYPE_CHECKING: diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 1f6b5bc88..98e6d4529 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -31,20 +31,20 @@ def _check_parameters(module: RNN) -> None: module: module which to check Raises: - ValueError: If any parameter of module does not match expectation + NotImplementedError: If any parameter of module does not match expectation """ if module.num_layers > 1: - raise ValueError("only num_layers = 1 is supported") + raise NotImplementedError("only num_layers = 1 is supported") if not module.nonlinearity == "tanh": - raise ValueError("only nonlinearity = tanh is supported") + raise NotImplementedError("only nonlinearity = tanh is supported") if module.bias is not True: - raise ValueError("only bias = True is supported") + raise NotImplementedError("only bias = True is supported") if module.batch_first is not False: - raise ValueError("only batch_first = False is supported") + raise NotImplementedError("only batch_first = False is supported") if not module.dropout == 0: - raise ValueError("only dropout = 0 is supported") + raise NotImplementedError("only dropout = 0 is supported") if module.bidirectional is not False: - raise ValueError("only bidirectional = False is supported") + raise NotImplementedError("only bidirectional = False is supported") @staticmethod def _a_jac_t_mat_prod( diff --git a/backpack/custom_module/__init__.py b/backpack/custom_module/__init__.py new file mode 100644 index 000000000..1b3c52f2b --- /dev/null +++ b/backpack/custom_module/__init__.py @@ -0,0 +1,4 @@ +"""This package adds torch.nn.Module type modules. + +These are used as utilities. +""" diff --git a/backpack/custom_module/permute.py b/backpack/custom_module/permute.py new file mode 100644 index 000000000..3e71dde80 --- /dev/null +++ b/backpack/custom_module/permute.py @@ -0,0 +1,29 @@ +"""Module containing Permute module.""" +from typing import Any + +from torch import Tensor +from torch.nn import Module + + +class Permute(Module): + """Module to permute a tensor.""" + + def __init__(self, *dims: Any): + """Initialization. + + Args: + dims: The desired ordering of dimensions. + """ + super(Permute, self).__init__() + self.dims = dims + + def forward(self, input: Tensor) -> Tensor: + """Permutes the input tensor. + + Args: + input: input tensor + + Returns: + view with new ordering + """ + return input.permute(self.dims) diff --git a/backpack/custom_module/reduce_tuple.py b/backpack/custom_module/reduce_tuple.py new file mode 100644 index 000000000..5b95179bb --- /dev/null +++ b/backpack/custom_module/reduce_tuple.py @@ -0,0 +1,29 @@ +"""Module containing ReduceTuple module.""" +from typing import Union + +from torch import Tensor +from torch.nn import Module + + +class ReduceTuple(Module): + """Module reducing tuple input.""" + + def __init__(self, index: int = 0): + """Initialization. + + Args: + index: which element to choose + """ + super(ReduceTuple, self).__init__() + self.index = index + + def forward(self, input: tuple) -> Union[tuple, Tensor]: + """Reduces the tuple. + + Args: + input: the tuple of data + + Returns: + the selected element + """ + return input[self.index] diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 98b350d9f..9417d8449 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -1,4 +1,9 @@ +"""Contains the backpropagation extension for grad_batch: BatchGrad. + +It defines the module extension for each module. +""" from torch.nn import ( + RNN, BatchNorm1d, Conv1d, Conv2d, @@ -20,6 +25,7 @@ conv_transpose2d, conv_transpose3d, linear, + rnn, ) @@ -39,10 +45,13 @@ class BatchGrad(BackpropExtension): The concept of individual gradients is only meaningful if the objective is a sum of independent functions (no batchnorm). - """ def __init__(self): + """Initialization. + + Defines extension for each module. + """ super().__init__( savefield="grad_batch", fail_mode="WARNING", @@ -55,5 +64,6 @@ def __init__(self): ConvTranspose2d: conv_transpose2d.BatchGradConvTranspose2d(), ConvTranspose3d: conv_transpose3d.BatchGradConvTranspose3d(), BatchNorm1d: batchnorm1d.BatchGradBatchNorm1d(), + RNN: rnn.BatchGradRNN(), }, ) diff --git a/backpack/extensions/firstorder/batch_grad/rnn.py b/backpack/extensions/firstorder/batch_grad/rnn.py new file mode 100644 index 000000000..50afdf516 --- /dev/null +++ b/backpack/extensions/firstorder/batch_grad/rnn.py @@ -0,0 +1,14 @@ +"""Contains BatchGradRNN.""" +from backpack.core.derivatives.rnn import RNNDerivatives +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase + + +class BatchGradRNN(BatchGradBase): + """Extension for RNN calculating grad_batch.""" + + def __init__(self): + """Initialization.""" + super().__init__( + derivatives=RNNDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + ) diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index e4f72a159..2639285e5 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -1,4 +1,10 @@ +"""Contains BatchL2Grad. + +Defines the backpropagation extension. +Within it, define the extension for each module. +""" from torch.nn import ( + RNN, Conv1d, Conv2d, Conv3d, @@ -18,6 +24,7 @@ convtranspose2d, convtranspose3d, linear, + rnn, ) @@ -37,7 +44,11 @@ class BatchL2Grad(BackpropExtension): """ def __init__(self): - super().__init__( + """Initialization. + + Define the extensions for each module. + """ + super(BatchL2Grad, self).__init__( savefield="batch_l2", fail_mode="WARNING", module_exts={ @@ -48,5 +59,6 @@ def __init__(self): ConvTranspose1d: convtranspose1d.BatchL2ConvTranspose1d(), ConvTranspose2d: convtranspose2d.BatchL2ConvTranspose2d(), ConvTranspose3d: convtranspose3d.BatchL2ConvTranspose3d(), + RNN: rnn.BatchL2RNN(), }, ) diff --git a/backpack/extensions/firstorder/batch_l2_grad/rnn.py b/backpack/extensions/firstorder/batch_l2_grad/rnn.py new file mode 100644 index 000000000..4540eef9a --- /dev/null +++ b/backpack/extensions/firstorder/batch_l2_grad/rnn.py @@ -0,0 +1,70 @@ +"""Contains BatchL2RNN.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.rnn import RNNDerivatives +from backpack.extensions.firstorder.base import FirstOrderModuleExtension + +if TYPE_CHECKING: + from backpack.extensions import BatchL2Grad + + +class BatchL2RNN(FirstOrderModuleExtension): + """Extension for RNN, calculating batch_l2.""" + + def __init__(self): + """Initialization.""" + params = ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"] + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) + super(BatchL2RNN, self).__init__(params=params) + self.derivatives: RNNDerivatives = RNNDerivatives() + + def _make_param_function( + self, param_str: str + ) -> Callable[[BatchL2Grad, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]: + """Creates a function that calculates batch_l2. + + Args: + param_str: name of parameter + + Returns: + function that calculates batch_l2 + """ + + def param_function( + ext: BatchL2Grad, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + bpQuantities: None, + ) -> Tensor: + """Calculates batch_l2 with the help of derivatives object. + + Args: + ext: extension that is used + module: module that performed forward pass + g_inp: input gradient tensors + g_out: output gradient tensors + bpQuantities: additional quantities for second order + + Returns: + batch_l2 + """ + return ( + ( + getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( + module, g_inp, g_out, g_out[0], sum_batch=False + ) + ** 2 + ) + .flatten(start_dim=1) + .sum(1) + ) + + return param_function diff --git a/backpack/extensions/firstorder/gradient/__init__.py b/backpack/extensions/firstorder/gradient/__init__.py index 7a5228834..89c7cff43 100644 --- a/backpack/extensions/firstorder/gradient/__init__.py +++ b/backpack/extensions/firstorder/gradient/__init__.py @@ -1 +1,5 @@ +"""This package contains the gradient extension. + +It calculates the same result as torch backward(). +""" # TODO: Rewrite variance to not need this extension diff --git a/backpack/extensions/firstorder/gradient/rnn.py b/backpack/extensions/firstorder/gradient/rnn.py new file mode 100644 index 000000000..7924ca554 --- /dev/null +++ b/backpack/extensions/firstorder/gradient/rnn.py @@ -0,0 +1,15 @@ +"""Contains GradRNN.""" +from backpack.core.derivatives.rnn import RNNDerivatives + +from .base import GradBaseModule + + +class GradRNN(GradBaseModule): + """Extension for RNN, calculating gradient.""" + + def __init__(self): + """Initialization.""" + super().__init__( + derivatives=RNNDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + ) diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index c0c6a2327..0e5b0ec4c 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -1,4 +1,9 @@ +"""Contains backpropagation extension for sum_grad_squared: SumGradSquared. + +Defines module extension for each module. +""" from torch.nn import ( + RNN, Conv1d, Conv2d, Conv3d, @@ -18,6 +23,7 @@ convtranspose2d, convtranspose3d, linear, + rnn, ) @@ -36,6 +42,10 @@ class SumGradSquared(BackpropExtension): """ def __init__(self): + """Initialization. + + Defines module extension for each module. + """ super().__init__( savefield="sum_grad_squared", fail_mode="WARNING", @@ -47,5 +57,6 @@ def __init__(self): ConvTranspose1d: convtranspose1d.SGSConvTranspose1d(), ConvTranspose2d: convtranspose2d.SGSConvTranspose2d(), ConvTranspose3d: convtranspose3d.SGSConvTranspose3d(), + RNN: rnn.SGSRNN(), }, ) diff --git a/backpack/extensions/firstorder/sum_grad_squared/rnn.py b/backpack/extensions/firstorder/sum_grad_squared/rnn.py new file mode 100644 index 000000000..61e96d698 --- /dev/null +++ b/backpack/extensions/firstorder/sum_grad_squared/rnn.py @@ -0,0 +1,14 @@ +"""Contains SGSRNN module.""" +from backpack.core.derivatives.rnn import RNNDerivatives +from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase + + +class SGSRNN(SGSBase): + """Extension for RNN, calculating sum_gradient_squared.""" + + def __init__(self): + """Initialization.""" + super(SGSRNN, self).__init__( + derivatives=RNNDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + ) diff --git a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py index e654e8016..48ab13734 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py +++ b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py @@ -1,20 +1,71 @@ +"""Contains SGSBase, the base module for sum_grad_squared extension.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.extensions.firstorder.base import FirstOrderModuleExtension +if TYPE_CHECKING: + from backpack.extensions import SumGradSquared + class SGSBase(FirstOrderModuleExtension): - def __init__(self, derivatives, params=None): - self.derivatives = derivatives - self.N_axis = 0 - super().__init__(params=params) - - def bias(self, ext, module, g_inp, g_out, bpQuantities): - grad_batch = self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=False - ) - return (grad_batch ** 2).sum(self.N_axis) - - def weight(self, ext, module, g_inp, g_out, bpQuantities): - grad_batch = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=False - ) - return (grad_batch ** 2).sum(self.N_axis) + """Base class for extensions calculating sum_grad_squared.""" + + def __init__(self, derivatives: BaseParameterDerivatives, params: List[str] = None): + """Initialization. + + For each parameter a function is initialized that is named like the parameter + + Args: + derivatives: calculates the derivatives wrt parameters + params: list of parameter names + """ + self.derivatives: BaseParameterDerivatives = derivatives + self.N_axis: int = 0 + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) + super(SGSBase, self).__init__(params=params) + + def _make_param_function( + self, param_str: str + ) -> Callable[[SumGradSquared, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]: + """Creates a function that calculates sum_grad_squared. + + Args: + param_str: name of parameter + + Returns: + function that calculates sum_grad_squared + """ + + def param_function( + ext: SumGradSquared, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + bpQuantities: None, + ) -> Tensor: + """Calculates sum_grad_squared with the help of derivatives object. + + Args: + ext: extension that is used + module: module that performed forward pass + g_inp: input gradient tensors + g_out: output gradient tensors + bpQuantities: additional quantities for second order + + Returns: + sum_grad_squared + """ + grad_batch = getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( + module, g_inp, g_out, g_out[0], sum_batch=False + ) + return (grad_batch ** 2).sum(self.N_axis) + + return param_function diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index e97c39205..cd69c13a3 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -1,4 +1,9 @@ +"""Defines backpropagation extension for variance: Variance. + +Defines module extension for each module. +""" from torch.nn import ( + RNN, Conv1d, Conv2d, Conv3d, @@ -18,6 +23,7 @@ convtranspose2d, convtranspose3d, linear, + rnn, ) @@ -36,6 +42,10 @@ class Variance(BackpropExtension): """ def __init__(self): + """Initialization. + + Defines module extension for each module. + """ super().__init__( savefield="variance", fail_mode="WARNING", @@ -47,5 +57,6 @@ def __init__(self): ConvTranspose1d: convtranspose1d.VarianceConvTranspose1d(), ConvTranspose2d: convtranspose2d.VarianceConvTranspose2d(), ConvTranspose3d: convtranspose3d.VarianceConvTranspose3d(), + RNN: rnn.VarianceRNN(), }, ) diff --git a/backpack/extensions/firstorder/variance/rnn.py b/backpack/extensions/firstorder/variance/rnn.py new file mode 100644 index 000000000..d41f6de8e --- /dev/null +++ b/backpack/extensions/firstorder/variance/rnn.py @@ -0,0 +1,20 @@ +"""Contains VarianceRNN.""" +from backpack.extensions.firstorder.gradient.rnn import GradRNN +from backpack.extensions.firstorder.sum_grad_squared.rnn import SGSRNN +from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule + + +class VarianceRNN(VarianceBaseModule): + """Extension for RNN, calculating variance.""" + + def __init__(self): + """Initialization.""" + super(VarianceRNN, self).__init__( + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + grad_extension=GradRNN(), + sgs_extension=SGSRNN(), + ) + + @staticmethod + def _get_axis_batch() -> int: + return 1 diff --git a/backpack/extensions/firstorder/variance/variance_base.py b/backpack/extensions/firstorder/variance/variance_base.py index 64d8c17e2..0341d490b 100644 --- a/backpack/extensions/firstorder/variance/variance_base.py +++ b/backpack/extensions/firstorder/variance/variance_base.py @@ -1,30 +1,89 @@ +"""Contains VarianceBaseModule.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Tuple + +from torch import Tensor +from torch.nn import Module + from backpack.extensions.firstorder.base import FirstOrderModuleExtension +if TYPE_CHECKING: + from backpack.extensions import Variance + from backpack.extensions.firstorder.gradient.base import GradBaseModule + from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase + class VarianceBaseModule(FirstOrderModuleExtension): - def __init__(self, params, grad_extension, sgs_extension): - super().__init__(params=params) - self.grad_ext = grad_extension - self.sgs_ext = sgs_extension + """Base class for extensions calculating variance.""" + + def __init__( + self, + params: List[str], + grad_extension: GradBaseModule, + sgs_extension: SGSBase, + ): + """Initialization. + + Creates a function named after each parameter. + + Args: + params: list of parameter names + grad_extension: the extension calculating grad. + sgs_extension: the extension calculating squared_grad_sum. + """ + self.grad_ext: GradBaseModule = grad_extension + self.sgs_ext: SGSBase = sgs_extension + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) + super(VarianceBaseModule, self).__init__(params=params) @staticmethod - def variance_from(grad, sgs, N): + def _variance_from(grad: Tensor, sgs: Tensor, N: int) -> Tensor: avgg_squared = (grad / N) ** 2 avg_gsquared = sgs / N return avg_gsquared - avgg_squared - def bias(self, ext, module, g_inp, g_out, backproped): - N = g_out[0].shape[0] - return self.variance_from( - self.grad_ext.bias(ext, module, g_inp, g_out, backproped), - self.sgs_ext.bias(ext, module, g_inp, g_out, backproped), - N, - ) - - def weight(self, ext, module, g_inp, g_out, backproped): - N = g_out[0].shape[0] - return self.variance_from( - self.grad_ext.weight(ext, module, g_inp, g_out, backproped), - self.sgs_ext.weight(ext, module, g_inp, g_out, backproped), - N, - ) + @staticmethod + def _get_axis_batch() -> int: + return 0 + + def _make_param_function( + self, param: str + ) -> Callable[[Variance, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]: + """Creates a function that calculates variance of grad_batch. + + Args: + param(str): name of parameter + + Returns: + function that calculates variance of grad_batch + """ + + def param_function( + ext: Variance, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + bpQuantities: None, + ) -> Tensor: + """Calculates variance with the help of derivatives object. + + Args: + ext: extension that is used + module: module that performed forward pass + g_inp: input gradient tensors + g_out: output gradient tensors + bpQuantities: additional quantities for second order + + Returns: + variance of the batch + """ + return self._variance_from( + getattr(self.grad_ext, param)(ext, module, g_inp, g_out, bpQuantities), + getattr(self.sgs_ext, param)(ext, module, g_inp, g_out, bpQuantities), + g_out[0].shape[self._get_axis_batch()], + ) + + return param_function diff --git a/fully_documented.txt b/fully_documented.txt index b5bd63967..c2e57ea96 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -1,8 +1,28 @@ +setup.py + backpack/core/derivatives/basederivatives.py backpack/core/derivatives/rnn.py backpack/core/derivatives/shape_check.py backpack/core/derivatives/__init__.py +backpack/extensions/firstorder/gradient/base.py +backpack/extensions/firstorder/gradient/rnn.py +backpack/extensions/firstorder/gradient/__init__.py +backpack/extensions/firstorder/batch_grad/batch_grad_base.py +backpack/extensions/firstorder/batch_grad/rnn.py +backpack/extensions/firstorder/batch_grad/__init__.py +backpack/extensions/firstorder/variance/variance_base.py +backpack/extensions/firstorder/variance/rnn.py +backpack/extensions/firstorder/variance/__init__.py +backpack/extensions/firstorder/sum_grad_squared/sgs_base.py +backpack/extensions/firstorder/sum_grad_squared/rnn.py +backpack/extensions/firstorder/sum_grad_squared/__init__.py +backpack/extensions/firstorder/batch_l2_grad/rnn.py +backpack/extensions/firstorder/batch_l2_grad/__init__.py +backpack/extensions/backprop_extension.py + +backpack/custom_module/ + test/core/derivatives/derivatives_test.py test/core/derivatives/__init__.py test/core/derivatives/rnn_settings.py @@ -12,5 +32,4 @@ test/core/derivatives/implementation/ test/extensions/problem.py test/extensions/test_backprop_extension.py test/extensions/firstorder/firstorder_settings.py - -# docs_src/examples/use_cases/example_custom_module +test/extensions/firstorder/batch_grad/batchgrad_settings.py diff --git a/setup.py b/setup.py index d6ffa7388..edf8054b8 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +"""Setup backpack.""" from os import path from setuptools import find_packages, setup @@ -43,5 +44,5 @@ license=LICENSE, packages=PACKAGES, zip_safe=False, - python_requires=">=3.6", + python_requires=">=3.7", ) diff --git a/test/extensions/firstorder/batch_grad/batchgrad_settings.py b/test/extensions/firstorder/batch_grad/batchgrad_settings.py index 862050ac7..3920b5b5a 100644 --- a/test/extensions/firstorder/batch_grad/batchgrad_settings.py +++ b/test/extensions/firstorder/batch_grad/batchgrad_settings.py @@ -1,14 +1,11 @@ -"""Test configurations to test batch_grad +"""Test configurations to test batch_grad. The tests are taken from `test.extensions.firstorder.firstorder_settings`, but additional custom tests can be defined here by appending it to the list. """ - from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS -BATCHGRAD_SETTINGS = [] - SHARED_SETTINGS = FIRSTORDER_SETTINGS -LOCAL_SETTING = [] +LOCAL_SETTINGS = [] -BATCHGRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTING +BATCHGRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 99206d86d..b5474ffe5 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -18,23 +18,26 @@ "device" [list(torch.device)]: List of devices to run the test on. "id_prefix" (str): Prefix to be included in the test name. "seed" (int): seed for the random number for torch.rand - "axis_batch (int): specifies the batch axis. Defaults to zero """ - - from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.automated_settings import make_simple_cnn_setting import torch from torch.nn import ( + RNN, Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Flatten, + Sequential, ) +from backpack.custom_module.permute import Permute +from backpack.custom_module.reduce_tuple import ReduceTuple + FIRSTORDER_SETTINGS = [] ############################################################################### @@ -154,3 +157,33 @@ (3, 3, 2, 7, 7), ConvTranspose3d, (3, 2, 2, 4, 2, 0, 1, False) ), ] + +############################################################################### +# test setting: RNN Layers # +############################################################################### + +FIRSTORDER_SETTINGS += [ + { + "input_fn": lambda: torch.rand(8, 5, 6), + "module_fn": lambda: Sequential( + Permute(1, 0, 2), + RNN(input_size=6, hidden_size=3), + ReduceTuple(index=0), + Permute(1, 2, 0), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((8, 5), 3), + }, + { + "input_fn": lambda: torch.rand(8, 5, 6), + "module_fn": lambda: Sequential( + Permute(1, 0, 2), + RNN(input_size=6, hidden_size=3), + ReduceTuple(index=0), + Permute(1, 2, 0), + Flatten(), + ), + "loss_function_fn": lambda: torch.nn.MSELoss(), + "target_fn": lambda: regression_targets((8, 3 * 5)), + }, +] diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index f956090de..3739cfccd 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -16,7 +16,7 @@ def batch_grad(self): Returns: list[torch.Tensor]: batch_grads """ - N = self.problem.input.shape[self.problem.axis_batch] + N = self.problem.input.shape[0] batch_grads = [ torch.zeros(N, *p.size()).to(self.problem.device) for p in self.problem.model.parameters() diff --git a/test/extensions/problem.py b/test/extensions/problem.py index f25dcd53a..1e3e92b5a 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -48,7 +48,6 @@ def add_missing_defaults(setting): "id_prefix": "", "seed": 0, "device": get_available_devices(), - "axis_batch": 0, } for req in required: @@ -78,7 +77,6 @@ def __init__( device, seed, id_prefix, - axis_batch, ): """Collection of information required to test extensions. @@ -90,7 +88,6 @@ def __init__( device (torch.device): Device to run on. seed (int): Random seed. id_prefix (str): Extra string added to test id. - axis_batch (int): index of batch axis. Defaults to 0. """ self.module_fn = module_fn self.input_fn = input_fn @@ -100,7 +97,6 @@ def __init__( self.device = device self.seed = seed self.id_prefix = id_prefix - self.axis_batch = axis_batch def set_up(self): """Set up problem from settings.""" @@ -151,17 +147,9 @@ def forward_pass(self, sample_idx=None): target = self.target.clone().detach() else: target = self.target.clone()[sample_idx].unsqueeze(0).detach() - input = self.input.split(1, dim=self.axis_batch)[sample_idx].detach() + input = self.input.split(1, dim=0)[sample_idx].detach() output = self.model(input) - if isinstance(output, tuple): - output = output[0] - - if self.axis_batch != 0: - # Note: This inserts a new operation into the computation graph. - # In second order extensions, breaks backpropagation of additional - # information. - output = output.transpose(0, self.axis_batch) loss = self.loss_function(output, target) From 9d5db3a8d6e43660d806de4d445f5294c5ad2073 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 10 Jun 2021 10:11:31 +0200 Subject: [PATCH 09/54] [extensions] {RNN, Permute, ReduceTuple} DiagGGN extensions (#159) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `DiagGGN` extension for `RNN` module - Refactor `DiagGGN` extension, introduce base class - Add Jacobians of `Permute` and `ReduceTuple`. Co-authored-by: f-dangel <48687646+f-dangel@users.noreply.github.com> Co-authored-by: Tim Schäfer --- backpack/core/derivatives/permute.py | 23 +++++ backpack/extensions/backprop_extension.py | 3 +- backpack/extensions/mat_to_mat_jac_base.py | 47 +++++++-- .../secondorder/diag_ggn/__init__.py | 98 ++++++++++++++----- .../secondorder/diag_ggn/diag_ggn_base.py | 70 ++++++++++++- .../secondorder/diag_ggn/permute.py | 11 +++ .../extensions/secondorder/diag_ggn/rnn.py | 27 +++++ fully_documented.txt | 15 ++- test/core/derivatives/__init__.py | 3 + test/core/derivatives/derivatives_test.py | 12 ++- test/core/derivatives/permute_settings.py | 33 +++++++ test/extensions/implementation/autograd.py | 20 +++- ...aggnn_settings.py => diag_ggn_settings.py} | 26 ++++- .../diag_ggn/test_batch_diag_ggn.py | 15 +-- .../secondorder/diag_ggn/test_diag_ggn.py | 15 +-- .../secondorder/secondorder_settings.py | 3 +- 16 files changed, 367 insertions(+), 54 deletions(-) create mode 100644 backpack/core/derivatives/permute.py create mode 100644 backpack/extensions/secondorder/diag_ggn/permute.py create mode 100644 backpack/extensions/secondorder/diag_ggn/rnn.py create mode 100644 test/core/derivatives/permute_settings.py rename test/extensions/secondorder/diag_ggn/{diaggnn_settings.py => diag_ggn_settings.py} (52%) diff --git a/backpack/core/derivatives/permute.py b/backpack/core/derivatives/permute.py new file mode 100644 index 000000000..a4654b9bf --- /dev/null +++ b/backpack/core/derivatives/permute.py @@ -0,0 +1,23 @@ +"""Module containing derivatives of Permute.""" +from typing import Tuple + +from torch import Tensor, argsort + +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.custom_module.permute import Permute + + +class PermuteDerivatives(BaseDerivatives): + """Derivatives of Permute.""" + + def _jac_t_mat_prod( + self, module: Permute, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + return mat.permute( + [0] + [element + 1 for element in argsort(Tensor(module.dims))] + ) + + def _jac_mat_prod( + self, module: Permute, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + return mat.permute([0] + [element + 1 for element in module.dims]) diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py index e32d12185..06c614c3a 100644 --- a/backpack/extensions/backprop_extension.py +++ b/backpack/extensions/backprop_extension.py @@ -5,6 +5,7 @@ import torch.nn from torch.nn import Sequential +from backpack.custom_module.reduce_tuple import ReduceTuple from backpack.extensions.module_extension import ModuleExtension from backpack.utils.hooks import no_op @@ -78,7 +79,7 @@ def __get_module_extension(self, module): if module_extension is None: - if isinstance(module, Sequential): + if isinstance(module, (Sequential, ReduceTuple)): return no_op if self.__fail_mode is FAIL_ERROR: diff --git a/backpack/extensions/mat_to_mat_jac_base.py b/backpack/extensions/mat_to_mat_jac_base.py index ca9d214e0..c922f825b 100644 --- a/backpack/extensions/mat_to_mat_jac_base.py +++ b/backpack/extensions/mat_to_mat_jac_base.py @@ -1,23 +1,56 @@ +"""Contains base class for second order extensions.""" +from typing import List, Tuple, Union + +from torch import Tensor +from torch.nn import Module + +from ..core.derivatives.basederivatives import BaseDerivatives, BaseParameterDerivatives from .module_extension import ModuleExtension class MatToJacMat(ModuleExtension): - """ - Base class for backpropagating matrices by multiplying with Jacobians. - """ + """Base class for backpropagation of matrices by multiplying with Jacobians.""" - def __init__(self, derivatives, params=None): + def __init__( + self, + derivatives: Union[BaseDerivatives, BaseParameterDerivatives], + params: List[str] = None, + ): + """Initialization. + + Args: + derivatives: class containing derivatives + params: list of parameter names + """ super().__init__(params) self.derivatives = derivatives - def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + def backpropagate( + self, + ext: ModuleExtension, + module: Module, + grad_inp: Tuple[Tensor], + grad_out: Tuple[Tensor], + backproped: Union[List[Tensor], Tensor], + ) -> Union[List[Tensor], Tensor]: + """Propagates second order information back. + + Args: + ext: extension + module: module through which to perform backpropagation + grad_inp: input gradients + grad_out: output gradients + backproped: backpropagation information + Returns: + derivative wrt input + """ if isinstance(backproped, list): - M_list = [ + M_list: List[Tensor] = [ self.derivatives.jac_t_mat_prod(module, grad_inp, grad_out, M) for M in backproped ] - return list(M_list) + return M_list else: return self.derivatives.jac_t_mat_prod( module, grad_inp, grad_out, backproped diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index 98ab7c0ea..4c1accf36 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -1,5 +1,16 @@ +"""Module contains definitions of DiagGGN extensions. + +Contains: +DiagGGN(BackpropExtension) +DiagGGNExact(DiagGGN) +DiagGGNMC(DiagGGN) +BatchDiagGGN(BackpropExtension) +BatchDiagGGNExact(BatchDiagGGN) +BatchDiagGGNMC(BatchDiagGGN) +""" from torch.nn import ( ELU, + RNN, SELU, AvgPool1d, AvgPool2d, @@ -26,6 +37,7 @@ ZeroPad2d, ) +from backpack.custom_module.permute import Permute from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.secondorder.hbp import LossHessianStrategy @@ -42,7 +54,9 @@ linear, losses, padding, + permute, pooling, + rnn, ) @@ -54,7 +68,16 @@ class DiagGGN(BackpropExtension): LossHessianStrategy.SAMPLING, ] - def __init__(self, loss_hessian_strategy, savefield): + def __init__(self, loss_hessian_strategy: str, savefield: str): + """Initialization. + + Args: + loss_hessian_strategy: either LossHessianStrategy.EXACT or .SAMPLING + savefield: the field where to save the calculated property + + Raises: + ValueError: if chosen loss strategy is not valid. + """ if loss_hessian_strategy not in self.VALID_LOSS_HESSIAN_STRATEGIES: raise ValueError( "Unknown hessian strategy: {}".format(loss_hessian_strategy) @@ -91,13 +114,15 @@ def __init__(self, loss_hessian_strategy, savefield): LogSigmoid: activations.DiagGGNLogSigmoid(), ELU: activations.DiagGGNELU(), SELU: activations.DiagGGNSELU(), + RNN: rnn.DiagGGNRNN(), + Permute: permute.DiagGGNPermute(), }, ) class DiagGGNExact(DiagGGN): - """ - Diagonal of the Generalized Gauss-Newton/Fisher. + """Diagonal of the Generalized Gauss-Newton/Fisher. + Uses the exact Hessian of the loss w.r.t. the model output. Stores the output in :code:`diag_ggn_exact`, @@ -105,16 +130,16 @@ class DiagGGNExact(DiagGGN): For a faster but less precise alternative, see :py:meth:`backpack.extensions.DiagGGNMC`. - """ def __init__(self): + """Initialization. Chooses exact loss strategy and savefield diag_ggn_exact.""" super().__init__(LossHessianStrategy.EXACT, "diag_ggn_exact") class DiagGGNMC(DiagGGN): - """ - Diagonal of the Generalized Gauss-Newton/Fisher. + """Diagonal of the Generalized Gauss-Newton/Fisher. + Uses a Monte-Carlo approximation of the Hessian of the loss w.r.t. the model output. @@ -123,17 +148,23 @@ class DiagGGNMC(DiagGGN): For a more precise but slower alternative, see :py:meth:`backpack.extensions.DiagGGNExact`. - - Args: - mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``. - """ - def __init__(self, mc_samples=1): + def __init__(self, mc_samples: int = 1): + """Initialization. Chooses sampling loss strategy and savefield diag_ggn_mc. + + Args: + mc_samples: Number of Monte-Carlo samples. Default: ``1``. + """ self._mc_samples = mc_samples super().__init__(LossHessianStrategy.SAMPLING, "diag_ggn_mc") - def get_num_mc_samples(self): + def get_num_mc_samples(self) -> int: + """Returns number of Monte-Carlo samples. + + Returns: + number of Monte-Carlo samples + """ return self._mc_samples @@ -145,7 +176,16 @@ class BatchDiagGGN(BackpropExtension): LossHessianStrategy.SAMPLING, ] - def __init__(self, loss_hessian_strategy, savefield): + def __init__(self, loss_hessian_strategy: str, savefield: str): + """Initialization. + + Args: + loss_hessian_strategy: either LossHessianStrategy.EXACT or .SAMPLING + savefield: name of variable where to save calculated quantity + + Raises: + ValueError: if chosen loss strategy is not valid. + """ if loss_hessian_strategy not in self.VALID_LOSS_HESSIAN_STRATEGIES: raise ValueError( "Unknown hessian strategy: {}".format(loss_hessian_strategy) @@ -181,13 +221,15 @@ def __init__(self, loss_hessian_strategy, savefield): LogSigmoid: activations.DiagGGNLogSigmoid(), ELU: activations.DiagGGNELU(), SELU: activations.DiagGGNSELU(), + RNN: rnn.BatchDiagGGNRNN(), + Permute: permute.DiagGGNPermute(), }, ) class BatchDiagGGNExact(BatchDiagGGN): - """ - Individual diagonal of the Generalized Gauss-Newton/Fisher. + """Individual diagonal of the Generalized Gauss-Newton/Fisher. + Uses the exact Hessian of the loss w.r.t. the model output. Stores the output in ``diag_ggn_exact_batch`` as a ``[N x ...]`` tensor, @@ -195,6 +237,10 @@ class BatchDiagGGNExact(BatchDiagGGN): """ def __init__(self): + """Initialization. + + Chooses exact loss strategy and savefield diag_ggn_exact_batch. + """ super().__init__( loss_hessian_strategy=LossHessianStrategy.EXACT, savefield="diag_ggn_exact_batch", @@ -202,8 +248,8 @@ def __init__(self): class BatchDiagGGNMC(BatchDiagGGN): - """ - Individual diagonal of the Generalized Gauss-Newton/Fisher. + """Individual diagonal of the Generalized Gauss-Newton/Fisher. + Uses a Monte-Carlo approximation of the Hessian of the loss w.r.t. the model output. @@ -212,18 +258,26 @@ class BatchDiagGGNMC(BatchDiagGGN): For a more precise but slower alternative, see :py:meth:`backpack.extensions.BatchDiagGGNExact`. + """ - Args: - mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``. + def __init__(self, mc_samples: int = 1): + """Initialization. - """ + Chooses sampling loss strategy and savefield diag_ggn_mc_batch. - def __init__(self, mc_samples=1): + Args: + mc_samples: Number of Monte-Carlo samples. Default: ``1``. + """ self._mc_samples = mc_samples super().__init__( loss_hessian_strategy=LossHessianStrategy.SAMPLING, savefield="diag_ggn_mc_batch", ) - def get_num_mc_samples(self): + def get_num_mc_samples(self) -> int: + """Returns number of Monte-Carlo samples. + + Returns: + number of Monte-Carlo samples + """ return self._mc_samples diff --git a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py index e97334dea..5d23cc200 100644 --- a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py +++ b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py @@ -1,6 +1,74 @@ +"""Contains DiagGGN base class.""" +from typing import Callable, List, Tuple, Union + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.basederivatives import ( + BaseDerivatives, + BaseParameterDerivatives, +) from backpack.extensions.mat_to_mat_jac_base import MatToJacMat +from backpack.extensions.module_extension import ModuleExtension class DiagGGNBaseModule(MatToJacMat): - def __init__(self, derivatives, params=None): + """Base class for DiagGGN extension.""" + + def __init__( + self, + derivatives: Union[BaseDerivatives, BaseParameterDerivatives], + params: List[str] = None, + sum_batch: bool = None, + ): + """Initialization. + + If params and sum_batch is provided: + Creates a method named after parameter for each parameter. Checks if that + method is implemented, so a child class can implement a more efficient version. + + Args: + derivatives: class containing derivatives + params: list of parameter names. Defaults to None. + sum_batch: Specifies whether the created method sums along batch axis. + For BatchDiagGGNModule should be False. + For DiagGGNModule should be True. + Defaults to None. + """ + if params is not None and sum_batch is not None: + for param in params: + if not hasattr(self, param): + setattr(self, param, self._make_param_method(param, sum_batch)) super().__init__(derivatives, params=params) + + def _make_param_method( + self, param: str, sum_batch: bool + ) -> Callable[ + [ModuleExtension, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor + ]: + def _param( + ext: ModuleExtension, + module: Module, + grad_inp: Tuple[Tensor], + grad_out: Tuple[Tensor], + backproped: Tensor, + ) -> Tensor: + """Returns diagonal of GGN. + + Args: + ext: extension + module: module through which to backpropagate + grad_inp: input gradients + grad_out: output gradients + backproped: backpropagated information + + Returns: + diagonal + """ + JS: Tensor = getattr(self.derivatives, f"{param}_jac_t_mat_prod")( + module, grad_inp, grad_out, backproped, sum_batch=False + ) + axis: Tuple[int] = (0, 1) if sum_batch else (0,) + return (JS ** 2).sum(axis=axis) + + return _param diff --git a/backpack/extensions/secondorder/diag_ggn/permute.py b/backpack/extensions/secondorder/diag_ggn/permute.py new file mode 100644 index 000000000..7e7db118c --- /dev/null +++ b/backpack/extensions/secondorder/diag_ggn/permute.py @@ -0,0 +1,11 @@ +"""Module defining DiagGGNPermute.""" +from backpack.core.derivatives.permute import PermuteDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class DiagGGNPermute(DiagGGNBaseModule): + """DiagGGN extension of Permute.""" + + def __init__(self): + """Initialize.""" + super().__init__(derivatives=PermuteDerivatives()) diff --git a/backpack/extensions/secondorder/diag_ggn/rnn.py b/backpack/extensions/secondorder/diag_ggn/rnn.py new file mode 100644 index 000000000..64bb4e6b7 --- /dev/null +++ b/backpack/extensions/secondorder/diag_ggn/rnn.py @@ -0,0 +1,27 @@ +"""Module implementing GGN for RNN.""" +from backpack.core.derivatives.rnn import RNNDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class DiagGGNRNN(DiagGGNBaseModule): + """Calculating diagonal of GGN.""" + + def __init__(self): + """Initialize.""" + super().__init__( + derivatives=RNNDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + sum_batch=True, + ) + + +class BatchDiagGGNRNN(DiagGGNBaseModule): + """Calculating per-sample diagonal of GGN.""" + + def __init__(self): + """Initialize.""" + super().__init__( + derivatives=RNNDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + sum_batch=False, + ) diff --git a/fully_documented.txt b/fully_documented.txt index c2e57ea96..bda381452 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -4,6 +4,10 @@ backpack/core/derivatives/basederivatives.py backpack/core/derivatives/rnn.py backpack/core/derivatives/shape_check.py backpack/core/derivatives/__init__.py +backpack/core/derivatives/permute.py + +backpack/extensions/backprop_extension.py +backpack/extensions/mat_to_mat_jac_base.py backpack/extensions/firstorder/gradient/base.py backpack/extensions/firstorder/gradient/rnn.py @@ -19,7 +23,11 @@ backpack/extensions/firstorder/sum_grad_squared/rnn.py backpack/extensions/firstorder/sum_grad_squared/__init__.py backpack/extensions/firstorder/batch_l2_grad/rnn.py backpack/extensions/firstorder/batch_l2_grad/__init__.py -backpack/extensions/backprop_extension.py + +backpack/extensions/secondorder/diag_ggn/__init__.py +backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py +backpack/extensions/secondorder/diag_ggn/rnn.py +backpack/extensions/secondorder/diag_ggn/permute.py backpack/custom_module/ @@ -28,8 +36,13 @@ test/core/derivatives/__init__.py test/core/derivatives/rnn_settings.py test/core/derivatives/utils.py test/core/derivatives/implementation/ +test/core/derivatives/permute_settings.py test/extensions/problem.py test/extensions/test_backprop_extension.py + test/extensions/firstorder/firstorder_settings.py test/extensions/firstorder/batch_grad/batchgrad_settings.py + +test/extensions/secondorder/secondorder_settings.py +test/extensions/secondorder/diag_ggn/ diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index 3ba38952e..05748ce57 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -46,12 +46,14 @@ from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives from backpack.core.derivatives.mseloss import MSELossDerivatives +from backpack.core.derivatives.permute import PermuteDerivatives from backpack.core.derivatives.relu import ReLUDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.core.derivatives.selu import SELUDerivatives from backpack.core.derivatives.sigmoid import SigmoidDerivatives from backpack.core.derivatives.tanh import TanhDerivatives from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives +from backpack.custom_module.permute import Permute derivatives_for = { Linear: LinearDerivatives, @@ -79,4 +81,5 @@ CrossEntropyLoss: CrossEntropyLossDerivatives, MSELoss: MSELossDerivatives, RNN: RNNDerivatives, + Permute: PermuteDerivatives, } diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 743b10b6a..ecb0c9b07 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -15,6 +15,7 @@ from test.core.derivatives.implementation.autograd import AutogradDerivatives from test.core.derivatives.implementation.backpack import BackpackDerivatives from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS +from test.core.derivatives.permute_settings import PERMUTE_SETTINGS from test.core.derivatives.problem import make_test_problems from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS from test.core.derivatives.settings import SETTINGS @@ -50,9 +51,14 @@ RNN_PROBLEMS = make_test_problems(RNN_SETTINGS) RNN_IDS = [problem.make_id() for problem in RNN_PROBLEMS] +PERMUTE_PROBLEMS = make_test_problems(PERMUTE_SETTINGS) +PERMUTE_IDS = [problem.make_id() for problem in PERMUTE_PROBLEMS] + @pytest.mark.parametrize( - "problem", NO_LOSS_PROBLEMS + RNN_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS + "problem", + NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS, ) def test_jac_mat_prod(problem, V=3): """Test the Jacobian-matrix product. @@ -72,7 +78,9 @@ def test_jac_mat_prod(problem, V=3): @pytest.mark.parametrize( - "problem", NO_LOSS_PROBLEMS + RNN_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS + "problem", + NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS, ) def test_jac_t_mat_prod(problem, V=3): """Test the transposed Jacobian-matrix product. diff --git a/test/core/derivatives/permute_settings.py b/test/core/derivatives/permute_settings.py new file mode 100644 index 000000000..8f499a629 --- /dev/null +++ b/test/core/derivatives/permute_settings.py @@ -0,0 +1,33 @@ +"""Test configurations for `backpack.core.derivatives` Permute. + +Required entries: + "module_fn" (callable): Contains a model constructed from `torch.nn` layers + "input_fn" (callable): Used for specifying input function + +Optional entries: + "target_fn" (callable): Fetches the groundtruth/target classes + of regression/classification task + "loss_function_fn" (callable): Loss function used in the model + "device" [list(torch.device)]: List of devices to run the test on. + "id_prefix" (str): Prefix to be included in the test name. + "seed" (int): seed for the random number for torch.rand +""" + +import torch + +from backpack.custom_module.permute import Permute + +PERMUTE_SETTINGS = [ + { + "module_fn": lambda: Permute(0, 1, 2), + "input_fn": lambda: torch.rand(size=(1, 2, 3)), + }, + { + "module_fn": lambda: Permute(2, 0, 1), + "input_fn": lambda: torch.rand(size=(4, 3, 2)), + }, + { + "module_fn": lambda: Permute(3, 1, 0, 2), + "input_fn": lambda: torch.rand(size=(5, 4, 3, 2)), + }, +] diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index 3739cfccd..8bedcfea4 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -95,10 +95,26 @@ def extract_ith_element_of_diag_ggn(i, p, loss, output): return diag_ggns def diag_ggn(self): - _, output, loss = self.problem.forward_pass() - return self._get_diag_ggn(loss, output) + try: + _, output, loss = self.problem.forward_pass() + return self._get_diag_ggn(loss, output) + except RuntimeError: + # torch does not implement cuda double-backwards pass on RNNs and + # recommends this workaround + with torch.backends.cudnn.flags(enabled=False): + _, output, loss = self.problem.forward_pass() + return self._get_diag_ggn(loss, output) def diag_ggn_batch(self): + try: + return self._diag_ggn_batch() + except RuntimeError: + # torch does not implement cuda double-backwards pass on RNNs and + # recommends this workaround + with torch.backends.cudnn.flags(enabled=False): + return self._diag_ggn_batch() + + def _diag_ggn_batch(self): batch_size = self.problem.input.shape[0] _, _, batch_loss = self.problem.forward_pass() loss_list = torch.zeros(batch_size, device=self.problem.device) diff --git a/test/extensions/secondorder/diag_ggn/diaggnn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py similarity index 52% rename from test/extensions/secondorder/diag_ggn/diaggnn_settings.py rename to test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index e0d5b7bbc..11148cfae 100644 --- a/test/extensions/secondorder/diag_ggn/diaggnn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -1,19 +1,35 @@ -"""Test configurations to test diag_ggn +"""Test configurations to test diag_ggn. The tests are taken from `test.extensions.secondorder.secondorder_settings`, but additional custom tests can be defined here by appending it to the list. """ - +from test.core.derivatives.utils import regression_targets from test.extensions.automated_settings import make_simple_act_setting from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS -from torch.nn import ELU, SELU +import torch +from torch.nn import ELU, RNN, SELU, Flatten, Sequential -DiagGGN_SETTINGS = [] +from backpack.custom_module.permute import Permute +from backpack.custom_module.reduce_tuple import ReduceTuple SHARED_SETTINGS = SECONDORDER_SETTINGS -LOCAL_SETTINGS = [] +LOCAL_SETTINGS = [ + # RNN settings + { + "input_fn": lambda: torch.rand(8, 5, 6), + "module_fn": lambda: Sequential( + Permute(1, 0, 2), + RNN(input_size=6, hidden_size=3), + ReduceTuple(index=0), + Permute(1, 2, 0), + Flatten(), + ), + "loss_function_fn": lambda: torch.nn.MSELoss(), + "target_fn": lambda: regression_targets((8, 3 * 5)), + }, +] ############################################################################### # test setting: Activation Layers # diff --git a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py index 378c0a2c7..c7a312bf6 100644 --- a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py @@ -1,8 +1,9 @@ +"""Test BatchDiagGGN extension.""" from test.automated_test import check_sizes_and_values from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems -from test.extensions.secondorder.diag_ggn.diaggnn_settings import DiagGGN_SETTINGS +from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS import pytest @@ -12,7 +13,7 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) def test_diag_ggn_batch(problem): - """Test the individual diagonal of Generalized Gauss-Newton/Fisher + """Test the individual diagonal of Generalized Gauss-Newton/Fisher. Args: problem (ExtensionsTestProblem): Problem for extension test. @@ -33,8 +34,9 @@ def test_diag_ggn_batch(problem): @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) def test_diag_ggn_mc_batch_light(problem): - """Test the MC approximation of individual diagonal of - Generalized Gauss-Newton/Fisher with few mc_samples (light version) + """Test the MC approximation of individual diagonal. + + of Generalized Gauss-Newton/Fisher with few mc_samples (light version) Args: problem (ExtensionsTestProblem): Problem for extension test. @@ -54,8 +56,9 @@ def test_diag_ggn_mc_batch_light(problem): @pytest.mark.montecarlo @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) def test_diag_ggn_mc_batch(problem): - """Test the MC approximation of individual diagonal of Gauss-Newton - with more samples (slow version) + """Test the MC approximation of individual diagonal. + + of generalized Gauss-Newton with more samples (slow version) Args: problem (ExtensionsTestProblem): Problem for extension test. diff --git a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py index 858e479b5..2a982fcd7 100644 --- a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py @@ -1,8 +1,9 @@ +"""Test DiagGGN extension.""" from test.automated_test import check_sizes_and_values from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems -from test.extensions.secondorder.diag_ggn.diaggnn_settings import DiagGGN_SETTINGS +from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS import pytest @@ -12,7 +13,7 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) def test_diag_ggn(problem): - """Test the diagonal of Gauss-Newton + """Test the diagonal of generalized Gauss-Newton. Args: problem (ExtensionsTestProblem): Problem for extension test. @@ -33,8 +34,9 @@ def test_diag_ggn(problem): @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) def test_diag_ggn_mc_light(problem): - """Test the MC approximation of Diagonal of Gauss-Newton - with few mc_samples (light version) + """Test the MC approximation of Diagonal of generalized Gauss-Newton. + + with few mc_samples (light version) Args: problem (ExtensionsTestProblem): Problem for extension test. @@ -54,8 +56,9 @@ def test_diag_ggn_mc_light(problem): @pytest.mark.montecarlo @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) def test_diag_ggn_mc(problem): - """Test the MC approximation of Diagonal of Gauss-Newton - with more samples (slow version) + """Test the MC approximation of Diagonal of generalized Gauss-Newton. + + with more samples (slow version) Args: problem (ExtensionsTestProblem): Problem for extension test. diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index 6cd68463c..a3844dd39 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -1,4 +1,5 @@ -"""Test configurations for `backpack.core.extensions.secondorder` +"""Test configurations for `backpack.core.extensions.secondorder`. + that is shared among the following secondorder methods: - Diagonal of Gauss Newton - Diagonal of Hessian From 3d31b1d314d91c1827d21733ea478b19e3b0e543 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 16 Jun 2021 16:36:47 +0200 Subject: [PATCH 10/54] [FIX] PyPI release of `1.3.0` (#176) (#177) * [FIX] Remove quotes in url * [FIX] Move `pytorch_memlab` dependency to makefile It is not possible to have GitHub dependencies in `setup.cfg` when attempting to push to PyPI * [FIX] Exclude test package in installation https://stackoverflow.com/a/59686298 --- makefile | 1 + setup.cfg | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/makefile b/makefile index ee3c67d75..fbcd353f1 100644 --- a/makefile +++ b/makefile @@ -121,6 +121,7 @@ install-lint: install-test: @pip install -e .[test] + @pip install git+https://git@github.com/Stonesjtu/pytorch_memlab.git@6ab5fab#egg=pytorch_memlab install-docs: @pip install -e .[docs] diff --git a/setup.cfg b/setup.cfg index b9222cced..bf2aa19f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ [metadata] name = backpack-for-pytorch author = Felix Dangel, Frederik Kunstner -URL = "https://github.com/f-dangel/backpack" +url = https://github.com/f-dangel/backpack description = BackPACK: Packing more into backprop long_description = file: README.md long_description_content_type = text/markdown; charset=UTF-8; variant=GFM @@ -41,6 +41,9 @@ install_requires = # Require a specific Python version, e.g. Python 2.7 or >= 3.4 python_requires = >=3.6 +[options.packages.find] +exclude = test* + ############################################################################### # Development dependencies # ############################################################################### @@ -54,7 +57,6 @@ test = pytest-optional-tests >= 0.1.1 pytest-cov coveralls - pytorch_memlab @ git+https://git@github.com/Stonesjtu/pytorch_memlab.git@6ab5fab#egg=pytorch_memlab # Dependencies needed for linting (semicolon/line-separated) lint = From 08008ab9b5c1775baa0697d90652016d4859a258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Fri, 18 Jun 2021 11:12:11 +0200 Subject: [PATCH 11/54] Update RNN branch with development branch (#178) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [DOC] Mention PR documentation requirements * [FMT] Apply auto-formatting with new black release Use black 21.4b0, released on 2021-04-26 (https://github.com/psf/black/releases/tag/21.4b0) * [DOC] Extend first-order extension with custom module tutorial (#152) Modifies how external modules are supported: The old approach was based on class variables, which could lead to subtle bugs due to the lookup order. Switching to instance variables eliminates this potential source of bugs. Includes a self-contained walk through how to extend a custom layer to add support for first-order extensions. Resolves #149 Co-authored-by: Tim Schäfer Co-authored-by: Felix Dangel * [core] ConvTransposeNd _weight_jac_t_mat_prod support groups (#151) Resolves https://github.com/fKunstner/backpack-discuss/issues/84. * support weight_jac_t to ConvTranspose * PR fixes * fix spaces for PR * PR changes * fix black * fix black again * [DOC] Improve documentation Co-authored-by: Felix Dangel * [DOC] Explain `save_memory` for convolutions with use case (#142) * [DOC] Draft use case explaining how to enable save_memory in convolutions * [DOC] Polish save_memory use case and add links to benchmark * [CI] Add partial docstring checks (#157) Fully documented files are added to `fully_documented.txt`. Once the entire code base is documented, the partial checks can be removed, and the full tests be turned on. * [core] Support groups ≠ 1 in ConvNd weight_jac_t (#161) Resolves fKunstner/backpack-discuss#83. * [core] Support groups ≠ 1 in ConvTransposeNd weight_jac_mat_prod (#162) * [core] Support groups ≠ 1 in ConvTransposeNd weight_jac * [test] Merge group conv with regular settings * [format] Remove blank line * [utils] Support groups ≠ 1 in ConvNd extract_weight_diagonal (#163) * [test] Reproduce behavior that groups ≠ 1 is not supported * [utils] Support groups ≠ 1 in ConvNd extract_weight_diagonal Also refactor the code and replace and move a `transpose` inside an `einsum` operation. * [test] Add cases for grouped Conv2d and Conv3d * [fix] Fix Conv{2,3}d by grouping spatial dimensions * [ref, doc] Use inplace power function, add documentation * [doc] polish docstring Co-authored-by: Felix Dangel * [utils] Support groups ≠ 1 in ConvTransposeNd extract_weight_diagonal (#164) Fixes `groups ≠ 1` in diagonal curvature extensions for transpose convolutions. --- * [test] Reproduce behavior that groups ≠ 1 is not supported * [utils] Support groups ≠ 1 in ConvNdTranspose extract_weight_diagonal * [test] Add cases for grouped ConvTranspose{2,3}d * [test] Make ConvTranspose architectures smaller * [utils] Change shape convention of transpose convolution unfolding Align to convention of ``torch.nn.Unfold`` (the equivalent for normal convolutions) and document new behavior. * [ref, doc] Remove transposes, use inplace power, add doc * [refactor] Write kernel size elements more compactly Co-authored-by: Felix Dangel * [test] Increase tolerance for old hvp tests They sometimes fail as they are unseeded. See https://github.com/fKunstner/backpack-discuss/issues/103 where such events are documented to reduce such unwanted failures. * [doc] enforce auto-formatting & linting rules on examples (#166) Resolves https://github.com/fKunstner/backpack-discuss/issues/79. * [doc] Auto-format examples * [lint] Use max-line-length=88 in flake8 configuration * [fix] Make examples pass flake8 * [KFAC, KFLR, KFRA] Raise exception for group convolutions (#167) * [hbp] Raise exception if `groups ≠ 1`, test for KFAC * [hbp] Test KFLR raises exception for `groups ≠ 1` * [hbp] Test KFRA raises exception for `groups ≠ 1` * [ci] Replace fully documented files by parent directory * [doc] Improve test description * [del] Remove dead code Co-authored-by: Felix Dangel * [DiagHessian] Support `ELU` and `SELU` (#168) Fixes two bugs in the second-order derivatives of `ELU` and `SELU`. Resolves https://github.com/fKunstner/backpack-discuss/issues/69. * [core] Fix math of ELU second derivative * [ref] Rewrite ELU derivatives * [DiagHessian] Add support for ELU * [core] Fix bug in second SELU derivative * [DiagHessian] Add support for SELU * [ref] Make ELU and SELU derivatives more similar * [ref] Add blank lines before `return` statements * [doc] Improve documentation * [ADD] BatchDiagHessian extension (#137) Extension to compute per-sample (individual) Hessian diagonal. Resolves https://github.com/fKunstner/backpack-discuss/issues/81. Auxiliary: - Many refactorings in utility functions to extract diagonals and L2 norms - Aggregate terms for (per-sample and sample-mean) Hessian diagonals with inplace summation --- * [doc] Add docstrings in Hessian diagonal for Conv1d * [doc] Add docstrings for Hessian diagonal of Conv{2,3}d * [doc] Add missing docstring for DiagHessian Conv3d * [doc] Fully document {Batch}DiagHessian extensions * [DiagHessian, ConvND] Aggregate terms inplace * [BatchDiagHessian] Share device and dtype with parameter * [fix] Black formatting * [BatchDiagHessian, DiagHessian] Aggregate result inplace * Remove conv dimension from diagonal extraction utilities * Remove conv dimension from diagonal extractors * Remove redundant docstrings * Return unfolded input instead of function that evaluates it * [fix] pydocstyle * Remove get_bias_gradient_factors for convolution * fix bias l2 computation Co-authored-by: Felix Dangel * [REF] Remove dead utility code * [DOC] Add per-sample curvatures to example and doc (#170) Relevant parts copied from https://github.com/f-dangel/backpack/pull/146. Resolves https://github.com/fKunstner/backpack-discuss/issues/90. Resolves https://github.com/fKunstner/backpack-discuss/issues/91. * [DOC] Update supported layers and future features (#171) * [FIX] RTD build for `save_memory` example - Reduce batch size for Read the Docs build With batch size 128, builds fails with `Command killed due to excessive memory consumption` (https://readthedocs.org/projects/backpack/builds/14015934/). - Synchronize CUDA only if device is GPU Log of build from RTD: https://readthedocs.org/projects/backpack/builds/14016051/ * [CFG] Move install and dev tool configuration to `setup.cfg` (#172) - Move configurations of linters and other dev tools to `setup.cfg` - Move dependencies to `setup.cfg` - Move main library meta data and installation info to `setup.cfg`, replace manual with automatic versioning (merged `master` to bump the development version to `1.2.xxx`) Auxiliary: - Remove `pre-commit` from development tools - I originally wanted to automate publishing to PyPI for new tags. This requires setting up GitHub secrets for the PyPI user and password. According to the [doc](https://docs.github.com/en/actions/reference/encrypted-secrets), those secrets can be seen by anyone with 'collaborator' access. I don't think that sharing my credentials is a good solution, hence I will stick to the manual procedure. * [REF] Move linter configurations to a `setup.cfg` * [DEL] Remove pre-commit from workflow * [REF] Move installation and dependencies to `setup.cfg` Switch to automatic version detection. * [DOC] Justify pytest configuration separate from `setup.cfg` * [FIX] Don't lint `.eggs` directory with flake8 * [FIX] Dependencies for read the docs build * Prepare `1.3.0` release (#173) - Update changelog - Update API and improve existing documentation for RTD --- * [DOC] Update and improve docstrings * [DOC] Add `disable` context to API documentation * [DOC] Remove `AvgPool{1,3}d` from supported layers * [DOC] Update changelog for `1.3.0` release * [DOC] Use napoleon sphinx extension * [DOC] Adapt to google docstring * [FIX] Format of scaling issue warning * [DOC] Tweak headers, remove blank lines * [FIX] Add missing colon * [FIX] Make pydocstyle pass * [FIX] PyPI release of `1.3.0` (#176) * [FIX] Remove quotes in url * [FIX] Move `pytorch_memlab` dependency to makefile It is not possible to have GitHub dependencies in `setup.cfg` when attempting to push to PyPI * [FIX] Exclude test package in installation https://stackoverflow.com/a/59686298 * [FIX] PyPI release of `1.3.0` (#176) (#177) * [FIX] Remove quotes in url * [FIX] Move `pytorch_memlab` dependency to makefile It is not possible to have GitHub dependencies in `setup.cfg` when attempting to push to PyPI * [FIX] Exclude test package in installation https://stackoverflow.com/a/59686298 Co-authored-by: f-dangel <48687646+f-dangel@users.noreply.github.com> Co-authored-by: Felix Dangel Co-authored-by: Tim Schäfer Co-authored-by: Shrisha Bharadwaj Co-authored-by: Felix Dangel --- .conda_env.yml | 5 +- .darglint | 4 - .flake8 | 40 --- .github/workflows/test.yaml | 4 +- .isort.cfg | 7 - .pre-commit-config.yaml | 30 --- .pydocstyle | 6 - .readthedocs.yml | 6 +- README-dev.md | 3 +- backpack/__init__.py | 85 +++---- backpack/core/derivatives/conv_transposend.py | 31 ++- backpack/core/derivatives/convnd.py | 10 +- backpack/core/derivatives/elu.py | 25 +- backpack/core/derivatives/selu.py | 33 ++- backpack/extensions/__init__.py | 6 +- backpack/extensions/firstorder/__init__.py | 5 +- .../firstorder/batch_grad/__init__.py | 5 +- .../firstorder/batch_l2_grad/__init__.py | 5 +- .../firstorder/batch_l2_grad/convnd.py | 7 +- .../batch_l2_grad/convtransposend.py | 8 +- .../firstorder/sum_grad_squared/__init__.py | 5 +- .../firstorder/variance/__init__.py | 5 +- backpack/extensions/secondorder/__init__.py | 6 +- .../extensions/secondorder/diag_ggn/conv1d.py | 12 +- .../extensions/secondorder/diag_ggn/conv2d.py | 12 +- .../extensions/secondorder/diag_ggn/conv3d.py | 12 +- .../extensions/secondorder/diag_ggn/convnd.py | 28 +-- .../secondorder/diag_ggn/convtranspose1d.py | 8 +- .../secondorder/diag_ggn/convtranspose2d.py | 8 +- .../secondorder/diag_ggn/convtranspose3d.py | 8 +- .../secondorder/diag_ggn/convtransposend.py | 18 +- .../secondorder/diag_hessian/__init__.py | 60 ++++- .../secondorder/diag_hessian/activations.py | 16 ++ .../secondorder/diag_hessian/conv1d.py | 23 +- .../secondorder/diag_hessian/conv2d.py | 23 +- .../secondorder/diag_hessian/conv3d.py | 23 +- .../secondorder/diag_hessian/convnd.py | 58 ++++- .../diag_hessian/convtranspose1d.py | 12 +- .../diag_hessian/convtranspose2d.py | 12 +- .../diag_hessian/convtranspose3d.py | 12 +- .../diag_hessian/convtransposend.py | 56 ++++- .../secondorder/diag_hessian/linear.py | 53 +++- backpack/extensions/secondorder/hbp/conv2d.py | 12 +- backpack/utils/conv.py | 139 +++++----- backpack/utils/conv_transpose.py | 126 ++++++---- backpack/utils/unsqueeze.py | 38 --- black.toml | 3 +- changelog.md | 154 ++++++++++-- .../basic_usage/example_all_in_one.py | 25 +- .../examples/use_cases/example_cg_newton.py | 16 +- .../use_cases/example_custom_module.py | 5 +- .../use_cases/example_first_order_resnet.py | 6 +- .../use_cases/example_gradient_of_variance.py | 4 +- .../example_save_memory_convolutions.py | 238 ++++++++++++++++++ .../use_cases/example_trace_estimation.py | 57 +++-- docs_src/rtd/conf.py | 1 + docs_src/rtd/extensions.rst | 4 +- docs_src/rtd/good-to-know.rst | 5 - docs_src/rtd/main-api.rst | 2 +- docs_src/rtd/supported-layers.rst | 43 +++- fully_documented.txt | 7 + makefile | 34 +-- pytest.ini | 2 + requirements-dev.txt | 4 - requirements.txt | 3 - requirements/docs.txt | 2 - requirements/lint.txt | 12 - requirements/test.txt | 7 - setup.cfg | 140 +++++++++++ setup.py | 56 +---- test/automated_test.py | 6 +- test/core/derivatives/convolution_settings.py | 6 +- test/core/derivatives/derivatives_test.py | 76 +----- test/core/derivatives/problem.py | 10 - test/core/derivatives/settings.py | 24 +- test/extensions/implementation/autograd.py | 24 +- test/extensions/implementation/backpack.py | 32 +++ test/extensions/implementation/base.py | 34 +++ .../diag_hessian/test_diag_hessian.py | 16 ++ test/extensions/secondorder/hbp/__init__.py | 1 + .../secondorder/hbp/kfac_settings.py | 8 + .../secondorder/hbp/kflr_settings.py | 8 + .../secondorder/hbp/kfra_settings.py | 8 + test/extensions/secondorder/hbp/test_kfac.py | 25 ++ test/extensions/secondorder/hbp/test_kflr.py | 25 ++ test/extensions/secondorder/hbp/test_kfra.py | 25 ++ .../secondorder/secondorder_settings.py | 30 ++- test/utils/test_conv.py | 8 +- test/utils/test_conv_transpose.py | 4 +- 89 files changed, 1517 insertions(+), 793 deletions(-) delete mode 100644 .darglint delete mode 100644 .flake8 delete mode 100644 .isort.cfg delete mode 100644 .pre-commit-config.yaml delete mode 100644 .pydocstyle create mode 100644 docs_src/examples/use_cases/example_save_memory_convolutions.py delete mode 100644 requirements-dev.txt delete mode 100644 requirements.txt delete mode 100644 requirements/docs.txt delete mode 100644 requirements/lint.txt delete mode 100644 requirements/test.txt create mode 100644 setup.cfg create mode 100644 test/extensions/secondorder/hbp/__init__.py create mode 100644 test/extensions/secondorder/hbp/kfac_settings.py create mode 100644 test/extensions/secondorder/hbp/kflr_settings.py create mode 100644 test/extensions/secondorder/hbp/kfra_settings.py create mode 100644 test/extensions/secondorder/hbp/test_kfac.py create mode 100644 test/extensions/secondorder/hbp/test_kflr.py create mode 100644 test/extensions/secondorder/hbp/test_kfra.py diff --git a/.conda_env.yml b/.conda_env.yml index 639997685..1a9bcefc9 100644 --- a/.conda_env.yml +++ b/.conda_env.yml @@ -6,6 +6,7 @@ dependencies: - pip=19.3.1 - python=3.7.6 - pip: - - -r requirements.txt - - -r requirements-dev.txt - -e . + - -e .[lint] + - -e .[test] + - -e .[docs] diff --git a/.darglint b/.darglint deleted file mode 100644 index 901f5823e..000000000 --- a/.darglint +++ /dev/null @@ -1,4 +0,0 @@ -[darglint] -docstring_style = google -# short, long, full -strictness = full \ No newline at end of file diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 6b377e5a5..000000000 --- a/.flake8 +++ /dev/null @@ -1,40 +0,0 @@ -[flake8] -select = B,C,E,F,P,W,B9 -max-line-length = 80 -max-complexity = 10 -ignore = - # replaced by B950 (max-line-length + 10%) - E501, # max-line-length - # ignored because pytorch uses dict - C408, # use {} instead of dict() - # Not Black-compatible - E203, # whitespace before : - E231, # missing whitespace after ',' - W291, # trailing whitespace - W503, # line break before binary operator - W504, # line break after binary operator -exclude = docs, docs_src, build, .git - - -# Differences with pytorch -# -# Smaller max-line-length -# Enabled max-complexity -# No flake8-mypy (T4 range) -# -# Set of rules ignore by pytorch, probably to get around the C -# -# F401 (import unused in __init__.py) not ignored -# F403 'from module import *' used; unable to detect undefined names -# F405 Name may be undefined, or defined from star imports: module -# F821 Undefined name name -# F841 Local variable name is assigned to but never used -# -# Pytorch ignored rules that I don't see a reason to ignore (yet?): -# -# E305 Expected 2 blank lines after end of function or class -# E402 Module level import not at top of file -# E721 Do not compare types, use 'isinstance()' -# E741 Do not use variables named 'l', 'o', or 'i' -# E302 Expected 2 blank lines, found 0 -# E303 Too many blank lines (3) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d75c724ac..c235d3eea 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,9 +29,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -r requirements/test.txt - pip install . + make install-test - name: Run test if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref) run: | diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index be0f1bf27..000000000 --- a/.isort.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[settings] -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -line_length=88 -skip_glob=docs/*,docs_src/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index d019c8b32..000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,30 +0,0 @@ -repos: -- repo: https://github.com/psf/black - rev: stable - hooks: - - id: black - args: [--config=black.toml] -- repo: https://gitlab.com/pycqa/flake8 - rev: '3.7.9' - hooks: - - id: flake8 - additional_dependencies: [ - mccabe, - pycodestyle, - pyflakes, - pep8-naming, - flake8-bugbear, - flake8-comprehensions, - ] -- repo: https://github.com/pycqa/pydocstyle - rev: 5.0.2 - hooks: - - id: pydocstyle - args: - - --count -- repo: https://github.com/terrencepreilly/darglint - rev: master - hooks: - - id: darglint - args: - - --verbosity 2 diff --git a/.pydocstyle b/.pydocstyle deleted file mode 100644 index 018314871..000000000 --- a/.pydocstyle +++ /dev/null @@ -1,6 +0,0 @@ -[pydocstyle] -convention = google -match = .*\.py -# ignore = -# A, -# B, diff --git a/.readthedocs.yml b/.readthedocs.yml index 64229fb6e..5e0b148fd 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -9,8 +9,8 @@ sphinx: python: version: 3.7 install: - - requirements: requirements.txt - - requirements: requirements/docs.txt - method: pip path: . - system_packages: true \ No newline at end of file + extra_requirements: + - docs + system_packages: true diff --git a/README-dev.md b/README-dev.md index 95edfb963..e853e7bc8 100644 --- a/README-dev.md +++ b/README-dev.md @@ -27,7 +27,7 @@ git checkout development make conda-env conda activate backpack ``` -3. Install the development dependencies and `pre-commit` hooks +3. Install the development dependencies ```bash make install-dev ``` @@ -84,6 +84,5 @@ Code that is affected (has a `git diff`) by a pull request must satisfy the foll - Lint code: [`flake8`](http://flake8.pycqa.org/) ([`flake8` config](.flake8)) - Check docstring style: [`pydocstyle`](https://github.com/PyCQA/pydocstyle) ([`pydocstyle` config](.pydocstyle)) - Check docstring description matches definition: [`darglint`](https://github.com/terrencepreilly/darglint) ([`darglint` config](.darglint)) -- Optional [`pre-commit`](https://github.com/pre-commit/pre-commit) hooks [ `pre-commit` config ](.pre-commit-config.yaml) ###### _BackPACK is not endorsed by or affiliated with Facebook, Inc. PyTorch, the PyTorch logo and any related marks are trademarks of Facebook, Inc._ diff --git a/backpack/__init__.py b/backpack/__init__.py index 38ca4ff41..35e339778 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -11,57 +11,34 @@ class backpack: - """Activates Backpack Extensions. + """Activate BackPACK extensions. - Activates the BackPACK extensions passed as arguments for the - :code:`backward` calls in the current :code:`with` block. + Enables the BackPACK extensions passed as arguments in the + :code:`backward` calls inside the current :code:`with` block. + + Args: + exts ([BackpropExtension]): Extensions to activate in the backward pass. + extension_hook (function, optional): Function called on each module after + all BackPACK extensions have run. Takes a ``torch.nn.Module`` and returns + ``None``. Default: ``None`` (no operation will be formed). + + Can be used to reduce memory overhead if the goal is to compute + transformations of BackPACK quantities. Information can be compacted + during a backward pass and obsolete tensors be freed manually (``del``). + + .. note:: + + If the callable iterates over the ``module.parameters()``, the same + parameter may be seen multiple times across calls. This happens + if the parameters are part of multiple modules. + For example, the parameters of a `torch.nn.Linear` module in + ``model = torch.nn.Sequential(torch.nn.Linear(...))`` are part of + both the ``Linear`` and the ``Sequential``. + debug (bool, optional): Print debug messages during the backward pass. + Default: ``False``. """ def __init__(self, *exts: BackpropExtension, extension_hook=None, debug=False): - """Activate the Backpack extensions. - - Example usage: - ``` - X, Y, model, lossfunc = get_problem() - - backpack.extend(model) - backpack.extend(lossfunc) - - with backpack.backpack(backpack.extensions.Variance()): - lossfunc(model(X), Y).backward() - - for p in model.parameters(): - print(p.grad) - print(p.variance) - ``` - - .. warning :: - - The quantities computed by backPACK may be garbage collected when - exiting the `with` clause. Use them within the `with` clause or - assign them to another variable. - - Attributes: - args: [BackpropExtension] - The extensions to activate for the backward pass. - extension_hook: Callable, optional (default: None) - Function called on each module after all BackPACK extensions have run. - Takes a ``torch.nn.Module`` and returns ``None``. - - Can be used to reduce memory overhead if the goal is to compute - transformations of BackPACK quantities. Information can be compacted - during a backward pass and obsolete tensors be freed manually (``del``). - - Note: - If the callable iterates over the ``module.parameters()``, the same - parameter may be seen multiple times across calls. This happens - if the parameters are part of multiple modules. - For example, the parameters of a `torch.nn.Linear` module in - ``model = torch.nn.Sequential(torch.nn.Linear(...))`` are part of - both the ``Linear`` and the ``Sequential``. - debug: Bool, optional (default: False) - If true, will print debug messages during the backward pass. - """ for ext in exts: if not isinstance(ext, BackpropExtension): if inspect.isclass(ext) and issubclass(ext, BackpropExtension): @@ -95,7 +72,7 @@ def __exit__(self, type, value, traceback): class disable: - """Entirely disables BackPACK, including storage of input and output. + """Entirely disable BackPACK, including storage of input and output. To compute the additional quantities, BackPACK needs to know the input and output of the modules in the computation graph. It saves those by default. @@ -196,16 +173,18 @@ def run_extension_hook(module): def extend(module: torch.nn.Module, debug=False): - """Extends the ``module`` to make it backPACK-ready. + """Extends a ``module`` to make it BackPACK-ready. If the ``module`` has children, e.g. for a ``torch.nn.Sequential``, they will also be extended. Args: - module: torch.nn.Module - The module to extend - debug: Bool, optional (default: False) - If true, will print debug messages during the extension. + module (torch.nn.Module): The module to extend. + debug (bool, optional): Print debug messages during the extension. + Default: ``False``. + + Returns: + torch.nn.Module: Extended module. """ if debug: print("[DEBUG] Extending", module) diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index b296f6af9..9cf9f223d 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -1,3 +1,4 @@ +"""Partial derivatives for ``torch.nn.ConvTranspose{1,2,3}d``.""" from einops import rearrange from numpy import prod from torch import einsum @@ -17,7 +18,17 @@ class ConvTransposeNDDerivatives(BaseParameterDerivatives): + """Base class for partial derivatives of transpose convolution.""" + def __init__(self, N): + """Store convolution dimension and operations. + + Args: + N (int): Convolution dimension. Must be ``1``, ``2``, or ``3``. + + Raises: + ValueError: If convolution dimension is unsupported. + """ if N == 1: self.module = ConvTranspose1d self.conv_func = conv1d @@ -31,7 +42,7 @@ def __init__(self, N): self.conv_func = conv3d self.conv_transpose_func = conv_transpose3d else: - raise ValueError("{}-dimensional Conv. is not implemented.".format(N)) + raise ValueError(f"ConvTranspose{N}d not supported.") self.conv_dims = N def hessian_is_zero(self): @@ -55,31 +66,25 @@ def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): return jac_mat.expand(*expand_shape) def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): - if module.groups != 1: - raise NotImplementedError("Groups greater than 1 are not supported yet") - V = mat.shape[0] G = module.groups C_in = module.input0.shape[1] N = module.output.shape[0] C_out = module.output.shape[1] - mat_reshape = mat.reshape(V, C_in, G, C_out // G, *module.weight.shape[2:]) + mat_reshape = mat.reshape(V, G, C_in // G, C_out // G, *module.weight.shape[2:]) u = unfold_by_conv_transpose(module.input0, module).reshape( - N, C_in // G, G, *module.weight.shape[2:], *module.output.shape[2:] + N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:] ) dims_kern = "xyz"[: self.conv_dims] dims_data = "abc"[: self.conv_dims] - einstr = "nig{0}{1},vigo{0}->vngo{1}".format(dims_kern, dims_data) + einstr = "ngi{0}{1},vgio{0}->vngo{1}".format(dims_kern, dims_data) jac_mat = einsum(einstr, u, mat_reshape) return self.reshape_like_output(jac_mat, module) def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - if module.groups != 1: - raise NotImplementedError("Groups greater than 1 are not supported yet") - V = mat.shape[0] G = module.groups C_in = module.input0.shape[1] @@ -89,13 +94,13 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): mat_reshape = mat.reshape(V, N, G, C_out // G, *module.output.shape[2:]) u = unfold_by_conv_transpose(module.input0, module).reshape( - N, C_in // G, G, *module.weight.shape[2:], *module.output.shape[2:] + N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:] ) dims_kern = "xyz"[: self.conv_dims] dims_data = "abc"[: self.conv_dims] - result_str = ("vigo" if sum_batch else "vnigo") + dims_kern - equation = "nig{0}{1},vngo{1}->{2}".format(dims_kern, dims_data, result_str) + result_str = ("vgio" if sum_batch else "vngio") + dims_kern + equation = f"ngi{dims_kern}{dims_data},vngo{dims_data}->{result_str}" final_shape = ( (V, *module.weight.shape) if sum_batch else (V, N, *module.weight.shape) diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py index c48b08d70..e0c5bf8cf 100644 --- a/backpack/core/derivatives/convnd.py +++ b/backpack/core/derivatives/convnd.py @@ -121,12 +121,14 @@ def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): return mat.sum(axes) def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): - if module.groups != 1: - raise NotImplementedError("Groups greater than 1 are not supported yet") + # separate output channel groups + jac_mat = rearrange(mat, "v (g o) i ... -> v g o (i ...)", g=module.groups) - jac_mat = rearrange(mat, "v o i ... -> v o (i ...)") X = self.get_unfolded_input(module) - jac_mat = einsum("nij,vki->vnkj", X, jac_mat) + # separate input channel groups + X = rearrange(X, "n (g i) j -> n g i j", g=module.groups) + jac_mat = einsum("ngij,vgki->vngkj", X, jac_mat) + return self.reshape_like_output(jac_mat, module) def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): diff --git a/backpack/core/derivatives/elu.py b/backpack/core/derivatives/elu.py index f06dd9db3..74092e883 100644 --- a/backpack/core/derivatives/elu.py +++ b/backpack/core/derivatives/elu.py @@ -1,19 +1,30 @@ -from torch import exp, gt +"""Partial derivatives for the ELU activation function.""" +from torch import exp, le, ones_like, zeros_like from backpack.core.derivatives.elementwise import ElementwiseDerivatives class ELUDerivatives(ElementwiseDerivatives): + """Implement first- and second-order partial derivatives of ELU.""" + def hessian_is_zero(self): """`ELU''(x) ≠ 0`.""" return False def df(self, module, g_inp, g_out): - """First ELU derivative: `ELU'(x) = alpha * e^x if x < 0 else 1`.""" - df_ELU = gt(module.input0, 0).float() - df_ELU[df_ELU == 0] = module.alpha * exp(module.input0[df_ELU == 0]) - return df_ELU + """First ELU derivative: `ELU'(x) = alpha * e^x if x <= 0 else 1`.""" + non_pos = le(module.input0, 0) + + result = ones_like(module.input0) + result[non_pos] = module.alpha * exp(module.input0[non_pos]) + + return result def d2f(self, module, g_inp, g_out): - """Second ELU derivative: `ELU''(x) = alpha * e^x if x < 0 else 1`.""" - return self.df(module, g_inp, g_out) + """Second ELU derivative: `ELU''(x) = alpha * e^x if x <= 0 else 0`.""" + non_pos = le(module.input0, 0) + + result = zeros_like(module.input0) + result[non_pos] = module.alpha * exp(module.input0[non_pos]) + + return result diff --git a/backpack/core/derivatives/selu.py b/backpack/core/derivatives/selu.py index 1e0cf00f2..33c4a9ceb 100644 --- a/backpack/core/derivatives/selu.py +++ b/backpack/core/derivatives/selu.py @@ -1,10 +1,11 @@ -from torch import exp, gt +"""Partial derivatives for the SELU activation function.""" +from torch import exp, le, ones_like, zeros_like from backpack.core.derivatives.elementwise import ElementwiseDerivatives class SELUDerivatives(ElementwiseDerivatives): - """Alpha and scale are not input_kwargs""" + """Implement first- and second-order partial derivatives of SELU.""" alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 @@ -14,21 +15,19 @@ def hessian_is_zero(self): return False def df(self, module, g_inp, g_out): - """First SELU derivative: `SELU'(x) = scale if x < 0 else scale*alpha*e^x`.""" + """First SELU derivative: `SELU'(x) = scale if x > 0 else scale*alpha*e^x`.""" + non_pos = le(module.input0, 0) - df_SELU = gt(module.input0, 0).float() - df_SELU[df_SELU == 1] = self.scale - df_SELU[df_SELU == 0] = ( - self.scale * self.alpha * exp(module.input0[df_SELU == 0]) - ) - return df_SELU + result = self.scale * ones_like(module.input0) + result[non_pos] = self.scale * self.alpha * exp(module.input0[non_pos]) + + return result def d2f(self, module, g_inp, g_out): - """Second SELU derivative: `SELU''(x) = 0 if x < 0 else scale*alpha*e^x`.""" - - d2f_SELU = gt(module.input0, 0).float() - d2f_SELU[d2f_SELU == 1] = 0 - d2f_SELU[d2f_SELU == 0] = ( - self.scale * self.alpha * exp(module.input0[d2f_SELU == 0]) - ) - return d2f_SELU + """Second SELU derivative: `SELU''(x) = 0 if x > 0 else scale*alpha*e^x`.""" + non_pos = le(module.input0, 0) + + result = zeros_like(module.input0) + result[non_pos] = self.scale * self.alpha * exp(module.input0[non_pos]) + + return result diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py index 8e66fdbe2..a84a64f71 100644 --- a/backpack/extensions/__init__.py +++ b/backpack/extensions/__init__.py @@ -1,6 +1,4 @@ -""" -BackPACK Extensions -""" +"""BackPACK extensions that can be passed into a ``with backpack(...)`` context.""" from .curvmatprod import GGNMP, HMP, PCHMP from .firstorder import BatchGrad, BatchL2Grad, SumGradSquared, Variance @@ -11,6 +9,7 @@ KFRA, BatchDiagGGNExact, BatchDiagGGNMC, + BatchDiagHessian, DiagGGNExact, DiagGGNMC, DiagHessian, @@ -33,4 +32,5 @@ "DiagGGNMC", "BatchDiagGGNMC", "DiagHessian", + "BatchDiagHessian", ] diff --git a/backpack/extensions/firstorder/__init__.py b/backpack/extensions/firstorder/__init__.py index 860ac4bff..61a6e098a 100644 --- a/backpack/extensions/firstorder/__init__.py +++ b/backpack/extensions/firstorder/__init__.py @@ -1,4 +1,4 @@ -"""First order extensions. +"""First order extensions =================================== First-order extensions make it easier to extract information from the gradients @@ -14,9 +14,6 @@ The variance of the individual gradients - :func:`BatchL2Grad ` The L2 norm of the individual gradients - - - """ from .batch_grad import BatchGrad diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 9417d8449..bee9c069a 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -35,7 +35,10 @@ class BatchGrad(BackpropExtension): Stores the output in ``grad_batch`` as a ``[N x ...]`` tensor, where ``N`` batch size and ``...`` is the shape of the gradient. - Note: beware of scaling issue + .. note:: + + Beware of scaling issue + The `individual gradients` depend on the scaling of the overall function. Let ``fᵢ`` be the loss of the ``i`` th sample, with gradient ``gᵢ``. ``BatchGrad`` will return diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 2639285e5..7d2d97ed1 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -34,7 +34,10 @@ class BatchL2Grad(BackpropExtension): Stores the output in ``batch_l2`` as a tensor of size ``[N]``, where ``N`` is the batch size. - Note: beware of scaling issue + .. note:: + + Beware of scaling issue + The individual L2 norm depends on the scaling of the overall function. Let ``fᵢ`` be the loss of the ``i`` th sample, with gradient ``gᵢ``. ``BatchL2Grad`` will return the L2 norm of diff --git a/backpack/extensions/firstorder/batch_l2_grad/convnd.py b/backpack/extensions/firstorder/batch_l2_grad/convnd.py index 0c8462a16..55c542a03 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/convnd.py +++ b/backpack/extensions/firstorder/batch_l2_grad/convnd.py @@ -9,9 +9,12 @@ def __init__(self, N, params=None): super().__init__(params=params) self.N = N + # TODO Use bias Jacobian to compute `bias_gradient` def bias(self, ext, module, g_inp, g_out, backproped): - C_axis = 1 - return convUtils.get_bias_gradient_factors(g_out[0], C_axis, self.N) + spatial_dims = list(range(2, g_out[0].dim())) + channel_dim = 1 + + return g_out[0].sum(spatial_dims).pow_(2).sum(channel_dim) def weight(self, ext, module, g_inp, g_out, backproped): X, dE_dY = convUtils.get_weight_gradient_factors( diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py index a6b0522e6..9ceaa7881 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py +++ b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py @@ -1,7 +1,6 @@ from torch import einsum from backpack.extensions.firstorder.base import FirstOrderModuleExtension -from backpack.utils import conv as convUtils from backpack.utils import conv_transpose as convTransposeUtils @@ -10,9 +9,12 @@ def __init__(self, N, params=None): super().__init__(params=params) self.N = N + # TODO Use bias Jacobian to compute `bias_gradient` def bias(self, ext, module, g_inp, g_out, backproped): - C_axis = 1 - return convUtils.get_bias_gradient_factors(g_out[0], C_axis, self.N) + spatial_dims = list(range(2, g_out[0].dim())) + channel_dim = 1 + + return g_out[0].sum(spatial_dims).pow_(2).sum(channel_dim) def weight(self, ext, module, g_inp, g_out, backproped): X, dE_dY = convTransposeUtils.get_weight_gradient_factors( diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index 0e5b0ec4c..97de5c258 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -32,7 +32,10 @@ class SumGradSquared(BackpropExtension): Stores the output in ``sum_grad_squared``. Same dimension as the gradient. - Note: beware of scaling issue + .. note:: + + Beware of scaling issue + The second moment depends on the scaling of the overall function. Let ``fᵢ`` be the loss of the ``i`` th sample, with gradient ``gᵢ``. ``SumGradSquared`` will return the sum of the squared diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index cd69c13a3..5d8a5ebeb 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -32,7 +32,10 @@ class Variance(BackpropExtension): Stores the output in ``variance``. Same dimension as the gradient. - Note: beware of scaling issue + .. note:: + + Beware of scaling issue + The variance depends on the scaling of the overall function. Let ``fᵢ`` be the loss of the ``i`` th sample, with gradient ``gᵢ``. ``Variance`` will return the variance of the vectors diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py index 09b31c8d0..afe9ebc2f 100644 --- a/backpack/extensions/secondorder/__init__.py +++ b/backpack/extensions/secondorder/__init__.py @@ -1,4 +1,5 @@ -"""Second order extensions. +# noqa: D205, D415 +"""Second order extensions ==================================== Second-order extensions propagate additional information through the graph @@ -19,7 +20,7 @@ """ from .diag_ggn import BatchDiagGGNExact, BatchDiagGGNMC, DiagGGNExact, DiagGGNMC -from .diag_hessian import DiagHessian +from .diag_hessian import BatchDiagHessian, DiagHessian from .hbp import HBP, KFAC, KFLR, KFRA __all__ = [ @@ -28,6 +29,7 @@ "DiagGGNMC", "BatchDiagGGNMC", "DiagHessian", + "BatchDiagHessian", "KFAC", "KFLR", "KFRA", diff --git a/backpack/extensions/secondorder/diag_ggn/conv1d.py b/backpack/extensions/secondorder/diag_ggn/conv1d.py index 82ffb7e93..4ce456f58 100644 --- a/backpack/extensions/secondorder/diag_ggn/conv1d.py +++ b/backpack/extensions/secondorder/diag_ggn/conv1d.py @@ -7,17 +7,9 @@ class DiagGGNConv1d(DiagGGNConvND): def __init__(self): - super().__init__( - derivatives=Conv1DDerivatives(), - N=1, - params=["bias", "weight"], - ) + super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) class BatchDiagGGNConv1d(BatchDiagGGNConvND): def __init__(self): - super().__init__( - derivatives=Conv1DDerivatives(), - N=1, - params=["bias", "weight"], - ) + super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/diag_ggn/conv2d.py b/backpack/extensions/secondorder/diag_ggn/conv2d.py index b8ef5da4c..ed98e0ab8 100644 --- a/backpack/extensions/secondorder/diag_ggn/conv2d.py +++ b/backpack/extensions/secondorder/diag_ggn/conv2d.py @@ -7,17 +7,9 @@ class DiagGGNConv2d(DiagGGNConvND): def __init__(self): - super().__init__( - derivatives=Conv2DDerivatives(), - N=2, - params=["bias", "weight"], - ) + super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) class BatchDiagGGNConv2d(BatchDiagGGNConvND): def __init__(self): - super().__init__( - derivatives=Conv2DDerivatives(), - N=2, - params=["bias", "weight"], - ) + super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/diag_ggn/conv3d.py b/backpack/extensions/secondorder/diag_ggn/conv3d.py index d1fe39854..5bda72276 100644 --- a/backpack/extensions/secondorder/diag_ggn/conv3d.py +++ b/backpack/extensions/secondorder/diag_ggn/conv3d.py @@ -7,17 +7,9 @@ class DiagGGNConv3d(DiagGGNConvND): def __init__(self): - super().__init__( - derivatives=Conv3DDerivatives(), - N=3, - params=["bias", "weight"], - ) + super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) class BatchDiagGGNConv3d(BatchDiagGGNConvND): def __init__(self): - super().__init__( - derivatives=Conv3DDerivatives(), - N=3, - params=["bias", "weight"], - ) + super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/diag_ggn/convnd.py b/backpack/extensions/secondorder/diag_ggn/convnd.py index 36690a223..3c48f5fe9 100644 --- a/backpack/extensions/secondorder/diag_ggn/convnd.py +++ b/backpack/extensions/secondorder/diag_ggn/convnd.py @@ -3,42 +3,26 @@ class DiagGGNConvND(DiagGGNBaseModule): - def __init__(self, derivatives, N, params=None): - super().__init__(derivatives=derivatives, params=params) - self.N = N - def bias(self, ext, module, grad_inp, grad_out, backproped): sqrt_ggn = backproped - return convUtils.extract_bias_diagonal(module, sqrt_ggn, self.N, sum_batch=True) + return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=True) def weight(self, ext, module, grad_inp, grad_out, backproped): - if self.N == 2: - X = convUtils.unfold_func(module)(module.input0) - else: - X = convUtils.unfold_by_conv(module.input0, module) + X = convUtils.unfold_input(module, module.input0) weight_diag = convUtils.extract_weight_diagonal( - module, X, backproped, self.N, sum_batch=True + module, X, backproped, sum_batch=True ) return weight_diag class BatchDiagGGNConvND(DiagGGNBaseModule): - def __init__(self, derivatives, N, params=None): - super().__init__(derivatives=derivatives, params=params) - self.N = N - def bias(self, ext, module, grad_inp, grad_out, backproped): sqrt_ggn = backproped - return convUtils.extract_bias_diagonal( - module, sqrt_ggn, self.N, sum_batch=False - ) + return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=False) def weight(self, ext, module, grad_inp, grad_out, backproped): - if self.N == 2: - X = convUtils.unfold_func(module)(module.input0) - else: - X = convUtils.unfold_by_conv(module.input0, module) + X = convUtils.unfold_input(module, module.input0) weight_diag = convUtils.extract_weight_diagonal( - module, X, backproped, self.N, sum_batch=False + module, X, backproped, sum_batch=False ) return weight_diag diff --git a/backpack/extensions/secondorder/diag_ggn/convtranspose1d.py b/backpack/extensions/secondorder/diag_ggn/convtranspose1d.py index db0ce55cd..05cbfcdc9 100644 --- a/backpack/extensions/secondorder/diag_ggn/convtranspose1d.py +++ b/backpack/extensions/secondorder/diag_ggn/convtranspose1d.py @@ -8,16 +8,12 @@ class DiagGGNConvTranspose1d(DiagGGNConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose1DDerivatives(), - N=1, - params=["bias", "weight"], + derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] ) class BatchDiagGGNConvTranspose1d(BatchDiagGGNConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose1DDerivatives(), - N=1, - params=["bias", "weight"], + derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/secondorder/diag_ggn/convtranspose2d.py b/backpack/extensions/secondorder/diag_ggn/convtranspose2d.py index ffa6d75b6..37f7a6bfe 100644 --- a/backpack/extensions/secondorder/diag_ggn/convtranspose2d.py +++ b/backpack/extensions/secondorder/diag_ggn/convtranspose2d.py @@ -8,16 +8,12 @@ class DiagGGNConvTranspose2d(DiagGGNConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose2DDerivatives(), - N=2, - params=["bias", "weight"], + derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] ) class BatchDiagGGNConvTranspose2d(BatchDiagGGNConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose2DDerivatives(), - N=2, - params=["bias", "weight"], + derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/secondorder/diag_ggn/convtranspose3d.py b/backpack/extensions/secondorder/diag_ggn/convtranspose3d.py index 140d3b419..e9083fd03 100644 --- a/backpack/extensions/secondorder/diag_ggn/convtranspose3d.py +++ b/backpack/extensions/secondorder/diag_ggn/convtranspose3d.py @@ -8,16 +8,12 @@ class DiagGGNConvTranspose3d(DiagGGNConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose3DDerivatives(), - N=3, - params=["bias", "weight"], + derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] ) class BatchDiagGGNConvTranspose3d(BatchDiagGGNConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose3DDerivatives(), - N=3, - params=["bias", "weight"], + derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/secondorder/diag_ggn/convtransposend.py b/backpack/extensions/secondorder/diag_ggn/convtransposend.py index 05ee01f33..2e83fae4d 100644 --- a/backpack/extensions/secondorder/diag_ggn/convtransposend.py +++ b/backpack/extensions/secondorder/diag_ggn/convtransposend.py @@ -3,36 +3,26 @@ class DiagGGNConvTransposeND(DiagGGNBaseModule): - def __init__(self, derivatives, N, params=None): - super().__init__(derivatives=derivatives, params=params) - self.N = N - def bias(self, ext, module, grad_inp, grad_out, backproped): sqrt_ggn = backproped - return convUtils.extract_bias_diagonal(module, sqrt_ggn, self.N, sum_batch=True) + return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=True) def weight(self, ext, module, grad_inp, grad_out, backproped): X = convUtils.unfold_by_conv_transpose(module.input0, module) weight_diag = convUtils.extract_weight_diagonal( - module, X, backproped, self.N, sum_batch=True + module, X, backproped, sum_batch=True ) return weight_diag class BatchDiagGGNConvTransposeND(DiagGGNBaseModule): - def __init__(self, derivatives, N, params=None): - super().__init__(derivatives=derivatives, params=params) - self.N = N - def bias(self, ext, module, grad_inp, grad_out, backproped): sqrt_ggn = backproped - return convUtils.extract_bias_diagonal( - module, sqrt_ggn, self.N, sum_batch=False - ) + return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=False) def weight(self, ext, module, grad_inp, grad_out, backproped): X = convUtils.unfold_by_conv_transpose(module.input0, module) weight_diag = convUtils.extract_weight_diagonal( - module, X, backproped, self.N, sum_batch=False + module, X, backproped, sum_batch=False ) return weight_diag diff --git a/backpack/extensions/secondorder/diag_hessian/__init__.py b/backpack/extensions/secondorder/diag_hessian/__init__.py index 010fd7a6c..6bb9933d8 100644 --- a/backpack/extensions/secondorder/diag_hessian/__init__.py +++ b/backpack/extensions/secondorder/diag_hessian/__init__.py @@ -1,4 +1,11 @@ +"""Define BackPACK extensions based on the Hessian diagonal. + +- Hessian diagonal +- Per-sample (individual) Hessian diagonal +""" from torch.nn import ( + ELU, + SELU, AvgPool1d, AvgPool2d, AvgPool3d, @@ -44,8 +51,7 @@ class DiagHessian(BackpropExtension): - """ - Diagonal of the Hessian. + """BackPACK extension that computes the Hessian diagonal. Stores the output in :code:`diag_h`, has the same dimensions as the gradient. @@ -56,6 +62,7 @@ class DiagHessian(BackpropExtension): """ def __init__(self): + """Store savefield and mappings between layers and module extensions.""" super().__init__( savefield="diag_h", fail_mode="ERROR", @@ -83,5 +90,54 @@ def __init__(self): Tanh: activations.DiagHTanh(), LeakyReLU: activations.DiagHLeakyReLU(), LogSigmoid: activations.DiagHLogSigmoid(), + ELU: activations.DiagHELU(), + SELU: activations.DiagHSELU(), + }, + ) + + +class BatchDiagHessian(BackpropExtension): + """BackPACK extensions that computes the per-sample (individual) Hessian diagonal. + + Stores the output in ``diag_h_batch`` as a ``[N x ...]`` tensor, + where ``N`` is the batch size and ``...`` is the parameter shape. + + .. warning:: + + Very expensive on networks with non-piecewise linear activations. + + """ + + def __init__(self): + """Store savefield and mappings between layers and module extensions.""" + super().__init__( + savefield="diag_h_batch", + fail_mode="ERROR", + module_exts={ + MSELoss: losses.DiagHMSELoss(), + CrossEntropyLoss: losses.DiagHCrossEntropyLoss(), + Linear: linear.BatchDiagHLinear(), + MaxPool1d: pooling.DiagHMaxPool1d(), + MaxPool2d: pooling.DiagHMaxPool2d(), + AvgPool1d: pooling.DiagHAvgPool1d(), + MaxPool3d: pooling.DiagHMaxPool3d(), + AvgPool2d: pooling.DiagHAvgPool2d(), + AvgPool3d: pooling.DiagHAvgPool3d(), + ZeroPad2d: padding.DiagHZeroPad2d(), + Conv1d: conv1d.BatchDiagHConv1d(), + Conv2d: conv2d.BatchDiagHConv2d(), + Conv3d: conv3d.BatchDiagHConv3d(), + ConvTranspose1d: convtranspose1d.BatchDiagHConvTranspose1d(), + ConvTranspose2d: convtranspose2d.BatchDiagHConvTranspose2d(), + ConvTranspose3d: convtranspose3d.BatchDiagHConvTranspose3d(), + Dropout: dropout.DiagHDropout(), + Flatten: flatten.DiagHFlatten(), + ReLU: activations.DiagHReLU(), + Sigmoid: activations.DiagHSigmoid(), + Tanh: activations.DiagHTanh(), + LeakyReLU: activations.DiagHLeakyReLU(), + LogSigmoid: activations.DiagHLogSigmoid(), + ELU: activations.DiagHELU(), + SELU: activations.DiagHSELU(), }, ) diff --git a/backpack/extensions/secondorder/diag_hessian/activations.py b/backpack/extensions/secondorder/diag_hessian/activations.py index ebd7ee6f6..a47d182e7 100644 --- a/backpack/extensions/secondorder/diag_hessian/activations.py +++ b/backpack/extensions/secondorder/diag_hessian/activations.py @@ -1,6 +1,8 @@ +from backpack.core.derivatives.elu import ELUDerivatives from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives from backpack.core.derivatives.relu import ReLUDerivatives +from backpack.core.derivatives.selu import SELUDerivatives from backpack.core.derivatives.sigmoid import SigmoidDerivatives from backpack.core.derivatives.tanh import TanhDerivatives from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule @@ -29,3 +31,17 @@ def __init__(self): class DiagHLogSigmoid(DiagHBaseModule): def __init__(self): super().__init__(derivatives=LogSigmoidDerivatives()) + + +class DiagHELU(DiagHBaseModule): + """Module extension that computes the Hessian diagonal for ``torch.nn.ELU``.""" + + def __init__(self): + super().__init__(derivatives=ELUDerivatives()) + + +class DiagHSELU(DiagHBaseModule): + """Module extension that computes the Hessian diagonal for ``torch.nn.SELU``.""" + + def __init__(self): + super().__init__(derivatives=SELUDerivatives()) diff --git a/backpack/extensions/secondorder/diag_hessian/conv1d.py b/backpack/extensions/secondorder/diag_hessian/conv1d.py index 2432b5f05..3f9a59e64 100644 --- a/backpack/extensions/secondorder/diag_hessian/conv1d.py +++ b/backpack/extensions/secondorder/diag_hessian/conv1d.py @@ -1,11 +1,22 @@ +"""Module extensions for diagonal Hessian properties of ``torch.nn.Conv1d``.""" from backpack.core.derivatives.conv1d import Conv1DDerivatives -from backpack.extensions.secondorder.diag_hessian.convnd import DiagHConvND +from backpack.extensions.secondorder.diag_hessian.convnd import ( + BatchDiagHConvND, + DiagHConvND, +) class DiagHConv1d(DiagHConvND): + """Module extension for the Hessian diagonal of ``torch.nn.Conv1d``.""" + + def __init__(self): + """Store parameter names and derivatives object.""" + super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) + + +class BatchDiagHConv1d(BatchDiagHConvND): + """Module extension for the per-sample Hessian diagonal of ``torch.nn.Conv1d``.""" + def __init__(self): - super().__init__( - derivatives=Conv1DDerivatives(), - N=1, - params=["bias", "weight"], - ) + """Store parameter names and derivatives object.""" + super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/diag_hessian/conv2d.py b/backpack/extensions/secondorder/diag_hessian/conv2d.py index 35a80bb99..fe7d71a75 100644 --- a/backpack/extensions/secondorder/diag_hessian/conv2d.py +++ b/backpack/extensions/secondorder/diag_hessian/conv2d.py @@ -1,11 +1,22 @@ +"""Module extensions for diagonal Hessian properties of ``torch.nn.Conv2d``.""" from backpack.core.derivatives.conv2d import Conv2DDerivatives -from backpack.extensions.secondorder.diag_hessian.convnd import DiagHConvND +from backpack.extensions.secondorder.diag_hessian.convnd import ( + BatchDiagHConvND, + DiagHConvND, +) class DiagHConv2d(DiagHConvND): + """Module extension for the Hessian diagonal of ``torch.nn.Conv2d``.""" + + def __init__(self): + """Store parameter names and derivatives object.""" + super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) + + +class BatchDiagHConv2d(BatchDiagHConvND): + """Module extension for the per-sample Hessian diagonal of ``torch.nn.Conv2d``.""" + def __init__(self): - super().__init__( - derivatives=Conv2DDerivatives(), - N=2, - params=["bias", "weight"], - ) + """Store parameter names and derivatives object.""" + super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/diag_hessian/conv3d.py b/backpack/extensions/secondorder/diag_hessian/conv3d.py index 8323b5aec..0b2c9e5eb 100644 --- a/backpack/extensions/secondorder/diag_hessian/conv3d.py +++ b/backpack/extensions/secondorder/diag_hessian/conv3d.py @@ -1,11 +1,22 @@ +"""Module extensions for diagonal Hessian properties of ``torch.nn.Conv3d``.""" from backpack.core.derivatives.conv3d import Conv3DDerivatives -from backpack.extensions.secondorder.diag_hessian.convnd import DiagHConvND +from backpack.extensions.secondorder.diag_hessian.convnd import ( + BatchDiagHConvND, + DiagHConvND, +) class DiagHConv3d(DiagHConvND): + """Module extension for the Hessian diagonal of ``torch.nn.Conv3d``.""" + + def __init__(self): + """Store parameter names and derivatives object.""" + super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) + + +class BatchDiagHConv3d(BatchDiagHConvND): + """Module extension for the per-sample Hessian diagonal of ``torch.nn.Conv3d``.""" + def __init__(self): - super().__init__( - derivatives=Conv3DDerivatives(), - N=3, - params=["bias", "weight"], - ) + """Store parameter names and derivatives object.""" + super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/diag_hessian/convnd.py b/backpack/extensions/secondorder/diag_hessian/convnd.py index 75408a54c..2f3544573 100644 --- a/backpack/extensions/secondorder/diag_hessian/convnd.py +++ b/backpack/extensions/secondorder/diag_hessian/convnd.py @@ -6,27 +6,67 @@ class DiagHConvND(DiagHBaseModule): - def __init__(self, derivatives, N, params=None): - super().__init__(derivatives=derivatives, params=["bias", "weight"]) - self.N = N - def bias(self, ext, module, g_inp, g_out, backproped): sqrt_h_outs = backproped["matrices"] sqrt_h_outs_signs = backproped["signs"] h_diag = torch.zeros_like(module.bias) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag_curr = convUtils.extract_bias_diagonal(module, h_sqrt, self.N) - h_diag.add_(sign * h_diag_curr) + h_diag.add_( + convUtils.extract_bias_diagonal(module, h_sqrt, sum_batch=True), + alpha=sign, + ) + return h_diag def weight(self, ext, module, g_inp, g_out, backproped): sqrt_h_outs = backproped["matrices"] sqrt_h_outs_signs = backproped["signs"] - X = convUtils.unfold_by_conv(module.input0, module) + X = convUtils.unfold_input(module, module.input0) h_diag = torch.zeros_like(module.weight) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag_curr = convUtils.extract_weight_diagonal(module, X, h_sqrt, self.N) - h_diag.add_(sign * h_diag_curr) + h_diag.add_( + convUtils.extract_weight_diagonal(module, X, h_sqrt, sum_batch=True), + alpha=sign, + ) + + return h_diag + + +class BatchDiagHConvND(DiagHBaseModule): + def bias(self, ext, module, g_inp, g_out, backproped): + N = module.input0.shape[0] + sqrt_h_outs = backproped["matrices"] + sqrt_h_outs_signs = backproped["signs"] + h_diag = torch.zeros( + N, *module.bias.shape, device=module.bias.device, dtype=module.bias.dtype + ) + + for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): + h_diag.add_( + convUtils.extract_bias_diagonal(module, h_sqrt, sum_batch=False), + alpha=sign, + ) + + return h_diag + + def weight(self, ext, module, g_inp, g_out, backproped): + N = module.input0.shape[0] + sqrt_h_outs = backproped["matrices"] + sqrt_h_outs_signs = backproped["signs"] + X = convUtils.unfold_input(module, module.input0) + h_diag = torch.zeros( + N, + *module.weight.shape, + device=module.weight.device, + dtype=module.weight.dtype, + ) + + for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): + h_diag.add_( + convUtils.extract_weight_diagonal(module, X, h_sqrt, sum_batch=False), + alpha=sign, + ) + return h_diag diff --git a/backpack/extensions/secondorder/diag_hessian/convtranspose1d.py b/backpack/extensions/secondorder/diag_hessian/convtranspose1d.py index cf0816761..c8cf20f43 100644 --- a/backpack/extensions/secondorder/diag_hessian/convtranspose1d.py +++ b/backpack/extensions/secondorder/diag_hessian/convtranspose1d.py @@ -1,5 +1,6 @@ from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives from backpack.extensions.secondorder.diag_hessian.convtransposend import ( + BatchDiagHConvTransposeND, DiagHConvTransposeND, ) @@ -7,7 +8,12 @@ class DiagHConvTranspose1d(DiagHConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose1DDerivatives(), - N=1, - params=["bias", "weight"], + derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] + ) + + +class BatchDiagHConvTranspose1d(BatchDiagHConvTransposeND): + def __init__(self): + super().__init__( + derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/secondorder/diag_hessian/convtranspose2d.py b/backpack/extensions/secondorder/diag_hessian/convtranspose2d.py index 627a35538..a7ea11c51 100644 --- a/backpack/extensions/secondorder/diag_hessian/convtranspose2d.py +++ b/backpack/extensions/secondorder/diag_hessian/convtranspose2d.py @@ -1,5 +1,6 @@ from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives from backpack.extensions.secondorder.diag_hessian.convtransposend import ( + BatchDiagHConvTransposeND, DiagHConvTransposeND, ) @@ -7,7 +8,12 @@ class DiagHConvTranspose2d(DiagHConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose2DDerivatives(), - N=2, - params=["bias", "weight"], + derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] + ) + + +class BatchDiagHConvTranspose2d(BatchDiagHConvTransposeND): + def __init__(self): + super().__init__( + derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/secondorder/diag_hessian/convtranspose3d.py b/backpack/extensions/secondorder/diag_hessian/convtranspose3d.py index 142af6319..9b83eba9b 100644 --- a/backpack/extensions/secondorder/diag_hessian/convtranspose3d.py +++ b/backpack/extensions/secondorder/diag_hessian/convtranspose3d.py @@ -1,5 +1,6 @@ from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives from backpack.extensions.secondorder.diag_hessian.convtransposend import ( + BatchDiagHConvTransposeND, DiagHConvTransposeND, ) @@ -7,7 +8,12 @@ class DiagHConvTranspose3d(DiagHConvTransposeND): def __init__(self): super().__init__( - derivatives=ConvTranspose3DDerivatives(), - N=3, - params=["bias", "weight"], + derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] + ) + + +class BatchDiagHConvTranspose3d(BatchDiagHConvTransposeND): + def __init__(self): + super().__init__( + derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/secondorder/diag_hessian/convtransposend.py b/backpack/extensions/secondorder/diag_hessian/convtransposend.py index 370c8db18..eb7e6147e 100644 --- a/backpack/extensions/secondorder/diag_hessian/convtransposend.py +++ b/backpack/extensions/secondorder/diag_hessian/convtransposend.py @@ -6,18 +6,17 @@ class DiagHConvTransposeND(DiagHBaseModule): - def __init__(self, derivatives, N, params=None): - super().__init__(derivatives=derivatives, params=["bias", "weight"]) - self.N = N - def bias(self, ext, module, g_inp, g_out, backproped): sqrt_h_outs = backproped["matrices"] sqrt_h_outs_signs = backproped["signs"] h_diag = torch.zeros_like(module.bias) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag_curr = convUtils.extract_bias_diagonal(module, h_sqrt, self.N) - h_diag.add_(sign * h_diag_curr) + h_diag.add_( + convUtils.extract_bias_diagonal(module, h_sqrt, sum_batch=True), + alpha=sign, + ) + return h_diag def weight(self, ext, module, g_inp, g_out, backproped): @@ -27,6 +26,47 @@ def weight(self, ext, module, g_inp, g_out, backproped): h_diag = torch.zeros_like(module.weight) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag_curr = convUtils.extract_weight_diagonal(module, X, h_sqrt, self.N) - h_diag.add_(sign * h_diag_curr) + h_diag.add_( + convUtils.extract_weight_diagonal(module, X, h_sqrt, sum_batch=True), + alpha=sign, + ) + + return h_diag + + +class BatchDiagHConvTransposeND(DiagHBaseModule): + def bias(self, ext, module, g_inp, g_out, backproped): + N = module.input0.shape[0] + sqrt_h_outs = backproped["matrices"] + sqrt_h_outs_signs = backproped["signs"] + h_diag = torch.zeros( + N, *module.bias.shape, device=module.bias.device, dtype=module.bias.dtype + ) + + for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): + h_diag.add_( + convUtils.extract_bias_diagonal(module, h_sqrt, sum_batch=False), + alpha=sign, + ) + + return h_diag + + def weight(self, ext, module, g_inp, g_out, backproped): + N = module.input0.shape[0] + sqrt_h_outs = backproped["matrices"] + sqrt_h_outs_signs = backproped["signs"] + X = convUtils.unfold_by_conv_transpose(module.input0, module) + h_diag = torch.zeros( + N, + *module.weight.shape, + device=module.weight.device, + dtype=module.weight.dtype, + ) + + for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): + h_diag.add_( + convUtils.extract_weight_diagonal(module, X, h_sqrt, sum_batch=False), + alpha=sign, + ) + return h_diag diff --git a/backpack/extensions/secondorder/diag_hessian/linear.py b/backpack/extensions/secondorder/diag_hessian/linear.py index 9b6a823eb..be0d8b949 100644 --- a/backpack/extensions/secondorder/diag_hessian/linear.py +++ b/backpack/extensions/secondorder/diag_hessian/linear.py @@ -15,8 +15,11 @@ def bias(self, ext, module, g_inp, g_out, backproped): h_diag = torch.zeros_like(module.bias) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag_curr = LinUtils.extract_bias_diagonal(module, h_sqrt) - h_diag.add_(sign * h_diag_curr) + h_diag.add_( + LinUtils.extract_bias_diagonal(module, h_sqrt, sum_batch=True), + alpha=sign, + ) + return h_diag def weight(self, ext, module, g_inp, g_out, backproped): @@ -25,6 +28,48 @@ def weight(self, ext, module, g_inp, g_out, backproped): h_diag = torch.zeros_like(module.weight) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag_curr = LinUtils.extract_weight_diagonal(module, h_sqrt) - h_diag.add_(sign * h_diag_curr) + h_diag.add_( + LinUtils.extract_weight_diagonal(module, h_sqrt, sum_batch=True), + alpha=sign, + ) + + return h_diag + + +class BatchDiagHLinear(DiagHBaseModule): + def __init__(self): + super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) + + def bias(self, ext, module, g_inp, g_out, backproped): + N = module.input0.shape[0] + sqrt_h_outs = backproped["matrices"] + sqrt_h_outs_signs = backproped["signs"] + h_diag = torch.zeros( + N, *module.bias.shape, device=module.bias.device, dtype=module.bias.dtype + ) + + for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): + h_diag.add_( + LinUtils.extract_bias_diagonal(module, h_sqrt, sum_batch=False), + alpha=sign, + ) + + return h_diag + + def weight(self, ext, module, g_inp, g_out, backproped): + N = module.input0.shape[0] + sqrt_h_outs = backproped["matrices"] + sqrt_h_outs_signs = backproped["signs"] + h_diag = torch.zeros( + N, + *module.weight.shape, + device=module.weight.device, + dtype=module.weight.dtype, + ) + + for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): + h_diag.add_( + LinUtils.extract_weight_diagonal(module, h_sqrt, sum_batch=False), + alpha=sign, + ) return h_diag diff --git a/backpack/extensions/secondorder/hbp/conv2d.py b/backpack/extensions/secondorder/hbp/conv2d.py index c797b6e87..2def51e66 100644 --- a/backpack/extensions/secondorder/hbp/conv2d.py +++ b/backpack/extensions/secondorder/hbp/conv2d.py @@ -14,6 +14,13 @@ def __init__(self): super().__init__(derivatives=Conv2DDerivatives(), params=["weight", "bias"]) def weight(self, ext, module, g_inp, g_out, backproped): + + if module.groups != 1: + raise NotImplementedError( + f"groups ≠ 1 is not supported by {ext.__class__.__name__} " + + f"(got {module.groups})." + ) + bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): @@ -34,7 +41,7 @@ def _weight_for_sqrt(self, ext, module, backproped): return kron_factors def _factors_from_input(self, ext, module): - X = convUtils.unfold_func(module)(module.input0) + X = convUtils.unfold_input(module, module.input0) batch = X.size(0) ea_strategy = ext.get_ea_strategy() @@ -72,6 +79,3 @@ def _factor_from_batch_average(self, module, backproped): # sum over spatial coordinates result = backproped.view(out_c, out_pixels, out_c, out_pixels).sum([1, 3]) return result.contiguous() - - -EXTENSIONS = [HBPConv2d()] diff --git a/backpack/utils/conv.py b/backpack/utils/conv.py index 157db3647..14d394f54 100644 --- a/backpack/utils/conv.py +++ b/backpack/utils/conv.py @@ -1,92 +1,109 @@ import torch from einops import rearrange from torch import einsum -from torch.nn import Unfold -from torch.nn.functional import conv1d, conv2d, conv3d +from torch.nn.functional import conv1d, conv2d, conv3d, unfold -def unfold_func(module): - return Unfold( - kernel_size=module.kernel_size, - dilation=module.dilation, - padding=module.padding, - stride=module.stride, - ) +def unfold_input(module, input): + """Return unfolded input to a convolution. + Use PyTorch's ``unfold`` operation for 2d convolutions (4d input tensors), + otherwise fall back to a custom implementation. -def get_weight_gradient_factors(input, grad_out, module, N): - # shape [N, C_in * K_x * K_y, H_out * W_out] - if N == 1: - X = unfold_by_conv(module.input0, module) - elif N == 2: - X = unfold_func(module)(input) - elif N == 3: - X = unfold_by_conv(module.input0, module) + Args: + module (torch.nn.Conv1d or torch.nn.Conv2d or torch.nn.Conv3d): Convolution + module whose hyperparameters are used for the unfold. + input (torch.Tensor): Input to convolution that will be unfolded. + + Returns: + torch.Tensor: Unfolded input. + """ + if input.dim() == 4: + return unfold( + input, + kernel_size=module.kernel_size, + dilation=module.dilation, + padding=module.padding, + stride=module.stride, + ) else: - raise ValueError("{}-dimensional Conv. is not implemented.".format(N)) + return unfold_by_conv(input, module) + +def get_weight_gradient_factors(input, grad_out, module, N): + X = unfold_input(module, input) dE_dY = rearrange(grad_out, "n c ... -> n c (...)") return X, dE_dY -def get_bias_gradient_factors(gradient, C_axis, N): - if N == 1: - bias_gradient = (einsum("ncl->nc", gradient) ** 2).sum(C_axis) - elif N == 2: - bias_gradient = (einsum("nchw->nc", gradient) ** 2).sum(C_axis) - elif N == 3: - bias_gradient = (einsum("ncdhw->nc", gradient) ** 2).sum(C_axis) - else: - raise ValueError("{}-dimensional Conv. is not implemented.".format(N)) - return bias_gradient - - def separate_channels_and_pixels(module, tensor): """Reshape (V, N, C, H, W) into (V, N, C, H * W).""" return rearrange(tensor, "v n c ... -> v n c (...)") -def extract_weight_diagonal(module, input, grad_output, N, sum_batch=True): - """ - input must be the unfolded input to the convolution (see unfold_func) - and grad_output the backpropagated gradient - """ - V_axis, N_axis = 0, 1 - grad_output_viewed = separate_channels_and_pixels(module, grad_output) - AX = einsum("nkl,vnml->vnkm", (input, grad_output_viewed)) - N = AX.shape[N_axis] - sum_dims = [V_axis, N_axis] if sum_batch else [V_axis] - transpose_dims = (V_axis, N_axis) if sum_batch else (V_axis + 1, N_axis + 1) - weight_diagonal = (AX ** 2).sum(sum_dims).transpose(*transpose_dims) - if sum_batch: - return weight_diagonal.view_as(module.weight) - else: - return weight_diagonal.reshape(N, *module.weight.shape) +def extract_weight_diagonal(module, unfolded_input, S, sum_batch=True): + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the weight Jacobian. + Args: + module (torch.nn.Conv1d or torch.nn.Conv2d or torch.nn.Conv3d): Convolution + layer for which the diagonal is extracted w.r.t. the weight. + unfolded_input (torch.Tensor): Unfolded input to the convolution. Shape must + follow the conventions of ``torch.nn.Unfold``. + S (torch.Tensor): Backpropagated (symmetric factorization) of the loss Hessian. + Has shape ``(V, *module.output.shape)``. + sum_batch (bool, optional): Sum out the batch dimension of the weight diagonals. + Default value: ``True``. -def extract_bias_diagonal(module, sqrt, N, sum_batch=True): + Returns: + torch.Tensor: Per-sample weight diagonal if ``sum_batch=False`` (shape + ``(N, module.weight.shape)`` with batch size ``N``) or summed weight + diagonal if ``sum_batch=True`` (shape ``module.weight.shape``). """ - `sqrt` must be the backpropagated quantity for DiagH or DiagGGN(MC) + S = rearrange(S, "v n (g c) ... -> v n g c (...)", g=module.groups) + unfolded_input = rearrange(unfolded_input, "n (g c) k -> n g c k", g=module.groups) + + JS = einsum("ngkl,vngml->vngmk", (unfolded_input, S)) + + sum_dims = [0, 1] if sum_batch else [0] + out_shape = ( + module.weight.shape if sum_batch else (JS.shape[1], *module.weight.shape) + ) + + weight_diagonal = JS.pow_(2).sum(sum_dims).reshape(out_shape) + + return weight_diagonal + + +# TODO This method applies the bias Jacobian, then squares and sums the result. Intro- +# duce base class for {Batch}DiagHessian and DiagGGN{Exact,MC} and remove this method +def extract_bias_diagonal(module, S, sum_batch=True): + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the bias Jacobian. + + Args: + module (torch.nn.Conv1d or torch.nn.Conv2d or torch.nn.Conv3d): Convolution + layer for which the diagonal is extracted w.r.t. the bias. + S (torch.Tensor): Backpropagated (symmetric factorization) of the loss Hessian. + Has shape ``(V, *module.output.shape)``. + sum_batch (bool, optional): Sum out the batch dimension of the bias diagonals. + Default value: ``True``. + + Returns: + torch.Tensor: Per-sample bias diagonal if ``sum_batch=False`` (shape + ``(N, module.bias.shape)`` with batch size ``N``) or summed bias + diagonal if ``sum_batch=True`` (shape ``module.bias.shape``). """ - V_axis, N_axis = 0, 1 - - if N == 1: - einsum_eq = "vncl->vnc" - elif N == 2: - einsum_eq = "vnchw->vnc" - elif N == 3: - einsum_eq = "vncdhw->vnc" - else: - ValueError("{}-dimensional Conv. is not implemented.".format(N)) - sum_dims = [V_axis, N_axis] if sum_batch else [V_axis] - return (einsum(einsum_eq, sqrt) ** 2).sum(sum_dims) + start_spatial = 3 + sum_before = list(range(start_spatial, S.dim())) + sum_after = [0, 1] if sum_batch else [0] + + return S.sum(sum_before).pow_(2).sum(sum_after) def unfold_by_conv(input, module): """Return the unfolded input using convolution""" N, C_in = input.shape[0], input.shape[1] kernel_size = module.kernel_size - kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size))) + kernel_size_numel = module.weight.shape[2:].numel() def make_weight(): weight = torch.zeros(kernel_size_numel, 1, *kernel_size) diff --git a/backpack/utils/conv_transpose.py b/backpack/utils/conv_transpose.py index 777445d76..d1183e909 100644 --- a/backpack/utils/conv_transpose.py +++ b/backpack/utils/conv_transpose.py @@ -1,15 +1,16 @@ +"""Utility functions for extracting transpose convolution BackPACK quantities.""" + import torch from einops import rearrange from torch import einsum from torch.nn.functional import conv_transpose1d, conv_transpose2d, conv_transpose3d -from backpack.utils.conv import separate_channels_and_pixels +from backpack.utils.conv import extract_bias_diagonal as conv_extract_bias_diagonal def get_weight_gradient_factors(input, grad_out, module, N): M, C_in = input.shape[0], input.shape[1] - kernel_size = module.kernel_size - kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size))) + kernel_size_numel = module.weight.shape[2:].numel() X = unfold_by_conv_transpose(input, module).reshape(M, C_in * kernel_size_numel, -1) dE_dY = rearrange(grad_out, "n c ... -> n c (...)") @@ -17,59 +18,84 @@ def get_weight_gradient_factors(input, grad_out, module, N): return X, dE_dY -def extract_weight_diagonal(module, input, grad_output, N, sum_batch=True): - """ - input must be the unfolded input to the convolution (see unfold_func) - and grad_output the backpropagated gradient - """ - out_channels = module.weight.shape[0] - in_channels = module.weight.shape[1] - k = module.weight.shape[2:] - - M = input.shape[0] - output_shape = module.output.shape - spatial_out_size = output_shape[2:] - spatial_out_numel = spatial_out_size.numel() - - input_reshaped = input.reshape(M, -1, spatial_out_numel) - - V_axis, N_axis = 0, 1 - grad_output_viewed = separate_channels_and_pixels(module, grad_output) - AX = einsum("nkl,vnml->vnkm", (input_reshaped, grad_output_viewed)) - N = AX.shape[N_axis] - sum_dims = [V_axis, N_axis] if sum_batch else [V_axis] - transpose_dims = (V_axis, N_axis) if sum_batch else (V_axis + 1, N_axis + 1) - weight_diagonal = (AX ** 2).sum(sum_dims).transpose(*transpose_dims) - if sum_batch: - weight_diagonal = weight_diagonal.reshape(in_channels, out_channels, *k) - else: - weight_diagonal = weight_diagonal.reshape(N, in_channels, out_channels, *k) - return weight_diagonal.transpose(*transpose_dims) - - -def extract_bias_diagonal(module, sqrt, N, sum_batch=True): - """ - `sqrt` must be the backpropagated quantity for DiagH or DiagGGN(MC) +def extract_weight_diagonal(module, unfolded_input, S, sum_batch=True): + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the weight Jacobian. + + Args: + module (torch.nn.ConvTranspose1d or torch.nn.ConvTranspose2d or + torch.nn.ConvTranspose3d ): Convolution layer for which the diagonal is + extracted w.r.t. the weight. + unfolded_input (torch.Tensor): Unfolded input to the transpose convolution. + S (torch.Tensor): Backpropagated (symmetric factorization) of the loss Hessian. + Has shape ``(V, *module.output.shape)``. + sum_batch (bool, optional): Sum out the batch dimension of the weight diagonals. + Default value: ``True``. + + Returns: + torch.Tensor: Per-sample weight diagonal if ``sum_batch=False`` (shape + ``(N, module.weight.shape)`` with batch size ``N``) or summed weight + diagonal if ``sum_batch=True`` (shape ``module.weight.shape``). """ - V_axis, N_axis = 0, 1 + S = rearrange(S, "v n (g o) ... -> v n g o (...)", g=module.groups) + unfolded_input = rearrange( + unfolded_input, + "n (g c) (k x) -> n g c k x", + g=module.groups, + k=module.weight.shape[2:].numel(), + ) + + JS = einsum("ngckx,vngox->vngcok", (unfolded_input, S)) - if N == 1: - einsum_eq = "vncl->vnc" - elif N == 2: - einsum_eq = "vnchw->vnc" - elif N == 3: - einsum_eq = "vncdhw->vnc" - else: - ValueError("{}-dimensional ConvTranspose is not implemented.".format(N)) - sum_dims = [V_axis, N_axis] if sum_batch else [V_axis] - return (einsum(einsum_eq, sqrt) ** 2).sum(sum_dims) + sum_dims = [0, 1] if sum_batch else [0] + out_shape = ( + module.weight.shape if sum_batch else (JS.shape[1], *module.weight.shape) + ) + + weight_diagonal = JS.pow_(2).sum(sum_dims).reshape(out_shape) + + return weight_diagonal + + +# TODO This method applies the bias Jacobian, then squares and sums the result. Intro- +# duce base class for {Batch}DiagHessian and DiagGGN{Exact,MC} and remove this method +def extract_bias_diagonal(module, S, sum_batch=True): + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the weight Jacobian. + + Args: + module (torch.nn.ConvTranspose1d or torch.nn.ConvTranspose2d or + torch.nn.ConvTranspose3d ): Convolution layer for which the diagonal is + extracted w.r.t. the bias. + unfolded_input (torch.Tensor): Unfolded input to the transpose convolution. + S (torch.Tensor): Backpropagated (symmetric factorization) of the loss Hessian. + Has shape ``(V, *module.output.shape)``. + sum_batch (bool, optional): Sum out the batch dimension of the bias diagonals. + Default value: ``True``. + + Returns: + torch.Tensor: Per-sample bias diagonal if ``sum_batch=False`` (shape + ``(N, module.bias.shape)`` with batch size ``N``) or summed bias + diagonal if ``sum_batch=True`` (shape ``module.bias.shape``). + """ + return conv_extract_bias_diagonal(module, S, sum_batch=sum_batch) def unfold_by_conv_transpose(input, module): - """Return the unfolded input using transpose convolution.""" + """Return the unfolded input using one-hot transpose convolution. + + Args: + input (torch.Tensor): Input to a transpose convolution. + module (torch.nn.ConvTranspose1d or torch.nn.ConvTranspose2d or + torch.nn.ConvTranspose3d): Transpose convolution layer that specifies + the hyperparameters for unfolding. + + Returns: + torch.Tensor: Unfolded input of shape ``(N, C, K * X)`` with + ``K = module.weight.shape[2:].numel()`` the number of kernel elements + and ``X = module.output.shape[2:].numel()`` the number of output pixels. + """ N, C_in = input.shape[0], input.shape[1] kernel_size = module.kernel_size - kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size))) + kernel_size_numel = module.weight.shape[2:].numel() def make_weight(): weight = torch.zeros(1, kernel_size_numel, *kernel_size) @@ -102,4 +128,4 @@ def get_conv_transpose(): groups=C_in, ) - return unfold.reshape(N, -1, kernel_size_numel) + return unfold.reshape(N, C_in, -1) diff --git a/backpack/utils/unsqueeze.py b/backpack/utils/unsqueeze.py index 21a5a16b0..d5782c2ed 100644 --- a/backpack/utils/unsqueeze.py +++ b/backpack/utils/unsqueeze.py @@ -1,44 +1,6 @@ import functools -def jmp_unsqueeze_if_missing_dim(mat_dim): - """Allow Jacobian-matrix routines to do Jacobian-vector products.""" - - def jmp_wrapper(jmp): - @functools.wraps(jmp) - def wrapped_jmp_support_jvp(self, module, g_inp, g_out, mat, **kwargs): - is_vec = len(mat.shape) == mat_dim - 1 - mat_used = mat.unsqueeze(-1) if is_vec else mat - result = jmp(self, module, g_inp, g_out, mat_used, **kwargs) - if is_vec: - return result.squeeze(-1) - else: - return result - - return wrapped_jmp_support_jvp - - return jmp_wrapper - - -def hmp_unsqueeze_if_missing_dim(mat_dim): - """Allow Hessian-matrix routines to do Hessian-vector products.""" - - def hmp_wrapper(hmp): - @functools.wraps(hmp) - def wrapped_hmp_support_hvp(mat): - is_vec = len(mat.shape) == mat_dim - 1 - mat_used = mat.unsqueeze(-1) if is_vec else mat - result = hmp(mat_used) - if is_vec: - return result.squeeze(-1) - else: - return result - - return wrapped_hmp_support_hvp - - return hmp_wrapper - - def kfacmp_unsqueeze_if_missing_dim(mat_dim): """ Allows Kronecker-factored matrix-matrix routines to do matrix-vector products. diff --git a/black.toml b/black.toml index 3c9ec678f..b1ddb483f 100644 --- a/black.toml +++ b/black.toml @@ -9,7 +9,8 @@ exclude = ''' | \.git | \.pytest_cache | \.benchmarks - | docs_src + | docs_src/rtd + | docs_src/rtd_output | docs | build | dist diff --git a/changelog.md b/changelog.md index 66e9fbe5e..6a00293f4 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,107 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.3.0] - 2021-06-16 + +Thanks to [@sbharadwajj](https://github.com/sbharadwajj) +and [@schaefertim](https://github.com/schaefertim) for +co-authoring many PRs shipped in this release. + +### Added +- New extensions + - `BatchDiagGGN{Exact,MC}`: Per sample diagonal of the GGN/Fisher, + exact or with a Monte-Carlo approximation + [[PR1](https://github.com/f-dangel/backpack/pull/135), + [PR2](https://github.com/f-dangel/backpack/pull/139), + [PR3](https://github.com/f-dangel/backpack/pull/170), + [example](https://docs.backpack.pt/en/1.3.0/basic_usage/example_all_in_one.html)] + - `BatchDiagHessian`: Per sample diagonal of the Hessian + [[PR1](https://github.com/f-dangel/backpack/pull/137), + [PR2](https://github.com/f-dangel/backpack/pull/170), + [example](https://docs.backpack.pt/en/1.3.0/basic_usage/example_all_in_one.html)] +- Support for more layers + ([[PR](https://github.com/f-dangel/backpack/pull/171), + [overview](https://docs.backpack.pt/en/1.3.0/supported-layers.html)]) + - `DiagGGN{Exact,MC}` extensions + - `Conv{1,3}d`, `ConvTranspose{1,2,3}d`, `LeakyReLU`, + `LogSigmoid`, `ELU`, `SELU` + [[PR](https://github.com/f-dangel/backpack/pull/113)] + - `MaxPool{1,3}d` + [[PR](https://github.com/f-dangel/backpack/pull/125)] + - `AvgPool{1,3}d` + [[PR](https://github.com/f-dangel/backpack/pull/128)] + - `DiagHessian` extension + - `Conv{1,3}d`, `ConvTranspose{1,2,3}d`, `LeakyReLU`, + `LogSigmoid` + [[PR](https://github.com/f-dangel/backpack/pull/115)] + - `MaxPool{1,3}d` + [[PR](https://github.com/f-dangel/backpack/pull/124)] + - `AvgPool{1,3}d` + [[PR](https://github.com/f-dangel/backpack/pull/127)] + - `ELU`, `SELU` + [[PR](https://github.com/f-dangel/backpack/pull/168)] + - `group` argument of (transpose) convolutions + - Full support for first-order diagonal curvature extensions + [[PR1](https://github.com/f-dangel/backpack/pull/151), + [PR2](https://github.com/f-dangel/backpack/pull/161), + [PR3](https://github.com/f-dangel/backpack/pull/162), + [PR4](https://github.com/f-dangel/backpack/pull/163)] + - No support (yet) for `KFAC`, `KFLR` and `KFRA` extensions + [[PR](https://github.com/f-dangel/backpack/pull/167)] +- Extension hook which allows to run code right after a BackPACK extension + [[PR](https://github.com/f-dangel/backpack/pull/120), + [example](https://docs.backpack.pt/en/development/use_cases/example_extension_hook.html)] +- Context to disable BackPACK + [[PR](https://github.com/f-dangel/backpack/pull/119)] +- Tutorial how to extend custom modules + [[PR](https://github.com/f-dangel/backpack/pull/152), + [example](https://docs.backpack.pt/en/development/use_cases/example_custom_module.html)] +- (Experimental) Alternative convolution weight Jacobian with option to save memory + [[PR](https://github.com/f-dangel/backpack/pull/142), + [example](https://docs.backpack.pt/en/development/use_cases/example_save_memory_convolutions.html)] + +### Fixed/Removed +- Remove hooks that save input/output shapes. This probably + resolves [#97](https://github.com/f-dangel/backpack/issues/97) + [[PR](https://github.com/f-dangel/backpack/pull/118)] +- Remove `DiagGGN` from API (use `DiagGGNExact` instead). It was + indented as abstract parent class for `DiagGGNExact` and `DiagGGNMC` + [[PR](https://github.com/f-dangel/backpack/pull/138)] + +### Internal +- CI + - Move tests from Travis to GitHub actions + [[PR](https://github.com/f-dangel/backpack/pull/118), + [small fix](https://github.com/f-dangel/backpack/pull/130)] + - Test `DiagHessian` with new test suite + [[PR](https://github.com/f-dangel/backpack/pull/114)] + - Test `DiagGGN` with new test suite, introduce 'light' and + 'full' tests + [[PR1](https://github.com/f-dangel/backpack/pull/112), + [PR2](https://github.com/f-dangel/backpack/pull/140)] + - Fix `isort` + [[PR](https://github.com/f-dangel/backpack/pull/144)] + - Add partial docstring checks + [[PR](https://github.com/f-dangel/backpack/pull/157)] + - Add docstrings to contribution guide lines + [[commit](https://github.com/f-dangel/backpack/commit/42897ca6dff1a5cd4a4d17d78dc9e309fa3ee178)] + - Auto-format and lint examples + [[PR](https://github.com/f-dangel/backpack/pull/167)] +- Refactoring + - Share code between `Conv{Transpose}{1,2,3}d` in `BatchL2Grad` + [[PR](https://github.com/f-dangel/backpack/pull/111)] + - Use `eingroup` package, remove custom `eingroup` utility + [[PR](https://github.com/f-dangel/backpack/pull/133)] +- Core + - Implement derivatives for `MaxPool{1,3}d` + [[PR](https://github.com/f-dangel/backpack/pull/129)] + - Implement derivatives for `AvgPool{1,3}d` + [[PR](https://github.com/f-dangel/backpack/pull/126)] + - Support for `groups` in (transpose) convolutions + [[PR1](https://github.com/f-dangel/backpack/pull/151), + [PR2](https://github.com/f-dangel/backpack/pull/161), + [PR3](https://github.com/f-dangel/backpack/pull/162), + [PR4](https://github.com/f-dangel/backpack/pull/163)] ## [1.2.0] - 2020-10-26 @@ -36,8 +137,8 @@ co-authoring many PRs shipped in this release. [[PR](https://github.com/f-dangel/backpack/pull/98)] and Hessian-free optimization with CG [[PR](https://github.com/f-dangel/backpack/pull/99)] -### Fixed +### Fixed - Add missing `zero_grad` in the diagonal GGN second-order optimization example [[PR](https://github.com/f-dangel/backpack/pull/101)] @@ -67,28 +168,28 @@ co-authoring many PRs shipped in this release. ## [1.1.1] - 2020-04-29 ### Added -- Improved documentation, moved to [ReadTheDocs](https://docs.backpack.pt) - [[PR1](https://github.com/f-dangel/backpack/pull/57), +- Improved documentation, moved to [ReadTheDocs](https://docs.backpack.pt) + [[PR1](https://github.com/f-dangel/backpack/pull/57), [PR2](https://github.com/f-dangel/backpack/pull/58), [PR3](https://github.com/f-dangel/backpack/pull/66)] - Tested compatibility with PyTorch 1.5.0. -- Support 2nd-order backprop for vectors in `MSELoss` +- Support 2nd-order backprop for vectors in `MSELoss` [[PR](https://github.com/f-dangel/backpack/pull/61)] - Sanity checks to raise warnings if the following are used. - `inplace` modification + `inplace` modification [[PR](https://github.com/f-dangel/backpack/pull/59)], - unsupported loss parameters + unsupported loss parameters [[PR](https://github.com/f-dangel/backpack/pull/60)], - custom losses in 2nd-order backpropagation + custom losses in 2nd-order backpropagation [[PR](https://github.com/f-dangel/backpack/pull/60)] ### Fixed -- Removed `opt_einsum` dependency +- Removed `opt_einsum` dependency [[PR](https://github.com/f-dangel/backpack/pull/54)] -- Missing implementations and wrong backpropagation of KFRA - for `Conv2d`, `MaxPool2d`, and `AvgPool2d` +- Missing implementations and wrong backpropagation of KFRA + for `Conv2d`, `MaxPool2d`, and `AvgPool2d` [[PR](https://github.com/f-dangel/backpack/pull/53)] -- Remove `try_view` and use `reshape` to use PyTorch 1.4.0 improvements +- Remove `try_view` and use `reshape` to use PyTorch 1.4.0 improvements [[PR](https://github.com/f-dangel/backpack/pull/50)] ### Internal @@ -98,42 +199,43 @@ co-authoring many PRs shipped in this release. ## [1.1.0] - 2020-02-11 ### Added -- Support MC sampling +- Support MC sampling [[Issue](https://github.com/f-dangel/backpack/issues/21), [PR](https://github.com/f-dangel/backpack/pull/36)] -- Utilities to handle Kronecker factors +- Utilities to handle Kronecker factors [[PR](https://github.com/f-dangel/backpack/pull/17)] -- Examples +- Examples [[PR](https://github.com/f-dangel/backpack/pull/34)] - + ### Fixed -- Fixed documentation issue in `Batch l2` +- Fixed documentation issue in `Batch l2` [[PR](https://github.com/f-dangel/backpack/pull/33)] -- Added support for stride parameter in Conv2d - [[Issue](https://github.com/f-dangel/backpack/issues/30), +- Added support for stride parameter in Conv2d + [[Issue](https://github.com/f-dangel/backpack/issues/30), [PR](https://github.com/f-dangel/backpack/pull/31)] -- Pytorch `1.3.0` compatibility - [[PR](https://github.com/f-dangel/backpack/pull/8), +- Pytorch `1.3.0` compatibility + [[PR](https://github.com/f-dangel/backpack/pull/8), [PR](https://github.com/f-dangel/backpack/pull/9)] - + ### Internal -- Added +- Added continuous integration [[PR](https://github.com/f-dangel/backpack/pull/19)], test coverage [[PR](https://github.com/f-dangel/backpack/pull/25)], style guide enforcement [[PR](https://github.com/f-dangel/backpack/pull/27)] -- Changed internal shape conventions of backpropagated quantities for performance improvements +- Changed internal shape conventions of backpropagated quantities for performance improvements [[PR](https://github.com/f-dangel/backpack/pull/37)] ## [1.0.1] - 2019-09-05 ### Fixed -- Fixed PyPI installaton +- Fixed PyPI installaton -## [1.0.0] - 2019-10-03 +## [1.0.0] - 2019-10-03 Initial release -[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.1.0...HEAD +[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.3.0...HEAD +[1.3.0]: https://github.com/f-dangel/backpack/compare/1.3.0...1.2.0 [1.2.0]: https://github.com/f-dangel/backpack/compare/1.2.0...1.1.1 [1.1.1]: https://github.com/f-dangel/backpack/compare/1.1.0...1.1.1 [1.1.0]: https://github.com/f-dangel/backpack/compare/1.0.1...1.1.0 diff --git a/docs_src/examples/basic_usage/example_all_in_one.py b/docs_src/examples/basic_usage/example_all_in_one.py index bb4e6d2d7..49a8f1e46 100644 --- a/docs_src/examples/basic_usage/example_all_in_one.py +++ b/docs_src/examples/basic_usage/example_all_in_one.py @@ -10,7 +10,7 @@ # %% # Let's start by loading some dummy data and extending the model -from torch import allclose, rand +from torch import rand from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential from backpack import backpack, extend @@ -21,6 +21,9 @@ KFLR, KFRA, PCHMP, + BatchDiagGGNExact, + BatchDiagGGNMC, + BatchDiagHessian, BatchGrad, BatchL2Grad, DiagGGNExact, @@ -111,7 +114,7 @@ # -------------------------- # %% -# Diagonal of the Gauss-Newton and its Monte-Carlo approximation +# Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation loss = lossfunc(model(X), y) with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)): @@ -123,6 +126,19 @@ print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) +# %% +# Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation + +loss = lossfunc(model(X), y) +with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".diag_ggn_mc_batch.shape: ", param.diag_ggn_mc_batch.shape) + print(".diag_ggn_exact_batch.shape: ", param.diag_ggn_exact_batch.shape) + + # %% # KFAC, KFRA and KFLR @@ -138,16 +154,17 @@ print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra]) # %% -# Diagonal Hessian +# Diagonal Hessian and per-sample diagonal Hessian loss = lossfunc(model(X), y) -with backpack(DiagHessian()): +with backpack(DiagHessian(), BatchDiagHessian()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_h.shape: ", param.diag_h.shape) + print(".diag_h_batch.shape: ", param.diag_h_batch.shape) # %% # Block-diagonal curvature products diff --git a/docs_src/examples/use_cases/example_cg_newton.py b/docs_src/examples/use_cases/example_cg_newton.py index 799939290..35f872790 100644 --- a/docs_src/examples/use_cases/example_cg_newton.py +++ b/docs_src/examples/use_cases/example_cg_newton.py @@ -118,14 +118,14 @@ def get_accuracy(output, targets): class CGNOptimizer(torch.optim.Optimizer): def __init__( - self, - parameters, - bp_extension, - lr=0.1, - damping=1e-2, - maxiter=100, - tol=1e-1, - atol=1e-8, + self, + parameters, + bp_extension, + lr=0.1, + damping=1e-2, + maxiter=100, + tol=1e-1, + atol=1e-8, ): super().__init__( parameters, diff --git a/docs_src/examples/use_cases/example_custom_module.py b/docs_src/examples/use_cases/example_custom_module.py index e027d3608..e59b10c8d 100644 --- a/docs_src/examples/use_cases/example_custom_module.py +++ b/docs_src/examples/use_cases/example_custom_module.py @@ -3,10 +3,11 @@ This tutorial shows how to support a custom module in a simple fashion. We focus on `BackPACK's first-order extensions `_. -They don't backpropagate additional information and thus require less functionality be implemented. +They don't backpropagate additional information and thus require less functionality be +implemented. Let's get the imports out of our way. -""" +""" # noqa: B950 import torch diff --git a/docs_src/examples/use_cases/example_first_order_resnet.py b/docs_src/examples/use_cases/example_first_order_resnet.py index 0c25633b1..df3ded93d 100644 --- a/docs_src/examples/use_cases/example_first_order_resnet.py +++ b/docs_src/examples/use_cases/example_first_order_resnet.py @@ -7,11 +7,11 @@ # Let's get the imports, configuration and some helper functions out of the way first. import torch +import torch.nn.functional as F from backpack import backpack, extend from backpack.extensions import BatchGrad from backpack.utils.examples import load_one_batch_mnist -import torch.nn.functional as F BATCH_SIZE = 3 torch.manual_seed(0) @@ -95,7 +95,7 @@ def forward(self, x): loss = F.cross_entropy(model(x_to_check), y_to_check) loss.backward() -print("Do the individual gradient match?") -for param_id, (name, p) in enumerate(model.named_parameters()): +print("Do the individual gradients match?") +for name, p in model.named_parameters(): match = torch.allclose(p.grad_batch[sample_to_check, :], p.grad, atol=1e-7) print("{:<20} {}".format(name, match)) diff --git a/docs_src/examples/use_cases/example_gradient_of_variance.py b/docs_src/examples/use_cases/example_gradient_of_variance.py index 633dfc96d..782b03854 100644 --- a/docs_src/examples/use_cases/example_gradient_of_variance.py +++ b/docs_src/examples/use_cases/example_gradient_of_variance.py @@ -25,11 +25,11 @@ # Let's get the imports and configuration out of the way. import torch +from torch import nn from backpack import backpack, extend from backpack.extensions import Variance from backpack.utils.examples import load_one_batch_mnist -from torch import nn torch.manual_seed(0) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -105,7 +105,7 @@ def individual_gradients_pytorch(x, y, model, lossfunc): grad_list_format = torch.autograd.grad( loss, model.parameters(), create_graph=True, retain_graph=True ) - grad_vector_format = torch.cat([g.view(-1,) for g in grad_list_format]) + grad_vector_format = torch.cat([g.view(-1) for g in grad_list_format]) grads_vector_format.append(grad_vector_format.clone()) return torch.stack(grads_vector_format) diff --git a/docs_src/examples/use_cases/example_save_memory_convolutions.py b/docs_src/examples/use_cases/example_save_memory_convolutions.py new file mode 100644 index 000000000..54ae04518 --- /dev/null +++ b/docs_src/examples/use_cases/example_save_memory_convolutions.py @@ -0,0 +1,238 @@ +"""Saving memory in convolutions +================================ + +There are different approaches to apply the Jacobian with respect to the kernel +of a convolution. They exhibit a non-trivial trade-off between run time and memory +consumption (see more details below). The default choice in BackPACK is a memory- +intensive implementation. This can lead to out-of-memory errors. + +Here, we show how to switch BackPACK's vector-Jacobian product algorithm for the kernel +(``weight``) of :py:class:`torch.nn.Conv2d` modules to a memory-saving variant +presented in `[Rochette, 2019] `_. + +This can be helpful if you are experiencing memory overflows with CNNs. + +.. note :: + This feature is experimental and may change in future releases. + +.. note :: + Currently, the savings only affect BackPACK's first-order extensions. + This may change in future releases. + +""" + +# %% +# Let's get the imports out of our way. + +import time + +import torch +from memory_profiler import memory_usage +from torch.nn import ( + Conv2d, + CrossEntropyLoss, + Flatten, + Linear, + MaxPool2d, + ReLU, + Sequential, +) + +from backpack import backpack, extend, extensions +from backpack.core.derivatives.convnd import weight_jac_t_save_memory +from backpack.utils.examples import load_one_batch_mnist + +# %% +# We start with the utility function for setting up an extended CNN, loss function, and +# input data from MNIST. + + +def setup(device): + """Load MNIST batch, create extended CNN and loss function. Load to device. + + Args: + device (torch.device): Device that all objects are transferred to. + + Returns: + inputs, labels, model, loss function + """ + X, y = load_one_batch_mnist(batch_size=64) + X, y = X.to(device), y.to(device) + + model = extend( + Sequential( + Conv2d(1, 128, 3, padding=1), + ReLU(), + MaxPool2d(3, stride=2), + Conv2d(128, 256, 3, padding=1), + ReLU(), + MaxPool2d(3, padding=1, stride=2), + Conv2d(256, 64, 3, padding=1), + ReLU(), + MaxPool2d(3, stride=2), + Conv2d(64, 32, 3, padding=1), + ReLU(), + MaxPool2d(3, stride=2), + Flatten(), + Linear(32, 10), + ).to(device) + ) + + lossfunc = extend(CrossEntropyLoss().to(device)) + + return X, y, model, lossfunc + + +# %% +# Let's demonstrate the differences between the vector-Jacobian methods. we benchmark +# the following function that computes individual gradients on the specified setup +# using BackPACK's :py:class:`BatchGrad ` extensions. + + +def compute_individual_gradients(device, seed=0): + """Compute individual gradients for the seeded problem specified in ``setup``. + + Args: + device (torch.device): Device that the computation should be performed on. + seed (int): Random seed to set before setting up the problem. + + Returns: + Dictionary with parameter name and individual gradients as key value pairs. + """ + torch.manual_seed(seed) + + X, y, model, lossfunc = setup(device) + + loss = lossfunc(model(X), y) + + with backpack(extensions.BatchGrad()): + loss.backward() + + return {name: param.grad_batch for name, param in model.named_parameters()} + + +# %% +# The memory-saving strategy is enabled by wrapping the backward pass with BackPACK +# inside :py:class:`weight_jac_t_save_memory` which accepts +# a boolean flag ``save_memory``. + +# %% +# Peak memory comparison +# ---------------------- +# Let's see the differences between both vector-Jacobian methods in terms of peak +# memory consumption. + + +def compare_peakmem(device): + """Print peak memory of different vector-Jacobian algorithms for convolution. + + Peak memory only makes sense when ``device`` is CPU as memory usage on GPU + cannot be tracked by this implementation. + + Args: + device (torch.device): Device that the computation should be performed on. + """ + print(f"Device: {device}") + + for save_memory in True, False: + + with weight_jac_t_save_memory(save_memory=save_memory): + + def work(): + return compute_individual_gradients(device) + + interval = 1e-3 + peakmem = max(memory_usage(work, interval=interval)) + + print(f"Save memory: {save_memory}\tPeak memory: {peakmem:.1f}") + + +compare_peakmem(torch.device("cpu")) + +# %% +# As expected, the backpropagation with ``save_memory=True`` requires less RAM. + + +# %% +# Run time comparison +# ------------------- +# Next, we inspect the run time of both strategies. + + +def compare_runtime(device): + """Print run time of different vector-Jacobian algorithms for convolution. + + Args: + device (torch.device): Device that the computation should be performed on. + """ + print(f"Device: {device}") + + for save_memory in True, False: + + with weight_jac_t_save_memory(save_memory=save_memory): + start = time.time() + + compute_individual_gradients(device) + + if str(device) == "cuda": + torch.cuda.synchronize() + + run_time = time.time() - start + + print(f"Save memory: {save_memory}\tRun time: {run_time:.3f}") + + +compare_runtime(torch.device("cpu")) + +# %% +# In this case, saving memory comes at the cost of reduced run time performance. +# +# If you have a GPU you will see a similar behavior, too: + +if torch.cuda.is_available(): + compare_runtime(torch.device("cuda")) + +# %% +# Let's quickly double-check that both algorithms computed the same result. +device = torch.device("cpu") + +with weight_jac_t_save_memory(save_memory=True): + individual_gradients = compute_individual_gradients(device) + +with weight_jac_t_save_memory(save_memory=False): + save_memory_individual_gradients = compute_individual_gradients(device) + +print(f"{'Parameter':<10}| Same individual gradients?") +for param_name in individual_gradients.keys(): + same = torch.allclose( + individual_gradients[param_name], + save_memory_individual_gradients[param_name], + atol=1e-7, + ) + msg = f"{param_name:<10}| {same}" + + if same: + print(msg) + else: + raise ValueError(msg) + +# %% +# When to enable save memory? +# --------------------------- +# If your program crashes because BackPACK tries to allocate too much memory, you +# should give it a try. Other than that, it is difficult to identify tendencies. +# The trend observed in this example (saving memory means slower run time) does not +# hold true in general, and you may want to compare both approaches for your specific +# setting, like we did here. +# +# You can also take a look at +# `backpack-benchmark `_, +# where BackPACK's run time and peak memory are continuously monitored for some neural +# nets from `DeepOBS `_. +# +# This benchmark can be inspected over the commit history. Commits between +# `567f079b `_ +# and +# `f72f666 `_ +# were performed with ``save_memory=True``. Compare them with any other commit +# benchmarked with ``save_memory=False`` to get an intuition how both algorithms differ. diff --git a/docs_src/examples/use_cases/example_trace_estimation.py b/docs_src/examples/use_cases/example_trace_estimation.py index 618b35d1c..86c24ae78 100644 --- a/docs_src/examples/use_cases/example_trace_estimation.py +++ b/docs_src/examples/use_cases/example_trace_estimation.py @@ -1,8 +1,10 @@ r"""Hutchinson Trace Estimation =============================== -This example illustrates the estimation the Hessian trace of a neural network using Hutchinson's `[Hutchinson, 1990] `_ -method, which is an algorithm to obtain such an an estimate from matrix-vector products: +This example illustrates the estimation the Hessian trace of a neural network using +Hutchinson's method +`[Hutchinson, 1990] `_, +which is an algorithm to obtain such an an estimate from matrix-vector products: .. math:: \text{Let } A \in \mathbb{R}^{D \times D} \text{ and } v \in \mathbb{R}^D @@ -11,9 +13,12 @@ .. math:: \mathrm{Tr}(A) = \mathbb{E}[v^TAv] = \frac{1}{V}\sum_{i=1}^{V}v_i^TAv_i. -We will draw v from a Rademacher Distribution and use Hessian-free multiplication. This can be done with plain autodiff, but -note that there is no dependency between sampled vectors, and the Hessian-vector product (HVP) could in principle be performed in parallel. We can use BackPACK's :code:`HMP` (Hessian-matrix product) extension to do so, and investigate the potential speedup. -""" +We will draw v from a Rademacher Distribution and use Hessian-free multiplication. This +can be done with plain autodiff, but note that there is no dependency between sampled +vectors, and the Hessian-vector product (HVP) could in principle be performed in +parallel. We can use BackPACK's :code:`HMP` (Hessian-matrix product) extension to do so, +and investigate the potential speedup. +""" # noqa: B950 # %% # Let's get the imports and define what a Rademacher distribution is @@ -44,7 +49,8 @@ def rademacher(shape, dtype=torch.float32, device=DEVICE): # Creating the model and loss # --------------------------- # -# We will use a small NN with 2 linear layers without bias (for a bias of size `d`, the exact trace can be obtained in `d` HVPs). +# We will use a small NN with 2 linear layers without bias (for a bias of size `d`, the +# exact trace can be obtained in `d` HVPs). model = torch.nn.Sequential( torch.nn.Flatten(), @@ -58,7 +64,9 @@ def rademacher(shape, dtype=torch.float32, device=DEVICE): loss_function = extend(loss_function) # %% -# In the following, we load a batch from MNIST, compute the loss and trigger the backward pass ``with(backpack(..))`` such that we have access to the extensions that we are going to use (``DiagHessian`` and ``HMP)``). +# In the following, we load a batch from MNIST, compute the loss and trigger the +# backward pass ``with(backpack(..))`` such that we have access to the extensions that +# we are going to use (``DiagHessian`` and ``HMP)``). x, y = load_one_batch_mnist(BATCH_SIZE) x, y = x.to(DEVICE), y.to(DEVICE) @@ -80,7 +88,8 @@ def forward_backward_with_backpack(): # %% # Exact trace computation # ----------------------- -# To make sure our implementation is fine, and to develop a feeling for the Hutchinson estimator quality, let's compute the exact trace by summing up the Hessian diagonal. +# To make sure our implementation is fine, and to develop a feeling for the Hutchinson +# estimator quality, let's compute the exact trace by summing up the Hessian diagonal. def exact_trace(): @@ -94,7 +103,11 @@ def exact_trace(): # %% # Trace estimation (BackPACK's :code:`HMP`) # ----------------------------------------- -# BackPACK's :code:`HMP` extension gives access to multiplication with the parameter Hessian, which is one diagonal block in the full Hessian whose trace we want to estimate. The multiplication can even handle multiple vectors at a time. Here is the implementation. The computation of :code:`V` HVPs, which might exceed our available memory, is chunked into batches of size :code:`V_batch`. +# BackPACK's :code:`HMP` extension gives access to multiplication with the parameter +# Hessian, which is one diagonal block in the full Hessian whose trace we want to +# estimate. The multiplication can even handle multiple vectors at a time. Here is the +# implementation. The computation of :code:`V` HVPs, which might exceed our available +# memory, is chunked into batches of size :code:`V_batch`. def hutchinson_trace_hmp(V, V_batch=1): @@ -129,7 +142,9 @@ def hutchinson_trace_hmp(V, V_batch=1): # %% # Trace estimation (autodiff, full Hessian) # ----------------------------------------- -# We can also use autodiff tricks to compute a single HVP at a time, provided by utility function :code:`hessian_vector_product` in BackPACK. Here is the implementation, and a test: +# We can also use autodiff tricks to compute a single HVP at a time, provided by utility +# function :code:`hessian_vector_product` in BackPACK. Here is the implementation, and a +# test: def hutchinson_trace_autodiff(V): @@ -154,7 +169,9 @@ def hutchinson_trace_autodiff(V): # %% # Trace estimation (autodiff, block-diagonal Hessian) # --------------------------------------------------- -# Since :code:`HMP` uses only the Hessian block-diagonal and not the full Hessian, here is the corresponding autodiff implementation using the same matrix as :code:`HMP`. We are going to reinvestigate it for benchmarking. +# Since :code:`HMP` uses only the Hessian block-diagonal and not the full Hessian, here +# is the corresponding autodiff implementation using the same matrix as :code:`HMP`. We +# are going to reinvestigate it for benchmarking. def hutchinson_trace_autodiff_blockwise(V): @@ -184,7 +201,9 @@ def hutchinson_trace_autodiff_blockwise(V): # %% # Trace approximation accuracy # ---------------------------- -# Next, let's observe how the approximation improves with the number of samples. Here, we plot multiple runs of the Hutchinson trace estimate, initialized at different random seeds. +# Next, let's observe how the approximation improves with the number of samples. Here, +# we plot multiple runs of the Hutchinson trace estimate, initialized at different +# random seeds. V_steps = 30 V_list = torch.logspace(1, 3, steps=V_steps).int() @@ -215,10 +234,13 @@ def hutchinson_trace_autodiff_blockwise(V): _ = plt.legend() -#%% +# %% # Runtime comparison # ------------------ -# Finally, we investigate if the trace estimation is sped up by vectorizing the HVPs. In particular, let's compare the estimations using autodiff HVPs (no parallelization), autodiff block-diagonal HVPs (no parallelization) and block-diagonal vectorized HVPs (:code:`HMP`). +# Finally, we investigate if the trace estimation is sped up by vectorizing the HVPs. +# In particular, let's compare the estimations using autodiff HVPs (no parallelization), +# autodiff block-diagonal HVPs (no parallelization) and block-diagonal vectorized HVPs +# (:code:`HMP`). V = 1000 @@ -269,7 +291,10 @@ def time_hutchinson_trace_hmp(V, V_batch): time_hutchinson_trace_hmp(V, V_batch=20) # %% -# Looks like the parallelized Hessian-vector products are able to speed up the computation. Nice. +# Looks like the parallel Hessian-vector products are able to speed up the +# computation. Nice. # %% -# Note that instead of the Hessian, we could have also used other interesting matrices, such as the generalized Gauss-Newton. BackPACK also offers a vectorized multiplication with the latter's block-diagonal (see the :code:`GGNMP` extension). +# Note that instead of the Hessian, we could have also used other interesting matrices, +# such as the generalized Gauss-Newton. BackPACK also offers a vectorized multiplication +# with the latter's block-diagonal (see the :code:`GGNMP` extension). diff --git a/docs_src/rtd/conf.py b/docs_src/rtd/conf.py index f7e0f5365..a95029434 100644 --- a/docs_src/rtd/conf.py +++ b/docs_src/rtd/conf.py @@ -30,6 +30,7 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + "sphinx.ext.napoleon", "sphinx.ext.autodoc", "sphinx.ext.coverage", "sphinx.ext.autosectionlabel", diff --git a/docs_src/rtd/extensions.rst b/docs_src/rtd/extensions.rst index 1e92cba67..2fa9e2369 100644 --- a/docs_src/rtd/extensions.rst +++ b/docs_src/rtd/extensions.rst @@ -18,10 +18,13 @@ Available Extensions .. autofunction:: backpack.extensions.DiagGGNMC .. autofunction:: backpack.extensions.DiagGGNExact +.. autofunction:: backpack.extensions.BatchDiagGGNMC +.. autofunction:: backpack.extensions.BatchDiagGGNExact .. autofunction:: backpack.extensions.KFAC .. autofunction:: backpack.extensions.KFLR .. autofunction:: backpack.extensions.KFRA .. autofunction:: backpack.extensions.DiagHessian +.. autofunction:: backpack.extensions.BatchDiagHessian ----- @@ -32,4 +35,3 @@ Available Extensions .. autofunction:: backpack.extensions.HMP .. autofunction:: backpack.extensions.GGNMP .. autofunction:: backpack.extensions.PCHMP - diff --git a/docs_src/rtd/good-to-know.rst b/docs_src/rtd/good-to-know.rst index 1463f9674..23911ea9e 100644 --- a/docs_src/rtd/good-to-know.rst +++ b/docs_src/rtd/good-to-know.rst @@ -81,14 +81,9 @@ We're working on how to handle those, as well as adding more Along those lines, some things that will (most likely) not work with BackPACK, but that we're trying to build support for: -- Inplace operations (e.g., using ``inplace=True`` for activation functions like - :py:class:`torch.nn.ReLU`. - Reusing the same parameters or module multiple time in the computation graph. For second order extensions, this also holds for any module, whether or not they have parameters. This sadly mean that BackPACK can't compute the individual gradients or second-order information of a L2-regularized loss, for example. - - - diff --git a/docs_src/rtd/main-api.rst b/docs_src/rtd/main-api.rst index 34ed51edf..67a4d0c52 100644 --- a/docs_src/rtd/main-api.rst +++ b/docs_src/rtd/main-api.rst @@ -67,4 +67,4 @@ and the :ref:`Supported models`. .. autofunction:: backpack.extend .. autofunction:: backpack.backpack - +.. autofunction:: backpack.disable diff --git a/docs_src/rtd/supported-layers.rst b/docs_src/rtd/supported-layers.rst index 624f0dcb6..149bb7dc2 100644 --- a/docs_src/rtd/supported-layers.rst +++ b/docs_src/rtd/supported-layers.rst @@ -1,10 +1,10 @@ Supported models ==================================== -BackPACK expects models to be -`sequences `_ +BackPACK expects models to be +`sequences `_ of `PyTorch NN modules `_. -For example, +For example, .. code-block:: python @@ -21,8 +21,8 @@ This page lists the layers currently supported by BackPACK. If the forward is not standard, the additional backward pass to compute second-order quantities will not match the actual function. First-order extensions that extract information might work outside of this framework, but it is not tested. -.. raw:: html - +.. raw:: html +
For first-order extensions @@ -53,7 +53,13 @@ BackPACK needs to know how to propagate second-order information. This is implemented for: +-------------------------------+---------------------------------------+ -| **Parametrized layers** | :py:class:`torch.nn.Conv2d` | +| **Parametrized layers** | :py:class:`torch.nn.Conv1d`, | +| | :py:class:`torch.nn.Conv2d`, | +| | :py:class:`torch.nn.Conv3d` | +| +---------------------------------------+ +| | :py:class:`torch.nn.ConvTranspose1d`, | +| | :py:class:`torch.nn.ConvTranspose2d`, | +| | :py:class:`torch.nn.ConvTranspose3d` | | +---------------------------------------+ | | :py:class:`torch.nn.Linear` | +-------------------------------+---------------------------------------+ @@ -61,15 +67,26 @@ This is implemented for: | +---------------------------------------+ | | :py:class:`torch.nn.CrossEntropyLoss` | +-------------------------------+---------------------------------------+ -| **Layers without parameters** | :py:class:`torch.nn.MaxPool2d` | -| | :py:class:`torch.nn.AvgPool2d` | +| **Layers without parameters** | :py:class:`torch.nn.MaxPool1d`, | +| | :py:class:`torch.nn.MaxPool2d`, | +| | :py:class:`torch.nn.MaxPool3d` | +| +---------------------------------------+ +| | :py:class:`torch.nn.AvgPool1d`, | +| | :py:class:`torch.nn.AvgPool2d`, | +| | :py:class:`torch.nn.AvgPool3d` | +| +---------------------------------------+ +| | :py:class:`torch.nn.ZeroPad2d`, | | +---------------------------------------+ | | :py:class:`torch.nn.Dropout` | | +---------------------------------------+ -| | :py:class:`torch.nn.ReLU` | -| | :py:class:`torch.nn.Sigmoid` | -| | :py:class:`torch.nn.Tanh` | +| | :py:class:`torch.nn.ReLU`, | +| | :py:class:`torch.nn.Sigmoid`, | +| | :py:class:`torch.nn.Tanh`, | +| | :py:class:`torch.nn.LeakyReLU`, | +| | :py:class:`torch.nn.LogSigmoid`, | +| | :py:class:`torch.nn.ELU`, | +| | :py:class:`torch.nn.SELU` | +-------------------------------+---------------------------------------+ -The other convolution layers (``Conv1d``, ``Conv3d``, and ``ConvTransposeNd``) -are not yet supported. +Some exotic hyperparameters are not fully supported, but feature requests +on the repository are welcome. diff --git a/fully_documented.txt b/fully_documented.txt index bda381452..480a24b3e 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -6,6 +6,7 @@ backpack/core/derivatives/shape_check.py backpack/core/derivatives/__init__.py backpack/core/derivatives/permute.py +backpack/extensions/__init__.py backpack/extensions/backprop_extension.py backpack/extensions/mat_to_mat_jac_base.py @@ -24,10 +25,15 @@ backpack/extensions/firstorder/sum_grad_squared/__init__.py backpack/extensions/firstorder/batch_l2_grad/rnn.py backpack/extensions/firstorder/batch_l2_grad/__init__.py +backpack/extensions/secondorder/__init__.py backpack/extensions/secondorder/diag_ggn/__init__.py backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py backpack/extensions/secondorder/diag_ggn/rnn.py backpack/extensions/secondorder/diag_ggn/permute.py +backpack/extensions/secondorder/diag_hessian/__init__.py +backpack/extensions/secondorder/diag_hessian/conv1d.py +backpack/extensions/secondorder/diag_hessian/conv2d.py +backpack/extensions/secondorder/diag_hessian/conv3d.py backpack/custom_module/ @@ -46,3 +52,4 @@ test/extensions/firstorder/batch_grad/batchgrad_settings.py test/extensions/secondorder/secondorder_settings.py test/extensions/secondorder/diag_ggn/ +test/extensions/secondorder/hbp diff --git a/makefile b/makefile index 9823af519..fbcd353f1 100644 --- a/makefile +++ b/makefile @@ -42,13 +42,13 @@ help: @echo "isort-check" @echo " Check if isort (sort imports) would change files" @echo "install-dev" - @echo " Install all development tools" + @echo " Install backpack and all development tools" @echo "install-lint" - @echo " Install only the linter tools (included in install-dev)" + @echo " Install backpack and linter tools (included in install-dev)" @echo "install-test" - @echo " Install only the testing tools (included in install-dev)" + @echo " Install backpack and testing tools (included in install-dev)" @echo "install-docs" - @echo " Install only the tools to build/view the docs (included in install-dev)" + @echo " Install backpack and tools to build/view the docs (included in install-dev)" @echo "conda-env" @echo " Create conda environment 'backpack' with dev setup" @echo "build-docs" @@ -114,31 +114,19 @@ format-check-partial: black-check isort-check flake8 pydocstyle-check-partial da # Installation install: - @pip install -r requirements.txt - @pip install . + @pip install -e . install-lint: - @pip install -r requirements/lint.txt + @pip install -e .[lint] install-test: - @pip install -r requirements/test.txt + @pip install -e .[test] + @pip install git+https://git@github.com/Stonesjtu/pytorch_memlab.git@6ab5fab#egg=pytorch_memlab install-docs: - @pip install -r requirements/docs.txt - -install-devtools: - @echo "Install dev tools..." - @pip install -r requirements-dev.txt - -install-dev: install-devtools - @echo "Install dependencies..." - @pip install -r requirements.txt - @echo "Uninstall existing version of backpack..." - @pip uninstall backpack-for-pytorch - @echo "Install backpack in editable mode..." - @pip install -e . - @echo "Install pre-commit hooks..." - @pre-commit install + @pip install -e .[docs] + +install-dev: install-lint install-test install-docs ### # Conda environment diff --git a/pytest.ini b/pytest.ini index 478d7ace8..3160620ef 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ +# NOTE The documentation recommends to **not** configure pytest with setup.cfg +# (https://docs.pytest.org/en/6.2.x/customize.html#setup-cfg) [pytest] optional_tests: montecarlo: slow tests using low-precision allclose after Monte-Carlo sampling diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 4d01b95cc..000000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,4 +0,0 @@ --r requirements/test.txt --r requirements/lint.txt --r requirements/docs.txt -pre-commit diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ceb807d40..000000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -torch >= 1.6.0, < 2.0.0 -torchvision >= 0.7.0, < 1.0.0 -einops >= 0.3.0, < 1.0.0 diff --git a/requirements/docs.txt b/requirements/docs.txt deleted file mode 100644 index faf1e94b6..000000000 --- a/requirements/docs.txt +++ /dev/null @@ -1,2 +0,0 @@ -matplotlib -sphinx-gallery diff --git a/requirements/lint.txt b/requirements/lint.txt deleted file mode 100644 index 84c4cca44..000000000 --- a/requirements/lint.txt +++ /dev/null @@ -1,12 +0,0 @@ -darglint -flake8 -mccabe -pycodestyle -pydocstyle -pyflakes -pep8-naming -flake8-bugbear -flake8-comprehensions -flake8-tidy-imports -black -isort \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt deleted file mode 100644 index 17b2e1b5b..000000000 --- a/requirements/test.txt +++ /dev/null @@ -1,7 +0,0 @@ -scipy -pytest >= 4.5.0, < 5.0.0 -pytest-benchmark >= 3.2.2, < 4.0.0 -pytest-optional-tests >= 0.1.1 -pytest-cov -coveralls -git+git://github.com/Stonesjtu/pytorch_memlab.git@6ab5fab#egg=pytorch_memlab diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..bf2aa19f0 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,140 @@ +# This file is used to configure your project. +# Read more about the various options under: +# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files + +############################################################################### +# Main library # +############################################################################### + +[metadata] +name = backpack-for-pytorch +author = Felix Dangel, Frederik Kunstner +url = https://github.com/f-dangel/backpack +description = BackPACK: Packing more into backprop +long_description = file: README.md +long_description_content_type = text/markdown; charset=UTF-8; variant=GFM +license = MIT +# Change if running only on Windows, Mac or Linux (comma-separated) +platforms = any +# Add all kinds of additional classifiers as defined under +# https://pypi.python.org/pypi?%3Aaction=list_classifiers +classifiers = + Development Status :: 4 - Beta + License :: OSI Approved :: MIT License + Operating System :: OS Independent + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + +[options] +zip_safe = False +packages = find: +include_package_data = True +setup_requires = + setuptools_scm +# Dependencies of the project (semicolon/line-separated): +install_requires = + torch >= 1.6.0, < 2.0.0 + torchvision >= 0.7.0, < 1.0.0 + einops >= 0.3.0, < 1.0.0 +# Require a specific Python version, e.g. Python 2.7 or >= 3.4 +python_requires = >=3.6 + +[options.packages.find] +exclude = test* + +############################################################################### +# Development dependencies # +############################################################################### + +[options.extras_require] +# Dependencies needed to run the tests (semicolon/line-separated) +test = + scipy + pytest >= 4.5.0, < 5.0.0 + pytest-benchmark >= 3.2.2, < 4.0.0 + pytest-optional-tests >= 0.1.1 + pytest-cov + coveralls + +# Dependencies needed for linting (semicolon/line-separated) +lint = + darglint + flake8 + mccabe + pycodestyle + pydocstyle + pyflakes + pep8-naming + flake8-bugbear + flake8-comprehensions + flake8-tidy-imports + black + isort + +# Dependencies needed to build/view the documentation (semicolon/line-separated) +docs = + matplotlib + sphinx-gallery + memory_profiler + +############################################################################### +# Development tool configurations # +############################################################################### + +[isort] +profile=black +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True + +[flake8] +select = B,C,E,F,P,W,B9 +max-line-length = 88 +max-complexity = 10 +ignore = + # replaced by B950 (max-line-length + 10%) + E501, # max-line-length + # ignored because pytorch uses dict + C408, # use {} instead of dict() + # Not Black-compatible + E203, # whitespace before : + E231, # missing whitespace after ',' + W291, # trailing whitespace + W503, # line break before binary operator + W504, # line break after binary operator +exclude = docs, build, .git, docs_src/rtd, docs_src/rtd_output, .eggs + +# Differences with pytorch +# +# Smaller max-line-length +# Enabled max-complexity +# No flake8-mypy (T4 range) +# +# Set of rules ignore by pytorch, probably to get around the C +# +# F401 (import unused in __init__.py) not ignored +# F403 'from module import *' used; unable to detect undefined names +# F405 Name may be undefined, or defined from star imports: module +# F821 Undefined name name +# F841 Local variable name is assigned to but never used +# +# Pytorch ignored rules that I don't see a reason to ignore (yet?): +# +# E305 Expected 2 blank lines after end of function or class +# E402 Module level import not at top of file +# E721 Do not compare types, use 'isinstance()' +# E741 Do not use variables named 'l', 'o', or 'i' +# E302 Expected 2 blank lines, found 0 +# E303 Too many blank lines (3) + +[darglint] +docstring_style = google +# short, long, full +strictness = full + +[pydocstyle] +convention = google +match = .*\.py \ No newline at end of file diff --git a/setup.py b/setup.py index edf8054b8..3358418d6 100644 --- a/setup.py +++ b/setup.py @@ -1,48 +1,18 @@ -"""Setup backpack.""" -from os import path +"""Setup file for BackPACK. -from setuptools import find_packages, setup +Use ``setup.cfg`` for configuration. +""" +import sys -# META -############################################################################## -AUTHORS = "F. Dangel, F. Kunstner" -NAME = "backpack-for-pytorch" -PACKAGES = find_packages() +from pkg_resources import VersionConflict, require +from setuptools import setup -DESCRIPTION = "BackPACK: Packing more into backprop" -LONG_DESCR = """ - BackPACK is built on top of PyTorch. - It efficiently computes quantities other than the gradient. +try: + require("setuptools>=38.3") +except VersionConflict: + print("Error: version of setuptools is too old (<38.3)!") + sys.exit(1) - Website: https://backpack.pt - Code: https://github.com/f-dangel/backpack - Documentation: https://readthedocs.org/projects/backpack/ - Bug reports & feature requests: https://github.com/f-dangel/backpack/issues - """ -VERSION = "1.2.0" -URL = "https://github.com/f-dangel/backpack" -LICENSE = "MIT" - -# DEPENDENCIES -############################################################################## -REQUIREMENTS_FILE = "requirements.txt" -REQUIREMENTS_PATH = path.join(path.abspath(__file__), REQUIREMENTS_FILE) - -with open(REQUIREMENTS_FILE) as f: - requirements = f.read().splitlines() - -setup( - author=AUTHORS, - name=NAME, - version=VERSION, - description=DESCRIPTION, - long_description=LONG_DESCR, - long_description_content_type="text/markdown", - install_requires=requirements, - url=URL, - license=LICENSE, - packages=PACKAGES, - zip_safe=False, - python_requires=">=3.7", -) +if __name__ == "__main__": + setup(use_scm_version=True) diff --git a/test/automated_test.py b/test/automated_test.py index e0e2be0fd..9fb2e08ba 100644 --- a/test/automated_test.py +++ b/test/automated_test.py @@ -232,8 +232,7 @@ def test_hmp(problem, device): autograd_res = AutogradImpl(problem).hmp(matrices) check_sizes(autograd_res, backpack_res) - atol = 5e-4 - rtol = 5e-4 + atol, rtol = 5e-4, 5e-4 check_values(autograd_res, backpack_res, atol=atol, rtol=rtol) @@ -272,7 +271,8 @@ def test_hvp(problem, device): autograd_res = AutogradImpl(problem).hvp(vecs) check_sizes(autograd_res, backpack_res) - check_values(autograd_res, backpack_res) + atol, rtol = 5e-4, 5e-4 + check_values(autograd_res, backpack_res, atol=atol, rtol=rtol) @pytest.mark.parametrize( diff --git a/test/core/derivatives/convolution_settings.py b/test/core/derivatives/convolution_settings.py index 08d88c7b1..b1e14df40 100644 --- a/test/core/derivatives/convolution_settings.py +++ b/test/core/derivatives/convolution_settings.py @@ -378,7 +378,7 @@ }, ] -CONVOLUTION_FAIL_SETTINGS = [ +_CONVOLUTION_GROUP_SETTINGS = [ # groups - 2 { "module_fn": lambda: torch.nn.Conv1d( @@ -454,8 +454,9 @@ "id_prefix": "groups-3", }, ] +CONVOLUTION_SETTINGS += _CONVOLUTION_GROUP_SETTINGS -CONVOLUTION_TRANSPOSED_FAIL_SETTINGS = [ +_CONVOLUTION_TRANSPOSED_GROUP_SETTINGS = [ { "module_fn": lambda: torch.nn.ConvTranspose1d( in_channels=6, @@ -529,3 +530,4 @@ "id_prefix": "groups-2", }, ] +CONVOLUTION_SETTINGS += _CONVOLUTION_TRANSPOSED_GROUP_SETTINGS diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index ecb0c9b07..8ff0b168a 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -8,10 +8,6 @@ """ from test.automated_test import check_sizes_and_values -from test.core.derivatives.convolution_settings import ( - CONVOLUTION_FAIL_SETTINGS, - CONVOLUTION_TRANSPOSED_FAIL_SETTINGS, -) from test.core.derivatives.implementation.autograd import AutogradDerivatives from test.core.derivatives.implementation.backpack import BackpackDerivatives from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS @@ -38,16 +34,6 @@ LOSS_FAIL_PROBLEMS = make_test_problems(LOSS_FAIL_SETTINGS) LOSS_FAIL_IDS = [problem.make_id() for problem in LOSS_FAIL_PROBLEMS] -CONVOLUTION_FAIL_PROBLEMS = make_test_problems(CONVOLUTION_FAIL_SETTINGS) -CONVOLUTION_FAIL_IDS = [problem.make_id() for problem in CONVOLUTION_FAIL_PROBLEMS] - -CONVOLUTION_TRANSPOSED_FAIL_PROBLEMS = make_test_problems( - CONVOLUTION_TRANSPOSED_FAIL_SETTINGS -) -CONVOLUTION_TRANSPOSED_FAIL_IDS = [ - problem.make_id() for problem in CONVOLUTION_TRANSPOSED_FAIL_PROBLEMS -] - RNN_PROBLEMS = make_test_problems(RNN_SETTINGS) RNN_IDS = [problem.make_id() for problem in RNN_PROBLEMS] @@ -219,11 +205,7 @@ def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): [True, False], ids=["save_memory=True", "save_memory=False"], ) -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_WEIGHTS, - ids=IDS_WITH_WEIGHTS, -) +@pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) def test_weight_jac_t_mat_prod(problem, sum_batch, save_memory, V=3): """Test the transposed Jacobian-matrix product w.r.t. to the weights. @@ -246,19 +228,7 @@ def test_weight_jac_t_mat_prod(problem, sum_batch, save_memory, V=3): problem.tear_down() -PROBLEMS_WITH_WEIGHTS_NO_GROUPS = [] -IDS_WITH_WEIGHTS_NO_GROUPS = [] -for problem, problem_id in zip(PROBLEMS, IDS): - if problem.has_weight() and not problem.is_group_conv(): - PROBLEMS_WITH_WEIGHTS_NO_GROUPS.append(problem) - IDS_WITH_WEIGHTS_NO_GROUPS.append(problem_id) - - -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_WEIGHTS_NO_GROUPS, - ids=IDS_WITH_WEIGHTS_NO_GROUPS, -) +@pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) def test_weight_jac_mat_prod(problem, V=3): """Test the Jacobian-matrix product w.r.t. to the weights. @@ -353,48 +323,6 @@ def test_sqrt_hessian_squared_equals_hessian(problem): problem.tear_down() -@pytest.mark.parametrize( - "problem", - CONVOLUTION_TRANSPOSED_FAIL_PROBLEMS + CONVOLUTION_FAIL_PROBLEMS, - ids=CONVOLUTION_TRANSPOSED_FAIL_IDS + CONVOLUTION_FAIL_IDS, -) -def test_weight_jac_mat_prod_should_fail(problem): - """Tests weight_jac_mat_prod. - - Should fail. - - Args: - problem: test problem - """ - with pytest.raises(NotImplementedError): - test_weight_jac_mat_prod(problem) - - -@pytest.mark.parametrize( - "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] -) -@pytest.mark.parametrize( - "save_memory", - [True, False], - ids=["save_memory=True", "save_memory=False"], -) -@pytest.mark.parametrize( - "problem", CONVOLUTION_TRANSPOSED_FAIL_PROBLEMS, ids=CONVOLUTION_TRANSPOSED_FAIL_IDS -) -def test_weight_jac_t_mat_prod_should_fail(problem, sum_batch, save_memory): - """Test weight_jac_t_mat_prod. - - Should fail. - - Args: - problem: problem - sum_batch: whether to sum along batch axis - save_memory: whether to save memory - """ - with pytest.raises(NotImplementedError): - test_weight_jac_t_mat_prod(problem, sum_batch, save_memory) - - @pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_should_fail(problem): """Test sqrt_hessian. Should fail. diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index 8c35aee4a..cb6f93120 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -185,13 +185,3 @@ def has_weight(self): def has_bias(self): module = self.make_module() return hasattr(module, "bias") and module.bias is not None - - def is_group_conv(self): - """Return whether module represents grouped convolution.""" - module = self.make_module() - group_conv = False - - if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)): - group_conv = module.groups > 1 - - return group_conv diff --git a/test/core/derivatives/settings.py b/test/core/derivatives/settings.py index fa9d616eb..0e1c33024 100644 --- a/test/core/derivatives/settings.py +++ b/test/core/derivatives/settings.py @@ -1,15 +1,13 @@ -"""Test configurations for `backpack.core.derivatives`. +"""Test cases for `backpack.core.derivatives`. -Required entries: - The tests for individual categories are - written in respective files and imported here. - Tests: - Activation layers - Convolutional layers - Linear Layers - Loss functions - Pooling layers - Padding layers +Cases are divided into the following layer categories: + +- Activations +- (Transposed) convolutions +- Linear +- Losses +- Padding +- Pooling """ from test.core.derivatives.activation_settings import ACTIVATION_SETTINGS @@ -19,9 +17,7 @@ from test.core.derivatives.padding_settings import PADDING_SETTINGS from test.core.derivatives.pooling_settings import POOLING_SETTINGS -SETTINGS = [] - -SETTINGS.extend( +SETTINGS = ( ACTIVATION_SETTINGS + CONVOLUTION_SETTINGS + LINEAR_SETTINGS diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index 8bedcfea4..71655b239 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -131,9 +131,7 @@ def _diag_ggn_batch(self): params_batch_diag_ggn = list(zip(*batch_diag_ggn)) return [torch.stack(param) * factor for param in params_batch_diag_ggn] - def diag_h(self): - _, _, loss = self.problem.forward_pass() - + def _get_diag_h(self, loss): def hvp(df_dx, x, v): Hv = R_op(df_dx, x, v) return [j.detach() for j in Hv] @@ -158,5 +156,23 @@ def extract_ith_element_of_diag_h(i, p, df_dx): diag_h_p[parameter_index] = diag_value diag_hs.append(diag_h_p.view(p.size())) - return diag_hs + + def diag_h(self): + _, _, loss = self.problem.forward_pass() + return self._get_diag_h(loss) + + def diag_h_batch(self): + batch_size = self.problem.input.shape[0] + _, _, batch_loss = self.problem.forward_pass() + loss_list = torch.zeros(batch_size, device=self.problem.device) + + batch_diag_h = [] + for b in range(batch_size): + _, _, loss = self.problem.forward_pass(sample_idx=b) + loss_list[b] = loss + diag_h = self._get_diag_h(loss) + batch_diag_h.append(diag_h) + factor = self.problem.get_reduction_factor(batch_loss, loss_list) + params_batch_diag_h = list(zip(*batch_diag_h)) + return [torch.stack(param) * factor for param in params_batch_diag_h] diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index 1040947e5..ae121c6c5 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -158,3 +158,35 @@ def diag_h(self): loss.backward() diag_h = [p.diag_h for p in self.problem.model.parameters()] return diag_h + + def kfac(self, mc_samples=1): + with backpack(new_ext.KFAC(mc_samples=mc_samples)): + _, _, loss = self.problem.forward_pass() + loss.backward() + kfac = [p.kfac for p in self.problem.model.parameters()] + + return kfac + + def kflr(self): + with backpack(new_ext.KFLR()): + _, _, loss = self.problem.forward_pass() + loss.backward() + kflr = [p.kflr for p in self.problem.model.parameters()] + + return kflr + + def kfra(self): + with backpack(new_ext.KFRA()): + _, _, loss = self.problem.forward_pass() + loss.backward() + kfra = [p.kfra for p in self.problem.model.parameters()] + + return kfra + + def diag_h_batch(self): + with backpack(new_ext.BatchDiagHessian()): + _, _, loss = self.problem.forward_pass() + loss.backward() + diag_h_batch = [p.diag_h_batch for p in self.problem.model.parameters()] + + return diag_h_batch diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index 1e566f584..0b2eb73b4 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -39,3 +39,37 @@ def diag_ggn_mc_batch(self, mc_samples): def diag_h(self): """Diagonal of Hessian""" raise NotImplementedError + + def kfac(self, mc_samples=1): + """Kronecker-factored approximate curvature (KFAC). + + Args: + mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``. + + Returns: + list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + """ + raise NotImplementedError + + def kflr(self): + """Kronecker-factored low-rank approximation (KFLR). + + Returns: + list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + """ + raise NotImplementedError + + def kfra(self): + """Kronecker-factored recursive approximation (KFRA). + + Returns: + list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + """ + + def diag_h_batch(self): + """Per-sample Hessian diagonal. + + Returns: + list(torch.Tensor): Parameter-wise per-sample Hessian diagonal. + """ + raise NotImplementedError diff --git a/test/extensions/secondorder/diag_hessian/test_diag_hessian.py b/test/extensions/secondorder/diag_hessian/test_diag_hessian.py index bdddf2036..16e4bd0b0 100644 --- a/test/extensions/secondorder/diag_hessian/test_diag_hessian.py +++ b/test/extensions/secondorder/diag_hessian/test_diag_hessian.py @@ -24,3 +24,19 @@ def test_diag_h(problem): check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() + + +@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) +def test_diag_h_batch(problem): + """Test Diagonal of Hessian + + Args: + problem (ExtensionsTestProblem): Problem for extension test. + """ + problem.set_up() + + backpack_res = BackpackExtensions(problem).diag_h_batch() + autograd_res = AutogradExtensions(problem).diag_h_batch() + + check_sizes_and_values(autograd_res, backpack_res) + problem.tear_down() diff --git a/test/extensions/secondorder/hbp/__init__.py b/test/extensions/secondorder/hbp/__init__.py new file mode 100644 index 000000000..0c4dcb653 --- /dev/null +++ b/test/extensions/secondorder/hbp/__init__.py @@ -0,0 +1 @@ +"""Tests for ``backpack.extensions.secondorder.hbp`` (Kronecker curvatures).""" diff --git a/test/extensions/secondorder/hbp/kfac_settings.py b/test/extensions/secondorder/hbp/kfac_settings.py new file mode 100644 index 000000000..895b99cc5 --- /dev/null +++ b/test/extensions/secondorder/hbp/kfac_settings.py @@ -0,0 +1,8 @@ +"""Define test cases for KFAC.""" + +from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS + +SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +LOCAL_NOT_SUPPORTED_SETTINGS = [] + +NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/hbp/kflr_settings.py b/test/extensions/secondorder/hbp/kflr_settings.py new file mode 100644 index 000000000..6b74a2842 --- /dev/null +++ b/test/extensions/secondorder/hbp/kflr_settings.py @@ -0,0 +1,8 @@ +"""Define test cases for KFLR.""" + +from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS + +SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +LOCAL_NOT_SUPPORTED_SETTINGS = [] + +NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/hbp/kfra_settings.py b/test/extensions/secondorder/hbp/kfra_settings.py new file mode 100644 index 000000000..5a28ab738 --- /dev/null +++ b/test/extensions/secondorder/hbp/kfra_settings.py @@ -0,0 +1,8 @@ +"""Define test cases for KFRA.""" + +from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS + +SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +LOCAL_NOT_SUPPORTED_SETTINGS = [] + +NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/hbp/test_kfac.py b/test/extensions/secondorder/hbp/test_kfac.py new file mode 100644 index 000000000..d6205e290 --- /dev/null +++ b/test/extensions/secondorder/hbp/test_kfac.py @@ -0,0 +1,25 @@ +"""Test BackPACK's KFAC extension.""" + +from test.extensions.implementation.backpack import BackpackExtensions +from test.extensions.problem import make_test_problems +from test.extensions.secondorder.hbp.kfac_settings import NOT_SUPPORTED_SETTINGS + +import pytest + +NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) +NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] + + +@pytest.mark.parametrize("problem", NOT_SUPPORTED_PROBLEMS, ids=NOT_SUPPORTED_IDS) +def test_kfac_not_supported(problem): + """Check that the KFAC extension does not allow specific hyperparameters/modules. + + Args: + problem (ExtensionsTestProblem): Test case. + """ + problem.set_up() + + with pytest.raises(NotImplementedError): + BackpackExtensions(problem).kfac() + + problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kflr.py b/test/extensions/secondorder/hbp/test_kflr.py new file mode 100644 index 000000000..79e46f186 --- /dev/null +++ b/test/extensions/secondorder/hbp/test_kflr.py @@ -0,0 +1,25 @@ +"""Test BackPACK's KFLR extension.""" + +from test.extensions.implementation.backpack import BackpackExtensions +from test.extensions.problem import make_test_problems +from test.extensions.secondorder.hbp.kflr_settings import NOT_SUPPORTED_SETTINGS + +import pytest + +NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) +NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] + + +@pytest.mark.parametrize("problem", NOT_SUPPORTED_PROBLEMS, ids=NOT_SUPPORTED_IDS) +def test_kflr_not_supported(problem): + """Check that the KFLR extension does not allow specific hyperparameters/modules. + + Args: + problem (ExtensionsTestProblem): Test case. + """ + problem.set_up() + + with pytest.raises(NotImplementedError): + BackpackExtensions(problem).kflr() + + problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kfra.py b/test/extensions/secondorder/hbp/test_kfra.py new file mode 100644 index 000000000..387438308 --- /dev/null +++ b/test/extensions/secondorder/hbp/test_kfra.py @@ -0,0 +1,25 @@ +"""Test BackPACK's KFRA extension.""" + +from test.extensions.implementation.backpack import BackpackExtensions +from test.extensions.problem import make_test_problems +from test.extensions.secondorder.hbp.kfra_settings import NOT_SUPPORTED_SETTINGS + +import pytest + +NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) +NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] + + +@pytest.mark.parametrize("problem", NOT_SUPPORTED_PROBLEMS, ids=NOT_SUPPORTED_IDS) +def test_kfra_not_supported(problem): + """Check that the KFRA extension does not allow specific hyperparameters/modules. + + Args: + problem (ExtensionsTestProblem): Test case. + """ + problem.set_up() + + with pytest.raises(NotImplementedError): + BackpackExtensions(problem).kfra() + + problem.tear_down() diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index a3844dd39..f90456134 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -1,11 +1,10 @@ -"""Test configurations for `backpack.core.extensions.secondorder`. +"""Shared test cases of BackPACK's second-order extensions. -that is shared among the following secondorder methods: -- Diagonal of Gauss Newton -- Diagonal of Hessian +- Exact diagonal of the generalized Gauss-Newton +- MC-approximated diagonal of the generalized Gauss-Newton +- Diagonal of the Hessian - MC Approximation of Diagonal of Gauss Newton - Required entries: "module_fn" (callable): Contains a model constructed from `torch.nn` layers "input_fn" (callable): Used for specifying input function @@ -29,6 +28,8 @@ import torch from torch.nn import ( + ELU, + SELU, AvgPool1d, AvgPool2d, AvgPool3d, @@ -98,7 +99,7 @@ ############################################################################### # test setting: Activation Layers # ############################################################################### -activations = [ReLU, Sigmoid, Tanh, LeakyReLU, LogSigmoid] +activations = [ReLU, Sigmoid, Tanh, LeakyReLU, LogSigmoid, ELU, SELU] for act in activations: for bias in [True, False]: @@ -215,3 +216,20 @@ (3, 3, 2, 7, 7), ConvTranspose3d, (3, 2, 2, 4, 2, 0, 1, False) ), ] + +GROUP_CONV_SETTINGS = [ + # last number is groups + make_simple_cnn_setting((3, 6, 7), Conv1d, (6, 4, 2, 1, 0, 1, 2)), + make_simple_cnn_setting((3, 6, 7, 5), Conv2d, (6, 3, 2, 1, 0, 1, 3)), + make_simple_cnn_setting((3, 4, 7, 5, 6), Conv3d, (4, 2, 2, 1, 0, 2, 2)), + # number before bool is groups + make_simple_cnn_setting((3, 6, 8), ConvTranspose1d, (6, 3, 2, 4, 2, 0, 3, True, 3)), + make_simple_cnn_setting( + (3, 4, 9, 9), ConvTranspose2d, (4, 2, 2, 1, 0, 0, 2, True, 2) + ), + make_simple_cnn_setting( + (3, 4, 3, 5, 5), ConvTranspose3d, (4, 2, (2, 2, 1), 2, 2, 0, 2, True, 2) + ), +] + +SECONDORDER_SETTINGS += GROUP_CONV_SETTINGS diff --git a/test/utils/test_conv.py b/test/utils/test_conv.py index 74f51fd33..b350c07f0 100644 --- a/test/utils/test_conv.py +++ b/test/utils/test_conv.py @@ -6,7 +6,7 @@ import pytest import torch -from backpack.utils.conv import unfold_by_conv, unfold_func +from backpack.utils.conv import unfold_by_conv, unfold_input from ..automated_test import check_sizes_and_values @@ -34,9 +34,7 @@ def get_output_shape(input, module): C_out = output_shape[1] spatial_out_size = output_shape[2:] spatial_out_numel = spatial_out_size.numel() - - kernel_size = module.kernel_size - kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size))) + kernel_size_numel = module.weight.shape[2:].numel() G = module.groups @@ -62,7 +60,7 @@ def test_unfold_by_conv(problem): problem.set_up() input = torch.rand(problem.input_shape).to(problem.device) - result_unfold = unfold_func(problem.module)(input) + result_unfold = unfold_input(problem.module, input) result_unfold_by_conv = unfold_by_conv(input, problem.module) check_sizes_and_values(result_unfold, result_unfold_by_conv) diff --git a/test/utils/test_conv_transpose.py b/test/utils/test_conv_transpose.py index 05d5a990e..1019f850b 100644 --- a/test/utils/test_conv_transpose.py +++ b/test/utils/test_conv_transpose.py @@ -27,9 +27,7 @@ def get_output_shape(input, module): C_out = output_shape[1] spatial_out_size = output_shape[2:] spatial_out_numel = spatial_out_size.numel() - - kernel_size = module.kernel_size - kernel_size_numel = int(torch.prod(torch.Tensor(kernel_size))) + kernel_size_numel = module.weight.shape[2:].numel() G = module.groups From 080d1f0538387b5ae12ca7788f2aca6bc91d5901 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Fri, 18 Jun 2021 11:22:57 +0200 Subject: [PATCH 12/54] [core] LSTM derivatives (#169) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * LSTM derivatives and test * change ValueError to NotImplementedError * improve lstm.py * improve lstm.py * improve lstm.py * finish settings * reduce memory consumption * simplify einsum Co-authored-by: Tim Schäfer --- backpack/core/derivatives/lstm.py | 346 ++++++++++++++++++++++ fully_documented.txt | 2 + test/core/derivatives/__init__.py | 3 + test/core/derivatives/derivatives_test.py | 28 +- test/core/derivatives/lstm_settings.py | 28 ++ 5 files changed, 399 insertions(+), 8 deletions(-) create mode 100644 backpack/core/derivatives/lstm.py create mode 100644 test/core/derivatives/lstm_settings.py diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py new file mode 100644 index 000000000..1ad3a79b8 --- /dev/null +++ b/backpack/core/derivatives/lstm.py @@ -0,0 +1,346 @@ +"""Partial derivatives for nn.LSTM.""" +from typing import Tuple + +from torch import Tensor, cat, einsum, sigmoid, tanh, zeros +from torch.nn import LSTM + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives + + +class LSTMDerivatives(BaseParameterDerivatives): + """Partial derivatives for nn.LSTM layer. + + Index conventions: + ------------------ + * t: Sequence dimension + * v: Free dimension + * n: Batch dimension + * h: Output dimension + * i: Input dimension + + LSTM forward pass (definition of variables): + see https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html + ifgo_tilde[t] = W_ih x[t] + b_ii + W_hh h[t-1] + b_hh + ifgo[t] = sigma(ifgo_tilde[t]) for i, f, o + ifgo[t] = tanh(ifgo_tilde[t]) for g + c[t] = f[t] c[t-1] + i[t] g[t] + h[t] = o[t] tanh(c[t]) + """ + + @staticmethod + def _check_parameters(module: LSTM) -> None: + """Check the parameters of module. + + Args: + module: module which to check + + Raises: + NotImplementedError: If any parameter of module does not match expectation + """ + if module.num_layers != 1: + raise NotImplementedError("only num_layers = 1 is supported") + if module.bias is not True: + raise NotImplementedError("only bias = True is supported") + if module.batch_first is not False: + raise NotImplementedError("only batch_first = False is supported") + if module.dropout != 0: + raise NotImplementedError("only dropout = 0 is supported") + if module.bidirectional is not False: + raise NotImplementedError("only bidirectional = False is supported") + if module.proj_size != 0: + raise NotImplementedError("only proj_size = 0 is supported") + + @staticmethod + def _forward_pass( + module: LSTM, mat: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """This performs an additional forward pass and returns the hidden variables. + + This is important because the PyTorch implementation does not grant access to + some of the hidden variables. Those are computed and returned. + + See also forward pass in class docstring. + + Args: + module: module + mat: matrix, used to extract device and shapes + + Returns: + ifgo, c, c_tanh, h + """ + T: int = mat.shape[1] + N: int = mat.shape[2] + H: int = module.hidden_size + H0: int = 0 * H + H1: int = 1 * H + H2: int = 2 * H + H3: int = 3 * H + H4: int = 4 * H + # forward pass and save i, f, g, o, c, c_tanh-> ifgo, c, c_tanh + ifgo: Tensor = zeros(T, N, 4 * H, device=mat.device, dtype=mat.dtype) + c: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) + c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) + h: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) + for t in range(T): + ifgo[t] = ( + einsum("hi,ni->nh", module.weight_ih_l0, module.input0[t]) + + module.bias_ih_l0 + + module.bias_hh_l0 + ) + if t != 0: + ifgo[t] += einsum( + "hg,ng->nh", module.weight_hh_l0, module.output[t - 1] + ) + ifgo[t, :, H0:H1] = sigmoid(ifgo[t, :, H0:H1]) + ifgo[t, :, H1:H2] = sigmoid(ifgo[t, :, H1:H2]) + ifgo[t, :, H2:H3] = tanh(ifgo[t, :, H2:H3]) + ifgo[t, :, H3:H4] = sigmoid(ifgo[t, :, H3:H4]) + c[t] = ifgo[t, :, H0:H1] * ifgo[t, :, H2:H3] + if t != 0: + c[t] += ifgo[t, :, H1:H2] * c[t - 1] + c_tanh[t] = tanh(c[t]) + h[t] = ifgo[t, :, H3:H4] * c_tanh[t] + + return ifgo, c, c_tanh, h + + @classmethod + def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor: + V: int = mat.shape[0] + T: int = mat.shape[1] + N: int = mat.shape[2] + H: int = module.hidden_size + H0: int = 0 * H + H1: int = 1 * H + H2: int = 2 * H + H3: int = 3 * H + H4: int = 4 * H + + ifgo, c, c_tanh, _ = cls._forward_pass(module, mat) + + # backward pass + H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) + C_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) + C_prod_old: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) + IFGO_prod: Tensor = zeros(V, T, N, 4 * H, device=mat.device, dtype=mat.dtype) + for t in reversed(range(T)): + # jac_t_mat_prod until node h + H_prod_t[:] = mat[:, t] + if t != (T - 1): + H_prod_t += einsum( + "vnh,hg->vng", + IFGO_prod[:, t + 1], + module.weight_hh_l0, + ) + + # C_prod_t = jac_t_mat_prod until node c + if t != (T - 1): + C_prod_old[:] = C_prod_t + C_prod_t[:] = einsum( + "vnh,nh->vnh", + H_prod_t, + ifgo[t, :, H3:H4] * (1 - c_tanh[t] ** 2), + ) + if t != (T - 1): + C_prod_t += einsum( + "vnh,nh->vnh", + C_prod_old, + ifgo[t + 1, :, H1:H2], + ) + + IFGO_prod[:, t, :, H3:H4] = einsum( + "vnh,nh->vnh", + H_prod_t, + c_tanh[t] * (ifgo[t, :, H3:H4] * (1 - ifgo[t, :, H3:H4])), + ) + IFGO_prod[:, t, :, H0:H1] = einsum( + "vnh,nh->vnh", + C_prod_t, + ifgo[t, :, H2:H3] * (ifgo[t, :, H0:H1] * (1 - ifgo[t, :, H0:H1])), + ) + if t >= 1: + IFGO_prod[:, t, :, H1:H2] = einsum( + "vnh,nh->vnh", + C_prod_t, + c[t - 1] * (ifgo[t, :, H1:H2] * (1 - ifgo[t, :, H1:H2])), + ) + IFGO_prod[:, t, :, H2:H3] = einsum( + "vnh,nh->vnh", + C_prod_t, + ifgo[t, :, H0:H1] * (1 - ifgo[t, :, H2:H3] ** 2), + ) + return IFGO_prod + + def _jac_mat_prod( + self, + module: LSTM, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + V: int = mat.shape[0] + T: int = mat.shape[1] + N: int = mat.shape[2] + H: int = module.hidden_size + H0: int = 0 * H + H1: int = 1 * H + H2: int = 2 * H + H3: int = 3 * H + H4: int = 4 * H + + ifgo, c, c_tanh, h = self._forward_pass(module, mat) + H_prod: Tensor = zeros(V, T, N, H, device=mat.device, dtype=mat.dtype) + C_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) + C_prod_old: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) + C_tanh_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) + IFGO_prod_t: Tensor = zeros(V, N, 4 * H, device=mat.device, dtype=mat.dtype) + for t in range(T): + # product until nodes ifgo + IFGO_prod_t[:] = einsum( + "hi,vni->vnh", + module.weight_ih_l0, + mat[:, t], + ) + if t != 0: + IFGO_prod_t[:] += einsum( + "hg,vng->vnh", + module.weight_hh_l0, + H_prod[:, t - 1], + ) + IFGO_prod_t[:, :, H0:H2] = einsum( + "vnh,nh->vnh", + IFGO_prod_t[:, :, H0:H2], + ifgo[t, :, H0:H2] * (1 - ifgo[t, :, H0:H2]), + ) + IFGO_prod_t[:, :, H3:H4] = einsum( + "vnh,nh->vnh", + IFGO_prod_t[:, :, H3:H4], + ifgo[t, :, H3:H4] * (1 - ifgo[t, :, H3:H4]), + ) + IFGO_prod_t[:, :, H2:H3] = einsum( + "vnh,nh->vnh", + IFGO_prod_t[:, :, H2:H3], + 1 - ifgo[t, :, H2:H3] ** 2, + ) + + # product until node c + if t >= 1: + C_prod_old[:] = C_prod_t + C_prod_t[:] = ( + einsum( + "vnh,nh->vnh", + IFGO_prod_t[:, :, H0:H1], + ifgo[t, :, H2:H3], + ) + + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H2:H3], ifgo[t, :, H0:H1]) + ) + if t >= 1: + C_prod_t += einsum( + "vnh,nh->vnh", + C_prod_old, + ifgo[t, :, H1:H2], + ) + einsum( + "vnh,nh->vnh", + IFGO_prod_t[:, :, H1:H2], + c[t - 1], + ) + + # product until node c_tanh + C_tanh_prod_t[:] = einsum( + "vnh,nh->vnh", + C_prod_t, + 1 - c_tanh[t] ** 2, + ) + + # product until node h + H_prod[:, t] = einsum( + "vnh,nh->vnh", + IFGO_prod_t[:, :, H3:H4], + c_tanh[t], + ) + einsum( + "vnh,nh->vnh", + C_tanh_prod_t, + ifgo[t, :, H3:H4], + ) + + return H_prod + + def _jac_t_mat_prod( + self, module: LSTM, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + + self._check_parameters(module) + + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + + X_prod: Tensor = einsum( + "vtnh,hi->vtni", + IFGO_prod, + module.weight_ih_l0, + ) + return X_prod + + def _bias_ih_l0_jac_t_mat_prod( + self, + module: LSTM, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + self._check_parameters(module) + + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + + return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod) + + def _bias_hh_l0_jac_t_mat_prod( + self, + module: LSTM, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + self._check_parameters(module) + + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + + return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod) + + def _weight_ih_l0_jac_t_mat_prod( + self, + module: LSTM, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + self._check_parameters(module) + + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + + return einsum( + f"vtnh,tni->v{'' if sum_batch else 'n'}hi", IFGO_prod, module.input0 + ) + + def _weight_hh_l0_jac_t_mat_prod( + self, + module: LSTM, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + self._check_parameters(module) + + N: int = mat.shape[2] + H: int = module.hidden_size + + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + + return einsum( + f"vtnh,tng->v{'' if sum_batch else 'n'}hg", + IFGO_prod, + cat([zeros(1, N, H, device=mat.device), module.output[0:-1]], dim=0), + ) diff --git a/fully_documented.txt b/fully_documented.txt index 480a24b3e..d86430c88 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -5,6 +5,7 @@ backpack/core/derivatives/rnn.py backpack/core/derivatives/shape_check.py backpack/core/derivatives/__init__.py backpack/core/derivatives/permute.py +backpack/core/derivatives/lstm.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py @@ -43,6 +44,7 @@ test/core/derivatives/rnn_settings.py test/core/derivatives/utils.py test/core/derivatives/implementation/ test/core/derivatives/permute_settings.py +test/core/derivatives/lstm_settings.py test/extensions/problem.py test/extensions/test_backprop_extension.py diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index 05748ce57..b89d60d3d 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -1,6 +1,7 @@ """Test functionality of `backpack.core.derivatives` module.""" from torch.nn import ( ELU, + LSTM, RNN, SELU, AvgPool1d, @@ -42,6 +43,7 @@ from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives from backpack.core.derivatives.linear import LinearDerivatives from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives +from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives @@ -82,4 +84,5 @@ MSELoss: MSELossDerivatives, RNN: RNNDerivatives, Permute: PermuteDerivatives, + LSTM: LSTMDerivatives, } diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 8ff0b168a..f77871b1c 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -11,6 +11,7 @@ from test.core.derivatives.implementation.autograd import AutogradDerivatives from test.core.derivatives.implementation.backpack import BackpackDerivatives from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS +from test.core.derivatives.lstm_settings import LSTM_SETTINGS from test.core.derivatives.permute_settings import PERMUTE_SETTINGS from test.core.derivatives.problem import make_test_problems from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS @@ -37,14 +38,17 @@ RNN_PROBLEMS = make_test_problems(RNN_SETTINGS) RNN_IDS = [problem.make_id() for problem in RNN_PROBLEMS] +LSTM_PROBLEMS = make_test_problems(LSTM_SETTINGS) +LSTM_IDS = [problem.make_id() for problem in LSTM_PROBLEMS] + PERMUTE_PROBLEMS = make_test_problems(PERMUTE_SETTINGS) PERMUTE_IDS = [problem.make_id() for problem in PERMUTE_PROBLEMS] @pytest.mark.parametrize( "problem", - NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS, - ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS, + NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + LSTM_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS, ) def test_jac_mat_prod(problem, V=3): """Test the Jacobian-matrix product. @@ -65,8 +69,8 @@ def test_jac_mat_prod(problem, V=3): @pytest.mark.parametrize( "problem", - NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS, - ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS, + NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + LSTM_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS, ) def test_jac_t_mat_prod(problem, V=3): """Test the transposed Jacobian-matrix product. @@ -96,7 +100,9 @@ def test_jac_t_mat_prod(problem, V=3): @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +@pytest.mark.parametrize( + "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS +) def test_bias_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): """Test the transposed Jacobian-matrix product w.r.t. to bias_ih_l0. @@ -122,7 +128,9 @@ def test_bias_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +@pytest.mark.parametrize( + "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS +) def test_bias_hh_l0_jac_t_mat_prod(problem, sum_batch, V=3): """Test the transposed Jacobian-matrix product w.r.t. to bias_hh_l0. @@ -148,7 +156,9 @@ def test_bias_hh_l0_jac_t_mat_prod(problem, sum_batch, V=3): @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +@pytest.mark.parametrize( + "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS +) def test_weight_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): """Test the transposed Jacobian-matrix product w.r.t. to weight_ih_l0. @@ -174,7 +184,9 @@ def test_weight_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize("problem", RNN_PROBLEMS, ids=RNN_IDS) +@pytest.mark.parametrize( + "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS +) def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): """Test the transposed Jacobian-matrix product w.r.t. to weight_hh_l0. diff --git a/test/core/derivatives/lstm_settings.py b/test/core/derivatives/lstm_settings.py new file mode 100644 index 000000000..b3b8397f0 --- /dev/null +++ b/test/core/derivatives/lstm_settings.py @@ -0,0 +1,28 @@ +"""Test configurations for `backpack.core.derivatives` LSTM layers. + +Required entries: + "module_fn" (callable): Contains a model constructed from `torch.nn` layers + "input_fn" (callable): Used for specifying input function + +Optional entries: + "target_fn" (callable): Fetches the groundtruth/target classes + of regression/classification task + "loss_function_fn" (callable): Loss function used in the model + "device" [list(torch.device)]: List of devices to run the test on. + "id_prefix" (str): Prefix to be included in the test name. + "seed" (int): seed for the random number for torch.rand +""" + +import torch + +LSTM_SETTINGS = [] + +############################################################################### +# test settings # +############################################################################### +LSTM_SETTINGS += [ + { + "module_fn": lambda: torch.nn.LSTM(input_size=5, hidden_size=3, num_layers=1), + "input_fn": lambda: torch.rand(size=(10, 8, 5)), + }, +] From 7086f1b3d2f667f292c1141ee580b3ba4b93f0c5 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Mon, 21 Jun 2021 13:15:42 +0200 Subject: [PATCH 13/54] [ADD] `SqrtGGNExact` extension (#180) This PR adds an extension that computes the matrix square root of the generalized Gauss-Newton, using an exact factorization of the loss Hessian. For details, see Equation (3) of [this paper](https://arxiv.org/abs/2106.02624v1). * [ADD] SqrtGGNExact extension * [TEST] Compare GGN via square root with autograd * [TEST] Report explanations for skipped settings * [DEL] Remove validity check for loss Hessian strategy * [FMT] Replace `.format` with f-string * [DOC] Include `SqrtGGNExact` in all-in-one example * [FIX] Typo in docstring * [FMT] Remove blank line * [FIX] Docstring * [REF] Directly evaluate the Hessian sqrt without creating a function * [REF] Change ValueError into NotImplementedError * [ADD] Introduce public getter for `loss_hessian_strategy` * [FMT] Remove blank line, shorten * [DOC] Update default value * [FIX] Make condition and skip explanation more precise --- backpack/extensions/__init__.py | 2 + backpack/extensions/secondorder/__init__.py | 16 ++- .../secondorder/sqrt_ggn/__init__.py | 121 ++++++++++++++++++ .../secondorder/sqrt_ggn/activations.py | 65 ++++++++++ .../extensions/secondorder/sqrt_ggn/base.py | 76 +++++++++++ .../extensions/secondorder/sqrt_ggn/convnd.py | 29 +++++ .../secondorder/sqrt_ggn/convtransposend.py | 29 +++++ .../secondorder/sqrt_ggn/dropout.py | 11 ++ .../secondorder/sqrt_ggn/flatten.py | 47 +++++++ .../extensions/secondorder/sqrt_ggn/linear.py | 11 ++ .../extensions/secondorder/sqrt_ggn/losses.py | 72 +++++++++++ .../secondorder/sqrt_ggn/padding.py | 11 ++ .../secondorder/sqrt_ggn/pooling.py | 56 ++++++++ .../basic_usage/example_all_in_one.py | 13 ++ docs_src/rtd/extensions.rst | 1 + fully_documented.txt | 4 + makefile | 8 +- test/extensions/implementation/autograd.py | 24 +++- test/extensions/implementation/backpack.py | 17 +++ test/extensions/implementation/base.py | 30 +++-- .../secondorder/sqrt_ggn/__init__.py | 1 + .../secondorder/sqrt_ggn/test_sqrt_ggn.py | 50 ++++++++ 22 files changed, 678 insertions(+), 16 deletions(-) create mode 100644 backpack/extensions/secondorder/sqrt_ggn/__init__.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/activations.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/base.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/convnd.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/convtransposend.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/dropout.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/flatten.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/linear.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/losses.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/padding.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/pooling.py create mode 100644 test/extensions/secondorder/sqrt_ggn/__init__.py create mode 100644 test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py index a84a64f71..c66752590 100644 --- a/backpack/extensions/__init__.py +++ b/backpack/extensions/__init__.py @@ -13,6 +13,7 @@ DiagGGNExact, DiagGGNMC, DiagHessian, + SqrtGGNExact, ) __all__ = [ @@ -33,4 +34,5 @@ "BatchDiagGGNMC", "DiagHessian", "BatchDiagHessian", + "SqrtGGNExact", ] diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py index afe9ebc2f..d498236fa 100644 --- a/backpack/extensions/secondorder/__init__.py +++ b/backpack/extensions/secondorder/__init__.py @@ -17,11 +17,20 @@ :func:`KFRA `, :func:`KFLR `. - The diagonal of the Hessian :func:`DiagHessian ` +- The symmetric (square root) factorization of the GGN/Fisher information, + using exact computation + (:func:`SqrtGGNExact `) """ -from .diag_ggn import BatchDiagGGNExact, BatchDiagGGNMC, DiagGGNExact, DiagGGNMC -from .diag_hessian import BatchDiagHessian, DiagHessian -from .hbp import HBP, KFAC, KFLR, KFRA +from backpack.extensions.secondorder.diag_ggn import ( + BatchDiagGGNExact, + BatchDiagGGNMC, + DiagGGNExact, + DiagGGNMC, +) +from backpack.extensions.secondorder.diag_hessian import BatchDiagHessian, DiagHessian +from backpack.extensions.secondorder.hbp import HBP, KFAC, KFLR, KFRA +from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact __all__ = [ "DiagGGNExact", @@ -34,4 +43,5 @@ "KFLR", "KFRA", "HBP", + "SqrtGGNExact", ] diff --git a/backpack/extensions/secondorder/sqrt_ggn/__init__.py b/backpack/extensions/secondorder/sqrt_ggn/__init__.py new file mode 100644 index 000000000..5fdde4f80 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/__init__.py @@ -0,0 +1,121 @@ +"""Defines base class and extensions for computing the GGN/Fisher matrix square root.""" + +from torch.nn import ( + ELU, + SELU, + AvgPool1d, + AvgPool2d, + AvgPool3d, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + CrossEntropyLoss, + Dropout, + Flatten, + LeakyReLU, + Linear, + LogSigmoid, + MaxPool1d, + MaxPool2d, + MaxPool3d, + MSELoss, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) + +from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.hbp import LossHessianStrategy +from backpack.extensions.secondorder.sqrt_ggn import ( + activations, + convnd, + convtransposend, + dropout, + flatten, + linear, + losses, + padding, + pooling, +) + + +class SqrtGGN(BackpropExtension): + """Base class for extensions that compute the GGN/Fisher matrix square root.""" + + def __init__(self, loss_hessian_strategy: str, savefield: str): + """Store approximation for backpropagated object and where to save the result. + + Args: + loss_hessian_strategy: Which approximation is used for the backpropagated + loss Hessian. Must be ``'exact'`` or ``'sampling'``. + savefield: Attribute under which the quantity is saved in a parameter. + """ + self.loss_hessian_strategy = loss_hessian_strategy + super().__init__( + savefield=savefield, + fail_mode="ERROR", + module_exts={ + MSELoss: losses.SqrtGGNMSELoss(), + CrossEntropyLoss: losses.SqrtGGNCrossEntropyLoss(), + Linear: linear.SqrtGGNLinear(), + MaxPool1d: pooling.SqrtGGNMaxPool1d(), + MaxPool2d: pooling.SqrtGGNMaxPool2d(), + AvgPool1d: pooling.SqrtGGNAvgPool1d(), + MaxPool3d: pooling.SqrtGGNMaxPool3d(), + AvgPool2d: pooling.SqrtGGNAvgPool2d(), + AvgPool3d: pooling.SqrtGGNAvgPool3d(), + ZeroPad2d: padding.SqrtGGNZeroPad2d(), + Conv1d: convnd.SqrtGGNConv1d(), + Conv2d: convnd.SqrtGGNConv2d(), + Conv3d: convnd.SqrtGGNConv3d(), + ConvTranspose1d: convtransposend.SqrtGGNConvTranspose1d(), + ConvTranspose2d: convtransposend.SqrtGGNConvTranspose2d(), + ConvTranspose3d: convtransposend.SqrtGGNConvTranspose3d(), + Dropout: dropout.SqrtGGNDropout(), + Flatten: flatten.SqrtGGNFlatten(), + ReLU: activations.SqrtGGNReLU(), + Sigmoid: activations.SqrtGGNSigmoid(), + Tanh: activations.SqrtGGNTanh(), + LeakyReLU: activations.SqrtGGNLeakyReLU(), + LogSigmoid: activations.SqrtGGNLogSigmoid(), + ELU: activations.SqrtGGNELU(), + SELU: activations.SqrtGGNSELU(), + }, + ) + + def get_loss_hessian_strategy(self) -> str: + """Return the strategy used to represent the backpropagated loss Hessian. + + Returns: + Loss Hessian strategy. + """ + return self.loss_hessian_strategy + + +class SqrtGGNExact(SqrtGGN): + """Exact matrix square root of the generalized Gauss-Newton/Fisher. + + Uses the exact Hessian of the loss w.r.t. the model output. + + Stores the output in :code:`sqrt_ggn_exact`, has shape ``[C, N, param.shape]``, + where ``C`` is the model output dimension (number of classes for classification + problems) and ``N`` is the batch size. + + For a faster but less precise alternative, see + :py:meth:`backpack.extensions.SqrtGGNMC`. + + .. note:: + + (Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_exact`` + can be viewed as a ``[C * N, param.numel()]`` matrix. Concatenating this + matrix over all parameters results in a matrix ``Vᵀ``, which + is the GGN/Fisher's matrix square root, i.e. ``G = V Vᵀ``. + """ + + def __init__(self): + """Use exact loss Hessian and set savefield to ``sqrt_ggn_exact``.""" + super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact") diff --git a/backpack/extensions/secondorder/sqrt_ggn/activations.py b/backpack/extensions/secondorder/sqrt_ggn/activations.py new file mode 100644 index 000000000..3aaf8fff2 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/activations.py @@ -0,0 +1,65 @@ +"""Contains extensions for activation layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.elu import ELUDerivatives +from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives +from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives +from backpack.core.derivatives.relu import ReLUDerivatives +from backpack.core.derivatives.selu import SELUDerivatives +from backpack.core.derivatives.sigmoid import SigmoidDerivatives +from backpack.core.derivatives.tanh import TanhDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNReLU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ReLU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ReLU`` module.""" + super().__init__(ReLUDerivatives()) + + +class SqrtGGNSigmoid(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Sigmoid`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Sigmoid`` module.""" + super().__init__(SigmoidDerivatives()) + + +class SqrtGGNTanh(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Tanh`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Tanh`` module.""" + super().__init__(TanhDerivatives()) + + +class SqrtGGNELU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ELU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ELU`` module.""" + super().__init__(ELUDerivatives()) + + +class SqrtGGNSELU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.SELU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.SELU`` module.""" + super().__init__(SELUDerivatives()) + + +class SqrtGGNLeakyReLU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LeakyReLU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.LeakyReLU`` module.""" + super().__init__(LeakyReLUDerivatives()) + + +class SqrtGGNLogSigmoid(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LogSigmoid`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.LogSigmoid`` module.""" + super().__init__(LogSigmoidDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/base.py b/backpack/extensions/secondorder/sqrt_ggn/base.py new file mode 100644 index 000000000..d74949a71 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/base.py @@ -0,0 +1,76 @@ +"""Contains base class for ``SqrtGGN{Exact, MC}`` module extensions.""" +from typing import Any, Callable, List, Tuple, Union + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.basederivatives import ( + BaseDerivatives, + BaseParameterDerivatives, +) +from backpack.extensions.mat_to_mat_jac_base import MatToJacMat + + +class SqrtGGNBaseModule(MatToJacMat): + """Base module extension for ``SqrtGGN{Exact, MC}``.""" + + def __init__( + self, + derivatives: Union[BaseParameterDerivatives, BaseDerivatives], + params: List[str] = None, + ): + """Store parameter names and derivatives. + + Sets up methods that extract the GGN/Fisher matrix square root for the + passed parameters, unless these methods are overwritten by a child class. + + Args: + derivatives: derivatives object. + params: List of parameter names. Defaults to None. + """ + if params is not None: + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) + + super().__init__(derivatives, params=params) + + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def _make_param_function( + self, param_str: str + ) -> Callable[[Any, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor]: + """Create a function that computes the GGN/Fisher square root for a parameter. + + Args: + param_str: name of parameter + + Returns: + Function that computes the GGN/Fisher matrix square root. + """ + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def param_function( + ext: Any, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + backproped: Tensor, + ) -> Tensor: + """Calculate the GGN/Fisher matrix square root with the derivatives object. + + Args: + ext: extension that is used + module: module that performed forward pass + g_inp: input gradient tensors + g_out: output gradient tensors + backproped: Backpropagated quantities from second-order extension. + + Returns: + GGN/Fisher matrix square root. + """ + return getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( + module, g_inp, g_out, backproped, sum_batch=False + ) + + return param_function diff --git a/backpack/extensions/secondorder/sqrt_ggn/convnd.py b/backpack/extensions/secondorder/sqrt_ggn/convnd.py new file mode 100644 index 000000000..74a88651c --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/convnd.py @@ -0,0 +1,29 @@ +"""Contains extensions for convolution layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.conv1d import Conv1DDerivatives +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.core.derivatives.conv3d import Conv3DDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNConv1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Conv1d`` module.""" + super().__init__(Conv1DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConv2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Conv2d`` module.""" + super().__init__(Conv2DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConv3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Conv3d`` module.""" + super().__init__(Conv3DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py new file mode 100644 index 000000000..a18331976 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py @@ -0,0 +1,29 @@ +"""Contains transpose convolution layer extensions used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives +from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives +from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNConvTranspose1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ConvTranspose1d`` module.""" + super().__init__(ConvTranspose1DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConvTranspose2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ConvTranspose2d`` module.""" + super().__init__(ConvTranspose2DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConvTranspose3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ConvTranspose3d`` module.""" + super().__init__(ConvTranspose3DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/sqrt_ggn/dropout.py b/backpack/extensions/secondorder/sqrt_ggn/dropout.py new file mode 100644 index 000000000..2f03b8aa9 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/dropout.py @@ -0,0 +1,11 @@ +"""Contains extensions for dropout layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.dropout import DropoutDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNDropout(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Dropout`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Dropout`` module.""" + super().__init__(DropoutDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/flatten.py b/backpack/extensions/secondorder/sqrt_ggn/flatten.py new file mode 100644 index 000000000..23ce30e9c --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/flatten.py @@ -0,0 +1,47 @@ +"""Contains extensions for the flatten layer used by ``SqrtGGN{Exact, MC}``.""" +from typing import Any, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.flatten import FlattenDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNFlatten(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Tanh`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Tanh`` module.""" + super().__init__(FlattenDerivatives()) + + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def backpropagate( + self, + ext: Any, + module: Module, + grad_inp: Tuple[Tensor], + grad_out: Tuple[Tensor], + backproped: Tensor, + ) -> Tensor: + """Backpropagate only if flatten created a node in the computation graph. + + Otherwise, the backward hook will not be called at the right stage and + no action must be performed. + + Args: + ext: BackPACK extension calling out to the module extension. + module: Module that performed the forward pass. + grad_inp: Gradients w.r.t. the module inputs. + grad_out: Gradients w.r.t. the module outputs. + backproped: Backpropagated symmetric factorization of the loss Hessian + from the child module. + + Returns: + Symmetric loss Hessian factorization, backpropagated through the module. + """ + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/sqrt_ggn/linear.py b/backpack/extensions/secondorder/sqrt_ggn/linear.py new file mode 100644 index 000000000..4aecca6f5 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/linear.py @@ -0,0 +1,11 @@ +"""Contains extension for the linear layer used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNLinear(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Linear`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Linear`` module.""" + super().__init__(LinearDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/sqrt_ggn/losses.py b/backpack/extensions/secondorder/sqrt_ggn/losses.py new file mode 100644 index 000000000..a088aed0a --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/losses.py @@ -0,0 +1,72 @@ +"""Contains base class and extensions for losses used by ``SqrtGGN{Exact, MC}``.""" +from typing import Any, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives +from backpack.core.derivatives.mseloss import MSELossDerivatives +from backpack.extensions.secondorder.hbp import LossHessianStrategy +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNBaseLossModule(SqrtGGNBaseModule): + """Base class for losses used by ``SqrtGGN{Exact, MC}``.""" + + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def backpropagate( + self, + ext: Any, + module: Module, + grad_inp: Tuple[Tensor], + grad_out: Tuple[Tensor], + backproped: None, + ) -> Tensor: + """Initialize the backpropagated quantity. + + Uses the exact loss Hessian square root, or a Monte-Carlo approximation + thereof. + + Args: + ext: BackPACK extension calling out to the module extension. + module: Module that performed the forward pass. + grad_inp: Gradients w.r.t. the module inputs. + grad_out: Gradients w.r.t. the module outputs. + backproped: Backpropagated information. Should be ``None``. + + Returns: + Symmetric factorization of the loss Hessian w.r.t. the module input. + + Raises: + NotImplementedError: For invalid strategies to represent the loss Hessian. + """ + loss_hessian_strategy = ext.get_loss_hessian_strategy() + + if loss_hessian_strategy == LossHessianStrategy.EXACT: + return self.derivatives.sqrt_hessian(module, grad_inp, grad_out) + elif loss_hessian_strategy == LossHessianStrategy.SAMPLING: + mc_samples = ext.get_num_mc_samples() + self.derivatives.sqrt_hessian_sampled( + module, grad_inp, grad_out, mc_samples=mc_samples + ) + else: + raise NotImplementedError( + f"Unknown hessian strategy {loss_hessian_strategy}" + ) + + +class SqrtGGNMSELoss(SqrtGGNBaseLossModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MSELoss`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MSELoss`` module.""" + super().__init__(MSELossDerivatives()) + + +class SqrtGGNCrossEntropyLoss(SqrtGGNBaseLossModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.CrossEntropyLoss`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.CrossEntropyLoss`` module.""" + super().__init__(CrossEntropyLossDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/padding.py b/backpack/extensions/secondorder/sqrt_ggn/padding.py new file mode 100644 index 000000000..18574f685 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/padding.py @@ -0,0 +1,11 @@ +"""Contains extensions for padding layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNZeroPad2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ZeroPad2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ZeroPad2d`` module.""" + super().__init__(ZeroPad2dDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/pooling.py b/backpack/extensions/secondorder/sqrt_ggn/pooling.py new file mode 100644 index 000000000..e19cfba2a --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/pooling.py @@ -0,0 +1,56 @@ +"""Contains extensions for pooling layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives +from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives +from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives +from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives +from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives +from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNMaxPool1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MaxPool1d`` module.""" + super().__init__(MaxPool1DDerivatives()) + + +class SqrtGGNMaxPool2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MaxPool2d`` module.""" + super().__init__(MaxPool2DDerivatives()) + + +class SqrtGGNMaxPool3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MaxPool3d`` module.""" + super().__init__(MaxPool3DDerivatives()) + + +class SqrtGGNAvgPool1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.AvgPool1d`` module.""" + super().__init__(AvgPool1DDerivatives()) + + +class SqrtGGNAvgPool2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.AvgPool2d`` module.""" + super().__init__(AvgPool2DDerivatives()) + + +class SqrtGGNAvgPool3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.AvgPool3d`` module.""" + super().__init__(AvgPool3DDerivatives()) diff --git a/docs_src/examples/basic_usage/example_all_in_one.py b/docs_src/examples/basic_usage/example_all_in_one.py index 49a8f1e46..7be359ef0 100644 --- a/docs_src/examples/basic_usage/example_all_in_one.py +++ b/docs_src/examples/basic_usage/example_all_in_one.py @@ -29,6 +29,7 @@ DiagGGNExact, DiagGGNMC, DiagHessian, + SqrtGGNExact, SumGradSquared, Variance, ) @@ -166,6 +167,18 @@ print(".diag_h.shape: ", param.diag_h.shape) print(".diag_h_batch.shape: ", param.diag_h_batch.shape) +# %% +# Matrix square root of the generalized Gauss-Newton + +loss = lossfunc(model(X), y) +with backpack(SqrtGGNExact()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".sqrt_ggn_exact.shape: ", param.sqrt_ggn_exact.shape) + # %% # Block-diagonal curvature products # --------------------------------- diff --git a/docs_src/rtd/extensions.rst b/docs_src/rtd/extensions.rst index 2fa9e2369..c796294cb 100644 --- a/docs_src/rtd/extensions.rst +++ b/docs_src/rtd/extensions.rst @@ -25,6 +25,7 @@ Available Extensions .. autofunction:: backpack.extensions.KFRA .. autofunction:: backpack.extensions.DiagHessian .. autofunction:: backpack.extensions.BatchDiagHessian +.. autofunction:: backpack.extensions.SqrtGGNExact ----- diff --git a/fully_documented.txt b/fully_documented.txt index 3eaa0f12f..ce4380394 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -4,6 +4,8 @@ test/extensions/secondorder/secondorder_settings.py test/extensions/secondorder/hbp +test/extensions/secondorder/sqrt_ggn + backpack/extensions/secondorder/__init__.py backpack/extensions/secondorder/diag_hessian/__init__.py @@ -12,3 +14,5 @@ backpack/extensions/secondorder/diag_hessian/conv2d.py backpack/extensions/secondorder/diag_hessian/conv3d.py backpack/extensions/__init__.py + +backpack/extensions/secondorder/sqrt_ggn diff --git a/makefile b/makefile index fbcd353f1..b0d4726a3 100644 --- a/makefile +++ b/makefile @@ -56,16 +56,16 @@ help: ### # Test coverage test: - @pytest -vx --run-optional-tests=montecarlo --cov=backpack . + @pytest -vx -rs --run-optional-tests=montecarlo --cov=backpack . test-light: - @pytest -vx --cov=backpack . + @pytest -vx -rs --cov=backpack . test-no-gpu: - @pytest -k "not cuda" -vx --run-optional-tests=montecarlo --cov=backpack . + @pytest -k "not cuda" -vx -rs --run-optional-tests=montecarlo --cov=backpack . test-light-no-gpu: - @pytest -k "not cuda" -vx --cov=backpack . + @pytest -k "not cuda" -vx -rs --cov=backpack . ### # Linter and autoformatter diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index adedb7cf0..2fa1571e7 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -1,8 +1,9 @@ from test.extensions.implementation.base import ExtensionsImplementation import torch +from torch.nn.utils.convert_parameters import parameters_to_vector -from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist +from backpack.hessianfree.ggnvp import ggn_vector_product, ggn_vector_product_from_plist from backpack.hessianfree.rop import R_op from backpack.utils.convert_parameters import vector_to_parameter_list @@ -155,3 +156,24 @@ def diag_h_batch(self): factor = self.problem.get_reduction_factor(batch_loss, loss_list) params_batch_diag_h = list(zip(*batch_diag_h)) return [torch.stack(param) * factor for param in params_batch_diag_h] + + def ggn(self): + _, output, loss = self.problem.forward_pass() + model = self.problem.model + + num_params = sum(p.numel() for p in model.parameters()) + ggn = torch.zeros(num_params, num_params).to(self.problem.device) + + for i in range(num_params): + # GGN-vector product with i.th unit vector yields the i.th row + e_i = torch.zeros(num_params).to(self.problem.device) + e_i[i] = 1.0 + + # convert to model parameter shapes + e_i_list = vector_to_parameter_list(e_i, model.parameters()) + ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) + + ggn_i = parameters_to_vector(ggn_i_list) + ggn[i, :] = ggn_i + + return ggn diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index ae121c6c5..20416acfe 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -5,6 +5,8 @@ SumGradSquaredHook, ) +from torch import cat, einsum + import backpack.extensions as new_ext from backpack import backpack @@ -190,3 +192,18 @@ def diag_h_batch(self): diag_h_batch = [p.diag_h_batch for p in self.problem.model.parameters()] return diag_h_batch + + def ggn(self): + with backpack(new_ext.SqrtGGNExact()): + _, _, loss = self.problem.forward_pass() + loss.backward() + + V = cat( + [ + p.sqrt_ggn_exact.flatten(start_dim=2) + for p in self.problem.model.parameters() + ], + dim=2, + ) + + return einsum("cni,cnj->ij", V, V) diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index 0b2eb73b4..6fd1ce5f9 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -1,3 +1,9 @@ +"""Base class containing the functions to compare BackPACK and autograd.""" +from typing import List + +from torch import Tensor + + class ExtensionsImplementation: """Base class for autograd and BackPACK implementations of extensions.""" @@ -40,36 +46,44 @@ def diag_h(self): """Diagonal of Hessian""" raise NotImplementedError - def kfac(self, mc_samples=1): + def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: """Kronecker-factored approximate curvature (KFAC). Args: - mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``. + mc_samples: Number of Monte-Carlo samples. Default: ``1``. Returns: - list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + Parameter-wise lists of Kronecker factors. """ raise NotImplementedError - def kflr(self): + def kflr(self) -> List[List[Tensor]]: """Kronecker-factored low-rank approximation (KFLR). Returns: - list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + Parameter-wise lists of Kronecker factors. """ raise NotImplementedError - def kfra(self): + def kfra(self) -> List[List[Tensor]]: """Kronecker-factored recursive approximation (KFRA). Returns: - list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + Parameter-wise lists of Kronecker factors. """ - def diag_h_batch(self): + def diag_h_batch(self) -> List[Tensor]: """Per-sample Hessian diagonal. Returns: list(torch.Tensor): Parameter-wise per-sample Hessian diagonal. """ raise NotImplementedError + + def ggn(self) -> Tensor: + """Exact generalized Gauss-Newton/Fisher matrix. + + Returns: + Matrix representation of the exact GGN. + """ + raise NotImplementedError diff --git a/test/extensions/secondorder/sqrt_ggn/__init__.py b/test/extensions/secondorder/sqrt_ggn/__init__.py new file mode 100644 index 000000000..6741d2c13 --- /dev/null +++ b/test/extensions/secondorder/sqrt_ggn/__init__.py @@ -0,0 +1 @@ +"""Contains tests of ``backpack.extensions.secondorder.sqrt_ggn``.""" diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py new file mode 100644 index 000000000..99ba3689f --- /dev/null +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -0,0 +1,50 @@ +"""Tests BackPACK's ``SqrtGGNExact`` extension.""" + +from test.automated_test import check_sizes_and_values +from test.extensions.implementation.autograd import AutogradExtensions +from test.extensions.implementation.backpack import BackpackExtensions +from test.extensions.problem import ExtensionsTestProblem, make_test_problems +from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS + +from pytest import fixture, skip + +PROBLEMS = make_test_problems(SECONDORDER_SETTINGS) +IDS = [problem.make_id() for problem in PROBLEMS] + + +@fixture(params=PROBLEMS, ids=IDS) +def problem(request, max_num_params: int = 4000) -> ExtensionsTestProblem: + """Set seed, create tested model, loss, data. Finally clean up. + + Models with too many parameters are skipped. + + Args: + request (SubRequest): Request for the fixture from a test/fixture function. + max_num_params: Maximum number of model parameters to run the case. + Default: ``4000``. + + Yields: + Test case with deterministically constructed attributes. + """ + case = request.param + case.set_up() + + num_params = sum(p.numel() for p in case.model.parameters() if p.requires_grad) + if num_params <= max_num_params: + yield case + else: + skip(f"Model has too many parameters: {num_params} > {max_num_params}") + + case.tear_down() + + +def test_sqrt_ggn_exact(problem: ExtensionsTestProblem): + """Compare exact GGN from BackPACK's matrix square root with autograd. + + Args: + problem: Test case + """ + autograd_res = AutogradExtensions(problem).ggn() + backpack_res = BackpackExtensions(problem).ggn() + + check_sizes_and_values(autograd_res, backpack_res) From d11e0d834618b480080a618b160a69625e68bf84 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Mon, 21 Jun 2021 13:36:08 +0200 Subject: [PATCH 14/54] [TEST] Add setting where `Flatten` does not perform an operation (#181) Adds an architecture with a `torch.nn.Flatten` layer that does not introduce a node in the computation graph. In this case the backward hook for the `Flatten` module will be called at an unexpected stage, which must be addressed in the second-order extensions. [This](https://coveralls.io/builds/40739633/source?filename=backpack%2Fextensions%2Fsecondorder%2Fsqrt_ggn%2Fflatten.py#L45) coverage report indicates that the new test suite did not include such a test case, whereas it seems to be covered by the old tests (the coverage report for second-order extensions that are tested in the old suite covers the linked branch in the module extensions for `Flatten`). --- .../secondorder/secondorder_settings.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index f90456134..f50e4658e 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -233,3 +233,18 @@ ] SECONDORDER_SETTINGS += GROUP_CONV_SETTINGS + +SECONDORDER_SETTINGS += [ + { + # Flatten layer does not add a node in the computation graph and thus the + # backward hook will be called at an unexpected stage. This must explicitly + # be addressed in the `backpropagate` function of the flatten module extension. + "input_fn": lambda: torch.rand(3, 5), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Linear(5, 4), torch.nn.Flatten(), torch.nn.Linear(4, 2) + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 2), + "id_prefix": "flatten-no-op", + }, +] From 9569e6637265e33afc7eadf15275e0615c556a2a Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Mon, 21 Jun 2021 17:44:27 +0200 Subject: [PATCH 15/54] [ADD] `SqrtGGNMC` extension (#182) A Monte-Carlo approximation of the generalized Gauss-Newton/Fisher matrix square root. * [REF] Split instantiation and filtering of test cases * [ADD] `SqrtGGNMC` extension * [ADD] Make base implementation of extensions abstract, type annotations * [TEST] Share names of BackPACK/autograd implementation * [TEST] Compare MC with exact GGN using many samples * [FIX] Increase atol to make tests pass on GPU, too * [ADD] Integration test for `SqrtGGNMC`, fix docstring --- backpack/extensions/__init__.py | 2 + backpack/extensions/secondorder/__init__.py | 5 +- .../secondorder/sqrt_ggn/__init__.py | 38 +++++++ .../extensions/secondorder/sqrt_ggn/losses.py | 2 +- .../basic_usage/example_all_in_one.py | 6 +- docs_src/rtd/extensions.rst | 1 + fully_documented.txt | 2 + test/extensions/implementation/autograd.py | 20 +++- test/extensions/implementation/backpack.py | 36 ++++-- test/extensions/implementation/base.py | 106 +++++++++++++----- .../diag_ggn/test_batch_diag_ggn.py | 4 +- .../secondorder/sqrt_ggn/test_sqrt_ggn.py | 76 ++++++++++--- 12 files changed, 234 insertions(+), 64 deletions(-) diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py index c66752590..df0bd558d 100644 --- a/backpack/extensions/__init__.py +++ b/backpack/extensions/__init__.py @@ -14,6 +14,7 @@ DiagGGNMC, DiagHessian, SqrtGGNExact, + SqrtGGNMC, ) __all__ = [ @@ -35,4 +36,5 @@ "DiagHessian", "BatchDiagHessian", "SqrtGGNExact", + "SqrtGGNMC", ] diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py index d498236fa..3d3de33c9 100644 --- a/backpack/extensions/secondorder/__init__.py +++ b/backpack/extensions/secondorder/__init__.py @@ -20,6 +20,8 @@ - The symmetric (square root) factorization of the GGN/Fisher information, using exact computation (:func:`SqrtGGNExact `) + or a Monte-Carlo (MC) approximation + (:func:`SqrtGGNMC`) """ from backpack.extensions.secondorder.diag_ggn import ( @@ -30,7 +32,7 @@ ) from backpack.extensions.secondorder.diag_hessian import BatchDiagHessian, DiagHessian from backpack.extensions.secondorder.hbp import HBP, KFAC, KFLR, KFRA -from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact +from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC __all__ = [ "DiagGGNExact", @@ -44,4 +46,5 @@ "KFRA", "HBP", "SqrtGGNExact", + "SqrtGGNMC", ] diff --git a/backpack/extensions/secondorder/sqrt_ggn/__init__.py b/backpack/extensions/secondorder/sqrt_ggn/__init__.py index 5fdde4f80..d258b350f 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/__init__.py +++ b/backpack/extensions/secondorder/sqrt_ggn/__init__.py @@ -119,3 +119,41 @@ class SqrtGGNExact(SqrtGGN): def __init__(self): """Use exact loss Hessian and set savefield to ``sqrt_ggn_exact``.""" super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact") + + +class SqrtGGNMC(SqrtGGN): + """Approximate matrix square root of the generalized Gauss-Newton/Fisher. + + Uses a Monte-Carlo (MC) approximation of the Hessian of the loss w.r.t. the model + output. + + Stores the output in :code:`sqrt_ggn_mc`, has shape ``[M, N, param.shape]``, + where ``M`` is the number of Monte-Carlo samples and ``N`` is the batch size. + + For a more precise but slower alternative, see + :py:meth:`backpack.extensions.SqrtGGNExact`. + + .. note:: + + (Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_mc`` + can be viewed as a ``[M * N, param.numel()]`` matrix. Concatenating this + matrix over all parameters results in a matrix ``Vᵀ``, which + is the approximate GGN/Fisher's matrix square root, i.e. ``G ≈ V Vᵀ``. + """ + + def __init__(self, mc_samples: int = 1): + """Approximate loss Hessian via MC and set savefield to ``sqrt_ggn_mc``. + + Args: + mc_samples: Number of Monte-Carlo samples. Default: ``1``. + """ + self._mc_samples = mc_samples + super().__init__(LossHessianStrategy.SAMPLING, "sqrt_ggn_mc") + + def get_num_mc_samples(self) -> int: + """Return the number of MC samples used to approximate the loss Hessian. + + Returns: + Number of Monte-Carlo samples. + """ + return self._mc_samples diff --git a/backpack/extensions/secondorder/sqrt_ggn/losses.py b/backpack/extensions/secondorder/sqrt_ggn/losses.py index a088aed0a..8e48cf318 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/losses.py +++ b/backpack/extensions/secondorder/sqrt_ggn/losses.py @@ -47,7 +47,7 @@ def backpropagate( return self.derivatives.sqrt_hessian(module, grad_inp, grad_out) elif loss_hessian_strategy == LossHessianStrategy.SAMPLING: mc_samples = ext.get_num_mc_samples() - self.derivatives.sqrt_hessian_sampled( + return self.derivatives.sqrt_hessian_sampled( module, grad_inp, grad_out, mc_samples=mc_samples ) else: diff --git a/docs_src/examples/basic_usage/example_all_in_one.py b/docs_src/examples/basic_usage/example_all_in_one.py index 7be359ef0..cb7aba42d 100644 --- a/docs_src/examples/basic_usage/example_all_in_one.py +++ b/docs_src/examples/basic_usage/example_all_in_one.py @@ -30,6 +30,7 @@ DiagGGNMC, DiagHessian, SqrtGGNExact, + SqrtGGNMC, SumGradSquared, Variance, ) @@ -168,16 +169,17 @@ print(".diag_h_batch.shape: ", param.diag_h_batch.shape) # %% -# Matrix square root of the generalized Gauss-Newton +# Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation loss = lossfunc(model(X), y) -with backpack(SqrtGGNExact()): +with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".sqrt_ggn_exact.shape: ", param.sqrt_ggn_exact.shape) + print(".sqrt_ggn_mc.shape: ", param.sqrt_ggn_mc.shape) # %% # Block-diagonal curvature products diff --git a/docs_src/rtd/extensions.rst b/docs_src/rtd/extensions.rst index c796294cb..9eea8df02 100644 --- a/docs_src/rtd/extensions.rst +++ b/docs_src/rtd/extensions.rst @@ -26,6 +26,7 @@ Available Extensions .. autofunction:: backpack.extensions.DiagHessian .. autofunction:: backpack.extensions.BatchDiagHessian .. autofunction:: backpack.extensions.SqrtGGNExact +.. autofunction:: backpack.extensions.SqrtGGNMC ----- diff --git a/fully_documented.txt b/fully_documented.txt index ce4380394..e1fdfc81b 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -6,6 +6,8 @@ test/extensions/secondorder/hbp test/extensions/secondorder/sqrt_ggn +test/extensions/implementation/base.py + backpack/extensions/secondorder/__init__.py backpack/extensions/secondorder/diag_hessian/__init__.py diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index 2fa1571e7..f26979a38 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -94,7 +94,7 @@ def diag_ggn(self): _, output, loss = self.problem.forward_pass() return self._get_diag_ggn(loss, output) - def diag_ggn_batch(self): + def diag_ggn_exact_batch(self): batch_size = self.problem.input.shape[0] _, _, batch_loss = self.problem.forward_pass() loss_list = torch.zeros(batch_size, device=self.problem.device) @@ -177,3 +177,21 @@ def ggn(self): ggn[i, :] = ggn_i return ggn + + def diag_ggn_mc(self, mc_samples): + raise NotImplementedError + + def diag_ggn_mc_batch(self, mc_samples): + raise NotImplementedError + + def ggn_mc(self, mc_samples, chunks=1): + raise NotImplementedError + + def kfac(self): + raise NotImplementedError + + def kflr(self): + raise NotImplementedError + + def kfra(self): + raise NotImplementedError diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index 20416acfe..04a2c7df8 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -4,8 +4,9 @@ ExtensionHookManager, SumGradSquaredHook, ) +from typing import List -from torch import cat, einsum +from torch import Tensor, cat, einsum import backpack.extensions as new_ext from backpack import backpack @@ -194,16 +195,33 @@ def diag_h_batch(self): return diag_h_batch def ggn(self): + return self._square_sqrt_ggn(self.sqrt_ggn()) + + def sqrt_ggn(self) -> List[Tensor]: with backpack(new_ext.SqrtGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() - V = cat( - [ - p.sqrt_ggn_exact.flatten(start_dim=2) - for p in self.problem.model.parameters() - ], - dim=2, - ) + return [p.sqrt_ggn_exact for p in self.problem.model.parameters()] + + def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: + with backpack(new_ext.SqrtGGNMC(mc_samples=mc_samples)): + _, _, loss = self.problem.forward_pass() + loss.backward() + + return [p.sqrt_ggn_mc for p in self.problem.model.parameters()] + + def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: + samples = self.chunk_sizes(mc_samples, chunks) + weights = [samples / mc_samples for samples in samples] - return einsum("cni,cnj->ij", V, V) + return sum( + w * self._square_sqrt_ggn(self.sqrt_ggn_mc(s)) + for w, s in zip(weights, samples) + ) + + @staticmethod + def _square_sqrt_ggn(sqrt_ggn: List[Tensor]) -> Tensor: + """Utility function to concatenate and square the GGN factorization.""" + sqrt_mat = cat([s.flatten(start_dim=2) for s in sqrt_ggn], dim=2) + return einsum("cni,cnj->ij", sqrt_mat, sqrt_mat) diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index 6fd1ce5f9..f3d73d91a 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -1,51 +1,76 @@ """Base class containing the functions to compare BackPACK and autograd.""" +from abc import ABC, abstractmethod +from test.extensions.problem import ExtensionsTestProblem from typing import List from torch import Tensor -class ExtensionsImplementation: +class ExtensionsImplementation(ABC): """Base class for autograd and BackPACK implementations of extensions.""" - def __init__(self, problem): + def __init__(self, problem: ExtensionsTestProblem): + """Store the test case. + + Args: + problem: Test case. + """ self.problem = problem - def batch_grad(self): + @abstractmethod + def batch_grad(self) -> List[Tensor]: """Individual gradients.""" - raise NotImplementedError + return - def batch_l2_grad(self): + @abstractmethod + def batch_l2_grad(self) -> List[Tensor]: """L2 norm of Individual gradients.""" - raise NotImplementedError + return + + @abstractmethod + def sgs(self) -> List[Tensor]: + """Sum of Square of Individual gradients.""" + return - def sgs(self): - """Sum of Square of Individual gradients""" - raise NotImplementedError + @abstractmethod + def variance(self) -> List[Tensor]: + """Variance of Individual gradients.""" + return - def variance(self): - """Variance of Individual gradients""" - raise NotImplementedError + @abstractmethod + def diag_ggn(self) -> List[Tensor]: + """Diagonal of Gauss Newton.""" + return - def diag_ggn(self): - """Diagonal of Gauss Newton""" - raise NotImplementedError + @abstractmethod + def diag_ggn_exact_batch(self) -> List[Tensor]: + """Individual diagonal of Generalized Gauss-Newton/Fisher.""" + return - def diag_ggn_batch(self): - """Individual diagonal of Generalized Gauss-Newton/Fisher""" - raise NotImplementedError + @abstractmethod + def diag_ggn_mc(self, mc_samples: int) -> List[Tensor]: + """MC approximation of the generalized Gauss-Newton/Fisher diagonal. - def diag_ggn_mc(self, mc_samples): - """MC approximation of Diagonal of Gauss Newton""" - raise NotImplementedError + Args: + mc_samples: Number of Monte-Carlo samples used for the approximation. + """ + return - def diag_ggn_mc_batch(self, mc_samples): - """MC approximation of individual Generalized Gauss-Newton/Fisher diagonal.""" - raise NotImplementedError + @abstractmethod + def diag_ggn_mc_batch(self, mc_samples: int) -> List[Tensor]: + """MC approximation of individual Generalized Gauss-Newton/Fisher diagonal. - def diag_h(self): - """Diagonal of Hessian""" - raise NotImplementedError + Args: + mc_samples: Number of Monte-Carlo samples used for the approximation. + """ + return + @abstractmethod + def diag_h(self) -> List[Tensor]: + """Diagonal of Hessian.""" + return + + @abstractmethod def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: """Kronecker-factored approximate curvature (KFAC). @@ -55,35 +80,54 @@ def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: Returns: Parameter-wise lists of Kronecker factors. """ - raise NotImplementedError + return + @abstractmethod def kflr(self) -> List[List[Tensor]]: """Kronecker-factored low-rank approximation (KFLR). Returns: Parameter-wise lists of Kronecker factors. """ - raise NotImplementedError + return + @abstractmethod def kfra(self) -> List[List[Tensor]]: """Kronecker-factored recursive approximation (KFRA). Returns: Parameter-wise lists of Kronecker factors. """ + return + @abstractmethod def diag_h_batch(self) -> List[Tensor]: """Per-sample Hessian diagonal. Returns: list(torch.Tensor): Parameter-wise per-sample Hessian diagonal. """ - raise NotImplementedError + return + @abstractmethod def ggn(self) -> Tensor: """Exact generalized Gauss-Newton/Fisher matrix. Returns: Matrix representation of the exact GGN. """ - raise NotImplementedError + return + + @abstractmethod + def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: + """Compute the MC-approximation of the GGN in chunks of MC samples. + + Args: + mc_samples: Number of Monte-Carlo samples. + chunks: Number of sequential portions to split the computation. + Default: ``1`` (no sequential split). + + Returns: + Matrix representation of the Monte-Carlo approximated GGN. + """ + return diff --git a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py index 378c0a2c7..f0587b57b 100644 --- a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_diag_ggn_batch(problem): +def test_diag_ggn_exact_batch(problem): """Test the individual diagonal of Generalized Gauss-Newton/Fisher Args: @@ -20,7 +20,7 @@ def test_diag_ggn_batch(problem): problem.set_up() backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() - autograd_res = AutogradExtensions(problem).diag_ggn_batch() + autograd_res = AutogradExtensions(problem).diag_ggn_exact_batch() check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py index 99ba3689f..96cde05c5 100644 --- a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -1,4 +1,4 @@ -"""Tests BackPACK's ``SqrtGGNExact`` extension.""" +"""Tests BackPACK's ``SqrtGGNExact`` and ``SqrtGGNMC`` extension.""" from test.automated_test import check_sizes_and_values from test.extensions.implementation.autograd import AutogradExtensions @@ -6,45 +6,87 @@ from test.extensions.problem import ExtensionsTestProblem, make_test_problems from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS -from pytest import fixture, skip +from pytest import fixture, mark, skip PROBLEMS = make_test_problems(SECONDORDER_SETTINGS) -IDS = [problem.make_id() for problem in PROBLEMS] -@fixture(params=PROBLEMS, ids=IDS) -def problem(request, max_num_params: int = 4000) -> ExtensionsTestProblem: +@fixture(params=PROBLEMS, ids=lambda p: p.make_id()) +def instantiated_problem(request) -> ExtensionsTestProblem: """Set seed, create tested model, loss, data. Finally clean up. - Models with too many parameters are skipped. - Args: request (SubRequest): Request for the fixture from a test/fixture function. - max_num_params: Maximum number of model parameters to run the case. - Default: ``4000``. Yields: Test case with deterministically constructed attributes. """ case = request.param case.set_up() + yield case + case.tear_down() - num_params = sum(p.numel() for p in case.model.parameters() if p.requires_grad) + +@fixture +def small_problem( + instantiated_problem: ExtensionsTestProblem, max_num_params=4000 +) -> ExtensionsTestProblem: + """Skip architectures with too many parameters whose GGN is expensive to evaluate. + + Args: + instantiated_problem: Test case with instantiated model, data, etc. + max_num_params: Maximum number of model parameters to run the case. + Default: ``4000``. + + Yields: + Instantiated test case whose model's are small enough. + """ + num_params = sum( + p.numel() for p in instantiated_problem.model.parameters() if p.requires_grad + ) if num_params <= max_num_params: - yield case + yield instantiated_problem else: skip(f"Model has too many parameters: {num_params} > {max_num_params}") - case.tear_down() - -def test_sqrt_ggn_exact(problem: ExtensionsTestProblem): +def test_ggn_exact(small_problem: ExtensionsTestProblem): """Compare exact GGN from BackPACK's matrix square root with autograd. Args: - problem: Test case + small_problem: Test case with small network whose GGN can be evaluated. """ - autograd_res = AutogradExtensions(problem).ggn() - backpack_res = BackpackExtensions(problem).ggn() + autograd_res = AutogradExtensions(small_problem).ggn() + backpack_res = BackpackExtensions(small_problem).ggn() check_sizes_and_values(autograd_res, backpack_res) + + +def test_sqrt_ggn_mc_integration(small_problem: ExtensionsTestProblem): + """Check if MC-approximated GGN matrix square root code executes. + + Note: + This test does not perform correctness checks on the results, + which are expensive because a large number of samples is required. + Such a check is performed by `test_sqrt_ggn_mc`, which is run less + frequently. + + Args: + small_problem: Test case with small network whose GGN can be evaluated. + """ + BackpackExtensions(small_problem).sqrt_ggn_mc(mc_samples=1) + + +@mark.montecarlo +def test_ggn_mc(small_problem: ExtensionsTestProblem): + """Compare MC-approximated GGN from BackpACK's with exact version from autograd. + + Args: + small_problem: Test case with small network whose GGN can be evaluated. + """ + autograd_res = AutogradExtensions(small_problem).ggn() + atol, rtol = 1e-3, 1e-2 + mc_samples, chunks = 500000, 50 + backpack_res = BackpackExtensions(small_problem).ggn_mc(mc_samples, chunks=chunks) + + check_sizes_and_values(autograd_res, backpack_res, atol=atol, rtol=rtol) From aaa26d42d86c1f9a8c7a650b413b8e4a635c9dd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 23 Jun 2021 12:44:38 +0200 Subject: [PATCH 16/54] [REF] Introduce base class for `BatchL2Grad` extension (#175) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Base class automatically uses derivatives to compute individual gradients, then compute their L2 norm. Those default extraction functions can be overwritten if there exists a more efficient extraction without computing individual gradients (e.g. linear layers). --- * Refactor: base class for batch_l2 * format * simplify convnd: delete bias * refactor convtransposend.py * refactor convnd.py * refactor linear.py * type hint linear.py * refactor batch_l2_base.py * delete "params=" Co-authored-by: Tim Schäfer --- .../firstorder/batch_l2_grad/__init__.py | 23 ++---- .../firstorder/batch_l2_grad/batch_l2_base.py | 75 +++++++++++++++++++ .../firstorder/batch_l2_grad/conv1d.py | 6 -- .../firstorder/batch_l2_grad/conv2d.py | 6 -- .../firstorder/batch_l2_grad/conv3d.py | 6 -- .../firstorder/batch_l2_grad/convnd.py | 57 ++++++++++---- .../batch_l2_grad/convtranspose1d.py | 8 -- .../batch_l2_grad/convtranspose2d.py | 8 -- .../batch_l2_grad/convtranspose3d.py | 8 -- .../batch_l2_grad/convtransposend.py | 57 ++++++++++---- .../firstorder/batch_l2_grad/linear.py | 46 +++++++++--- .../firstorder/batch_l2_grad/rnn.py | 68 ++--------------- backpack/utils/conv.py | 2 +- backpack/utils/conv_transpose.py | 2 +- fully_documented.txt | 3 +- 15 files changed, 216 insertions(+), 159 deletions(-) create mode 100644 backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py delete mode 100644 backpack/extensions/firstorder/batch_l2_grad/conv1d.py delete mode 100644 backpack/extensions/firstorder/batch_l2_grad/conv2d.py delete mode 100644 backpack/extensions/firstorder/batch_l2_grad/conv3d.py delete mode 100644 backpack/extensions/firstorder/batch_l2_grad/convtranspose1d.py delete mode 100644 backpack/extensions/firstorder/batch_l2_grad/convtranspose2d.py delete mode 100644 backpack/extensions/firstorder/batch_l2_grad/convtranspose3d.py diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 7d2d97ed1..c68bfc820 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -16,16 +16,7 @@ from backpack.extensions.backprop_extension import BackpropExtension -from . import ( - conv1d, - conv2d, - conv3d, - convtranspose1d, - convtranspose2d, - convtranspose3d, - linear, - rnn, -) +from . import convnd, convtransposend, linear, rnn class BatchL2Grad(BackpropExtension): @@ -56,12 +47,12 @@ def __init__(self): fail_mode="WARNING", module_exts={ Linear: linear.BatchL2Linear(), - Conv1d: conv1d.BatchL2Conv1d(), - Conv2d: conv2d.BatchL2Conv2d(), - Conv3d: conv3d.BatchL2Conv3d(), - ConvTranspose1d: convtranspose1d.BatchL2ConvTranspose1d(), - ConvTranspose2d: convtranspose2d.BatchL2ConvTranspose2d(), - ConvTranspose3d: convtranspose3d.BatchL2ConvTranspose3d(), + Conv1d: convnd.BatchL2Conv1d(), + Conv2d: convnd.BatchL2Conv2d(), + Conv3d: convnd.BatchL2Conv3d(), + ConvTranspose1d: convtransposend.BatchL2ConvTranspose1d(), + ConvTranspose2d: convtransposend.BatchL2ConvTranspose2d(), + ConvTranspose3d: convtransposend.BatchL2ConvTranspose3d(), RNN: rnn.BatchL2RNN(), }, ) diff --git a/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py new file mode 100644 index 000000000..6604f811e --- /dev/null +++ b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py @@ -0,0 +1,75 @@ +"""Contains Base class for batch_l2_grad.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.extensions.firstorder.base import FirstOrderModuleExtension + +if TYPE_CHECKING: + from backpack.extensions import BatchL2Grad + + +class BatchL2Base(FirstOrderModuleExtension): + """BaseExtension for batch_l2.""" + + def __init__(self, params: List[str], derivatives: BaseParameterDerivatives = None): + """Initialization. + + If derivatives object is provided initializes methods that compute batch_l2. + If there is an existent method in a child class it is not overwritten. + + Args: + params: parameter names + derivatives: derivatives object. Defaults to None. + """ + if derivatives is not None: + self.derivatives: BaseParameterDerivatives = derivatives + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) + super().__init__(params=params) + + def _make_param_function( + self, param_str: str + ) -> Callable[[BatchL2Grad, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]: + """Creates a function that calculates batch_l2. + + Args: + param_str: name of parameter + + Returns: + function that calculates batch_l2 + """ + + def param_function( + ext: BatchL2Grad, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + bpQuantities: None, + ) -> Tensor: + """Calculates batch_l2 with the help of derivatives object. + + Args: + ext: extension that is used + module: module that performed forward pass + g_inp: input gradient tensors + g_out: output gradient tensors + bpQuantities: additional quantities for second order + + Returns: + batch_l2 + """ + param_dims: List[int] = list(range(1, 1 + getattr(module, param_str).dim())) + return ( + getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( + module, g_inp, g_out, g_out[0], sum_batch=False + ) + ** 2 + ).sum(param_dims) + + return param_function diff --git a/backpack/extensions/firstorder/batch_l2_grad/conv1d.py b/backpack/extensions/firstorder/batch_l2_grad/conv1d.py deleted file mode 100644 index 64eb36066..000000000 --- a/backpack/extensions/firstorder/batch_l2_grad/conv1d.py +++ /dev/null @@ -1,6 +0,0 @@ -from backpack.extensions.firstorder.batch_l2_grad.convnd import BatchL2ConvND - - -class BatchL2Conv1d(BatchL2ConvND): - def __init__(self): - super().__init__(N=1, params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/conv2d.py b/backpack/extensions/firstorder/batch_l2_grad/conv2d.py deleted file mode 100644 index 327c90598..000000000 --- a/backpack/extensions/firstorder/batch_l2_grad/conv2d.py +++ /dev/null @@ -1,6 +0,0 @@ -from backpack.extensions.firstorder.batch_l2_grad.convnd import BatchL2ConvND - - -class BatchL2Conv2d(BatchL2ConvND): - def __init__(self): - super().__init__(N=2, params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/conv3d.py b/backpack/extensions/firstorder/batch_l2_grad/conv3d.py deleted file mode 100644 index 369f6bb8a..000000000 --- a/backpack/extensions/firstorder/batch_l2_grad/conv3d.py +++ /dev/null @@ -1,6 +0,0 @@ -from backpack.extensions.firstorder.batch_l2_grad.convnd import BatchL2ConvND - - -class BatchL2Conv3d(BatchL2ConvND): - def __init__(self): - super().__init__(N=3, params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/convnd.py b/backpack/extensions/firstorder/batch_l2_grad/convnd.py index 55c542a03..991eb96e2 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/convnd.py +++ b/backpack/extensions/firstorder/batch_l2_grad/convnd.py @@ -1,23 +1,54 @@ +"""batch_l2 extension for Conv.""" from torch import einsum -from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.core.derivatives.conv1d import Conv1DDerivatives +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.core.derivatives.conv3d import Conv3DDerivatives +from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base from backpack.utils import conv as convUtils -class BatchL2ConvND(FirstOrderModuleExtension): - def __init__(self, N, params=None): - super().__init__(params=params) - self.N = N +class BatchL2ConvND(BatchL2Base): + """batch_l2 extension for Conv.""" - # TODO Use bias Jacobian to compute `bias_gradient` - def bias(self, ext, module, g_inp, g_out, backproped): - spatial_dims = list(range(2, g_out[0].dim())) - channel_dim = 1 + def weight(self, ext, module, g_inp, g_out, backproped): + """batch_l2 for weight. - return g_out[0].sum(spatial_dims).pow_(2).sum(channel_dim) + Args: + ext: extension + module: module + g_inp: input gradients + g_out: output gradients + backproped: backpropagation quantities - def weight(self, ext, module, g_inp, g_out, backproped): + Returns: + batch_l2 for weight + """ X, dE_dY = convUtils.get_weight_gradient_factors( - module.input0, g_out[0], module, self.N + module.input0, g_out[0], module ) - return einsum("nmi,nki,nmj,nkj->n", (dE_dY, X, dE_dY, X)) + return einsum("nmi,nki,nmj,nkj->n", dE_dY, X, dE_dY, X) + + +class BatchL2Conv1d(BatchL2ConvND): + """batch_l2 extension for Conv1d.""" + + def __init__(self): + """Initialization.""" + super().__init__(["bias", "weight"], derivatives=Conv1DDerivatives()) + + +class BatchL2Conv2d(BatchL2ConvND): + """batch_l2 extension for Conv2d.""" + + def __init__(self): + """Initialization.""" + super().__init__(["bias", "weight"], derivatives=Conv2DDerivatives()) + + +class BatchL2Conv3d(BatchL2ConvND): + """batch_l2 extension for Conv3d.""" + + def __init__(self): + """Initialization.""" + super().__init__(["bias", "weight"], derivatives=Conv3DDerivatives()) diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtranspose1d.py b/backpack/extensions/firstorder/batch_l2_grad/convtranspose1d.py deleted file mode 100644 index aad345c58..000000000 --- a/backpack/extensions/firstorder/batch_l2_grad/convtranspose1d.py +++ /dev/null @@ -1,8 +0,0 @@ -from backpack.extensions.firstorder.batch_l2_grad.convtransposend import ( - BatchL2ConvTransposeND, -) - - -class BatchL2ConvTranspose1d(BatchL2ConvTransposeND): - def __init__(self): - super().__init__(N=1, params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtranspose2d.py b/backpack/extensions/firstorder/batch_l2_grad/convtranspose2d.py deleted file mode 100644 index 0d916fbed..000000000 --- a/backpack/extensions/firstorder/batch_l2_grad/convtranspose2d.py +++ /dev/null @@ -1,8 +0,0 @@ -from backpack.extensions.firstorder.batch_l2_grad.convtransposend import ( - BatchL2ConvTransposeND, -) - - -class BatchL2ConvTranspose2d(BatchL2ConvTransposeND): - def __init__(self): - super().__init__(N=2, params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtranspose3d.py b/backpack/extensions/firstorder/batch_l2_grad/convtranspose3d.py deleted file mode 100644 index 8a1f5e257..000000000 --- a/backpack/extensions/firstorder/batch_l2_grad/convtranspose3d.py +++ /dev/null @@ -1,8 +0,0 @@ -from backpack.extensions.firstorder.batch_l2_grad.convtransposend import ( - BatchL2ConvTransposeND, -) - - -class BatchL2ConvTranspose3d(BatchL2ConvTransposeND): - def __init__(self): - super().__init__(N=3, params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py index 9ceaa7881..3c54be1f5 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py +++ b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py @@ -1,23 +1,54 @@ +"""batch_l2 extension for ConvTranspose.""" from torch import einsum -from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives +from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives +from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives +from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base from backpack.utils import conv_transpose as convTransposeUtils -class BatchL2ConvTransposeND(FirstOrderModuleExtension): - def __init__(self, N, params=None): - super().__init__(params=params) - self.N = N +class BatchL2ConvTransposeND(BatchL2Base): + """batch_l2 extension for ConvTranspose.""" - # TODO Use bias Jacobian to compute `bias_gradient` - def bias(self, ext, module, g_inp, g_out, backproped): - spatial_dims = list(range(2, g_out[0].dim())) - channel_dim = 1 + def weight(self, ext, module, g_inp, g_out, backproped): + """batch_l2 for weight. - return g_out[0].sum(spatial_dims).pow_(2).sum(channel_dim) + Args: + ext: extension + module: module + g_inp: input gradients + g_out: output gradients + backproped: backpropagation quantities - def weight(self, ext, module, g_inp, g_out, backproped): + Returns: + batch_l2 for weight + """ X, dE_dY = convTransposeUtils.get_weight_gradient_factors( - module.input0, g_out[0], module, self.N + module.input0, g_out[0], module ) - return einsum("nmi,nki,nmj,nkj->n", (dE_dY, X, dE_dY, X)) + return einsum("nmi,nki,nmj,nkj->n", dE_dY, X, dE_dY, X) + + +class BatchL2ConvTranspose1d(BatchL2ConvTransposeND): + """batch_l2 extension for ConvTranspose1d.""" + + def __init__(self): + """Initialization.""" + super().__init__(["bias", "weight"], derivatives=ConvTranspose1DDerivatives()) + + +class BatchL2ConvTranspose2d(BatchL2ConvTransposeND): + """batch_l2 extension for ConvTranspose2d.""" + + def __init__(self): + """Initialization.""" + super().__init__(["bias", "weight"], derivatives=ConvTranspose2DDerivatives()) + + +class BatchL2ConvTranspose3d(BatchL2ConvTransposeND): + """batch_l2 extension for ConvTranspose3d.""" + + def __init__(self): + """Initialization.""" + super().__init__(["bias", "weight"], derivatives=ConvTranspose3DDerivatives()) diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py index 6a9b3b73c..9fb17536c 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/linear.py +++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py @@ -1,15 +1,43 @@ -from torch import einsum +"""Contains batch_l2 extension for Linear.""" +from __future__ import annotations -from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from typing import TYPE_CHECKING, Tuple +from torch import Tensor, einsum +from torch.nn import Linear + +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base + +if TYPE_CHECKING: + from backpack.extensions import BatchL2Grad + + +class BatchL2Linear(BatchL2Base): + """batch_l2 extension for Linear.""" -class BatchL2Linear(FirstOrderModuleExtension): def __init__(self): - super().__init__(params=["bias", "weight"]) + """Initialization.""" + super().__init__(["bias", "weight"], derivatives=LinearDerivatives()) + + def weight( + self, + ext: BatchL2Grad, + module: Linear, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + backproped: None, + ) -> Tensor: + """batch_l2 for weight. - def bias(self, ext, module, g_inp, g_out, backproped): - C_axis = 1 - return (g_out[0] ** 2).sum(C_axis) + Args: + ext: extension + module: module + g_inp: input gradients + g_out: output gradients + backproped: backpropagation quantities - def weight(self, ext, module, g_inp, g_out, backproped): - return einsum("ni,nj->n", (g_out[0] ** 2, module.input0 ** 2)) + Returns: + batch_l2 for weight + """ + return einsum("ni,nj->n", g_out[0] ** 2, module.input0 ** 2) diff --git a/backpack/extensions/firstorder/batch_l2_grad/rnn.py b/backpack/extensions/firstorder/batch_l2_grad/rnn.py index 4540eef9a..159be22e2 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/rnn.py +++ b/backpack/extensions/firstorder/batch_l2_grad/rnn.py @@ -1,70 +1,14 @@ """Contains BatchL2RNN.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Callable, Tuple - -from torch import Tensor -from torch.nn import Module - from backpack.core.derivatives.rnn import RNNDerivatives -from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base -if TYPE_CHECKING: - from backpack.extensions import BatchL2Grad - -class BatchL2RNN(FirstOrderModuleExtension): +class BatchL2RNN(BatchL2Base): """Extension for RNN, calculating batch_l2.""" def __init__(self): """Initialization.""" - params = ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"] - for param_str in params: - if not hasattr(self, param_str): - setattr(self, param_str, self._make_param_function(param_str)) - super(BatchL2RNN, self).__init__(params=params) - self.derivatives: RNNDerivatives = RNNDerivatives() - - def _make_param_function( - self, param_str: str - ) -> Callable[[BatchL2Grad, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]: - """Creates a function that calculates batch_l2. - - Args: - param_str: name of parameter - - Returns: - function that calculates batch_l2 - """ - - def param_function( - ext: BatchL2Grad, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - bpQuantities: None, - ) -> Tensor: - """Calculates batch_l2 with the help of derivatives object. - - Args: - ext: extension that is used - module: module that performed forward pass - g_inp: input gradient tensors - g_out: output gradient tensors - bpQuantities: additional quantities for second order - - Returns: - batch_l2 - """ - return ( - ( - getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( - module, g_inp, g_out, g_out[0], sum_batch=False - ) - ** 2 - ) - .flatten(start_dim=1) - .sum(1) - ) - - return param_function + super(BatchL2RNN, self).__init__( + ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + derivatives=RNNDerivatives(), + ) diff --git a/backpack/utils/conv.py b/backpack/utils/conv.py index 14d394f54..e2838e027 100644 --- a/backpack/utils/conv.py +++ b/backpack/utils/conv.py @@ -30,7 +30,7 @@ def unfold_input(module, input): return unfold_by_conv(input, module) -def get_weight_gradient_factors(input, grad_out, module, N): +def get_weight_gradient_factors(input, grad_out, module): X = unfold_input(module, input) dE_dY = rearrange(grad_out, "n c ... -> n c (...)") return X, dE_dY diff --git a/backpack/utils/conv_transpose.py b/backpack/utils/conv_transpose.py index d1183e909..6b4d138ec 100644 --- a/backpack/utils/conv_transpose.py +++ b/backpack/utils/conv_transpose.py @@ -8,7 +8,7 @@ from backpack.utils.conv import extract_bias_diagonal as conv_extract_bias_diagonal -def get_weight_gradient_factors(input, grad_out, module, N): +def get_weight_gradient_factors(input, grad_out, module): M, C_in = input.shape[0], input.shape[1] kernel_size_numel = module.weight.shape[2:].numel() diff --git a/fully_documented.txt b/fully_documented.txt index d86430c88..8b38fd07e 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -23,8 +23,7 @@ backpack/extensions/firstorder/variance/__init__.py backpack/extensions/firstorder/sum_grad_squared/sgs_base.py backpack/extensions/firstorder/sum_grad_squared/rnn.py backpack/extensions/firstorder/sum_grad_squared/__init__.py -backpack/extensions/firstorder/batch_l2_grad/rnn.py -backpack/extensions/firstorder/batch_l2_grad/__init__.py +backpack/extensions/firstorder/batch_l2_grad/ backpack/extensions/secondorder/__init__.py backpack/extensions/secondorder/diag_ggn/__init__.py From 8680bebbe0822fa74c337bd37a39cc4abe19c7cf Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 23 Jun 2021 14:01:40 +0200 Subject: [PATCH 17/54] [core] Support additional dimensions in input to `Linear` (#185) Addssupport for additional axes to the `LinearDerivatives` in the `core`. --- * [TEST] Preserve batch axis of individual samples To avoid PyTorch from accidentally identifying the additional axis of an individual sample fed through a linear layer as batch axis, explicitly keeps the batch axis during slicing out individual samples of a mini-batch. * [core] Support additional dimensions in Linear jac_t_mat_prod * [core] Support additional dimensions in Linear jac_mat_prod * [core] Support additional dimensions in Linear weight_jac_t * [core] Support additional dimensions in Linear weight_jac * [core] Support additional dimensions in Linear bias_jac_t * [core] Support additional dimensions in linear bias_jac * [core] Support additional dimensions in Linear ea_jac_t_mat_jac * [TEST] Merge additional dimension cases with regular cases * [DOC] Fully document LinearDerivatives --- backpack/core/derivatives/linear.py | 227 +++++++++++++++--- fully_documented.txt | 2 + test/core/derivatives/derivatives_test.py | 76 +++--- .../derivatives/implementation/autograd.py | 4 +- test/core/derivatives/linear_settings.py | 21 +- 5 files changed, 255 insertions(+), 75 deletions(-) diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index 0e1f5f32b..2b540b413 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -1,4 +1,8 @@ -from torch import einsum +"""Contains partial derivatives for the ``torch.nn.Linear`` layer.""" +from typing import Any + +from torch import Size, Tensor, einsum +from torch.nn import Linear from backpack.core.derivatives.basederivatives import BaseParameterDerivatives @@ -14,43 +18,202 @@ class LinearDerivatives(BaseParameterDerivatives): * i: Input dimension """ - def hessian_is_zero(self): + def hessian_is_zero(self) -> bool: + """Linear layer output is linear w.r.t. to its input. + + Returns: + True + """ return True - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): - """Apply transposed Jacobian of the output w.r.t. the input.""" - d_input = module.weight.data - return einsum("oi,vno->vni", (d_input, mat)) + def _jac_t_mat_prod( + self, module: Linear, g_inp: Any, g_out: Any, mat: Tensor + ) -> Tensor: + """Batch-apply transposed Jacobian of the output w.r.t. the input. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer output + (``[N, *, out_features]``) to which the transposed output-input Jacobian + is applied. Has shape ``[V, N, *, out_features]``. + + Returns: + Batched transposed Jacobian vector products. Has shape + ``[V, N, *, in_features]``. + """ + return einsum("oi,vn...o->vn...i", module.weight.data, mat) + + def _jac_mat_prod( + self, module: Linear, g_inp: Any, g_out: Any, mat: Tensor + ) -> Tensor: + """Batch-apply Jacobian of the output w.r.t. the input. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer input + (``[N, *, in_features]``) to which the output-input Jacobian is applied. + Has shape ``[V, N, *, in_features]``. + + Returns: + Batched Jacobian vector products. Has shape ``[V, N, *, out_features]``. + """ + return einsum("oi,vn...i->vn...o", module.weight.data, mat) + + def ea_jac_t_mat_jac_prod( + self, module: Linear, g_inp: Any, g_out: Any, mat: Tensor + ) -> Tensor: + """Expectation approximation of outer product with input-output Jacobian. + + Used for KFRA backpropagation: ``mat ← E(Jₙᵀ mat Jₙ) = 1/N ∑ₙ Jₙᵀ mat Jₙ``. - def _jac_mat_prod(self, module, g_inp, g_out, mat): - """Apply Jacobian of the output w.r.t. the input.""" - d_input = module.weight.data - return einsum("oi,vni->vno", (d_input, mat)) + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Matrix of shape + ``[module.output.numel() // N, module.output.numel() // N]``. + + Returns: + Matrix of shape + ``[module.input0.numel() // N, module.input0.numel() // N]``. + """ + add_features = self._get_additional_dims(module).numel() + in_features, out_features = module.in_features, module.out_features - def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): jac = module.weight.data - return einsum("ik,ij,jl->kl", (jac, mat, jac)) - def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): - """Apply Jacobian of the output w.r.t. the weight.""" - d_weight = module.input0 - return einsum("ni,voi->vno", (d_weight, mat)) + result = mat.reshape(add_features, out_features, add_features, out_features) + result = einsum("ik,xiyj,jl->xkyl", (jac, result, jac)) + + return result.reshape(in_features * add_features, in_features * add_features) + + def _weight_jac_mat_prod( + self, module: Linear, g_inp: Any, g_out: Any, mat: Tensor + ) -> Tensor: + """Batch-apply Jacobian of the output w.r.t. the weight. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of shape ``module.weight.shape`` to which the + transposed output-input Jacobian is applied. Has shape + ``[V, *module.weight.shape]``. - def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - """Apply transposed Jacobian of the output w.r.t. the weight.""" + Returns: + Batched Jacobian vector products. Has shape + ``[V, N, *module.output.shape]``. + """ + return einsum("n...i,voi->vn...o", module.input0, mat) + + def _weight_jac_t_mat_prod( + self, module: Linear, g_inp: Any, g_out: Any, mat: Tensor, sum_batch: int = True + ) -> Tensor: + """Batch-apply transposed Jacobian of the output w.r.t. the weight. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer output + (``[N, *, out_features]``) to which the transposed output-input Jacobian + is applied. Has shape ``[V, N, *, out_features]``. + sum_batch: Sum the result's batch axis. Default: ``True``. + + Returns: + Batched transposed Jacobian vector products. Has shape + ``[V, N, *module.weight.shape]`` when ``sum_batch`` is ``False``. With + ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. + """ d_weight = module.input0 - contract = "vno,ni->voi" if sum_batch else "vno,ni->vnoi" - return einsum(contract, (mat, d_weight)) - - def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): - """Apply Jacobian of the output w.r.t. the bias.""" - N = module.input0.size(0) - return mat.unsqueeze(1).expand(-1, N, -1) - - def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - """Apply transposed Jacobian of the output w.r.t. the bias.""" - if sum_batch: - N_axis = 1 - return mat.sum(N_axis) + + if self._has_additional_dims(module): + # Flatten additional dimensions because they cannot be represented as + # ellipsis. WAITING https://github.com/pytorch/pytorch/issues/45854 + d_weight = d_weight.flatten(start_dim=1, end_dim=-2) + mat = mat.flatten(start_dim=2, end_dim=-2) + equation = f"vnao,nai->v{'' if sum_batch else 'n'}oi" else: - return mat + equation = f"vno,ni->v{'' if sum_batch else 'n'}oi" + + return einsum(equation, mat, d_weight) + + def _bias_jac_mat_prod( + self, module: Linear, g_inp: Any, g_out: Any, mat: Tensor + ) -> Tensor: + """Batch-apply Jacobian of the output w.r.t. the bias. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of shape ``module.bias.shape`` to which the + transposed output-input Jacobian is applied. Has shape + ``[V, *module.bias.shape]``. + + Returns: + Batched Jacobian vector products. Has shape + ``[V, N, *module.output.shape]``. + """ + N = module.input0.shape[0] + additional_dims = list(self._get_additional_dims(module)) + + for _ in range(len(additional_dims) + 1): + mat = mat.unsqueeze(1) + + expand = [-1, N] + additional_dims + [-1] + + return mat.expand(*expand) + + def _bias_jac_t_mat_prod( + self, module: Linear, g_inp: Any, g_out: Any, mat: Tensor, sum_batch: int = True + ) -> Tensor: + """Batch-apply transposed Jacobian of the output w.r.t. the bias. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer output + (``[N, *, out_features]``) to which the transposed output-input Jacobian + is applied. Has shape ``[V, N, *, out_features]``. + sum_batch: Sum the result's batch axis. Default: ``True``. + + Returns: + Batched transposed Jacobian vector products. Has shape + ``[V, N, *module.bias.shape]`` when ``sum_batch`` is ``False``. With + ``sum_batch=True``, has shape ``[V, *module.bias.shape]``. + """ + equation = f"vn...o->v{'' if sum_batch else 'n'}o" + + return einsum(equation, mat) + + def _has_additional_dims(self, module: Linear) -> bool: + """Return whether the input to a linear layer has additional (>1) dimensions. + + The input to a linear layer may have shape ``[N, *, out_features]``. + It has additional dimensions if ``*`` is non-empty. + + Args: + module: Linear layer. + + Returns: + Whether the input has hidden dimensions. + """ + return len(self._get_additional_dims(module)) != 0 + + def _get_additional_dims(self, module: Linear) -> Size: + """Return the shape of additional dimensions in the input to a linear layer. + + Args: + module: A linear layer. + + Returns: + Shape of the additional dimensions. Corresponds to ``*`` in the + input shape ``[N, *, out_features]``. + """ + return module.input0.shape[1:-1] diff --git a/fully_documented.txt b/fully_documented.txt index e1fdfc81b..ca5f6ec86 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -18,3 +18,5 @@ backpack/extensions/secondorder/diag_hessian/conv3d.py backpack/extensions/__init__.py backpack/extensions/secondorder/sqrt_ggn + +backpack/core/derivatives/linear.py diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 776edcd03..1192490cd 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -11,7 +11,7 @@ from test.core.derivatives.implementation.autograd import AutogradDerivatives from test.core.derivatives.implementation.backpack import BackpackDerivatives from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS -from test.core.derivatives.problem import make_test_problems +from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems from test.core.derivatives.settings import SETTINGS import pytest @@ -34,12 +34,12 @@ @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_jac_mat_prod(problem, V=3): +def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: """Test the Jacobian-matrix product. Args: - problem (DerivativesProblem): Problem for derivative test. - V (int): Number of vectorized Jacobian-vector products. + problem: Test case. + V (int): Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.input_shape).to(problem.device) @@ -52,12 +52,12 @@ def test_jac_mat_prod(problem, V=3): @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_jac_t_mat_prod(problem, V=3): +def test_jac_t_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: """Test the transposed Jacobian-matrix product. Args: - problem (DerivativesProblem): Problem for derivative test. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) @@ -86,14 +86,16 @@ def test_jac_t_mat_prod(problem, V=3): ids=["save_memory=True", "save_memory=False"], ) @pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) -def test_weight_jac_t_mat_prod(problem, sum_batch, save_memory, V=3): - """Test the transposed Jacobian-matrix product w.r.t. to the weights. +def test_weight_jac_t_mat_prod( + problem: DerivativesTestProblem, sum_batch: bool, save_memory: bool, V: int = 3 +) -> None: + """Test the transposed Jacobian-matrix product w.r.t. to the weight. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - save_memory (bool): Use Owkin implementation to save memory. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + sum_batch: Sum out the batch dimension. + save_memory: Use Owkin implementation in convolutions to save memory. + V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) @@ -109,12 +111,12 @@ def test_weight_jac_t_mat_prod(problem, sum_batch, save_memory, V=3): @pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) -def test_weight_jac_mat_prod(problem, V=3): - """Test the Jacobian-matrix product w.r.t. to the weights. +def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: + """Test the Jacobian-matrix product w.r.t. to the weight. Args: - problem (DerivativesProblem): Problem for derivative test. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.module.weight.shape).to(problem.device) @@ -137,18 +139,16 @@ def test_weight_jac_mat_prod(problem, V=3): @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_BIAS, - ids=IDS_WITH_BIAS, -) -def test_bias_jac_t_mat_prod(problem, sum_batch, V=3): - """Test the transposed Jacobian-matrix product w.r.t. to the biass. +@pytest.mark.parametrize("problem", PROBLEMS_WITH_BIAS, ids=IDS_WITH_BIAS) +def test_bias_jac_t_mat_prod( + problem: DerivativesTestProblem, sum_batch: bool, V: int = 3 +) -> None: + """Test the transposed Jacobian-matrix product w.r.t. to the bias. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + sum_batch: Sum out the batch dimension. + V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) @@ -160,17 +160,13 @@ def test_bias_jac_t_mat_prod(problem, sum_batch, V=3): problem.tear_down() -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_BIAS, - ids=IDS_WITH_BIAS, -) -def test_bias_jac_mat_prod(problem, V=3): - """Test the Jacobian-matrix product w.r.t. to the biass. +@pytest.mark.parametrize("problem", PROBLEMS_WITH_BIAS, ids=IDS_WITH_BIAS) +def test_bias_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: + """Test the Jacobian-matrix product w.r.t. to the bias. Args: - problem (DerivativesProblem): Problem for derivative test. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.module.bias.shape).to(problem.device) @@ -259,8 +255,8 @@ def test_sum_hessian_should_fail(problem): @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_ea_jac_t_mat_jac_prod(problem): - """Test KFRA backpropagation +def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem) -> None: + """Test KFRA backpropagation. H_in → 1/N ∑ₙ Jₙ^T H_out Jₙ @@ -271,10 +267,10 @@ def test_ea_jac_t_mat_jac_prod(problem): as `Dropout` is not deterministic. Args: - problem (DerivativesProblem): Problem for derivative test. + problem: Test case. """ problem.set_up() - out_features = torch.prod(torch.tensor(problem.output_shape[1:])) + out_features = problem.output_shape[1:].numel() mat = torch.rand(out_features, out_features).to(problem.device) backpack_res = BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat) diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index e7a60f6b2..1cf7fa96a 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -49,8 +49,8 @@ def param_jac_t_vec_prod(self, name, vec, sum_batch): else: N = input.shape[0] - sample_outputs = [output[n] for n in range(N)] - sample_vecs = [vec[n] for n in range(N)] + sample_outputs = [output[[n]] for n in range(N)] + sample_vecs = [vec[[n]] for n in range(N)] jac_t_sample_prods = [ transposed_jacobian_vector_product(n_out, param, n_vec)[0] diff --git a/test/core/derivatives/linear_settings.py b/test/core/derivatives/linear_settings.py index df7132d99..a98df7aab 100644 --- a/test/core/derivatives/linear_settings.py +++ b/test/core/derivatives/linear_settings.py @@ -6,7 +6,7 @@ "input_fn" (callable): Used for specifying input function Optional entries: - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model "device" [list(torch.device)]: List of devices to run the test on. @@ -55,3 +55,22 @@ ), }, ] + +# additional dimensions +LINEAR_SETTINGS += [ + { + "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True), + "input_fn": lambda: torch.rand(size=(3, 2, 4)), + "id_prefix": "one-additional", + }, + { + "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True), + "input_fn": lambda: torch.rand(size=(3, 2, 3, 4)), + "id_prefix": "two-additional", + }, + { + "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True), + "input_fn": lambda: torch.rand(size=(3, 2, 3, 5, 4)), + "id_prefix": "three-additional", + }, +] From 4968d60c73ea5f2d6683fb3ca6b26ed83beb7597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 23 Jun 2021 18:47:42 +0200 Subject: [PATCH 18/54] [core] Add derivatives for `AdaptiveAvgPoolNd` (#165) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `AdaptiveAvgPoolNd` is supported through `AvgPoolNd`, and therefor only in cases where such an equivalent unadaptive version exists. Slight refactoring of `AvgPoolNd` to reduce branching with `N`. --- * Create shape check for AdaptiveAvgPoolNDDerivatives and test it * AdaptiveAvgPool: make derivatives and test * refactor check_parameters, improve test * isort * AdaptiveAvgPool: allow None in output_size * simplify types * merge get_equivalent_parameters into get_parameters * remove _get_shape_target * simplify problem.py * improve test coverage * improve adaptive_avg_pool_nd.py * remove example * improve problem.py * improve problem.py * fix AdaptiveAvgPool3d test * improve test * delete TODO * change get_parameters to get_avg_pool_parameters * format * refactor avgpoolnd.py Co-authored-by: Tim Schäfer --- .../core/derivatives/adaptive_avg_pool_nd.py | 105 ++++++++++ backpack/core/derivatives/avgpoolnd.py | 77 ++++---- fully_documented.txt | 4 + test/adaptive_avg_pool/__init__.py | 4 + test/adaptive_avg_pool/problem.py | 179 ++++++++++++++++++ .../settings_adaptive_avg_pool_nd.py | 49 +++++ .../test_adaptive_avg_pool_nd.py | 28 +++ test/core/derivatives/__init__.py | 11 ++ test/core/derivatives/derivatives_test.py | 21 +- .../derivatives/pooling_adaptive_settings.py | 36 ++++ test/core/derivatives/settings.py | 2 + 11 files changed, 473 insertions(+), 43 deletions(-) create mode 100644 backpack/core/derivatives/adaptive_avg_pool_nd.py create mode 100644 test/adaptive_avg_pool/__init__.py create mode 100644 test/adaptive_avg_pool/problem.py create mode 100644 test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py create mode 100644 test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py create mode 100644 test/core/derivatives/pooling_adaptive_settings.py diff --git a/backpack/core/derivatives/adaptive_avg_pool_nd.py b/backpack/core/derivatives/adaptive_avg_pool_nd.py new file mode 100644 index 000000000..167d92495 --- /dev/null +++ b/backpack/core/derivatives/adaptive_avg_pool_nd.py @@ -0,0 +1,105 @@ +"""Implements the derivatives for AdaptiveAvgPool.""" +from typing import List, Tuple, Union +from warnings import warn + +from torch import Size +from torch.nn import AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d + +from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives + + +class AdaptiveAvgPoolNDDerivatives(AvgPoolNDDerivatives): + """Implements the derivatives for AdaptiveAvgPool.""" + + def check_parameters( + self, module: Union[AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d] + ) -> None: + """Checks if the parameters are supported. + + Specifically checks if input shape is multiple of output shape. + In this case, there are parameters for AvgPoolND that are equivalent. + + https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993 # noqa: B950 + + Args: + module: module to check + + Raises: + NotImplementedError: if the given shapes do not match + """ + if ("cuda" in str(module.input0.device)) and (self.N == 3): + warn( + "Be careful when computing gradients of AdaptiveAvgPool3d. " + "There is a bug using autograd.grad on cuda with AdaptiveAvgPool3d. " + "https://discuss.pytorch.org/t/bug-report-autograd-grad-adaptiveavgpool3d-cuda/124614 " # noqa: B950 + "BackPACK derivatives are correct." + ) + + shape_input: Size = module.input0.shape + shape_output: Size = module.output.shape + + # check length of input shape + if not len(shape_input) == (self.N + 2): + raise NotImplementedError( + f"input must be (batch_size, C, ...) with ... {self.N} dimensions" + ) + + # check if input shape is multiple of output shape + if any([shape_input[2 + n] % shape_output[2 + n] != 0 for n in range(self.N)]): + raise NotImplementedError( + f"No equivalent AvgPool (unadaptive): Input shape ({shape_input}) " + f"must be multiple of output shape ({shape_output})." + ) + + def get_avg_pool_parameters( + self, module: Union[AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d] + ) -> Tuple[List[int], List[int], List[int]]: + """Return parameters for an equivalent AvgPool. + + Assumes that check_parameters has been run before. + Therefore, does not check parameters. + + Args: + module: module to compute on + + Returns: + stride, kernel_size, padding as lists of length self.N + """ + shape_input: Size = module.input0.shape + shape_target: Size = module.output.shape + + # calculate equivalent AvgPoolND parameters + stride: List[int] = [] + kernel_size: List[int] = [] + for n in range(self.N): + in_dim: int = shape_input[2 + n] + out_dim: int = shape_target[2 + n] + stride.append(in_dim // out_dim) + kernel_size.append(in_dim - (out_dim - 1) * stride[n]) + padding: List[int] = [0 for _ in range(self.N)] + + return stride, kernel_size, padding + + +class AdaptiveAvgPool1dDerivatives(AdaptiveAvgPoolNDDerivatives): + """Derivatives for AdaptiveAvgPool1d.""" + + def __init__(self): + """Initialization.""" + super().__init__(N=1) + + +class AdaptiveAvgPool2dDerivatives(AdaptiveAvgPoolNDDerivatives): + """Derivatives for AdaptiveAvgPool2d.""" + + def __init__(self): + """Initialization.""" + super().__init__(N=2) + + +class AdaptiveAvgPool3dDerivatives(AdaptiveAvgPoolNDDerivatives): + """Derivatives for AdaptiveAvgPool3d.""" + + def __init__(self): + """Initialization.""" + super().__init__(N=3) diff --git a/backpack/core/derivatives/avgpoolnd.py b/backpack/core/derivatives/avgpoolnd.py index e40d1f0af..618d8af81 100644 --- a/backpack/core/derivatives/avgpoolnd.py +++ b/backpack/core/derivatives/avgpoolnd.py @@ -3,6 +3,7 @@ Average pooling can be expressed as convolution over grouped channels with a constant kernel. """ +from typing import Any, Tuple import torch.nn from einops import rearrange @@ -13,6 +14,7 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Module, ) from backpack.core.derivatives.basederivatives import BaseDerivatives @@ -31,29 +33,34 @@ def __init__(self, N): self.conv = Conv3d self.convt = ConvTranspose3d + def check_parameters(self, module: Module) -> None: + assert module.count_include_pad, ( + "Might not work for exotic hyperparameters of AvgPool2d, " + + "like count_include_pad=False" + ) + + def get_avg_pool_parameters(self, module) -> Tuple[Any, Any, Any]: + """Return the parameters of the module. + + Args: + module: module + + Returns: + stride, kernel_size, padding + """ + return module.stride, module.kernel_size, module.padding + def hessian_is_zero(self): return True def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): """Use fact that average pooling can be implemented as conv.""" - if self.N == 1: - _, C, L_in = module.input0.size() - _, _, L_out = module.output.size() - in_features = C * L_in - out_features = C * L_out - shape_out = (1, L_out) - elif self.N == 2: - _, C, H_in, W_in = module.input0.size() - _, _, H_out, W_out = module.output.size() - in_features = C * H_in * W_in - out_features = C * H_out * W_out - shape_out = (1, H_out, W_out) - elif self.N == 3: - _, C, D_in, H_in, W_in = module.input0.size() - _, _, D_out, H_out, W_out = module.output.size() - in_features = C * D_in * H_in * W_in - out_features = C * D_out * H_out * W_out - shape_out = (1, D_out, H_out, W_out) + self.check_parameters(module) + + C = module.input0.shape[1] + shape_out = (1,) + tuple(module.output.shape[2:]) + in_features = module.input0.shape[1:].numel() + out_features = module.output.shape[1:].numel() mat = mat.reshape(out_features * C, *shape_out) jac_t_mat = self.__apply_jacobian_t_of(module, mat).reshape( @@ -66,14 +73,8 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): return jac_t_mat_t_jac.t() - def check_exotic_parameters(self, module): - assert module.count_include_pad, ( - "Might not work for exotic hyperparameters of AvgPool2d, " - + "like count_include_pad=False" - ) - def _jac_mat_prod(self, module, g_inp, g_out, mat): - self.check_exotic_parameters(module) + self.check_parameters(module) mat_as_pool = self.__make_single_channel(mat, module) jmp_as_pool = self.__apply_jacobian_of(module, mat_as_pool) @@ -89,12 +90,13 @@ class and channel dimension.""" return result.unsqueeze(C_axis) def __apply_jacobian_of(self, module, mat): + stride, kernel_size, padding = self.get_avg_pool_parameters(module) convnd = self.conv( in_channels=1, out_channels=1, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, + kernel_size=kernel_size, + stride=stride, + padding=padding, bias=False, ).to(module.input0.device) @@ -117,7 +119,7 @@ def __check_jmp_out_as_pool(self, mat, jmp_as_pool, module): assert jmp_as_pool.shape == (V * N * C_out, 1, D_out, H_out, W_out) def _jac_t_mat_prod(self, module, g_inp, g_out, mat): - self.check_exotic_parameters(module) + self.check_parameters(module) mat_as_pool = self.__make_single_channel(mat, module) jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool) @@ -126,14 +128,15 @@ def _jac_t_mat_prod(self, module, g_inp, g_out, mat): return self.reshape_like_input(jmp_as_pool, module) def __apply_jacobian_t_of(self, module, mat): + stride, kernel_size, padding = self.get_avg_pool_parameters(module) C_for_conv_t = 1 convnd_t = self.convt( in_channels=C_for_conv_t, out_channels=C_for_conv_t, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, + kernel_size=kernel_size, + stride=stride, + padding=padding, bias=False, ).to(module.input0.device) @@ -142,15 +145,7 @@ def __apply_jacobian_t_of(self, module, mat): convnd_t.weight.data = avg_kernel V_N_C_in = mat.size(0) - if self.N == 1: - _, _, L_in = module.input0.size() - output_size = (V_N_C_in, C_for_conv_t, L_in) - elif self.N == 2: - _, _, H_in, W_in = module.input0.size() - output_size = (V_N_C_in, C_for_conv_t, H_in, W_in) - elif self.N == 3: - _, _, D_in, H_in, W_in = module.input0.size() - output_size = (V_N_C_in, C_for_conv_t, D_in, H_in, W_in) + output_size = (V_N_C_in, C_for_conv_t) + tuple(module.input0.shape[2:]) return convnd_t(mat, output_size=output_size) diff --git a/fully_documented.txt b/fully_documented.txt index 8b38fd07e..41bb83b5b 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -6,6 +6,7 @@ backpack/core/derivatives/shape_check.py backpack/core/derivatives/__init__.py backpack/core/derivatives/permute.py backpack/core/derivatives/lstm.py +backpack/core/derivatives/adaptive_avg_pool_nd.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py @@ -44,6 +45,7 @@ test/core/derivatives/utils.py test/core/derivatives/implementation/ test/core/derivatives/permute_settings.py test/core/derivatives/lstm_settings.py +test/core/derivatives/pooling_adaptive_settings.py test/extensions/problem.py test/extensions/test_backprop_extension.py @@ -54,3 +56,5 @@ test/extensions/firstorder/batch_grad/batchgrad_settings.py test/extensions/secondorder/secondorder_settings.py test/extensions/secondorder/diag_ggn/ test/extensions/secondorder/hbp + +test/adaptive_avg_pool/ diff --git a/test/adaptive_avg_pool/__init__.py b/test/adaptive_avg_pool/__init__.py new file mode 100644 index 000000000..d9ed32847 --- /dev/null +++ b/test/adaptive_avg_pool/__init__.py @@ -0,0 +1,4 @@ +"""Module tests AdaptiveAvgPoolNDDerivatives. + +Especially the shape checker for equivalence with AvgPoolND. +""" diff --git a/test/adaptive_avg_pool/problem.py b/test/adaptive_avg_pool/problem.py new file mode 100644 index 000000000..e96713ed4 --- /dev/null +++ b/test/adaptive_avg_pool/problem.py @@ -0,0 +1,179 @@ +"""Test problems for the AdaptiveAvgPool shape checker.""" +from __future__ import annotations + +import copy +from test.automated_test import check_sizes_and_values +from test.core.derivatives.utils import get_available_devices +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch import Tensor, randn +from torch.nn import ( + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, + AvgPool1d, + AvgPool2d, + AvgPool3d, + Module, +) + +from backpack import extend +from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives + + +def make_test_problems(settings: List[Dict[str, Any]]) -> List[AdaptiveAvgPoolProblem]: + """Creates the test problem from settings. + + Args: + settings: list of dictionaries with settings + + Returns: + a list of the test problems + """ + problem_dicts: List[Dict[str, Any]] = [] + + for setting in settings: + setting = add_missing_defaults(setting) + devices = setting["device"] + + for dev in devices: + problem = copy.deepcopy(setting) + problem["device"] = dev + problem_dicts.append(problem) + + return [AdaptiveAvgPoolProblem(**p) for p in problem_dicts] + + +def add_missing_defaults(setting: Dict[str, Any]) -> Dict[str, Any]: + """Add missing entries in settings such that the new format works. + + Args: + setting: dictionary with required settings and some optional settings + + Returns: + complete settings including the default values for missing optional settings + + Raises: + ValueError: if the settings do not work + """ + required = ["N", "shape_input", "shape_target", "works"] + optional = { + "id_prefix": "", + "device": get_available_devices(), + "seed": 0, + } + + for req in required: + if req not in setting.keys(): + raise ValueError(f"Missing configuration entry for {req}") + + for opt, default in optional.items(): + if opt not in setting.keys(): + setting[opt] = default + + for s in setting.keys(): + if s not in required and s not in optional.keys(): + raise ValueError(f"Unknown config: {s}") + + return setting + + +class AdaptiveAvgPoolProblem: + """Test problem for testing AdaptiveAvgPoolNDDerivatives.check_parameters().""" + + def __init__( + self, + N: int, + shape_input: Any, + shape_target: Tuple[int], + works: bool, + device, + seed: int, + id_prefix: str, + ): + """Initialization. + + Args: + N: number of dimensions + shape_input: input shape + shape_target: target shape + works: whether the test should run without errors + device: device + seed: seed for torch + id_prefix: prefix for problem id + + Raises: + NotImplementedError: if N is not in [1, 2, 3] + """ + if N not in [1, 2, 3]: + raise NotImplementedError(f"N={N} not implemented in test suite.") + self.N = N + self.shape_input = shape_input + self.shape_target = shape_target + self.works = works + self.device = device + self.seed = seed + self.id_prefix = id_prefix + + def make_id(self) -> str: + """Create an id from problem parameters. + + Returns: + problem id + """ + prefix = (self.id_prefix + "-") if self.id_prefix != "" else "" + return ( + prefix + f"dev={self.device}-N={self.N}-in={self.shape_input}-" + f"out={self.shape_target}-works={self.works}" + ) + + def set_up(self) -> None: + """Set up problem and do one forward pass.""" + torch.manual_seed(self.seed) + self.module = self._make_module() + self.input = randn(self.shape_input) + self.output = self.module(self.input) + + def tear_down(self): + """Delete created torch variables.""" + del self.module + del self.input + del self.output + + def _make_module( + self, + ) -> Union[AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d]: + map_class = {1: AdaptiveAvgPool1d, 2: AdaptiveAvgPool2d, 3: AdaptiveAvgPool3d} + module = map_class[self.N](output_size=self.shape_target) + return extend(module.to(device=self.device)) + + def check_parameters(self) -> None: + """Key method for test. + + Run the AdaptiveAvgPoolNDDerivatives.check_parameters() method. + """ + self._get_derivatives().check_parameters(module=self.module) + + def _get_derivatives(self) -> AdaptiveAvgPoolNDDerivatives: + return AdaptiveAvgPoolNDDerivatives(N=self.N) + + def check_equivalence(self) -> None: + """Check if the given parameters lead to the same output. + + Checks the sizes and values. + """ + stride, kernel_size, _ = self._get_derivatives().get_avg_pool_parameters( + self.module + ) + module_equivalent: Module = self._make_module_equivalent(stride, kernel_size) + output_equivalent: Tensor = module_equivalent(self.input) + + check_sizes_and_values(self.output, output_equivalent) + + def _make_module_equivalent( + self, stride: List[int], kernel_size: List[int] + ) -> Union[AvgPool1d, AvgPool2d, AvgPool3d]: + map_class = {1: AvgPool1d, 2: AvgPool2d, 3: AvgPool3d} + module = map_class[self.N](kernel_size=kernel_size, stride=stride) + return module.to(self.device) diff --git a/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py b/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py new file mode 100644 index 000000000..58fdaaebb --- /dev/null +++ b/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py @@ -0,0 +1,49 @@ +"""Settings to run test_adaptive_avg_pool_nd.""" +from typing import Any, Dict, List + +from torch import Size + +SETTINGS: List[Dict[str, Any]] = [ + { + "N": 1, + "shape_target": 2, + "shape_input": (1, 5, 8), + "works": True, + }, + { + "N": 1, + "shape_target": 2, + "shape_input": (1, 8, 7), + "works": False, + }, + { + "N": 2, + "shape_target": Size((4, 3)), + "shape_input": (1, 64, 8, 9), + "works": True, + }, + { + "N": 2, + "shape_target": 2, + "shape_input": (1, 64, 8, 10), + "works": True, + }, + { + "N": 2, + "shape_target": 2, + "shape_input": (1, 64, 8, 9), + "works": False, + }, + { + "N": 2, + "shape_target": (5, 2), + "shape_input": (1, 64, 64, 10), + "works": False, + }, + { + "N": 3, + "shape_target": (None, 2, None), + "shape_input": (1, 64, 7, 10, 5), + "works": True, + }, +] diff --git a/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py b/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py new file mode 100644 index 000000000..eebf1931b --- /dev/null +++ b/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py @@ -0,0 +1,28 @@ +"""Test the shape checker of AdaptiveAvgPoolNDDerivatives.""" +from test.adaptive_avg_pool.problem import AdaptiveAvgPoolProblem, make_test_problems +from test.adaptive_avg_pool.settings_adaptive_avg_pool_nd import SETTINGS +from typing import List + +import pytest + +PROBLEMS: List[AdaptiveAvgPoolProblem] = make_test_problems(SETTINGS) +IDS: List[str] = [problem.make_id() for problem in PROBLEMS] + + +@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) +def test_adaptive_avg_pool_check_parameters(problem: AdaptiveAvgPoolProblem): + """Test AdaptiveAvgPoolNDDerivatives.check_parameters(). + + Additionally check if returned parameters are indeed equivalent. + + Args: + problem: test problem + """ + problem.set_up() + if problem.works: + problem.check_parameters() + problem.check_equivalence() + else: + with pytest.raises(NotImplementedError): + problem.check_parameters() + problem.tear_down() diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index b89d60d3d..c63ec0b7d 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -4,6 +4,9 @@ LSTM, RNN, SELU, + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, AvgPool1d, AvgPool2d, AvgPool3d, @@ -28,6 +31,11 @@ ZeroPad2d, ) +from backpack.core.derivatives.adaptive_avg_pool_nd import ( + AdaptiveAvgPool1dDerivatives, + AdaptiveAvgPool2dDerivatives, + AdaptiveAvgPool3dDerivatives, +) from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives @@ -85,4 +93,7 @@ RNN: RNNDerivatives, Permute: PermuteDerivatives, LSTM: LSTMDerivatives, + AdaptiveAvgPool1d: AdaptiveAvgPool1dDerivatives, + AdaptiveAvgPool2d: AdaptiveAvgPool2dDerivatives, + AdaptiveAvgPool3d: AdaptiveAvgPool3dDerivatives, } diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index f77871b1c..032d2dd5e 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -72,16 +72,24 @@ def test_jac_mat_prod(problem, V=3): NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + LSTM_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS, ) -def test_jac_t_mat_prod(problem, V=3): +def test_jac_t_mat_prod(problem, request, V=3): """Test the transposed Jacobian-matrix product. Args: problem (DerivativesProblem): Problem for derivative test. + request: Pytest request, used for getting id. V (int): Number of vectorized transposed Jacobian-vector products. """ problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) + if all( + [string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"]] + ): + with pytest.warns(UserWarning): + BackpackDerivatives(problem).jac_t_mat_prod(mat) + problem.tear_down() + return backpack_res = BackpackDerivatives(problem).jac_t_mat_prod(mat) autograd_res = AutogradDerivatives(problem).jac_t_mat_prod(mat) @@ -407,7 +415,7 @@ def test_sum_hessian_should_fail(problem): @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_ea_jac_t_mat_jac_prod(problem): +def test_ea_jac_t_mat_jac_prod(problem, request): """Test KFRA backpropagation. H_in → 1/N ∑ₙ Jₙ^T H_out Jₙ @@ -420,11 +428,20 @@ def test_ea_jac_t_mat_jac_prod(problem): Args: problem (DerivativesProblem): Problem for derivative test. + request: PyTest request, used to get test id. """ problem.set_up() out_features = torch.prod(torch.tensor(problem.output_shape[1:])) mat = torch.rand(out_features, out_features).to(problem.device) + if all( + [string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"]] + ): + with pytest.warns(UserWarning): + BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat) + problem.tear_down() + return + backpack_res = BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat) autograd_res = AutogradDerivatives(problem).ea_jac_t_mat_jac_prod(mat) diff --git a/test/core/derivatives/pooling_adaptive_settings.py b/test/core/derivatives/pooling_adaptive_settings.py new file mode 100644 index 000000000..0222f4c4d --- /dev/null +++ b/test/core/derivatives/pooling_adaptive_settings.py @@ -0,0 +1,36 @@ +"""Test configurations for `backpack.core.derivatives` for adaptive pooling layers. + +Required entries: + "module_fn" (callable): Contains a model constructed from `torch.nn` layers + "input_fn" (callable): Used for specifying input function + +Optional entries: + "target_fn" (callable): Fetches the groundtruth/target classes + of regression/classification task + "loss_function_fn" (callable): Loss function used in the model + "device" [list(torch.device)]: List of devices to run the test on. + "id_prefix" (str): Prefix to be included in the test name. + "seed" (int): seed for the random number for torch.rand +""" + +import torch + +POOLING_ADAPTIVE_SETTINGS = [] + +############################################################################### +# test settings # +############################################################################### +POOLING_ADAPTIVE_SETTINGS += [ + { + "module_fn": lambda: torch.nn.AdaptiveAvgPool1d(output_size=(3,)), + "input_fn": lambda: torch.rand(size=(1, 4, 9)), + }, + { + "module_fn": lambda: torch.nn.AdaptiveAvgPool2d(output_size=(3, 5)), + "input_fn": lambda: torch.rand(size=(2, 3, 9, 20)), + }, + { + "module_fn": lambda: torch.nn.AdaptiveAvgPool3d(output_size=(2, 2, 2)), + "input_fn": lambda: torch.rand(size=(1, 3, 4, 8, 8)), + }, +] diff --git a/test/core/derivatives/settings.py b/test/core/derivatives/settings.py index 0e1c33024..7379f5af7 100644 --- a/test/core/derivatives/settings.py +++ b/test/core/derivatives/settings.py @@ -15,6 +15,7 @@ from test.core.derivatives.linear_settings import LINEAR_SETTINGS from test.core.derivatives.loss_settings import LOSS_SETTINGS from test.core.derivatives.padding_settings import PADDING_SETTINGS +from test.core.derivatives.pooling_adaptive_settings import POOLING_ADAPTIVE_SETTINGS from test.core.derivatives.pooling_settings import POOLING_SETTINGS SETTINGS = ( @@ -24,4 +25,5 @@ + LOSS_SETTINGS + PADDING_SETTINGS + POOLING_SETTINGS + + POOLING_ADAPTIVE_SETTINGS ) From ed1fc13f2f966d3580090f023dbcc0cf2f9fc7c8 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 24 Jun 2021 11:38:45 +0200 Subject: [PATCH 19/54] [core] Add test of `hessian_is_zero` property (#183) Adds a test that checks `BaseDerivatives.hessian_is_zero` and improves documentation. --- * [TEST] Make test for `hessian_is_zero` work * [REF] Use fixture for `test_hessian_is_zero` * [FIX] flake8 * [TEST] Restrict Hessian tests to small inputs * [TEST] Enable `hessian_is_diagonal` test * [DEL] All Hessian property tests except `test_hessian_is_zero` * [FIX] flake8 Co-authored-by: Felix Dangel --- backpack/core/derivatives/basederivatives.py | 31 +++++- test/core/derivatives/derivatives_test.py | 94 +++++++++++++------ .../derivatives/implementation/autograd.py | 22 +++-- .../derivatives/implementation/backpack.py | 8 +- test/core/derivatives/implementation/base.py | 8 ++ 5 files changed, 113 insertions(+), 50 deletions(-) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 1911d9adf..2fd604986 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -1,5 +1,9 @@ """Base classes for more flexible Jacobians and second-order information.""" import warnings +from typing import Tuple + +from torch import Tensor +from torch.nn import Module from backpack.core.derivatives import shape_check @@ -124,17 +128,36 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): """ raise NotImplementedError - def hessian_is_zero(self): + def hessian_is_zero(self) -> bool: + """Whether ``∂²output[i] / ∂input[j] ∂input[k] = 0 ∀ i,j,k``.""" raise NotImplementedError - def hessian_is_diagonal(self): - """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`.""" + def hessian_is_diagonal(self) -> bool: + """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`. + + The Hessian diagonal is only defined for layers that preserve the size + of their input. + + Must be implemented by descendants that don't implement ``hessian_is_zero``. + """ raise NotImplementedError - def hessian_diagonal(self): + # FIXME Currently returns `∂²output[i] / ∂input[i]² * g_out[0][i]`, + # which s the residual matrix diagonal, rather than the Hessian diagonal + def hessian_diagonal( + self, module: Module, g_in: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Tensor: """Return `∂²output[i] / ∂input[i]²`. Only required if `hessian_is_diagonal` returns `True`. + + Args: + module: Module whose output-input Hessian diagonal is computed. + g_in: Gradients w.r.t. the module input. + g_out: Gradients w.r.t. the module output. + + Returns: + Hessian diagonal. Has same shape as module input. """ raise NotImplementedError diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 1192490cd..0ed46a3a4 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -13,9 +13,11 @@ from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems from test.core.derivatives.settings import SETTINGS +from warnings import warn import pytest import torch +from pytest import fixture, skip from backpack.core.derivatives.convnd import weight_jac_t_save_memory @@ -280,47 +282,79 @@ def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem) -> None: problem.tear_down() -@pytest.mark.skip("[WAITING] Autograd issue with Hessian-vector products") -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_is_zero(problem): - """Check if the input-output Hessian is (non-)zero.""" - problem.set_up() +@fixture(params=PROBLEMS, ids=lambda p: p.make_id()) +def instantiated_problem(request) -> DerivativesTestProblem: + """Set seed, create tested layer and data. Finally clean up. - backpack_res = BackpackDerivatives(problem).hessian_is_zero() - autograd_res = AutogradDerivatives(problem).hessian_is_zero() + Args: + request (SubRequest): Request for the fixture from a test/fixture function. - assert backpack_res == autograd_res - problem.tear_down() + Yields: + Test case with deterministically constructed attributes. + """ + case = request.param + case.set_up() + yield case + case.tear_down() -@pytest.mark.skip -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_is_diagonal(problem): - problem.set_up() +@fixture +def small_input_problem( + instantiated_problem: DerivativesTestProblem, max_input_numel: int = 100 +) -> DerivativesTestProblem: + """Skip cases with large inputs. - # TODO - raise NotImplementedError + Args: + max_input_numel: Maximum input size. Default: ``100``. - problem.tear_down() + Yields: + Instantiated test case with small input. + """ + if instantiated_problem.input.numel() > max_input_numel: + skip( + "Input is too large:" + + f" {instantiated_problem.input.numel()} > {max_input_numel}" + ) + else: + yield instantiated_problem -@pytest.mark.skip -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_diagonal(problem): - problem.set_up() +@fixture +def no_loss_problem( + small_input_problem: DerivativesTestProblem, +) -> DerivativesTestProblem: + """Skip cases that are loss functions or have to large inputs. - # TODO - raise NotImplementedError + Args: + small_input_problem: Test case with small input. - problem.tear_down() + Yields: + Instantiated test case that is not a loss layer. + """ + if small_input_problem.is_loss(): + skip("Only required for non-loss layers.") + else: + yield small_input_problem -@pytest.mark.skip -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_is_psd(problem): - problem.set_up() +def test_hessian_is_zero(no_loss_problem: DerivativesTestProblem): + """Check if the input-output Hessian is (non-)zero. - # TODO - raise NotImplementedError + Note: + `hessian_is_zero` is a global statement that assumes arbitrary inputs. + It can thus happen that the Hessian diagonal is zero for the current + input, but not in general. - problem.tear_down() + Args: + no_loss_problem: Test case whose module is not a loss. + """ + backpack_res = BackpackDerivatives(no_loss_problem).hessian_is_zero() + autograd_res = AutogradDerivatives(no_loss_problem).hessian_is_zero() + + if autograd_res and not backpack_res: + warn( + "Autograd Hessian diagonal is zero for this input " + " while BackPACK implementation implies inputs with non-zero Hessian." + ) + else: + assert backpack_res == autograd_res diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 1cf7fa96a..c40a7de34 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -1,6 +1,7 @@ from test.core.derivatives.implementation.base import DerivativesImplementation import torch +from torch import zeros_like from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.lop import transposed_jacobian_vector_product @@ -160,10 +161,18 @@ def hessian(self, loss, x): def elementwise_hessian(self, tensor, x): """Yield the Hessian of each element in `tensor` w.r.t `x`. - Hessians are returned in the order of elements in the flattened tensor. + If ``tensor`` is linear in ``x``, autograd raises a ``RuntimeError``. + If ``tensor`` does not depend on ``x``, autograd raises an ``AttributeError``. + In both cases, a Hessian of zeros is created manually and returned. + + Yields: + Hessians in the order of elements in the flattened tensor. """ for t in tensor.flatten(): - yield self.hessian(t, x) + try: + yield self.hessian(t, x) + except (RuntimeError, AttributeError): + yield torch.zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype) def tensor_hessian(self, tensor, x): """Return the Hessian of a tensor `tensor` w.r.t. a tensor `x`. @@ -185,18 +194,13 @@ def tensor_hessian(self, tensor, x): return torch.cat(list(self.elementwise_hessian(tensor, x))).reshape(shape) - def hessian_is_zero(self): - """Return whether the input-output Hessian is zero. - - Returns: - bool: `True`, if Hessian is zero, else `False`. - """ + def hessian_is_zero(self) -> bool: input, output, _ = self.problem.forward_pass(input_requires_grad=True) zero = None for hessian in self.elementwise_hessian(output, input): if zero is None: - zero = torch.zeros_like(hessian) + zero = zeros_like(hessian) if not torch.allclose(hessian, zero): return False diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index c36d94d9f..0b82431d2 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -60,7 +60,6 @@ def sum_hessian(self): return self.problem.derivative.sum_hessian(self.problem.module, None, None) def input_hessian_via_sqrt_hessian(self, mc_samples=None): - # MC_SAMPLES = 100000 self.store_forward_io() if mc_samples is not None: @@ -78,12 +77,7 @@ def input_hessian_via_sqrt_hessian(self, mc_samples=None): individual_hessians, self.problem.module.input0 ) - def hessian_is_zero(self): - """Return whether the input-output Hessian is zero. - - Returns: - bool: `True`, if Hessian is zero, else `False`. - """ + def hessian_is_zero(self) -> bool: return self.problem.derivative.hessian_is_zero() def _sample_hessians_from_sqrt(self, sqrt): diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index 9edaa8194..a33eac8ed 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -21,3 +21,11 @@ def weight_jac_mat_prod(self, mat): def bias_jac_mat_prod(self, mat): raise NotImplementedError + + def hessian_is_zero(self) -> bool: + """Return whether the input-output Hessian is zero. + + Returns: + `True`, if Hessian is zero, else `False`. + """ + raise NotImplementedError From b6e661b235addcf8a84bddaafd5a922bfc5416b2 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 24 Jun 2021 15:22:46 +0200 Subject: [PATCH 20/54] [extensions] Support additional dims in Linear (#186) Adds support for first-order extensions, diagonal second-order extensions, and an error message for Kronecker curvatures. Improve documentation and clean up (white space, replace `torch.nn.` by direct imports). Had to slightly adapt the Monte-Carlo test tolerances to make the new settings pass. --- * [DOC] Fully document first-order settings - White space cleanup - Docstring improvements - Make imports shorter * [ADD] Support additional dims in Linear first-order extensions * [TEST] Slightly increase atol for Variance, improve documentation * [TEST] Remove redundant local settings and fix file name * [FMT] Shorten imports, white space clean up * [ADD] Support additional dims in Linear diagonal extensions * [ADD] Only allow 2d inputs to Linear Kronecker curvatures * [DOC] Improve docstring, create TODO for einsum performance * [REF] Remove unnecessary parentheses --- .../firstorder/batch_l2_grad/linear.py | 22 ++- .../firstorder/sum_grad_squared/linear.py | 11 +- backpack/extensions/secondorder/hbp/linear.py | 19 +++ backpack/utils/__init__.py | 1 + backpack/utils/linear.py | 65 +++++++-- fully_documented.txt | 7 + .../firstorder/firstorder_settings.py | 127 ++++++++++++------ .../firstorder/variance/__init__.py | 1 + .../firstorder/variance/test_variance.py | 20 +-- .../firstorder/variance/variance_settings.py | 6 +- .../secondorder/diag_ggn/diagggn_settings.py | 18 +++ .../secondorder/diag_ggn/diaggnn_settings.py | 27 ---- .../diag_ggn/test_batch_diag_ggn.py | 4 +- .../secondorder/diag_ggn/test_diag_ggn.py | 2 +- .../secondorder/hbp/kfac_settings.py | 9 +- .../secondorder/hbp/kflr_settings.py | 9 +- .../secondorder/hbp/kfra_settings.py | 9 +- .../secondorder/secondorder_settings.py | 127 ++++++++++++------ 18 files changed, 342 insertions(+), 142 deletions(-) create mode 100644 test/extensions/secondorder/diag_ggn/diagggn_settings.py delete mode 100644 test/extensions/secondorder/diag_ggn/diaggnn_settings.py diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py index 6a9b3b73c..40d90e760 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/linear.py +++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py @@ -8,8 +8,26 @@ def __init__(self): super().__init__(params=["bias", "weight"]) def bias(self, ext, module, g_inp, g_out, backproped): + add_axes = list(range(1, g_out[0].dim() - 1)) + + if add_axes: + grad_batch = g_out[0].sum(add_axes) + else: + grad_batch = g_out[0] + C_axis = 1 - return (g_out[0] ** 2).sum(C_axis) + + return (grad_batch ** 2).sum(C_axis) def weight(self, ext, module, g_inp, g_out, backproped): - return einsum("ni,nj->n", (g_out[0] ** 2, module.input0 ** 2)) + add_axes = list(range(1, g_out[0].dim() - 1)) + + if add_axes: + # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class + # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 + dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) + X = module.input0.flatten(start_dim=1, end_dim=-2) + return einsum("nmi,nmj,nki,nkj->n", dE_dY, X, dE_dY, X) + + else: + return einsum("ni,nj->n", g_out[0] ** 2, module.input0 ** 2) diff --git a/backpack/extensions/firstorder/sum_grad_squared/linear.py b/backpack/extensions/firstorder/sum_grad_squared/linear.py index 4cf75db1d..2aa2a549a 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/linear.py +++ b/backpack/extensions/firstorder/sum_grad_squared/linear.py @@ -18,4 +18,13 @@ def weight(self, ext, module, g_inp, g_out, backproped): For details, see page 12 (paragraph about "second moment") of the paper (https://arxiv.org/pdf/1912.10985.pdf). """ - return einsum("ni,nj->ij", (g_out[0] ** 2, module.input0 ** 2)) + add_axes = list(range(1, g_out[0].dim() - 1)) + + if add_axes: + # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class + # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 + dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) + X = module.input0.flatten(start_dim=1, end_dim=-2) + return einsum("nmi,nmj,nki,nkj->ij", dE_dY, X, dE_dY, X) + + return einsum("ni,nj->ij", g_out[0] ** 2, module.input0 ** 2) diff --git a/backpack/extensions/secondorder/hbp/linear.py b/backpack/extensions/secondorder/hbp/linear.py index 779459e14..c89791a3c 100644 --- a/backpack/extensions/secondorder/hbp/linear.py +++ b/backpack/extensions/secondorder/hbp/linear.py @@ -1,4 +1,5 @@ from torch import einsum +from torch.nn import Linear from backpack.core.derivatives.linear import LinearDerivatives from backpack.extensions.secondorder.hbp.hbp_options import ( @@ -13,6 +14,7 @@ def __init__(self): super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"]) def weight(self, ext, module, g_inp, g_out, backproped): + self.check_parameters(ext, module) bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): @@ -44,6 +46,7 @@ def _factor_from_sqrt(self, backproped): return [einsum("vni,vnj->ij", (backproped, backproped))] def bias(self, ext, module, g_inp, g_out, backproped): + self.check_parameters(ext, module) bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): @@ -61,3 +64,19 @@ def __mean_input_outer(self, module): N = module.input0.size(0) flat_input = module.input0.reshape(N, -1) return einsum("ni,nj->ij", (flat_input, flat_input)) / N + + def check_parameters(self, ext, module: Linear) -> None: + """Raise an exception if module parameters are not supported. + + Args: + ext (KFAC or KFRA or KFLR): Extension calling out to the module. + module: Linear layer. + + Raises: + NotImplementedError: If the setting is not implemented. + """ + if module.input0.dim() != 2: + raise NotImplementedError( + f"Only 2d inputs are supported by {ext.__class__.__name__} " + + f"(got {module.input0.dim()})." + ) diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index e69de29bb..39f7fa2b1 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -0,0 +1 @@ +"""Contains utility functions.""" diff --git a/backpack/utils/linear.py b/backpack/utils/linear.py index b3a2453b3..e2b825323 100644 --- a/backpack/utils/linear.py +++ b/backpack/utils/linear.py @@ -1,17 +1,60 @@ -from torch import einsum +"""Contains utility functions to extract the GGN diagonal for linear layers.""" +from torch import Tensor, einsum +from torch.nn import Linear -def extract_weight_diagonal(module, backproped, sum_batch=True): - if sum_batch: - equation = "vno,ni->oi" +def extract_weight_diagonal( + module: Linear, S: Tensor, sum_batch: bool = True +) -> Tensor: + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the weight Jacobian. + + Args: + module: Linear layer for which the diagonal is extracted w.r.t. the weight. + S: Backpropagated symmetric factorization of the loss Hessian. Has shape + ``(V, *module.output.shape)``. + sum_batch: Sum out the weight diagonal's batch dimension. Default: ``True``. + + Returns: + Per-sample weight diagonal if ``sum_batch=False`` (shape + ``(N, module.weight.shape)`` with batch size ``N``) or summed weight diagonal + if ``sum_batch=True`` (shape ``module.weight.shape``). + """ + add_axes = list(range(1, module.input0.dim() - 1)) + + if add_axes: + S_flat = S.flatten(start_dim=2, end_dim=-2) + X_flat = module.input0.flatten(start_dim=1, end_dim=-2) + equation = f"vnmo,nmi,vnko,nki->{'' if sum_batch else 'n'}oi" + # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class + # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 + return einsum(equation, S_flat, X_flat, S_flat, X_flat) + else: - equation = "vno,ni->noi" - return einsum(equation, (backproped ** 2, module.input0 ** 2)) + equation = f"vno,ni->{'' if sum_batch else 'n'}oi" + return einsum(equation, S ** 2, module.input0 ** 2) + +def extract_bias_diagonal(module: Linear, S: Tensor, sum_batch: bool = True) -> Tensor: + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the bias Jacobian. -def extract_bias_diagonal(module, backproped, sum_batch=True): - if sum_batch: - equation = "vno->o" + Args: + module: Linear layer for which the diagonal is extracted w.r.t. the bias. + S: Backpropagated symmetric factorization of the loss Hessian. Has shape + ``(V, *module.output.shape)``. + sum_batch: Sum out the bias diagonal's batch dimension. Default: ``True``. + + Returns: + Per-sample bias diagonal if ``sum_batch=False`` (shape + ``(N, module.bias.shape)`` with batch size ``N``) or summed bias diagonal + if ``sum_batch=True`` (shape ``module.bias.shape``). + """ + add_axes = list(range(2, module.input0.dim())) + + if add_axes: + JS = S.sum(add_axes) else: - equation = "vno->no" - return einsum(equation, backproped ** 2) + JS = S + + equation = f"vno->{'' if sum_batch else 'n'}o" + + return einsum(equation, JS ** 2) diff --git a/fully_documented.txt b/fully_documented.txt index ca5f6ec86..49626aaa1 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -1,6 +1,10 @@ test/extensions/test_backprop_extension.py +test/extensions/firstorder/firstorder_settings.py +test/extensions/firstorder/variance + test/extensions/secondorder/secondorder_settings.py +test/extensions/secondorder/diag_ggn/diagggn_settings.py test/extensions/secondorder/hbp @@ -20,3 +24,6 @@ backpack/extensions/__init__.py backpack/extensions/secondorder/sqrt_ggn backpack/core/derivatives/linear.py + +backpack/utils/linear.py +backpack/utils/__init__.py diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 0385e3c4e..dba39f392 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -1,29 +1,29 @@ -"""Test configurations for `backpack.core.extensions.firstorder` -that is shared among the following firstorder methods: -- batch_grad -- batch_l2_grad -- sum_grad_sqaured -- variance +"""Shared test cases for BackPACK's first-order extensions. +Shared by the tests of: +- ``BatchGrad`` +- ``BatchL2Grad`` +- ``SumGradSquared`` +- ``Variance`` Required entries: "module_fn" (callable): Contains a model constructed from `torch.nn` layers "input_fn" (callable): Used for specifying input function - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model Optional entries: - "device" [list(torch.device)]: List of devices to run the test on. + "device" [list(device)]: List of devices to run the test on. "id_prefix" (str): Prefix to be included in the test name. - "seed" (int): seed for the random number for torch.rand + "seed" (int): seed set before initializing a case. """ from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.automated_settings import make_simple_cnn_setting -import torch +from torch import device, rand from torch.nn import ( Conv1d, Conv2d, @@ -31,6 +31,13 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + CrossEntropyLoss, + Flatten, + Linear, + MSELoss, + ReLU, + Sequential, + Sigmoid, ) FIRSTORDER_SETTINGS = [] @@ -40,11 +47,11 @@ ############################################################################### example = { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential(torch.nn.Linear(10, 5)), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), - "device": [torch.device("cpu")], + "device": [device("cpu")], "seed": 0, "id_prefix": "example", } @@ -57,47 +64,91 @@ FIRSTORDER_SETTINGS += [ # classification { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3,), 5), }, { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, - # Regression + # regression { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] +# linear with additional dimension +FIRSTORDER_SETTINGS += [ + # regression + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 4, 2)), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Sigmoid(), Linear(3, 2)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 4, 2, 2)), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)), + "loss_function_fn": lambda: MSELoss(reduction="sum"), + "target_fn": lambda: regression_targets((3, 4, 2, 3, 2)), + "id_prefix": "three-additional", + }, + # classification + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 8), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential( + Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten() + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 16), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 48), + "id_prefix": "three-additional", + }, +] + ############################################################################### # test setting: Convolutional Layers # """ -Syntax with default parameters: - - `torch.nn.ConvNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, dilation=1, - groups=1, bias=True, padding_mode='zeros)` +Syntax with default parameters: + - `torch.nn.ConvNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, dilation=1, + groups=1, bias=True, padding_mode='zeros)` - - `torch.nn.ConvTransposeNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, output_padding=0, + - `torch.nn.ConvTransposeNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros)` -Note: There are 5 tests added to each `torch.nn.layers`. +Note: There are 5 tests added to each `torch.nn.layers`. For `torch.nn.ConvTranspose2d` and `torch.nn.ConvTranspose3d` -only 3 tests are added because they are very memory intensive. +only 3 tests are added because they are very memory intensive. """ ############################################################################### diff --git a/test/extensions/firstorder/variance/__init__.py b/test/extensions/firstorder/variance/__init__.py index e69de29bb..7ca2b624c 100644 --- a/test/extensions/firstorder/variance/__init__.py +++ b/test/extensions/firstorder/variance/__init__.py @@ -0,0 +1 @@ +"""Contains tests for BackPACK's ``Variance`` extension.""" diff --git a/test/extensions/firstorder/variance/test_variance.py b/test/extensions/firstorder/variance/test_variance.py index 50cf21675..8680c2086 100644 --- a/test/extensions/firstorder/variance/test_variance.py +++ b/test/extensions/firstorder/variance/test_variance.py @@ -1,16 +1,9 @@ -"""Test class for module variance -from `backpack.core.extensions.firstorder` - -Test variances for the following layers: -- variance of linear layers -- variance of convolutional layers - -""" +"""Test BackPACK's ``Variance`` extension.""" from test.automated_test import check_sizes_and_values from test.extensions.firstorder.variance.variance_settings import VARIANCE_SETTINGS from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions -from test.extensions.problem import make_test_problems +from test.extensions.problem import ExtensionsTestProblem, make_test_problems import pytest @@ -19,16 +12,17 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_variance(problem): - """Test variance of individual gradients +def test_variance(problem: ExtensionsTestProblem) -> None: + """Test variance of individual gradients. Args: - problem (ExtensionsTestProblem): Problem for extension test. + problem: Test case. """ problem.set_up() backpack_res = BackpackExtensions(problem).variance() autograd_res = AutogradExtensions(problem).variance() - check_sizes_and_values(autograd_res, backpack_res) + rtol = 5e-5 + check_sizes_and_values(autograd_res, backpack_res, rtol=rtol) problem.tear_down() diff --git a/test/extensions/firstorder/variance/variance_settings.py b/test/extensions/firstorder/variance/variance_settings.py index 61c39d0d6..c8a8de2da 100644 --- a/test/extensions/firstorder/variance/variance_settings.py +++ b/test/extensions/firstorder/variance/variance_settings.py @@ -1,7 +1,7 @@ -"""Test configurations to test variance +"""Test cases for ``Variance`` extension. -The tests are taken from `test.extensions.firstorder.firstorder_settings`, -but additional custom tests can be defined here by appending it to the list. +Uses shared test cases from `test.extensions.firstorder.firstorder_settings`, +and the local cases defined in this file. """ from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS diff --git a/test/extensions/secondorder/diag_ggn/diagggn_settings.py b/test/extensions/secondorder/diag_ggn/diagggn_settings.py new file mode 100644 index 000000000..d0d806671 --- /dev/null +++ b/test/extensions/secondorder/diag_ggn/diagggn_settings.py @@ -0,0 +1,18 @@ +"""Test cases for BackPACK extensions for the GGN diagonal. + +Includes +- ``DiagGGNExact`` +- ``DiagGGNMC`` +- ``BatchDiagGGNExact`` +- ``BatchDiagGGNMC`` + +Shared settings are taken from `test.extensions.secondorder.secondorder_settings`. +Additional local cases can be defined here through ``LOCAL_SETTINGS``. +""" + +from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS + +SHARED_SETTINGS = SECONDORDER_SETTINGS +LOCAL_SETTINGS = [] + +DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/secondorder/diag_ggn/diaggnn_settings.py b/test/extensions/secondorder/diag_ggn/diaggnn_settings.py deleted file mode 100644 index e0d5b7bbc..000000000 --- a/test/extensions/secondorder/diag_ggn/diaggnn_settings.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Test configurations to test diag_ggn - -The tests are taken from `test.extensions.secondorder.secondorder_settings`, -but additional custom tests can be defined here by appending it to the list. -""" - -from test.extensions.automated_settings import make_simple_act_setting -from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS - -from torch.nn import ELU, SELU - -DiagGGN_SETTINGS = [] - -SHARED_SETTINGS = SECONDORDER_SETTINGS - -LOCAL_SETTINGS = [] - -############################################################################### -# test setting: Activation Layers # -############################################################################### -activations = [ELU, SELU] - -for act in activations: - for bias in [True, False]: - LOCAL_SETTINGS.append(make_simple_act_setting(act, bias=bias)) - -DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py index f0587b57b..ca0d9c4df 100644 --- a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py @@ -2,7 +2,7 @@ from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems -from test.extensions.secondorder.diag_ggn.diaggnn_settings import DiagGGN_SETTINGS +from test.extensions.secondorder.diag_ggn.diagggn_settings import DiagGGN_SETTINGS import pytest @@ -42,7 +42,7 @@ def test_diag_ggn_mc_batch_light(problem): problem.set_up() backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() - mc_samples = 5000 + mc_samples = 6000 backpack_res_mc_avg = BackpackExtensions(problem).diag_ggn_mc_batch(mc_samples) check_sizes_and_values( diff --git a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py index 858e479b5..2d9d7fcf5 100644 --- a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py @@ -2,7 +2,7 @@ from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems -from test.extensions.secondorder.diag_ggn.diaggnn_settings import DiagGGN_SETTINGS +from test.extensions.secondorder.diag_ggn.diagggn_settings import DiagGGN_SETTINGS import pytest diff --git a/test/extensions/secondorder/hbp/kfac_settings.py b/test/extensions/secondorder/hbp/kfac_settings.py index 895b99cc5..a6d0ece1d 100644 --- a/test/extensions/secondorder/hbp/kfac_settings.py +++ b/test/extensions/secondorder/hbp/kfac_settings.py @@ -1,8 +1,13 @@ """Define test cases for KFAC.""" -from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS +from test.extensions.secondorder.secondorder_settings import ( + GROUP_CONV_SETTINGS, + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, +) -SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +SHARED_NOT_SUPPORTED_SETTINGS = ( + GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS +) LOCAL_NOT_SUPPORTED_SETTINGS = [] NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/hbp/kflr_settings.py b/test/extensions/secondorder/hbp/kflr_settings.py index 6b74a2842..de61c5b3b 100644 --- a/test/extensions/secondorder/hbp/kflr_settings.py +++ b/test/extensions/secondorder/hbp/kflr_settings.py @@ -1,8 +1,13 @@ """Define test cases for KFLR.""" -from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS +from test.extensions.secondorder.secondorder_settings import ( + GROUP_CONV_SETTINGS, + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, +) -SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +SHARED_NOT_SUPPORTED_SETTINGS = ( + GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS +) LOCAL_NOT_SUPPORTED_SETTINGS = [] NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/hbp/kfra_settings.py b/test/extensions/secondorder/hbp/kfra_settings.py index 5a28ab738..94e65c2b7 100644 --- a/test/extensions/secondorder/hbp/kfra_settings.py +++ b/test/extensions/secondorder/hbp/kfra_settings.py @@ -1,8 +1,13 @@ """Define test cases for KFRA.""" -from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS +from test.extensions.secondorder.secondorder_settings import ( + GROUP_CONV_SETTINGS, + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, +) -SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +SHARED_NOT_SUPPORTED_SETTINGS = ( + GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS +) LOCAL_NOT_SUPPORTED_SETTINGS = [] NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index f50e4658e..44a4483b3 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -8,14 +8,14 @@ Required entries: "module_fn" (callable): Contains a model constructed from `torch.nn` layers "input_fn" (callable): Used for specifying input function - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model Optional entries: "device" [list(torch.device)]: List of devices to run the test on. "id_prefix" (str): Prefix to be included in the test name. - "seed" (int): seed for the random number for torch.rand + "seed" (int): seed for the random number for rand """ @@ -26,7 +26,7 @@ make_simple_pooling_setting, ) -import torch +from torch import device, rand from torch.nn import ( ELU, SELU, @@ -39,12 +39,17 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + CrossEntropyLoss, + Flatten, LeakyReLU, + Linear, LogSigmoid, MaxPool1d, MaxPool2d, MaxPool3d, + MSELoss, ReLU, + Sequential, Sigmoid, Tanh, ) @@ -56,11 +61,11 @@ ############################################################################### example = { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential(torch.nn.Linear(10, 5)), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(), "target_fn": lambda: classification_targets((3,), 5), - "device": [torch.device("cpu")], + "device": [device("cpu")], "seed": 0, "id_prefix": "example", } @@ -70,28 +75,22 @@ SECONDORDER_SETTINGS += [ # classification { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3,), 5), }, { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, # Regression { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] @@ -109,8 +108,8 @@ ############################################################################### # test setting: Pooling Layers # """ -Syntax with default parameters: - - `torch.nn.MaxPoolNd(kernel_size, stride, padding, dilation, +Syntax with default parameters: + - `MaxPoolNd(kernel_size, stride, padding, dilation, return_indices, ceil_mode)` """ ############################################################################### @@ -150,18 +149,18 @@ ############################################################################### # test setting: Convolutional Layers # """ -Syntax with default parameters: - - `torch.nn.ConvNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, dilation=1, - groups=1, bias=True, padding_mode='zeros)` +Syntax with default parameters: + - `ConvNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, dilation=1, + groups=1, bias=True, padding_mode='zeros)` - - `torch.nn.ConvTransposeNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, output_padding=0, + - `ConvTransposeNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros)` -Note: There are 5 tests added to each `torch.nn.layers`. -For `torch.nn.ConvTranspose2d` and `torch.nn.ConvTranspose3d` -only 3 tests are added because they are very memory intensive. +Note: There are 5 tests added to each `layers`. +For `ConvTranspose2d` and `ConvTranspose3d` +only 3 tests are added because they are very memory intensive. """ ############################################################################### @@ -239,12 +238,64 @@ # Flatten layer does not add a node in the computation graph and thus the # backward hook will be called at an unexpected stage. This must explicitly # be addressed in the `backpropagate` function of the flatten module extension. - "input_fn": lambda: torch.rand(3, 5), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(5, 4), torch.nn.Flatten(), torch.nn.Linear(4, 2) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "input_fn": lambda: rand(3, 5), + "module_fn": lambda: Sequential(Linear(5, 4), Flatten(), Linear(4, 2)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3,), 2), "id_prefix": "flatten-no-op", }, ] + +# linear with additional dimension +LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS = [ + # regression + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 8)), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential( + Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten() + ), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 16)), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: MSELoss(reduction="sum"), + "target_fn": lambda: regression_targets((3, 48)), + "id_prefix": "three-additional", + }, + # classification + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 8), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential( + Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten() + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 16), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 48), + "id_prefix": "three-additional", + }, +] + +SECONDORDER_SETTINGS += LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS From 9bbeebc39142111c110502392153413359364146 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 24 Jun 2021 16:47:21 +0200 Subject: [PATCH 21/54] [FIX] Tweak tolerances for `SqrtGGNMC` tests, reduce cases Squash --- test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py index 96cde05c5..6aa48cece 100644 --- a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -29,7 +29,7 @@ def instantiated_problem(request) -> ExtensionsTestProblem: @fixture def small_problem( - instantiated_problem: ExtensionsTestProblem, max_num_params=4000 + instantiated_problem: ExtensionsTestProblem, max_num_params=1000 ) -> ExtensionsTestProblem: """Skip architectures with too many parameters whose GGN is expensive to evaluate. @@ -85,8 +85,8 @@ def test_ggn_mc(small_problem: ExtensionsTestProblem): small_problem: Test case with small network whose GGN can be evaluated. """ autograd_res = AutogradExtensions(small_problem).ggn() - atol, rtol = 1e-3, 1e-2 - mc_samples, chunks = 500000, 50 + atol, rtol = 5e-2, 1e-2 + mc_samples, chunks = 300000, 30 backpack_res = BackpackExtensions(small_problem).ggn_mc(mc_samples, chunks=chunks) check_sizes_and_values(autograd_res, backpack_res, atol=atol, rtol=rtol) From 172f2219fe1fcdfe4b0b0b4284dca72216529bd9 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Tue, 29 Jun 2021 09:25:16 +0200 Subject: [PATCH 22/54] Update from `development` into `rnn` (#188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [DOC] Mention PR documentation requirements * [FMT] Apply auto-formatting with new black release Use black 21.4b0, released on 2021-04-26 (https://github.com/psf/black/releases/tag/21.4b0) * [DOC] Extend first-order extension with custom module tutorial (#152) Modifies how external modules are supported: The old approach was based on class variables, which could lead to subtle bugs due to the lookup order. Switching to instance variables eliminates this potential source of bugs. Includes a self-contained walk through how to extend a custom layer to add support for first-order extensions. Resolves #149 Co-authored-by: Tim Schäfer Co-authored-by: Felix Dangel * [core] ConvTransposeNd _weight_jac_t_mat_prod support groups (#151) Resolves https://github.com/fKunstner/backpack-discuss/issues/84. * support weight_jac_t to ConvTranspose * PR fixes * fix spaces for PR * PR changes * fix black * fix black again * [DOC] Improve documentation Co-authored-by: Felix Dangel * [DOC] Explain `save_memory` for convolutions with use case (#142) * [DOC] Draft use case explaining how to enable save_memory in convolutions * [DOC] Polish save_memory use case and add links to benchmark * [CI] Add partial docstring checks (#157) Fully documented files are added to `fully_documented.txt`. Once the entire code base is documented, the partial checks can be removed, and the full tests be turned on. * [core] Support groups ≠ 1 in ConvNd weight_jac_t (#161) Resolves fKunstner/backpack-discuss#83. * [core] Support groups ≠ 1 in ConvTransposeNd weight_jac_mat_prod (#162) * [core] Support groups ≠ 1 in ConvTransposeNd weight_jac * [test] Merge group conv with regular settings * [format] Remove blank line * [utils] Support groups ≠ 1 in ConvNd extract_weight_diagonal (#163) * [test] Reproduce behavior that groups ≠ 1 is not supported * [utils] Support groups ≠ 1 in ConvNd extract_weight_diagonal Also refactor the code and replace and move a `transpose` inside an `einsum` operation. * [test] Add cases for grouped Conv2d and Conv3d * [fix] Fix Conv{2,3}d by grouping spatial dimensions * [ref, doc] Use inplace power function, add documentation * [doc] polish docstring Co-authored-by: Felix Dangel * [utils] Support groups ≠ 1 in ConvTransposeNd extract_weight_diagonal (#164) Fixes `groups ≠ 1` in diagonal curvature extensions for transpose convolutions. --- * [test] Reproduce behavior that groups ≠ 1 is not supported * [utils] Support groups ≠ 1 in ConvNdTranspose extract_weight_diagonal * [test] Add cases for grouped ConvTranspose{2,3}d * [test] Make ConvTranspose architectures smaller * [utils] Change shape convention of transpose convolution unfolding Align to convention of ``torch.nn.Unfold`` (the equivalent for normal convolutions) and document new behavior. * [ref, doc] Remove transposes, use inplace power, add doc * [refactor] Write kernel size elements more compactly Co-authored-by: Felix Dangel * [test] Increase tolerance for old hvp tests They sometimes fail as they are unseeded. See https://github.com/fKunstner/backpack-discuss/issues/103 where such events are documented to reduce such unwanted failures. * [doc] enforce auto-formatting & linting rules on examples (#166) Resolves https://github.com/fKunstner/backpack-discuss/issues/79. * [doc] Auto-format examples * [lint] Use max-line-length=88 in flake8 configuration * [fix] Make examples pass flake8 * [KFAC, KFLR, KFRA] Raise exception for group convolutions (#167) * [hbp] Raise exception if `groups ≠ 1`, test for KFAC * [hbp] Test KFLR raises exception for `groups ≠ 1` * [hbp] Test KFRA raises exception for `groups ≠ 1` * [ci] Replace fully documented files by parent directory * [doc] Improve test description * [del] Remove dead code Co-authored-by: Felix Dangel * [DiagHessian] Support `ELU` and `SELU` (#168) Fixes two bugs in the second-order derivatives of `ELU` and `SELU`. Resolves https://github.com/fKunstner/backpack-discuss/issues/69. * [core] Fix math of ELU second derivative * [ref] Rewrite ELU derivatives * [DiagHessian] Add support for ELU * [core] Fix bug in second SELU derivative * [DiagHessian] Add support for SELU * [ref] Make ELU and SELU derivatives more similar * [ref] Add blank lines before `return` statements * [doc] Improve documentation * [ADD] BatchDiagHessian extension (#137) Extension to compute per-sample (individual) Hessian diagonal. Resolves https://github.com/fKunstner/backpack-discuss/issues/81. Auxiliary: - Many refactorings in utility functions to extract diagonals and L2 norms - Aggregate terms for (per-sample and sample-mean) Hessian diagonals with inplace summation --- * [doc] Add docstrings in Hessian diagonal for Conv1d * [doc] Add docstrings for Hessian diagonal of Conv{2,3}d * [doc] Add missing docstring for DiagHessian Conv3d * [doc] Fully document {Batch}DiagHessian extensions * [DiagHessian, ConvND] Aggregate terms inplace * [BatchDiagHessian] Share device and dtype with parameter * [fix] Black formatting * [BatchDiagHessian, DiagHessian] Aggregate result inplace * Remove conv dimension from diagonal extraction utilities * Remove conv dimension from diagonal extractors * Remove redundant docstrings * Return unfolded input instead of function that evaluates it * [fix] pydocstyle * Remove get_bias_gradient_factors for convolution * fix bias l2 computation Co-authored-by: Felix Dangel * [REF] Remove dead utility code * [DOC] Add per-sample curvatures to example and doc (#170) Relevant parts copied from https://github.com/f-dangel/backpack/pull/146. Resolves https://github.com/fKunstner/backpack-discuss/issues/90. Resolves https://github.com/fKunstner/backpack-discuss/issues/91. * [DOC] Update supported layers and future features (#171) * [FIX] RTD build for `save_memory` example - Reduce batch size for Read the Docs build With batch size 128, builds fails with `Command killed due to excessive memory consumption` (https://readthedocs.org/projects/backpack/builds/14015934/). - Synchronize CUDA only if device is GPU Log of build from RTD: https://readthedocs.org/projects/backpack/builds/14016051/ * [CFG] Move install and dev tool configuration to `setup.cfg` (#172) - Move configurations of linters and other dev tools to `setup.cfg` - Move dependencies to `setup.cfg` - Move main library meta data and installation info to `setup.cfg`, replace manual with automatic versioning (merged `master` to bump the development version to `1.2.xxx`) Auxiliary: - Remove `pre-commit` from development tools - I originally wanted to automate publishing to PyPI for new tags. This requires setting up GitHub secrets for the PyPI user and password. According to the [doc](https://docs.github.com/en/actions/reference/encrypted-secrets), those secrets can be seen by anyone with 'collaborator' access. I don't think that sharing my credentials is a good solution, hence I will stick to the manual procedure. * [REF] Move linter configurations to a `setup.cfg` * [DEL] Remove pre-commit from workflow * [REF] Move installation and dependencies to `setup.cfg` Switch to automatic version detection. * [DOC] Justify pytest configuration separate from `setup.cfg` * [FIX] Don't lint `.eggs` directory with flake8 * [FIX] Dependencies for read the docs build * Prepare `1.3.0` release (#173) - Update changelog - Update API and improve existing documentation for RTD --- * [DOC] Update and improve docstrings * [DOC] Add `disable` context to API documentation * [DOC] Remove `AvgPool{1,3}d` from supported layers * [DOC] Update changelog for `1.3.0` release * [DOC] Use napoleon sphinx extension * [DOC] Adapt to google docstring * [FIX] Format of scaling issue warning * [DOC] Tweak headers, remove blank lines * [FIX] Add missing colon * [FIX] Make pydocstyle pass * [FIX] PyPI release of `1.3.0` (#176) * [FIX] Remove quotes in url * [FIX] Move `pytorch_memlab` dependency to makefile It is not possible to have GitHub dependencies in `setup.cfg` when attempting to push to PyPI * [FIX] Exclude test package in installation https://stackoverflow.com/a/59686298 * [FIX] PyPI release of `1.3.0` (#176) (#177) * [FIX] Remove quotes in url * [FIX] Move `pytorch_memlab` dependency to makefile It is not possible to have GitHub dependencies in `setup.cfg` when attempting to push to PyPI * [FIX] Exclude test package in installation https://stackoverflow.com/a/59686298 * [ADD] `SqrtGGNExact` extension (#180) This PR adds an extension that computes the matrix square root of the generalized Gauss-Newton, using an exact factorization of the loss Hessian. For details, see Equation (3) of [this paper](https://arxiv.org/abs/2106.02624v1). * [ADD] SqrtGGNExact extension * [TEST] Compare GGN via square root with autograd * [TEST] Report explanations for skipped settings * [DEL] Remove validity check for loss Hessian strategy * [FMT] Replace `.format` with f-string * [DOC] Include `SqrtGGNExact` in all-in-one example * [FIX] Typo in docstring * [FMT] Remove blank line * [FIX] Docstring * [REF] Directly evaluate the Hessian sqrt without creating a function * [REF] Change ValueError into NotImplementedError * [ADD] Introduce public getter for `loss_hessian_strategy` * [FMT] Remove blank line, shorten * [DOC] Update default value * [FIX] Make condition and skip explanation more precise * [TEST] Add setting where `Flatten` does not perform an operation (#181) Adds an architecture with a `torch.nn.Flatten` layer that does not introduce a node in the computation graph. In this case the backward hook for the `Flatten` module will be called at an unexpected stage, which must be addressed in the second-order extensions. [This](https://coveralls.io/builds/40739633/source?filename=backpack%2Fextensions%2Fsecondorder%2Fsqrt_ggn%2Fflatten.py#L45) coverage report indicates that the new test suite did not include such a test case, whereas it seems to be covered by the old tests (the coverage report for second-order extensions that are tested in the old suite covers the linked branch in the module extensions for `Flatten`). * [ADD] `SqrtGGNMC` extension (#182) A Monte-Carlo approximation of the generalized Gauss-Newton/Fisher matrix square root. * [REF] Split instantiation and filtering of test cases * [ADD] `SqrtGGNMC` extension * [ADD] Make base implementation of extensions abstract, type annotations * [TEST] Share names of BackPACK/autograd implementation * [TEST] Compare MC with exact GGN using many samples * [FIX] Increase atol to make tests pass on GPU, too * [ADD] Integration test for `SqrtGGNMC`, fix docstring * [core] Support additional dimensions in input to `Linear` (#185) Addssupport for additional axes to the `LinearDerivatives` in the `core`. --- * [TEST] Preserve batch axis of individual samples To avoid PyTorch from accidentally identifying the additional axis of an individual sample fed through a linear layer as batch axis, explicitly keeps the batch axis during slicing out individual samples of a mini-batch. * [core] Support additional dimensions in Linear jac_t_mat_prod * [core] Support additional dimensions in Linear jac_mat_prod * [core] Support additional dimensions in Linear weight_jac_t * [core] Support additional dimensions in Linear weight_jac * [core] Support additional dimensions in Linear bias_jac_t * [core] Support additional dimensions in linear bias_jac * [core] Support additional dimensions in Linear ea_jac_t_mat_jac * [TEST] Merge additional dimension cases with regular cases * [DOC] Fully document LinearDerivatives * [core] Add test of `hessian_is_zero` property (#183) Adds a test that checks `BaseDerivatives.hessian_is_zero` and improves documentation. --- * [TEST] Make test for `hessian_is_zero` work * [REF] Use fixture for `test_hessian_is_zero` * [FIX] flake8 * [TEST] Restrict Hessian tests to small inputs * [TEST] Enable `hessian_is_diagonal` test * [DEL] All Hessian property tests except `test_hessian_is_zero` * [FIX] flake8 Co-authored-by: Felix Dangel * [extensions] Support additional dims in Linear (#186) Adds support for first-order extensions, diagonal second-order extensions, and an error message for Kronecker curvatures. Improve documentation and clean up (white space, replace `torch.nn.` by direct imports). Had to slightly adapt the Monte-Carlo test tolerances to make the new settings pass. --- * [DOC] Fully document first-order settings - White space cleanup - Docstring improvements - Make imports shorter * [ADD] Support additional dims in Linear first-order extensions * [TEST] Slightly increase atol for Variance, improve documentation * [TEST] Remove redundant local settings and fix file name * [FMT] Shorten imports, white space clean up * [ADD] Support additional dims in Linear diagonal extensions * [ADD] Only allow 2d inputs to Linear Kronecker curvatures * [DOC] Improve docstring, create TODO for einsum performance * [REF] Remove unnecessary parentheses * [FIX] Tweak tolerances for `SqrtGGNMC` tests * [FIX] Tweak tolerances for `SqrtGGNMC` tests, reduce cases Squash * [FIX] Incorporate feedback from PR review * [REF] Raise NotImplementedError in interface functions The original solution used `return` statements to satisfy darglint. But from `darglint>=1.8`, `raise` is also supported. Hence changing to the more comprehensible `NotImplementedError` solution. * [TEST] Add types, ignore missing docstrings (parent class) * [TEST] Add types, ignore/add missing docstrings (parent class) * Specify output type * [FIX] Change types of `g_inp, g_out` from `Any` to `Tuple[Tensor]` * [REF] Replace `weight.data` by `weight` * [CI] Unify directory format * [REF] Spell out `add_axes`, replace list with bool * [FIX] Remove redundant type * [DEL] Remove redundant print statement and type * [FIX] flake8 * [REF] Change `List[Tensor]` into `Tuple[Tensor]` Co-authored-by: Tim Schäfer Co-authored-by: Tim Schäfer Co-authored-by: Shrisha Bharadwaj Co-authored-by: Felix Dangel --- .../core/derivatives/adaptive_avg_pool_nd.py | 2 +- backpack/core/derivatives/basederivatives.py | 25 +- backpack/core/derivatives/linear.py | 239 +++++++++++++++--- backpack/extensions/__init__.py | 4 + .../firstorder/batch_l2_grad/linear.py | 12 +- .../firstorder/sum_grad_squared/linear.py | 11 +- backpack/extensions/secondorder/__init__.py | 19 +- backpack/extensions/secondorder/hbp/linear.py | 19 ++ .../secondorder/sqrt_ggn/__init__.py | 159 ++++++++++++ .../secondorder/sqrt_ggn/activations.py | 65 +++++ .../extensions/secondorder/sqrt_ggn/base.py | 69 +++++ .../extensions/secondorder/sqrt_ggn/convnd.py | 29 +++ .../secondorder/sqrt_ggn/convtransposend.py | 29 +++ .../secondorder/sqrt_ggn/dropout.py | 11 + .../secondorder/sqrt_ggn/flatten.py | 47 ++++ .../extensions/secondorder/sqrt_ggn/linear.py | 11 + .../extensions/secondorder/sqrt_ggn/losses.py | 72 ++++++ .../secondorder/sqrt_ggn/padding.py | 11 + .../secondorder/sqrt_ggn/pooling.py | 56 ++++ backpack/utils/__init__.py | 1 + backpack/utils/linear.py | 65 ++++- .../basic_usage/example_all_in_one.py | 15 ++ docs_src/rtd/extensions.rst | 2 + fully_documented.txt | 19 +- makefile | 8 +- test/core/derivatives/derivatives_test.py | 177 +++++++------ .../derivatives/implementation/autograd.py | 50 ++-- .../derivatives/implementation/backpack.py | 12 +- test/core/derivatives/implementation/base.py | 9 + test/core/derivatives/linear_settings.py | 21 +- .../firstorder/firstorder_settings.py | 134 ++++++---- .../firstorder/variance/__init__.py | 1 + .../firstorder/variance/test_variance.py | 20 +- .../firstorder/variance/variance_settings.py | 6 +- test/extensions/implementation/autograd.py | 72 ++++-- test/extensions/implementation/backpack.py | 144 +++++++++-- test/extensions/implementation/base.py | 118 +++++++-- .../secondorder/diag_ggn/diag_ggn_settings.py | 26 +- .../diag_ggn/test_batch_diag_ggn.py | 6 +- .../secondorder/hbp/kfac_settings.py | 9 +- .../secondorder/hbp/kflr_settings.py | 9 +- .../secondorder/hbp/kfra_settings.py | 9 +- .../secondorder/secondorder_settings.py | 132 +++++++--- .../secondorder/sqrt_ggn/__init__.py | 1 + .../secondorder/sqrt_ggn/test_sqrt_ggn.py | 92 +++++++ 45 files changed, 1676 insertions(+), 372 deletions(-) create mode 100644 backpack/extensions/secondorder/sqrt_ggn/__init__.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/activations.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/base.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/convnd.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/convtransposend.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/dropout.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/flatten.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/linear.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/losses.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/padding.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/pooling.py create mode 100644 test/extensions/secondorder/sqrt_ggn/__init__.py create mode 100644 test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py diff --git a/backpack/core/derivatives/adaptive_avg_pool_nd.py b/backpack/core/derivatives/adaptive_avg_pool_nd.py index 167d92495..71529a267 100644 --- a/backpack/core/derivatives/adaptive_avg_pool_nd.py +++ b/backpack/core/derivatives/adaptive_avg_pool_nd.py @@ -45,7 +45,7 @@ def check_parameters( ) # check if input shape is multiple of output shape - if any([shape_input[2 + n] % shape_output[2 + n] != 0 for n in range(self.N)]): + if any(shape_input[2 + n] % shape_output[2 + n] != 0 for n in range(self.N)): raise NotImplementedError( f"No equivalent AvgPool (unadaptive): Input shape ({shape_input}) " f"must be multiple of output shape ({shape_output})." diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 14f26a626..cf5635261 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -141,11 +141,13 @@ def ea_jac_t_mat_jac_prod( raise NotImplementedError def hessian_is_zero(self) -> bool: - """Returns whether hessian is zero. + """Returns whether Hessian is zero. + + I.e. whether ``∂²output[i] / ∂input[j] ∂input[k] = 0 ∀ i,j,k``. # noqa: DAR202 Returns: - whether hessian is zero + whether Hessian is zero Raises: NotImplementedError: if not overwritten @@ -157,21 +159,32 @@ def hessian_is_diagonal(self) -> bool: # noqa: DAR202 Returns: - whether hessian is diagonal + whether Hessian is diagonal Raises: NotImplementedError: if not overwritten """ raise NotImplementedError - def hessian_diagonal(self) -> Tensor: - """Return `∂²output[i] / ∂input[i]²`. + # FIXME Currently returns `∂²output[i] / ∂input[i]² * g_out[0][i]`, + # which s the residual matrix diagonal, rather than the Hessian diagonal + def hessian_diagonal( + self, module: Module, g_in: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Tensor: + """Return the Hessian diagonal `∂²output[i] / ∂input[i]²`. Only required if `hessian_is_diagonal` returns `True`. + The Hessian diagonal is only defined for layers that preserve the size + of their input. + + Args: + module: Module whose output-input Hessian diagonal is computed. + g_in: Gradients w.r.t. the module input. + g_out: Gradients w.r.t. the module output. # noqa: DAR202 Returns: - hessian diagonal + Hessian diagonal. Has same shape as module input. Raises: NotImplementedError: if not overwritten diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index 0e1f5f32b..a895d01eb 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -1,4 +1,8 @@ -from torch import einsum +"""Contains partial derivatives for the ``torch.nn.Linear`` layer.""" +from typing import Tuple + +from torch import Size, Tensor, einsum +from torch.nn import Linear from backpack.core.derivatives.basederivatives import BaseParameterDerivatives @@ -14,43 +18,212 @@ class LinearDerivatives(BaseParameterDerivatives): * i: Input dimension """ - def hessian_is_zero(self): + def hessian_is_zero(self) -> bool: + """Linear layer output is linear w.r.t. to its input. + + Returns: + True + """ return True - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): - """Apply transposed Jacobian of the output w.r.t. the input.""" - d_input = module.weight.data - return einsum("oi,vno->vni", (d_input, mat)) + def _jac_t_mat_prod( + self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + """Batch-apply transposed Jacobian of the output w.r.t. the input. - def _jac_mat_prod(self, module, g_inp, g_out, mat): - """Apply Jacobian of the output w.r.t. the input.""" - d_input = module.weight.data - return einsum("oi,vni->vno", (d_input, mat)) + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer output + (``[N, *, out_features]``) to which the transposed output-input Jacobian + is applied. Has shape ``[V, N, *, out_features]``. - def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): - jac = module.weight.data - return einsum("ik,ij,jl->kl", (jac, mat, jac)) + Returns: + Batched transposed Jacobian vector products. Has shape + ``[V, N, *, in_features]``. + """ + return einsum("oi,vn...o->vn...i", module.weight, mat) - def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): - """Apply Jacobian of the output w.r.t. the weight.""" - d_weight = module.input0 - return einsum("ni,voi->vno", (d_weight, mat)) + def _jac_mat_prod( + self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + """Batch-apply Jacobian of the output w.r.t. the input. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer input + (``[N, *, in_features]``) to which the output-input Jacobian is applied. + Has shape ``[V, N, *, in_features]``. + + Returns: + Batched Jacobian vector products. Has shape ``[V, N, *, out_features]``. + """ + return einsum("oi,vn...i->vn...o", module.weight, mat) + + def ea_jac_t_mat_jac_prod( + self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + """Expectation approximation of outer product with input-output Jacobian. + + Used for KFRA backpropagation: ``mat ← E(Jₙᵀ mat Jₙ) = 1/N ∑ₙ Jₙᵀ mat Jₙ``. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Matrix of shape + ``[module.output.numel() // N, module.output.numel() // N]``. + + Returns: + Matrix of shape + ``[module.input0.numel() // N, module.input0.numel() // N]``. + """ + add_features = self._get_additional_dims(module).numel() + in_features, out_features = module.in_features, module.out_features - def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - """Apply transposed Jacobian of the output w.r.t. the weight.""" + result = mat.reshape(add_features, out_features, add_features, out_features) + result = einsum("ik,xiyj,jl->xkyl", module.weight, result, module.weight) + + return result.reshape(in_features * add_features, in_features * add_features) + + def _weight_jac_mat_prod( + self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + """Batch-apply Jacobian of the output w.r.t. the weight. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of shape ``module.weight.shape`` to which the + transposed output-input Jacobian is applied. Has shape + ``[V, *module.weight.shape]``. + + Returns: + Batched Jacobian vector products. Has shape + ``[V, N, *module.output.shape]``. + """ + return einsum("n...i,voi->vn...o", module.input0, mat) + + def _weight_jac_t_mat_prod( + self, + module: Linear, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: int = True, + ) -> Tensor: + """Batch-apply transposed Jacobian of the output w.r.t. the weight. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer output + (``[N, *, out_features]``) to which the transposed output-input Jacobian + is applied. Has shape ``[V, N, *, out_features]``. + sum_batch: Sum the result's batch axis. Default: ``True``. + + Returns: + Batched transposed Jacobian vector products. Has shape + ``[V, N, *module.weight.shape]`` when ``sum_batch`` is ``False``. With + ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. + """ d_weight = module.input0 - contract = "vno,ni->voi" if sum_batch else "vno,ni->vnoi" - return einsum(contract, (mat, d_weight)) - - def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): - """Apply Jacobian of the output w.r.t. the bias.""" - N = module.input0.size(0) - return mat.unsqueeze(1).expand(-1, N, -1) - - def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - """Apply transposed Jacobian of the output w.r.t. the bias.""" - if sum_batch: - N_axis = 1 - return mat.sum(N_axis) + + if self._has_additional_dims(module): + # Flatten additional dimensions because they cannot be represented as + # ellipsis. WAITING https://github.com/pytorch/pytorch/issues/45854 + d_weight = d_weight.flatten(start_dim=1, end_dim=-2) + mat = mat.flatten(start_dim=2, end_dim=-2) + equation = f"vnao,nai->v{'' if sum_batch else 'n'}oi" else: - return mat + equation = f"vno,ni->v{'' if sum_batch else 'n'}oi" + + return einsum(equation, mat, d_weight) + + def _bias_jac_mat_prod( + self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + ) -> Tensor: + """Batch-apply Jacobian of the output w.r.t. the bias. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of shape ``module.bias.shape`` to which the + transposed output-input Jacobian is applied. Has shape + ``[V, *module.bias.shape]``. + + Returns: + Batched Jacobian vector products. Has shape + ``[V, N, *module.output.shape]``. + """ + N = module.input0.shape[0] + additional_dims = list(self._get_additional_dims(module)) + + for _ in range(len(additional_dims) + 1): + mat = mat.unsqueeze(1) + + expand = [-1, N] + additional_dims + [-1] + + return mat.expand(*expand) + + def _bias_jac_t_mat_prod( + self, + module: Linear, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: int = True, + ) -> Tensor: + """Batch-apply transposed Jacobian of the output w.r.t. the bias. + + Args: + module: Linear layer. + g_inp: Gradients w.r.t. module input. Not required by the implementation. + g_out: Gradients w.r.t. module output. Not required by the implementation. + mat: Batch of ``V`` vectors of same shape as the layer output + (``[N, *, out_features]``) to which the transposed output-input Jacobian + is applied. Has shape ``[V, N, *, out_features]``. + sum_batch: Sum the result's batch axis. Default: ``True``. + + Returns: + Batched transposed Jacobian vector products. Has shape + ``[V, N, *module.bias.shape]`` when ``sum_batch`` is ``False``. With + ``sum_batch=True``, has shape ``[V, *module.bias.shape]``. + """ + equation = f"vn...o->v{'' if sum_batch else 'n'}o" + + return einsum(equation, mat) + + @classmethod + def _has_additional_dims(cls, module: Linear) -> bool: + """Return whether the input to a linear layer has additional (>1) dimensions. + + The input to a linear layer may have shape ``[N, *, out_features]``. + It has additional dimensions if ``*`` is non-empty. + + Args: + module: Linear layer. + + Returns: + Whether the input has hidden dimensions. + """ + return len(cls._get_additional_dims(module)) != 0 + + @staticmethod + def _get_additional_dims(module: Linear) -> Size: + """Return the shape of additional dimensions in the input to a linear layer. + + Args: + module: A linear layer. + + Returns: + Shape of the additional dimensions. Corresponds to ``*`` in the + input shape ``[N, *, out_features]``. + """ + return module.input0.shape[1:-1] diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py index a84a64f71..df0bd558d 100644 --- a/backpack/extensions/__init__.py +++ b/backpack/extensions/__init__.py @@ -13,6 +13,8 @@ DiagGGNExact, DiagGGNMC, DiagHessian, + SqrtGGNExact, + SqrtGGNMC, ) __all__ = [ @@ -33,4 +35,6 @@ "BatchDiagGGNMC", "DiagHessian", "BatchDiagHessian", + "SqrtGGNExact", + "SqrtGGNMC", ] diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py index 9fb17536c..4978fd059 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/linear.py +++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py @@ -40,4 +40,14 @@ def weight( Returns: batch_l2 for weight """ - return einsum("ni,nj->n", g_out[0] ** 2, module.input0 ** 2) + has_additional_axes = g_out[0].dim() > 2 + + if has_additional_axes: + # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class + # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 + dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) + X = module.input0.flatten(start_dim=1, end_dim=-2) + return einsum("nmi,nmj,nki,nkj->n", dE_dY, X, dE_dY, X) + + else: + return einsum("ni,nj->n", g_out[0] ** 2, module.input0 ** 2) diff --git a/backpack/extensions/firstorder/sum_grad_squared/linear.py b/backpack/extensions/firstorder/sum_grad_squared/linear.py index 4cf75db1d..8239da936 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/linear.py +++ b/backpack/extensions/firstorder/sum_grad_squared/linear.py @@ -18,4 +18,13 @@ def weight(self, ext, module, g_inp, g_out, backproped): For details, see page 12 (paragraph about "second moment") of the paper (https://arxiv.org/pdf/1912.10985.pdf). """ - return einsum("ni,nj->ij", (g_out[0] ** 2, module.input0 ** 2)) + has_additional_axes = g_out[0].dim() > 2 + + if has_additional_axes: + # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class + # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 + dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) + X = module.input0.flatten(start_dim=1, end_dim=-2) + return einsum("nmi,nmj,nki,nkj->ij", dE_dY, X, dE_dY, X) + + return einsum("ni,nj->ij", g_out[0] ** 2, module.input0 ** 2) diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py index afe9ebc2f..3d3de33c9 100644 --- a/backpack/extensions/secondorder/__init__.py +++ b/backpack/extensions/secondorder/__init__.py @@ -17,11 +17,22 @@ :func:`KFRA `, :func:`KFLR `. - The diagonal of the Hessian :func:`DiagHessian ` +- The symmetric (square root) factorization of the GGN/Fisher information, + using exact computation + (:func:`SqrtGGNExact `) + or a Monte-Carlo (MC) approximation + (:func:`SqrtGGNMC`) """ -from .diag_ggn import BatchDiagGGNExact, BatchDiagGGNMC, DiagGGNExact, DiagGGNMC -from .diag_hessian import BatchDiagHessian, DiagHessian -from .hbp import HBP, KFAC, KFLR, KFRA +from backpack.extensions.secondorder.diag_ggn import ( + BatchDiagGGNExact, + BatchDiagGGNMC, + DiagGGNExact, + DiagGGNMC, +) +from backpack.extensions.secondorder.diag_hessian import BatchDiagHessian, DiagHessian +from backpack.extensions.secondorder.hbp import HBP, KFAC, KFLR, KFRA +from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC __all__ = [ "DiagGGNExact", @@ -34,4 +45,6 @@ "KFLR", "KFRA", "HBP", + "SqrtGGNExact", + "SqrtGGNMC", ] diff --git a/backpack/extensions/secondorder/hbp/linear.py b/backpack/extensions/secondorder/hbp/linear.py index 779459e14..c89791a3c 100644 --- a/backpack/extensions/secondorder/hbp/linear.py +++ b/backpack/extensions/secondorder/hbp/linear.py @@ -1,4 +1,5 @@ from torch import einsum +from torch.nn import Linear from backpack.core.derivatives.linear import LinearDerivatives from backpack.extensions.secondorder.hbp.hbp_options import ( @@ -13,6 +14,7 @@ def __init__(self): super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"]) def weight(self, ext, module, g_inp, g_out, backproped): + self.check_parameters(ext, module) bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): @@ -44,6 +46,7 @@ def _factor_from_sqrt(self, backproped): return [einsum("vni,vnj->ij", (backproped, backproped))] def bias(self, ext, module, g_inp, g_out, backproped): + self.check_parameters(ext, module) bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): @@ -61,3 +64,19 @@ def __mean_input_outer(self, module): N = module.input0.size(0) flat_input = module.input0.reshape(N, -1) return einsum("ni,nj->ij", (flat_input, flat_input)) / N + + def check_parameters(self, ext, module: Linear) -> None: + """Raise an exception if module parameters are not supported. + + Args: + ext (KFAC or KFRA or KFLR): Extension calling out to the module. + module: Linear layer. + + Raises: + NotImplementedError: If the setting is not implemented. + """ + if module.input0.dim() != 2: + raise NotImplementedError( + f"Only 2d inputs are supported by {ext.__class__.__name__} " + + f"(got {module.input0.dim()})." + ) diff --git a/backpack/extensions/secondorder/sqrt_ggn/__init__.py b/backpack/extensions/secondorder/sqrt_ggn/__init__.py new file mode 100644 index 000000000..d258b350f --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/__init__.py @@ -0,0 +1,159 @@ +"""Defines base class and extensions for computing the GGN/Fisher matrix square root.""" + +from torch.nn import ( + ELU, + SELU, + AvgPool1d, + AvgPool2d, + AvgPool3d, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + CrossEntropyLoss, + Dropout, + Flatten, + LeakyReLU, + Linear, + LogSigmoid, + MaxPool1d, + MaxPool2d, + MaxPool3d, + MSELoss, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) + +from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.hbp import LossHessianStrategy +from backpack.extensions.secondorder.sqrt_ggn import ( + activations, + convnd, + convtransposend, + dropout, + flatten, + linear, + losses, + padding, + pooling, +) + + +class SqrtGGN(BackpropExtension): + """Base class for extensions that compute the GGN/Fisher matrix square root.""" + + def __init__(self, loss_hessian_strategy: str, savefield: str): + """Store approximation for backpropagated object and where to save the result. + + Args: + loss_hessian_strategy: Which approximation is used for the backpropagated + loss Hessian. Must be ``'exact'`` or ``'sampling'``. + savefield: Attribute under which the quantity is saved in a parameter. + """ + self.loss_hessian_strategy = loss_hessian_strategy + super().__init__( + savefield=savefield, + fail_mode="ERROR", + module_exts={ + MSELoss: losses.SqrtGGNMSELoss(), + CrossEntropyLoss: losses.SqrtGGNCrossEntropyLoss(), + Linear: linear.SqrtGGNLinear(), + MaxPool1d: pooling.SqrtGGNMaxPool1d(), + MaxPool2d: pooling.SqrtGGNMaxPool2d(), + AvgPool1d: pooling.SqrtGGNAvgPool1d(), + MaxPool3d: pooling.SqrtGGNMaxPool3d(), + AvgPool2d: pooling.SqrtGGNAvgPool2d(), + AvgPool3d: pooling.SqrtGGNAvgPool3d(), + ZeroPad2d: padding.SqrtGGNZeroPad2d(), + Conv1d: convnd.SqrtGGNConv1d(), + Conv2d: convnd.SqrtGGNConv2d(), + Conv3d: convnd.SqrtGGNConv3d(), + ConvTranspose1d: convtransposend.SqrtGGNConvTranspose1d(), + ConvTranspose2d: convtransposend.SqrtGGNConvTranspose2d(), + ConvTranspose3d: convtransposend.SqrtGGNConvTranspose3d(), + Dropout: dropout.SqrtGGNDropout(), + Flatten: flatten.SqrtGGNFlatten(), + ReLU: activations.SqrtGGNReLU(), + Sigmoid: activations.SqrtGGNSigmoid(), + Tanh: activations.SqrtGGNTanh(), + LeakyReLU: activations.SqrtGGNLeakyReLU(), + LogSigmoid: activations.SqrtGGNLogSigmoid(), + ELU: activations.SqrtGGNELU(), + SELU: activations.SqrtGGNSELU(), + }, + ) + + def get_loss_hessian_strategy(self) -> str: + """Return the strategy used to represent the backpropagated loss Hessian. + + Returns: + Loss Hessian strategy. + """ + return self.loss_hessian_strategy + + +class SqrtGGNExact(SqrtGGN): + """Exact matrix square root of the generalized Gauss-Newton/Fisher. + + Uses the exact Hessian of the loss w.r.t. the model output. + + Stores the output in :code:`sqrt_ggn_exact`, has shape ``[C, N, param.shape]``, + where ``C`` is the model output dimension (number of classes for classification + problems) and ``N`` is the batch size. + + For a faster but less precise alternative, see + :py:meth:`backpack.extensions.SqrtGGNMC`. + + .. note:: + + (Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_exact`` + can be viewed as a ``[C * N, param.numel()]`` matrix. Concatenating this + matrix over all parameters results in a matrix ``Vᵀ``, which + is the GGN/Fisher's matrix square root, i.e. ``G = V Vᵀ``. + """ + + def __init__(self): + """Use exact loss Hessian and set savefield to ``sqrt_ggn_exact``.""" + super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact") + + +class SqrtGGNMC(SqrtGGN): + """Approximate matrix square root of the generalized Gauss-Newton/Fisher. + + Uses a Monte-Carlo (MC) approximation of the Hessian of the loss w.r.t. the model + output. + + Stores the output in :code:`sqrt_ggn_mc`, has shape ``[M, N, param.shape]``, + where ``M`` is the number of Monte-Carlo samples and ``N`` is the batch size. + + For a more precise but slower alternative, see + :py:meth:`backpack.extensions.SqrtGGNExact`. + + .. note:: + + (Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_mc`` + can be viewed as a ``[M * N, param.numel()]`` matrix. Concatenating this + matrix over all parameters results in a matrix ``Vᵀ``, which + is the approximate GGN/Fisher's matrix square root, i.e. ``G ≈ V Vᵀ``. + """ + + def __init__(self, mc_samples: int = 1): + """Approximate loss Hessian via MC and set savefield to ``sqrt_ggn_mc``. + + Args: + mc_samples: Number of Monte-Carlo samples. Default: ``1``. + """ + self._mc_samples = mc_samples + super().__init__(LossHessianStrategy.SAMPLING, "sqrt_ggn_mc") + + def get_num_mc_samples(self) -> int: + """Return the number of MC samples used to approximate the loss Hessian. + + Returns: + Number of Monte-Carlo samples. + """ + return self._mc_samples diff --git a/backpack/extensions/secondorder/sqrt_ggn/activations.py b/backpack/extensions/secondorder/sqrt_ggn/activations.py new file mode 100644 index 000000000..3aaf8fff2 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/activations.py @@ -0,0 +1,65 @@ +"""Contains extensions for activation layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.elu import ELUDerivatives +from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives +from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives +from backpack.core.derivatives.relu import ReLUDerivatives +from backpack.core.derivatives.selu import SELUDerivatives +from backpack.core.derivatives.sigmoid import SigmoidDerivatives +from backpack.core.derivatives.tanh import TanhDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNReLU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ReLU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ReLU`` module.""" + super().__init__(ReLUDerivatives()) + + +class SqrtGGNSigmoid(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Sigmoid`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Sigmoid`` module.""" + super().__init__(SigmoidDerivatives()) + + +class SqrtGGNTanh(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Tanh`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Tanh`` module.""" + super().__init__(TanhDerivatives()) + + +class SqrtGGNELU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ELU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ELU`` module.""" + super().__init__(ELUDerivatives()) + + +class SqrtGGNSELU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.SELU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.SELU`` module.""" + super().__init__(SELUDerivatives()) + + +class SqrtGGNLeakyReLU(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LeakyReLU`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.LeakyReLU`` module.""" + super().__init__(LeakyReLUDerivatives()) + + +class SqrtGGNLogSigmoid(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LogSigmoid`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.LogSigmoid`` module.""" + super().__init__(LogSigmoidDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/base.py b/backpack/extensions/secondorder/sqrt_ggn/base.py new file mode 100644 index 000000000..336325b08 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/base.py @@ -0,0 +1,69 @@ +"""Contains base class for ``SqrtGGN{Exact, MC}`` module extensions.""" +from typing import Any, Callable, List, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.extensions.mat_to_mat_jac_base import MatToJacMat + + +class SqrtGGNBaseModule(MatToJacMat): + """Base module extension for ``SqrtGGN{Exact, MC}``.""" + + def __init__(self, derivatives: BaseDerivatives, params: List[str] = None): + """Store parameter names and derivatives. + + Sets up methods that extract the GGN/Fisher matrix square root for the + passed parameters, unless these methods are overwritten by a child class. + + Args: + derivatives: derivatives object. + params: List of parameter names. Defaults to None. + """ + if params is not None: + for param_str in params: + if not hasattr(self, param_str): + setattr(self, param_str, self._make_param_function(param_str)) + + super().__init__(derivatives, params=params) + + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def _make_param_function( + self, param_str: str + ) -> Callable[[Any, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor]: + """Create a function that computes the GGN/Fisher square root for a parameter. + + Args: + param_str: name of parameter + + Returns: + Function that computes the GGN/Fisher matrix square root. + """ + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def param_function( + ext: Any, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + backproped: Tensor, + ) -> Tensor: + """Calculate the GGN/Fisher matrix square root with the derivatives object. + + Args: + ext: extension that is used + module: module that performed forward pass + g_inp: input gradient tensors + g_out: output gradient tensors + backproped: Backpropagated quantities from second-order extension. + + Returns: + GGN/Fisher matrix square root. + """ + return getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( + module, g_inp, g_out, backproped, sum_batch=False + ) + + return param_function diff --git a/backpack/extensions/secondorder/sqrt_ggn/convnd.py b/backpack/extensions/secondorder/sqrt_ggn/convnd.py new file mode 100644 index 000000000..74a88651c --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/convnd.py @@ -0,0 +1,29 @@ +"""Contains extensions for convolution layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.conv1d import Conv1DDerivatives +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.core.derivatives.conv3d import Conv3DDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNConv1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Conv1d`` module.""" + super().__init__(Conv1DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConv2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Conv2d`` module.""" + super().__init__(Conv2DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConv3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Conv3d`` module.""" + super().__init__(Conv3DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py new file mode 100644 index 000000000..a18331976 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py @@ -0,0 +1,29 @@ +"""Contains transpose convolution layer extensions used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives +from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives +from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNConvTranspose1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ConvTranspose1d`` module.""" + super().__init__(ConvTranspose1DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConvTranspose2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ConvTranspose2d`` module.""" + super().__init__(ConvTranspose2DDerivatives(), params=["bias", "weight"]) + + +class SqrtGGNConvTranspose3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ConvTranspose3d`` module.""" + super().__init__(ConvTranspose3DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/sqrt_ggn/dropout.py b/backpack/extensions/secondorder/sqrt_ggn/dropout.py new file mode 100644 index 000000000..2f03b8aa9 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/dropout.py @@ -0,0 +1,11 @@ +"""Contains extensions for dropout layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.dropout import DropoutDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNDropout(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Dropout`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Dropout`` module.""" + super().__init__(DropoutDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/flatten.py b/backpack/extensions/secondorder/sqrt_ggn/flatten.py new file mode 100644 index 000000000..c2f4eb5a3 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/flatten.py @@ -0,0 +1,47 @@ +"""Contains extensions for the flatten layer used by ``SqrtGGN{Exact, MC}``.""" +from typing import Any, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.flatten import FlattenDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNFlatten(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Flatten`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Flatten`` module.""" + super().__init__(FlattenDerivatives()) + + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def backpropagate( + self, + ext: Any, + module: Module, + grad_inp: Tuple[Tensor], + grad_out: Tuple[Tensor], + backproped: Tensor, + ) -> Tensor: + """Backpropagate only if flatten created a node in the computation graph. + + Otherwise, the backward hook will not be called at the right stage and + no action must be performed. + + Args: + ext: BackPACK extension calling out to the module extension. + module: Module that performed the forward pass. + grad_inp: Gradients w.r.t. the module inputs. + grad_out: Gradients w.r.t. the module outputs. + backproped: Backpropagated symmetric factorization of the loss Hessian + from the child module. + + Returns: + Symmetric loss Hessian factorization, backpropagated through the module. + """ + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/sqrt_ggn/linear.py b/backpack/extensions/secondorder/sqrt_ggn/linear.py new file mode 100644 index 000000000..4aecca6f5 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/linear.py @@ -0,0 +1,11 @@ +"""Contains extension for the linear layer used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNLinear(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Linear`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Linear`` module.""" + super().__init__(LinearDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/secondorder/sqrt_ggn/losses.py b/backpack/extensions/secondorder/sqrt_ggn/losses.py new file mode 100644 index 000000000..8e48cf318 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/losses.py @@ -0,0 +1,72 @@ +"""Contains base class and extensions for losses used by ``SqrtGGN{Exact, MC}``.""" +from typing import Any, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives +from backpack.core.derivatives.mseloss import MSELossDerivatives +from backpack.extensions.secondorder.hbp import LossHessianStrategy +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNBaseLossModule(SqrtGGNBaseModule): + """Base class for losses used by ``SqrtGGN{Exact, MC}``.""" + + # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] + # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def backpropagate( + self, + ext: Any, + module: Module, + grad_inp: Tuple[Tensor], + grad_out: Tuple[Tensor], + backproped: None, + ) -> Tensor: + """Initialize the backpropagated quantity. + + Uses the exact loss Hessian square root, or a Monte-Carlo approximation + thereof. + + Args: + ext: BackPACK extension calling out to the module extension. + module: Module that performed the forward pass. + grad_inp: Gradients w.r.t. the module inputs. + grad_out: Gradients w.r.t. the module outputs. + backproped: Backpropagated information. Should be ``None``. + + Returns: + Symmetric factorization of the loss Hessian w.r.t. the module input. + + Raises: + NotImplementedError: For invalid strategies to represent the loss Hessian. + """ + loss_hessian_strategy = ext.get_loss_hessian_strategy() + + if loss_hessian_strategy == LossHessianStrategy.EXACT: + return self.derivatives.sqrt_hessian(module, grad_inp, grad_out) + elif loss_hessian_strategy == LossHessianStrategy.SAMPLING: + mc_samples = ext.get_num_mc_samples() + return self.derivatives.sqrt_hessian_sampled( + module, grad_inp, grad_out, mc_samples=mc_samples + ) + else: + raise NotImplementedError( + f"Unknown hessian strategy {loss_hessian_strategy}" + ) + + +class SqrtGGNMSELoss(SqrtGGNBaseLossModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MSELoss`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MSELoss`` module.""" + super().__init__(MSELossDerivatives()) + + +class SqrtGGNCrossEntropyLoss(SqrtGGNBaseLossModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.CrossEntropyLoss`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.CrossEntropyLoss`` module.""" + super().__init__(CrossEntropyLossDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/padding.py b/backpack/extensions/secondorder/sqrt_ggn/padding.py new file mode 100644 index 000000000..18574f685 --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/padding.py @@ -0,0 +1,11 @@ +"""Contains extensions for padding layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNZeroPad2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ZeroPad2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.ZeroPad2d`` module.""" + super().__init__(ZeroPad2dDerivatives()) diff --git a/backpack/extensions/secondorder/sqrt_ggn/pooling.py b/backpack/extensions/secondorder/sqrt_ggn/pooling.py new file mode 100644 index 000000000..e19cfba2a --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/pooling.py @@ -0,0 +1,56 @@ +"""Contains extensions for pooling layers used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives +from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives +from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives +from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives +from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives +from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNMaxPool1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MaxPool1d`` module.""" + super().__init__(MaxPool1DDerivatives()) + + +class SqrtGGNMaxPool2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MaxPool2d`` module.""" + super().__init__(MaxPool2DDerivatives()) + + +class SqrtGGNMaxPool3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.MaxPool3d`` module.""" + super().__init__(MaxPool3DDerivatives()) + + +class SqrtGGNAvgPool1d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool1d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.AvgPool1d`` module.""" + super().__init__(AvgPool1DDerivatives()) + + +class SqrtGGNAvgPool2d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool2d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.AvgPool2d`` module.""" + super().__init__(AvgPool2DDerivatives()) + + +class SqrtGGNAvgPool3d(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool3d`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.AvgPool3d`` module.""" + super().__init__(AvgPool3DDerivatives()) diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index e69de29bb..39f7fa2b1 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -0,0 +1 @@ +"""Contains utility functions.""" diff --git a/backpack/utils/linear.py b/backpack/utils/linear.py index b3a2453b3..b61c72cab 100644 --- a/backpack/utils/linear.py +++ b/backpack/utils/linear.py @@ -1,17 +1,60 @@ -from torch import einsum +"""Contains utility functions to extract the GGN diagonal for linear layers.""" +from torch import Tensor, einsum +from torch.nn import Linear -def extract_weight_diagonal(module, backproped, sum_batch=True): - if sum_batch: - equation = "vno,ni->oi" +def extract_weight_diagonal( + module: Linear, S: Tensor, sum_batch: bool = True +) -> Tensor: + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the weight Jacobian. + + Args: + module: Linear layer for which the diagonal is extracted w.r.t. the weight. + S: Backpropagated symmetric factorization of the loss Hessian. Has shape + ``(V, *module.output.shape)``. + sum_batch: Sum out the weight diagonal's batch dimension. Default: ``True``. + + Returns: + Per-sample weight diagonal if ``sum_batch=False`` (shape + ``(N, module.weight.shape)`` with batch size ``N``) or summed weight diagonal + if ``sum_batch=True`` (shape ``module.weight.shape``). + """ + has_additional_axes = module.input0.dim() > 2 + + if has_additional_axes: + S_flat = S.flatten(start_dim=2, end_dim=-2) + X_flat = module.input0.flatten(start_dim=1, end_dim=-2) + equation = f"vnmo,nmi,vnko,nki->{'' if sum_batch else 'n'}oi" + # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class + # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 + return einsum(equation, S_flat, X_flat, S_flat, X_flat) + else: - equation = "vno,ni->noi" - return einsum(equation, (backproped ** 2, module.input0 ** 2)) + equation = f"vno,ni->{'' if sum_batch else 'n'}oi" + return einsum(equation, S ** 2, module.input0 ** 2) + +def extract_bias_diagonal(module: Linear, S: Tensor, sum_batch: bool = True) -> Tensor: + """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the bias Jacobian. -def extract_bias_diagonal(module, backproped, sum_batch=True): - if sum_batch: - equation = "vno->o" + Args: + module: Linear layer for which the diagonal is extracted w.r.t. the bias. + S: Backpropagated symmetric factorization of the loss Hessian. Has shape + ``(V, *module.output.shape)``. + sum_batch: Sum out the bias diagonal's batch dimension. Default: ``True``. + + Returns: + Per-sample bias diagonal if ``sum_batch=False`` (shape + ``(N, module.bias.shape)`` with batch size ``N``) or summed bias diagonal + if ``sum_batch=True`` (shape ``module.bias.shape``). + """ + additional_axes = list(range(2, module.input0.dim())) + + if additional_axes: + JS = S.sum(additional_axes) else: - equation = "vno->no" - return einsum(equation, backproped ** 2) + JS = S + + equation = f"vno->{'' if sum_batch else 'n'}o" + + return einsum(equation, JS ** 2) diff --git a/docs_src/examples/basic_usage/example_all_in_one.py b/docs_src/examples/basic_usage/example_all_in_one.py index 49a8f1e46..cb7aba42d 100644 --- a/docs_src/examples/basic_usage/example_all_in_one.py +++ b/docs_src/examples/basic_usage/example_all_in_one.py @@ -29,6 +29,8 @@ DiagGGNExact, DiagGGNMC, DiagHessian, + SqrtGGNExact, + SqrtGGNMC, SumGradSquared, Variance, ) @@ -166,6 +168,19 @@ print(".diag_h.shape: ", param.diag_h.shape) print(".diag_h_batch.shape: ", param.diag_h_batch.shape) +# %% +# Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation + +loss = lossfunc(model(X), y) +with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".sqrt_ggn_exact.shape: ", param.sqrt_ggn_exact.shape) + print(".sqrt_ggn_mc.shape: ", param.sqrt_ggn_mc.shape) + # %% # Block-diagonal curvature products # --------------------------------- diff --git a/docs_src/rtd/extensions.rst b/docs_src/rtd/extensions.rst index 2fa9e2369..9eea8df02 100644 --- a/docs_src/rtd/extensions.rst +++ b/docs_src/rtd/extensions.rst @@ -25,6 +25,8 @@ Available Extensions .. autofunction:: backpack.extensions.KFRA .. autofunction:: backpack.extensions.DiagHessian .. autofunction:: backpack.extensions.BatchDiagHessian +.. autofunction:: backpack.extensions.SqrtGGNExact +.. autofunction:: backpack.extensions.SqrtGGNMC ----- diff --git a/fully_documented.txt b/fully_documented.txt index 41bb83b5b..a37063729 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -1,11 +1,14 @@ setup.py +backpack/custom_module/ + backpack/core/derivatives/basederivatives.py backpack/core/derivatives/rnn.py backpack/core/derivatives/shape_check.py backpack/core/derivatives/__init__.py backpack/core/derivatives/permute.py backpack/core/derivatives/lstm.py +backpack/core/derivatives/linear.py backpack/core/derivatives/adaptive_avg_pool_nd.py backpack/extensions/__init__.py @@ -35,8 +38,12 @@ backpack/extensions/secondorder/diag_hessian/__init__.py backpack/extensions/secondorder/diag_hessian/conv1d.py backpack/extensions/secondorder/diag_hessian/conv2d.py backpack/extensions/secondorder/diag_hessian/conv3d.py +backpack/extensions/secondorder/sqrt_ggn/ -backpack/custom_module/ +backpack/utils/linear.py +backpack/utils/__init__.py + +test/adaptive_avg_pool/ test/core/derivatives/derivatives_test.py test/core/derivatives/__init__.py @@ -47,14 +54,18 @@ test/core/derivatives/permute_settings.py test/core/derivatives/lstm_settings.py test/core/derivatives/pooling_adaptive_settings.py -test/extensions/problem.py test/extensions/test_backprop_extension.py +test/extensions/problem.py test/extensions/firstorder/firstorder_settings.py +test/extensions/firstorder/variance/ test/extensions/firstorder/batch_grad/batchgrad_settings.py test/extensions/secondorder/secondorder_settings.py test/extensions/secondorder/diag_ggn/ -test/extensions/secondorder/hbp +test/extensions/secondorder/hbp/ +test/extensions/secondorder/sqrt_ggn/ -test/adaptive_avg_pool/ +test/extensions/implementation/base.py +test/extensions/implementation/autograd.py +test/extensions/implementation/backpack.py diff --git a/makefile b/makefile index fbcd353f1..b0d4726a3 100644 --- a/makefile +++ b/makefile @@ -56,16 +56,16 @@ help: ### # Test coverage test: - @pytest -vx --run-optional-tests=montecarlo --cov=backpack . + @pytest -vx -rs --run-optional-tests=montecarlo --cov=backpack . test-light: - @pytest -vx --cov=backpack . + @pytest -vx -rs --cov=backpack . test-no-gpu: - @pytest -k "not cuda" -vx --run-optional-tests=montecarlo --cov=backpack . + @pytest -k "not cuda" -vx -rs --run-optional-tests=montecarlo --cov=backpack . test-light-no-gpu: - @pytest -k "not cuda" -vx --cov=backpack . + @pytest -k "not cuda" -vx -rs --cov=backpack . ### # Linter and autoformatter diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 032d2dd5e..b21babde2 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -13,12 +13,14 @@ from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS from test.core.derivatives.lstm_settings import LSTM_SETTINGS from test.core.derivatives.permute_settings import PERMUTE_SETTINGS -from test.core.derivatives.problem import make_test_problems +from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS from test.core.derivatives.settings import SETTINGS +from warnings import warn import pytest import torch +from pytest import fixture, skip from backpack.core.derivatives.convnd import weight_jac_t_save_memory @@ -50,12 +52,12 @@ NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + LSTM_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS, ) -def test_jac_mat_prod(problem, V=3): +def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: """Test the Jacobian-matrix product. Args: - problem (DerivativesProblem): Problem for derivative test. - V (int): Number of vectorized Jacobian-vector products. + problem: Test case. + V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.input_shape).to(problem.device) @@ -72,19 +74,19 @@ def test_jac_mat_prod(problem, V=3): NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + LSTM_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS, ) -def test_jac_t_mat_prod(problem, request, V=3): +def test_jac_t_mat_prod(problem: DerivativesTestProblem, request, V: int = 3) -> None: """Test the transposed Jacobian-matrix product. Args: - problem (DerivativesProblem): Problem for derivative test. + problem: Problem for derivative test. request: Pytest request, used for getting id. - V (int): Number of vectorized transposed Jacobian-vector products. + V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) if all( - [string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"]] + string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"] ): with pytest.warns(UserWarning): BackpackDerivatives(problem).jac_t_mat_prod(mat) @@ -226,14 +228,16 @@ def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): ids=["save_memory=True", "save_memory=False"], ) @pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) -def test_weight_jac_t_mat_prod(problem, sum_batch, save_memory, V=3): - """Test the transposed Jacobian-matrix product w.r.t. to the weights. +def test_weight_jac_t_mat_prod( + problem: DerivativesTestProblem, sum_batch: bool, save_memory: bool, V: int = 3 +) -> None: + """Test the transposed Jacobian-matrix product w.r.t. to the weight. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - save_memory (bool): Use Owkin implementation to save memory. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + sum_batch: Sum out the batch dimension. + save_memory: Use Owkin implementation in convolutions to save memory. + V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) @@ -249,12 +253,12 @@ def test_weight_jac_t_mat_prod(problem, sum_batch, save_memory, V=3): @pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) -def test_weight_jac_mat_prod(problem, V=3): - """Test the Jacobian-matrix product w.r.t. to the weights. +def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: + """Test the Jacobian-matrix product w.r.t. to the weight. Args: - problem (DerivativesProblem): Problem for derivative test. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.module.weight.shape).to(problem.device) @@ -277,18 +281,16 @@ def test_weight_jac_mat_prod(problem, V=3): @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_BIAS, - ids=IDS_WITH_BIAS, -) -def test_bias_jac_t_mat_prod(problem, sum_batch, V=3): - """Test the transposed Jacobian-matrix product w.r.t. to the biass. +@pytest.mark.parametrize("problem", PROBLEMS_WITH_BIAS, ids=IDS_WITH_BIAS) +def test_bias_jac_t_mat_prod( + problem: DerivativesTestProblem, sum_batch: bool, V: int = 3 +) -> None: + """Test the transposed Jacobian-matrix product w.r.t. to the bias. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + sum_batch: Sum out the batch dimension. + V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) @@ -300,17 +302,13 @@ def test_bias_jac_t_mat_prod(problem, sum_batch, V=3): problem.tear_down() -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_BIAS, - ids=IDS_WITH_BIAS, -) -def test_bias_jac_mat_prod(problem, V=3): - """Test the Jacobian-matrix product w.r.t. to the biass. +@pytest.mark.parametrize("problem", PROBLEMS_WITH_BIAS, ids=IDS_WITH_BIAS) +def test_bias_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: + """Test the Jacobian-matrix product w.r.t. to the bias. Args: - problem (DerivativesProblem): Problem for derivative test. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Test case. + V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() mat = torch.rand(V, *problem.module.bias.shape).to(problem.device) @@ -336,9 +334,6 @@ def test_sqrt_hessian_squared_equals_hessian(problem): backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian() autograd_res = AutogradDerivatives(problem).input_hessian() - print(backpack_res.device) - print(autograd_res.device) - check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() @@ -415,7 +410,7 @@ def test_sum_hessian_should_fail(problem): @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_ea_jac_t_mat_jac_prod(problem, request): +def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None: """Test KFRA backpropagation. H_in → 1/N ∑ₙ Jₙ^T H_out Jₙ @@ -427,15 +422,15 @@ def test_ea_jac_t_mat_jac_prod(problem, request): as `Dropout` is not deterministic. Args: - problem (DerivativesProblem): Problem for derivative test. + problem: Test case. request: PyTest request, used to get test id. """ problem.set_up() - out_features = torch.prod(torch.tensor(problem.output_shape[1:])) + out_features = problem.output_shape[1:].numel() mat = torch.rand(out_features, out_features).to(problem.device) if all( - [string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"]] + string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"] ): with pytest.warns(UserWarning): BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat) @@ -449,56 +444,80 @@ def test_ea_jac_t_mat_jac_prod(problem, request): problem.tear_down() -@pytest.mark.skip("[WAITING] Autograd issue with Hessian-vector products") -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_is_zero(problem): - """Check if the input-output Hessian is (non-)zero. +@fixture(params=PROBLEMS, ids=lambda p: p.make_id()) +def instantiated_problem(request) -> DerivativesTestProblem: + """Set seed, create tested layer and data. Finally clean up. Args: - problem: test problem + request (SubRequest): Request for the fixture from a test/fixture function. + + Yields: + Test case with deterministically constructed attributes. """ - problem.set_up() + case = request.param + case.set_up() + yield case + case.tear_down() - backpack_res = BackpackDerivatives(problem).hessian_is_zero() - autograd_res = AutogradDerivatives(problem).hessian_is_zero() - assert backpack_res == autograd_res - problem.tear_down() +@fixture +def small_input_problem( + instantiated_problem: DerivativesTestProblem, max_input_numel: int = 100 +) -> DerivativesTestProblem: + """Skip cases with large inputs. + Args: + instantiated_problem: Test case with constructed attributes. + max_input_numel: Maximum input size. Default: ``100``. + + Yields: + Instantiated test case with small input. + """ + if instantiated_problem.input.numel() > max_input_numel: + skip( + "Input is too large:" + + f" {instantiated_problem.input.numel()} > {max_input_numel}" + ) + else: + yield instantiated_problem -@pytest.mark.skip -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_is_diagonal(problem): - """Test whether hessian is diagonal. + +@fixture +def no_loss_problem( + small_input_problem: DerivativesTestProblem, +) -> DerivativesTestProblem: + """Skip cases that are loss functions. Args: - problem: test problem + small_input_problem: Test case with small input. - Raises: - NotImplementedError: . + Yields: + Instantiated test case that is not a loss layer. """ - problem.set_up() - - # TODO - raise NotImplementedError + if small_input_problem.is_loss(): + skip("Only required for non-loss layers.") + else: + yield small_input_problem - problem.tear_down() +def test_hessian_is_zero(no_loss_problem: DerivativesTestProblem) -> None: + """Check if the input-output Hessian is (non-)zero. -@pytest.mark.skip -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_hessian_is_psd(problem): - """Test whether hessian is semi positive definite. + Note: + `hessian_is_zero` is a global statement that assumes arbitrary inputs. + It can thus happen that the Hessian diagonal is zero for the current + input, but not in general. Args: - problem: test problem - - Raises: - NotImplementedError: . + no_loss_problem: Test case whose module is not a loss. """ - problem.set_up() - - # TODO - raise NotImplementedError + backpack_res = BackpackDerivatives(no_loss_problem).hessian_is_zero() + autograd_res = AutogradDerivatives(no_loss_problem).hessian_is_zero() - problem.tear_down() + if autograd_res and not backpack_res: + warn( + "Autograd Hessian diagonal is zero for this input " + " while BackPACK implementation implies inputs with non-zero Hessian." + ) + else: + assert backpack_res == autograd_res diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 061db5cf4..b4d1b43d0 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -2,7 +2,7 @@ from test.core.derivatives.implementation.base import DerivativesImplementation import torch -from torch import Tensor +from torch import Tensor, zeros_like from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.lop import transposed_jacobian_vector_product @@ -81,7 +81,7 @@ def param_jac_t_vec_prod(self, name, vec, sum_batch, axis_batch=0): Returns: torch.Tensor: product of jac_t and vec """ - input, output, named_params = self.problem.forward_pass() + _, output, named_params = self.problem.forward_pass() param = named_params[name] if sum_batch: @@ -226,53 +226,39 @@ def _hessian(self, loss: Tensor, x: Tensor) -> Tensor: return hessian_vec_x.reshape(final_shape) - def _elementwise_hessian(self, tensor, x: Tensor): + def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor: """Computes the Hessian of each element in `tensor` w.r.t `x`. - Hessians are returned in the order of elements in the flattened tensor. - - Args: - tensor: . - x: Tensor used in the computation graph of `loss`. - - Yields: - hessian of each element - """ - for t in tensor.flatten(): - yield self._hessian(t, x) - - def _tensor_hessian(self, tensor, x): - """Return the Hessian of a tensor `tensor` w.r.t. a tensor `x`. - Given a `tensor` of shape `[A, B, C]` and another tensor `x` with shape `[D, E]` used in the computation of `tensor`, the generalized Hessian has shape [A, B, C, D, E, D, E]. Let `hessian` denote this generalized Hessian. Then, `hessian[a, b, c]` contains the Hessian of the scalar entry `tensor[a, b, c]` w.r.t. `x[a, b, c]`. + If ``tensor`` is linear in ``x``, autograd raises a ``RuntimeError``. + If ``tensor`` does not depend on ``x``, autograd raises an ``AttributeError``. + In both cases, a Hessian of zeros is created manually and returned. + Arguments: - tensor (torch.Tensor): An arbitrary tensor. - x (torch.Tensor): Tensor used in the computation graph of `tensor`. + tensor: An arbitrary tensor. + x: Tensor used in the computation graph of `tensor`. - Returns: - torch.Tensor: Generalized Hessian of `tensor` w.r.t. `x`. + Yields: + Hessians in the order of elements in the flattened tensor. """ - shape = (*tensor.shape, *x.shape, *x.shape) - - return torch.cat(list(self._elementwise_hessian(tensor, x))).reshape(shape) - - def hessian_is_zero(self): - """Return whether the input-output Hessian is zero. + for t in tensor.flatten(): + try: + yield self._hessian(t, x) + except (RuntimeError, AttributeError): + yield torch.zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype) - Returns: - bool: `True`, if Hessian is zero, else `False`. - """ + def hessian_is_zero(self) -> bool: # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) zero = None for hessian in self._elementwise_hessian(output, input): if zero is None: - zero = torch.zeros_like(hessian) + zero = zeros_like(hessian) if not torch.allclose(hessian, zero): return False diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index e3de0a402..3ac44cf3b 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -99,8 +99,7 @@ def input_hessian_via_sqrt_hessian(self, mc_samples=None) -> Tensor: Args: mc_samples: If int, uses an MC approximation with the specified - number of samples. - If None, uses the exact hessian. Defaults to None. + number of samples. If None, uses the exact hessian. Defaults to None. Returns: hessian @@ -122,12 +121,7 @@ def input_hessian_via_sqrt_hessian(self, mc_samples=None) -> Tensor: individual_hessians, self.problem.module.input0 ) - def hessian_is_zero(self): - """Return whether the input-output Hessian is zero. - - Returns: - bool: `True`, if Hessian is zero, else `False`. - """ + def hessian_is_zero(self) -> bool: # noqa: D102 return self.problem.derivative.hessian_is_zero() def _sample_hessians_from_sqrt(self, sqrt: Tensor) -> Tensor: @@ -140,7 +134,7 @@ def _sample_hessians_from_sqrt(self, sqrt: Tensor) -> Tensor: individual full matrix Raises: - ValueError: if input is not 2d + ValueError: if input is not 3d """ equation = None num_axes = len(sqrt.shape) diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index 0c2e5580f..43aa48868 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -163,3 +163,12 @@ def sum_hessian(self) -> Tensor: the sum of hessians """ raise NotImplementedError + + @abstractmethod + def hessian_is_zero(self) -> bool: + """Return whether the input-output Hessian is zero. + + Returns: + `True`, if Hessian is zero, else `False`. + """ + raise NotImplementedError diff --git a/test/core/derivatives/linear_settings.py b/test/core/derivatives/linear_settings.py index df7132d99..a98df7aab 100644 --- a/test/core/derivatives/linear_settings.py +++ b/test/core/derivatives/linear_settings.py @@ -6,7 +6,7 @@ "input_fn" (callable): Used for specifying input function Optional entries: - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model "device" [list(torch.device)]: List of devices to run the test on. @@ -55,3 +55,22 @@ ), }, ] + +# additional dimensions +LINEAR_SETTINGS += [ + { + "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True), + "input_fn": lambda: torch.rand(size=(3, 2, 4)), + "id_prefix": "one-additional", + }, + { + "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True), + "input_fn": lambda: torch.rand(size=(3, 2, 3, 4)), + "id_prefix": "two-additional", + }, + { + "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True), + "input_fn": lambda: torch.rand(size=(3, 2, 3, 5, 4)), + "id_prefix": "three-additional", + }, +] diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index b5474ffe5..eb03c1788 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -1,28 +1,27 @@ -"""Test configurations for `backpack.core.extensions.firstorder`. - -It is shared among the following firstorder methods: -- batch_grad -- batch_l2_grad -- sum_grad_sqaured -- variance +"""Shared test cases for BackPACK's first-order extensions. +Shared by the tests of: +- ``BatchGrad`` +- ``BatchL2Grad`` +- ``SumGradSquared`` +- ``Variance`` Required entries: "module_fn" (callable): Contains a model constructed from `torch.nn` layers "input_fn" (callable): Used for specifying input function - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model Optional entries: - "device" [list(torch.device)]: List of devices to run the test on. + "device" [list(device)]: List of devices to run the test on. "id_prefix" (str): Prefix to be included in the test name. - "seed" (int): seed for the random number for torch.rand + "seed" (int): seed set before initializing a case. """ from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.automated_settings import make_simple_cnn_setting -import torch +from torch import device, rand from torch.nn import ( RNN, Conv1d, @@ -31,8 +30,13 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + CrossEntropyLoss, Flatten, + Linear, + MSELoss, + ReLU, Sequential, + Sigmoid, ) from backpack.custom_module.permute import Permute @@ -45,11 +49,11 @@ ############################################################################### example = { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential(torch.nn.Linear(10, 5)), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), - "device": [torch.device("cpu")], + "device": [device("cpu")], "seed": 0, "id_prefix": "example", } @@ -62,47 +66,91 @@ FIRSTORDER_SETTINGS += [ # classification { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3,), 5), }, { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, - # Regression + # regression { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] +# linear with additional dimension +FIRSTORDER_SETTINGS += [ + # regression + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 4, 2)), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Sigmoid(), Linear(3, 2)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 4, 2, 2)), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)), + "loss_function_fn": lambda: MSELoss(reduction="sum"), + "target_fn": lambda: regression_targets((3, 4, 2, 3, 2)), + "id_prefix": "three-additional", + }, + # classification + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 8), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential( + Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten() + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 16), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 48), + "id_prefix": "three-additional", + }, +] + ############################################################################### # test setting: Convolutional Layers # """ -Syntax with default parameters: - - `torch.nn.ConvNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, dilation=1, - groups=1, bias=True, padding_mode='zeros)` +Syntax with default parameters: + - `torch.nn.ConvNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, dilation=1, + groups=1, bias=True, padding_mode='zeros)` - - `torch.nn.ConvTransposeNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, output_padding=0, + - `torch.nn.ConvTransposeNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros)` -Note: There are 5 tests added to each `torch.nn.layers`. +Note: There are 5 tests added to each `torch.nn.layers`. For `torch.nn.ConvTranspose2d` and `torch.nn.ConvTranspose3d` -only 3 tests are added because they are very memory intensive. +only 3 tests are added because they are very memory intensive. """ ############################################################################### @@ -164,18 +212,18 @@ FIRSTORDER_SETTINGS += [ { - "input_fn": lambda: torch.rand(8, 5, 6), + "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( Permute(1, 0, 2), RNN(input_size=6, hidden_size=3), ReduceTuple(index=0), Permute(1, 2, 0), ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((8, 5), 3), }, { - "input_fn": lambda: torch.rand(8, 5, 6), + "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( Permute(1, 0, 2), RNN(input_size=6, hidden_size=3), @@ -183,7 +231,7 @@ Permute(1, 2, 0), Flatten(), ), - "loss_function_fn": lambda: torch.nn.MSELoss(), + "loss_function_fn": lambda: MSELoss(), "target_fn": lambda: regression_targets((8, 3 * 5)), }, ] diff --git a/test/extensions/firstorder/variance/__init__.py b/test/extensions/firstorder/variance/__init__.py index e69de29bb..7ca2b624c 100644 --- a/test/extensions/firstorder/variance/__init__.py +++ b/test/extensions/firstorder/variance/__init__.py @@ -0,0 +1 @@ +"""Contains tests for BackPACK's ``Variance`` extension.""" diff --git a/test/extensions/firstorder/variance/test_variance.py b/test/extensions/firstorder/variance/test_variance.py index 50cf21675..8680c2086 100644 --- a/test/extensions/firstorder/variance/test_variance.py +++ b/test/extensions/firstorder/variance/test_variance.py @@ -1,16 +1,9 @@ -"""Test class for module variance -from `backpack.core.extensions.firstorder` - -Test variances for the following layers: -- variance of linear layers -- variance of convolutional layers - -""" +"""Test BackPACK's ``Variance`` extension.""" from test.automated_test import check_sizes_and_values from test.extensions.firstorder.variance.variance_settings import VARIANCE_SETTINGS from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions -from test.extensions.problem import make_test_problems +from test.extensions.problem import ExtensionsTestProblem, make_test_problems import pytest @@ -19,16 +12,17 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_variance(problem): - """Test variance of individual gradients +def test_variance(problem: ExtensionsTestProblem) -> None: + """Test variance of individual gradients. Args: - problem (ExtensionsTestProblem): Problem for extension test. + problem: Test case. """ problem.set_up() backpack_res = BackpackExtensions(problem).variance() autograd_res = AutogradExtensions(problem).variance() - check_sizes_and_values(autograd_res, backpack_res) + rtol = 5e-5 + check_sizes_and_values(autograd_res, backpack_res, rtol=rtol) problem.tear_down() diff --git a/test/extensions/firstorder/variance/variance_settings.py b/test/extensions/firstorder/variance/variance_settings.py index 61c39d0d6..c8a8de2da 100644 --- a/test/extensions/firstorder/variance/variance_settings.py +++ b/test/extensions/firstorder/variance/variance_settings.py @@ -1,7 +1,7 @@ -"""Test configurations to test variance +"""Test cases for ``Variance`` extension. -The tests are taken from `test.extensions.firstorder.firstorder_settings`, -but additional custom tests can be defined here by appending it to the list. +Uses shared test cases from `test.extensions.firstorder.firstorder_settings`, +and the local cases defined in this file. """ from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index 71655b239..c86350ebd 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -1,8 +1,12 @@ +"""Autograd implementation of BackPACK's extensions.""" from test.extensions.implementation.base import ExtensionsImplementation +from typing import List import torch +from torch import Tensor +from torch.nn.utils.convert_parameters import parameters_to_vector -from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist +from backpack.hessianfree.ggnvp import ggn_vector_product, ggn_vector_product_from_plist from backpack.hessianfree.rop import R_op from backpack.utils.convert_parameters import vector_to_parameter_list @@ -10,12 +14,7 @@ class AutogradExtensions(ExtensionsImplementation): """Extension implementations with autograd.""" - def batch_grad(self): - """Scaled individual gradients computed by BackPACK's BatchGrad extension. - - Returns: - list[torch.Tensor]: batch_grads - """ + def batch_grad(self) -> List[Tensor]: # noqa: D102 N = self.problem.input.shape[0] batch_grads = [ torch.zeros(N, *p.size()).to(self.problem.device) @@ -39,12 +38,12 @@ def batch_grad(self): return batch_grads - def batch_l2_grad(self): + def batch_l2_grad(self) -> List[Tensor]: # noqa: D102 batch_grad = self.batch_grad() batch_l2_grads = [(g ** 2).flatten(start_dim=1).sum(1) for g in batch_grad] return batch_l2_grads - def sgs(self): + def sgs(self) -> List[Tensor]: # noqa: D102 N = self.problem.input.shape[0] sgs = [ torch.zeros(*p.size()).to(self.problem.device) @@ -67,7 +66,7 @@ def sgs(self): sgs[idx] += (g.detach() * factor) ** 2 return sgs - def variance(self): + def variance(self) -> List[Tensor]: # noqa: D102 batch_grad = self.batch_grad() variances = [torch.var(g, dim=0, unbiased=False) for g in batch_grad] return variances @@ -94,7 +93,7 @@ def extract_ith_element_of_diag_ggn(i, p, loss, output): diag_ggns.append(diag_ggn_p.view(p.size())) return diag_ggns - def diag_ggn(self): + def diag_ggn(self) -> List[Tensor]: # noqa: D102 try: _, output, loss = self.problem.forward_pass() return self._get_diag_ggn(loss, output) @@ -105,16 +104,16 @@ def diag_ggn(self): _, output, loss = self.problem.forward_pass() return self._get_diag_ggn(loss, output) - def diag_ggn_batch(self): + def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa: D102 try: - return self._diag_ggn_batch() + return self._diag_ggn_exact_batch() except RuntimeError: # torch does not implement cuda double-backwards pass on RNNs and # recommends this workaround with torch.backends.cudnn.flags(enabled=False): - return self._diag_ggn_batch() + return self._diag_ggn_exact_batch() - def _diag_ggn_batch(self): + def _diag_ggn_exact_batch(self): batch_size = self.problem.input.shape[0] _, _, batch_loss = self.problem.forward_pass() loss_list = torch.zeros(batch_size, device=self.problem.device) @@ -158,11 +157,11 @@ def extract_ith_element_of_diag_h(i, p, df_dx): diag_hs.append(diag_h_p.view(p.size())) return diag_hs - def diag_h(self): + def diag_h(self) -> List[Tensor]: # noqa: D102 _, _, loss = self.problem.forward_pass() return self._get_diag_h(loss) - def diag_h_batch(self): + def diag_h_batch(self) -> List[Tensor]: # noqa: D102 batch_size = self.problem.input.shape[0] _, _, batch_loss = self.problem.forward_pass() loss_list = torch.zeros(batch_size, device=self.problem.device) @@ -176,3 +175,42 @@ def diag_h_batch(self): factor = self.problem.get_reduction_factor(batch_loss, loss_list) params_batch_diag_h = list(zip(*batch_diag_h)) return [torch.stack(param) * factor for param in params_batch_diag_h] + + def ggn(self) -> Tensor: # noqa: D102 + _, output, loss = self.problem.forward_pass() + model = self.problem.model + + num_params = sum(p.numel() for p in model.parameters()) + ggn = torch.zeros(num_params, num_params).to(self.problem.device) + + for i in range(num_params): + # GGN-vector product with i.th unit vector yields the i.th row + e_i = torch.zeros(num_params).to(self.problem.device) + e_i[i] = 1.0 + + # convert to model parameter shapes + e_i_list = vector_to_parameter_list(e_i, model.parameters()) + ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) + + ggn_i = parameters_to_vector(ggn_i_list) + ggn[i, :] = ggn_i + + return ggn + + def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa: D102 + raise NotImplementedError + + def diag_ggn_mc_batch(self, mc_samples: int) -> List[Tensor]: # noqa: D102 + raise NotImplementedError + + def ggn_mc(self, mc_samples: int, chunks: int = 1): # noqa: D102 + raise NotImplementedError + + def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: # noqa: D102 + raise NotImplementedError + + def kflr(self) -> List[List[Tensor]]: # noqa: D102 + raise NotImplementedError + + def kfra(self) -> List[List[Tensor]]: # noqa: D102 + raise NotImplementedError diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index ae121c6c5..c8d4d90bf 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -1,9 +1,14 @@ +"""Extension implementations with BackPACK.""" from test.extensions.implementation.base import ExtensionsImplementation from test.extensions.implementation.hooks import ( BatchL2GradHook, ExtensionHookManager, SumGradSquaredHook, ) +from test.extensions.problem import ExtensionsTestProblem +from typing import List + +from torch import Tensor, cat, einsum import backpack.extensions as new_ext from backpack import backpack @@ -12,26 +17,35 @@ class BackpackExtensions(ExtensionsImplementation): """Extension implementations with BackPACK.""" - def __init__(self, problem): + def __init__(self, problem: ExtensionsTestProblem): + """Add BackPACK functionality to, and store, the test case. + + Args: + problem: Test case + """ problem.extend() super().__init__(problem) - def batch_grad(self): + def batch_grad(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchGrad()): _, _, loss = self.problem.forward_pass() loss.backward() batch_grads = [p.grad_batch for p in self.problem.model.parameters()] return batch_grads - def batch_l2_grad(self): + def batch_l2_grad(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchL2Grad()): _, _, loss = self.problem.forward_pass() loss.backward() batch_l2_grad = [p.batch_l2 for p in self.problem.model.parameters()] return batch_l2_grad - def batch_l2_grad_extension_hook(self): - """Individual gradient squared ℓ₂ norms via extension hook.""" + def batch_l2_grad_extension_hook(self) -> List[Tensor]: + """Individual gradient squared ℓ₂ norms via extension hook. + + Returns: + Parameter-wise individual gradient norms. + """ hook = ExtensionHookManager(BatchL2GradHook()) with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() @@ -39,15 +53,19 @@ def batch_l2_grad_extension_hook(self): batch_l2_grad = [p.batch_l2_hook for p in self.problem.model.parameters()] return batch_l2_grad - def sgs(self): + def sgs(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.SumGradSquared()): _, _, loss = self.problem.forward_pass() loss.backward() sgs = [p.sum_grad_squared for p in self.problem.model.parameters()] return sgs - def sgs_extension_hook(self): - """Individual gradient second moment via extension hook.""" + def sgs_extension_hook(self) -> List[Tensor]: + """Individual gradient second moment via extension hook. + + Returns: + Parameter-wise individual gradient second moment. + """ hook = ExtensionHookManager(SumGradSquaredHook()) with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() @@ -55,21 +73,21 @@ def sgs_extension_hook(self): sgs = [p.sum_grad_squared_hook for p in self.problem.model.parameters()] return sgs - def variance(self): + def variance(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.Variance()): _, _, loss = self.problem.forward_pass() loss.backward() variances = [p.variance for p in self.problem.model.parameters()] return variances - def diag_ggn(self): + def diag_ggn(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() diag_ggn = [p.diag_ggn_exact for p in self.problem.model.parameters()] return diag_ggn - def diag_ggn_exact_batch(self): + def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() @@ -78,14 +96,14 @@ def diag_ggn_exact_batch(self): ] return diag_ggn_exact_batch - def diag_ggn_mc(self, mc_samples): + def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() diag_ggn_mc = [p.diag_ggn_mc for p in self.problem.model.parameters()] return diag_ggn_mc - def diag_ggn_mc_batch(self, mc_samples): + def diag_ggn_mc_batch(self, mc_samples) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() @@ -94,8 +112,16 @@ def diag_ggn_mc_batch(self, mc_samples): ] return diag_ggn_mc_batch - def diag_ggn_mc_chunk(self, mc_samples, chunks=10): - """Like ``diag_ggn_mc``, but handles larger number of samples by chunking.""" + def diag_ggn_mc_chunk(self, mc_samples: int, chunks: int = 10) -> List[Tensor]: + """Like ``diag_ggn_mc``, but can handle more samples by chunking. + + Args: + mc_samples: Number of Monte-Carlo samples. + chunks: Maximum sequential split of the computation. Default: ``10``. + + Returns: + Parameter-wise MC-approximation of the GGN diagonal. + """ chunk_samples = self.chunk_sizes(mc_samples, chunks) chunk_weights = [samples / mc_samples for samples in chunk_samples] @@ -113,9 +139,17 @@ def diag_ggn_mc_chunk(self, mc_samples, chunks=10): return diag_ggn_mc - def diag_ggn_mc_batch_chunk(self, mc_samples, chunks=10): - """ - Like ``diag_ggn_mc_batch``, but handles larger number of samples by chunking. + def diag_ggn_mc_batch_chunk( + self, mc_samples: int, chunks: int = 10 + ) -> List[Tensor]: + """Like ``diag_ggn_mc_batch``, but can handle more samples by chunking. + + Args: + mc_samples: Number of Monte-Carlo samples. + chunks: Maximum sequential split of the computation. Default: ``10``. + + Returns: + Parameter-wise MC-approximation of the per-sample GGN diagonals. """ chunk_samples = self.chunk_sizes(mc_samples, chunks) chunk_weights = [samples / mc_samples for samples in chunk_samples] @@ -137,8 +171,16 @@ def diag_ggn_mc_batch_chunk(self, mc_samples, chunks=10): return diag_ggn_mc_batch @staticmethod - def chunk_sizes(total_size, num_chunks): - """Return list containing the sizes of chunks.""" + def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: + """Return list containing the sizes of chunks. + + Args: + total_size: Total computation work. + num_chunks: Maximum number of chunks the work will be split into. + + Returns: + List of chunks with split work. + """ chunk_size = max(total_size // num_chunks, 1) if chunk_size == 1: @@ -152,14 +194,14 @@ def chunk_sizes(total_size, num_chunks): return sizes - def diag_h(self): + def diag_h(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagHessian()): _, _, loss = self.problem.forward_pass() loss.backward() diag_h = [p.diag_h for p in self.problem.model.parameters()] return diag_h - def kfac(self, mc_samples=1): + def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFAC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() @@ -167,7 +209,7 @@ def kfac(self, mc_samples=1): return kfac - def kflr(self): + def kflr(self) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFLR()): _, _, loss = self.problem.forward_pass() loss.backward() @@ -175,7 +217,7 @@ def kflr(self): return kflr - def kfra(self): + def kfra(self) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFRA()): _, _, loss = self.problem.forward_pass() loss.backward() @@ -183,10 +225,62 @@ def kfra(self): return kfra - def diag_h_batch(self): + def diag_h_batch(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagHessian()): _, _, loss = self.problem.forward_pass() loss.backward() diag_h_batch = [p.diag_h_batch for p in self.problem.model.parameters()] return diag_h_batch + + def ggn(self) -> Tensor: # noqa:D102 + return self._square_sqrt_ggn(self.sqrt_ggn()) + + def sqrt_ggn(self) -> List[Tensor]: + """Compute the matrix square root of the exact generalized Gauss-Newton. + + Returns: + Parameter-wise matrix square root of the exact GGN. + """ + with backpack(new_ext.SqrtGGNExact()): + _, _, loss = self.problem.forward_pass() + loss.backward() + + return [p.sqrt_ggn_exact for p in self.problem.model.parameters()] + + def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: + """Compute the approximate matrix square root of the generalized Gauss-Newton. + + Args: + mc_samples: Number of Monte-Carlo samples. + + Returns: + Parameter-wise approximate matrix square root of the exact GGN. + """ + with backpack(new_ext.SqrtGGNMC(mc_samples=mc_samples)): + _, _, loss = self.problem.forward_pass() + loss.backward() + + return [p.sqrt_ggn_mc for p in self.problem.model.parameters()] + + def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: # noqa:D102 + samples = self.chunk_sizes(mc_samples, chunks) + weights = [samples / mc_samples for samples in samples] + + return sum( + w * self._square_sqrt_ggn(self.sqrt_ggn_mc(s)) + for w, s in zip(weights, samples) + ) + + @staticmethod + def _square_sqrt_ggn(sqrt_ggn: List[Tensor]) -> Tensor: + """Utility function to concatenate and square the GGN factorization. + + Args: + sqrt_ggn: Parameter-wise matrix square root of the GGN. + + Returns: + Matrix representation of the GGN. + """ + sqrt_mat = cat([s.flatten(start_dim=2) for s in sqrt_ggn], dim=2) + return einsum("cni,cnj->ij", sqrt_mat, sqrt_mat) diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index 0b2eb73b4..6c677128c 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -1,75 +1,137 @@ -class ExtensionsImplementation: +"""Base class containing the functions to compare BackPACK and autograd.""" +from abc import ABC, abstractmethod +from test.extensions.problem import ExtensionsTestProblem +from typing import List + +from torch import Tensor + + +class ExtensionsImplementation(ABC): """Base class for autograd and BackPACK implementations of extensions.""" - def __init__(self, problem): + def __init__(self, problem: ExtensionsTestProblem): + """Store the test case. + + Args: + problem: Test case. + """ self.problem = problem - def batch_grad(self): + @abstractmethod + def batch_grad(self) -> List[Tensor]: """Individual gradients.""" raise NotImplementedError - def batch_l2_grad(self): + @abstractmethod + def batch_l2_grad(self) -> List[Tensor]: """L2 norm of Individual gradients.""" raise NotImplementedError - def sgs(self): - """Sum of Square of Individual gradients""" + @abstractmethod + def sgs(self) -> List[Tensor]: + """Sum of Square of Individual gradients.""" raise NotImplementedError - def variance(self): - """Variance of Individual gradients""" + @abstractmethod + def variance(self) -> List[Tensor]: + """Variance of Individual gradients.""" raise NotImplementedError - def diag_ggn(self): - """Diagonal of Gauss Newton""" + @abstractmethod + def diag_ggn(self) -> List[Tensor]: + """Diagonal of Gauss Newton.""" raise NotImplementedError - def diag_ggn_batch(self): - """Individual diagonal of Generalized Gauss-Newton/Fisher""" + @abstractmethod + def diag_ggn_exact_batch(self) -> List[Tensor]: + """Individual diagonal of Generalized Gauss-Newton/Fisher.""" raise NotImplementedError - def diag_ggn_mc(self, mc_samples): - """MC approximation of Diagonal of Gauss Newton""" - raise NotImplementedError + @abstractmethod + def diag_ggn_mc(self, mc_samples: int) -> List[Tensor]: + """MC approximation of the generalized Gauss-Newton/Fisher diagonal. - def diag_ggn_mc_batch(self, mc_samples): - """MC approximation of individual Generalized Gauss-Newton/Fisher diagonal.""" + Args: + mc_samples: Number of Monte-Carlo samples used for the approximation. + """ raise NotImplementedError - def diag_h(self): - """Diagonal of Hessian""" + @abstractmethod + def diag_ggn_mc_batch(self, mc_samples: int) -> List[Tensor]: + """MC approximation of individual Generalized Gauss-Newton/Fisher diagonal. + + Args: + mc_samples: Number of Monte-Carlo samples used for the approximation. + """ raise NotImplementedError - def kfac(self, mc_samples=1): + @abstractmethod + def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: """Kronecker-factored approximate curvature (KFAC). Args: - mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``. + mc_samples: Number of Monte-Carlo samples. Default: ``1``. Returns: - list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + Parameter-wise lists of Kronecker factors. """ raise NotImplementedError - def kflr(self): + @abstractmethod + def kflr(self) -> List[List[Tensor]]: """Kronecker-factored low-rank approximation (KFLR). Returns: - list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + Parameter-wise lists of Kronecker factors. """ raise NotImplementedError - def kfra(self): + @abstractmethod + def kfra(self) -> List[List[Tensor]]: """Kronecker-factored recursive approximation (KFRA). Returns: - list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors. + Parameter-wise lists of Kronecker factors. + """ + raise NotImplementedError + + @abstractmethod + def diag_h(self) -> List[Tensor]: + """Diagonal of Hessian. + + Returns: + Hessian diagonal for each parameter. """ + raise NotImplementedError - def diag_h_batch(self): + @abstractmethod + def diag_h_batch(self) -> List[Tensor]: """Per-sample Hessian diagonal. Returns: - list(torch.Tensor): Parameter-wise per-sample Hessian diagonal. + Parameter-wise per-sample Hessian diagonal. + """ + raise NotImplementedError + + @abstractmethod + def ggn(self) -> Tensor: + """Exact generalized Gauss-Newton/Fisher matrix. + + Returns: + Matrix representation of the exact GGN. + """ + raise NotImplementedError + + @abstractmethod + def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: + """Compute the MC-approximation of the GGN in chunks of MC samples. + + Args: + mc_samples: Number of Monte-Carlo samples. + chunks: Number of sequential portions to split the computation. + Default: ``1`` (no sequential split). + + Returns: + Matrix representation of the Monte-Carlo approximated GGN. """ raise NotImplementedError diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 11148cfae..ef50ba716 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -1,20 +1,24 @@ -"""Test configurations to test diag_ggn. +"""Test cases for BackPACK extensions for the GGN diagonal. -The tests are taken from `test.extensions.secondorder.secondorder_settings`, -but additional custom tests can be defined here by appending it to the list. +Includes +- ``DiagGGNExact`` +- ``DiagGGNMC`` +- ``BatchDiagGGNExact`` +- ``BatchDiagGGNMC`` + +Shared settings are taken from `test.extensions.secondorder.secondorder_settings`. +Additional local cases can be defined here through ``LOCAL_SETTINGS``. """ from test.core.derivatives.utils import regression_targets -from test.extensions.automated_settings import make_simple_act_setting from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS import torch -from torch.nn import ELU, RNN, SELU, Flatten, Sequential +from torch.nn import RNN, Flatten, Sequential from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple SHARED_SETTINGS = SECONDORDER_SETTINGS - LOCAL_SETTINGS = [ # RNN settings { @@ -30,14 +34,4 @@ "target_fn": lambda: regression_targets((8, 3 * 5)), }, ] - -############################################################################### -# test setting: Activation Layers # -############################################################################### -activations = [ELU, SELU] - -for act in activations: - for bias in [True, False]: - LOCAL_SETTINGS.append(make_simple_act_setting(act, bias=bias)) - DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py index c7a312bf6..c35100623 100644 --- a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_diag_ggn_batch(problem): +def test_diag_ggn_exact_batch(problem): """Test the individual diagonal of Generalized Gauss-Newton/Fisher. Args: @@ -21,7 +21,7 @@ def test_diag_ggn_batch(problem): problem.set_up() backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() - autograd_res = AutogradExtensions(problem).diag_ggn_batch() + autograd_res = AutogradExtensions(problem).diag_ggn_exact_batch() check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() @@ -44,7 +44,7 @@ def test_diag_ggn_mc_batch_light(problem): problem.set_up() backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() - mc_samples = 5000 + mc_samples = 6000 backpack_res_mc_avg = BackpackExtensions(problem).diag_ggn_mc_batch(mc_samples) check_sizes_and_values( diff --git a/test/extensions/secondorder/hbp/kfac_settings.py b/test/extensions/secondorder/hbp/kfac_settings.py index 895b99cc5..a6d0ece1d 100644 --- a/test/extensions/secondorder/hbp/kfac_settings.py +++ b/test/extensions/secondorder/hbp/kfac_settings.py @@ -1,8 +1,13 @@ """Define test cases for KFAC.""" -from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS +from test.extensions.secondorder.secondorder_settings import ( + GROUP_CONV_SETTINGS, + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, +) -SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +SHARED_NOT_SUPPORTED_SETTINGS = ( + GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS +) LOCAL_NOT_SUPPORTED_SETTINGS = [] NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/hbp/kflr_settings.py b/test/extensions/secondorder/hbp/kflr_settings.py index 6b74a2842..de61c5b3b 100644 --- a/test/extensions/secondorder/hbp/kflr_settings.py +++ b/test/extensions/secondorder/hbp/kflr_settings.py @@ -1,8 +1,13 @@ """Define test cases for KFLR.""" -from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS +from test.extensions.secondorder.secondorder_settings import ( + GROUP_CONV_SETTINGS, + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, +) -SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +SHARED_NOT_SUPPORTED_SETTINGS = ( + GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS +) LOCAL_NOT_SUPPORTED_SETTINGS = [] NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/hbp/kfra_settings.py b/test/extensions/secondorder/hbp/kfra_settings.py index 5a28ab738..94e65c2b7 100644 --- a/test/extensions/secondorder/hbp/kfra_settings.py +++ b/test/extensions/secondorder/hbp/kfra_settings.py @@ -1,8 +1,13 @@ """Define test cases for KFRA.""" -from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS +from test.extensions.secondorder.secondorder_settings import ( + GROUP_CONV_SETTINGS, + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, +) -SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS +SHARED_NOT_SUPPORTED_SETTINGS = ( + GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS +) LOCAL_NOT_SUPPORTED_SETTINGS = [] NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index f90456134..44a4483b3 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -8,14 +8,14 @@ Required entries: "module_fn" (callable): Contains a model constructed from `torch.nn` layers "input_fn" (callable): Used for specifying input function - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model Optional entries: "device" [list(torch.device)]: List of devices to run the test on. "id_prefix" (str): Prefix to be included in the test name. - "seed" (int): seed for the random number for torch.rand + "seed" (int): seed for the random number for rand """ @@ -26,7 +26,7 @@ make_simple_pooling_setting, ) -import torch +from torch import device, rand from torch.nn import ( ELU, SELU, @@ -39,12 +39,17 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + CrossEntropyLoss, + Flatten, LeakyReLU, + Linear, LogSigmoid, MaxPool1d, MaxPool2d, MaxPool3d, + MSELoss, ReLU, + Sequential, Sigmoid, Tanh, ) @@ -56,11 +61,11 @@ ############################################################################### example = { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential(torch.nn.Linear(10, 5)), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(), "target_fn": lambda: classification_targets((3,), 5), - "device": [torch.device("cpu")], + "device": [device("cpu")], "seed": 0, "id_prefix": "example", } @@ -70,28 +75,22 @@ SECONDORDER_SETTINGS += [ # classification { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3,), 5), }, { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, # Regression { - "input_fn": lambda: torch.rand(3, 10), - "module_fn": lambda: torch.nn.Sequential( - torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5) - ), - "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"), + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)), + "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] @@ -109,8 +108,8 @@ ############################################################################### # test setting: Pooling Layers # """ -Syntax with default parameters: - - `torch.nn.MaxPoolNd(kernel_size, stride, padding, dilation, +Syntax with default parameters: + - `MaxPoolNd(kernel_size, stride, padding, dilation, return_indices, ceil_mode)` """ ############################################################################### @@ -150,18 +149,18 @@ ############################################################################### # test setting: Convolutional Layers # """ -Syntax with default parameters: - - `torch.nn.ConvNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, dilation=1, - groups=1, bias=True, padding_mode='zeros)` +Syntax with default parameters: + - `ConvNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, dilation=1, + groups=1, bias=True, padding_mode='zeros)` - - `torch.nn.ConvTransposeNd(in_channels, out_channels, - kernel_size, stride=1, padding=0, output_padding=0, + - `ConvTransposeNd(in_channels, out_channels, + kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros)` -Note: There are 5 tests added to each `torch.nn.layers`. -For `torch.nn.ConvTranspose2d` and `torch.nn.ConvTranspose3d` -only 3 tests are added because they are very memory intensive. +Note: There are 5 tests added to each `layers`. +For `ConvTranspose2d` and `ConvTranspose3d` +only 3 tests are added because they are very memory intensive. """ ############################################################################### @@ -233,3 +232,70 @@ ] SECONDORDER_SETTINGS += GROUP_CONV_SETTINGS + +SECONDORDER_SETTINGS += [ + { + # Flatten layer does not add a node in the computation graph and thus the + # backward hook will be called at an unexpected stage. This must explicitly + # be addressed in the `backpropagate` function of the flatten module extension. + "input_fn": lambda: rand(3, 5), + "module_fn": lambda: Sequential(Linear(5, 4), Flatten(), Linear(4, 2)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 2), + "id_prefix": "flatten-no-op", + }, +] + +# linear with additional dimension +LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS = [ + # regression + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 8)), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential( + Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten() + ), + "loss_function_fn": lambda: MSELoss(reduction="mean"), + "target_fn": lambda: regression_targets((3, 16)), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: MSELoss(reduction="sum"), + "target_fn": lambda: regression_targets((3, 48)), + "id_prefix": "three-additional", + }, + # classification + { + "input_fn": lambda: rand(3, 4, 5), + "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 8), + "id_prefix": "one-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 5), + "module_fn": lambda: Sequential( + Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten() + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 16), + "id_prefix": "two-additional", + }, + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2), Flatten()), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 48), + "id_prefix": "three-additional", + }, +] + +SECONDORDER_SETTINGS += LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS diff --git a/test/extensions/secondorder/sqrt_ggn/__init__.py b/test/extensions/secondorder/sqrt_ggn/__init__.py new file mode 100644 index 000000000..6741d2c13 --- /dev/null +++ b/test/extensions/secondorder/sqrt_ggn/__init__.py @@ -0,0 +1 @@ +"""Contains tests of ``backpack.extensions.secondorder.sqrt_ggn``.""" diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py new file mode 100644 index 000000000..ba3a0436c --- /dev/null +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -0,0 +1,92 @@ +"""Tests BackPACK's ``SqrtGGNExact`` and ``SqrtGGNMC`` extension.""" + +from test.automated_test import check_sizes_and_values +from test.extensions.implementation.autograd import AutogradExtensions +from test.extensions.implementation.backpack import BackpackExtensions +from test.extensions.problem import ExtensionsTestProblem, make_test_problems +from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS + +from pytest import fixture, mark, skip + +PROBLEMS = make_test_problems(SECONDORDER_SETTINGS) + + +@fixture(params=PROBLEMS, ids=lambda p: p.make_id()) +def instantiated_problem(request) -> ExtensionsTestProblem: + """Set seed, create tested model, loss, data. Finally clean up. + + Args: + request (SubRequest): Request for the fixture from a test/fixture function. + + Yields: + Test case with deterministically constructed attributes. + """ + case = request.param + case.set_up() + yield case + case.tear_down() + + +@fixture +def small_problem( + instantiated_problem: ExtensionsTestProblem, max_num_params=1000 +) -> ExtensionsTestProblem: + """Skip architectures with too many parameters whose GGN is expensive to evaluate. + + Args: + instantiated_problem: Test case with instantiated model, data, etc. + max_num_params: Maximum number of model parameters to run the case. + Default: ``1000``. + + Yields: + Instantiated test case whose model's are small enough. + """ + num_params = sum( + p.numel() for p in instantiated_problem.model.parameters() if p.requires_grad + ) + if num_params <= max_num_params: + yield instantiated_problem + else: + skip(f"Model has too many parameters: {num_params} > {max_num_params}") + + +def test_ggn_exact(small_problem: ExtensionsTestProblem): + """Compare exact GGN from BackPACK's matrix square root with autograd. + + Args: + small_problem: Test case with small network whose GGN can be evaluated. + """ + autograd_res = AutogradExtensions(small_problem).ggn() + backpack_res = BackpackExtensions(small_problem).ggn() + + check_sizes_and_values(autograd_res, backpack_res) + + +def test_sqrt_ggn_mc_integration(small_problem: ExtensionsTestProblem): + """Check if MC-approximated GGN matrix square root code executes. + + Note: + This test does not perform correctness checks on the results, + which are expensive because a large number of samples is required. + Such a check is performed by `test_sqrt_ggn_mc`, which is run less + frequently. + + Args: + small_problem: Test case with small network whose GGN can be evaluated. + """ + BackpackExtensions(small_problem).sqrt_ggn_mc(mc_samples=1) + + +@mark.montecarlo +def test_ggn_mc(small_problem: ExtensionsTestProblem): + """Compare MC-approximated GGN from BackpACK's with exact version from autograd. + + Args: + small_problem: Test case with small network whose GGN can be evaluated. + """ + autograd_res = AutogradExtensions(small_problem).ggn() + atol, rtol = 5e-2, 1e-2 + mc_samples, chunks = 300000, 30 + backpack_res = BackpackExtensions(small_problem).ggn_mc(mc_samples, chunks=chunks) + + check_sizes_and_values(autograd_res, backpack_res, atol=atol, rtol=rtol) From f4defa55901de8d1b9aea9819dd3153765750b43 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Tue, 29 Jun 2021 17:33:19 +0200 Subject: [PATCH 23/54] [REQ] Fully deprecate python3.6 (#190) * [REQ] Bump to python3.7+ * [DOC] Add annotations with cyclic imports (python3.7+) --- README-dev.md | 4 ++-- README.md | 2 +- .../extensions/secondorder/sqrt_ggn/base.py | 19 ++++++++++++------- .../secondorder/sqrt_ggn/flatten.py | 11 +++++++---- .../extensions/secondorder/sqrt_ggn/losses.py | 11 +++++++---- setup.cfg | 3 +-- test/test___init__.py | 16 +--------------- 7 files changed, 31 insertions(+), 35 deletions(-) diff --git a/README-dev.md b/README-dev.md index e853e7bc8..913b035e5 100644 --- a/README-dev.md +++ b/README-dev.md @@ -1,7 +1,7 @@ # BackPACK BackPACK developer manual -## General standards -- Python version: support 3.6+, use 3.7 for development +## General standards +- Python version: support 3.7+, use 3.7 for development - `git` [branching model](https://nvie.com/posts/a-successful-git-branching-model/) - Docstring style: [Google](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) - Test runner: [`pytest`](https://docs.pytest.org/en/latest/) diff --git a/README.md b/README.md index f3f179d15..14cdeea8c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Travis](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack) [![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack) -[![Python 3.6+](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/downloads/release/python-360/) +[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/) BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient. diff --git a/backpack/extensions/secondorder/sqrt_ggn/base.py b/backpack/extensions/secondorder/sqrt_ggn/base.py index 336325b08..d5e20e721 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/base.py +++ b/backpack/extensions/secondorder/sqrt_ggn/base.py @@ -1,5 +1,7 @@ """Contains base class for ``SqrtGGN{Exact, MC}`` module extensions.""" -from typing import Any, Callable, List, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Tuple, Union from torch import Tensor from torch.nn import Module @@ -7,6 +9,9 @@ from backpack.core.derivatives.basederivatives import BaseDerivatives from backpack.extensions.mat_to_mat_jac_base import MatToJacMat +if TYPE_CHECKING: + from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC + class SqrtGGNBaseModule(MatToJacMat): """Base module extension for ``SqrtGGN{Exact, MC}``.""" @@ -28,11 +33,12 @@ def __init__(self, derivatives: BaseDerivatives, params: List[str] = None): super().__init__(derivatives, params=params) - # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] - # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) def _make_param_function( self, param_str: str - ) -> Callable[[Any, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor]: + ) -> Callable[ + [Union[SqrtGGNExact, SqrtGGNMC], Module, Tuple[Tensor], Tuple[Tensor], Tensor], + Tensor, + ]: """Create a function that computes the GGN/Fisher square root for a parameter. Args: @@ -41,10 +47,9 @@ def _make_param_function( Returns: Function that computes the GGN/Fisher matrix square root. """ - # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] - # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) + def param_function( - ext: Any, + ext: Union[SqrtGGNExact, SqrtGGNMC], module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], diff --git a/backpack/extensions/secondorder/sqrt_ggn/flatten.py b/backpack/extensions/secondorder/sqrt_ggn/flatten.py index c2f4eb5a3..bae59b402 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/flatten.py +++ b/backpack/extensions/secondorder/sqrt_ggn/flatten.py @@ -1,5 +1,7 @@ """Contains extensions for the flatten layer used by ``SqrtGGN{Exact, MC}``.""" -from typing import Any, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple, Union from torch import Tensor from torch.nn import Module @@ -7,6 +9,9 @@ from backpack.core.derivatives.flatten import FlattenDerivatives from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule +if TYPE_CHECKING: + from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC + class SqrtGGNFlatten(SqrtGGNBaseModule): """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Flatten`` module.""" @@ -15,11 +20,9 @@ def __init__(self): """Pass derivatives for ``torch.nn.Flatten`` module.""" super().__init__(FlattenDerivatives()) - # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] - # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) def backpropagate( self, - ext: Any, + ext: Union[SqrtGGNExact, SqrtGGNMC], module: Module, grad_inp: Tuple[Tensor], grad_out: Tuple[Tensor], diff --git a/backpack/extensions/secondorder/sqrt_ggn/losses.py b/backpack/extensions/secondorder/sqrt_ggn/losses.py index 8e48cf318..30f561396 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/losses.py +++ b/backpack/extensions/secondorder/sqrt_ggn/losses.py @@ -1,5 +1,7 @@ """Contains base class and extensions for losses used by ``SqrtGGN{Exact, MC}``.""" -from typing import Any, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple, Union from torch import Tensor from torch.nn import Module @@ -9,15 +11,16 @@ from backpack.extensions.secondorder.hbp import LossHessianStrategy from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule +if TYPE_CHECKING: + from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC + class SqrtGGNBaseLossModule(SqrtGGNBaseModule): """Base class for losses used by ``SqrtGGN{Exact, MC}``.""" - # TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] - # WAITING Deprecation of python3.6 (cyclic imports caused by annotations) def backpropagate( self, - ext: Any, + ext: Union[SqrtGGNExact, SqrtGGNMC], module: Module, grad_inp: Tuple[Tensor], grad_out: Tuple[Tensor], diff --git a/setup.cfg b/setup.cfg index bf2aa19f0..32ae54570 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,6 @@ classifiers = Development Status :: 4 - Beta License :: OSI Approved :: MIT License Operating System :: OS Independent - Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 @@ -39,7 +38,7 @@ install_requires = torchvision >= 0.7.0, < 1.0.0 einops >= 0.3.0, < 1.0.0 # Require a specific Python version, e.g. Python 2.7 or >= 3.4 -python_requires = >=3.6 +python_requires = >=3.7 [options.packages.find] exclude = test* diff --git a/test/test___init__.py b/test/test___init__.py index 600b82b0f..16abef1fc 100644 --- a/test/test___init__.py +++ b/test/test___init__.py @@ -1,5 +1,6 @@ """Tests for `backpack.__init__.py`.""" +from contextlib import nullcontext from test import pytorch_current_memory_usage from test.core.derivatives.utils import classification_targets, get_available_devices @@ -12,21 +13,6 @@ DEVICES_ID = [str(dev) for dev in DEVICES] -# TODO Use contextlib.nullcontext after dropping Python 3.6 support -class nullcontext: - """Empty context. - - ``contextlib.nullcontext`` is available from Python 3.7 onwards. - The tests are also executed on Python 3.6. - """ - - def __enter__(self): - pass - - def __exit__(self, type, value, traceback): - pass - - def test_no_io(): """Check IO is not tracked.""" torch.manual_seed(0) From 5726cc1ab3e6c31d94192b6fe6b256911245a01f Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 30 Jun 2021 16:35:48 +0200 Subject: [PATCH 24/54] Resolve merge conflicts of `rnn` with `development` (#191) * [FIX] Resolve merge conflicts with `development` * [CI] Add `setup.py` to fully documented again --- backpack/core/derivatives/basederivatives.py | 5 +++ .../firstorder/batch_l2_grad/linear.py | 1 - .../firstorder/sum_grad_squared/linear.py | 4 +- backpack/utils/conv.py | 14 +++--- backpack/utils/linear.py | 2 + fully_documented.txt | 27 +++++------- test/core/derivatives/derivatives_test.py | 2 +- .../derivatives/implementation/autograd.py | 4 ++ .../derivatives/implementation/backpack.py | 2 +- test/extensions/implementation/base.py | 44 +++++++++---------- 10 files changed, 53 insertions(+), 52 deletions(-) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index cf5635261..1c583d0a2 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -157,6 +157,11 @@ def hessian_is_zero(self) -> bool: def hessian_is_diagonal(self) -> bool: """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`. + The Hessian diagonal is only defined for layers that preserve the size + of their input. + + Must be implemented by descendants that don't implement ``hessian_is_zero``. + # noqa: DAR202 Returns: whether Hessian is diagonal diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py index 4978fd059..1c5c10d7c 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/linear.py +++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py @@ -48,6 +48,5 @@ def weight( dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) X = module.input0.flatten(start_dim=1, end_dim=-2) return einsum("nmi,nmj,nki,nkj->n", dE_dY, X, dE_dY, X) - else: return einsum("ni,nj->n", g_out[0] ** 2, module.input0 ** 2) diff --git a/backpack/extensions/firstorder/sum_grad_squared/linear.py b/backpack/extensions/firstorder/sum_grad_squared/linear.py index 8239da936..99bb73c40 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/linear.py +++ b/backpack/extensions/firstorder/sum_grad_squared/linear.py @@ -26,5 +26,5 @@ def weight(self, ext, module, g_inp, g_out, backproped): dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) X = module.input0.flatten(start_dim=1, end_dim=-2) return einsum("nmi,nmj,nki,nkj->ij", dE_dY, X, dE_dY, X) - - return einsum("ni,nj->ij", g_out[0] ** 2, module.input0 ** 2) + else: + return einsum("ni,nj->ij", g_out[0] ** 2, module.input0 ** 2) diff --git a/backpack/utils/conv.py b/backpack/utils/conv.py index e2838e027..3aa4b9a2e 100644 --- a/backpack/utils/conv.py +++ b/backpack/utils/conv.py @@ -1,22 +1,24 @@ +from typing import Union + import torch from einops import rearrange -from torch import einsum +from torch import Tensor, einsum +from torch.nn import Conv1d, Conv2d, Conv3d from torch.nn.functional import conv1d, conv2d, conv3d, unfold -def unfold_input(module, input): +def unfold_input(module: Union[Conv1d, Conv2d, Conv3d], input: Tensor) -> Tensor: """Return unfolded input to a convolution. Use PyTorch's ``unfold`` operation for 2d convolutions (4d input tensors), otherwise fall back to a custom implementation. Args: - module (torch.nn.Conv1d or torch.nn.Conv2d or torch.nn.Conv3d): Convolution - module whose hyperparameters are used for the unfold. - input (torch.Tensor): Input to convolution that will be unfolded. + module: Convolution module whose hyperparameters are used for the unfold. + input: Input to convolution that will be unfolded. Returns: - torch.Tensor: Unfolded input. + Unfolded input. """ if input.dim() == 4: return unfold( diff --git a/backpack/utils/linear.py b/backpack/utils/linear.py index b61c72cab..60912e5ac 100644 --- a/backpack/utils/linear.py +++ b/backpack/utils/linear.py @@ -34,6 +34,8 @@ def extract_weight_diagonal( return einsum(equation, S ** 2, module.input0 ** 2) +# TODO This method applies the bias Jacobian, then squares and sums the result. Intro- +# duce base class for {Batch}DiagHessian and DiagGGN{Exact,MC} and remove this method def extract_bias_diagonal(module: Linear, S: Tensor, sum_batch: bool = True) -> Tensor: """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the bias Jacobian. diff --git a/fully_documented.txt b/fully_documented.txt index a37063729..f2e5646bf 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -14,7 +14,6 @@ backpack/core/derivatives/adaptive_avg_pool_nd.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py backpack/extensions/mat_to_mat_jac_base.py - backpack/extensions/firstorder/gradient/base.py backpack/extensions/firstorder/gradient/rnn.py backpack/extensions/firstorder/gradient/__init__.py @@ -28,7 +27,6 @@ backpack/extensions/firstorder/sum_grad_squared/sgs_base.py backpack/extensions/firstorder/sum_grad_squared/rnn.py backpack/extensions/firstorder/sum_grad_squared/__init__.py backpack/extensions/firstorder/batch_l2_grad/ - backpack/extensions/secondorder/__init__.py backpack/extensions/secondorder/diag_ggn/__init__.py backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py @@ -43,29 +41,24 @@ backpack/extensions/secondorder/sqrt_ggn/ backpack/utils/linear.py backpack/utils/__init__.py -test/adaptive_avg_pool/ - -test/core/derivatives/derivatives_test.py -test/core/derivatives/__init__.py -test/core/derivatives/rnn_settings.py -test/core/derivatives/utils.py -test/core/derivatives/implementation/ -test/core/derivatives/permute_settings.py -test/core/derivatives/lstm_settings.py -test/core/derivatives/pooling_adaptive_settings.py - test/extensions/test_backprop_extension.py -test/extensions/problem.py - test/extensions/firstorder/firstorder_settings.py test/extensions/firstorder/variance/ test/extensions/firstorder/batch_grad/batchgrad_settings.py - test/extensions/secondorder/secondorder_settings.py test/extensions/secondorder/diag_ggn/ test/extensions/secondorder/hbp/ test/extensions/secondorder/sqrt_ggn/ - test/extensions/implementation/base.py test/extensions/implementation/autograd.py test/extensions/implementation/backpack.py +test/adaptive_avg_pool/ +test/core/derivatives/derivatives_test.py +test/core/derivatives/__init__.py +test/core/derivatives/rnn_settings.py +test/core/derivatives/utils.py +test/core/derivatives/implementation/ +test/core/derivatives/permute_settings.py +test/core/derivatives/lstm_settings.py +test/core/derivatives/pooling_adaptive_settings.py +test/extensions/problem.py diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index b21babde2..50ed3f321 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -467,7 +467,7 @@ def small_input_problem( """Skip cases with large inputs. Args: - instantiated_problem: Test case with constructed attributes. + instantiated_problem: Test case with deterministically constructed attributes. max_input_numel: Maximum input size. Default: ``100``. Yields: diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index b4d1b43d0..cbc2043a0 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -229,6 +229,10 @@ def _hessian(self, loss: Tensor, x: Tensor) -> Tensor: def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor: """Computes the Hessian of each element in `tensor` w.r.t `x`. + If ``tensor`` is linear in ``x``, autograd raises a ``RuntimeError``. + If ``tensor`` does not depend on ``x``, autograd raises an ``AttributeError``. + In both cases, a Hessian of zeros is created manually and returned. + Given a `tensor` of shape `[A, B, C]` and another tensor `x` with shape `[D, E]` used in the computation of `tensor`, the generalized Hessian has shape [A, B, C, D, E, D, E]. Let `hessian` denote this generalized Hessian. Then, diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 3ac44cf3b..a31db1f80 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -124,7 +124,7 @@ def input_hessian_via_sqrt_hessian(self, mc_samples=None) -> Tensor: def hessian_is_zero(self) -> bool: # noqa: D102 return self.problem.derivative.hessian_is_zero() - def _sample_hessians_from_sqrt(self, sqrt: Tensor) -> Tensor: + def _sample_hessians_from_sqrt(self, sqrt): """Convert individual matrix square root into individual full matrix. Args: diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index 6c677128c..f3d73d91a 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -20,32 +20,32 @@ def __init__(self, problem: ExtensionsTestProblem): @abstractmethod def batch_grad(self) -> List[Tensor]: """Individual gradients.""" - raise NotImplementedError + return @abstractmethod def batch_l2_grad(self) -> List[Tensor]: """L2 norm of Individual gradients.""" - raise NotImplementedError + return @abstractmethod def sgs(self) -> List[Tensor]: """Sum of Square of Individual gradients.""" - raise NotImplementedError + return @abstractmethod def variance(self) -> List[Tensor]: """Variance of Individual gradients.""" - raise NotImplementedError + return @abstractmethod def diag_ggn(self) -> List[Tensor]: """Diagonal of Gauss Newton.""" - raise NotImplementedError + return @abstractmethod def diag_ggn_exact_batch(self) -> List[Tensor]: """Individual diagonal of Generalized Gauss-Newton/Fisher.""" - raise NotImplementedError + return @abstractmethod def diag_ggn_mc(self, mc_samples: int) -> List[Tensor]: @@ -54,7 +54,7 @@ def diag_ggn_mc(self, mc_samples: int) -> List[Tensor]: Args: mc_samples: Number of Monte-Carlo samples used for the approximation. """ - raise NotImplementedError + return @abstractmethod def diag_ggn_mc_batch(self, mc_samples: int) -> List[Tensor]: @@ -63,7 +63,12 @@ def diag_ggn_mc_batch(self, mc_samples: int) -> List[Tensor]: Args: mc_samples: Number of Monte-Carlo samples used for the approximation. """ - raise NotImplementedError + return + + @abstractmethod + def diag_h(self) -> List[Tensor]: + """Diagonal of Hessian.""" + return @abstractmethod def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: @@ -75,7 +80,7 @@ def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: Returns: Parameter-wise lists of Kronecker factors. """ - raise NotImplementedError + return @abstractmethod def kflr(self) -> List[List[Tensor]]: @@ -84,7 +89,7 @@ def kflr(self) -> List[List[Tensor]]: Returns: Parameter-wise lists of Kronecker factors. """ - raise NotImplementedError + return @abstractmethod def kfra(self) -> List[List[Tensor]]: @@ -93,25 +98,16 @@ def kfra(self) -> List[List[Tensor]]: Returns: Parameter-wise lists of Kronecker factors. """ - raise NotImplementedError - - @abstractmethod - def diag_h(self) -> List[Tensor]: - """Diagonal of Hessian. - - Returns: - Hessian diagonal for each parameter. - """ - raise NotImplementedError + return @abstractmethod def diag_h_batch(self) -> List[Tensor]: """Per-sample Hessian diagonal. Returns: - Parameter-wise per-sample Hessian diagonal. + list(torch.Tensor): Parameter-wise per-sample Hessian diagonal. """ - raise NotImplementedError + return @abstractmethod def ggn(self) -> Tensor: @@ -120,7 +116,7 @@ def ggn(self) -> Tensor: Returns: Matrix representation of the exact GGN. """ - raise NotImplementedError + return @abstractmethod def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: @@ -134,4 +130,4 @@ def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: Returns: Matrix representation of the Monte-Carlo approximated GGN. """ - raise NotImplementedError + return From f3bbc892a7142c3bef036e5c2e40db6b6766942a Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 30 Jun 2021 16:52:07 +0200 Subject: [PATCH 25/54] [FIX] Check `LSTM` `proj_size` only for `torch>=1.8.0` (#192) Resolves https://github.com/fKunstner/backpack-discuss/issues/112. --- backpack/core/derivatives/lstm.py | 6 ++++-- backpack/utils/__init__.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 1ad3a79b8..08888066a 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -5,6 +5,7 @@ from torch.nn import LSTM from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils import TORCH_VERSION, VERSION_1_8_0 class LSTMDerivatives(BaseParameterDerivatives): @@ -47,8 +48,9 @@ def _check_parameters(module: LSTM) -> None: raise NotImplementedError("only dropout = 0 is supported") if module.bidirectional is not False: raise NotImplementedError("only bidirectional = False is supported") - if module.proj_size != 0: - raise NotImplementedError("only proj_size = 0 is supported") + if TORCH_VERSION >= VERSION_1_8_0: + if module.proj_size != 0: + raise NotImplementedError("only proj_size = 0 is supported") @staticmethod def _forward_pass( diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index 39f7fa2b1..7d98a7c33 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -1 +1,6 @@ """Contains utility functions.""" + +from pkg_resources import get_distribution, packaging + +TORCH_VERSION = packaging.version.parse(get_distribution("torch").version) +VERSION_1_8_0 = packaging.version.parse("1.8.0") From 8ea0fb575b1a54aae3cef95a1746e8d1d0502da8 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 1 Jul 2021 15:53:06 +0200 Subject: [PATCH 26/54] [REF] Incorporate feedback from PR reviews --- .github/workflows/lint.yaml | 1 - .github/workflows/test.yaml | 1 - backpack/core/derivatives/lstm.py | 17 ++++---- backpack/core/derivatives/rnn.py | 13 ++++-- backpack/custom_module/permute.py | 2 +- backpack/custom_module/reduce_tuple.py | 2 +- .../firstorder/batch_l2_grad/__init__.py | 10 +++-- .../firstorder/batch_l2_grad/rnn.py | 2 +- .../extensions/firstorder/gradient/rnn.py | 3 +- .../firstorder/sum_grad_squared/rnn.py | 2 +- .../firstorder/sum_grad_squared/sgs_base.py | 12 +++--- .../extensions/firstorder/variance/rnn.py | 2 +- .../firstorder/variance/variance_base.py | 4 +- backpack/extensions/mat_to_mat_jac_base.py | 13 ++---- .../secondorder/diag_ggn/__init__.py | 10 +---- .../secondorder/diag_ggn/diag_ggn_base.py | 10 +++-- fully_documented.txt | 2 +- .../derivatives/implementation/autograd.py | 42 ++++++------------- test/extensions/problem.py | 2 +- 19 files changed, 68 insertions(+), 82 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 1cec88cf5..f77039b29 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -9,7 +9,6 @@ on: - development - master - release - - rnn jobs: flake8: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c235d3eea..2f63340eb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -9,7 +9,6 @@ on: - development - master - release - - rnn jobs: tests: diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 08888066a..1360103e6 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -270,7 +270,6 @@ def _jac_mat_prod( def _jac_t_mat_prod( self, module: LSTM, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: - self._check_parameters(module) IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) @@ -304,11 +303,9 @@ def _bias_hh_l0_jac_t_mat_prod( mat: Tensor, sum_batch: bool = True, ) -> Tensor: - self._check_parameters(module) - - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) - - return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod) + return self._bias_ih_l0_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch + ) def _weight_ih_l0_jac_t_mat_prod( self, @@ -344,5 +341,11 @@ def _weight_hh_l0_jac_t_mat_prod( return einsum( f"vtnh,tng->v{'' if sum_batch else 'n'}hg", IFGO_prod, - cat([zeros(1, N, H, device=mat.device), module.output[0:-1]], dim=0), + cat( + [ + zeros(1, N, H, device=mat.device, dtype=mat.dtype), + module.output[0:-1], + ], + dim=0, + ), ) diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 98e6d4529..434b987b9 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -67,7 +67,7 @@ def _a_jac_t_mat_prod( T: int = mat.shape[1] H: int = mat.shape[3] a_jac_t_mat_prod: Tensor = zeros(V, T, N, H, device=mat.device) - for t in range(T)[::-1]: + for t in reversed(range(T)): if t == (T - 1): a_jac_t_mat_prod[:, t, ...] = einsum( "vnh,nh->vnh", @@ -90,6 +90,7 @@ def _a_jac_t_mat_prod( def _jac_t_mat_prod( self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: + self._check_parameters(module) return torch.einsum( "vtnh,hk->vtnk", self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), @@ -99,6 +100,7 @@ def _jac_t_mat_prod( def _jac_mat_prod( self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: + self._check_parameters(module) V: int = mat.shape[0] N: int = mat.shape[2] T: int = mat.shape[1] @@ -174,7 +176,6 @@ def _bias_hh_l0_jac_t_mat_prod( Returns: product """ - # identical to bias_ih_l0 return self._bias_ih_l0_jac_t_mat_prod( module, g_inp, g_out, mat, sum_batch=sum_batch ) @@ -232,5 +233,11 @@ def _weight_hh_l0_jac_t_mat_prod( return einsum( "vtnh,tnk->" + ("vhk" if sum_batch else "vnhk"), self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), - cat([zeros(1, N, H, device=mat.device), module.output[0:-1]], dim=0), + cat( + [ + zeros(1, N, H, device=mat.device, dtype=mat.dtype), + module.output[0:-1], + ], + dim=0, + ), ) diff --git a/backpack/custom_module/permute.py b/backpack/custom_module/permute.py index 3e71dde80..04b4ea619 100644 --- a/backpack/custom_module/permute.py +++ b/backpack/custom_module/permute.py @@ -14,7 +14,7 @@ def __init__(self, *dims: Any): Args: dims: The desired ordering of dimensions. """ - super(Permute, self).__init__() + super().__init__() self.dims = dims def forward(self, input: Tensor) -> Tensor: diff --git a/backpack/custom_module/reduce_tuple.py b/backpack/custom_module/reduce_tuple.py index 5b95179bb..02fa9f5cc 100644 --- a/backpack/custom_module/reduce_tuple.py +++ b/backpack/custom_module/reduce_tuple.py @@ -14,7 +14,7 @@ def __init__(self, index: int = 0): Args: index: which element to choose """ - super(ReduceTuple, self).__init__() + super().__init__() self.index = index def forward(self, input: tuple) -> Union[tuple, Tensor]: diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index c68bfc820..5864022c3 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -15,8 +15,12 @@ ) from backpack.extensions.backprop_extension import BackpropExtension - -from . import convnd, convtransposend, linear, rnn +from backpack.extensions.firstorder.batch_l2_grad import ( + convnd, + convtransposend, + linear, + rnn, +) class BatchL2Grad(BackpropExtension): @@ -42,7 +46,7 @@ def __init__(self): Define the extensions for each module. """ - super(BatchL2Grad, self).__init__( + super().__init__( savefield="batch_l2", fail_mode="WARNING", module_exts={ diff --git a/backpack/extensions/firstorder/batch_l2_grad/rnn.py b/backpack/extensions/firstorder/batch_l2_grad/rnn.py index 159be22e2..efdbc4320 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/rnn.py +++ b/backpack/extensions/firstorder/batch_l2_grad/rnn.py @@ -8,7 +8,7 @@ class BatchL2RNN(BatchL2Base): def __init__(self): """Initialization.""" - super(BatchL2RNN, self).__init__( + super().__init__( ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], derivatives=RNNDerivatives(), ) diff --git a/backpack/extensions/firstorder/gradient/rnn.py b/backpack/extensions/firstorder/gradient/rnn.py index 7924ca554..c99130efc 100644 --- a/backpack/extensions/firstorder/gradient/rnn.py +++ b/backpack/extensions/firstorder/gradient/rnn.py @@ -1,7 +1,6 @@ """Contains GradRNN.""" from backpack.core.derivatives.rnn import RNNDerivatives - -from .base import GradBaseModule +from backpack.extensions.firstorder.gradient.base import GradBaseModule class GradRNN(GradBaseModule): diff --git a/backpack/extensions/firstorder/sum_grad_squared/rnn.py b/backpack/extensions/firstorder/sum_grad_squared/rnn.py index 61e96d698..7388dd746 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/rnn.py +++ b/backpack/extensions/firstorder/sum_grad_squared/rnn.py @@ -8,7 +8,7 @@ class SGSRNN(SGSBase): def __init__(self): """Initialization.""" - super(SGSRNN, self).__init__( + super().__init__( derivatives=RNNDerivatives(), params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], ) diff --git a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py index 48ab13734..09f3d80f1 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py +++ b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py @@ -30,7 +30,7 @@ def __init__(self, derivatives: BaseParameterDerivatives, params: List[str] = No for param_str in params: if not hasattr(self, param_str): setattr(self, param_str, self._make_param_function(param_str)) - super(SGSBase, self).__init__(params=params) + super().__init__(params=params) def _make_param_function( self, param_str: str @@ -63,9 +63,11 @@ def param_function( Returns: sum_grad_squared """ - grad_batch = getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( - module, g_inp, g_out, g_out[0], sum_batch=False - ) - return (grad_batch ** 2).sum(self.N_axis) + return ( + getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( + module, g_inp, g_out, g_out[0], sum_batch=False + ) + ** 2 + ).sum(self.N_axis) return param_function diff --git a/backpack/extensions/firstorder/variance/rnn.py b/backpack/extensions/firstorder/variance/rnn.py index d41f6de8e..93342b32c 100644 --- a/backpack/extensions/firstorder/variance/rnn.py +++ b/backpack/extensions/firstorder/variance/rnn.py @@ -9,7 +9,7 @@ class VarianceRNN(VarianceBaseModule): def __init__(self): """Initialization.""" - super(VarianceRNN, self).__init__( + super().__init__( params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], grad_extension=GradRNN(), sgs_extension=SGSRNN(), diff --git a/backpack/extensions/firstorder/variance/variance_base.py b/backpack/extensions/firstorder/variance/variance_base.py index 0341d490b..47e383d46 100644 --- a/backpack/extensions/firstorder/variance/variance_base.py +++ b/backpack/extensions/firstorder/variance/variance_base.py @@ -37,7 +37,7 @@ def __init__( for param_str in params: if not hasattr(self, param_str): setattr(self, param_str, self._make_param_function(param_str)) - super(VarianceBaseModule, self).__init__(params=params) + super().__init__(params=params) @staticmethod def _variance_from(grad: Tensor, sgs: Tensor, N: int) -> Tensor: @@ -55,7 +55,7 @@ def _make_param_function( """Creates a function that calculates variance of grad_batch. Args: - param(str): name of parameter + param: name of parameter Returns: function that calculates variance of grad_batch diff --git a/backpack/extensions/mat_to_mat_jac_base.py b/backpack/extensions/mat_to_mat_jac_base.py index c922f825b..23fd71d9e 100644 --- a/backpack/extensions/mat_to_mat_jac_base.py +++ b/backpack/extensions/mat_to_mat_jac_base.py @@ -4,18 +4,14 @@ from torch import Tensor from torch.nn import Module -from ..core.derivatives.basederivatives import BaseDerivatives, BaseParameterDerivatives -from .module_extension import ModuleExtension +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.extensions.module_extension import ModuleExtension class MatToJacMat(ModuleExtension): """Base class for backpropagation of matrices by multiplying with Jacobians.""" - def __init__( - self, - derivatives: Union[BaseDerivatives, BaseParameterDerivatives], - params: List[str] = None, - ): + def __init__(self, derivatives: BaseDerivatives, params: List[str] = None): """Initialization. Args: @@ -46,11 +42,10 @@ def backpropagate( derivative wrt input """ if isinstance(backproped, list): - M_list: List[Tensor] = [ + return [ self.derivatives.jac_t_mat_prod(module, grad_inp, grad_out, M) for M in backproped ] - return M_list else: return self.derivatives.jac_t_mat_prod( module, grad_inp, grad_out, backproped diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index 4c1accf36..25539c548 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -241,10 +241,7 @@ def __init__(self): Chooses exact loss strategy and savefield diag_ggn_exact_batch. """ - super().__init__( - loss_hessian_strategy=LossHessianStrategy.EXACT, - savefield="diag_ggn_exact_batch", - ) + super().__init__(LossHessianStrategy.EXACT, "diag_ggn_exact_batch") class BatchDiagGGNMC(BatchDiagGGN): @@ -269,10 +266,7 @@ def __init__(self, mc_samples: int = 1): mc_samples: Number of Monte-Carlo samples. Default: ``1``. """ self._mc_samples = mc_samples - super().__init__( - loss_hessian_strategy=LossHessianStrategy.SAMPLING, - savefield="diag_ggn_mc_batch", - ) + super().__init__(LossHessianStrategy.SAMPLING, "diag_ggn_mc_batch") def get_num_mc_samples(self) -> int: """Returns number of Monte-Carlo samples. diff --git a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py index 5d23cc200..4b445a63b 100644 --- a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py +++ b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py @@ -65,10 +65,12 @@ def _param( Returns: diagonal """ - JS: Tensor = getattr(self.derivatives, f"{param}_jac_t_mat_prod")( - module, grad_inp, grad_out, backproped, sum_batch=False - ) axis: Tuple[int] = (0, 1) if sum_batch else (0,) - return (JS ** 2).sum(axis=axis) + return ( + getattr(self.derivatives, f"{param}_jac_t_mat_prod")( + module, grad_inp, grad_out, backproped, sum_batch=False + ) + ** 2 + ).sum(axis=axis) return _param diff --git a/fully_documented.txt b/fully_documented.txt index f2e5646bf..8937eb7ae 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -41,6 +41,7 @@ backpack/extensions/secondorder/sqrt_ggn/ backpack/utils/linear.py backpack/utils/__init__.py +test/extensions/problem.py test/extensions/test_backprop_extension.py test/extensions/firstorder/firstorder_settings.py test/extensions/firstorder/variance/ @@ -61,4 +62,3 @@ test/core/derivatives/implementation/ test/core/derivatives/permute_settings.py test/core/derivatives/lstm_settings.py test/core/derivatives/pooling_adaptive_settings.py -test/extensions/problem.py diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index cbc2043a0..bdc281f51 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -2,7 +2,7 @@ from test.core.derivatives.implementation.base import DerivativesImplementation import torch -from torch import Tensor, zeros_like +from torch import Tensor, stack, zeros_like from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.lop import transposed_jacobian_vector_product @@ -25,31 +25,21 @@ def jac_vec_prod(self, vec) -> Tensor: return jacobian_vector_product(output, input, vec)[0] def jac_mat_prod(self, mat): # noqa: D102 - V = mat.shape[0] - - vecs = [mat[v] for v in range(V)] try: - jac_vec_prods = [self.jac_vec_prod(vec) for vec in vecs] + return stack([self.jac_vec_prod(vec) for vec in mat]) except RuntimeError: # A RuntimeError is thrown for RNNs on CUDA, # because PyTorch does not support double-backwards pass for them. # This is the recommended workaround. with torch.backends.cudnn.flags(enabled=False): - jac_vec_prods = [self.jac_vec_prod(vec) for vec in vecs] - - return torch.stack(jac_vec_prods) + return stack([self.jac_vec_prod(vec) for vec in mat]) def jac_t_vec_prod(self, vec): # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) return transposed_jacobian_vector_product(output, input, vec)[0] def jac_t_mat_prod(self, mat): # noqa: D102 - V = mat.shape[0] - - vecs = [mat[v] for v in range(V)] - jac_t_vec_prods = [self.jac_t_vec_prod(vec) for vec in vecs] - - return torch.stack(jac_t_vec_prods) + return stack([self.jac_t_vec_prod(vec) for vec in mat]) def weight_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 return self.param_jac_t_mat_prod("weight", mat, sum_batch) @@ -95,7 +85,7 @@ def param_jac_t_vec_prod(self, name, vec, sum_batch, axis_batch=0): for n_out, n_vec in zip(sample_outputs, sample_vecs) ] - return torch.stack(jac_t_sample_prods) + return stack(jac_t_sample_prods) def param_jac_t_mat_prod(self, name, mat, sum_batch, axis_batch=0): """Compute the product of jac_t and the given matrix. @@ -110,15 +100,12 @@ def param_jac_t_mat_prod(self, name, mat, sum_batch, axis_batch=0): Returns: torch.Tensor: product of jac_t and mat """ - V = mat.shape[0] - - vecs = [mat[v] for v in range(V)] - jac_t_vec_prods = [ - self.param_jac_t_vec_prod(name, vec, sum_batch, axis_batch=axis_batch) - for vec in vecs - ] - - return torch.stack(jac_t_vec_prods) + return stack( + [ + self.param_jac_t_vec_prod(name, vec, sum_batch, axis_batch=axis_batch) + for vec in mat + ] + ) def weight_jac_mat_prod(self, mat) -> Tensor: """Product of jacobian and matrix. @@ -149,12 +136,7 @@ def _param_jac_vec_prod(self, name, vec): return jacobian_vector_product(output, param, vec)[0] def _param_jac_mat_prod(self, name, mat): - V = mat.shape[0] - - vecs = [mat[v] for v in range(V)] - jac_vec_prods = [self._param_jac_vec_prod(name, vec) for vec in vecs] - - return torch.stack(jac_vec_prods) + return stack([self._param_jac_vec_prod(name, vec) for vec in mat]) def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102 def _sample_jac_t_mat_jac_prod(sample_idx, mat): diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 1e3e92b5a..05f25afe1 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -146,7 +146,7 @@ def forward_pass(self, sample_idx=None): input = self.input.clone().detach() target = self.target.clone().detach() else: - target = self.target.clone()[sample_idx].unsqueeze(0).detach() + target = self.target.split(1, dim=0)[sample_idx].detach() input = self.input.split(1, dim=0)[sample_idx].detach() output = self.model(input) From 0a819a801eaa341a5531287f07cbc867f2f98091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Mon, 5 Jul 2021 17:57:45 +0200 Subject: [PATCH 27/54] [core] BatchNormNd derivatives and tests (#179) * experiments with BatchNorm1d * experiments with BatchNorm1d * simplify * BatchNorm: test _jac_mat_prod and _jac_t_mat_prod in derivatives_test.py * BatchNorm for any dimensionality * refactor * check parameters * start work on training=False * training=False: jac_t_mat_prod and jac_mat_prod * BatchNorm: _weight_jac_t_mat_prod * BatchNorm: _bias_jac_t_mat_prod * format * clean up * fix * add old function _residual_mat_prod * fix docstring * BatchNorm: all derivatives and tests * delete batchnorm1d.py * Test: train BatchNorm before evaluation * Fix weight derivative * clean up * residual_mat_prod: checker, tech report title, move to parent class * simplify batchnorm_nd.py * align index names with index convention * refactor batchnorm_nd.py * refactor dictionaries * remove init, remove child classes * check module parameters * move running stats initialization from problem.py to batch_norm_settings.py * add docstring to _unsqueeze_free_axis * change interface of is_hessian_{zero|diagonal} * change interface of is_hessian_{zero|diagonal} * fix seed for BatchNorm3d * fix batch_norm_settings.py * format * [REF] Remove redundant `is True` * link tech report * remove dictionaries * format Co-authored-by: Felix Dangel --- backpack/core/derivatives/avgpoolnd.py | 2 +- backpack/core/derivatives/basederivatives.py | 10 +- backpack/core/derivatives/batchnorm1d.py | 122 ------- backpack/core/derivatives/batchnorm_nd.py | 310 ++++++++++++++++++ backpack/core/derivatives/conv_transposend.py | 2 +- backpack/core/derivatives/convnd.py | 2 +- backpack/core/derivatives/dropout.py | 2 +- backpack/core/derivatives/elementwise.py | 2 +- backpack/core/derivatives/elu.py | 2 +- backpack/core/derivatives/flatten.py | 2 +- backpack/core/derivatives/leakyrelu.py | 2 +- backpack/core/derivatives/linear.py | 5 +- backpack/core/derivatives/logsigmoid.py | 2 +- backpack/core/derivatives/maxpoolnd.py | 2 +- backpack/core/derivatives/relu.py | 2 +- backpack/core/derivatives/selu.py | 2 +- backpack/core/derivatives/sigmoid.py | 2 +- backpack/core/derivatives/tanh.py | 2 +- backpack/core/derivatives/zeropad2d.py | 2 +- .../curvmatprod/ggnmp/batchnorm1d.py | 4 +- .../extensions/curvmatprod/hmp/batchnorm1d.py | 4 +- .../extensions/curvmatprod/hmp/hmpbase.py | 2 +- .../extensions/curvmatprod/pchmp/pchmpbase.py | 8 +- .../firstorder/batch_grad/batchnorm1d.py | 4 +- .../firstorder/gradient/batchnorm1d.py | 4 +- .../secondorder/diag_hessian/diag_h_base.py | 4 +- .../extensions/secondorder/hbp/hbpbase.py | 4 +- fully_documented.txt | 2 + test/core/derivatives/__init__.py | 7 + test/core/derivatives/batch_norm_settings.py | 69 ++++ test/core/derivatives/derivatives_test.py | 44 ++- .../derivatives/implementation/backpack.py | 2 +- 32 files changed, 469 insertions(+), 166 deletions(-) delete mode 100644 backpack/core/derivatives/batchnorm1d.py create mode 100644 backpack/core/derivatives/batchnorm_nd.py create mode 100644 test/core/derivatives/batch_norm_settings.py diff --git a/backpack/core/derivatives/avgpoolnd.py b/backpack/core/derivatives/avgpoolnd.py index 618d8af81..d95e561c8 100644 --- a/backpack/core/derivatives/avgpoolnd.py +++ b/backpack/core/derivatives/avgpoolnd.py @@ -50,7 +50,7 @@ def get_avg_pool_parameters(self, module) -> Tuple[Any, Any, Any]: """ return module.stride, module.kernel_size, module.padding - def hessian_is_zero(self): + def hessian_is_zero(self, module): return True def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 1c583d0a2..f0115834e 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -140,11 +140,14 @@ def ea_jac_t_mat_jac_prod( """ raise NotImplementedError - def hessian_is_zero(self) -> bool: + def hessian_is_zero(self, module: Module) -> bool: """Returns whether Hessian is zero. I.e. whether ``∂²output[i] / ∂input[j] ∂input[k] = 0 ∀ i,j,k``. + Args: + module: current module to evaluate + # noqa: DAR202 Returns: whether Hessian is zero @@ -154,7 +157,7 @@ def hessian_is_zero(self) -> bool: """ raise NotImplementedError - def hessian_is_diagonal(self) -> bool: + def hessian_is_diagonal(self, module: Module) -> bool: """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`. The Hessian diagonal is only defined for layers that preserve the size @@ -162,6 +165,9 @@ def hessian_is_diagonal(self) -> bool: Must be implemented by descendants that don't implement ``hessian_is_zero``. + Args: + module: current module to evaluate + # noqa: DAR202 Returns: whether Hessian is diagonal diff --git a/backpack/core/derivatives/batchnorm1d.py b/backpack/core/derivatives/batchnorm1d.py deleted file mode 100644 index cbdf8cc82..000000000 --- a/backpack/core/derivatives/batchnorm1d.py +++ /dev/null @@ -1,122 +0,0 @@ -from warnings import warn - -from torch import einsum - -from backpack.core.derivatives.basederivatives import BaseParameterDerivatives - - -class BatchNorm1dDerivatives(BaseParameterDerivatives): - def hessian_is_zero(self): - return False - - def hessian_is_diagonal(self): - return False - - def _jac_mat_prod(self, module, g_inp, g_out, mat): - return self._jac_t_mat_prod(module, g_inp, g_out, mat) - - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): - """ - Note: - ----- - The Jacobian is *not independent* among the batch dimension, i.e. - D z_i = D z_i(x_1, ..., x_B). - - This structure breaks the computation of the GGN diagonal, - for curvature-matrix products it should still work. - - References: - ----------- - https://kevinzakka.github.io/2016/09/14/batch_normalization/ - https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html - """ - assert module.affine is True - - N = module.input0.size(0) - x_hat, var = self.get_normalized_input_and_var(module) - ivar = 1.0 / (var + module.eps).sqrt() - - dx_hat = einsum("vni,i->vni", (mat, module.weight)) - - jac_t_mat = N * dx_hat - jac_t_mat -= dx_hat.sum(1).unsqueeze(1).expand_as(jac_t_mat) - jac_t_mat -= einsum("ni,vsi,si->vni", (x_hat, dx_hat, x_hat)) - jac_t_mat = einsum("vni,i->vni", (jac_t_mat, ivar / N)) - - return jac_t_mat - - def get_normalized_input_and_var(self, module): - input = module.input0 - mean = input.mean(dim=0) - var = input.var(dim=0, unbiased=False) - return (input - mean) / (var + module.eps).sqrt(), var - - def _residual_mat_prod(self, module, g_inp, g_out, mat): - """Multiply with BatchNorm1d residual-matrix. - - Paul Fischer (GitHub: @paulkogni) contributed this code during a research - project in winter 2019. - - Details are described in - - - `TODO: Add tech report title` - _ - by Paul Fischer, 2020. - """ - N = module.input0.size(0) - x_hat, var = self.get_normalized_input_and_var(module) - gamma = module.weight - eps = module.eps - - factor = gamma / (N * (var + eps)) - - sum_127 = einsum("nc,vnc->vc", (x_hat, mat)) - sum_24 = einsum("nc->c", g_out[0]) - sum_3 = einsum("nc,vnc->vc", (g_out[0], mat)) - sum_46 = einsum("vnc->vc", mat) - sum_567 = einsum("nc,nc->c", (x_hat, g_out[0])) - - r_mat = -einsum("nc,vc->vnc", (g_out[0], sum_127)) - r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_24, sum_127)).unsqueeze(1).expand( - -1, N, -1 - ) - r_mat -= einsum("nc,vc->vnc", (x_hat, sum_3)) - r_mat += (1.0 / N) * einsum("nc,c,vc->vnc", (x_hat, sum_24, sum_46)) - - r_mat -= einsum("vnc,c->vnc", (mat, sum_567)) - r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_567, sum_46)).unsqueeze(1).expand( - -1, N, -1 - ) - r_mat += (3.0 / N) * einsum("nc,vc,c->vnc", (x_hat, sum_127, sum_567)) - - return einsum("c,vnc->vnc", (factor, r_mat)) - - def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): - x_hat, _ = self.get_normalized_input_and_var(module) - return einsum("ni,vi->vni", (x_hat, mat)) - - def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch): - if not sum_batch: - warn( - "BatchNorm batch summation disabled." - "This may not compute meaningful quantities" - ) - x_hat, _ = self.get_normalized_input_and_var(module) - equation = "vni,ni->v{}i".format("" if sum_batch is True else "n") - operands = [mat, x_hat] - return einsum(equation, operands) - - def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): - N = module.input0.size(0) - return mat.unsqueeze(1).repeat(1, N, 1) - - def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - if not sum_batch: - warn( - "BatchNorm batch summation disabled." - "This may not compute meaningful quantities" - ) - return mat - else: - N_axis = 1 - return mat.sum(N_axis) diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py new file mode 100644 index 000000000..4375c64e2 --- /dev/null +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -0,0 +1,310 @@ +"""Contains derivatives for BatchNorm.""" +from typing import List, Tuple, Union + +from torch import Size, Tensor, einsum +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives + + +class BatchNormNdDerivatives(BaseParameterDerivatives): + """Derivatives for BatchNorm1d, 2d and 3d. + + If training=False: saved statistics are used. + If training=True: statistics of current batch are used. + + Index convention: + n: batch axis + c: category axis + {empty}/l/hw/dhw: dimension axis for 0/1/2/3-dimensions (alternatively using xyz) + ...: usually for the remaining dimension axis (same as dhw) + + Links to PyTorch docs: + https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html + https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html + + As a starting point for derivative computation, see these references: + https://kevinzakka.github.io/2016/09/14/batch_normalization/ + https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html + """ + + def _check_parameters( + self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] + ) -> None: + if module.affine is False: + raise NotImplementedError("Only implemented for affine=True") + if module.track_running_stats is False: + raise NotImplementedError("Only implemented for track_running_stats=True") + + def hessian_is_zero( + self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] + ) -> bool: + """Whether hessian is zero. + + Args: + module: current module to evaluate + + Returns: + whether hessian is zero + + Raises: + NotImplementedError: if module is in evaluation mode + """ + if module.training: + return False + else: + raise NotImplementedError( + "hessian_is_zero is not tested for BatchNorm. Create an issue if you need it." + ) + + def hessian_is_diagonal( + self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] + ) -> bool: + """Whether hessian is diagonal. + + Args: + module: current module to evaluate + + Returns: + whether hessian is diagonal + + Raises: + NotImplementedError: if module is in evaluation mode + """ + if module.training: + return False + else: + raise NotImplementedError( + "hessian_is_diagonal is not tested for BatchNorm. " + "Create an issue if you need it." + ) + + def _jac_mat_prod( + self, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + ) -> Tensor: + return self._jac_t_mat_prod(module, g_inp, g_out, mat) + + def _jac_t_mat_prod( + self, + module: BatchNorm1d, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + ) -> Tensor: + self._check_parameters(module) + N: int = self._get_n_axis(module) + if module.training: + denominator: int = self._get_denominator(module) + x_hat, var = self._get_normalized_input_and_var(module) + ivar = 1.0 / (var + module.eps).sqrt() + + dx_hat: Tensor = einsum("vnc...,c->vnc...", mat, module.weight) + jac_t_mat = denominator * dx_hat + jac_t_mat -= dx_hat.sum( + self._get_free_axes(module), + keepdim=True, + ).expand_as(jac_t_mat) + equation = "nc...,vmcx,mcx->vnc...".replace( + "x", + { + 0: "", + 1: "x", + 2: "xy", + 3: "xyz", + }[N], + ) + jac_t_mat -= einsum(equation, x_hat, dx_hat, x_hat) + jac_t_mat = einsum("vnc...,c->vnc...", jac_t_mat, ivar / denominator) + return jac_t_mat + else: + return einsum( + "c,vnc...->vnc...", + ((module.running_var + module.eps) ** (-0.5)) * module.weight, + mat, + ) + + def _weight_jac_mat_prod( + self, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + ) -> Tensor: + x_hat, _ = self._get_normalized_input_and_var(module) + return einsum("nc...,vc->vnc...", x_hat, mat) + + def _weight_jac_t_mat_prod( + self, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + x_hat, _ = self._get_normalized_input_and_var(module) + return einsum(f"vnc...,nc...->v{'' if sum_batch else 'n'}c", mat, x_hat) + + def _bias_jac_mat_prod( + self, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + ) -> Tensor: + out = self._unsqueeze_free_axis(module, mat, 1) + dim_expand: List[int] = [-1, module.input0.shape[0], -1] + for n in range(self._get_n_axis(module)): + dim_expand.append(module.input0.shape[2 + n]) + return out.expand(*dim_expand) + + def _bias_jac_t_mat_prod( + self, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + ) -> Tensor: + axis_sum: Tuple[int] = self._get_free_axes(module, with_batch_axis=sum_batch) + return mat.sum(dim=axis_sum) if axis_sum else mat + + def _residual_mat_prod( + self, + module: BatchNorm1d, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + ) -> Tensor: + """Multiply with BatchNorm1d residual-matrix. + + Paul Fischer (GitHub: @paulkogni) contributed this code during a research + project in winter 2019. + + Details are described in + + `HESSIAN BACKPROPAGATION FOR BATCHNORM` + + by Paul Fischer, 2020. + + Args: + module: module + g_inp: input gradients + g_out: output gradients + mat: matrix to multiply + + Returns: + product + + Raises: + NotImplementedError: if used with a not supported mode or input + """ # noqa: B950 + self._check_parameters(module) + if module.training is False: + raise NotImplementedError("residual_mat_prod works only for training mode.") + if module.input0.dim() != 2: + raise NotImplementedError( + "residual_mat_prod is implemented only for 0 dimensions. " + "If you need more dimension make a feature request." + ) + + N = module.input0.size(0) + x_hat, var = self._get_normalized_input_and_var(module) + gamma = module.weight + eps = module.eps + + factor = gamma / (N * (var + eps)) + + sum_127 = einsum("nc,vnc->vc", x_hat, mat) + sum_24 = einsum("nc->c", g_out[0]) + sum_3 = einsum("nc,vnc->vc", g_out[0], mat) + sum_46 = einsum("vnc->vc", mat) + sum_567 = einsum("nc,nc->c", x_hat, g_out[0]) + + r_mat = -einsum("nc,vc->vnc", g_out[0], sum_127) + r_mat += (1.0 / N) * einsum("c,vc->vc", sum_24, sum_127).unsqueeze(1).expand( + -1, N, -1 + ) + r_mat -= einsum("nc,vc->vnc", x_hat, sum_3) + r_mat += (1.0 / N) * einsum("nc,c,vc->vnc", x_hat, sum_24, sum_46) + + r_mat -= einsum("vnc,c->vnc", mat, sum_567) + r_mat += (1.0 / N) * einsum("c,vc->vc", sum_567, sum_46).unsqueeze(1).expand( + -1, N, -1 + ) + r_mat += (3.0 / N) * einsum("nc,vc,c->vnc", x_hat, sum_127, sum_567) + + return einsum("c,vnc->vnc", factor, r_mat) + + ############################################################### + # HELPER FUNCTIONS ### + ############################################################### + def _get_normalized_input_and_var( + self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] + ) -> Tuple[Tensor, Tensor]: + input: Tensor = module.input0 + if module.training: + dim: Tuple[int] = self._get_free_axes(module, index_batch=0) + mean: Tensor = input.mean(dim=dim) + var: Tensor = input.var(dim=dim, unbiased=False) + else: + mean: Tensor = module.running_mean + var: Tensor = module.running_var + mean: Tensor = self._unsqueeze_free_axis(module, mean, 0) + var_expanded: Tensor = self._unsqueeze_free_axis(module, var, 0) + return (input - mean) / (var_expanded + module.eps).sqrt(), var + + def _get_denominator( + self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] + ) -> int: + shape_input: Size = module.input0.shape + free_axes: Tuple[int] = self._get_free_axes(module, index_batch=0) + denominator: int = 1 + for index in free_axes: + denominator *= shape_input[index] + return denominator + + @staticmethod + def _get_n_axis(module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]) -> int: + return module.input0.dim() - 2 + + def _unsqueeze_free_axis( + self, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + tensor: Tensor, + index_batch: int, + ) -> Tensor: + """Unsqueezes the free dimensions. + + This function is useful to avoid broadcasting. + Also useful when applying .expand(self._get_free_axes()) afterwards. + + Args: + module: extended module + tensor: the tensor to operate on + index_batch: the batch axes index + + Returns: + tensor with the free dimensions unsqueezed. + """ + out = tensor.unsqueeze(index_batch) + for _ in range(self._get_n_axis(module)): + out = out.unsqueeze(-1) + return out + + def _get_free_axes( + self, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + with_batch_axis: bool = True, + index_batch: int = 1, + ) -> Tuple[int]: + free_axes: List[int] = [] + if with_batch_axis: + free_axes.append(index_batch) + for n in range(self._get_n_axis(module)): + free_axes.append(index_batch + n + 2) + return tuple(free_axes) diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index 9cf9f223d..8491ebbd0 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -45,7 +45,7 @@ def __init__(self, N): raise ValueError(f"ConvTranspose{N}d not supported.") self.conv_dims = N - def hessian_is_zero(self): + def hessian_is_zero(self, module): return True def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py index e0c5bf8cf..9f928f438 100644 --- a/backpack/core/derivatives/convnd.py +++ b/backpack/core/derivatives/convnd.py @@ -54,7 +54,7 @@ def __init__(self, N): raise ValueError("{}-dimensional Conv. is not implemented.".format(N)) self.conv_dims = N - def hessian_is_zero(self): + def hessian_is_zero(self, module): return True def get_unfolded_input(self, module): diff --git a/backpack/core/derivatives/dropout.py b/backpack/core/derivatives/dropout.py index 964deaa77..2ddc2aa23 100644 --- a/backpack/core/derivatives/dropout.py +++ b/backpack/core/derivatives/dropout.py @@ -4,7 +4,7 @@ class DropoutDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): return True def df(self, module, g_inp, g_out): diff --git a/backpack/core/derivatives/elementwise.py b/backpack/core/derivatives/elementwise.py index d6f10bce5..602706f27 100644 --- a/backpack/core/derivatives/elementwise.py +++ b/backpack/core/derivatives/elementwise.py @@ -65,7 +65,7 @@ def hessian_diagonal(self, module, g_inp, g_out): return self.d2f(module, g_inp, g_out) * g_out[0] - def hessian_is_diagonal(self): + def hessian_is_diagonal(self, module): """Elementwise activation function Hessians are diagonal. Returns: diff --git a/backpack/core/derivatives/elu.py b/backpack/core/derivatives/elu.py index 74092e883..99cafd0b5 100644 --- a/backpack/core/derivatives/elu.py +++ b/backpack/core/derivatives/elu.py @@ -7,7 +7,7 @@ class ELUDerivatives(ElementwiseDerivatives): """Implement first- and second-order partial derivatives of ELU.""" - def hessian_is_zero(self): + def hessian_is_zero(self, module): """`ELU''(x) ≠ 0`.""" return False diff --git a/backpack/core/derivatives/flatten.py b/backpack/core/derivatives/flatten.py index 8366c770e..4f8dd72d0 100644 --- a/backpack/core/derivatives/flatten.py +++ b/backpack/core/derivatives/flatten.py @@ -2,7 +2,7 @@ class FlattenDerivatives(BaseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): return True def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): diff --git a/backpack/core/derivatives/leakyrelu.py b/backpack/core/derivatives/leakyrelu.py index 60a650c93..89c2f9ee4 100644 --- a/backpack/core/derivatives/leakyrelu.py +++ b/backpack/core/derivatives/leakyrelu.py @@ -4,7 +4,7 @@ class LeakyReLUDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): """`LeakyReLU''(x) = 0`.""" return True diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index a895d01eb..5d8238d91 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -18,9 +18,12 @@ class LinearDerivatives(BaseParameterDerivatives): * i: Input dimension """ - def hessian_is_zero(self) -> bool: + def hessian_is_zero(self, module: Linear) -> bool: """Linear layer output is linear w.r.t. to its input. + Args: + module: current module + Returns: True """ diff --git a/backpack/core/derivatives/logsigmoid.py b/backpack/core/derivatives/logsigmoid.py index f9e016c11..78a882241 100644 --- a/backpack/core/derivatives/logsigmoid.py +++ b/backpack/core/derivatives/logsigmoid.py @@ -4,7 +4,7 @@ class LogSigmoidDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): """`logsigmoid''(x) ≠ 0`.""" return False diff --git a/backpack/core/derivatives/maxpoolnd.py b/backpack/core/derivatives/maxpoolnd.py index b3960459b..ac39765a2 100644 --- a/backpack/core/derivatives/maxpoolnd.py +++ b/backpack/core/derivatives/maxpoolnd.py @@ -32,7 +32,7 @@ def get_pooling_idx(self, module): ) return pool_idx - def hessian_is_zero(self): + def hessian_is_zero(self, module): return True def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): diff --git a/backpack/core/derivatives/relu.py b/backpack/core/derivatives/relu.py index eae9d5ebf..1aa775fbb 100644 --- a/backpack/core/derivatives/relu.py +++ b/backpack/core/derivatives/relu.py @@ -4,7 +4,7 @@ class ReLUDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): """`ReLU''(x) = 0`.""" return True diff --git a/backpack/core/derivatives/selu.py b/backpack/core/derivatives/selu.py index 33c4a9ceb..fda077a5d 100644 --- a/backpack/core/derivatives/selu.py +++ b/backpack/core/derivatives/selu.py @@ -10,7 +10,7 @@ class SELUDerivatives(ElementwiseDerivatives): alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 - def hessian_is_zero(self): + def hessian_is_zero(self, module): """`SELU''(x) != 0`.""" return False diff --git a/backpack/core/derivatives/sigmoid.py b/backpack/core/derivatives/sigmoid.py index 5b45a114b..c5b5437fe 100644 --- a/backpack/core/derivatives/sigmoid.py +++ b/backpack/core/derivatives/sigmoid.py @@ -2,7 +2,7 @@ class SigmoidDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): """`σ''(x) ≠ 0`.""" return False diff --git a/backpack/core/derivatives/tanh.py b/backpack/core/derivatives/tanh.py index 525cb3aa2..4bf99c849 100644 --- a/backpack/core/derivatives/tanh.py +++ b/backpack/core/derivatives/tanh.py @@ -2,7 +2,7 @@ class TanhDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): return False def df(self, module, g_inp, g_out): diff --git a/backpack/core/derivatives/zeropad2d.py b/backpack/core/derivatives/zeropad2d.py index 197566461..6e121e099 100644 --- a/backpack/core/derivatives/zeropad2d.py +++ b/backpack/core/derivatives/zeropad2d.py @@ -5,7 +5,7 @@ class ZeroPad2dDerivatives(BaseDerivatives): - def hessian_is_zero(self): + def hessian_is_zero(self, module): return True def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): diff --git a/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py b/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py index 03c05ffa5..aeb6f3877 100644 --- a/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py +++ b/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py @@ -1,11 +1,11 @@ -from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase class GGNMPBatchNorm1d(GGNMPBase): def __init__(self): super().__init__( - derivatives=BatchNorm1dDerivatives(), params=["weight", "bias"] + derivatives=BatchNormNdDerivatives(), params=["weight", "bias"] ) def weight(self, ext, module, g_inp, g_out, backproped): diff --git a/backpack/extensions/curvmatprod/hmp/batchnorm1d.py b/backpack/extensions/curvmatprod/hmp/batchnorm1d.py index d441a82c4..388180ab1 100644 --- a/backpack/extensions/curvmatprod/hmp/batchnorm1d.py +++ b/backpack/extensions/curvmatprod/hmp/batchnorm1d.py @@ -1,11 +1,11 @@ -from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase class HMPBatchNorm1d(HMPBase): def __init__(self): super().__init__( - derivatives=BatchNorm1dDerivatives(), params=["weight", "bias"] + derivatives=BatchNormNdDerivatives(), params=["weight", "bias"] ) def weight(self, ext, module, g_inp, g_out, backproped): diff --git a/backpack/extensions/curvmatprod/hmp/hmpbase.py b/backpack/extensions/curvmatprod/hmp/hmpbase.py index 013582c06..459be8fa4 100644 --- a/backpack/extensions/curvmatprod/hmp/hmpbase.py +++ b/backpack/extensions/curvmatprod/hmp/hmpbase.py @@ -29,7 +29,7 @@ def h_in_mat_prod(mat): result = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, result) # Multiply with the residual term: mat → [∑ᵢ Hzᵢ(x) δzᵢ] mat. - if not self.derivatives.hessian_is_zero(): + if not self.derivatives.hessian_is_zero(module): result += self.derivatives.residual_mat_prod(module, g_inp, g_out, mat) return result diff --git a/backpack/extensions/curvmatprod/pchmp/pchmpbase.py b/backpack/extensions/curvmatprod/pchmp/pchmpbase.py index 4c6cd1a07..8438c3750 100644 --- a/backpack/extensions/curvmatprod/pchmp/pchmpbase.py +++ b/backpack/extensions/curvmatprod/pchmp/pchmpbase.py @@ -20,9 +20,9 @@ def backpropagate(self, ext, module, g_inp, g_out, backproped): Given mat → ℋz(x) mat, backpropagate mat → ℋx mat. """ - diagonal_or_zero_residual = ( - self.derivatives.hessian_is_zero() or self.derivatives.hessian_is_diagonal() - ) + diagonal_or_zero_residual = self.derivatives.hessian_is_zero( + module + ) or self.derivatives.hessian_is_diagonal(module) if not diagonal_or_zero_residual: raise ValueError("Only linear or element-wise operations supported.") @@ -45,7 +45,7 @@ def h_in_mat_prod(mat): result = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, result) # Multiply with the residual term: mat → [∑ᵢ Hzᵢ(x) δzᵢ] mat. - if not self.derivatives.hessian_is_zero(): + if not self.derivatives.hessian_is_zero(module): result += self.modified_residual_mat_prod( ext, module, g_inp, g_out, mat, modify ) diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py b/backpack/extensions/firstorder/batch_grad/batchnorm1d.py index 74d8737d9..46a41b739 100644 --- a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py +++ b/backpack/extensions/firstorder/batch_grad/batchnorm1d.py @@ -1,9 +1,9 @@ -from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase class BatchGradBatchNorm1d(BatchGradBase): def __init__(self): super().__init__( - derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"] + derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/firstorder/gradient/batchnorm1d.py b/backpack/extensions/firstorder/gradient/batchnorm1d.py index 5e0f3b6fd..92dbce28c 100644 --- a/backpack/extensions/firstorder/gradient/batchnorm1d.py +++ b/backpack/extensions/firstorder/gradient/batchnorm1d.py @@ -1,4 +1,4 @@ -from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives from .base import GradBaseModule @@ -6,5 +6,5 @@ class GradBatchNorm1d(GradBaseModule): def __init__(self): super().__init__( - derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"] + derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] ) diff --git a/backpack/extensions/secondorder/diag_hessian/diag_h_base.py b/backpack/extensions/secondorder/diag_hessian/diag_h_base.py index 58c15ee60..acf58e718 100644 --- a/backpack/extensions/secondorder/diag_hessian/diag_h_base.py +++ b/backpack/extensions/secondorder/diag_hessian/diag_h_base.py @@ -24,9 +24,9 @@ def backpropagate(self, ext, module, g_inp, g_out, backproped): return {"matrices": bp_matrices, "signs": bp_signs} def __local_curvatures(self, module, g_inp, g_out): - if self.derivatives.hessian_is_zero(): + if self.derivatives.hessian_is_zero(module): return [] - if not self.derivatives.hessian_is_diagonal(): + if not self.derivatives.hessian_is_diagonal(module): raise NotImplementedError def positive_part(sign, H): diff --git a/backpack/extensions/secondorder/hbp/hbpbase.py b/backpack/extensions/secondorder/hbp/hbpbase.py index 6bf2647a8..e6258b79b 100644 --- a/backpack/extensions/secondorder/hbp/hbpbase.py +++ b/backpack/extensions/secondorder/hbp/hbpbase.py @@ -34,10 +34,10 @@ def backpropagate_batch_average(self, ext, module, g_inp, g_out, H): return ggn def second_order_module_effects(self, module, g_inp, g_out): - if self.derivatives.hessian_is_zero(): + if self.derivatives.hessian_is_zero(module): return None - elif not self.derivatives.hessian_is_diagonal(): + elif not self.derivatives.hessian_is_diagonal(module): raise NotImplementedError( "Residual terms are only supported for elementwise functions" ) diff --git a/fully_documented.txt b/fully_documented.txt index 8937eb7ae..608c3a7ef 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -10,6 +10,7 @@ backpack/core/derivatives/permute.py backpack/core/derivatives/lstm.py backpack/core/derivatives/linear.py backpack/core/derivatives/adaptive_avg_pool_nd.py +backpack/core/derivatives/batchnorm_nd.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py @@ -62,3 +63,4 @@ test/core/derivatives/implementation/ test/core/derivatives/permute_settings.py test/core/derivatives/lstm_settings.py test/core/derivatives/pooling_adaptive_settings.py +test/core/derivatives/batch_norm_settings.py diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index c63ec0b7d..28fdd6b9b 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -10,6 +10,9 @@ AvgPool1d, AvgPool2d, AvgPool3d, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv1d, Conv2d, Conv3d, @@ -39,6 +42,7 @@ from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives from backpack.core.derivatives.conv1d import Conv1DDerivatives from backpack.core.derivatives.conv2d import Conv2DDerivatives from backpack.core.derivatives.conv3d import Conv3DDerivatives @@ -96,4 +100,7 @@ AdaptiveAvgPool1d: AdaptiveAvgPool1dDerivatives, AdaptiveAvgPool2d: AdaptiveAvgPool2dDerivatives, AdaptiveAvgPool3d: AdaptiveAvgPool3dDerivatives, + BatchNorm1d: BatchNormNdDerivatives, + BatchNorm2d: BatchNormNdDerivatives, + BatchNorm3d: BatchNormNdDerivatives, } diff --git a/test/core/derivatives/batch_norm_settings.py b/test/core/derivatives/batch_norm_settings.py new file mode 100644 index 000000000..621cd7332 --- /dev/null +++ b/test/core/derivatives/batch_norm_settings.py @@ -0,0 +1,69 @@ +"""Test configurations for `backpack.core.derivatives` BatchNorm layers. + +Required entries: + "module_fn" (callable): Contains a model constructed from `torch.nn` layers + "input_fn" (callable): Used for specifying input function + +Optional entries: + "target_fn" (callable): Fetches the groundtruth/target classes + of regression/classification task + "loss_function_fn" (callable): Loss function used in the model + "device" [list(torch.device)]: List of devices to run the test on. + "id_prefix" (str): Prefix to be included in the test name. + "seed" (int): seed for the random number for torch.rand +""" +from typing import Union + +from torch import rand, rand_like +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + + +def _initialize_training_false( + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] +) -> Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]: + module.running_mean = rand_like(module.running_mean) + module.running_var = rand_like(module.running_var) + module.weight.data = rand_like(module.weight) + module.bias.data = rand_like(module.bias) + return module.train(False) + + +BATCH_NORM_SETTINGS = [ + { + "module_fn": lambda: BatchNorm1d(num_features=7), + "input_fn": lambda: rand(size=(5, 7)), + }, + { + "module_fn": lambda: BatchNorm1d(num_features=7), + "input_fn": lambda: rand(size=(5, 7, 4)), + }, + { + "module_fn": lambda: BatchNorm2d(num_features=7), + "input_fn": lambda: rand(size=(5, 7, 3, 4)), + }, + { + "module_fn": lambda: BatchNorm3d(num_features=7), + "input_fn": lambda: rand(size=(5, 7, 3, 4, 2)), + "seed": 1, + }, + { + "module_fn": lambda: _initialize_training_false(BatchNorm1d(num_features=7)), + "input_fn": lambda: rand(size=(5, 7)), + "id_prefix": "training=False", + }, + { + "module_fn": lambda: _initialize_training_false(BatchNorm1d(num_features=7)), + "input_fn": lambda: rand(size=(5, 7, 4)), + "id_prefix": "training=False", + }, + { + "module_fn": lambda: _initialize_training_false(BatchNorm2d(num_features=7)), + "input_fn": lambda: rand(size=(5, 7, 3, 4)), + "id_prefix": "training=False", + }, + { + "module_fn": lambda: _initialize_training_false(BatchNorm3d(num_features=7)), + "input_fn": lambda: rand(size=(5, 7, 3, 4, 2)), + "id_prefix": "training=False", + }, +] diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 50ed3f321..7e23e187d 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -8,6 +8,7 @@ """ from test.automated_test import check_sizes_and_values +from test.core.derivatives.batch_norm_settings import BATCH_NORM_SETTINGS from test.core.derivatives.implementation.autograd import AutogradDerivatives from test.core.derivatives.implementation.backpack import BackpackDerivatives from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS @@ -46,11 +47,18 @@ PERMUTE_PROBLEMS = make_test_problems(PERMUTE_SETTINGS) PERMUTE_IDS = [problem.make_id() for problem in PERMUTE_PROBLEMS] +BATCH_NORM_PROBLEMS = make_test_problems(BATCH_NORM_SETTINGS) +BATCH_NORM_IDS = [problem.make_id() for problem in BATCH_NORM_PROBLEMS] + @pytest.mark.parametrize( "problem", - NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + LSTM_PROBLEMS, - ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS, + NO_LOSS_PROBLEMS + + RNN_PROBLEMS + + PERMUTE_PROBLEMS + + LSTM_PROBLEMS + + BATCH_NORM_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS + BATCH_NORM_IDS, ) def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: """Test the Jacobian-matrix product. @@ -71,8 +79,12 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: @pytest.mark.parametrize( "problem", - NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + LSTM_PROBLEMS, - ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS, + NO_LOSS_PROBLEMS + + RNN_PROBLEMS + + PERMUTE_PROBLEMS + + LSTM_PROBLEMS + + BATCH_NORM_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS + BATCH_NORM_IDS, ) def test_jac_t_mat_prod(problem: DerivativesTestProblem, request, V: int = 3) -> None: """Test the transposed Jacobian-matrix product. @@ -227,7 +239,11 @@ def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): [True, False], ids=["save_memory=True", "save_memory=False"], ) -@pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) +@pytest.mark.parametrize( + "problem", + PROBLEMS_WITH_WEIGHTS + BATCH_NORM_PROBLEMS, + ids=IDS_WITH_WEIGHTS + BATCH_NORM_IDS, +) def test_weight_jac_t_mat_prod( problem: DerivativesTestProblem, sum_batch: bool, save_memory: bool, V: int = 3 ) -> None: @@ -252,7 +268,11 @@ def test_weight_jac_t_mat_prod( problem.tear_down() -@pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS) +@pytest.mark.parametrize( + "problem", + PROBLEMS_WITH_WEIGHTS + BATCH_NORM_PROBLEMS, + ids=IDS_WITH_WEIGHTS + BATCH_NORM_IDS, +) def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: """Test the Jacobian-matrix product w.r.t. to the weight. @@ -281,7 +301,11 @@ def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> Non @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize("problem", PROBLEMS_WITH_BIAS, ids=IDS_WITH_BIAS) +@pytest.mark.parametrize( + "problem", + PROBLEMS_WITH_BIAS + BATCH_NORM_PROBLEMS, + ids=IDS_WITH_BIAS + BATCH_NORM_IDS, +) def test_bias_jac_t_mat_prod( problem: DerivativesTestProblem, sum_batch: bool, V: int = 3 ) -> None: @@ -302,7 +326,11 @@ def test_bias_jac_t_mat_prod( problem.tear_down() -@pytest.mark.parametrize("problem", PROBLEMS_WITH_BIAS, ids=IDS_WITH_BIAS) +@pytest.mark.parametrize( + "problem", + PROBLEMS_WITH_BIAS + BATCH_NORM_PROBLEMS, + ids=IDS_WITH_BIAS + BATCH_NORM_IDS, +) def test_bias_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: """Test the Jacobian-matrix product w.r.t. to the bias. diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index a31db1f80..97206c16f 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -122,7 +122,7 @@ def input_hessian_via_sqrt_hessian(self, mc_samples=None) -> Tensor: ) def hessian_is_zero(self) -> bool: # noqa: D102 - return self.problem.derivative.hessian_is_zero() + return self.problem.derivative.hessian_is_zero(self.problem.module) def _sample_hessians_from_sqrt(self, sqrt): """Convert individual matrix square root into individual full matrix. From 49c0aaa8e57b8fbe82f496750b8a80b755d476e2 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Tue, 6 Jul 2021 09:48:15 +0200 Subject: [PATCH 28/54] [CI] Test multiple torch versions, fix compatibility with `torch==1.6.0` (#198) Notes on excluded settings: - `torch==1.6.0` not available for `python3.9` - `torch==1.7.1`, `python3.9`: segmentation fault in GitHub action after successfully running the tests --- * [CI] Run tests on `torch==1.6.0` and `1.9.0` * [FIX] No `torch==1.6.0` for Python 3.9 * [CI] Improve workflow name * [FIX] Try fix torch version * [CI] Add `torch==1.7.1,1.8.0` to test * [FIX] Make `einsum` equation compatible with `torch>=1.6.0` * [FIX] flake8 * [CI] Shorten job names * [TEST] Increase rtol * [FIX] Increase rtol * [FIX] Make not supported exception version sensitive * [TEST] Exclude py3.9 torch1.7.1 (segmentation fault) https://github.com/f-dangel/backpack/pull/198/checks?check_run_id=2996074583#step:6:2994 * [REF] Make einsum version-specific, deprecation notes * [FIX] f-string Co-authored-by: Felix Dangel --- .github/workflows/test.yaml | 13 +++++++-- backpack/core/derivatives/batchnorm_nd.py | 29 +++++++++++++------- backpack/core/derivatives/linear.py | 18 +++++++----- backpack/utils/__init__.py | 2 ++ test/core/derivatives/derivatives_test.py | 2 +- test/extensions/secondorder/hbp/test_kfac.py | 5 +++- test/extensions/secondorder/hbp/test_kflr.py | 5 +++- test/extensions/secondorder/hbp/test_kfra.py | 5 +++- 8 files changed, 55 insertions(+), 24 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2f63340eb..3a0d4871a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,23 +12,30 @@ on: jobs: tests: - name: "Python ${{ matrix.python-version }}" + name: "py${{ matrix.python-version }} torch${{ matrix.pytorch-version}}" runs-on: ubuntu-latest env: USING_COVERAGE: '3.7,3.9' strategy: matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: [3.7, 3.8, 3.9] + pytorch-version: [1.6.0, 1.7.1, 1.8.0, 1.9.0] + exclude: + - pytorch-version: 1.6.0 + python-version: 3.9 + - pytorch-version: 1.7.1 + python-version: 3.9 steps: - uses: actions/checkout@v1 - uses: actions/setup-python@v1 with: - python-version: "${{ matrix.python-version }}" + python-version: ${{ matrix.python-version }} - name: Install Dependencies run: | python -m pip install --upgrade pip make install-test + pip install torch==${{ matrix.pytorch-version }} torchvision - name: Run test if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref) run: | diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index 4375c64e2..56fc3a1fd 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -5,6 +5,7 @@ from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils import TORCH_VERSION, VERSION_1_9_0 class BatchNormNdDerivatives(BaseParameterDerivatives): @@ -109,16 +110,13 @@ def _jac_t_mat_prod( self._get_free_axes(module), keepdim=True, ).expand_as(jac_t_mat) - equation = "nc...,vmcx,mcx->vnc...".replace( - "x", - { - 0: "", - 1: "x", - 2: "xy", - 3: "xyz", - }[N], + spatial_dims = "xyz"[:N] + jac_t_mat -= einsum( + f"nc...,vmc{spatial_dims},mc{spatial_dims}->vnc...", + x_hat, + dx_hat, + x_hat, ) - jac_t_mat -= einsum(equation, x_hat, dx_hat, x_hat) jac_t_mat = einsum("vnc...,c->vnc...", jac_t_mat, ivar / denominator) return jac_t_mat else: @@ -147,7 +145,18 @@ def _weight_jac_t_mat_prod( sum_batch: bool = True, ) -> Tensor: x_hat, _ = self._get_normalized_input_and_var(module) - return einsum(f"vnc...,nc...->v{'' if sum_batch else 'n'}c", mat, x_hat) + + if TORCH_VERSION >= VERSION_1_9_0: + equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c" + # TODO Remove else-branch after deprecating torch<1.9.0 + else: + N: int = self._get_n_axis(module) + spatial_dims = "xyz"[:N] + equation = ( + f"vnc{spatial_dims},nc{spatial_dims}->v{'' if sum_batch else 'n'}c" + ) + + return einsum(equation, mat, x_hat) def _bias_jac_mat_prod( self, diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index 5d8238d91..a41fbb47a 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -5,6 +5,7 @@ from torch.nn import Linear from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils import TORCH_VERSION, VERSION_1_9_0 class LinearDerivatives(BaseParameterDerivatives): @@ -137,14 +138,16 @@ def _weight_jac_t_mat_prod( """ d_weight = module.input0 - if self._has_additional_dims(module): - # Flatten additional dimensions because they cannot be represented as - # ellipsis. WAITING https://github.com/pytorch/pytorch/issues/45854 - d_weight = d_weight.flatten(start_dim=1, end_dim=-2) - mat = mat.flatten(start_dim=2, end_dim=-2) - equation = f"vnao,nai->v{'' if sum_batch else 'n'}oi" + if TORCH_VERSION >= VERSION_1_9_0: + equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi" + # TODO Remove else-branch after deprecating torch<1.9.0 else: - equation = f"vno,ni->v{'' if sum_batch else 'n'}oi" + if self._has_additional_dims(module): + d_weight = d_weight.flatten(start_dim=1, end_dim=-2) + mat = mat.flatten(start_dim=2, end_dim=-2) + equation = f"vnao,nai->v{'' if sum_batch else 'n'}oi" + else: + equation = f"vno,ni->v{'' if sum_batch else 'n'}oi" return einsum(equation, mat, d_weight) @@ -203,6 +206,7 @@ def _bias_jac_t_mat_prod( return einsum(equation, mat) + # TODO Remove after deprecating torch<1.9.0 @classmethod def _has_additional_dims(cls, module: Linear) -> bool: """Return whether the input to a linear layer has additional (>1) dimensions. diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index 7d98a7c33..afc0887d3 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -3,4 +3,6 @@ from pkg_resources import get_distribution, packaging TORCH_VERSION = packaging.version.parse(get_distribution("torch").version) +VERSION_1_9_0 = packaging.version.parse("1.9.0") VERSION_1_8_0 = packaging.version.parse("1.8.0") +VERSION_1_6_0 = packaging.version.parse("1.6.0") diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 7e23e187d..b5c1e83f6 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -264,7 +264,7 @@ def test_weight_jac_t_mat_prod( ) autograd_res = AutogradDerivatives(problem).weight_jac_t_mat_prod(mat, sum_batch) - check_sizes_and_values(autograd_res, backpack_res) + check_sizes_and_values(autograd_res, backpack_res, rtol=5e-5) problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kfac.py b/test/extensions/secondorder/hbp/test_kfac.py index d6205e290..ec494a5de 100644 --- a/test/extensions/secondorder/hbp/test_kfac.py +++ b/test/extensions/secondorder/hbp/test_kfac.py @@ -6,6 +6,8 @@ import pytest +from backpack.utils import TORCH_VERSION, VERSION_1_6_0 + NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -19,7 +21,8 @@ def test_kfac_not_supported(problem): """ problem.set_up() - with pytest.raises(NotImplementedError): + exception = RuntimeError if TORCH_VERSION == VERSION_1_6_0 else NotImplementedError + with pytest.raises(exception): BackpackExtensions(problem).kfac() problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kflr.py b/test/extensions/secondorder/hbp/test_kflr.py index 79e46f186..0a4e04e46 100644 --- a/test/extensions/secondorder/hbp/test_kflr.py +++ b/test/extensions/secondorder/hbp/test_kflr.py @@ -6,6 +6,8 @@ import pytest +from backpack.utils import TORCH_VERSION, VERSION_1_6_0 + NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -19,7 +21,8 @@ def test_kflr_not_supported(problem): """ problem.set_up() - with pytest.raises(NotImplementedError): + exception = RuntimeError if TORCH_VERSION == VERSION_1_6_0 else NotImplementedError + with pytest.raises(exception): BackpackExtensions(problem).kflr() problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kfra.py b/test/extensions/secondorder/hbp/test_kfra.py index 387438308..943fd4e41 100644 --- a/test/extensions/secondorder/hbp/test_kfra.py +++ b/test/extensions/secondorder/hbp/test_kfra.py @@ -6,6 +6,8 @@ import pytest +from backpack.utils import TORCH_VERSION, VERSION_1_6_0 + NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -19,7 +21,8 @@ def test_kfra_not_supported(problem): """ problem.set_up() - with pytest.raises(NotImplementedError): + exception = RuntimeError if TORCH_VERSION == VERSION_1_6_0 else NotImplementedError + with pytest.raises(exception): BackpackExtensions(problem).kfra() problem.tear_down() From 1d61d9649f10d6486d8788e33cf938cdbcc8ee1f Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Tue, 6 Jul 2021 10:52:22 +0200 Subject: [PATCH 29/54] [ADD] Application of `weight_jac_t` for a subset of samples (#195) Auxiliary: Implement `hessian_is_zero` for BN eval mode (tests from #183) --- * [ADD] Extend `weight_jac_t` by subsampling argument * [ADD] Sub-sampling for batch normalization, adapt interface * [REF] Separate fixtures for filtering weights and matrix generation * [REF] Shorten instantiation fixture name * [TEST] Add BN setting to `weight_jac_t` tests * [DOC] Add batch norm test setting to fully-documented * Merge branch to test multiple PyTorch versions * [DOC] Fix grammar * [DEL] Batchnorm1d file --- backpack/core/derivatives/basederivatives.py | 16 ++- backpack/core/derivatives/batchnorm_nd.py | 29 +++-- backpack/core/derivatives/conv_transposend.py | 23 +++- backpack/core/derivatives/convnd.py | 56 ++++++-- backpack/core/derivatives/linear.py | 16 ++- backpack/core/derivatives/shape_check.py | 9 +- backpack/utils/subsampling.py | 28 ++++ fully_documented.txt | 1 + test/core/derivatives/derivatives_test.py | 120 ++++++++++++++---- .../derivatives/implementation/autograd.py | 79 ++++++++---- .../derivatives/implementation/backpack.py | 9 +- test/core/derivatives/implementation/base.py | 8 +- 12 files changed, 300 insertions(+), 94 deletions(-) create mode 100644 backpack/utils/subsampling.py diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index f0115834e..19c449400 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -1,7 +1,7 @@ """Base classes for more flexible Jacobians and second-order information.""" import warnings from abc import ABC -from typing import Callable, Tuple +from typing import Callable, List, Tuple from torch import Tensor from torch.nn import Module @@ -397,6 +397,7 @@ def weight_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight to a matrix. @@ -405,16 +406,20 @@ def weight_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, N, C_out, H_out, ...]. + Has shape ``[V, *module.output.shape]``; but if used with + sub-sampling, the batch dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, N, C_w, H_w, ...] if `sum_batch == False`. - Has shape [V, C_w, H_w, ...] if `sum_batch == True`. + If ``sum_batch=False``, has shape ``[V, N, *module.weight.shape]``. + If ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. + If sub-sampling is used, ``N`` is replaced by ``len(subsampling)``. """ return self._weight_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_jac_t_mat_prod( @@ -424,6 +429,7 @@ def _weight_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index 56fc3a1fd..b6d0dcff7 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -1,11 +1,13 @@ """Contains derivatives for BatchNorm.""" from typing import List, Tuple, Union +from warnings import warn from torch import Size, Tensor, einsum from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils import TORCH_VERSION, VERSION_1_9_0 +from backpack.utils.subsampling import subsample class BatchNormNdDerivatives(BaseParameterDerivatives): @@ -48,16 +50,8 @@ def hessian_is_zero( Returns: whether hessian is zero - - Raises: - NotImplementedError: if module is in evaluation mode """ - if module.training: - return False - else: - raise NotImplementedError( - "hessian_is_zero is not tested for BatchNorm. Create an issue if you need it." - ) + return not module.training def hessian_is_diagonal( self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] @@ -143,8 +137,11 @@ def _weight_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: + self._maybe_warn_no_batch_summation(sum_batch) x_hat, _ = self._get_normalized_input_and_var(module) + x_hat = subsample(x_hat, subsampling=subsampling) if TORCH_VERSION >= VERSION_1_9_0: equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c" @@ -179,6 +176,7 @@ def _bias_jac_t_mat_prod( mat: Tensor, sum_batch: bool = True, ) -> Tensor: + self._maybe_warn_no_batch_summation(sum_batch) axis_sum: Tuple[int] = self._get_free_axes(module, with_batch_axis=sum_batch) return mat.sum(dim=axis_sum) if axis_sum else mat @@ -317,3 +315,16 @@ def _get_free_axes( for n in range(self._get_n_axis(module)): free_axes.append(index_batch + n + 2) return tuple(free_axes) + + @staticmethod + def _maybe_warn_no_batch_summation(sum_batch: bool) -> None: + """Warn that Jacobians w.r.t. single components are not per-sample gradients. + + Args: + sum_batch: Whether to sum out the batch dimension. + """ + if not sum_batch: + warn( + "BatchNorm batch summation disabled." + "This may not compute meaningful quantities" + ) diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index 8491ebbd0..e00907386 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -1,7 +1,9 @@ """Partial derivatives for ``torch.nn.ConvTranspose{1,2,3}d``.""" +from typing import List, Tuple, Union + from einops import rearrange from numpy import prod -from torch import einsum +from torch import Tensor, einsum from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from torch.nn.functional import ( conv1d, @@ -15,6 +17,7 @@ from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils.conv_transpose import unfold_by_conv_transpose +from backpack.utils.subsampling import subsample class ConvTransposeNDDerivatives(BaseParameterDerivatives): @@ -84,18 +87,26 @@ def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): return self.reshape_like_output(jac_mat, module) - def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + def _weight_jac_t_mat_prod( + self, + module: Union[ConvTranspose1d, ConvTranspose2d, ConvTranspose3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + subsampling: List[int] = None, + ) -> Tensor: V = mat.shape[0] G = module.groups C_in = module.input0.shape[1] - N = module.output.shape[0] + N = module.output.shape[0] if subsampling is None else len(subsampling) C_out = module.output.shape[1] mat_reshape = mat.reshape(V, N, G, C_out // G, *module.output.shape[2:]) - u = unfold_by_conv_transpose(module.input0, module).reshape( - N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:] - ) + u = unfold_by_conv_transpose( + subsample(module.input0, subsampling=subsampling), module + ).reshape(N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:]) dims_kern = "xyz"[: self.conv_dims] dims_data = "abc"[: self.conv_dims] diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py index 9f928f438..149f75b43 100644 --- a/backpack/core/derivatives/convnd.py +++ b/backpack/core/derivatives/convnd.py @@ -1,8 +1,9 @@ import warnings +from typing import List, Tuple, Union from einops import rearrange, reduce from numpy import prod -from torch import einsum +from torch import Tensor, einsum from torch.nn import Conv1d, Conv2d, Conv3d from torch.nn.functional import ( conv1d, @@ -16,6 +17,7 @@ from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils import conv as convUtils +from backpack.utils.subsampling import subsample class weight_jac_t_save_memory: @@ -131,11 +133,21 @@ def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): return self.reshape_like_output(jac_mat, module) - def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + def _weight_jac_t_mat_prod( + self, + module: Union[Conv1d, Conv2d, Conv3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + subsampling: List[int] = None, + ) -> Tensor: save_memory = weight_jac_t_save_memory._SAVE_MEMORY if save_memory and self.conv_dims in [1, 2]: - return self.__higher_conv_weight_jac_t(module, mat, sum_batch) + return self.__higher_conv_weight_jac_t( + module, mat, sum_batch, subsampling=subsampling + ) else: @@ -147,13 +159,22 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): ) ) - return self.__same_conv_weight_jac_t(module, mat, sum_batch) - - def __same_conv_weight_jac_t(self, module, mat, sum_batch): + return self.__same_conv_weight_jac_t( + module, mat, sum_batch, subsampling=subsampling + ) + + def __same_conv_weight_jac_t( + self, + module: Union[Conv1d, Conv2d, Conv3d], + mat: Tensor, + sum_batch: bool, + subsampling: List[int] = None, + ) -> Tensor: """Uses convolution of same order.""" G = module.groups V = mat.shape[0] - N, C_out = module.output.shape[0], module.output.shape[1] + C_out = module.output.shape[1] + N = module.output.shape[0] if subsampling is None else len(subsampling) C_in = module.input0.shape[1] C_in_axis = 1 N_axis = 0 @@ -165,7 +186,9 @@ def __same_conv_weight_jac_t(self, module, mat, sum_batch): mat = rearrange(mat, "a b ... -> (a b) ...") mat = mat.unsqueeze(C_in_axis) - input = rearrange(module.input0, "n c ... -> (n c) ...") + input = rearrange( + subsample(module.input0, subsampling=subsampling), "n c ... -> (n c) ..." + ) input = input.unsqueeze(N_axis) repeat_pattern = [1, V] + [1 for _ in range(self.conv_dims)] input = input.repeat(*repeat_pattern) @@ -191,7 +214,13 @@ def __same_conv_weight_jac_t(self, module, mat, sum_batch): else: return rearrange(grad_weight, "(v n g i o) ... -> v n (g o) i ...", **dim) - def __higher_conv_weight_jac_t(self, module, mat, sum_batch): + def __higher_conv_weight_jac_t( + self, + module: Union[Conv1d, Conv2d, Conv3d], + mat: Tensor, + sum_batch: bool, + subsampling: List[int] = None, + ) -> Tensor: """Requires higher-order convolution. The algorithm is proposed in: @@ -201,7 +230,8 @@ def __higher_conv_weight_jac_t(self, module, mat, sum_batch): """ G = module.groups V = mat.shape[0] - N, C_out = module.output.shape[0], module.output.shape[1] + C_out = module.output.shape[1] + N = module.output.shape[0] if subsampling is None else len(subsampling) C_in = module.input0.shape[1] if self.conv_dims == 1: @@ -223,8 +253,10 @@ def __higher_conv_weight_jac_t(self, module, mat, sum_batch): # Reshape to extract groups from the convolutional layer # Channels are seen as an extra spatial dimension with kernel size 1 - input_conv = module.input0.reshape(1, N * G, *spatial_dim).repeat( - *spatial_dim_axis + input_conv = ( + subsample(module.input0, subsampling=subsampling) + .reshape(1, N * G, *spatial_dim) + .repeat(*spatial_dim_axis) ) # Compute convolution between input and output; the batchsize is seen # as channels, taking advantage of the `groups` argument diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index a41fbb47a..1b6012a82 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -1,11 +1,12 @@ """Contains partial derivatives for the ``torch.nn.Linear`` layer.""" -from typing import Tuple +from typing import List, Tuple from torch import Size, Tensor, einsum from torch.nn import Linear from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils import TORCH_VERSION, VERSION_1_9_0 +from backpack.utils.subsampling import subsample class LinearDerivatives(BaseParameterDerivatives): @@ -119,6 +120,7 @@ def _weight_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: int = True, + subsampling: List[int] = None, ) -> Tensor: """Batch-apply transposed Jacobian of the output w.r.t. the weight. @@ -128,15 +130,19 @@ def _weight_jac_t_mat_prod( g_out: Gradients w.r.t. module output. Not required by the implementation. mat: Batch of ``V`` vectors of same shape as the layer output (``[N, *, out_features]``) to which the transposed output-input Jacobian - is applied. Has shape ``[V, N, *, out_features]``. + is applied. Has shape ``[V, N, *, out_features]`` if subsampling is not + used, otherwise ``N`` must be ``len(subsampling)`` instead. sum_batch: Sum the result's batch axis. Default: ``True``. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Batched transposed Jacobian vector products. Has shape - ``[V, N, *module.weight.shape]`` when ``sum_batch`` is ``False``. With - ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. + ``[V, N, *module.weight.shape]`` when ``sum_batch`` is ``False``. With + ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. If sub- + sampling is used, ``N`` must be ``len(subsampling)`` instead. """ - d_weight = module.input0 + d_weight = subsample(module.input0, subsampling=subsampling) if TORCH_VERSION >= VERSION_1_9_0: equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi" diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index 931c0ef0c..a9ba070c8 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -8,6 +8,8 @@ from torch import Tensor from torch.nn import Module +from backpack.utils.subsampling import subsample + ############################################################################### # Utility functions # @@ -59,7 +61,12 @@ def _check_same_V_dim(mat1, mat2): def _check_like(mat, module, name, diff=1, *args, **kwargs): - return check_shape(mat, getattr(module, name), diff=diff) + if name == "output" and "subsampling" in kwargs.keys(): + compare = subsample(module.output, subsampling=kwargs["subsampling"]) + else: + compare = getattr(module, name) + + return check_shape(mat, compare, diff=diff) def _check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs): diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py new file mode 100644 index 000000000..f429d580b --- /dev/null +++ b/backpack/utils/subsampling.py @@ -0,0 +1,28 @@ +"""Utility functions to enable mini-batch subsampling in extensions.""" +from typing import List + +from torch import Tensor + + +def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: + """Select samples from a tensor along a dimension. + + Args: + tensor: Tensor to select from. + dim: Selection dimension. Defaults to ``0``. + subsampling: Indices of samples that are sliced along the dimension. + Defaults to ``None`` (use all samples). + + Returns: + Tensor of same rank that is sub-sampled along the dimension. + + Raises: + NotImplementedError: If dimension differs from ``0``. + """ + if subsampling is None: + return tensor + else: + if dim == 0: + return tensor[subsampling] + else: + raise NotImplementedError(f"Only supports dim = 0. Got {dim}.") diff --git a/fully_documented.txt b/fully_documented.txt index 608c3a7ef..15528fdbb 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -40,6 +40,7 @@ backpack/extensions/secondorder/diag_hessian/conv3d.py backpack/extensions/secondorder/sqrt_ggn/ backpack/utils/linear.py +backpack/utils/subsampling.py backpack/utils/__init__.py test/extensions/problem.py diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index b5c1e83f6..e74e0bf3e 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -17,11 +17,13 @@ from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS from test.core.derivatives.settings import SETTINGS +from typing import List, Tuple, Union from warnings import warn import pytest import torch from pytest import fixture, skip +from torch import Size, Tensor from backpack.core.derivatives.convnd import weight_jac_t_save_memory @@ -50,6 +52,9 @@ BATCH_NORM_PROBLEMS = make_test_problems(BATCH_NORM_SETTINGS) BATCH_NORM_IDS = [problem.make_id() for problem in BATCH_NORM_PROBLEMS] +SUBSAMPLINGS = [None, [0, 0], [2, 0]] +SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] + @pytest.mark.parametrize( "problem", @@ -239,33 +244,54 @@ def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): [True, False], ids=["save_memory=True", "save_memory=False"], ) -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_WEIGHTS + BATCH_NORM_PROBLEMS, - ids=IDS_WITH_WEIGHTS + BATCH_NORM_IDS, -) def test_weight_jac_t_mat_prod( - problem: DerivativesTestProblem, sum_batch: bool, save_memory: bool, V: int = 3 + problem_weight_jac_t_mat, + sum_batch: bool, + save_memory: bool, ) -> None: """Test the transposed Jacobian-matrix product w.r.t. to the weight. Args: - problem: Test case. + problem_weight_jac_t_mat: Instantiated test case, subsampling, and + input for weight_jac_t sum_batch: Sum out the batch dimension. save_memory: Use Owkin implementation in convolutions to save memory. - V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ - problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + problem, subsampling, mat = problem_weight_jac_t_mat with weight_jac_t_save_memory(save_memory): backpack_res = BackpackDerivatives(problem).weight_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) - autograd_res = AutogradDerivatives(problem).weight_jac_t_mat_prod(mat, sum_batch) - + autograd_res = AutogradDerivatives(problem).weight_jac_t_mat_prod( + mat, sum_batch, subsampling=subsampling + ) check_sizes_and_values(autograd_res, backpack_res, rtol=5e-5) - problem.tear_down() + + +def rand_mat_like_output( + V: int, output_shape: Size, subsampling: List[int] = None +) -> Tensor: + """Generate random matrix whose columns are shaped like the layer output. + + Can be used to generate random inputs to functions that act on tensors + shaped like the module output (like ``*_jac_t_mat_prod``). + + Args: + V: Number of rows. + output_shape: Shape of the module output. + subsampling: Indices of samples used by sub-sampling. + + Returns: + Random matrix with (subsampled) output shape. + """ + subsample_shape = list(output_shape) + + if subsampling is not None: + N_axis = 0 + subsample_shape[N_axis] = len(subsampling) + + return torch.rand(V, *subsample_shape) @pytest.mark.parametrize( @@ -472,8 +498,8 @@ def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None problem.tear_down() -@fixture(params=PROBLEMS, ids=lambda p: p.make_id()) -def instantiated_problem(request) -> DerivativesTestProblem: +@fixture(params=PROBLEMS + BATCH_NORM_PROBLEMS, ids=lambda p: p.make_id()) +def problem(request) -> DerivativesTestProblem: """Set seed, create tested layer and data. Finally clean up. Args: @@ -488,26 +514,72 @@ def instantiated_problem(request) -> DerivativesTestProblem: case.tear_down() +@fixture +def problem_weight(problem: DerivativesTestProblem) -> DerivativesTestProblem: + """Filter out cases that don't have a weight parameter. + + Args: + problem: Test case with deterministically constructed attributes. + + Yields: + Instantiated cases that have a weight parameter. + """ + has_weight = hasattr(problem.module, "weight") and problem.module.weight is not None + if has_weight: + yield problem + else: + skip("Test case has no weight parameter.") + + +@fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +def problem_weight_jac_t_mat( + request, problem_weight: DerivativesTestProblem +) -> Tuple[DerivativesTestProblem, Union[None, List[int]], Tensor]: + """Create matrix that will be multiplied by the weight Jacobian. + + Skip if there is a conflict where the subsampling indices exceed the number of + samples in the input. + + Args: + request (SubRequest): Request for the fixture from a test/fixture function. + problem_weight: Test case with weight parameter. + + Yields: + problem with weight, subsampling, matrix for weight_jac_t + """ + subsampling: Union[None, List[int]] = request.param + N = problem_weight.input_shape[0] + enough_samples = subsampling is None or N >= max(subsampling) + + if not enough_samples: + skip(f"Not enough samples: sub-sampling {subsampling}, batch_size {N}") + + V = 3 + mat = rand_mat_like_output( + V, problem_weight.output_shape, subsampling=subsampling + ).to(problem_weight.device) + + yield (problem_weight, subsampling, mat) + del mat + + @fixture def small_input_problem( - instantiated_problem: DerivativesTestProblem, max_input_numel: int = 100 + problem: DerivativesTestProblem, max_input_numel: int = 100 ) -> DerivativesTestProblem: """Skip cases with large inputs. Args: - instantiated_problem: Test case with deterministically constructed attributes. + problem: Test case with deterministically constructed attributes. max_input_numel: Maximum input size. Default: ``100``. Yields: Instantiated test case with small input. """ - if instantiated_problem.input.numel() > max_input_numel: - skip( - "Input is too large:" - + f" {instantiated_problem.input.numel()} > {max_input_numel}" - ) + if problem.input.numel() > max_input_numel: + skip("Input is too large:" + f" {problem.input.numel()} > {max_input_numel}") else: - yield instantiated_problem + yield problem @fixture diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index bdc281f51..f19941236 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -1,5 +1,6 @@ """Derivatives computed with PyTorch's autograd.""" from test.core.derivatives.implementation.base import DerivativesImplementation +from typing import List import torch from torch import Tensor, stack, zeros_like @@ -41,8 +42,10 @@ def jac_t_vec_prod(self, vec): # noqa: D102 def jac_t_mat_prod(self, mat): # noqa: D102 return stack([self.jac_t_vec_prod(vec) for vec in mat]) - def weight_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("weight", mat, sum_batch) + def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 + return self.param_jac_t_mat_prod( + "weight", mat, sum_batch, subsampling=subsampling + ) def bias_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 return self.param_jac_t_mat_prod("bias", mat, sum_batch) @@ -59,50 +62,70 @@ def weight_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 def weight_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 return self.param_jac_t_mat_prod("weight_hh_l0", mat, sum_batch, axis_batch=1) - def param_jac_t_vec_prod(self, name, vec, sum_batch, axis_batch=0): + def param_jac_t_vec_prod( + self, + name: str, + vec: Tensor, + sum_batch: bool, + axis_batch: int = 0, + subsampling: List[int] = None, + ) -> Tensor: """Compute the product of jac_t and the given vector. Args: - name (str): name of parameter for derivative - vec (torch.Tensor): vectors which to multiply - sum_batch (boolean): whether to sum along batch axis - axis_batch (int, optional): index of batch axis. Defaults to 0. + name: name of parameter for derivative + vec: vectors which to multiply + sum_batch: whether to sum along batch axis + axis_batch: index of batch axis. Defaults to 0. + subsampling: Indices of active samples. Default: ``None`` (all). Returns: - torch.Tensor: product of jac_t and vec + product of jac_t and vec """ - _, output, named_params = self.problem.forward_pass() + input, output, named_params = self.problem.forward_pass() param = named_params[name] - if sum_batch: - return transposed_jacobian_vector_product(output, param, vec)[0] - else: - sample_outputs = output.split(1, dim=axis_batch) - sample_vecs = vec.split(1, dim=axis_batch) - - jac_t_sample_prods = [ - transposed_jacobian_vector_product(n_out, param, n_vec)[0] - for n_out, n_vec in zip(sample_outputs, sample_vecs) - ] + samples = range(input.shape[axis_batch]) if subsampling is None else subsampling + sample_outputs = output.split(1, dim=axis_batch) + sample_vecs = vec.split(1, dim=axis_batch) - return stack(jac_t_sample_prods) + jac_t_sample_prods = stack( + [ + transposed_jacobian_vector_product(sample_outputs[n], param, vec_n)[0] + for n, vec_n in zip(samples, sample_vecs) + ], + ) - def param_jac_t_mat_prod(self, name, mat, sum_batch, axis_batch=0): + if sum_batch: + jac_t_sample_prods = jac_t_sample_prods.sum(0) + + return jac_t_sample_prods + + def param_jac_t_mat_prod( + self, + name: str, + mat: Tensor, + sum_batch: bool, + axis_batch: int = 0, + subsampling: List[int] = None, + ) -> Tensor: """Compute the product of jac_t and the given matrix. Args: - name (str): name of parameter for derivative - mat (torch.Tensor): matrix which to multiply - sum_batch (boolean): whether to sum along batch axis - axis_batch (int, optional): index of batch axis. This is counted - without the first axis. Defaults to 0. + name: name of parameter for derivative + mat: matrix which to multiply + sum_batch: whether to sum along batch axis + axis_batch: Batch axis, counted without the first axis. Defaults to 0. + subsampling: Indices of active samples. Default: ``None`` (all). Returns: - torch.Tensor: product of jac_t and mat + product of jac_t and mat """ return stack( [ - self.param_jac_t_vec_prod(name, vec, sum_batch, axis_batch=axis_batch) + self.param_jac_t_vec_prod( + name, vec, sum_batch, axis_batch=axis_batch, subsampling=subsampling + ) for vec in mat ] ) diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 97206c16f..907be6015 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -36,10 +36,15 @@ def jac_t_mat_prod(self, mat): # noqa: D102 self.problem.module, None, None, mat ) - def weight_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 self.store_forward_io() return self.problem.derivative.weight_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) def bias_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index 43aa48868..002268b89 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -1,5 +1,6 @@ """Contains DerivativesImplementation, the base class for autograd and backpack.""" from abc import ABC, abstractmethod +from typing import List from torch import Tensor @@ -42,12 +43,15 @@ def jac_t_mat_prod(self, mat: Tensor) -> Tensor: raise NotImplementedError @abstractmethod - def weight_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: - """Product of jacobian and matrix. + def weight_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: + """Matrix-Jacobian products w.r.t. the weight. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product From 68870eddacf38c6966e569f8806564d077c85421 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Tue, 6 Jul 2021 18:32:22 +0200 Subject: [PATCH 30/54] [ADD] Application of `bias_jac_t` for a subset of samples (#196) Prepares the `core` functionality to support #12. * [ADD] Extend `weight_jac_t` by subsampling argument * [ADD] Sub-sampling for batch normalization, adapt interface * [REF] Separate fixtures for filtering weights and matrix generation * [REF] Shorten instantiation fixture name * [TEST] Add BN setting to `weight_jac_t` tests * [DOC] Add batch norm test setting to fully-documented * [ADD] Support subsampling in `bias_jac_t` * [DOC] Fix docstring * [TEST] Re-introduce missing BN setting Co-authored-by: Felix Dangel --- backpack/core/derivatives/basederivatives.py | 16 +++- backpack/core/derivatives/batchnorm_nd.py | 1 + backpack/core/derivatives/conv_transposend.py | 10 +-- backpack/core/derivatives/convnd.py | 10 +-- backpack/core/derivatives/linear.py | 9 ++- test/core/derivatives/derivatives_test.py | 78 +++++++++++++++---- .../derivatives/implementation/autograd.py | 6 +- .../derivatives/implementation/backpack.py | 9 ++- test/core/derivatives/implementation/base.py | 5 +- 9 files changed, 107 insertions(+), 37 deletions(-) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 19c449400..524a72fd7 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -335,6 +335,7 @@ def bias_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias to a matrix. @@ -343,15 +344,21 @@ def bias_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, N, C_out, H_out, ...]. + Has shape ``[V, *module.output.shape]``; but if used with + sub-sampling, the batch dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, N, C_b, ...] if `sum_batch == False`. - Has shape [V, C_b, ...] if `sum_batch == True`. + If ``sum_batch=False``, has shape ``[V, N, *module.bias.shape]``. + If ``sum_batch=True``, has shape ``[V, *module.bias.shape]``. + If sub-sampling is used, ``N`` is replaced by ``len(subsampling)``. """ - return self._bias_jac_t_mat_prod(module, g_inp, g_out, mat, sum_batch=sum_batch) + return self._bias_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling + ) def _bias_jac_t_mat_prod( self, @@ -360,6 +367,7 @@ def _bias_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index b6d0dcff7..bc6a3b083 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -175,6 +175,7 @@ def _bias_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: self._maybe_warn_no_batch_summation(sum_batch) axis_sum: Tuple[int] = self._get_free_axes(module, with_batch_axis=sum_batch) diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index e00907386..fc8dd7844 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -51,11 +51,11 @@ def __init__(self, N): def hessian_is_zero(self, module): return True - def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - axes = list(range(3, len(module.output.shape) + 1)) - if sum_batch: - axes = [1] + axes - return mat.sum(axes) + def _bias_jac_t_mat_prod( + self, module, g_inp, g_out, mat, sum_batch=True, subsampling=None + ): + equation = f"vnc...->v{'' if sum_batch else 'n'}c" + return einsum(equation, mat) def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): # Expand batch dimension diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py index 149f75b43..2e8dedaf3 100644 --- a/backpack/core/derivatives/convnd.py +++ b/backpack/core/derivatives/convnd.py @@ -116,11 +116,11 @@ def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): return jac_mat.expand(*expand_shape) - def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - axes = list(range(3, len(module.output.shape) + 1)) - if sum_batch: - axes = [1] + axes - return mat.sum(axes) + def _bias_jac_t_mat_prod( + self, module, g_inp, g_out, mat, sum_batch=True, subsampling=None + ): + equation = f"vnc...->v{'' if sum_batch else 'n'}c" + return einsum(equation, mat) def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): # separate output channel groups diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index 1b6012a82..e6a8fe586 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -191,6 +191,7 @@ def _bias_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: int = True, + subsampling: List[int] = None, ) -> Tensor: """Batch-apply transposed Jacobian of the output w.r.t. the bias. @@ -202,14 +203,16 @@ def _bias_jac_t_mat_prod( (``[N, *, out_features]``) to which the transposed output-input Jacobian is applied. Has shape ``[V, N, *, out_features]``. sum_batch: Sum the result's batch axis. Default: ``True``. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Batched transposed Jacobian vector products. Has shape - ``[V, N, *module.bias.shape]`` when ``sum_batch`` is ``False``. With - ``sum_batch=True``, has shape ``[V, *module.bias.shape]``. + ``[V, N, *module.bias.shape]`` when ``sum_batch`` is ``False``. With + ``sum_batch=True``, has shape ``[V, *module.bias.shape]``. If sub- + sampling is used, ``N`` is replaced by ``len(subsampling)``. """ equation = f"vn...o->v{'' if sum_batch else 'n'}o" - return einsum(equation, mat) # TODO Remove after deprecating torch<1.9.0 diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index e74e0bf3e..d4475f2a7 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -245,7 +245,7 @@ def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): ids=["save_memory=True", "save_memory=False"], ) def test_weight_jac_t_mat_prod( - problem_weight_jac_t_mat, + problem_weight_jac_t_mat: Tuple[DerivativesTestProblem, List[int], Tensor], sum_batch: bool, save_memory: bool, ) -> None: @@ -266,6 +266,7 @@ def test_weight_jac_t_mat_prod( autograd_res = AutogradDerivatives(problem).weight_jac_t_mat_prod( mat, sum_batch, subsampling=subsampling ) + check_sizes_and_values(autograd_res, backpack_res, rtol=5e-5) @@ -327,29 +328,27 @@ def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> Non @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) -@pytest.mark.parametrize( - "problem", - PROBLEMS_WITH_BIAS + BATCH_NORM_PROBLEMS, - ids=IDS_WITH_BIAS + BATCH_NORM_IDS, -) def test_bias_jac_t_mat_prod( - problem: DerivativesTestProblem, sum_batch: bool, V: int = 3 + problem_bias_jac_t_mat: Tuple[DerivativesTestProblem, List[int], Tensor], + sum_batch: bool, ) -> None: """Test the transposed Jacobian-matrix product w.r.t. to the bias. Args: - problem: Test case. + problem_bias_jac_t_mat: Instantiated test case, subsampling, and + input for bias_jac_t sum_batch: Sum out the batch dimension. - V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ - problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + problem, subsampling, mat = problem_bias_jac_t_mat - backpack_res = BackpackDerivatives(problem).bias_jac_t_mat_prod(mat, sum_batch) - autograd_res = AutogradDerivatives(problem).bias_jac_t_mat_prod(mat, sum_batch) + backpack_res = BackpackDerivatives(problem).bias_jac_t_mat_prod( + mat, sum_batch, subsampling=subsampling + ) + autograd_res = AutogradDerivatives(problem).bias_jac_t_mat_prod( + mat, sum_batch, subsampling=subsampling + ) check_sizes_and_values(autograd_res, backpack_res) - problem.tear_down() @pytest.mark.parametrize( @@ -524,7 +523,7 @@ def problem_weight(problem: DerivativesTestProblem) -> DerivativesTestProblem: Yields: Instantiated cases that have a weight parameter. """ - has_weight = hasattr(problem.module, "weight") and problem.module.weight is not None + has_weight = getattr(problem.module, "weight", None) is not None if has_weight: yield problem else: @@ -563,6 +562,55 @@ def problem_weight_jac_t_mat( del mat +@fixture +def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: + """Filter out cases that don't have a bias parameter. + + Args: + problem: Test case with deterministically constructed attributes. + + Yields: + Instantiated cases that have a bias parameter. + """ + has_bias = getattr(problem.module, "bias", None) is not None + if has_bias: + yield problem + else: + skip("Test case has no bias parameter.") + + +@fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +def problem_bias_jac_t_mat( + request, problem_bias: DerivativesTestProblem +) -> Tuple[DerivativesTestProblem, Union[None, List[int]], Tensor]: + """Create matrix that will be multiplied by the bias Jacobian. + + Skip if there is a conflict where the subsampling indices exceed the number of + samples in the input. + + Args: + request (SubRequest): Request for the fixture from a test/fixture function. + problem_bias: Test case with bias parameter. + + Yields: + problem with bias, subsampling, matrix for bias_jac_t + """ + subsampling: Union[None, List[int]] = request.param + N = problem_bias.input_shape[0] + enough_samples = subsampling is None or N >= max(subsampling) + + if not enough_samples: + skip(f"Not enough samples: sub-sampling {subsampling}, batch_size {N}") + + V = 3 + mat = rand_mat_like_output( + V, problem_bias.output_shape, subsampling=subsampling + ).to(problem_bias.device) + + yield (problem_bias, subsampling, mat) + del mat + + @fixture def small_input_problem( problem: DerivativesTestProblem, max_input_numel: int = 100 diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index f19941236..d17f8a1e3 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -47,8 +47,10 @@ def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 "weight", mat, sum_batch, subsampling=subsampling ) - def bias_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("bias", mat, sum_batch) + def bias_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 + return self.param_jac_t_mat_prod( + "bias", mat, sum_batch, subsampling=subsampling + ) def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 907be6015..2662c6847 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -47,10 +47,15 @@ def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 subsampling=subsampling, ) - def bias_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def bias_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 self.store_forward_io() return self.problem.derivative.bias_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) def weight_jac_mat_prod(self, mat): # noqa: D102 diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index 002268b89..b6eedb421 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -59,12 +59,15 @@ def weight_jac_t_mat_prod( raise NotImplementedError @abstractmethod - def bias_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def bias_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product From cbee344c1fa5fec91cab0ccd102ded81bdb19ceb Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Tue, 6 Jul 2021 19:57:06 +0200 Subject: [PATCH 31/54] [ADD] Application of RNN/LSTM `param_jac_t` for a subset of samples (#197) Prepares the `core` functionality to support #12. * [ADD] Extend `weight_jac_t` by subsampling argument * [ADD] Sub-sampling for batch normalization, adapt interface * [REF] Separate fixtures for filtering weights and matrix generation * [REF] Shorten instantiation fixture name * [TEST] Add BN setting to `weight_jac_t` tests * [DOC] Add batch norm test setting to fully-documented * [ADD] Support subsampling in `bias_jac_t` * [ADD] Support sub-sampling in RNN/LSTM `param_weight_jac_t` * [FIX] flake8 * [FMT] Squeeze some lines * [FIX] Typo in exception message * [DOC] Correct shapes Co-authored-by: Felix Dangel --- backpack/core/derivatives/basederivatives.py | 60 +++++-- backpack/core/derivatives/lstm.py | 55 ++++-- backpack/core/derivatives/rnn.py | 41 +++-- backpack/core/derivatives/shape_check.py | 6 +- backpack/utils/subsampling.py | 25 ++- test/core/derivatives/derivatives_test.py | 164 +++++++++++------- .../derivatives/implementation/autograd.py | 28 ++- .../derivatives/implementation/backpack.py | 40 ++++- test/core/derivatives/implementation/base.py | 20 ++- .../secondorder/diag_ggn/diagggn_settings.py | 18 -- 10 files changed, 313 insertions(+), 144 deletions(-) delete mode 100644 test/extensions/secondorder/diag_ggn/diagggn_settings.py diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 524a72fd7..35ca27d30 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -450,6 +450,7 @@ def bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_ih_l0 to a matrix. @@ -458,16 +459,21 @@ def bias_ih_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H] if `sum_batch == False`. - Has shape [V, T, H] if `sum_batch == True`. + Has shape [V, N, *module.bias_ih_l0.shape] if ``sum_batch == False``; but if + used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.bias_ih_l0.shape] if + ``sum_batch == True``. """ return self._bias_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _bias_ih_l0_jac_t_mat_prod( @@ -477,6 +483,7 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -489,6 +496,7 @@ def bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_hh_l0 to a matrix. @@ -497,16 +505,21 @@ def bias_hh_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H] if `sum_batch == False`. - Has shape [V, T, H] if `sum_batch == True`. + Has shape [V, N, *module.bias_hh_l0.shape] if ``sum_batch == False``; but if + used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.bias_hh_l0.shape] if + ``sum_batch == True``. """ return self._bias_hh_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _bias_hh_l0_jac_t_mat_prod( @@ -516,6 +529,7 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -528,6 +542,7 @@ def weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_ih_l0 to a matrix. @@ -536,16 +551,21 @@ def weight_ih_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H, I] if `sum_batch == False`. - Has shape [V, T, H, I] if `sum_batch == True`. + Has shape [V, N, *module.weight_ih_l0.shape] if ``sum_batch == False``; but + if used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.weight_ih_l0.shape] if + ``sum_batch == True``. """ return self._weight_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_ih_l0_jac_t_mat_prod( @@ -555,6 +575,7 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -567,6 +588,7 @@ def weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_hh_l0 to a matrix. @@ -575,16 +597,21 @@ def weight_hh_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H, I] if `sum_batch == False`. - Has shape [V, T, H, I] if `sum_batch == True`. + Has shape [V, N, *module.weight_hh_l0.shape] if ``sum_batch == False``; but + if used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.weight_hh_l0.shape] if + ``sum_batch == True``. """ return self._weight_hh_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_hh_l0_jac_t_mat_prod( @@ -594,6 +621,7 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 1360103e6..8eb88386d 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -1,11 +1,12 @@ """Partial derivatives for nn.LSTM.""" -from typing import Tuple +from typing import List, Tuple from torch import Tensor, cat, einsum, sigmoid, tanh, zeros from torch.nn import LSTM from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils import TORCH_VERSION, VERSION_1_8_0 +from backpack.utils.subsampling import get_batch_axis, subsample class LSTMDerivatives(BaseParameterDerivatives): @@ -54,7 +55,7 @@ def _check_parameters(module: LSTM) -> None: @staticmethod def _forward_pass( - module: LSTM, mat: Tensor + module: LSTM, mat: Tensor, subsampling: List[int] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """This performs an additional forward pass and returns the hidden variables. @@ -65,7 +66,8 @@ def _forward_pass( Args: module: module - mat: matrix, used to extract device and shapes + mat: matrix, used to extract device and shapes. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: ifgo, c, c_tanh, h @@ -83,16 +85,19 @@ def _forward_pass( c: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) h: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) + + N_axis = get_batch_axis(module) + input0 = subsample(module.input0, dim=N_axis, subsampling=subsampling) + output = subsample(module.output, dim=N_axis, subsampling=subsampling) + for t in range(T): ifgo[t] = ( - einsum("hi,ni->nh", module.weight_ih_l0, module.input0[t]) + einsum("hi,ni->nh", module.weight_ih_l0, input0[t]) + module.bias_ih_l0 + module.bias_hh_l0 ) if t != 0: - ifgo[t] += einsum( - "hg,ng->nh", module.weight_hh_l0, module.output[t - 1] - ) + ifgo[t] += einsum("hg,ng->nh", module.weight_hh_l0, output[t - 1]) ifgo[t, :, H0:H1] = sigmoid(ifgo[t, :, H0:H1]) ifgo[t, :, H1:H2] = sigmoid(ifgo[t, :, H1:H2]) ifgo[t, :, H2:H3] = tanh(ifgo[t, :, H2:H3]) @@ -106,7 +111,9 @@ def _forward_pass( return ifgo, c, c_tanh, h @classmethod - def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor: + def _ifgo_jac_t_mat_prod( + cls, module: LSTM, mat: Tensor, subsampling: List[int] = None + ) -> Tensor: V: int = mat.shape[0] T: int = mat.shape[1] N: int = mat.shape[2] @@ -117,7 +124,7 @@ def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor: H3: int = 3 * H H4: int = 4 * H - ifgo, c, c_tanh, _ = cls._forward_pass(module, mat) + ifgo, c, c_tanh, _ = cls._forward_pass(module, mat, subsampling=subsampling) # backward pass H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) @@ -288,10 +295,13 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( + module, mat, subsampling=subsampling + ) return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod) @@ -302,9 +312,10 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: return self._bias_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_ih_l0_jac_t_mat_prod( @@ -314,13 +325,20 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( + module, mat, subsampling=subsampling + ) return einsum( - f"vtnh,tni->v{'' if sum_batch else 'n'}hi", IFGO_prod, module.input0 + f"vtnh,tni->v{'' if sum_batch else 'n'}hi", + IFGO_prod, + subsample( + module.input0, dim=get_batch_axis(module), subsampling=subsampling + ), ) def _weight_hh_l0_jac_t_mat_prod( @@ -330,13 +348,16 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) N: int = mat.shape[2] H: int = module.hidden_size - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( + module, mat, subsampling=subsampling + ) return einsum( f"vtnh,tng->v{'' if sum_batch else 'n'}hg", @@ -344,7 +365,11 @@ def _weight_hh_l0_jac_t_mat_prod( cat( [ zeros(1, N, H, device=mat.device, dtype=mat.dtype), - module.output[0:-1], + subsample( + module.output, + dim=get_batch_axis(module), + subsampling=subsampling, + )[0:-1], ], dim=0, ), diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 434b987b9..2277c0380 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -6,6 +6,7 @@ from torch.nn import RNN from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils.subsampling import get_batch_axis, subsample class RNNDerivatives(BaseParameterDerivatives): @@ -134,6 +135,7 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_ih_l0. @@ -143,6 +145,7 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product @@ -152,9 +155,13 @@ def _bias_ih_l0_jac_t_mat_prod( dim: List[int] = [1, 2] else: dim: int = 1 - return self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat).sum( - dim=dim - ) + return self._a_jac_t_mat_prod( + subsample( + module.output, dim=get_batch_axis(module), subsampling=subsampling + ), + module.weight_hh_l0, + mat, + ).sum(dim=dim) def _bias_hh_l0_jac_t_mat_prod( self, @@ -163,6 +170,7 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_hh_l0. @@ -172,12 +180,13 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ return self._bias_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_ih_l0_jac_t_mat_prod( @@ -187,6 +196,7 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_ih_l0. @@ -196,15 +206,21 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ self._check_parameters(module) + N_axis = get_batch_axis(module) return einsum( "vtnh,tnj->" + ("vhj" if sum_batch else "vnhj"), - self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), - module.input0, + self._a_jac_t_mat_prod( + subsample(module.output, dim=N_axis, subsampling=subsampling), + module.weight_hh_l0, + mat, + ), + subsample(module.input0, dim=N_axis, subsampling=subsampling), ) def _weight_hh_l0_jac_t_mat_prod( @@ -214,6 +230,7 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_hh_l0. @@ -223,21 +240,21 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ self._check_parameters(module) - N: int = mat.shape[2] + N_axis = get_batch_axis(module) + N: int = mat.shape[N_axis + 1] H: int = mat.shape[3] + output = subsample(module.output, dim=N_axis, subsampling=subsampling) return einsum( "vtnh,tnk->" + ("vhk" if sum_batch else "vnhk"), - self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), + self._a_jac_t_mat_prod(output, module.weight_hh_l0, mat), cat( - [ - zeros(1, N, H, device=mat.device, dtype=mat.dtype), - module.output[0:-1], - ], + [zeros(1, N, H, device=mat.device, dtype=mat.dtype), output[0:-1]], dim=0, ), ) diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index a9ba070c8..c5bdc6e2b 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -8,7 +8,7 @@ from torch import Tensor from torch.nn import Module -from backpack.utils.subsampling import subsample +from backpack.utils.subsampling import get_batch_axis, subsample ############################################################################### @@ -62,7 +62,9 @@ def _check_same_V_dim(mat1, mat2): def _check_like(mat, module, name, diff=1, *args, **kwargs): if name == "output" and "subsampling" in kwargs.keys(): - compare = subsample(module.output, subsampling=kwargs["subsampling"]) + compare = subsample( + module.output, dim=get_batch_axis(module), subsampling=kwargs["subsampling"] + ) else: compare = getattr(module, name) diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py index f429d580b..68a02c34f 100644 --- a/backpack/utils/subsampling.py +++ b/backpack/utils/subsampling.py @@ -2,6 +2,7 @@ from typing import List from torch import Tensor +from torch.nn import LSTM, RNN, Module def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: @@ -17,12 +18,32 @@ def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Te Tensor of same rank that is sub-sampled along the dimension. Raises: - NotImplementedError: If dimension differs from ``0``. + NotImplementedError: If dimension differs from ``0`` or ``1``. """ if subsampling is None: return tensor else: if dim == 0: return tensor[subsampling] + elif dim == 1: + return tensor[:, subsampling] else: - raise NotImplementedError(f"Only supports dim = 0. Got {dim}.") + raise NotImplementedError(f"Only supports dim = 0,1. Got {dim}.") + + +def get_batch_axis(module: Module) -> int: + """Return the batch axis assumed by the module. + + Args: + module: A module. + + Returns: + Batch axis + """ + if isinstance(module, (RNN, LSTM)): + if module.batch_first: + return 0 + else: + return 1 + else: + return 0 diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index d4475f2a7..3478c6e64 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -23,9 +23,10 @@ import pytest import torch from pytest import fixture, skip -from torch import Size, Tensor +from torch import Tensor from backpack.core.derivatives.convnd import weight_jac_t_save_memory +from backpack.utils.subsampling import get_batch_axis PROBLEMS = make_test_problems(SETTINGS) IDS = [problem.make_id() for problem in PROBLEMS] @@ -124,112 +125,144 @@ def test_jac_t_mat_prod(problem: DerivativesTestProblem, request, V: int = 3) -> IDS_WITH_WEIGHTS.append(problem_id) +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_bias_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): +def test_bias_ih_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to bias_ih_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).bias_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).bias_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_bias_hh_l0_jac_t_mat_prod(problem, sum_batch, V=3): +def test_bias_hh_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to bias_hh_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).bias_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).bias_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_weight_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): +def test_weight_ih_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to weight_ih_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).weight_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).weight_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): +def test_weight_hh_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to weight_hh_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).weight_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).weight_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) @@ -271,7 +304,7 @@ def test_weight_jac_t_mat_prod( def rand_mat_like_output( - V: int, output_shape: Size, subsampling: List[int] = None + V: int, problem: DerivativesTestProblem, subsampling: List[int] = None ) -> Tensor: """Generate random matrix whose columns are shaped like the layer output. @@ -280,16 +313,16 @@ def rand_mat_like_output( Args: V: Number of rows. - output_shape: Shape of the module output. + problem: Test case. subsampling: Indices of samples used by sub-sampling. Returns: Random matrix with (subsampled) output shape. """ - subsample_shape = list(output_shape) + subsample_shape = list(problem.output_shape) if subsampling is not None: - N_axis = 0 + N_axis = get_batch_axis(problem.module) subsample_shape[N_axis] = len(subsampling) return torch.rand(V, *subsample_shape) @@ -523,11 +556,8 @@ def problem_weight(problem: DerivativesTestProblem) -> DerivativesTestProblem: Yields: Instantiated cases that have a weight parameter. """ - has_weight = getattr(problem.module, "weight", None) is not None - if has_weight: - yield problem - else: - skip("Test case has no weight parameter.") + _skip_if_no_param(problem, "weight") + yield problem @fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @@ -547,21 +577,44 @@ def problem_weight_jac_t_mat( problem with weight, subsampling, matrix for weight_jac_t """ subsampling: Union[None, List[int]] = request.param - N = problem_weight.input_shape[0] - enough_samples = subsampling is None or N >= max(subsampling) - - if not enough_samples: - skip(f"Not enough samples: sub-sampling {subsampling}, batch_size {N}") + _skip_if_subsampling_conflict(problem_weight, subsampling) V = 3 - mat = rand_mat_like_output( - V, problem_weight.output_shape, subsampling=subsampling - ).to(problem_weight.device) + mat = rand_mat_like_output(V, problem_weight, subsampling=subsampling).to( + problem_weight.device + ) yield (problem_weight, subsampling, mat) del mat +def _skip_if_subsampling_conflict( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: + """Skip if some samples in subsampling are not contained in input. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + N = problem.input_shape[get_batch_axis(problem.module)] + enough_samples = subsampling is None or N >= max(subsampling) + if not enough_samples: + skip("Not enough samples.") + + +def _skip_if_no_param(problem: DerivativesTestProblem, param_str: str) -> None: + """Skip if test case does not contain the parameter. + + Args: + problem: Test case. + param_str: Parameter name. + """ + has_param = getattr(problem.module, param_str, None) is not None + if not has_param: + skip(f"Test case has no {param_str} parameter.") + + @fixture def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: """Filter out cases that don't have a bias parameter. @@ -572,11 +625,8 @@ def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: Yields: Instantiated cases that have a bias parameter. """ - has_bias = getattr(problem.module, "bias", None) is not None - if has_bias: - yield problem - else: - skip("Test case has no bias parameter.") + _skip_if_no_param(problem, "bias") + yield problem @fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @@ -596,16 +646,12 @@ def problem_bias_jac_t_mat( problem with bias, subsampling, matrix for bias_jac_t """ subsampling: Union[None, List[int]] = request.param - N = problem_bias.input_shape[0] - enough_samples = subsampling is None or N >= max(subsampling) - - if not enough_samples: - skip(f"Not enough samples: sub-sampling {subsampling}, batch_size {N}") + _skip_if_subsampling_conflict(problem_bias, subsampling) V = 3 - mat = rand_mat_like_output( - V, problem_bias.output_shape, subsampling=subsampling - ).to(problem_bias.device) + mat = rand_mat_like_output(V, problem_bias, subsampling=subsampling).to( + problem_bias.device + ) yield (problem_bias, subsampling, mat) del mat diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index d17f8a1e3..68060e990 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -52,17 +52,29 @@ def bias_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 "bias", mat, sum_batch, subsampling=subsampling ) - def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) + def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 + return self.param_jac_t_mat_prod( + "bias_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) - def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) + def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 + return self.param_jac_t_mat_prod( + "bias_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) - def weight_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("weight_ih_l0", mat, sum_batch, axis_batch=1) + def weight_ih_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 + return self.param_jac_t_mat_prod( + "weight_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) - def weight_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("weight_hh_l0", mat, sum_batch, axis_batch=1) + def weight_hh_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 + return self.param_jac_t_mat_prod( + "weight_hh_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) def param_jac_t_vec_prod( self, diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 2662c6847..cced2fc93 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -70,28 +70,52 @@ def bias_jac_mat_prod(self, mat): # noqa: D102 self.problem.module, None, None, mat ) - def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 self.store_forward_io() return self.problem.derivative.bias_ih_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) - def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 self.store_forward_io() return self.problem.derivative.bias_hh_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) - def weight_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def weight_ih_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 self.store_forward_io() return self.problem.derivative.weight_ih_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) - def weight_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def weight_hh_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 self.store_forward_io() return self.problem.derivative.weight_hh_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102 diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index b6eedb421..f64115357 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -99,12 +99,15 @@ def bias_jac_mat_prod(self, mat: Tensor) -> Tensor: raise NotImplementedError @abstractmethod - def bias_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def bias_ih_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product @@ -112,12 +115,15 @@ def bias_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: raise NotImplementedError @abstractmethod - def bias_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def bias_hh_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product @@ -125,12 +131,15 @@ def bias_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: raise NotImplementedError @abstractmethod - def weight_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def weight_ih_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product @@ -138,12 +147,15 @@ def weight_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: raise NotImplementedError @abstractmethod - def weight_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def weight_hh_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product diff --git a/test/extensions/secondorder/diag_ggn/diagggn_settings.py b/test/extensions/secondorder/diag_ggn/diagggn_settings.py deleted file mode 100644 index d0d806671..000000000 --- a/test/extensions/secondorder/diag_ggn/diagggn_settings.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Test cases for BackPACK extensions for the GGN diagonal. - -Includes -- ``DiagGGNExact`` -- ``DiagGGNMC`` -- ``BatchDiagGGNExact`` -- ``BatchDiagGGNMC`` - -Shared settings are taken from `test.extensions.secondorder.secondorder_settings`. -Additional local cases can be defined here through ``LOCAL_SETTINGS``. -""" - -from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS - -SHARED_SETTINGS = SECONDORDER_SETTINGS -LOCAL_SETTINGS = [] - -DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS From f3296513b81244ae7b5ff06e37bf86c3f79adf6e Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 7 Jul 2021 09:34:12 +0200 Subject: [PATCH 32/54] [TEST] Reduce run time (#199) - make automated settings' last linear layers non-differentiable to save time - update new test suite, as it assumes all model parameters are differentiable - update `ggn_vector_product` to work with models that contain non-differentiable parameters and fully-document `ggnvp.py` --- * [DOC] Fully-document GGNVP and ignore non-differentiable parameters * [TEST] Make automated settings' last linear layers non-differentiable The final layers can have large that make the computation of second-order quantities expensive. Disabling their `requires_grad` speeds up the tests. * [TEST] Check non-differentiable parameters while collecting results * [DOC] Fully document automated test setting helpers --- backpack/hessianfree/ggnvp.py | 80 ++++---- fully_documented.txt | 3 + test/extensions/automated_settings.py | 179 ++++++++++-------- test/extensions/implementation/autograd.py | 17 +- test/extensions/implementation/backpack.py | 59 ++---- test/extensions/problem.py | 38 ++++ .../secondorder/sqrt_ggn/test_sqrt_ggn.py | 4 +- 7 files changed, 213 insertions(+), 167 deletions(-) diff --git a/backpack/hessianfree/ggnvp.py b/backpack/hessianfree/ggnvp.py index 92c083636..165aff253 100644 --- a/backpack/hessianfree/ggnvp.py +++ b/backpack/hessianfree/ggnvp.py @@ -1,42 +1,56 @@ -from .hvp import hessian_vector_product -from .lop import L_op -from .rop import R_op - - -def ggn_vector_product(loss, output, model, v): +"""Autodiff-only matrix-free multiplication by the generalized Gauss-Newton/Fisher.""" +from typing import List, Tuple + +from torch import Tensor +from torch.nn import Module +from torch.nn.parameter import Parameter + +from backpack.hessianfree.hvp import hessian_vector_product +from backpack.hessianfree.lop import L_op +from backpack.hessianfree.rop import R_op + + +def ggn_vector_product( + loss: Tensor, output: Tensor, model: Module, v: List[Tensor] +) -> Tuple[Tensor]: + """Multiply a vector with the generalized Gauss-Newton/Fisher. + + Note: + ``G v = J.T @ H @ J @ v`` where ``J`` is the Jacobian of ``output`` w.r.t. + ``model``'s trainable parameters and `H` is the Hessian of `loss` w.r.t. + ``output``. ``v`` is the flattened and concatenated version of the passed + list of vectors. + + Args: + loss: Scalar tensor that represents the loss. + output: Model output. + model: The model. + v: Vector specified as list of tensors matching the trainable parameters. + + Returns: + GGN-vector product in list format, i.e. as list that matches the sizes + of trainable model parameters. """ - Multiplies the vector `v` with the Generalized Gauss-Newton, - `ggn_v = J.T @ H @ J @ v` - - where `J` is the Jacobian of `output` w.r.t. `model.parameters()` - and `H` is the Hessian of `loss` w.r.t. `output`. + return ggn_vector_product_from_plist( + loss, output, [p for p in model.parameters() if p.requires_grad], v + ) - Example usage: - ``` - X, Y = data() - model = torch.nn.Linear(784, 10) - lossfunc = torch.nn.CrossEntropyLoss() - output = model(X) - loss = lossfunc(output, Y) +def ggn_vector_product_from_plist( + loss: Tensor, output: Tensor, plist: List[Parameter], v: List[Tensor] +) -> Tuple[Tensor]: + """Multiply a vector with a sub-block of the generalized Gauss-Newton/Fisher. - v = list([torch.randn_like(p) for p in model.parameters]) + Args: + loss: Scalar tensor that represents the loss. + output: Model output. + plist: List of trainable parameters whose GGN block is used for multiplication. + v: Vector specified as list of tensors matching the sizes of ``plist``. - GGNv = ggn_vector_product(loss, output, model, v) - ``` - - Parameters: - ----------- - loss: torch.Tensor - output: torch.Tensor - model: torch.nn.Module - v: [torch.Tensor] - List of tensors matching the sizes of model.parameters() + Returns: + GGN-vector product in list format, i.e. as list that matches the sizes of + ``plist``. """ - return ggn_vector_product_from_plist(loss, output, list(model.parameters()), v) - - -def ggn_vector_product_from_plist(loss, output, plist, v): Jv = R_op(output, plist, v) HJv = hessian_vector_product(loss, output, Jv) JTHJv = L_op(output, plist, HJv) diff --git a/fully_documented.txt b/fully_documented.txt index 15528fdbb..f3542299e 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -39,10 +39,13 @@ backpack/extensions/secondorder/diag_hessian/conv2d.py backpack/extensions/secondorder/diag_hessian/conv3d.py backpack/extensions/secondorder/sqrt_ggn/ +backpack/hessianfree/ggnvp.py + backpack/utils/linear.py backpack/utils/subsampling.py backpack/utils/__init__.py +test/extensions/automated_settings.py test/extensions/problem.py test/extensions/test_backprop_extension.py test/extensions/firstorder/firstorder_settings.py diff --git a/test/extensions/automated_settings.py b/test/extensions/automated_settings.py index 58fbec8ac..f2334c515 100644 --- a/test/extensions/automated_settings.py +++ b/test/extensions/automated_settings.py @@ -1,35 +1,45 @@ +"""Contains helpers to create CNN test cases.""" from test.core.derivatives.utils import classification_targets +from typing import Any, Tuple, Type -import torch +from torch import Tensor, rand +from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, Module, ReLU, Sequential -### -# Helpers -### +def set_requires_grad(model: Module, new_requires_grad: bool) -> None: + """Set the ``requires_grad`` attribute of the model parameters. -def make_simple_act_setting(act_cls, bias): + Args: + model: Network or layer. + new_requires_grad: New value for ``requires_grad``. """ - input: Activation function & Bias setting - return: simple CNN Network + for p in model.parameters(): + p.requires_grad = new_requires_grad - This function is used to automatically create a - simple CNN Network consisting of CNN & Linear layer - for different activation functions. - It is used to test `test.extensions`. + +def make_simple_act_setting(act_cls: Type[Module], bias: bool) -> dict: + """Create a simple CNN with activation as test case dictionary. + + Make parameters of final linear layer non-differentiable to save run time. + + Args: + act_cls: Class of the activation function. + bias: Use bias in the convolution. + + Returns: + Dictionary representation of the simple CNN test case. """ - def make_simple_cnn(act_cls, bias): - return torch.nn.Sequential( - torch.nn.Conv2d(3, 2, 2, bias=bias), - act_cls(), - torch.nn.Flatten(), - torch.nn.Linear(72, 5), - ) + def _make_simple_cnn(act_cls: Type[Module], bias: bool) -> Sequential: + linear = Linear(72, 5) + set_requires_grad(linear, False) + + return Sequential(Conv2d(3, 2, 2, bias=bias), act_cls(), Flatten(), linear) dict_setting = { - "input_fn": lambda: torch.rand(3, 3, 7, 7), - "module_fn": lambda: make_simple_cnn(act_cls, bias), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(), + "input_fn": lambda: rand(3, 3, 7, 7), + "module_fn": lambda: _make_simple_cnn(act_cls, bias), + "loss_function_fn": lambda: CrossEntropyLoss(), "target_fn": lambda: classification_targets((3,), 5), "id_prefix": "automated-simple-cnn-act", } @@ -37,40 +47,37 @@ def make_simple_cnn(act_cls, bias): return dict_setting -def make_simple_cnn_setting(input_size, conv_class, conv_params): - """ - input_size: tuple of input size of (N*C*Image Size) - conv_class: convolutional class - conv_params: configurations for convolutional class - return: simple CNN Network - - This function is used to automatically create a - simple CNN Network consisting of CNN & Linear layer - for different convolutional layers. - It is used to test `test.extensions`. +def make_simple_cnn_setting( + input_size: Tuple[int], conv_cls: Type[Module], conv_params: Tuple[Any] +) -> dict: + """Create ReLU CNN with convolution hyperparameters as test case dictionary. + + Make parameters of final linear layer non-differentiable to save run time. + + Args: + input_size: Input shape ``[N, C_in, ...]``. + conv_cls: Class of convolution layer. + conv_params: Convolution hyperparameters. + + Returns: + Dictionary representation of the test case. """ - def make_cnn(conv_class, output_size, conv_params): - """Note: output class size is assumed to be 5""" - return torch.nn.Sequential( - conv_class(*conv_params), - torch.nn.ReLU(), - torch.nn.Flatten(), - torch.nn.Linear(output_size, 5), - ) + def _make_cnn( + conv_cls: Type[Module], output_dim: int, conv_params: Tuple + ) -> Sequential: + linear = Linear(output_dim, 5) + set_requires_grad(linear, False) - def get_output_shape(module, module_params, input): - """Returns the output shape for a given layer.""" - output = module(*module_params)(input) - return output.numel() // output.shape[0] + return Sequential(conv_cls(*conv_params), ReLU(), Flatten(), linear) - input = torch.rand(input_size) - output_size = get_output_shape(conv_class, conv_params, input) + input = rand(input_size) + output_dim = _get_output_dim(conv_cls(*conv_params), input) dict_setting = { - "input_fn": lambda: torch.rand(input_size), - "module_fn": lambda: make_cnn(conv_class, output_size, conv_params), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(input_size), + "module_fn": lambda: _make_cnn(conv_cls, output_dim, conv_params), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), "id_prefix": "automated-simple-cnn", } @@ -78,49 +85,59 @@ def get_output_shape(module, module_params, input): return dict_setting -def make_simple_pooling_setting(input_size, conv_class, pool_cls, pool_params): - """ - input_size: tuple of input size of (N*C*Image Size) - conv_class: convolutional class - conv_params: configurations for convolutional class - return: simple CNN Network - - This function is used to automatically create a - simple CNN Network consisting of CNN & Linear layer - for different convolutional layers. - It is used to test `test.extensions`. +def make_simple_pooling_setting( + input_size: Tuple[int], + conv_cls: Type[Module], + pool_cls: Type[Module], + pool_params: Tuple[Any], +) -> dict: + """Create CNN with convolution and pooling layer as test case dictionary. + + Make parameters of final linear layer non-differentiable to save run time. + + Args: + input_size: Input shape ``[N, C_in, ...]``. + conv_cls: Class of convolution layer. + pool_cls: Class of pooling layer. + pool_params: Pooling hyperparameters. + + Returns: + Dictionary representation of the test case. """ - def make_cnn(conv_class, output_size, conv_params, pool_cls, pool_params): - """Note: output class size is assumed to be 5""" - return torch.nn.Sequential( - conv_class(*conv_params), - torch.nn.ReLU(), - pool_cls(*pool_params), - torch.nn.Flatten(), - torch.nn.Linear(output_size, 5), + def _make_cnn( + conv_cls: Type[Module], + output_size: int, + conv_params: Tuple[Any], + pool_cls: Type[Module], + pool_params: Tuple[Any], + ) -> Sequential: + linear = Linear(output_size, 5) + set_requires_grad(linear, False) + + return Sequential( + conv_cls(*conv_params), ReLU(), pool_cls(*pool_params), Flatten(), linear ) - def get_output_shape(module, module_params, input, pool, pool_params): - """Returns the output shape for a given layer.""" - output_1 = module(*module_params)(input) - output = pool_cls(*pool_params)(output_1) - return output.numel() // output.shape[0] - conv_params = (3, 2, 2) - input = torch.rand(input_size) - output_size = get_output_shape( - conv_class, conv_params, input, pool_cls, pool_params + input = rand(input_size) + output_dim = _get_output_dim( + Sequential(conv_cls(*conv_params), pool_cls(*pool_params)), input ) dict_setting = { - "input_fn": lambda: torch.rand(input_size), - "module_fn": lambda: make_cnn( - conv_class, output_size, conv_params, pool_cls, pool_params + "input_fn": lambda: rand(input_size), + "module_fn": lambda: _make_cnn( + conv_cls, output_dim, conv_params, pool_cls, pool_params ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), "id_prefix": "automated-simple-cnn", } return dict_setting + + +def _get_output_dim(module: Module, input: Tensor) -> int: + output = module(input) + return output.numel() // output.shape[0] diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index c86350ebd..2edcf9f1e 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -18,14 +18,14 @@ def batch_grad(self) -> List[Tensor]: # noqa: D102 N = self.problem.input.shape[0] batch_grads = [ torch.zeros(N, *p.size()).to(self.problem.device) - for p in self.problem.model.parameters() + for p in self.problem.trainable_parameters() ] loss_list = torch.zeros((N)) gradients_list = [] for b in range(N): _, _, loss = self.problem.forward_pass(sample_idx=b) - gradients = torch.autograd.grad(loss, self.problem.model.parameters()) + gradients = torch.autograd.grad(loss, self.problem.trainable_parameters()) gradients_list.append(gradients) loss_list[b] = loss @@ -47,14 +47,14 @@ def sgs(self) -> List[Tensor]: # noqa: D102 N = self.problem.input.shape[0] sgs = [ torch.zeros(*p.size()).to(self.problem.device) - for p in self.problem.model.parameters() + for p in self.problem.trainable_parameters() ] loss_list = torch.zeros((N)) gradients_list = [] for b in range(N): _, _, loss = self.problem.forward_pass(sample_idx=b) - gradients = torch.autograd.grad(loss, self.problem.model.parameters()) + gradients = torch.autograd.grad(loss, self.problem.trainable_parameters()) loss_list[b] = loss gradients_list.append(gradients) @@ -81,7 +81,7 @@ def extract_ith_element_of_diag_ggn(i, p, loss, output): return GGN_v[i] diag_ggns = [] - for p in list(self.problem.model.parameters()): + for p in list(self.problem.trainable_parameters()): diag_ggn_p = torch.zeros_like(p).view(-1) for parameter_index in range(p.numel()): @@ -146,7 +146,7 @@ def extract_ith_element_of_diag_h(i, p, df_dx): return Hv[i] diag_hs = [] - for p in list(self.problem.model.parameters()): + for p in list(self.problem.trainable_parameters()): diag_h_p = torch.zeros_like(p).view(-1) df_dx = torch.autograd.grad(loss, [p], create_graph=True, retain_graph=True) @@ -179,8 +179,9 @@ def diag_h_batch(self) -> List[Tensor]: # noqa: D102 def ggn(self) -> Tensor: # noqa: D102 _, output, loss = self.problem.forward_pass() model = self.problem.model + params = list(self.problem.trainable_parameters()) - num_params = sum(p.numel() for p in model.parameters()) + num_params = sum(p.numel() for p in params) ggn = torch.zeros(num_params, num_params).to(self.problem.device) for i in range(num_params): @@ -189,7 +190,7 @@ def ggn(self) -> Tensor: # noqa: D102 e_i[i] = 1.0 # convert to model parameter shapes - e_i_list = vector_to_parameter_list(e_i, model.parameters()) + e_i_list = vector_to_parameter_list(e_i, params) ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) ggn_i = parameters_to_vector(ggn_i_list) diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index c8d4d90bf..3340654e6 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -30,15 +30,13 @@ def batch_grad(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchGrad()): _, _, loss = self.problem.forward_pass() loss.backward() - batch_grads = [p.grad_batch for p in self.problem.model.parameters()] - return batch_grads + return self.problem.collect_data("grad_batch") def batch_l2_grad(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchL2Grad()): _, _, loss = self.problem.forward_pass() loss.backward() - batch_l2_grad = [p.batch_l2 for p in self.problem.model.parameters()] - return batch_l2_grad + return self.problem.collect_data("batch_l2") def batch_l2_grad_extension_hook(self) -> List[Tensor]: """Individual gradient squared ℓ₂ norms via extension hook. @@ -50,15 +48,13 @@ def batch_l2_grad_extension_hook(self) -> List[Tensor]: with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() - batch_l2_grad = [p.batch_l2_hook for p in self.problem.model.parameters()] - return batch_l2_grad + return self.problem.collect_data("batch_l2_hook") def sgs(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.SumGradSquared()): _, _, loss = self.problem.forward_pass() loss.backward() - sgs = [p.sum_grad_squared for p in self.problem.model.parameters()] - return sgs + return self.problem.collect_data("sum_grad_squared") def sgs_extension_hook(self) -> List[Tensor]: """Individual gradient second moment via extension hook. @@ -70,47 +66,37 @@ def sgs_extension_hook(self) -> List[Tensor]: with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() - sgs = [p.sum_grad_squared_hook for p in self.problem.model.parameters()] - return sgs + return self.problem.collect_data("sum_grad_squared_hook") def variance(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.Variance()): _, _, loss = self.problem.forward_pass() loss.backward() - variances = [p.variance for p in self.problem.model.parameters()] - return variances + return self.problem.collect_data("variance") def diag_ggn(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn = [p.diag_ggn_exact for p in self.problem.model.parameters()] - return diag_ggn + return self.problem.collect_data("diag_ggn_exact") def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn_exact_batch = [ - p.diag_ggn_exact_batch for p in self.problem.model.parameters() - ] - return diag_ggn_exact_batch + return self.problem.collect_data("diag_ggn_exact_batch") def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn_mc = [p.diag_ggn_mc for p in self.problem.model.parameters()] - return diag_ggn_mc + return self.problem.collect_data("diag_ggn_mc") def diag_ggn_mc_batch(self, mc_samples) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn_mc_batch = [ - p.diag_ggn_mc_batch for p in self.problem.model.parameters() - ] - return diag_ggn_mc_batch + return self.problem.collect_data("diag_ggn_mc_batch") def diag_ggn_mc_chunk(self, mc_samples: int, chunks: int = 10) -> List[Tensor]: """Like ``diag_ggn_mc``, but can handle more samples by chunking. @@ -198,40 +184,31 @@ def diag_h(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagHessian()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_h = [p.diag_h for p in self.problem.model.parameters()] - return diag_h + return self.problem.collect_data("diag_h") def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFAC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - kfac = [p.kfac for p in self.problem.model.parameters()] - - return kfac + return self.problem.collect_data("kfac") def kflr(self) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFLR()): _, _, loss = self.problem.forward_pass() loss.backward() - kflr = [p.kflr for p in self.problem.model.parameters()] - - return kflr + return self.problem.collect_data("kflr") def kfra(self) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFRA()): _, _, loss = self.problem.forward_pass() loss.backward() - kfra = [p.kfra for p in self.problem.model.parameters()] - - return kfra + return self.problem.collect_data("kfra") def diag_h_batch(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagHessian()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_h_batch = [p.diag_h_batch for p in self.problem.model.parameters()] - - return diag_h_batch + return self.problem.collect_data("diag_h_batch") def ggn(self) -> Tensor: # noqa:D102 return self._square_sqrt_ggn(self.sqrt_ggn()) @@ -245,8 +222,7 @@ def sqrt_ggn(self) -> List[Tensor]: with backpack(new_ext.SqrtGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() - - return [p.sqrt_ggn_exact for p in self.problem.model.parameters()] + return self.problem.collect_data("sqrt_ggn_exact") def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: """Compute the approximate matrix square root of the generalized Gauss-Newton. @@ -260,8 +236,7 @@ def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: with backpack(new_ext.SqrtGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - - return [p.sqrt_ggn_mc for p in self.problem.model.parameters()] + return self.problem.collect_data("sqrt_ggn_mc") def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: # noqa:D102 samples = self.chunk_sizes(mc_samples, chunks) diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 05f25afe1..0d940cec6 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -2,8 +2,10 @@ import copy from test.core.derivatives.utils import get_available_devices +from typing import Any, Iterator, List import torch +from torch.nn.parameter import Parameter from backpack import extend @@ -191,3 +193,39 @@ def get_reduction_factor(self, loss, unreduced_loss): f"'mean': {mean_loss}, 'sum': {sum_loss}, loss: {loss}", ) return factor + + def trainable_parameters(self) -> Iterator[Parameter]: + """Yield the model's trainable parameters. + + Yields: + Model parameter with gradients enabled. + """ + for p in self.model.parameters(): + if p.requires_grad: + yield p + + def collect_data(self, savefield: str) -> List[Any]: + """Collect BackPACK attributes from trainable parameters. + + Args: + savefield: Attribute name. + + Returns: + List of attributes saved under the trainable model parameters. + + Raises: + RuntimeError: If a non-differentiable parameter with the attribute is + encountered. + """ + data = [] + + for p in self.model.parameters(): + if p.requires_grad: + data.append(getattr(p, savefield)) + else: + if hasattr(p, savefield): + raise RuntimeError( + f"Found non-differentiable parameter with attribute '{savefield}'." + ) + + return data diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py index ba3a0436c..75b1dbee4 100644 --- a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -41,9 +41,7 @@ def small_problem( Yields: Instantiated test case whose model's are small enough. """ - num_params = sum( - p.numel() for p in instantiated_problem.model.parameters() if p.requires_grad - ) + num_params = sum(p.numel() for p in instantiated_problem.trainable_parameters()) if num_params <= max_num_params: yield instantiated_problem else: From 69a83a45aaa2530c759f226f8b463d8821d8ad05 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 7 Jul 2021 19:36:26 +0200 Subject: [PATCH 33/54] [ADD] Sub-sampling in `BatchGrad` extension (#200) Extends BackPACK's individual gradient extension and allows specifying the mini-batch samples for which the quantity will be extracted (#12), rather than evaluating individual gradients for every sample in the mini-batch. Auxiliary: - Improve type annotations and code of abstract module extension class - Support sub-sampling in the test suite and reduce duplication - Rewrite `BatchGrad` tests using `pytest fixtures` --- * [ADD] Support sub-sampling in ``BatchGrad`` extension * [REF] Reduce duplication, remove redundant function calls * [REF] Reduce duplication of GGNVPs * [FIX] flake8 Co-authored-by: Felix Dangel --- .../firstorder/batch_grad/__init__.py | 21 ++- .../firstorder/batch_grad/batch_grad_base.py | 66 ++++++--- fully_documented.txt | 4 +- ...rad_settings.py => batch_grad_settings.py} | 6 +- .../firstorder/batch_grad/test_batch_grad.py | 51 +++++++ .../firstorder/batch_grad/test_batchgrad.py | 34 ----- test/extensions/implementation/autograd.py | 137 +++++++----------- test/extensions/implementation/backpack.py | 4 +- test/extensions/implementation/base.py | 10 +- test/extensions/utils.py | 23 +++ 10 files changed, 207 insertions(+), 149 deletions(-) rename test/extensions/firstorder/batch_grad/{batchgrad_settings.py => batch_grad_settings.py} (54%) create mode 100644 test/extensions/firstorder/batch_grad/test_batch_grad.py delete mode 100644 test/extensions/firstorder/batch_grad/test_batchgrad.py create mode 100644 test/extensions/utils.py diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index bee9c069a..897170860 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -2,6 +2,8 @@ It defines the module extension for each module. """ +from typing import List, Union + from torch.nn import ( RNN, BatchNorm1d, @@ -35,6 +37,9 @@ class BatchGrad(BackpropExtension): Stores the output in ``grad_batch`` as a ``[N x ...]`` tensor, where ``N`` batch size and ``...`` is the shape of the gradient. + If ``subsampling`` is specified, ``N`` is replaced by the number of active + samples. + .. note:: Beware of scaling issue @@ -50,10 +55,14 @@ class BatchGrad(BackpropExtension): objective is a sum of independent functions (no batchnorm). """ - def __init__(self): + def __init__(self, subsampling: List[int] = None): """Initialization. Defines extension for each module. + + Args: + subsampling: Indices of samples in the mini-batch for which individual + gradients will be computed. Defaults to ``None`` (use all samples). """ super().__init__( savefield="grad_batch", @@ -70,3 +79,13 @@ def __init__(self): RNN: rnn.BatchGradRNN(), }, ) + self._subsampling = subsampling + + def get_subsampling(self) -> Union[List[int], None]: + """Get the indices of samples for which individual gradients are requested. + + Returns: + List of indices containing the active samples in the mini-batch. ``None`` + means all samples will be considered. + """ + return self._subsampling diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index d6bd89822..2ba72cce4 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -1,5 +1,17 @@ """Calculates the batch_grad derivative.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Tuple + +from torch import Tensor +from torch.nn import Module + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.utils.subsampling import get_batch_axis, subsample + +if TYPE_CHECKING: + from backpack.extensions.firstorder import BatchGrad class BatchGradBase(FirstOrderModuleExtension): @@ -18,48 +30,64 @@ class BatchGradBase(FirstOrderModuleExtension): In this case, the method is not overwritten by this class. """ - def __init__(self, derivatives, params): + def __init__( + self, derivatives: BaseParameterDerivatives, params: List[str] + ) -> None: """Initializes all methods. If the param method has already been defined, it is left unchanged. Args: - derivatives(backpack.core.derivatives.basederivatives.BaseParameterDerivatives): # noqa: B950 - Derivatives object assigned to self.derivatives. - params (list[str]): list of strings with parameter names. - For each, a method is assigned. + derivatives: Derivatives object used to apply parameter Jacobians. + params: List of parameter names. """ - self.derivatives = derivatives + self._derivatives = derivatives for param_str in params: if not hasattr(self, param_str): setattr(self, param_str, self._make_param_function(param_str)) super().__init__(params=params) - def _make_param_function(self, param): - """Creates a function that calculates batch_grad wrt param. + def _make_param_function( + self, param_str: str + ) -> Callable[[BatchGrad, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]: + """Creates a function that calculates batch_grad w.r.t. param. Args: - param(str): name of parameter + param_str: Parameter name. Returns: - function: function that calculates batch_grad wrt param + Function that calculates batch_grad wrt param """ - def param_function(ext, module, g_inp, g_out, bpQuantities): + def param_function( + ext: BatchGrad, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + bpQuantities: None, + ) -> Tensor: """Calculates batch_grad with the help of derivatives object. Args: - ext(backpack.extensions.BatchGrad): extension that is used - module(torch.nn.Module): module that performed forward pass - g_inp(tuple[torch.Tensor]): input gradient tensors - g_out(tuple[torch.Tensor]): output gradient tensors - bpQuantities(None): additional quantities for second order + ext: extension that is used + module: module that performed forward pass + g_inp: input gradient tensors + g_out: output gradient tensors + bpQuantities: additional quantities for second order Returns: - torch.Tensor: scaled individual gradients + Scaled individual gradients """ - return getattr(self.derivatives, f"{param}_jac_t_mat_prod")( - module, g_inp, g_out, g_out[0], sum_batch=False + subsampling = ext.get_subsampling() + return getattr(self._derivatives, f"{param_str}_jac_t_mat_prod")( + module, + g_inp, + g_out, + subsample( + g_out[0], dim=get_batch_axis(module), subsampling=subsampling + ), + sum_batch=False, + subsampling=subsampling, ) return param_function diff --git a/fully_documented.txt b/fully_documented.txt index f3542299e..d6e752a97 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -47,10 +47,12 @@ backpack/utils/__init__.py test/extensions/automated_settings.py test/extensions/problem.py +test/extensions/utils.py test/extensions/test_backprop_extension.py test/extensions/firstorder/firstorder_settings.py test/extensions/firstorder/variance/ -test/extensions/firstorder/batch_grad/batchgrad_settings.py +test/extensions/firstorder/batch_grad/batch_grad_settings.py +test/extensions/firstorder/batch_grad/test_batch_grad.py test/extensions/secondorder/secondorder_settings.py test/extensions/secondorder/diag_ggn/ test/extensions/secondorder/hbp/ diff --git a/test/extensions/firstorder/batch_grad/batchgrad_settings.py b/test/extensions/firstorder/batch_grad/batch_grad_settings.py similarity index 54% rename from test/extensions/firstorder/batch_grad/batchgrad_settings.py rename to test/extensions/firstorder/batch_grad/batch_grad_settings.py index 3920b5b5a..7b1926d63 100644 --- a/test/extensions/firstorder/batch_grad/batchgrad_settings.py +++ b/test/extensions/firstorder/batch_grad/batch_grad_settings.py @@ -1,6 +1,6 @@ -"""Test configurations to test batch_grad. +"""Test cases for BackPACK's ``BatchGrad`` extension. -The tests are taken from `test.extensions.firstorder.firstorder_settings`, +The tests are taken from ``test.extensions.firstorder.firstorder_settings``, but additional custom tests can be defined here by appending it to the list. """ from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS @@ -8,4 +8,4 @@ SHARED_SETTINGS = FIRSTORDER_SETTINGS LOCAL_SETTINGS = [] -BATCHGRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS +BATCH_GRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/firstorder/batch_grad/test_batch_grad.py b/test/extensions/firstorder/batch_grad/test_batch_grad.py new file mode 100644 index 000000000..7c916568e --- /dev/null +++ b/test/extensions/firstorder/batch_grad/test_batch_grad.py @@ -0,0 +1,51 @@ +"""Test BackPACK's ``BatchGrad`` extension.""" +from test.automated_test import check_sizes_and_values +from test.extensions.firstorder.batch_grad.batch_grad_settings import ( + BATCH_GRAD_SETTINGS, +) +from test.extensions.implementation.autograd import AutogradExtensions +from test.extensions.implementation.backpack import BackpackExtensions +from test.extensions.problem import ExtensionsTestProblem, make_test_problems +from test.extensions.utils import skip_if_subsampling_conflict +from typing import List, Union + +from pytest import fixture, mark + +PROBLEMS = make_test_problems(BATCH_GRAD_SETTINGS) + +SUBSAMPLINGS = [None, [0, 0], [2, 0]] +SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] + + +@fixture(params=PROBLEMS, ids=lambda p: p.make_id()) +def problem(request) -> ExtensionsTestProblem: + """Set up and tear down a test case. + + Args: + request: Pytest request. + + Yields: + Instantiated test case. + """ + problem = request.param + problem.set_up() + yield problem + problem.tear_down() + + +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +def test_batch_grad( + problem: ExtensionsTestProblem, subsampling: Union[List[int], None] +) -> None: + """Test individual gradients. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + skip_if_subsampling_conflict(problem, subsampling) + + backpack_res = BackpackExtensions(problem).batch_grad(subsampling) + autograd_res = AutogradExtensions(problem).batch_grad(subsampling) + + check_sizes_and_values(autograd_res, backpack_res) diff --git a/test/extensions/firstorder/batch_grad/test_batchgrad.py b/test/extensions/firstorder/batch_grad/test_batchgrad.py deleted file mode 100644 index c93ae69df..000000000 --- a/test/extensions/firstorder/batch_grad/test_batchgrad.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Test class for module Batch_grad (batch gradients) -from `backpack.core.extensions.firstorder` - -Test individual gradients for the following layers: -- batch gradients of linear layers -- batch gradients of convolutional layers - -""" -from test.automated_test import check_sizes_and_values -from test.extensions.firstorder.batch_grad.batchgrad_settings import BATCHGRAD_SETTINGS -from test.extensions.implementation.autograd import AutogradExtensions -from test.extensions.implementation.backpack import BackpackExtensions -from test.extensions.problem import make_test_problems - -import pytest - -PROBLEMS = make_test_problems(BATCHGRAD_SETTINGS) -IDS = [problem.make_id() for problem in PROBLEMS] - - -@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_batch_grad(problem): - """Test individual gradients - - Args: - problem (ExtensionsTestProblem): Problem for extension test. - """ - problem.set_up() - - backpack_res = BackpackExtensions(problem).batch_grad() - autograd_res = AutogradExtensions(problem).batch_grad() - - check_sizes_and_values(autograd_res, backpack_res) - problem.tear_down() diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index 2edcf9f1e..ccee35a0a 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -1,97 +1,68 @@ """Autograd implementation of BackPACK's extensions.""" from test.extensions.implementation.base import ExtensionsImplementation -from typing import List +from typing import Iterator, List, Union -import torch -from torch import Tensor +from torch import Tensor, autograd, backends, cat, stack, var, zeros, zeros_like from torch.nn.utils.convert_parameters import parameters_to_vector -from backpack.hessianfree.ggnvp import ggn_vector_product, ggn_vector_product_from_plist +from backpack.hessianfree.ggnvp import ggn_vector_product from backpack.hessianfree.rop import R_op from backpack.utils.convert_parameters import vector_to_parameter_list +from backpack.utils.subsampling import get_batch_axis class AutogradExtensions(ExtensionsImplementation): """Extension implementations with autograd.""" - def batch_grad(self) -> List[Tensor]: # noqa: D102 - N = self.problem.input.shape[0] - batch_grads = [ - torch.zeros(N, *p.size()).to(self.problem.device) - for p in self.problem.trainable_parameters() - ] + def batch_grad( + self, subsampling: Union[List[int], None] + ) -> List[Tensor]: # noqa: D102 + N = self.problem.input.shape[get_batch_axis(self.problem.model)] + samples = list(range(N)) if subsampling is None else subsampling - loss_list = torch.zeros((N)) + loss_list = zeros(N) gradients_list = [] for b in range(N): _, _, loss = self.problem.forward_pass(sample_idx=b) - gradients = torch.autograd.grad(loss, self.problem.trainable_parameters()) + gradients = autograd.grad(loss, self.problem.trainable_parameters()) gradients_list.append(gradients) loss_list[b] = loss _, _, batch_loss = self.problem.forward_pass() factor = self.problem.get_reduction_factor(batch_loss, loss_list) - for b, gradients in zip(range(N), gradients_list): - for idx, g in enumerate(gradients): - batch_grads[idx][b, :] = g.detach() * factor + batch_grads = [ + zeros(len(samples), *p.size()).to(self.problem.device) + for p in self.problem.trainable_parameters() + ] + + for out_idx, sample in enumerate(samples): + for param_idx, sample_g in enumerate(gradients_list[sample]): + batch_grads[param_idx][out_idx, :] = sample_g.detach() * factor return batch_grads def batch_l2_grad(self) -> List[Tensor]: # noqa: D102 - batch_grad = self.batch_grad() - batch_l2_grads = [(g ** 2).flatten(start_dim=1).sum(1) for g in batch_grad] - return batch_l2_grads - - def sgs(self) -> List[Tensor]: # noqa: D102 - N = self.problem.input.shape[0] - sgs = [ - torch.zeros(*p.size()).to(self.problem.device) - for p in self.problem.trainable_parameters() + return [ + (g ** 2).flatten(start_dim=1).sum(1) + for g in self.batch_grad(subsampling=None) ] - loss_list = torch.zeros((N)) - gradients_list = [] - for b in range(N): - _, _, loss = self.problem.forward_pass(sample_idx=b) - gradients = torch.autograd.grad(loss, self.problem.trainable_parameters()) - loss_list[b] = loss - gradients_list.append(gradients) - - _, _, batch_loss = self.problem.forward_pass() - factor = self.problem.get_reduction_factor(batch_loss, loss_list) - - for _, gradients in zip(range(N), gradients_list): - for idx, g in enumerate(gradients): - sgs[idx] += (g.detach() * factor) ** 2 - return sgs + def sgs(self) -> List[Tensor]: # noqa: D102 + return [(g ** 2).sum(0) for g in self.batch_grad(subsampling=None)] def variance(self) -> List[Tensor]: # noqa: D102 - batch_grad = self.batch_grad() - variances = [torch.var(g, dim=0, unbiased=False) for g in batch_grad] - return variances - - def _get_diag_ggn(self, loss, output): - def extract_ith_element_of_diag_ggn(i, p, loss, output): - v = torch.zeros(p.numel()).to(self.problem.device) - v[i] = 1.0 - vs = vector_to_parameter_list(v, [p]) - GGN_vs = ggn_vector_product_from_plist(loss, output, [p], vs) - GGN_v = torch.cat([g.detach().view(-1) for g in GGN_vs]) - return GGN_v[i] - - diag_ggns = [] - for p in list(self.problem.trainable_parameters()): - diag_ggn_p = torch.zeros_like(p).view(-1) - - for parameter_index in range(p.numel()): - diag_value = extract_ith_element_of_diag_ggn( - parameter_index, p, loss, output - ) - diag_ggn_p[parameter_index] = diag_value + return [ + var(g, dim=0, unbiased=False) for g in self.batch_grad(subsampling=None) + ] - diag_ggns.append(diag_ggn_p.view(p.size())) - return diag_ggns + def _get_diag_ggn(self, loss: Tensor, output: Tensor) -> List[Tensor]: + diag_ggn_flat = cat( + [col[[i]] for i, col in enumerate(self._ggn_columns(loss, output))] + ) + return vector_to_parameter_list( + diag_ggn_flat, list(self.problem.trainable_parameters()) + ) def diag_ggn(self) -> List[Tensor]: # noqa: D102 try: @@ -100,7 +71,7 @@ def diag_ggn(self) -> List[Tensor]: # noqa: D102 except RuntimeError: # torch does not implement cuda double-backwards pass on RNNs and # recommends this workaround - with torch.backends.cudnn.flags(enabled=False): + with backends.cudnn.flags(enabled=False): _, output, loss = self.problem.forward_pass() return self._get_diag_ggn(loss, output) @@ -110,13 +81,13 @@ def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa: D102 except RuntimeError: # torch does not implement cuda double-backwards pass on RNNs and # recommends this workaround - with torch.backends.cudnn.flags(enabled=False): + with backends.cudnn.flags(enabled=False): return self._diag_ggn_exact_batch() def _diag_ggn_exact_batch(self): batch_size = self.problem.input.shape[0] _, _, batch_loss = self.problem.forward_pass() - loss_list = torch.zeros(batch_size, device=self.problem.device) + loss_list = zeros(batch_size, device=self.problem.device) # batch_diag_ggn has entries [sample_idx][param_idx] batch_diag_ggn = [] @@ -128,28 +99,24 @@ def _diag_ggn_exact_batch(self): factor = self.problem.get_reduction_factor(batch_loss, loss_list) # params_batch_diag_ggn has entries [param_idx][sample_idx] params_batch_diag_ggn = list(zip(*batch_diag_ggn)) - return [torch.stack(param) * factor for param in params_batch_diag_ggn] + return [stack(param) * factor for param in params_batch_diag_ggn] def _get_diag_h(self, loss): - def hvp(df_dx, x, v): - Hv = R_op(df_dx, x, v) - return [j.detach() for j in Hv] - def extract_ith_element_of_diag_h(i, p, df_dx): - v = torch.zeros(p.numel()).to(self.problem.device) + v = zeros_like(p).flatten() v[i] = 1.0 vs = vector_to_parameter_list(v, [p]) - Hvs = hvp(df_dx, [p], vs) - Hv = torch.cat([g.detach().view(-1) for g in Hvs]) + Hvs = R_op(df_dx, [p], vs) + Hv = cat([g.flatten() for g in Hvs]) return Hv[i] diag_hs = [] for p in list(self.problem.trainable_parameters()): - diag_h_p = torch.zeros_like(p).view(-1) + diag_h_p = zeros_like(p).flatten() - df_dx = torch.autograd.grad(loss, [p], create_graph=True, retain_graph=True) + df_dx = autograd.grad(loss, [p], create_graph=True, retain_graph=True) for parameter_index in range(p.numel()): diag_value = extract_ith_element_of_diag_h(parameter_index, p, df_dx) diag_h_p[parameter_index] = diag_value @@ -164,7 +131,7 @@ def diag_h(self) -> List[Tensor]: # noqa: D102 def diag_h_batch(self) -> List[Tensor]: # noqa: D102 batch_size = self.problem.input.shape[0] _, _, batch_loss = self.problem.forward_pass() - loss_list = torch.zeros(batch_size, device=self.problem.device) + loss_list = zeros(batch_size, device=self.problem.device) batch_diag_h = [] for b in range(batch_size): @@ -174,29 +141,27 @@ def diag_h_batch(self) -> List[Tensor]: # noqa: D102 batch_diag_h.append(diag_h) factor = self.problem.get_reduction_factor(batch_loss, loss_list) params_batch_diag_h = list(zip(*batch_diag_h)) - return [torch.stack(param) * factor for param in params_batch_diag_h] + return [stack(param) * factor for param in params_batch_diag_h] def ggn(self) -> Tensor: # noqa: D102 _, output, loss = self.problem.forward_pass() - model = self.problem.model - params = list(self.problem.trainable_parameters()) + return stack(list(self._ggn_columns(loss, output)), dim=1) + def _ggn_columns(self, loss: Tensor, output: Tensor) -> Iterator[Tensor]: + params = list(self.problem.trainable_parameters()) num_params = sum(p.numel() for p in params) - ggn = torch.zeros(num_params, num_params).to(self.problem.device) + model = self.problem.model for i in range(num_params): # GGN-vector product with i.th unit vector yields the i.th row - e_i = torch.zeros(num_params).to(self.problem.device) + e_i = zeros(num_params).to(self.problem.device) e_i[i] = 1.0 # convert to model parameter shapes e_i_list = vector_to_parameter_list(e_i, params) ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) - ggn_i = parameters_to_vector(ggn_i_list) - ggn[i, :] = ggn_i - - return ggn + yield parameters_to_vector(ggn_i_list) def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa: D102 raise NotImplementedError diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index 3340654e6..ff475913d 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -26,8 +26,8 @@ def __init__(self, problem: ExtensionsTestProblem): problem.extend() super().__init__(problem) - def batch_grad(self) -> List[Tensor]: # noqa:D102 - with backpack(new_ext.BatchGrad()): + def batch_grad(self, subsampling) -> List[Tensor]: # noqa:D102 + with backpack(new_ext.BatchGrad(subsampling=subsampling)): _, _, loss = self.problem.forward_pass() loss.backward() return self.problem.collect_data("grad_batch") diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index f3d73d91a..638c780df 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -1,7 +1,7 @@ """Base class containing the functions to compare BackPACK and autograd.""" from abc import ABC, abstractmethod from test.extensions.problem import ExtensionsTestProblem -from typing import List +from typing import List, Union from torch import Tensor @@ -18,8 +18,12 @@ def __init__(self, problem: ExtensionsTestProblem): self.problem = problem @abstractmethod - def batch_grad(self) -> List[Tensor]: - """Individual gradients.""" + def batch_grad(self, subsampling: Union[List[int], None]) -> List[Tensor]: + """Individual gradients. + + Args: + subsampling: List of active samples. ``None`` means all samples. + """ return @abstractmethod diff --git a/test/extensions/utils.py b/test/extensions/utils.py new file mode 100644 index 000000000..0965f55d0 --- /dev/null +++ b/test/extensions/utils.py @@ -0,0 +1,23 @@ +"""Utility functions for testing BackPACK's extensions.""" + +from test.extensions.problem import ExtensionsTestProblem +from typing import List, Union + +from pytest import skip + +from backpack.utils.subsampling import get_batch_axis + + +def skip_if_subsampling_conflict( + problem: ExtensionsTestProblem, subsampling: Union[List[int], None] +) -> None: + """Skip if some samples in subsampling are not contained in input. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + N = problem.input.shape[get_batch_axis(problem.model)] + enough_samples = subsampling is None or N >= max(subsampling) + if not enough_samples: + skip(f"Not enough samples: N={N}, subsampling={subsampling}") From 36c6bb3c51dbf4e555600c9789989a3feaa9e579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 21 Jul 2021 14:33:51 +0200 Subject: [PATCH 34/54] [ADD] First-order and DiagGGN extensions for BatchNormNd & AdaptiveAvgPoolNd (#201) - First-order extensions for `BatchNorm{1,2,3}d` - Test correctness in eval mode, warn user in train mode - `{Batch}DiagGGN` extensions for `BatchNorm{1,2,3}d` - Test correctness in eval mode, raise error in train mode - `{Batch}DiagGGN` extensions for `AdaptiveAvgPool{1,2,3]d` + tests --- * DiagGGN AdaptiveAvgPool * BatchNorm: first order extensions * BatchNorm: DiagGGN extension * fix subsampling test * fix test extension autograd.py * delete requires_grad, improve tests, skip AdaptiveAvgPool3d on cuda * split settings into sections * align names * AdaptiveAvgPool3d on cuda, assume solved for torch >= 1.9.1 * revert support for extensions tests without parameters * add tests with multiple BatchNorm layers * format * incorporate suggestions * BatchNorm extensions: raise error if training * BatchNorm, batch_grad: don't raise error in training mode * Improve error message Co-authored-by: Felix Dangel --- .../core/derivatives/adaptive_avg_pool_nd.py | 7 +- .../firstorder/batch_grad/__init__.py | 8 +- .../firstorder/batch_grad/batchnorm1d.py | 9 -- .../firstorder/batch_grad/batchnorm_nd.py | 18 ++++ .../firstorder/batch_l2_grad/__init__.py | 7 ++ .../firstorder/batch_l2_grad/batchnorm_nd.py | 16 +++ .../firstorder/gradient/batchnorm1d.py | 10 -- .../firstorder/gradient/batchnorm_nd.py | 19 ++++ .../firstorder/sum_grad_squared/__init__.py | 7 ++ .../sum_grad_squared/batchnorm_nd.py | 16 +++ .../firstorder/variance/__init__.py | 7 ++ .../firstorder/variance/batchnorm_nd.py | 17 +++ .../secondorder/diag_ggn/__init__.py | 20 ++++ .../diag_ggn/adaptive_avg_pool_nd.py | 15 +++ .../secondorder/diag_ggn/batchnorm_nd.py | 28 +++++ backpack/utils/__init__.py | 2 + backpack/utils/errors.py | 29 +++++ fully_documented.txt | 9 ++ test/core/derivatives/batch_norm_settings.py | 23 ++-- test/core/derivatives/derivatives_test.py | 20 +--- .../firstorder/firstorder_settings.py | 48 ++++++++- test/extensions/problem.py | 8 +- .../secondorder/diag_ggn/diag_ggn_settings.py | 101 +++++++++++++++++- .../diag_ggn/test_batch_diag_ggn.py | 5 +- .../secondorder/diag_ggn/test_diag_ggn.py | 5 +- test/extensions/utils.py | 2 +- test/utils/evaluation_mode.py | 40 +++++++ test/utils/skip_test.py | 23 ++++ 28 files changed, 452 insertions(+), 67 deletions(-) delete mode 100644 backpack/extensions/firstorder/batch_grad/batchnorm1d.py create mode 100644 backpack/extensions/firstorder/batch_grad/batchnorm_nd.py create mode 100644 backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py delete mode 100644 backpack/extensions/firstorder/gradient/batchnorm1d.py create mode 100644 backpack/extensions/firstorder/gradient/batchnorm_nd.py create mode 100644 backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py create mode 100644 backpack/extensions/firstorder/variance/batchnorm_nd.py create mode 100644 backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py create mode 100644 backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py create mode 100644 backpack/utils/errors.py create mode 100644 test/utils/evaluation_mode.py create mode 100644 test/utils/skip_test.py diff --git a/backpack/core/derivatives/adaptive_avg_pool_nd.py b/backpack/core/derivatives/adaptive_avg_pool_nd.py index 71529a267..2c0a4ed0a 100644 --- a/backpack/core/derivatives/adaptive_avg_pool_nd.py +++ b/backpack/core/derivatives/adaptive_avg_pool_nd.py @@ -6,6 +6,7 @@ from torch.nn import AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_1 class AdaptiveAvgPoolNDDerivatives(AvgPoolNDDerivatives): @@ -27,7 +28,11 @@ def check_parameters( Raises: NotImplementedError: if the given shapes do not match """ - if ("cuda" in str(module.input0.device)) and (self.N == 3): + if ( + TORCH_VERSION_AT_LEAST_1_9_1 is False + and module.input0.is_cuda + and (self.N == 3) + ): warn( "Be careful when computing gradients of AdaptiveAvgPool3d. " "There is a bug using autograd.grad on cuda with AdaptiveAvgPool3d. " diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 897170860..7fff33789 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -7,6 +7,8 @@ from torch.nn import ( RNN, BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv1d, Conv2d, Conv3d, @@ -19,7 +21,7 @@ from backpack.extensions.backprop_extension import BackpropExtension from . import ( - batchnorm1d, + batchnorm_nd, conv1d, conv2d, conv3d, @@ -75,7 +77,9 @@ def __init__(self, subsampling: List[int] = None): ConvTranspose1d: conv_transpose1d.BatchGradConvTranspose1d(), ConvTranspose2d: conv_transpose2d.BatchGradConvTranspose2d(), ConvTranspose3d: conv_transpose3d.BatchGradConvTranspose3d(), - BatchNorm1d: batchnorm1d.BatchGradBatchNorm1d(), + BatchNorm1d: batchnorm_nd.BatchGradBatchNormNd(), + BatchNorm2d: batchnorm_nd.BatchGradBatchNormNd(), + BatchNorm3d: batchnorm_nd.BatchGradBatchNormNd(), RNN: rnn.BatchGradRNN(), }, ) diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py b/backpack/extensions/firstorder/batch_grad/batchnorm1d.py deleted file mode 100644 index 46a41b739..000000000 --- a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py +++ /dev/null @@ -1,9 +0,0 @@ -from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives -from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase - - -class BatchGradBatchNorm1d(BatchGradBase): - def __init__(self): - super().__init__( - derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] - ) diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py new file mode 100644 index 000000000..cb0637ea2 --- /dev/null +++ b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py @@ -0,0 +1,18 @@ +"""Contains grad_batch extension for BatchNorm.""" +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase +from backpack.utils.errors import batch_norm_raise_error_if_train + + +class BatchGradBatchNormNd(BatchGradBase): + """BatchGrad extension for BatchNorm.""" + + def __init__(self): + """Initialization.""" + super().__init__( + derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] + ) + + def apply(self, ext, module, g_inp, g_out): # noqa: D102 + batch_norm_raise_error_if_train(module, raise_error=False) + super().apply(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 5864022c3..4f643fb48 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -5,6 +5,9 @@ """ from torch.nn import ( RNN, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv1d, Conv2d, Conv3d, @@ -16,6 +19,7 @@ from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.firstorder.batch_l2_grad import ( + batchnorm_nd, convnd, convtransposend, linear, @@ -58,5 +62,8 @@ def __init__(self): ConvTranspose2d: convtransposend.BatchL2ConvTranspose2d(), ConvTranspose3d: convtransposend.BatchL2ConvTranspose3d(), RNN: rnn.BatchL2RNN(), + BatchNorm1d: batchnorm_nd.BatchL2BatchNorm(), + BatchNorm2d: batchnorm_nd.BatchL2BatchNorm(), + BatchNorm3d: batchnorm_nd.BatchL2BatchNorm(), }, ) diff --git a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py new file mode 100644 index 000000000..42d6e2426 --- /dev/null +++ b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py @@ -0,0 +1,16 @@ +"""Contains batch_l2 extension for BatchNorm.""" +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base +from backpack.utils.errors import batch_norm_raise_error_if_train + + +class BatchL2BatchNorm(BatchL2Base): + """batch_l2 extension for BatchNorm.""" + + def __init__(self): + """Initialization.""" + super().__init__(["weight", "bias"], BatchNormNdDerivatives()) + + def apply(self, ext, module, g_inp, g_out): # noqa: D102 + batch_norm_raise_error_if_train(module) + super().apply(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/gradient/batchnorm1d.py b/backpack/extensions/firstorder/gradient/batchnorm1d.py deleted file mode 100644 index 92dbce28c..000000000 --- a/backpack/extensions/firstorder/gradient/batchnorm1d.py +++ /dev/null @@ -1,10 +0,0 @@ -from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives - -from .base import GradBaseModule - - -class GradBatchNorm1d(GradBaseModule): - def __init__(self): - super().__init__( - derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] - ) diff --git a/backpack/extensions/firstorder/gradient/batchnorm_nd.py b/backpack/extensions/firstorder/gradient/batchnorm_nd.py new file mode 100644 index 000000000..02dd78412 --- /dev/null +++ b/backpack/extensions/firstorder/gradient/batchnorm_nd.py @@ -0,0 +1,19 @@ +"""Gradient extension for BatchNorm.""" +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.utils.errors import batch_norm_raise_error_if_train + +from .base import GradBaseModule + + +class GradBatchNormNd(GradBaseModule): + """Gradient extension for BatchNorm.""" + + def __init__(self): + """Initialization.""" + super().__init__( + derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] + ) + + def apply(self, ext, module, g_inp, g_out): # noqa: D102 + batch_norm_raise_error_if_train(module) + super().apply(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index 97de5c258..a22f5b73f 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -4,6 +4,9 @@ """ from torch.nn import ( RNN, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv1d, Conv2d, Conv3d, @@ -16,6 +19,7 @@ from backpack.extensions.backprop_extension import BackpropExtension from . import ( + batchnorm_nd, conv1d, conv2d, conv3d, @@ -61,5 +65,8 @@ def __init__(self): ConvTranspose2d: convtranspose2d.SGSConvTranspose2d(), ConvTranspose3d: convtranspose3d.SGSConvTranspose3d(), RNN: rnn.SGSRNN(), + BatchNorm1d: batchnorm_nd.SGSBatchNormNd(), + BatchNorm2d: batchnorm_nd.SGSBatchNormNd(), + BatchNorm3d: batchnorm_nd.SGSBatchNormNd(), }, ) diff --git a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py new file mode 100644 index 000000000..866c95736 --- /dev/null +++ b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py @@ -0,0 +1,16 @@ +"""SGS extension for BatchNorm.""" +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase +from backpack.utils.errors import batch_norm_raise_error_if_train + + +class SGSBatchNormNd(SGSBase): + """SGS extension for BatchNorm.""" + + def __init__(self): + """Initialization.""" + super().__init__(BatchNormNdDerivatives(), ["weight", "bias"]) + + def apply(self, ext, module, g_inp, g_out): # noqa: D102 + batch_norm_raise_error_if_train(module) + super().apply(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index 5d8a5ebeb..f3e3085d2 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -4,6 +4,9 @@ """ from torch.nn import ( RNN, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv1d, Conv2d, Conv3d, @@ -16,6 +19,7 @@ from backpack.extensions.backprop_extension import BackpropExtension from . import ( + batchnorm_nd, conv1d, conv2d, conv3d, @@ -61,5 +65,8 @@ def __init__(self): ConvTranspose2d: convtranspose2d.VarianceConvTranspose2d(), ConvTranspose3d: convtranspose3d.VarianceConvTranspose3d(), RNN: rnn.VarianceRNN(), + BatchNorm1d: batchnorm_nd.VarianceBatchNormNd(), + BatchNorm2d: batchnorm_nd.VarianceBatchNormNd(), + BatchNorm3d: batchnorm_nd.VarianceBatchNormNd(), }, ) diff --git a/backpack/extensions/firstorder/variance/batchnorm_nd.py b/backpack/extensions/firstorder/variance/batchnorm_nd.py new file mode 100644 index 000000000..6bcaa9386 --- /dev/null +++ b/backpack/extensions/firstorder/variance/batchnorm_nd.py @@ -0,0 +1,17 @@ +"""Variance extension for BatchNorm.""" +from backpack.extensions.firstorder.gradient.batchnorm_nd import GradBatchNormNd +from backpack.extensions.firstorder.sum_grad_squared.batchnorm_nd import SGSBatchNormNd +from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule +from backpack.utils.errors import batch_norm_raise_error_if_train + + +class VarianceBatchNormNd(VarianceBaseModule): + """Variance extension for BatchNorm.""" + + def __init__(self): + """Initialization.""" + super().__init__(["weight", "bias"], GradBatchNormNd(), SGSBatchNormNd()) + + def apply(self, ext, module, g_inp, g_out): # noqa: D102 + batch_norm_raise_error_if_train(module) + super().apply(ext, module, g_inp, g_out) diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index 25539c548..45848b5e2 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -12,9 +12,15 @@ ELU, RNN, SELU, + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, AvgPool1d, AvgPool2d, AvgPool3d, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv1d, Conv2d, Conv3d, @@ -43,6 +49,8 @@ from . import ( activations, + adaptive_avg_pool_nd, + batchnorm_nd, conv1d, conv2d, conv3d, @@ -116,6 +124,12 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): SELU: activations.DiagGGNSELU(), RNN: rnn.DiagGGNRNN(), Permute: permute.DiagGGNPermute(), + AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1), + AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(2), + AdaptiveAvgPool3d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(3), + BatchNorm1d: batchnorm_nd.DiagGGNBatchNormNd(), + BatchNorm2d: batchnorm_nd.DiagGGNBatchNormNd(), + BatchNorm3d: batchnorm_nd.DiagGGNBatchNormNd(), }, ) @@ -223,6 +237,12 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): SELU: activations.DiagGGNSELU(), RNN: rnn.BatchDiagGGNRNN(), Permute: permute.DiagGGNPermute(), + AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1), + AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(2), + AdaptiveAvgPool3d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(3), + BatchNorm1d: batchnorm_nd.BatchDiagGGNBatchNormNd(), + BatchNorm2d: batchnorm_nd.BatchDiagGGNBatchNormNd(), + BatchNorm3d: batchnorm_nd.BatchDiagGGNBatchNormNd(), }, ) diff --git a/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py b/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py new file mode 100644 index 000000000..b2cfceb46 --- /dev/null +++ b/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py @@ -0,0 +1,15 @@ +"""DiagGGN extension for AdaptiveAvgPool.""" +from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class DiagGGNAdaptiveAvgPoolNd(DiagGGNBaseModule): + """DiagGGN extension for AdaptiveAvgPool.""" + + def __init__(self, N: int): + """Initialization. + + Args: + N: number of free dimensions, e.g. use N=1 for AdaptiveAvgPool1d + """ + super().__init__(derivatives=AdaptiveAvgPoolNDDerivatives(N=N)) diff --git a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py new file mode 100644 index 000000000..e48dbc68d --- /dev/null +++ b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py @@ -0,0 +1,28 @@ +"""DiagGGN extension for BatchNorm.""" +from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule +from backpack.utils.errors import batch_norm_raise_error_if_train + + +class DiagGGNBatchNormNd(DiagGGNBaseModule): + """DiagGGN extension for BatchNorm.""" + + def __init__(self): + """Initialization.""" + super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=True) + + def apply(self, ext, module, g_inp, g_out): # noqa: D102 + batch_norm_raise_error_if_train(module) + super().apply(ext, module, g_inp, g_out) + + +class BatchDiagGGNBatchNormNd(DiagGGNBaseModule): + """BatchDiagGGN extension for BatchNorm.""" + + def __init__(self): + """Initialization.""" + super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=False) + + def apply(self, ext, module, g_inp, g_out): # noqa: D102 + batch_norm_raise_error_if_train(module) + super().apply(ext, module, g_inp, g_out) diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index afc0887d3..8f5caf781 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -3,6 +3,8 @@ from pkg_resources import get_distribution, packaging TORCH_VERSION = packaging.version.parse(get_distribution("torch").version) +VERSION_1_9_1 = packaging.version.parse("1.9.1") VERSION_1_9_0 = packaging.version.parse("1.9.0") VERSION_1_8_0 = packaging.version.parse("1.8.0") VERSION_1_6_0 = packaging.version.parse("1.6.0") +TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= VERSION_1_9_1 diff --git a/backpack/utils/errors.py b/backpack/utils/errors.py new file mode 100644 index 000000000..690dc451b --- /dev/null +++ b/backpack/utils/errors.py @@ -0,0 +1,29 @@ +"""Contains errors for BackPACK.""" +from typing import Union +from warnings import warn + +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + + +def batch_norm_raise_error_if_train( + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], raise_error: bool = True +) -> None: + """Check if BatchNorm module is in training mode. + + Args: + module: BatchNorm module to check + raise_error: whether to raise an error, alternatively warn. Default: True. + + Raises: + NotImplementedError: if module is in training mode + """ + if module.training: + message = ( + "Encountered BatchNorm module in training mode. BackPACK's computation " + "will pass, but results like individual gradients may not be meaningful, " + "as BatchNorm mixes samples. Only proceed if you know what you are doing." + ) + if raise_error: + raise NotImplementedError(message) + else: + warn(message) diff --git a/fully_documented.txt b/fully_documented.txt index d6e752a97..47756b78b 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -18,21 +18,27 @@ backpack/extensions/mat_to_mat_jac_base.py backpack/extensions/firstorder/gradient/base.py backpack/extensions/firstorder/gradient/rnn.py backpack/extensions/firstorder/gradient/__init__.py +backpack/extensions/firstorder/gradient/batchnorm_nd.py backpack/extensions/firstorder/batch_grad/batch_grad_base.py backpack/extensions/firstorder/batch_grad/rnn.py backpack/extensions/firstorder/batch_grad/__init__.py +backpack/extensions/firstorder/batch_grad/batchnorm_nd.py backpack/extensions/firstorder/variance/variance_base.py backpack/extensions/firstorder/variance/rnn.py backpack/extensions/firstorder/variance/__init__.py +backpack/extensions/firstorder/variance/batchnorm_nd.py backpack/extensions/firstorder/sum_grad_squared/sgs_base.py backpack/extensions/firstorder/sum_grad_squared/rnn.py backpack/extensions/firstorder/sum_grad_squared/__init__.py +backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py backpack/extensions/firstorder/batch_l2_grad/ backpack/extensions/secondorder/__init__.py backpack/extensions/secondorder/diag_ggn/__init__.py backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py backpack/extensions/secondorder/diag_ggn/rnn.py backpack/extensions/secondorder/diag_ggn/permute.py +backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py +backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py backpack/extensions/secondorder/diag_hessian/__init__.py backpack/extensions/secondorder/diag_hessian/conv1d.py backpack/extensions/secondorder/diag_hessian/conv2d.py @@ -43,6 +49,7 @@ backpack/hessianfree/ggnvp.py backpack/utils/linear.py backpack/utils/subsampling.py +backpack/utils/errors.py backpack/utils/__init__.py test/extensions/automated_settings.py @@ -70,3 +77,5 @@ test/core/derivatives/permute_settings.py test/core/derivatives/lstm_settings.py test/core/derivatives/pooling_adaptive_settings.py test/core/derivatives/batch_norm_settings.py +test/utils/evaluation_mode.py +test/utils/skip_test.py diff --git a/test/core/derivatives/batch_norm_settings.py b/test/core/derivatives/batch_norm_settings.py index 621cd7332..6c8f0deeb 100644 --- a/test/core/derivatives/batch_norm_settings.py +++ b/test/core/derivatives/batch_norm_settings.py @@ -12,22 +12,11 @@ "id_prefix" (str): Prefix to be included in the test name. "seed" (int): seed for the random number for torch.rand """ -from typing import Union +from test.utils.evaluation_mode import initialize_batch_norm_eval -from torch import rand, rand_like +from torch import rand from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d - -def _initialize_training_false( - module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] -) -> Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]: - module.running_mean = rand_like(module.running_mean) - module.running_var = rand_like(module.running_var) - module.weight.data = rand_like(module.weight) - module.bias.data = rand_like(module.bias) - return module.train(False) - - BATCH_NORM_SETTINGS = [ { "module_fn": lambda: BatchNorm1d(num_features=7), @@ -47,22 +36,22 @@ def _initialize_training_false( "seed": 1, }, { - "module_fn": lambda: _initialize_training_false(BatchNorm1d(num_features=7)), + "module_fn": lambda: initialize_batch_norm_eval(BatchNorm1d(num_features=7)), "input_fn": lambda: rand(size=(5, 7)), "id_prefix": "training=False", }, { - "module_fn": lambda: _initialize_training_false(BatchNorm1d(num_features=7)), + "module_fn": lambda: initialize_batch_norm_eval(BatchNorm1d(num_features=7)), "input_fn": lambda: rand(size=(5, 7, 4)), "id_prefix": "training=False", }, { - "module_fn": lambda: _initialize_training_false(BatchNorm2d(num_features=7)), + "module_fn": lambda: initialize_batch_norm_eval(BatchNorm2d(num_features=7)), "input_fn": lambda: rand(size=(5, 7, 3, 4)), "id_prefix": "training=False", }, { - "module_fn": lambda: _initialize_training_false(BatchNorm3d(num_features=7)), + "module_fn": lambda: initialize_batch_norm_eval(BatchNorm3d(num_features=7)), "input_fn": lambda: rand(size=(5, 7, 3, 4, 2)), "id_prefix": "training=False", }, diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 3478c6e64..1f3ebb9b2 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -17,6 +17,7 @@ from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS from test.core.derivatives.settings import SETTINGS +from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda from typing import List, Tuple, Union from warnings import warn @@ -100,16 +101,11 @@ def test_jac_t_mat_prod(problem: DerivativesTestProblem, request, V: int = 3) -> request: Pytest request, used for getting id. V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ + skip_adaptive_avg_pool3d_cuda(request) + problem.set_up() mat = torch.rand(V, *problem.output_shape).to(problem.device) - if all( - string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"] - ): - with pytest.warns(UserWarning): - BackpackDerivatives(problem).jac_t_mat_prod(mat) - problem.tear_down() - return backpack_res = BackpackDerivatives(problem).jac_t_mat_prod(mat) autograd_res = AutogradDerivatives(problem).jac_t_mat_prod(mat) @@ -511,18 +507,12 @@ def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None problem: Test case. request: PyTest request, used to get test id. """ + skip_adaptive_avg_pool3d_cuda(request) + problem.set_up() out_features = problem.output_shape[1:].numel() mat = torch.rand(out_features, out_features).to(problem.device) - if all( - string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"] - ): - with pytest.warns(UserWarning): - BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat) - problem.tear_down() - return - backpack_res = BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat) autograd_res = AutogradDerivatives(problem).ea_jac_t_mat_jac_prod(mat) diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index eb03c1788..55ca07eb3 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -20,10 +20,14 @@ """ from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.automated_settings import make_simple_cnn_setting +from test.utils.evaluation_mode import initialize_training_false_recursive from torch import device, rand from torch.nn import ( RNN, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, Conv1d, Conv2d, Conv3d, @@ -206,10 +210,52 @@ ), ] +############################################################################### +# test setting: BatchNorm # +############################################################################### +FIRSTORDER_SETTINGS += [ + { + "input_fn": lambda: rand(2, 3, 4), + "module_fn": lambda: initialize_training_false_recursive( + BatchNorm1d(num_features=3) + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((2, 4), 3), + }, + { + "input_fn": lambda: rand(3, 2, 4, 3), + "module_fn": lambda: initialize_training_false_recursive( + BatchNorm2d(num_features=2) + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3, 4, 3), 2), + }, + { + "input_fn": lambda: rand(3, 3, 4, 1, 2), + "module_fn": lambda: initialize_training_false_recursive( + BatchNorm3d(num_features=3) + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3, 4, 1, 2), 3), + }, + { + "input_fn": lambda: rand(3, 3, 4, 1, 2), + "module_fn": lambda: initialize_training_false_recursive( + Sequential( + BatchNorm3d(num_features=3), + Linear(2, 3), + BatchNorm3d(num_features=3), + ReLU(), + BatchNorm3d(num_features=3), + ) + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3, 4, 1, 3), 3), + }, +] ############################################################################### # test setting: RNN Layers # ############################################################################### - FIRSTORDER_SETTINGS += [ { "input_fn": lambda: rand(8, 5, 6), diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 0d940cec6..dcb620458 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -145,11 +145,11 @@ def forward_pass(self, sample_idx=None): input, output, loss, each with batch axis first """ if sample_idx is None: - input = self.input.clone().detach() - target = self.target.clone().detach() + input = self.input.clone() + target = self.target.clone() else: - target = self.target.split(1, dim=0)[sample_idx].detach() - input = self.input.split(1, dim=0)[sample_idx].detach() + target = self.target.split(1, dim=0)[sample_idx] + input = self.input.split(1, dim=0)[sample_idx] output = self.model(input) diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index ef50ba716..596d68bf2 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -11,18 +11,36 @@ """ from test.core.derivatives.utils import regression_targets from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS +from test.utils.evaluation_mode import initialize_training_false_recursive -import torch -from torch.nn import RNN, Flatten, Sequential +from torch import rand +from torch.nn import ( + RNN, + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + Flatten, + Linear, + MSELoss, + ReLU, + Sequential, +) from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple SHARED_SETTINGS = SECONDORDER_SETTINGS -LOCAL_SETTINGS = [ +LOCAL_SETTINGS = [] +################################################################## +# RNN settings # +################################################################## +LOCAL_SETTINGS += [ # RNN settings { - "input_fn": lambda: torch.rand(8, 5, 6), + "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( Permute(1, 0, 2), RNN(input_size=6, hidden_size=3), @@ -30,8 +48,81 @@ Permute(1, 2, 0), Flatten(), ), - "loss_function_fn": lambda: torch.nn.MSELoss(), + "loss_function_fn": lambda: MSELoss(), "target_fn": lambda: regression_targets((8, 3 * 5)), }, ] +################################################################## +# AdaptiveAvgPool settings # +################################################################## +LOCAL_SETTINGS += [ + { + "input_fn": lambda: rand(2, 2, 9), + "module_fn": lambda: Sequential( + Linear(9, 9), AdaptiveAvgPool1d((3,)), Flatten() + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((2, 2 * 3)), + }, + { + "input_fn": lambda: rand(2, 2, 6, 8), + "module_fn": lambda: Sequential( + Linear(8, 8), AdaptiveAvgPool2d((3, 4)), Flatten() + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((2, 2 * 3 * 4)), + }, + { + "input_fn": lambda: rand(2, 2, 9, 5, 4), + "module_fn": lambda: Sequential( + Linear(4, 4), AdaptiveAvgPool3d((3, 5, 2)), Flatten() + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((2, 2 * 3 * 5 * 2)), + }, +] +################################################################## +# BatchNorm settings # +################################################################## +LOCAL_SETTINGS += [ + { + "input_fn": lambda: rand(2, 3, 4), + "module_fn": lambda: initialize_training_false_recursive( + Sequential(BatchNorm1d(num_features=3), Flatten()) + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((2, 4 * 3)), + }, + { + "input_fn": lambda: rand(3, 2, 4, 3), + "module_fn": lambda: initialize_training_false_recursive( + Sequential(BatchNorm2d(num_features=2), Flatten()) + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((3, 2 * 4 * 3)), + }, + { + "input_fn": lambda: rand(3, 3, 4, 1, 2), + "module_fn": lambda: initialize_training_false_recursive( + Sequential(BatchNorm3d(num_features=3), Flatten()) + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((3, 3 * 4 * 1 * 2)), + }, + { + "input_fn": lambda: rand(3, 3, 4, 1, 2), + "module_fn": lambda: initialize_training_false_recursive( + Sequential( + BatchNorm3d(num_features=3), + Linear(2, 3), + BatchNorm3d(num_features=3), + ReLU(), + BatchNorm3d(num_features=3), + Flatten(), + ) + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((3, 4 * 1 * 3 * 3)), + }, +] DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py index c35100623..0030fa95b 100644 --- a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py @@ -4,6 +4,7 @@ from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS +from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda import pytest @@ -12,12 +13,14 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_diag_ggn_exact_batch(problem): +def test_diag_ggn_exact_batch(problem, request): """Test the individual diagonal of Generalized Gauss-Newton/Fisher. Args: problem (ExtensionsTestProblem): Problem for extension test. + request: problem request """ + skip_adaptive_avg_pool3d_cuda(request) problem.set_up() backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() diff --git a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py index 2a982fcd7..0b7ba0469 100644 --- a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py @@ -4,6 +4,7 @@ from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS +from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda import pytest @@ -12,12 +13,14 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_diag_ggn(problem): +def test_diag_ggn(problem, request): """Test the diagonal of generalized Gauss-Newton. Args: problem (ExtensionsTestProblem): Problem for extension test. + request: problem request """ + skip_adaptive_avg_pool3d_cuda(request) problem.set_up() backpack_res = BackpackExtensions(problem).diag_ggn() diff --git a/test/extensions/utils.py b/test/extensions/utils.py index 0965f55d0..2705e92a8 100644 --- a/test/extensions/utils.py +++ b/test/extensions/utils.py @@ -18,6 +18,6 @@ def skip_if_subsampling_conflict( subsampling: Indices of active samples. """ N = problem.input.shape[get_batch_axis(problem.model)] - enough_samples = subsampling is None or N >= max(subsampling) + enough_samples = subsampling is None or N > max(subsampling) if not enough_samples: skip(f"Not enough samples: N={N}, subsampling={subsampling}") diff --git a/test/utils/evaluation_mode.py b/test/utils/evaluation_mode.py new file mode 100644 index 000000000..f4e57be77 --- /dev/null +++ b/test/utils/evaluation_mode.py @@ -0,0 +1,40 @@ +"""Tools for initializing in evaluation mode, especially BatchNorm.""" +from typing import Union + +from torch import rand_like +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d, Module + + +def initialize_training_false_recursive(module: Module) -> Module: + """Initializes a module recursively in evaluation mode. + + Args: + module: the module to initialize + + Returns: + initialized module in evaluation mode + """ + if isinstance(module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)): + initialize_batch_norm_eval(module) + else: + for module_child in module.children(): + initialize_training_false_recursive(module_child) + return module.train(False) + + +def initialize_batch_norm_eval( + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] +) -> Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]: + """Initializes a BatchNorm module in evaluation mode. + + Args: + module: BatchNorm module + + Returns: + the initialized BatchNorm module in evaluation mode + """ + module.running_mean = rand_like(module.running_mean) + module.running_var = rand_like(module.running_var) + module.weight.data = rand_like(module.weight) + module.bias.data = rand_like(module.bias) + return module.train(False) diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py new file mode 100644 index 000000000..a9744526c --- /dev/null +++ b/test/utils/skip_test.py @@ -0,0 +1,23 @@ +"""Skip specific tests.""" +import pytest + +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_1 + + +def skip_adaptive_avg_pool3d_cuda(request) -> None: + """Skips test if AdaptiveAvgPool3d and cuda. + + Args: + request: problem request + """ + if TORCH_VERSION_AT_LEAST_1_9_1: + pass + else: + if all( + string in request.node.callspec.id + for string in ["AdaptiveAvgPool3d", "cuda"] + ): + pytest.skip( + "Skip test because AdaptiveAvgPool3d does not work on cuda. " + "Is fixed in torch 1.9.1." + ) From 2c632c9fd2e54119595fc32745eaec4d3babdb4b Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 22 Jul 2021 15:13:18 +0200 Subject: [PATCH 35/54] [core] Support sub-sampling in `jac_t_mat_prod` (#205) Add `subsampling` argument to `jac_t_mat_prod` and extend tests Auxiliary: - Refactor `if N==1: ... elif N==2: ...` patterns in Nd derivatives - Other refactorings by introducing helpers functions --- * [REF] Reduce `if N == ...` branches * [ADD} Prototypes of all `jac_t` derivatives with `subsampling` * [REF] Reduce if-else branches, test #30 with save_memory * [REF] Extract shareable parts from if...else * [FIX] Remove unused import * [REF] Make cast data type more specific * [DOC] Reduce redundancy * [FMT] Remove blank line * [REF] Make cast data type more specific * [REF] Specify dtype * [REF] Specify dtype * [FMT] Save one line * [REF] Move test skip helpers, shorten imports --- backpack/core/derivatives/avgpoolnd.py | 73 +++---- backpack/core/derivatives/basederivatives.py | 48 +++-- backpack/core/derivatives/batchnorm_nd.py | 7 + backpack/core/derivatives/conv_transposend.py | 57 +++--- backpack/core/derivatives/convnd.py | 90 +++------ backpack/core/derivatives/dropout.py | 21 +- backpack/core/derivatives/elementwise.py | 50 +++-- backpack/core/derivatives/elu.py | 23 ++- backpack/core/derivatives/flatten.py | 25 ++- backpack/core/derivatives/leakyrelu.py | 23 ++- backpack/core/derivatives/linear.py | 16 +- backpack/core/derivatives/logsigmoid.py | 18 +- backpack/core/derivatives/lstm.py | 11 +- backpack/core/derivatives/maxpoolnd.py | 120 ++++++------ backpack/core/derivatives/permute.py | 9 +- backpack/core/derivatives/relu.py | 18 +- backpack/core/derivatives/rnn.py | 15 +- backpack/core/derivatives/selu.py | 21 +- backpack/core/derivatives/shape_check.py | 6 +- backpack/core/derivatives/sigmoid.py | 18 +- backpack/core/derivatives/tanh.py | 18 +- backpack/core/derivatives/zeropad2d.py | 15 +- backpack/utils/conv.py | 46 ++++- backpack/utils/conv_transpose.py | 45 ++++- test/bugfixes_test.py | 12 +- test/core/derivatives/derivatives_test.py | 183 +++++++----------- .../derivatives/implementation/autograd.py | 31 ++- .../derivatives/implementation/backpack.py | 7 +- test/core/derivatives/implementation/base.py | 3 +- test/utils/skip_test.py | 67 ++++++- 30 files changed, 675 insertions(+), 421 deletions(-) diff --git a/backpack/core/derivatives/avgpoolnd.py b/backpack/core/derivatives/avgpoolnd.py index d95e561c8..b51d21480 100644 --- a/backpack/core/derivatives/avgpoolnd.py +++ b/backpack/core/derivatives/avgpoolnd.py @@ -3,35 +3,22 @@ Average pooling can be expressed as convolution over grouped channels with a constant kernel. """ -from typing import Any, Tuple +from typing import Any, List, Tuple -import torch.nn from einops import rearrange -from torch.nn import ( - Conv1d, - Conv2d, - Conv3d, - ConvTranspose1d, - ConvTranspose2d, - ConvTranspose3d, - Module, -) +from torch import Tensor, ones_like +from torch.nn import Module from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.utils.conv import get_conv_module +from backpack.utils.conv_transpose import get_conv_transpose_module class AvgPoolNDDerivatives(BaseDerivatives): - def __init__(self, N): + def __init__(self, N: int): + self.conv = get_conv_module(N) + self.convt = get_conv_transpose_module(N) self.N = N - if self.N == 1: - self.conv = Conv1d - self.convt = ConvTranspose1d - elif self.N == 2: - self.conv = Conv2d - self.convt = ConvTranspose2d - elif self.N == 3: - self.conv = Conv3d - self.convt = ConvTranspose3d def check_parameters(self, module: Module) -> None: assert module.count_include_pad, ( @@ -101,31 +88,31 @@ def __apply_jacobian_of(self, module, mat): ).to(module.input0.device) convnd.weight.requires_grad = False - avg_kernel = torch.ones_like(convnd.weight) / convnd.weight.numel() + avg_kernel = ones_like(convnd.weight) / convnd.weight.numel() convnd.weight.data = avg_kernel return convnd(mat) def __check_jmp_out_as_pool(self, mat, jmp_as_pool, module): - V = mat.size(0) - if self.N == 1: - N, C_out, L_out = module.output.shape - assert jmp_as_pool.shape == (V * N * C_out, 1, L_out) - elif self.N == 2: - N, C_out, H_out, W_out = module.output.shape - assert jmp_as_pool.shape == (V * N * C_out, 1, H_out, W_out) - elif self.N == 3: - N, C_out, D_out, H_out, W_out = module.output.shape - assert jmp_as_pool.shape == (V * N * C_out, 1, D_out, H_out, W_out) - - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + V = mat.shape[0] + N, C_out = module.output.shape[:2] + + assert jmp_as_pool.shape == (V * N * C_out, 1) + module.output.shape[2:] + + def _jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: self.check_parameters(module) mat_as_pool = self.__make_single_channel(mat, module) jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool) - self.__check_jmp_in_as_pool(mat, jmp_as_pool, module) - return self.reshape_like_input(jmp_as_pool, module) + return self.reshape_like_input(jmp_as_pool, module, subsampling=subsampling) def __apply_jacobian_t_of(self, module, mat): stride, kernel_size, padding = self.get_avg_pool_parameters(module) @@ -141,22 +128,10 @@ def __apply_jacobian_t_of(self, module, mat): ).to(module.input0.device) convnd_t.weight.requires_grad = False - avg_kernel = torch.ones_like(convnd_t.weight) / convnd_t.weight.numel() + avg_kernel = ones_like(convnd_t.weight) / convnd_t.weight.numel() convnd_t.weight.data = avg_kernel V_N_C_in = mat.size(0) output_size = (V_N_C_in, C_for_conv_t) + tuple(module.input0.shape[2:]) return convnd_t(mat, output_size=output_size) - - def __check_jmp_in_as_pool(self, mat, jmp_as_pool, module): - V = mat.size(0) - if self.N == 1: - N, C_in, L_in = module.input0.size() - assert jmp_as_pool.shape == (V * N * C_in, 1, L_in) - elif self.N == 2: - N, C_in, H_in, W_in = module.input0.size() - assert jmp_as_pool.shape == (V * N * C_in, 1, H_in, W_in) - elif self.N == 3: - N, C_in, D_in, H_in, W_in = module.input0.size() - assert jmp_as_pool.shape == (V * N * C_in, 1, D_in, H_in, W_in) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 35ca27d30..523955bdb 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -7,6 +7,7 @@ from torch.nn import Module from backpack.core.derivatives import shape_check +from backpack.utils.subsampling import get_batch_axis class BaseDerivatives(ABC): @@ -80,7 +81,12 @@ def _jac_mat_prod( @shape_check.jac_t_mat_prod_accept_vectors @shape_check.jac_t_mat_prod_check_shapes def jac_t_mat_prod( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed input-ouput Jacobian of module output to a matrix. @@ -93,16 +99,25 @@ def jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, N, C_out, H_out, ...]. + Must have shape ``[V, *module.output.shape]``; but if used with + sub-sampling, the batch dimension is replaced by ``len(subsampling)``. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Transposed Jacobian-matrix product. - Has shape [V, N, C_in, H_in, ...]. + Has shape ``[V, *module.input0.shape]``; but if used with sub-sampling, + the batch dimension is replaced by ``len(subsampling)``. """ - return self._jac_t_mat_prod(module, g_inp, g_out, mat) + return self._jac_t_mat_prod(module, g_inp, g_out, mat, subsampling=subsampling) def _jac_t_mat_prod( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -244,34 +259,39 @@ def _residual_mat_prod( raise NotImplementedError @staticmethod - def _reshape_like(mat: Tensor, like: Tensor) -> Tensor: + def _reshape_like(mat: Tensor, shape: Tuple[int]) -> Tensor: """Reshape as like with trailing and additional 0th dimension. If like is [N, C, H, ...], returns shape [-1, N, C, H, ...] Args: - mat: matrix to reshape - like: matrix with target shape + mat: Matrix to reshape. + shape: Trailing target shape. Returns: reshaped matrix """ - V = -1 - shape = (V, *like.shape) - return mat.reshape(shape) + return mat.reshape(-1, *shape) @classmethod - def reshape_like_input(cls, mat: Tensor, module: Module) -> Tensor: + def reshape_like_input( + cls, mat: Tensor, module: Module, subsampling: List[int] = None + ) -> Tensor: """Reshapes matrix according to input. Args: mat: matrix to reshape module: module which input shape is used + subsampling: Indices of active samples. ``None`` means use all samples. Returns: reshaped matrix """ - return cls._reshape_like(mat, module.input0) + shape = list(module.input0.shape) + if subsampling is not None: + shape[get_batch_axis(module)] = len(subsampling) + + return cls._reshape_like(mat, shape) @classmethod def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor: @@ -284,7 +304,7 @@ def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor: Returns: reshaped matrix """ - return cls._reshape_like(mat, module.output) + return cls._reshape_like(mat, module.output.shape) class BaseParameterDerivatives(BaseDerivatives, ABC): diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index bc6a3b083..7643ad22f 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -90,10 +90,17 @@ def _jac_t_mat_prod( g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) N: int = self._get_n_axis(module) if module.training: + + if subsampling is not None: + raise NotImplementedError( + "BatchNorm VJP sub-sampling is not defined in train mode." + ) + denominator: int = self._get_denominator(module) x_hat, var = self._get_normalized_input_and_var(module) ivar = 1.0 / (var + module.eps).sqrt() diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index fc8dd7844..f0046337f 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -4,48 +4,29 @@ from einops import rearrange from numpy import prod from torch import Tensor, einsum -from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d -from torch.nn.functional import ( - conv1d, - conv2d, - conv3d, - conv_transpose1d, - conv_transpose2d, - conv_transpose3d, -) +from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module from torch.nn.grad import _grad_input_padding from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils.conv_transpose import unfold_by_conv_transpose +from backpack.utils.conv import get_conv_function +from backpack.utils.conv_transpose import ( + get_conv_transpose_function, + unfold_by_conv_transpose, +) from backpack.utils.subsampling import subsample class ConvTransposeNDDerivatives(BaseParameterDerivatives): """Base class for partial derivatives of transpose convolution.""" - def __init__(self, N): - """Store convolution dimension and operations. + def __init__(self, N: int): + """Store transpose convolution dimension and operations. Args: - N (int): Convolution dimension. Must be ``1``, ``2``, or ``3``. - - Raises: - ValueError: If convolution dimension is unsupported. + N: Transpose convolution dimension. """ - if N == 1: - self.module = ConvTranspose1d - self.conv_func = conv1d - self.conv_transpose_func = conv_transpose1d - elif N == 2: - self.module = ConvTranspose2d - self.conv_func = conv2d - self.conv_transpose_func = conv_transpose2d - elif N == 3: - self.module = ConvTranspose3d - self.conv_func = conv3d - self.conv_transpose_func = conv_transpose3d - else: - raise ValueError(f"ConvTranspose{N}d not supported.") + self.conv_func = get_conv_function(N) + self.conv_transpose_func = get_conv_transpose_function(N) self.conv_dims = N def hessian_is_zero(self, module): @@ -150,7 +131,7 @@ def __jac(self, module, mat): dilation=module.dilation, ) - jac_t_mat = conv_transpose1d( + jac_t_mat = self.conv_transpose_func( input=mat, weight=module.weight, bias=None, @@ -160,14 +141,22 @@ def __jac(self, module, mat): groups=module.groups, dilation=module.dilation, ) + return jac_t_mat - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + def _jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...") jmp_as_conv = self.__jac_t(module, mat_as_conv) - return self.reshape_like_input(jmp_as_conv, module) + return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling) - def __jac_t(self, module, mat): + def __jac_t(self, module: Module, mat: Tensor) -> Tensor: jac_t = self.conv_func( mat, module.weight, diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py index 2e8dedaf3..167a5ad75 100644 --- a/backpack/core/derivatives/convnd.py +++ b/backpack/core/derivatives/convnd.py @@ -1,22 +1,15 @@ -import warnings from typing import List, Tuple, Union +from warnings import warn from einops import rearrange, reduce from numpy import prod from torch import Tensor, einsum -from torch.nn import Conv1d, Conv2d, Conv3d -from torch.nn.functional import ( - conv1d, - conv2d, - conv3d, - conv_transpose1d, - conv_transpose2d, - conv_transpose3d, -) +from torch.nn import Conv1d, Conv2d, Conv3d, Module from torch.nn.grad import _grad_input_padding from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import conv as convUtils +from backpack.utils.conv import get_conv_function, unfold_by_conv +from backpack.utils.conv_transpose import get_conv_transpose_function from backpack.utils.subsampling import subsample @@ -40,27 +33,15 @@ def __exit__(self, type, value, traceback): class ConvNDDerivatives(BaseParameterDerivatives): def __init__(self, N): - if N == 1: - self.module = Conv1d - self.conv_func = conv1d - self.conv_transpose_func = conv_transpose1d - elif N == 2: - self.module = Conv2d - self.conv_func = conv2d - self.conv_transpose_func = conv_transpose2d - elif N == 3: - self.module = Conv3d - self.conv_func = conv3d - self.conv_transpose_func = conv_transpose3d - else: - raise ValueError("{}-dimensional Conv. is not implemented.".format(N)) + self.conv_func = get_conv_function(N) + self.conv_transpose_func = get_conv_transpose_function(N) self.conv_dims = N def hessian_is_zero(self, module): return True def get_unfolded_input(self, module): - return convUtils.unfold_by_conv(module.input0, module) + return unfold_by_conv(module.input0, module) def _jac_mat_prod(self, module, g_inp, g_out, mat): mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...") @@ -74,10 +55,17 @@ def _jac_mat_prod(self, module, g_inp, g_out, mat): ) return self.reshape_like_output(jmp_as_conv, module) - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + def _jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...") jmp_as_conv = self.__jac_t(module, mat_as_conv) - return self.reshape_like_input(jmp_as_conv, module) + return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling) def __jac_t(self, module, mat): input_size = list(module.input0.size()) @@ -145,23 +133,16 @@ def _weight_jac_t_mat_prod( save_memory = weight_jac_t_save_memory._SAVE_MEMORY if save_memory and self.conv_dims in [1, 2]: - return self.__higher_conv_weight_jac_t( - module, mat, sum_batch, subsampling=subsampling - ) - + weight_jac_t_func = self.__higher_conv_weight_jac_t else: - if save_memory and self.conv_dims == 3: - warnings.warn( - UserWarning( - "Conv3d: Cannot save memory as there is no Conv4d." - + " Fallback to more memory-intense method." - ) + warn( + "Conv3d: Cannot save memory as there is no Conv4d." + + " Fallback to more memory-intense method." ) + weight_jac_t_func = self.__same_conv_weight_jac_t - return self.__same_conv_weight_jac_t( - module, mat, sum_batch, subsampling=subsampling - ) + return weight_jac_t_func(module, mat, sum_batch, subsampling=subsampling) def __same_conv_weight_jac_t( self, @@ -234,22 +215,11 @@ def __higher_conv_weight_jac_t( N = module.output.shape[0] if subsampling is None else len(subsampling) C_in = module.input0.shape[1] - if self.conv_dims == 1: - _, _, L_in = module.input0.size() - higher_conv_func = conv2d - K_L_axis = 2 - K_L = module.kernel_size[0] - spatial_dim = (C_in // G, L_in) - spatial_dim_axis = (1, V, 1, 1) - spatial_dim_new = (C_in // G, K_L) - else: - _, _, H_in, W_in = module.input0.size() - higher_conv_func = conv3d - K_H_axis, K_W_axis = 2, 3 - K_H, K_W = module.kernel_size - spatial_dim = (C_in // G, H_in, W_in) - spatial_dim_axis = (1, V, 1, 1, 1) - spatial_dim_new = (C_in // G, K_H, K_W) + higher_conv_func = get_conv_function(self.conv_dims + 1) + + spatial_dim = (C_in // G,) + module.input0.shape[2:] + spatial_dim_axis = (1, V) + tuple([1] * (self.conv_dims + 1)) + spatial_dim_new = (C_in // G,) + module.weight.shape[2:] # Reshape to extract groups from the convolutional layer # Channels are seen as an extra spatial dimension with kernel size 1 @@ -277,10 +247,8 @@ def __higher_conv_weight_jac_t( # Because of rounding shapes when using non-default stride or dilation, # convolution result must be truncated to convolution kernel size - if self.conv_dims == 1: - conv = conv.narrow(K_L_axis, 0, K_L) - else: - conv = conv.narrow(K_H_axis, 0, K_H).narrow(K_W_axis, 0, K_W) + for axis in range(2, 2 + self.conv_dims): + conv = conv.narrow(axis, 0, module.weight.shape[axis]) new_shape = [V, N, C_out, *spatial_dim_new] weight_grad = conv.reshape(*new_shape) diff --git a/backpack/core/derivatives/dropout.py b/backpack/core/derivatives/dropout.py index 2ddc2aa23..5ca3d4a02 100644 --- a/backpack/core/derivatives/dropout.py +++ b/backpack/core/derivatives/dropout.py @@ -1,13 +1,26 @@ -from torch import eq +"""Partial derivatives for the dropout layer.""" +from typing import List, Tuple + +from torch import Tensor, eq +from torch.nn import Dropout from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class DropoutDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self, module): + def hessian_is_zero(self, module: Dropout) -> bool: + """``Dropout''(x) = 0``.""" return True - def df(self, module, g_inp, g_out): + def df( + self, + module: Dropout, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: + output = subsample(module.output, subsampling=subsampling) scaling = 1 / (1 - module.p) - mask = 1 - eq(module.output, 0.0).float() + mask = 1 - eq(output, 0.0).to(output.dtype) return mask * scaling diff --git a/backpack/core/derivatives/elementwise.py b/backpack/core/derivatives/elementwise.py index 602706f27..e7ec3de29 100644 --- a/backpack/core/derivatives/elementwise.py +++ b/backpack/core/derivatives/elementwise.py @@ -1,6 +1,9 @@ """Base class for more flexible Jacobians/Hessians of activation functions.""" -from torch import einsum +from typing import List, Tuple + +from torch import Tensor, einsum +from torch.nn import Module from backpack.core.derivatives.basederivatives import BaseDerivatives @@ -20,18 +23,24 @@ class ElementwiseDerivatives(BaseDerivatives): - If the activation is piece-wise linear: `hessian_is_zero`, else `d2f`. """ - def df(self, module, g_inp, g_out): + def df( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ): """Elementwise first derivative. Args: - module (torch.nn.Module): PyTorch activation function module. - g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs. - g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs. + module: PyTorch activation module. + g_inp: Gradients of the module w.r.t. its inputs. + g_out: Gradients of the module w.r.t. its outputs. + subsampling: Indices of active samples. ``None`` means all samples. Returns: - (torch.Tensor): Tensor containing the derivatives `f'(input[i]) ∀ i`. + Tensor containing the derivatives `f'(input[i]) ∀ i`. """ - raise NotImplementedError("First derivatives not implemented") def d2f(self, module, g_inp, g_out): @@ -40,14 +49,13 @@ def d2f(self, module, g_inp, g_out): Only needs to be implemented for non piece-wise linear functions. Args: - module (torch.nn.Module): PyTorch activation function module. + module (torch.nn.Module): PyTorch activation module. g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs. g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs. Returns: (torch.Tensor): Tensor containing the derivatives `f''(input[i]) ∀ i`. """ - raise NotImplementedError("Second derivatives not implemented") def hessian_diagonal(self, module, g_inp, g_out): @@ -57,7 +65,7 @@ def hessian_diagonal(self, module, g_inp, g_out): - Only required if `hessian_is_diagonal` returns `True`. Args: - module (torch.nn.Module): PyTorch activation function module. + module (torch.nn.Module): PyTorch activation module. g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs. g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs. """ @@ -73,11 +81,18 @@ def hessian_is_diagonal(self, module): """ return True - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + def _jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: self._no_inplace(module) - df_elementwise = self.df(module, g_inp, g_out) - return einsum("...,v...->v...", (df_elementwise, mat)) + df_elementwise = self.df(module, g_inp, g_out, subsampling=subsampling) + return einsum("...,v...->v...", df_elementwise, mat) def _jac_mat_prod(self, module, g_inp, g_out, mat): self._no_inplace(module) @@ -89,21 +104,22 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): N = module.input0.size(0) df_flat = self.df(module, g_inp, g_out).reshape(N, -1) - return einsum("ni,nj,ij->ij", (df_flat, df_flat, mat)) / N + return einsum("ni,nj,ij->ij", df_flat, df_flat, mat) / N def _residual_mat_prod(self, module, g_inp, g_out, mat): residual = self.d2f(module, g_inp, g_out) * g_out[0] - return einsum("...,v...->v...", (residual, mat)) + return einsum("...,v...->v...", residual, mat) + # TODO Deprecate after supporting torch >= 1.8.0 and full_backward_hook @staticmethod - def _no_inplace(module): + def _no_inplace(module: Module): """Do not support inplace modification. Jacobians/Hessians might be computed using the modified input instead of the original. Args: - module (torch.nn.Module): Elementwise activation module. + module: Elementwise activation module. Raises: NotImplementedError: If `module` has inplace option enabled. diff --git a/backpack/core/derivatives/elu.py b/backpack/core/derivatives/elu.py index 99cafd0b5..5d3778223 100644 --- a/backpack/core/derivatives/elu.py +++ b/backpack/core/derivatives/elu.py @@ -1,22 +1,33 @@ """Partial derivatives for the ELU activation function.""" -from torch import exp, le, ones_like, zeros_like +from typing import List, Tuple + +from torch import Tensor, exp, le, ones_like, zeros_like +from torch.nn import ELU from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class ELUDerivatives(ElementwiseDerivatives): """Implement first- and second-order partial derivatives of ELU.""" - def hessian_is_zero(self, module): + def hessian_is_zero(self, module: ELU) -> bool: """`ELU''(x) ≠ 0`.""" return False - def df(self, module, g_inp, g_out): + def df( + self, + module: ELU, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ): """First ELU derivative: `ELU'(x) = alpha * e^x if x <= 0 else 1`.""" - non_pos = le(module.input0, 0) + input0 = subsample(module.input0, subsampling=subsampling) + non_pos = le(input0, 0) - result = ones_like(module.input0) - result[non_pos] = module.alpha * exp(module.input0[non_pos]) + result = ones_like(input0) + result[non_pos] = module.alpha * exp(input0[non_pos]) return result diff --git a/backpack/core/derivatives/flatten.py b/backpack/core/derivatives/flatten.py index 4f8dd72d0..a0af28da1 100644 --- a/backpack/core/derivatives/flatten.py +++ b/backpack/core/derivatives/flatten.py @@ -1,3 +1,9 @@ +"""Partial derivatives of the flatten layer.""" +from typing import List, Tuple + +from torch import Tensor +from torch.nn import Flatten + from backpack.core.derivatives.basederivatives import BaseDerivatives @@ -8,10 +14,23 @@ def hessian_is_zero(self, module): def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): return mat - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): - return self.reshape_like_input(mat, module) + def _jac_t_mat_prod( + self, + module: Flatten, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: + return self.reshape_like_input(mat, module, subsampling=subsampling) - def _jac_mat_prod(self, module, g_inp, g_out, mat): + def _jac_mat_prod( + self, + module: Flatten, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + ) -> Tensor: return self.reshape_like_output(mat, module) def is_no_op(self, module): diff --git a/backpack/core/derivatives/leakyrelu.py b/backpack/core/derivatives/leakyrelu.py index 89c2f9ee4..7cb0dfa1e 100644 --- a/backpack/core/derivatives/leakyrelu.py +++ b/backpack/core/derivatives/leakyrelu.py @@ -1,16 +1,27 @@ -from torch import gt +"""Partial derivatives for the leaky ReLU layer.""" +from typing import List, Tuple + +from torch import Tensor, gt +from torch.nn import LeakyReLU from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class LeakyReLUDerivatives(ElementwiseDerivatives): - def hessian_is_zero(self, module): + def hessian_is_zero(self, module: LeakyReLU) -> bool: """`LeakyReLU''(x) = 0`.""" return True - def df(self, module, g_inp, g_out): - """First LeakyReLU derivative: - `LeakyReLU'(x) = negative_slope if x < 0 else 1`.""" - df_leakyrelu = gt(module.input0, 0).float() + def df( + self, + module: LeakyReLU, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: + """``LeakyReLU'(x) = negative_slope if x < 0 else 1``.""" + input0 = subsample(module.input0, subsampling=subsampling) + df_leakyrelu = gt(input0, 0).to(input0.dtype) df_leakyrelu[df_leakyrelu == 0] = module.negative_slope return df_leakyrelu diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index e6a8fe586..aae3da1a5 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -32,7 +32,12 @@ def hessian_is_zero(self, module: Linear) -> bool: return True def _jac_t_mat_prod( - self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + self, + module: Linear, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: """Batch-apply transposed Jacobian of the output w.r.t. the input. @@ -42,13 +47,16 @@ def _jac_t_mat_prod( g_out: Gradients w.r.t. module output. Not required by the implementation. mat: Batch of ``V`` vectors of same shape as the layer output (``[N, *, out_features]``) to which the transposed output-input Jacobian - is applied. Has shape ``[V, N, *, out_features]``. + is applied. Has shape ``[V, N, *, out_features]``; but if used with + sub-sampling, ``N`` is replaced by ``len(subsampling)``. + subsampling: Indices of active samples. ``None`` means all samples. Returns: Batched transposed Jacobian vector products. Has shape - ``[V, N, *, in_features]``. + ``[V, N, *, in_features]``. If used with sub-sampling, ``N`` is replaced + by ``len(subsampling)``. """ - return einsum("oi,vn...o->vn...i", module.weight, mat) + return einsum("vn...o,oi->vn...i", mat, module.weight) def _jac_mat_prod( self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor diff --git a/backpack/core/derivatives/logsigmoid.py b/backpack/core/derivatives/logsigmoid.py index 78a882241..917784010 100644 --- a/backpack/core/derivatives/logsigmoid.py +++ b/backpack/core/derivatives/logsigmoid.py @@ -1,6 +1,11 @@ -from torch import exp +"""Contains partial derivatives for the ``torch.nn.LogSigmoid`` layer.""" +from typing import List, Tuple + +from torch import Tensor, exp +from torch.nn import LogSigmoid from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class LogSigmoidDerivatives(ElementwiseDerivatives): @@ -8,9 +13,16 @@ def hessian_is_zero(self, module): """`logsigmoid''(x) ≠ 0`.""" return False - def df(self, module, g_inp, g_out): + def df( + self, + module: LogSigmoid, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: """First Logsigmoid derivative: `logsigmoid'(x) = 1 / (e^x + 1) `.""" - return 1 / (exp(module.input0) + 1) + input0 = subsample(module.input0, subsampling=subsampling) + return 1 / (exp(input0) + 1) def d2f(self, module, g_inp, g_out): """Second Logsigmoid derivative: `logsigmoid''(x) = - e^x / (e^x + 1)^2`.""" diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 8eb88386d..de774c5fa 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -275,11 +275,18 @@ def _jac_mat_prod( return H_prod def _jac_t_mat_prod( - self, module: LSTM, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + self, + module: LSTM, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( + module, mat, subsampling=subsampling + ) X_prod: Tensor = einsum( "vtnh,hi->vtni", diff --git a/backpack/core/derivatives/maxpoolnd.py b/backpack/core/derivatives/maxpoolnd.py index ac39765a2..dc1c6af59 100644 --- a/backpack/core/derivatives/maxpoolnd.py +++ b/backpack/core/derivatives/maxpoolnd.py @@ -1,28 +1,31 @@ +from typing import List, Tuple, Union + from einops import rearrange -from torch import zeros +from torch import Tensor, zeros +from torch.nn import MaxPool1d, MaxPool2d, MaxPool3d from torch.nn.functional import max_pool1d, max_pool2d, max_pool3d from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.utils.subsampling import subsample class MaxPoolNDDerivatives(BaseDerivatives): - def __init__(self, N): + def __init__(self, N: int): self.N = N - if self.N == 1: - self.maxpool = max_pool1d - elif self.N == 2: - self.maxpool = max_pool2d - elif self.N == 3: - self.maxpool = max_pool3d - else: - raise ValueError( - "{}-dimensional Maxpool. is not implemented.".format(self.N) - ) + self.maxpool = { + 1: max_pool1d, + 2: max_pool2d, + 3: max_pool3d, + }[N] # TODO: Do not recompute but get from forward pass of module - def get_pooling_idx(self, module): + def get_pooling_idx( + self, + module: Union[MaxPool1d, MaxPool2d, MaxPool3d], + subsampling: List[int] = None, + ) -> Tensor: _, pool_idx = self.maxpool( - module.input0, + subsample(module.input0, subsampling=subsampling), kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, @@ -48,22 +51,9 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): """ device = mat.device - if self.N == 1: - N, C, L_in = module.input0.size() - _, _, L_out = module.output.size() - in_pixels = L_in - out_pixels = L_out - elif self.N == 2: - N, C, H_in, W_in = module.input0.size() - _, _, H_out, W_out = module.output.size() - in_pixels = H_in * W_in - out_pixels = H_out * W_out - elif self.N == 3: - N, C, D_in, H_in, W_in = module.input0.size() - _, _, D_out, H_out, W_out = module.output.size() - in_pixels = D_in * H_in * W_in - out_pixels = D_out * H_out * W_out - + N, C = module.input0.shape[:2] + in_pixels = module.input0.shape[2:].numel() + out_pixels = module.output.shape[2:].numel() in_features = C * in_pixels pool_idx = self.get_pooling_idx(module).view(N, C, out_pixels) @@ -102,46 +92,56 @@ def __apply_jacobian_of(self, module, mat): pool_idx = self.__pool_idx_for_jac(module, V) return mat.gather(N_axis, pool_idx) - def __pool_idx_for_jac(self, module, V): + def __pool_idx_for_jac( + self, + module: Union[MaxPool1d, MaxPool2d, MaxPool3d], + V: int, + subsampling: List[int] = None, + ) -> Tensor: """Manipulated pooling indices ready-to-use in jac(t).""" - pool_idx = self.get_pooling_idx(module) + pool_idx = self.get_pooling_idx(module, subsampling=subsampling) pool_idx = rearrange(pool_idx, "n c ... -> n c (...)") - V_axis = 0 - - return pool_idx.unsqueeze(V_axis).expand(V, -1, -1, -1) + return pool_idx.unsqueeze(0).expand(V, -1, -1, -1) - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + def _jac_t_mat_prod( + self, + module: Union[MaxPool1d, MaxPool2d, MaxPool3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: mat_as_pool = rearrange(mat, "v n c ... -> v n c (...)") - jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool) - return self.reshape_like_input(jmp_as_pool, module) - - def __apply_jacobian_t_of(self, module, mat): + jmp_as_pool = self.__apply_jacobian_t_of( + module, mat_as_pool, subsampling=subsampling + ) + return self.reshape_like_input(jmp_as_pool, module, subsampling=subsampling) + + def __apply_jacobian_t_of( + self, + module: Union[MaxPool1d, MaxPool2d, MaxPool3d], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: V = mat.shape[0] - result = self.__zero_for_jac_t(module, V, mat.device) - pool_idx = self.__pool_idx_for_jac(module, V) + result = self.__zero_for_jac_t(module, V, subsampling=subsampling) + pool_idx = self.__pool_idx_for_jac(module, V, subsampling=subsampling) N_axis = 3 result.scatter_add_(N_axis, pool_idx, mat) return result - def __zero_for_jac_t(self, module, V, device): - if self.N == 1: - N, C_out, _ = module.output.shape - _, _, L_in = module.input0.size() - - shape = (V, N, C_out, L_in) - - elif self.N == 2: - N, C_out, _, _ = module.output.shape - _, _, H_in, W_in = module.input0.size() - - shape = (V, N, C_out, H_in * W_in) - - elif self.N == 3: - N, C_out, _, _, _ = module.output.shape - _, _, D_in, H_in, W_in = module.input0.size() + def __zero_for_jac_t( + self, + module: Union[MaxPool1d, MaxPool2d, MaxPool3d], + V: int, + subsampling: List[int] = None, + ) -> Tensor: + N, C_out = module.output.shape[:2] + in_pixels = module.input0.shape[2:].numel() + N = N if subsampling is None else len(subsampling) - shape = (V, N, C_out, D_in * H_in * W_in) + shape = (V, N, C_out, in_pixels) - return zeros(shape, device=device) + return zeros(shape, device=module.output.device, dtype=module.output.dtype) diff --git a/backpack/core/derivatives/permute.py b/backpack/core/derivatives/permute.py index a4654b9bf..396803876 100644 --- a/backpack/core/derivatives/permute.py +++ b/backpack/core/derivatives/permute.py @@ -1,5 +1,5 @@ """Module containing derivatives of Permute.""" -from typing import Tuple +from typing import List, Tuple from torch import Tensor, argsort @@ -11,7 +11,12 @@ class PermuteDerivatives(BaseDerivatives): """Derivatives of Permute.""" def _jac_t_mat_prod( - self, module: Permute, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + self, + module: Permute, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: return mat.permute( [0] + [element + 1 for element in argsort(Tensor(module.dims))] diff --git a/backpack/core/derivatives/relu.py b/backpack/core/derivatives/relu.py index 1aa775fbb..18dab75fa 100644 --- a/backpack/core/derivatives/relu.py +++ b/backpack/core/derivatives/relu.py @@ -1,6 +1,11 @@ -from torch import gt +"""Partial derivatives for the ReLU activation function.""" +from typing import List, Tuple + +from torch import Tensor, gt +from torch.nn import ReLU from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class ReLUDerivatives(ElementwiseDerivatives): @@ -8,6 +13,13 @@ def hessian_is_zero(self, module): """`ReLU''(x) = 0`.""" return True - def df(self, module, g_inp, g_out): + def df( + self, + module: ReLU, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: """First ReLU derivative: `ReLU'(x) = 0 if x < 0 else 1`.""" - return gt(module.input0, 0).float() + input0 = subsample(module.input0, subsampling=subsampling) + return gt(input0, 0).to(input0.dtype) diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 2277c0380..59867da3a 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -89,12 +89,23 @@ def _a_jac_t_mat_prod( return a_jac_t_mat_prod def _jac_t_mat_prod( - self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor + self, + module: RNN, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) return torch.einsum( "vtnh,hk->vtnk", - self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), + self._a_jac_t_mat_prod( + subsample( + module.output, dim=get_batch_axis(module), subsampling=subsampling + ), + module.weight_hh_l0, + mat, + ), module.weight_ih_l0, ) diff --git a/backpack/core/derivatives/selu.py b/backpack/core/derivatives/selu.py index fda077a5d..b6e1c6852 100644 --- a/backpack/core/derivatives/selu.py +++ b/backpack/core/derivatives/selu.py @@ -1,7 +1,11 @@ """Partial derivatives for the SELU activation function.""" -from torch import exp, le, ones_like, zeros_like +from typing import List, Tuple + +from torch import Tensor, exp, le, ones_like, zeros_like +from torch.nn import SELU from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class SELUDerivatives(ElementwiseDerivatives): @@ -14,12 +18,19 @@ def hessian_is_zero(self, module): """`SELU''(x) != 0`.""" return False - def df(self, module, g_inp, g_out): + def df( + self, + module: SELU, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: """First SELU derivative: `SELU'(x) = scale if x > 0 else scale*alpha*e^x`.""" - non_pos = le(module.input0, 0) + input0 = subsample(module.input0, subsampling=subsampling) + non_pos = le(input0, 0) - result = self.scale * ones_like(module.input0) - result[non_pos] = self.scale * self.alpha * exp(module.input0[non_pos]) + result = self.scale * ones_like(input0) + result[non_pos] = self.scale * self.alpha * exp(input0[non_pos]) return result diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index c5bdc6e2b..e9d9ee056 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -61,9 +61,11 @@ def _check_same_V_dim(mat1, mat2): def _check_like(mat, module, name, diff=1, *args, **kwargs): - if name == "output" and "subsampling" in kwargs.keys(): + if name in ["output", "input0"] and "subsampling" in kwargs.keys(): compare = subsample( - module.output, dim=get_batch_axis(module), subsampling=kwargs["subsampling"] + getattr(module, name), + dim=get_batch_axis(module), + subsampling=kwargs["subsampling"], ) else: compare = getattr(module, name) diff --git a/backpack/core/derivatives/sigmoid.py b/backpack/core/derivatives/sigmoid.py index c5b5437fe..f03573e57 100644 --- a/backpack/core/derivatives/sigmoid.py +++ b/backpack/core/derivatives/sigmoid.py @@ -1,4 +1,11 @@ +"""Partial derivatives for the Sigmoid activation function.""" +from typing import List, Tuple + +from torch import Tensor +from torch.nn import Sigmoid + from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class SigmoidDerivatives(ElementwiseDerivatives): @@ -6,9 +13,16 @@ def hessian_is_zero(self, module): """`σ''(x) ≠ 0`.""" return False - def df(self, module, g_inp, g_out): + def df( + self, + module: Sigmoid, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: """First sigmoid derivative: `σ'(x) = σ(x) (1 - σ(x))`.""" - return module.output * (1.0 - module.output) + output = subsample(module.output, subsampling=subsampling) + return output * (1.0 - output) def d2f(self, module, g_inp, g_out): """Second sigmoid derivative: `σ''(x) = σ(x) (1 - σ(x)) (1 - 2 σ(x))`.""" diff --git a/backpack/core/derivatives/tanh.py b/backpack/core/derivatives/tanh.py index 4bf99c849..1fd6c9c14 100644 --- a/backpack/core/derivatives/tanh.py +++ b/backpack/core/derivatives/tanh.py @@ -1,12 +1,26 @@ +"""Partial derivatives for the Tanh activation function.""" +from typing import List, Tuple + +from torch import Tensor +from torch.nn import Tanh + from backpack.core.derivatives.elementwise import ElementwiseDerivatives +from backpack.utils.subsampling import subsample class TanhDerivatives(ElementwiseDerivatives): def hessian_is_zero(self, module): return False - def df(self, module, g_inp, g_out): - return 1.0 - module.output ** 2 + def df( + self, + module: Tanh, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: + output = subsample(module.output, subsampling=subsampling) + return 1.0 - output ** 2 def d2f(self, module, g_inp, g_out): return -2.0 * module.output * (1.0 - module.output ** 2) diff --git a/backpack/core/derivatives/zeropad2d.py b/backpack/core/derivatives/zeropad2d.py index 6e121e099..07af6c95e 100644 --- a/backpack/core/derivatives/zeropad2d.py +++ b/backpack/core/derivatives/zeropad2d.py @@ -1,5 +1,9 @@ +"""Partial derivatives for the ZeroPad2d function.""" +from typing import List, Tuple + from einops import rearrange -from torch.nn import functional +from torch import Tensor +from torch.nn import ZeroPad2d, functional from backpack.core.derivatives.basederivatives import BaseDerivatives @@ -27,7 +31,14 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): return result.view(in_features, in_features) - def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + def _jac_t_mat_prod( + self, + module: ZeroPad2d, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: (W_top, W_bottom), (H_bottom, H_top) = self.__unpad_indices(module) return mat[:, :, :, W_top:W_bottom, H_bottom:H_top] diff --git a/backpack/utils/conv.py b/backpack/utils/conv.py index 3aa4b9a2e..2cc1c5adb 100644 --- a/backpack/utils/conv.py +++ b/backpack/utils/conv.py @@ -1,12 +1,44 @@ -from typing import Union +from typing import Callable, Type, Union import torch from einops import rearrange from torch import Tensor, einsum -from torch.nn import Conv1d, Conv2d, Conv3d +from torch.nn import Conv1d, Conv2d, Conv3d, Module from torch.nn.functional import conv1d, conv2d, conv3d, unfold +def get_conv_module(N: int) -> Type[Module]: + """Return the PyTorch module class of N-dimensional convolution. + + Args: + N: Convolution dimension. + + Returns: + Convolution class. + """ + return { + 1: Conv1d, + 2: Conv2d, + 3: Conv3d, + }[N] + + +def get_conv_function(N: int) -> Callable: + """Return the PyTorch function of N-dimensional convolution. + + Args: + N: Convolution dimension. + + Returns: + Convolution function. + """ + return { + 1: conv1d, + 2: conv2d, + 3: conv3d, + }[N] + + def unfold_input(module: Union[Conv1d, Conv2d, Conv3d], input: Tensor) -> Tensor: """Return unfolded input to a convolution. @@ -118,15 +150,9 @@ def make_weight(): repeat = [C_in, 1] + [1 for _ in kernel_size] return weight.repeat(*repeat) - def get_conv(): - functional_for_module_cls = { - torch.nn.Conv1d: conv1d, - torch.nn.Conv2d: conv2d, - torch.nn.Conv3d: conv3d, - } - return functional_for_module_cls[module.__class__] + conv_dim = input.dim() - 2 + conv = get_conv_function(conv_dim) - conv = get_conv() unfold = conv( input, make_weight().to(input.device), diff --git a/backpack/utils/conv_transpose.py b/backpack/utils/conv_transpose.py index 6b4d138ec..3c90834be 100644 --- a/backpack/utils/conv_transpose.py +++ b/backpack/utils/conv_transpose.py @@ -1,13 +1,48 @@ """Utility functions for extracting transpose convolution BackPACK quantities.""" +from typing import Callable, Type + import torch from einops import rearrange from torch import einsum +from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module from torch.nn.functional import conv_transpose1d, conv_transpose2d, conv_transpose3d from backpack.utils.conv import extract_bias_diagonal as conv_extract_bias_diagonal +def get_conv_transpose_module(N: int) -> Type[Module]: + """Return the PyTorch module class of N-dimensional transpose convolution. + + Args: + N: Transpose convolution dimension. + + Returns: + Transpose convolution class. + """ + return { + 1: ConvTranspose1d, + 2: ConvTranspose2d, + 3: ConvTranspose3d, + }[N] + + +def get_conv_transpose_function(N: int) -> Callable: + """Return the PyTorch function of N-dimensional transpose convolution. + + Args: + N: Transpose convolution dimension. + + Returns: + Transpose convolution function. + """ + return { + 1: conv_transpose1d, + 2: conv_transpose2d, + 3: conv_transpose3d, + }[N] + + def get_weight_gradient_factors(input, grad_out, module): M, C_in = input.shape[0], input.shape[1] kernel_size_numel = module.weight.shape[2:].numel() @@ -109,15 +144,9 @@ def make_weight(): weight = weight.repeat(*repeat) return weight.to(module.weight.device) - def get_conv_transpose(): - functional_for_module_cls = { - torch.nn.ConvTranspose1d: conv_transpose1d, - torch.nn.ConvTranspose2d: conv_transpose2d, - torch.nn.ConvTranspose3d: conv_transpose3d, - } - return functional_for_module_cls[module.__class__] + conv_dim = input.dim() - 2 + conv_transpose = get_conv_transpose_function(conv_dim) - conv_transpose = get_conv_transpose() unfold = conv_transpose( input, make_weight().to(module.weight.device), diff --git a/test/bugfixes_test.py b/test/bugfixes_test.py index 088ecd907..ba4ef8a19 100644 --- a/test/bugfixes_test.py +++ b/test/bugfixes_test.py @@ -4,6 +4,7 @@ import torch import backpack +from backpack.core.derivatives.convnd import weight_jac_t_save_memory def parameters_issue_30(): @@ -31,7 +32,12 @@ def parameters_issue_30(): @pytest.mark.parametrize("params", **parameters_issue_30()) -def test_convolutions_stride_issue_30(params): +@pytest.mark.parametrize( + "save_memory", + [True, False], + ids=["save_memory=True", "save_memory=False"], +) +def test_convolutions_stride_issue_30(params, save_memory): """ https://github.com/f-dangel/backpack/issues/30 @@ -51,7 +57,9 @@ def test_convolutions_stride_issue_30(params): backpack.extend(mod) x = torch.randn(size=(params["N"], params["C_in"], params["W"], params["H"])) - with backpack.backpack(backpack.extensions.BatchGrad()): + with weight_jac_t_save_memory(save_memory), backpack.backpack( + backpack.extensions.BatchGrad() + ): loss = torch.sum(mod(x)) loss.backward() diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 1f3ebb9b2..4e67a5882 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -17,14 +17,18 @@ from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS from test.core.derivatives.settings import SETTINGS -from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda +from test.utils.skip_test import ( + skip_adaptive_avg_pool3d_cuda, + skip_batch_norm_train_mode_with_subsampling, + skip_no_param, + skip_permute_with_subsampling, + skip_subsampling_conflict, +) from typing import List, Tuple, Union from warnings import warn -import pytest -import torch -from pytest import fixture, skip -from torch import Tensor +from pytest import fixture, mark, raises, skip +from torch import Tensor, rand from backpack.core.derivatives.convnd import weight_jac_t_save_memory from backpack.utils.subsampling import get_batch_axis @@ -58,7 +62,7 @@ SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] -@pytest.mark.parametrize( +@mark.parametrize( "problem", NO_LOSS_PROBLEMS + RNN_PROBLEMS @@ -75,7 +79,7 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() - mat = torch.rand(V, *problem.input_shape).to(problem.device) + mat = rand(V, *problem.input_shape).to(problem.device) backpack_res = BackpackDerivatives(problem).jac_mat_prod(mat) autograd_res = AutogradDerivatives(problem).jac_mat_prod(mat) @@ -84,7 +88,8 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: problem.tear_down() -@pytest.mark.parametrize( +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize( "problem", NO_LOSS_PROBLEMS + RNN_PROBLEMS @@ -93,21 +98,34 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: + BATCH_NORM_PROBLEMS, ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS + BATCH_NORM_IDS, ) -def test_jac_t_mat_prod(problem: DerivativesTestProblem, request, V: int = 3) -> None: +def test_jac_t_mat_prod( + problem: DerivativesTestProblem, + subsampling: Union[None, List[int]], + request, + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product. Args: problem: Problem for derivative test. + subsampling: Indices of active samples. request: Pytest request, used for getting id. V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ skip_adaptive_avg_pool3d_cuda(request) problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + skip_permute_with_subsampling(problem, subsampling) + skip_batch_norm_train_mode_with_subsampling(problem, subsampling) + skip_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling) - backpack_res = BackpackDerivatives(problem).jac_t_mat_prod(mat) - autograd_res = AutogradDerivatives(problem).jac_t_mat_prod(mat) + backpack_res = BackpackDerivatives(problem).jac_t_mat_prod( + mat, subsampling=subsampling + ) + autograd_res = AutogradDerivatives(problem).jac_t_mat_prod( + mat, subsampling=subsampling + ) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() @@ -121,13 +139,9 @@ def test_jac_t_mat_prod(problem: DerivativesTestProblem, request, V: int = 3) -> IDS_WITH_WEIGHTS.append(problem_id) -@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@pytest.mark.parametrize( - "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] -) -@pytest.mark.parametrize( - "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS -) +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) +@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) def test_bias_ih_l0_jac_t_mat_prod( problem: DerivativesTestProblem, sum_batch: bool, @@ -143,8 +157,8 @@ def test_bias_ih_l0_jac_t_mat_prod( V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - _skip_if_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) + skip_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling) autograd_res = AutogradDerivatives(problem).bias_ih_l0_jac_t_mat_prod( mat, sum_batch, subsampling=subsampling @@ -157,13 +171,9 @@ def test_bias_ih_l0_jac_t_mat_prod( problem.tear_down() -@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@pytest.mark.parametrize( - "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] -) -@pytest.mark.parametrize( - "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS -) +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) +@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) def test_bias_hh_l0_jac_t_mat_prod( problem: DerivativesTestProblem, sum_batch: bool, @@ -179,8 +189,8 @@ def test_bias_hh_l0_jac_t_mat_prod( V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - _skip_if_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) + skip_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling) autograd_res = AutogradDerivatives(problem).bias_hh_l0_jac_t_mat_prod( mat, sum_batch, subsampling=subsampling @@ -193,13 +203,9 @@ def test_bias_hh_l0_jac_t_mat_prod( problem.tear_down() -@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@pytest.mark.parametrize( - "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] -) -@pytest.mark.parametrize( - "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS -) +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) +@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) def test_weight_ih_l0_jac_t_mat_prod( problem: DerivativesTestProblem, sum_batch: bool, @@ -215,8 +221,8 @@ def test_weight_ih_l0_jac_t_mat_prod( V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - _skip_if_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) + skip_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling) autograd_res = AutogradDerivatives(problem).weight_ih_l0_jac_t_mat_prod( mat, sum_batch, subsampling=subsampling @@ -229,13 +235,9 @@ def test_weight_ih_l0_jac_t_mat_prod( problem.tear_down() -@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@pytest.mark.parametrize( - "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] -) -@pytest.mark.parametrize( - "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS -) +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) +@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) def test_weight_hh_l0_jac_t_mat_prod( problem: DerivativesTestProblem, sum_batch: bool, @@ -251,8 +253,8 @@ def test_weight_hh_l0_jac_t_mat_prod( V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - _skip_if_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) + skip_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling) autograd_res = AutogradDerivatives(problem).weight_hh_l0_jac_t_mat_prod( mat, sum_batch, subsampling=subsampling @@ -265,10 +267,8 @@ def test_weight_hh_l0_jac_t_mat_prod( problem.tear_down() -@pytest.mark.parametrize( - "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] -) -@pytest.mark.parametrize( +@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) +@mark.parametrize( "save_memory", [True, False], ids=["save_memory=True", "save_memory=False"], @@ -321,10 +321,10 @@ def rand_mat_like_output( N_axis = get_batch_axis(problem.module) subsample_shape[N_axis] = len(subsampling) - return torch.rand(V, *subsample_shape) + return rand(V, *subsample_shape, device=problem.device) -@pytest.mark.parametrize( +@mark.parametrize( "problem", PROBLEMS_WITH_WEIGHTS + BATCH_NORM_PROBLEMS, ids=IDS_WITH_WEIGHTS + BATCH_NORM_IDS, @@ -337,7 +337,7 @@ def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> Non V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() - mat = torch.rand(V, *problem.module.weight.shape).to(problem.device) + mat = rand(V, *problem.module.weight.shape).to(problem.device) backpack_res = BackpackDerivatives(problem).weight_jac_mat_prod(mat) autograd_res = AutogradDerivatives(problem).weight_jac_mat_prod(mat) @@ -354,9 +354,7 @@ def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> Non IDS_WITH_BIAS.append(problem_id) -@pytest.mark.parametrize( - "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] -) +@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) def test_bias_jac_t_mat_prod( problem_bias_jac_t_mat: Tuple[DerivativesTestProblem, List[int], Tensor], sum_batch: bool, @@ -380,7 +378,7 @@ def test_bias_jac_t_mat_prod( check_sizes_and_values(autograd_res, backpack_res) -@pytest.mark.parametrize( +@mark.parametrize( "problem", PROBLEMS_WITH_BIAS + BATCH_NORM_PROBLEMS, ids=IDS_WITH_BIAS + BATCH_NORM_IDS, @@ -393,7 +391,7 @@ def test_bias_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() - mat = torch.rand(V, *problem.module.bias.shape).to(problem.device) + mat = rand(V, *problem.module.bias.shape).to(problem.device) backpack_res = BackpackDerivatives(problem).bias_jac_mat_prod(mat) autograd_res = AutogradDerivatives(problem).bias_jac_mat_prod(mat) @@ -402,7 +400,7 @@ def test_bias_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: problem.tear_down() -@pytest.mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) +@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) def test_sqrt_hessian_squared_equals_hessian(problem): """Test the sqrt decomposition of the input Hessian. @@ -420,18 +418,18 @@ def test_sqrt_hessian_squared_equals_hessian(problem): problem.tear_down() -@pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) +@mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_should_fail(problem): """Test sqrt_hessian. Should fail. Args: problem: test problem """ - with pytest.raises(ValueError): + with raises(ValueError): test_sqrt_hessian_squared_equals_hessian(problem) -@pytest.mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) +@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) def test_sqrt_hessian_sampled_squared_approximates_hessian(problem, mc_samples=100000): """Test the MC-sampled sqrt decomposition of the input Hessian. @@ -453,18 +451,18 @@ def test_sqrt_hessian_sampled_squared_approximates_hessian(problem, mc_samples=1 problem.tear_down() -@pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) +@mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_sampled_should_fail(problem): """Test sqrt_hessian. Should fail. Args: problem: test problem """ - with pytest.raises(ValueError): + with raises(ValueError): test_sqrt_hessian_sampled_squared_approximates_hessian(problem) -@pytest.mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) +@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) def test_sum_hessian(problem): """Test the summed Hessian. @@ -480,18 +478,18 @@ def test_sum_hessian(problem): problem.tear_down() -@pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) +@mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sum_hessian_should_fail(problem): """Test sum_hessian, should fail. Args: problem: test problem """ - with pytest.raises(ValueError): + with raises(ValueError): test_sum_hessian(problem) -@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) +@mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None: """Test KFRA backpropagation. @@ -511,7 +509,7 @@ def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None problem.set_up() out_features = problem.output_shape[1:].numel() - mat = torch.rand(out_features, out_features).to(problem.device) + mat = rand(out_features, out_features).to(problem.device) backpack_res = BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat) autograd_res = AutogradDerivatives(problem).ea_jac_t_mat_jac_prod(mat) @@ -546,7 +544,7 @@ def problem_weight(problem: DerivativesTestProblem) -> DerivativesTestProblem: Yields: Instantiated cases that have a weight parameter. """ - _skip_if_no_param(problem, "weight") + skip_no_param(problem, "weight") yield problem @@ -567,44 +565,15 @@ def problem_weight_jac_t_mat( problem with weight, subsampling, matrix for weight_jac_t """ subsampling: Union[None, List[int]] = request.param - _skip_if_subsampling_conflict(problem_weight, subsampling) + skip_subsampling_conflict(problem_weight, subsampling) V = 3 - mat = rand_mat_like_output(V, problem_weight, subsampling=subsampling).to( - problem_weight.device - ) + mat = rand_mat_like_output(V, problem_weight, subsampling=subsampling) yield (problem_weight, subsampling, mat) del mat -def _skip_if_subsampling_conflict( - problem: DerivativesTestProblem, subsampling: Union[List[int], None] -) -> None: - """Skip if some samples in subsampling are not contained in input. - - Args: - problem: Test case. - subsampling: Indices of active samples. - """ - N = problem.input_shape[get_batch_axis(problem.module)] - enough_samples = subsampling is None or N >= max(subsampling) - if not enough_samples: - skip("Not enough samples.") - - -def _skip_if_no_param(problem: DerivativesTestProblem, param_str: str) -> None: - """Skip if test case does not contain the parameter. - - Args: - problem: Test case. - param_str: Parameter name. - """ - has_param = getattr(problem.module, param_str, None) is not None - if not has_param: - skip(f"Test case has no {param_str} parameter.") - - @fixture def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: """Filter out cases that don't have a bias parameter. @@ -615,7 +584,7 @@ def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: Yields: Instantiated cases that have a bias parameter. """ - _skip_if_no_param(problem, "bias") + skip_no_param(problem, "bias") yield problem @@ -636,12 +605,10 @@ def problem_bias_jac_t_mat( problem with bias, subsampling, matrix for bias_jac_t """ subsampling: Union[None, List[int]] = request.param - _skip_if_subsampling_conflict(problem_bias, subsampling) + skip_subsampling_conflict(problem_bias, subsampling) V = 3 - mat = rand_mat_like_output(V, problem_bias, subsampling=subsampling).to( - problem_bias.device - ) + mat = rand_mat_like_output(V, problem_bias, subsampling=subsampling) yield (problem_bias, subsampling, mat) del mat diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 68060e990..230298c10 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -3,11 +3,12 @@ from typing import List import torch -from torch import Tensor, stack, zeros_like +from torch import Tensor, cat, stack, zeros_like from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.lop import transposed_jacobian_vector_product from backpack.hessianfree.rop import jacobian_vector_product +from backpack.utils.subsampling import get_batch_axis, subsample class AutogradDerivatives(DerivativesImplementation): @@ -35,12 +36,32 @@ def jac_mat_prod(self, mat): # noqa: D102 with torch.backends.cudnn.flags(enabled=False): return stack([self.jac_vec_prod(vec) for vec in mat]) - def jac_t_vec_prod(self, vec): # noqa: D102 + def jac_t_vec_prod(self, vec: Tensor, subsampling=None) -> Tensor: # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) - return transposed_jacobian_vector_product(output, input, vec)[0] - def jac_t_mat_prod(self, mat): # noqa: D102 - return stack([self.jac_t_vec_prod(vec) for vec in mat]) + if subsampling is None: + return transposed_jacobian_vector_product(output, input, vec)[0] + else: + # for each sample, multiply by full input Jacobian, slice out result: + # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n] + batch_axis = get_batch_axis(self.problem.module) + output = subsample(output, dim=batch_axis, subsampling=subsampling) + output = output.split(1, dim=batch_axis) + vec = vec.split(1, dim=batch_axis) + + vjps: List[Tensor] = [] + + for sample_idx, out, v in zip(subsampling, output, vec): + vjp = transposed_jacobian_vector_product(out, input, v)[0] + vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx]) + vjps.append(vjp) + + return cat(vjps, dim=batch_axis) + + def jac_t_mat_prod( + self, mat: Tensor, subsampling: List[int] = None + ) -> Tensor: # noqa: D102 + return stack([self.jac_t_vec_prod(vec, subsampling=subsampling) for vec in mat]) def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 return self.param_jac_t_mat_prod( diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index cced2fc93..5daeb01e9 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -1,5 +1,6 @@ """Contains derivative calculation with BackPACK.""" from test.core.derivatives.implementation.base import DerivativesImplementation +from typing import List import torch from torch import Tensor @@ -30,10 +31,12 @@ def jac_mat_prod(self, mat): # noqa: D102 self.problem.module, None, None, mat ) - def jac_t_mat_prod(self, mat): # noqa: D102 + def jac_t_mat_prod( + self, mat: Tensor, subsampling: List[int] + ) -> Tensor: # noqa: D102 self.store_forward_io() return self.problem.derivative.jac_t_mat_prod( - self.problem.module, None, None, mat + self.problem.module, None, None, mat, subsampling=subsampling ) def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index f64115357..cc875b409 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -30,11 +30,12 @@ def jac_mat_prod(self, mat: Tensor) -> Tensor: raise NotImplementedError @abstractmethod - def jac_t_mat_prod(self, mat: Tensor) -> Tensor: + def jac_t_mat_prod(self, mat: Tensor, subsampling: List[int] = None) -> Tensor: """Vectorized product of transposed jacobian and matrix. Args: mat: matrix: the vectors along its leading dimension will be multiplied. + subsampling: Active samples in the output. Default: ``None`` (all). Returns: Tensor representing the result of Jacobian-vector product. diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index a9744526c..f300a4b25 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -1,7 +1,14 @@ """Skip specific tests.""" -import pytest +from test.core.derivatives.problem import DerivativesTestProblem +from typing import List, Union + +from pytest import skip +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + +from backpack.custom_module.permute import Permute from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_1 +from backpack.utils.subsampling import get_batch_axis def skip_adaptive_avg_pool3d_cuda(request) -> None: @@ -17,7 +24,63 @@ def skip_adaptive_avg_pool3d_cuda(request) -> None: string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"] ): - pytest.skip( + skip( "Skip test because AdaptiveAvgPool3d does not work on cuda. " "Is fixed in torch 1.9.1." ) + + +def skip_permute_with_subsampling( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: + """Skip Permute module when sub-sampling is turned on. + + Permute does not assume a batch axis. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + if isinstance(problem.module, Permute) and subsampling is not None: + skip(f"Skipping Permute with sub-sampling: {subsampling}") + + +def skip_batch_norm_train_mode_with_subsampling( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: + """Skip BatchNorm in train mode when sub-sampling is turned on. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + if isinstance(problem.module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)): + if problem.module.train and subsampling is not None: + skip(f"Skipping BatchNorm in train mode with sub-sampling: {subsampling}") + + +def skip_subsampling_conflict( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: + """Skip if some samples in subsampling are not contained in input. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + N = problem.input_shape[get_batch_axis(problem.module)] + enough_samples = subsampling is None or N > max(subsampling) + if not enough_samples: + skip("Not enough samples.") + + +def skip_no_param(problem: DerivativesTestProblem, param_str: str) -> None: + """Skip if test case does not contain the parameter. + + Args: + problem: Test case. + param_str: Parameter name. + """ + has_param = getattr(problem.module, param_str, None) is not None + if not has_param: + skip(f"Test case has no {param_str} parameter.") From cd043f740c5f89963486eef58511d6b4a5e36a23 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 29 Jul 2021 08:51:11 +0200 Subject: [PATCH 36/54] [core] Support sub-sampling in loss Hessian factorization (#207) Prepares the core for #12. Adds `subsampling` argument to loss Hessian factorizations in the core derivatives. Auxiliary: - `test`: Make MC tolerances stricter, add chunking to deal with more samples - `core`: Slightly rewrite `sqrt_hessian` functions for `MSELoss` --- * [ADD] Support sub-sampling in loss Hessian factorization Make tolerances for MC sampling more strict and include chunked processing. * [FIX] Remove unused imports or make them more explicit * [FIX] Docstrings * [REF] Fix dtypes, slighly improve performance of sqrt_hessian * [REF] Fix dtype * [REF] Minor rewrite --- backpack/core/derivatives/basederivatives.py | 61 ++++++++++---- backpack/core/derivatives/crossentropyloss.py | 43 ++++++++-- backpack/core/derivatives/mseloss.py | 72 +++++++++-------- fully_documented.txt | 1 + test/core/derivatives/derivatives_test.py | 64 ++++++++++----- .../derivatives/implementation/autograd.py | 74 +++++++++++++---- .../derivatives/implementation/backpack.py | 79 ++++++++++++------- test/extensions/implementation/backpack.py | 31 +------- test/utils/__init__.py | 27 +++++++ 9 files changed, 308 insertions(+), 144 deletions(-) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 523955bdb..6f4173f1e 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -651,23 +651,41 @@ class BaseLossDerivatives(BaseDerivatives, ABC): # TODO Add shape check def sqrt_hessian( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, ) -> Tensor: """Symmetric factorization ('sqrt') of the loss Hessian. + The Hessian factorization is returned in format ``Hs = [D, N, D]``, where + ``Hs[:, n, :]`` is the Hessian factorization for the ``n``th sample, i.e. + ``Hs[:, n, :]ᵀ Hs[:, n, :]`` is the Hessian w.r.t. to the ``n``th sample. + Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients + module: Loss layer whose factorized Hessian will be computed. + g_inp: Gradients w.r.t. module input. + g_out: Gradients w.r.t. module output. + subsampling: Indices of data samples to be considered. Default of ``None`` + uses all data in the mini-batch. Returns: - square root of hessian + Symmetric factorization of the loss Hessian for each sample. If the input + to the loss has shape ``[N, D]``, this is a tensor of shape ``[D, N, D]``; + if used with sub-sampling, ``N`` is replaced by ``len(subsampling)``. + For fixed ``n``, squaring the matrix implied by the slice ``[:, n, :]`` + results in the loss Hessian w.r.t. to sample ``n``. """ self._check_2nd_order_make_sense(module, g_out) - return self._sqrt_hessian(module, g_inp, g_out) + return self._sqrt_hessian(module, g_inp, g_out, subsampling=subsampling) def _sqrt_hessian( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -678,20 +696,34 @@ def sqrt_hessian_sampled( g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mc_samples: int = 1, + subsampling: List[int] = None, ) -> Tensor: - """Monte-Carlo sampled symmetric factorization of the loss Hessian. + """A Monte-Carlo sampled symmetric factorization of the loss Hessian. + + The Hessian factorization is returned in format ``Hs = [M, N, D]``, where + ``Hs[:, n, :]`` approximates the Hessian factorization for the ``n``th sample, + i.e. ``Hs[:, n, :]ᵀ Hs[:, n, :]ᵀ`` approximates the Hessian w.r.t. to sample + ``n``. Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients - mc_samples: number of monte carlo samples. Defaults to 1. + module: Loss layer whose factorized Hessian will be computed. + g_inp: Gradients w.r.t. module input. + g_out: Gradients w.r.t. module output. + mc_samples: Number of samples used for MC approximation. + subsampling: Indices of data samples to be considered. Default of ``None`` + uses all data in the mini-batch. Returns: - square root of hessian + Symmetric factorization of the loss Hessian for each sample. If the input + to the loss has shape ``[N, D]``, this is a tensor of shape ``[M, N, D]`` + when using ``M`` MC samples; if used with sub-sampling, ``N`` is replaced + by ``len(subsampling)``. For fixed ``n``, squaring the matrix implied by the + slice ``[:, n, :]`` approximates the loss Hessian w.r.t. to sample ``n``. """ self._check_2nd_order_make_sense(module, g_out) - return self._sqrt_hessian_sampled(module, g_inp, g_out, mc_samples=mc_samples) + return self._sqrt_hessian_sampled( + module, g_inp, g_out, mc_samples=mc_samples, subsampling=subsampling + ) def _sqrt_hessian_sampled( self, @@ -699,6 +731,7 @@ def _sqrt_hessian_sampled( g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mc_samples: int = 1, + subsampling=None, ) -> Tensor: raise NotImplementedError diff --git a/backpack/core/derivatives/crossentropyloss.py b/backpack/core/derivatives/crossentropyloss.py index 8ead9f404..1a596b39f 100644 --- a/backpack/core/derivatives/crossentropyloss.py +++ b/backpack/core/derivatives/crossentropyloss.py @@ -1,11 +1,14 @@ """Partial derivatives for cross-entropy loss.""" from math import sqrt +from typing import List, Tuple -from torch import diag, diag_embed, einsum, multinomial, ones_like, softmax +from torch import Tensor, diag, diag_embed, einsum, multinomial, ones_like, softmax from torch import sqrt as torchsqrt +from torch.nn import CrossEntropyLoss from torch.nn.functional import one_hot from backpack.core.derivatives.basederivatives import BaseLossDerivatives +from backpack.utils.subsampling import subsample class CrossEntropyLossDerivatives(BaseLossDerivatives): @@ -15,10 +18,16 @@ class CrossEntropyLossDerivatives(BaseLossDerivatives): and negative log-likelihood. """ - def _sqrt_hessian(self, module, g_inp, g_out): + def _sqrt_hessian( + self, + module: CrossEntropyLoss, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: # noqa: D102 self._check_2nd_order_parameters(module) - probs = self._get_probs(module) + probs = self._get_probs(module, subsampling=subsampling) tau = torchsqrt(probs) V_dim, C_dim = 0, 2 Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim) @@ -31,13 +40,20 @@ def _sqrt_hessian(self, module, g_inp, g_out): return sqrt_H - def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): + def _sqrt_hessian_sampled( + self, + module: CrossEntropyLoss, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mc_samples: int = 1, + subsampling: List[int] = None, + ) -> Tensor: # noqa: D102 self._check_2nd_order_parameters(module) M = mc_samples C = module.input0.shape[1] - probs = self._get_probs(module) + probs = self._get_probs(module, subsampling=subsampling) V_dim = 0 probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1) @@ -88,8 +104,21 @@ def hessian_is_psd(self): """Return whether cross-entropy loss Hessian is positive semi-definite.""" return True - def _get_probs(self, module): - return softmax(module.input0, dim=1) + def _get_probs( + self, module: CrossEntropyLoss, subsampling: List[int] = None + ) -> Tensor: + """Compute the softmax probabilities from the module input. + + Args: + module: cross-entropy loss with I/O. + subsampling: Indices of samples to be considered. Default of ``None`` uses + the full mini-batch. + + Returns: + Softmax probabilites + """ + input0 = subsample(module.input0, subsampling=subsampling) + return softmax(input0, dim=1) def _check_2nd_order_parameters(self, module): """Verify that the parameters are supported by 2nd-order quantities. diff --git a/backpack/core/derivatives/mseloss.py b/backpack/core/derivatives/mseloss.py index e18c12b97..f09750052 100644 --- a/backpack/core/derivatives/mseloss.py +++ b/backpack/core/derivatives/mseloss.py @@ -1,8 +1,10 @@ """Derivatives of the MSE Loss.""" from math import sqrt +from typing import List, Tuple -from torch import einsum, eye, normal +from torch import Tensor, eye, normal, ones +from torch.nn import MSELoss from backpack.core.derivatives.basederivatives import BaseLossDerivatives @@ -16,50 +18,52 @@ class MSELossDerivatives(BaseLossDerivatives): `∑ᵢ₌₁ⁿ ‖X[i,∶] − Y[i,∶]‖²`. If `reduce=mean`, the result is divided by `nd`. """ - def _sqrt_hessian(self, module, g_inp, g_out): - """Square-root of the hessian of the MSE for each minibatch elements. - - Returns the Hessian in format `Hs = [D, N, D]`, where - `Hs[:, n, :]` is the Hessian for the `n`th element. - - Attributes: - module: (torch.nn.MSELoss) module - g_inp: Gradient of loss w.r.t. input - g_out: Gradient of loss w.r.t. output - - Returns: - Batch of hessians, in format [D, N, D] - """ + def _sqrt_hessian( + self, + module: MSELoss, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + subsampling: List[int] = None, + ) -> Tensor: # noqa: D102 self.check_input_dims(module) - N, D = module.input0.shape - sqrt_H = sqrt(2) * eye(D, device=module.input0.device) # [D, D] - sqrt_H = sqrt_H.unsqueeze(0).repeat(N, 1, 1) # [N, D, D] - sqrt_H = einsum("nab->anb", sqrt_H) # [D, N, D] + input0: Tensor = module.input0 + N, D = input0.shape + N_active = N if subsampling is None else len(subsampling) + scale = sqrt(2) if module.reduction == "mean": - sqrt_H /= sqrt(module.input0.numel()) + scale /= sqrt(input0.numel()) - return sqrt_H + sqrt_H_diag = scale * ones(D, device=input0.device, dtype=input0.dtype) + sqrt_H = sqrt_H_diag.diag().unsqueeze(1).expand(-1, N_active, -1) - def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): - """A Monte-Carlo estimate of the square-root of the Hessian. + return sqrt_H - Attributes: - module: (torch.nn.MSELoss) module. - g_inp: Gradient of loss w.r.t. input. - g_out: Gradient of loss w.r.t. output. - mc_samples: (int, optional) Number of MC samples to use. Default: 1. + def _sqrt_hessian_sampled( + self, + module: MSELoss, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mc_samples: int = 1, + subsampling: List[int] = None, + ) -> Tensor: + self.check_input_dims(module) - Returns: - tensor: - """ - N, D = module.input0.shape - samples = normal(0, 1, size=[mc_samples, N, D], device=module.input0.device) + input0: Tensor = module.input0 + N, D = input0.shape + N_active = N if subsampling is None else len(subsampling) + samples = normal( + 0, + 1, + size=[mc_samples, N_active, D], + device=input0.device, + dtype=input0.dtype, + ) samples *= sqrt(2) / sqrt(mc_samples) if module.reduction == "mean": - samples /= sqrt(module.input0.numel()) + samples /= sqrt(input0.numel()) return samples diff --git a/fully_documented.txt b/fully_documented.txt index 47756b78b..0c9ed68d8 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -79,3 +79,4 @@ test/core/derivatives/pooling_adaptive_settings.py test/core/derivatives/batch_norm_settings.py test/utils/evaluation_mode.py test/utils/skip_test.py +test/utils/__init__.py diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 4e67a5882..a4366dedd 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -400,66 +400,90 @@ def test_bias_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: problem.tear_down() +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) -def test_sqrt_hessian_squared_equals_hessian(problem): +def test_sqrt_hessian_squared_equals_hessian( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: """Test the sqrt decomposition of the input Hessian. Args: - problem (DerivativesProblem): Problem for derivative test. + problem: Test case. + subsampling: Indices of active samples. Compares the Hessian to reconstruction from individual Hessian sqrt. """ problem.set_up() + skip_subsampling_conflict(problem, subsampling) - backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian() - autograd_res = AutogradDerivatives(problem).input_hessian() + backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian( + subsampling=subsampling + ) + autograd_res = AutogradDerivatives(problem).input_hessian(subsampling=subsampling) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) -def test_sqrt_hessian_should_fail(problem): - """Test sqrt_hessian. Should fail. +def test_sqrt_hessian_should_fail( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: + """Test that sqrt_hessian fails. Args: - problem: test problem + problem: Test case. + subsampling: Indices of active samples. """ with raises(ValueError): - test_sqrt_hessian_squared_equals_hessian(problem) + test_sqrt_hessian_squared_equals_hessian(problem, subsampling) +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) -def test_sqrt_hessian_sampled_squared_approximates_hessian(problem, mc_samples=100000): +def test_sqrt_hessian_sampled_squared_approximates_hessian( + problem: DerivativesTestProblem, + subsampling: Union[List[int], None], + mc_samples: int = 1000000, + chunks: int = 10, +) -> None: """Test the MC-sampled sqrt decomposition of the input Hessian. - Args: - problem (DerivativesProblem): Problem for derivative test. - mc_samples: number of samples. Defaults to 100000. - Compares the Hessian to reconstruction from individual Hessian MC-sampled sqrt. + + Args: + problem: Test case. + subsampling: Indices of active samples. + mc_samples: number of samples. Defaults to 1000000. + chunks: Number of passes the MC samples will be processed sequentially. """ problem.set_up() + skip_subsampling_conflict(problem, subsampling) backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian( - mc_samples=mc_samples + mc_samples=mc_samples, chunks=chunks, subsampling=subsampling ) - autograd_res = AutogradDerivatives(problem).input_hessian() + autograd_res = AutogradDerivatives(problem).input_hessian(subsampling=subsampling) - RTOL, ATOL = 1e-2, 2e-2 + RTOL, ATOL = 1e-2, 7e-3 check_sizes_and_values(autograd_res, backpack_res, rtol=RTOL, atol=ATOL) problem.tear_down() +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) -def test_sqrt_hessian_sampled_should_fail(problem): - """Test sqrt_hessian. Should fail. +def test_sqrt_hessian_sampled_should_fail( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: + """Test that sqrt_hessian_samples fails. Args: - problem: test problem + problem: Test case. + subsampling: Indices of active samples. """ with raises(ValueError): - test_sqrt_hessian_sampled_squared_approximates_hessian(problem) + test_sqrt_hessian_sampled_squared_approximates_hessian(problem, subsampling) @mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 230298c10..6281ab4ad 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -2,8 +2,7 @@ from test.core.derivatives.implementation.base import DerivativesImplementation from typing import List -import torch -from torch import Tensor, cat, stack, zeros_like +from torch import Tensor, allclose, backends, cat, stack, zeros, zeros_like from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.lop import transposed_jacobian_vector_product @@ -33,7 +32,7 @@ def jac_mat_prod(self, mat): # noqa: D102 # A RuntimeError is thrown for RNNs on CUDA, # because PyTorch does not support double-backwards pass for them. # This is the recommended workaround. - with torch.backends.cudnn.flags(enabled=False): + with backends.cudnn.flags(enabled=False): return stack([self.jac_vec_prod(vec) for vec in mat]) def jac_t_vec_prod(self, vec: Tensor, subsampling=None) -> Tensor: # noqa: D102 @@ -205,7 +204,7 @@ def _sample_jac_t_mat_prod(sample_idx, mat): input_requires_grad=True, sample_idx=sample_idx ) - result = torch.zeros(sample.numel(), mat.size(1), device=sample.device) + result = zeros(sample.numel(), mat.size(1), device=sample.device) for col in range(mat.size(1)): column = mat[:, col].reshape(output.shape) @@ -225,7 +224,7 @@ def _sample_jac_t_mat_prod(sample_idx, mat): N = self.problem.input.shape[0] input_features = self.problem.input.shape.numel() // N - result = torch.zeros(input_features, input_features).to(self.problem.device) + result = zeros(input_features, input_features).to(self.problem.device) for n in range(N): result += _sample_jac_t_mat_jac_prod(n, mat) @@ -252,11 +251,11 @@ def _hessian(self, loss: Tensor, x: Tensor) -> Tensor: vectorized_shape = (x.numel(), x.numel()) final_shape = (*x.shape, *x.shape) - hessian_vec_x = torch.zeros(vectorized_shape).to(loss.device) + hessian_vec_x = zeros(vectorized_shape).to(loss.device) num_cols = hessian_vec_x.shape[1] for column_idx in range(num_cols): - unit = torch.zeros(num_cols).to(loss.device) + unit = zeros(num_cols).to(loss.device) unit[column_idx] = 1.0 unit = unit.view_as(x) @@ -294,7 +293,7 @@ def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor: try: yield self._hessian(t, x) except (RuntimeError, AttributeError): - yield torch.zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype) + yield zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype) def hessian_is_zero(self) -> bool: # noqa: D102 input, output, _ = self.problem.forward_pass(input_requires_grad=True) @@ -304,19 +303,66 @@ def hessian_is_zero(self) -> bool: # noqa: D102 if zero is None: zero = zeros_like(hessian) - if not torch.allclose(hessian, zero): + if not allclose(hessian, zero): return False return True - def input_hessian(self) -> Tensor: + def input_hessian(self, subsampling: List[int] = None) -> Tensor: """Compute the Hessian of the module output w.r.t. the input. + Args: + subsampling: Indices of active samples. ``None`` uses all samples. + Returns: hessian """ input, output, _ = self.problem.forward_pass(input_requires_grad=True) - return self._hessian(output, input) + hessian = self._hessian(output, input) + return self._subsample_input_hessian(hessian, input, subsampling=subsampling) + + @staticmethod + def _subsample_input_hessian( + hessian: Tensor, input: Tensor, subsampling: List[int] = None + ) -> Tensor: + """Slice sub-samples out of Hessian w.r.t the full input. + + If ``subsampling`` is set to ``None``, leaves the Hessian unchanged. + + Args: + hessian: The Hessian w.r.t. the module input. + input: Module input. + subsampling: List of active samples. Default of ``None`` uses all samples. + + Returns: + Sub-sampled Hessian of shape ``[N, *, N, *]`` where ``N`` denotes the + number of sub-samples, and ``*`` is the input feature shape. + """ + N, D_shape = input.shape[0], input.shape[1:] + D = input.numel() // N + + subsampled_hessian = hessian.reshape(N, D, N, D)[subsampling, :, :, :][ + :, :, subsampling, : + ] + + has_duplicates = subsampling is not None and len(set(subsampling)) != len( + subsampling + ) + if has_duplicates: + # For duplicates in `subsampling`, the above slicing is not sufficient. + # and off-diagonal blocks need to be zeroed. E.g. if subsampling is [0, 0] + # then the sliced input Hessian has non-zero off-diagonal blocks (1, 0) and + # (0, 1), which should be zero as the samples are considered independent. + for idx1, sample1 in enumerate(subsampling[:-1]): + for idx2, sample2 in enumerate(subsampling[idx1 + 1 :], start=idx1 + 1): + if sample1 == sample2: + subsampled_hessian[idx1, :, idx2, :] = 0 + subsampled_hessian[idx2, :, idx1, :] = 0 + + N_active = N if subsampling is None else len(subsampling) + out_shape = [N_active, *D_shape, N_active, *D_shape] + + return subsampled_hessian.reshape(out_shape) def sum_hessian(self) -> Tensor: """Compute the Hessian of a loss module w.r.t. its input. @@ -351,9 +397,9 @@ def _sum_hessian_blocks(self, hessian: Tensor) -> Tensor: N = input.shape[0] num_features = input.numel() // N - sum_hessian = torch.zeros(num_features, num_features, device=input.device) + sum_hessian = zeros(num_features, num_features, device=input.device) - hessian_different_samples = torch.zeros( + hessian_different_samples = zeros( num_features, num_features, device=input.device ) for n_1 in range(N): @@ -364,6 +410,6 @@ def _sum_hessian_blocks(self, hessian: Tensor) -> Tensor: sum_hessian += block else: - assert torch.allclose(block, hessian_different_samples) + assert allclose(block, hessian_different_samples) return sum_hessian diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 5daeb01e9..c40a2d230 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -1,9 +1,11 @@ """Contains derivative calculation with BackPACK.""" from test.core.derivatives.implementation.base import DerivativesImplementation +from test.utils import chunk_sizes from typing import List -import torch -from torch import Tensor +from torch import Tensor, einsum, zeros + +from backpack.utils.subsampling import subsample class BackpackDerivatives(DerivativesImplementation): @@ -131,32 +133,48 @@ def sum_hessian(self): # noqa: D102 self.store_forward_io() return self.problem.derivative.sum_hessian(self.problem.module, None, None) - def input_hessian_via_sqrt_hessian(self, mc_samples=None) -> Tensor: - """Computes the input hessian. + def input_hessian_via_sqrt_hessian( + self, mc_samples: int = None, chunks: int = 1, subsampling: List[int] = None + ) -> Tensor: + """Computes the Hessian w.r.t. to the input from its matrix square root. Args: mc_samples: If int, uses an MC approximation with the specified number of samples. If None, uses the exact hessian. Defaults to None. + chunks: Maximum sequential split of the computation. Default: ``1``. + Only used if mc_samples is specified. + subsampling: Indices of active samples. ``None`` uses all samples. Returns: - hessian + Hessian with respect to the input. """ self.store_forward_io() if mc_samples is not None: - sqrt_hessian = self.problem.derivative.sqrt_hessian_sampled( - self.problem.module, None, None, mc_samples=mc_samples + chunk_samples = chunk_sizes(mc_samples, chunks) + chunk_weights = [samples / mc_samples for samples in chunk_samples] + + individual_hessians: Tensor = sum( + weight + * self._sample_hessians_from_sqrt( + self.problem.derivative.sqrt_hessian_sampled( + self.problem.module, + None, + None, + mc_samples=samples, + subsampling=subsampling, + ) + ) + for weight, samples in zip(chunk_weights, chunk_samples) ) else: sqrt_hessian = self.problem.derivative.sqrt_hessian( - self.problem.module, None, None + self.problem.module, None, None, subsampling=subsampling ) + individual_hessians = self._sample_hessians_from_sqrt(sqrt_hessian) - individual_hessians = self._sample_hessians_from_sqrt(sqrt_hessian) - - return self._embed_sample_hessians( - individual_hessians, self.problem.module.input0 - ) + input0 = subsample(self.problem.module.input0, subsampling=subsampling) + return self._embed_sample_hessians(individual_hessians, input0) def hessian_is_zero(self) -> bool: # noqa: D102 return self.problem.derivative.hessian_is_zero(self.problem.module) @@ -171,30 +189,35 @@ def _sample_hessians_from_sqrt(self, sqrt): individual full matrix Raises: - ValueError: if input is not 3d + ValueError: if input is not 2d """ - equation = None - num_axes = len(sqrt.shape) - # TODO improve readability - if num_axes == 3: - equation = "vni,vnj->nij" + if sqrt.dim() == 3: + return einsum("vni,vnj->nij", sqrt, sqrt) else: raise ValueError("Only 2D inputs are currently supported.") - return torch.einsum(equation, sqrt, sqrt) + def _embed_sample_hessians( + self, individual_hessians: Tensor, input: Tensor + ) -> Tensor: + """Embed Hessians w.r.t. individual samples into Hessian w.r.t. all samples. - def _embed_sample_hessians(self, individual_hessians, input): - hessian_shape = (*input.shape, *input.shape) - hessian = torch.zeros(hessian_shape, device=input.device) + Args: + individual_hessians: Hessians w.r.t. individual samples in the input. + input: Inputs for the individual Hessians. - N = input.shape[0] + Returns: + Hessian that contains the individual Hessians as diagonal blocks. - for n in range(N): - num_axes = len(input.shape) + Raises: + ValueError: if input is not 2d + """ + hessian_shape = (*input.shape, *input.shape) + hessian = zeros(hessian_shape, device=input.device, dtype=input.dtype) - if num_axes == 2: - hessian[n, :, n, :] = individual_hessians[n] + for idx in range(input.shape[0]): + if input.dim() == 2: + hessian[idx, :, idx, :] = individual_hessians[idx] else: raise ValueError("Only 2D inputs are currently supported.") diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index ff475913d..2784bfc03 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -6,6 +6,7 @@ SumGradSquaredHook, ) from test.extensions.problem import ExtensionsTestProblem +from test.utils import chunk_sizes from typing import List from torch import Tensor, cat, einsum @@ -108,7 +109,7 @@ def diag_ggn_mc_chunk(self, mc_samples: int, chunks: int = 10) -> List[Tensor]: Returns: Parameter-wise MC-approximation of the GGN diagonal. """ - chunk_samples = self.chunk_sizes(mc_samples, chunks) + chunk_samples = chunk_sizes(mc_samples, chunks) chunk_weights = [samples / mc_samples for samples in chunk_samples] diag_ggn_mc = None @@ -137,7 +138,7 @@ def diag_ggn_mc_batch_chunk( Returns: Parameter-wise MC-approximation of the per-sample GGN diagonals. """ - chunk_samples = self.chunk_sizes(mc_samples, chunks) + chunk_samples = chunk_sizes(mc_samples, chunks) chunk_weights = [samples / mc_samples for samples in chunk_samples] diag_ggn_mc_batch = None @@ -156,30 +157,6 @@ def diag_ggn_mc_batch_chunk( return diag_ggn_mc_batch - @staticmethod - def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: - """Return list containing the sizes of chunks. - - Args: - total_size: Total computation work. - num_chunks: Maximum number of chunks the work will be split into. - - Returns: - List of chunks with split work. - """ - chunk_size = max(total_size // num_chunks, 1) - - if chunk_size == 1: - sizes = total_size * [chunk_size] - else: - equal, rest = divmod(total_size, chunk_size) - sizes = equal * [chunk_size] - - if rest != 0: - sizes.append(rest) - - return sizes - def diag_h(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagHessian()): _, _, loss = self.problem.forward_pass() @@ -239,7 +216,7 @@ def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: return self.problem.collect_data("sqrt_ggn_mc") def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: # noqa:D102 - samples = self.chunk_sizes(mc_samples, chunks) + samples = chunk_sizes(mc_samples, chunks) weights = [samples / mc_samples for samples in samples] return sum( diff --git a/test/utils/__init__.py b/test/utils/__init__.py index e69de29bb..40711349c 100644 --- a/test/utils/__init__.py +++ b/test/utils/__init__.py @@ -0,0 +1,27 @@ +"""Helper functions for tests.""" + +from typing import List + + +def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: + """Return list containing the sizes of chunks. + + Args: + total_size: Total computation work. + num_chunks: Maximum number of chunks the work will be split into. + + Returns: + List of chunks with split work. + """ + chunk_size = max(total_size // num_chunks, 1) + + if chunk_size == 1: + sizes = total_size * [chunk_size] + else: + equal, rest = divmod(total_size, chunk_size) + sizes = equal * [chunk_size] + + if rest != 0: + sizes.append(rest) + + return sizes From 75c73c2809b63c347f5172d04dfc262aa2db6cb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 29 Jul 2021 21:04:06 +0200 Subject: [PATCH 37/54] [REF] Shared matrix parameter-Jacobian products (#203) Share `{weight, bias, ...}_jac_t_mat_prod` through `param_mjp`. Reduces duplication introduced by the additional parameters `weight_ih_l0, bias_ih_l0` etc. of recurrent nets. Progress on https://github.com/fKunstner/backpack-discuss/issues/105 --- * [core] Design prototype for transpose JVPs, enable for bias * param_mjp: shape check: accept vectors * include subsampling * param_mjp: shape check: accept vectors * delete {param}_jac_t_mat_prod methods * delete private methods * introduce single test test_param_jac_t_mat_prod * seed * seed * include RNN and LSTM in test_param_jac_t_mat_prod * include save_memory in test_param_jac_t_mat_prod * include save_memory in test_param_jac_t_mat_prod * refactor test derivatives implementation * move shape check inside method * docstring * format * rename methods * format * format * use get_batch_axis * format * simplify: skip subsampling * [REF] Remove unnecessary device transfer * [REF] Inline call to `get_batch_axis` * [DEL] Remove `skip_no_param` condition * [REF] tidy up test_param_mjp * [REF] tidy up test_param_mjp * [DEL] Remove explicit delete statements * [TEST] Add LSTM with small input to test hessian_is_zero * [REF] More explicit imports Co-authored-by: Felix Dangel Co-authored-by: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Co-authored-by: Felix Dangel --- backpack/core/derivatives/basederivatives.py | 334 ++++-------------- backpack/core/derivatives/lstm.py | 3 + backpack/core/derivatives/rnn.py | 3 + backpack/core/derivatives/shape_check.py | 111 +++--- .../curvmatprod/ggnmp/batchnorm1d.py | 6 +- .../extensions/curvmatprod/ggnmp/conv2d.py | 6 +- .../extensions/curvmatprod/ggnmp/linear.py | 6 +- .../extensions/curvmatprod/hmp/batchnorm1d.py | 6 +- backpack/extensions/curvmatprod/hmp/conv2d.py | 6 +- backpack/extensions/curvmatprod/hmp/linear.py | 6 +- .../extensions/curvmatprod/pchmp/conv2d.py | 6 +- .../extensions/curvmatprod/pchmp/linear.py | 6 +- .../firstorder/batch_grad/batch_grad_base.py | 3 +- .../firstorder/batch_l2_grad/batch_l2_base.py | 4 +- .../extensions/firstorder/gradient/base.py | 8 +- .../firstorder/sum_grad_squared/sgs_base.py | 4 +- .../secondorder/diag_ggn/diag_ggn_base.py | 6 +- .../extensions/secondorder/sqrt_ggn/base.py | 4 +- test/benchmark/jvp.py | 4 +- test/core/derivatives/batch_norm_settings.py | 5 +- test/core/derivatives/derivatives_test.py | 333 +++-------------- .../derivatives/implementation/autograd.py | 81 +---- .../derivatives/implementation/backpack.py | 70 +--- test/core/derivatives/implementation/base.py | 91 +---- test/core/derivatives/lstm_settings.py | 11 +- test/utils/skip_test.py | 12 - 26 files changed, 253 insertions(+), 882 deletions(-) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 6f4173f1e..dd1205d45 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -322,66 +322,10 @@ class BaseParameterDerivatives(BaseDerivatives, ABC): For most layers, these shapes correspond to shapes of the module input or output. """ - @shape_check.bias_jac_mat_prod_accept_vectors - @shape_check.bias_jac_mat_prod_check_shapes - def bias_jac_mat_prod( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor - ) -> Tensor: - """Apply Jacobian of the output w.r.t. bias to a matrix. - - Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients - mat: Matrix the Jacobian will be applied to. - Must have shape [V, C_b, ...]. - - Returns: - Jacobian-matrix product. Has shape [V, N, C_out, H_out, ...]. - """ - return self._bias_jac_mat_prod(module, g_inp, g_out, mat) - - def _bias_jac_mat_prod( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor - ) -> Tensor: - raise NotImplementedError - - @shape_check.bias_jac_t_mat_prod_accept_vectors - @shape_check.bias_jac_t_mat_prod_check_shapes - def bias_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, - ) -> Tensor: - """Apply transposed Jacobian of the output w.r.t. bias to a matrix. - - Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients - mat: Matrix the transposed Jacobian will be applied to. - Has shape ``[V, *module.output.shape]``; but if used with - sub-sampling, the batch dimension is replaced by ``len(subsampling)``. - sum_batch: Whether to sum over the batch dimension on the fly. - subsampling: Indices of samples along the output's batch dimension that - should be considered. Defaults to ``None`` (use all samples). - - Returns: - Jacobian-matrix product. - If ``sum_batch=False``, has shape ``[V, N, *module.bias.shape]``. - If ``sum_batch=True``, has shape ``[V, *module.bias.shape]``. - If sub-sampling is used, ``N`` is replaced by ``len(subsampling)``. - """ - return self._bias_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling - ) - - def _bias_jac_t_mat_prod( + @shape_check.param_mjp_accept_vectors + def param_mjp( self, + param_str: str, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], @@ -389,259 +333,105 @@ def _bias_jac_t_mat_prod( sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: - raise NotImplementedError - - @shape_check.weight_jac_mat_prod_accept_vectors - @shape_check.weight_jac_mat_prod_check_shapes - def weight_jac_mat_prod( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor - ) -> Tensor: - """Apply Jacobian of the output w.r.t. weight to a matrix. + """Compute matrix-Jacobian products (MJPs) of the module w.r.t. a parameter. - Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients - mat: Matrix the Jacobian will be applied to. - Must have shape [V, C_w, H_w, ...]. - - Returns: - Jacobian-matrix product. - Has shape [V, N, C_out, H_out, ...]. - """ - return self._weight_jac_mat_prod(module, g_inp, g_out, mat) - - def _weight_jac_mat_prod( - self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor - ) -> Tensor: - raise NotImplementedError + Handles both vector and matrix inputs. Preserves input format in output. - @shape_check.weight_jac_t_mat_prod_accept_vectors - @shape_check.weight_jac_t_mat_prod_check_shapes - def weight_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, - ) -> Tensor: - """Apply transposed Jacobian of the output w.r.t. weight to a matrix. + Internally calls out to ``_{param_str}_jac_t_mat_prod`` function that must be + implemented by descendants. It follows the same signature, but does not have + the ``param_str`` argument. Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients - mat: Matrix the transposed Jacobian will be applied to. - Has shape ``[V, *module.output.shape]``; but if used with - sub-sampling, the batch dimension is replaced by ``len(subsampling)``. - sum_batch: Whether to sum over the batch dimension on the fly. + param_str: Attribute name under which the parameter is stored in the module. + module: Module whose Jacobian will be applied. Must provide access to IO. + g_inp: Gradients w.r.t. module input. + g_out: Gradients w.r.t. module output. + mat: Matrix the Jacobian will be applied to. Has shape + ``[V, *module.output.shape]`` (matrix case) or same shape as + ``module.output`` (vector case). If used with subsampling, has dimension + len(subsampling) instead of batch size along the batch axis. + sum_batch: Sum out the MJP's batch axis. Default: ``True``. subsampling: Indices of samples along the output's batch dimension that should be considered. Defaults to ``None`` (use all samples). Returns: - Jacobian-matrix product. - If ``sum_batch=False``, has shape ``[V, N, *module.weight.shape]``. - If ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. - If sub-sampling is used, ``N`` is replaced by ``len(subsampling)``. - """ - return self._weight_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling - ) - - def _weight_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, - ) -> Tensor: - raise NotImplementedError - - @shape_check.bias_jac_t_mat_prod_accept_vectors - @shape_check.bias_rnn_jac_t_mat_prod_check_shapes - def bias_ih_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, - ) -> Tensor: - """Apply transposed Jacobian of the output w.r.t. bias_ih_l0 to a matrix. - - Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients - mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]; but if used with sub-sampling, the batch - dimension is replaced by ``len(subsampling)``. - sum_batch: Whether to sum over the batch dimension on the fly. - subsampling: Indices of samples along the output's batch dimension that - should be considered. Defaults to ``None`` (use all samples). + Matrix-Jacobian products. Has shape ``[V, *param_shape]`` when batch + summation is enabled (same shape as parameter in the vector case). Without + batch summation, the result has shape ``[V, N, *param_shape]`` (vector case + has shape ``[N, *param_shape]``). If used with subsampling, the batch size N + is replaced by len(subsampling). - Returns: - Jacobian-matrix product. - Has shape [V, N, *module.bias_ih_l0.shape] if ``sum_batch == False``; but if - used with sub-sampling, the batch dimension is replaced by - ``len(subsampling)``. Has shape [V, *module.bias_ih_l0.shape] if - ``sum_batch == True``. + Raises: + NotImplementedError: if required method is not implemented by derivatives class """ - return self._bias_ih_l0_jac_t_mat_prod( + # input check + shape_check.shape_like_output(mat, module, subsampling=subsampling) + + method_name = f"_{param_str}_jac_t_mat_prod" + mjp = getattr(self, method_name, None) + if mjp is None: + raise NotImplementedError( + f"Computation requires implementation of {method_name}, but {self} " + f"(defining derivatives of {module}) does not implement it." + ) + mjp_out = mjp( module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) - def _bias_ih_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, - ) -> Tensor: - raise NotImplementedError - - @shape_check.bias_jac_t_mat_prod_accept_vectors - @shape_check.bias_rnn_jac_t_mat_prod_check_shapes - def bias_hh_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, - ) -> Tensor: - """Apply transposed Jacobian of the output w.r.t. bias_hh_l0 to a matrix. - - Args: - module: module to perform derivatives on - g_inp: input gradients - g_out: output gradients - mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]; but if used with sub-sampling, the batch - dimension is replaced by ``len(subsampling)``. - sum_batch: Whether to sum over the batch dimension on the fly. - subsampling: Indices of samples along the output's batch dimension that - should be considered. Defaults to ``None`` (use all samples). - - Returns: - Jacobian-matrix product. - Has shape [V, N, *module.bias_hh_l0.shape] if ``sum_batch == False``; but if - used with sub-sampling, the batch dimension is replaced by - ``len(subsampling)``. Has shape [V, *module.bias_hh_l0.shape] if - ``sum_batch == True``. - """ - return self._bias_hh_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling + # output check + shape_check.check_like_with_sum_batch( + mjp_out, module, param_str, sum_batch=sum_batch ) + shape_check.check_same_V_dim(mjp_out, mat) - def _bias_hh_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, - ) -> Tensor: - raise NotImplementedError + return mjp_out - @shape_check.weight_jac_t_mat_prod_accept_vectors - @shape_check.weight_ih_jac_t_mat_prod_check_shapes - def weight_ih_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, + @shape_check.bias_jac_mat_prod_accept_vectors + @shape_check.bias_jac_mat_prod_check_shapes + def bias_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: - """Apply transposed Jacobian of the output w.r.t. weight_ih_l0 to a matrix. + """Apply Jacobian of the output w.r.t. bias to a matrix. Args: module: module to perform derivatives on g_inp: input gradients g_out: output gradients - mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]; but if used with sub-sampling, the batch - dimension is replaced by ``len(subsampling)``. - sum_batch: Whether to sum over the batch dimension on the fly. - subsampling: Indices of samples along the output's batch dimension that - should be considered. Defaults to ``None`` (use all samples). + mat: Matrix the Jacobian will be applied to. + Must have shape [V, C_b, ...]. Returns: - Jacobian-matrix product. - Has shape [V, N, *module.weight_ih_l0.shape] if ``sum_batch == False``; but - if used with sub-sampling, the batch dimension is replaced by - ``len(subsampling)``. Has shape [V, *module.weight_ih_l0.shape] if - ``sum_batch == True``. + Jacobian-matrix product. Has shape [V, N, C_out, H_out, ...]. """ - return self._weight_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling - ) + return self._bias_jac_mat_prod(module, g_inp, g_out, mat) - def _weight_ih_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, + def _bias_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: raise NotImplementedError - @shape_check.weight_jac_t_mat_prod_accept_vectors - @shape_check.weight_hh_jac_t_mat_prod_check_shapes - def weight_hh_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, + @shape_check.weight_jac_mat_prod_accept_vectors + @shape_check.weight_jac_mat_prod_check_shapes + def weight_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: - """Apply transposed Jacobian of the output w.r.t. weight_hh_l0 to a matrix. + """Apply Jacobian of the output w.r.t. weight to a matrix. Args: module: module to perform derivatives on g_inp: input gradients g_out: output gradients - mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]; but if used with sub-sampling, the batch - dimension is replaced by ``len(subsampling)``. - sum_batch: Whether to sum over the batch dimension on the fly. - subsampling: Indices of samples along the output's batch dimension that - should be considered. Defaults to ``None`` (use all samples). + mat: Matrix the Jacobian will be applied to. + Must have shape [V, C_w, H_w, ...]. Returns: Jacobian-matrix product. - Has shape [V, N, *module.weight_hh_l0.shape] if ``sum_batch == False``; but - if used with sub-sampling, the batch dimension is replaced by - ``len(subsampling)``. Has shape [V, *module.weight_hh_l0.shape] if - ``sum_batch == True``. + Has shape [V, N, C_out, H_out, ...]. """ - return self._weight_hh_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling - ) + return self._weight_jac_mat_prod(module, g_inp, g_out, mat) - def _weight_hh_l0_jac_t_mat_prod( - self, - module: Module, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mat: Tensor, - sum_batch: bool = True, - subsampling: List[int] = None, + def _weight_jac_mat_prod( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: raise NotImplementedError diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index de774c5fa..1d31638e7 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -179,6 +179,9 @@ def _ifgo_jac_t_mat_prod( ) return IFGO_prod + def hessian_is_zero(self, module: LSTM) -> bool: # noqa: D102 + return False + def _jac_mat_prod( self, module: LSTM, diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 59867da3a..422abd93c 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -47,6 +47,9 @@ def _check_parameters(module: RNN) -> None: if module.bidirectional is not False: raise NotImplementedError("only bidirectional = False is supported") + def hessian_is_zero(self, module: RNN) -> bool: # noqa: D102 + return False + @staticmethod def _a_jac_t_mat_prod( output: Tensor, diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index e9d9ee056..fcf84634b 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -54,7 +54,16 @@ def check_shape(mat: Tensor, like: Tensor, diff: int = 1) -> None: ) -def _check_same_V_dim(mat1, mat2): +def check_same_V_dim(mat1, mat2): + """Check whether V dim (first dim) matches. + + Args: + mat1: first tensor + mat2: second tensor + + Raises: + RuntimeError: if V dim (first dim) doesn't match + """ V1, V2 = mat1.shape[0], mat2.shape[0] if V1 != V2: raise RuntimeError("Number of vectors changed. Got {} and {}".format(V1, V2)) @@ -73,9 +82,19 @@ def _check_like(mat, module, name, diff=1, *args, **kwargs): return check_shape(mat, compare, diff=diff) -def _check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs): +def check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs): + """Checks shape, considers sum_batch. + + Args: + mat: matrix to multiply + module: module + name: parameter to operate on: module.name + sum_batch: whether to consider with or without sum + *args: ignored + **kwargs: ignored + """ diff = 1 if sum_batch else 2 - return check_shape(mat, getattr(module, name), diff=diff) + check_shape(mat, getattr(module, name), diff=diff) def _same_dim_as(mat, module, name, *args, **kwargs): @@ -133,15 +152,6 @@ def _wrapped_mat_prod_accept_vectors( vec_criterion=same_dim_as_output, ) -weight_jac_t_mat_prod_accept_vectors = functools.partial( - _mat_prod_accept_vectors, - vec_criterion=same_dim_as_output, -) -bias_jac_t_mat_prod_accept_vectors = functools.partial( - _mat_prod_accept_vectors, - vec_criterion=same_dim_as_output, -) - jac_mat_prod_accept_vectors = functools.partial( _mat_prod_accept_vectors, vec_criterion=same_dim_as_input, @@ -181,7 +191,7 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar in_check(mat, module, *args, **kwargs) mat_out = mat_prod(self, module, g_inp, g_out, mat, *args, **kwargs) out_check(mat_out, module, *args, **kwargs) - _check_same_V_dim(mat_out, mat) + check_same_V_dim(mat_out, mat) return mat_out @@ -193,21 +203,6 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar shape_like_input = functools.partial(_check_like, name="input0") shape_like_weight = functools.partial(_check_like, name="weight") shape_like_bias = functools.partial(_check_like, name="bias") -shape_like_weight_with_sum_batch = functools.partial( - _check_like_with_sum_batch, name="weight" -) -shape_like_bias_with_sum_batch = functools.partial( - _check_like_with_sum_batch, name="bias" -) -shape_like_bias_rnn_with_sum_batch = functools.partial( - _check_like_with_sum_batch, name="bias_ih_l0" -) -shape_like_weight_ih_with_sum_batch = functools.partial( - _check_like_with_sum_batch, name="weight_ih_l0" -) -shape_like_weight_hh_with_sum_batch = functools.partial( - _check_like_with_sum_batch, name="weight_hh_l0" -) # decorators for shape checking jac_mat_prod_check_shapes = functools.partial( @@ -226,33 +221,6 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar mat_prod_check_shapes, in_check=shape_like_output, out_check=shape_like_input ) - -weight_jac_t_mat_prod_check_shapes = functools.partial( - mat_prod_check_shapes, - in_check=shape_like_output, - out_check=shape_like_weight_with_sum_batch, -) -bias_jac_t_mat_prod_check_shapes = functools.partial( - mat_prod_check_shapes, - in_check=shape_like_output, - out_check=shape_like_bias_with_sum_batch, -) -bias_rnn_jac_t_mat_prod_check_shapes = functools.partial( - mat_prod_check_shapes, - in_check=shape_like_output, - out_check=shape_like_bias_rnn_with_sum_batch, -) -weight_ih_jac_t_mat_prod_check_shapes = functools.partial( - mat_prod_check_shapes, - in_check=shape_like_output, - out_check=shape_like_weight_ih_with_sum_batch, -) -weight_hh_jac_t_mat_prod_check_shapes = functools.partial( - mat_prod_check_shapes, - in_check=shape_like_output, - out_check=shape_like_weight_hh_with_sum_batch, -) - ############################################################################### # Wrapper for second-order extensions # ############################################################################### @@ -327,3 +295,36 @@ def _new_hessian_mat_prod(mat): return _new_hessian_mat_prod return _wrapped_make_hessian_mat_prod + + +def param_mjp_accept_vectors(mat_prod: Callable[..., Tensor]) -> Callable[..., Tensor]: + """Add support for vectors to matrix products. + + vec_criterion(mat, module) returns if mat is a vector. + + Args: + mat_prod: Function that processes multiple vectors in format of a matrix. + + Returns: + Wrapped ``mat_prod`` function that processes multiple vectors in format of + a matrix, and supports vector-shaped inputs which are internally converted + to the correct format. + Preserves format of input: + If the input format is a vector, the output format is a vector. + If the input format is a matrix, the output format is a matrix. + """ + + @functools.wraps(mat_prod) + def _wrapped_mat_prod_accept_vectors( + self, param_str, module, g_inp, g_out, mat, *args, **kwargs + ): + is_vec = same_dim_as_output(mat, module) + mat_in = mat if not is_vec else _add_V_dim(mat) + mat_out = mat_prod( + self, param_str, module, g_inp, g_out, mat_in, *args, **kwargs + ) + mat_out = mat_out if not is_vec else _remove_V_dim(mat_out) + + return mat_out + + return _wrapped_mat_prod_accept_vectors diff --git a/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py b/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py index aeb6f3877..831117e99 100644 --- a/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py +++ b/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py @@ -14,9 +14,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_ggnmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -28,7 +26,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_ggnmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/curvmatprod/ggnmp/conv2d.py b/backpack/extensions/curvmatprod/ggnmp/conv2d.py index afd9785b6..825d88038 100644 --- a/backpack/extensions/curvmatprod/ggnmp/conv2d.py +++ b/backpack/extensions/curvmatprod/ggnmp/conv2d.py @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_ggnmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_ggnmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/curvmatprod/ggnmp/linear.py b/backpack/extensions/curvmatprod/ggnmp/linear.py index 21d92b685..dfb35d3d8 100644 --- a/backpack/extensions/curvmatprod/ggnmp/linear.py +++ b/backpack/extensions/curvmatprod/ggnmp/linear.py @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_ggnmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_ggnmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/curvmatprod/hmp/batchnorm1d.py b/backpack/extensions/curvmatprod/hmp/batchnorm1d.py index 388180ab1..aa9f70f3d 100644 --- a/backpack/extensions/curvmatprod/hmp/batchnorm1d.py +++ b/backpack/extensions/curvmatprod/hmp/batchnorm1d.py @@ -14,9 +14,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_hmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -28,7 +26,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_hmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/curvmatprod/hmp/conv2d.py b/backpack/extensions/curvmatprod/hmp/conv2d.py index 69c74cfa8..7430c042a 100644 --- a/backpack/extensions/curvmatprod/hmp/conv2d.py +++ b/backpack/extensions/curvmatprod/hmp/conv2d.py @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_hmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_hmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/curvmatprod/hmp/linear.py b/backpack/extensions/curvmatprod/hmp/linear.py index 11428f18f..7917dfa1a 100644 --- a/backpack/extensions/curvmatprod/hmp/linear.py +++ b/backpack/extensions/curvmatprod/hmp/linear.py @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_hmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_hmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/curvmatprod/pchmp/conv2d.py b/backpack/extensions/curvmatprod/pchmp/conv2d.py index 620fb2a6d..213601e6c 100644 --- a/backpack/extensions/curvmatprod/pchmp/conv2d.py +++ b/backpack/extensions/curvmatprod/pchmp/conv2d.py @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_pchmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_pchmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/curvmatprod/pchmp/linear.py b/backpack/extensions/curvmatprod/pchmp/linear.py index 37dbc49d7..d38539622 100644 --- a/backpack/extensions/curvmatprod/pchmp/linear.py +++ b/backpack/extensions/curvmatprod/pchmp/linear.py @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): def weight_pchmp(mat): result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, result - ) + result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) return result @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): def bias_pchmp(mat): result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) result = h_out_mat_prod(result) - result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result) + result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) return result diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index 2ba72cce4..c0ee49729 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -79,7 +79,8 @@ def param_function( Scaled individual gradients """ subsampling = ext.get_subsampling() - return getattr(self._derivatives, f"{param_str}_jac_t_mat_prod")( + return self._derivatives.param_mjp( + param_str, module, g_inp, g_out, diff --git a/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py index 6604f811e..f7b4f79dd 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py +++ b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py @@ -66,8 +66,8 @@ def param_function( """ param_dims: List[int] = list(range(1, 1 + getattr(module, param_str).dim())) return ( - getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( - module, g_inp, g_out, g_out[0], sum_batch=False + self.derivatives.param_mjp( + param_str, module, g_inp, g_out, g_out[0], sum_batch=False ) ** 2 ).sum(param_dims) diff --git a/backpack/extensions/firstorder/gradient/base.py b/backpack/extensions/firstorder/gradient/base.py index eb45c8da7..b9f198855 100644 --- a/backpack/extensions/firstorder/gradient/base.py +++ b/backpack/extensions/firstorder/gradient/base.py @@ -35,11 +35,11 @@ def __init__(self, derivatives, params): setattr(self, param_str, self._make_param_function(param_str)) super().__init__(params=params) - def _make_param_function(self, param): + def _make_param_function(self, param_str): """Creates a function that calculates gradient wrt param. Args: - param(str): name of parameter + param_str: name of parameter Returns: function: function that calculates gradient wrt param @@ -58,8 +58,8 @@ def param_function(ext, module, g_inp, g_out, bpQuantities): Returns: torch.Tensor: gradient of the batch, similar to autograd """ - return getattr(self.derivatives, f"{param}_jac_t_mat_prod")( - module, g_inp, g_out, g_out[0], sum_batch=True + return self.derivatives.param_mjp( + param_str, module, g_inp, g_out, g_out[0], sum_batch=True ) return param_function diff --git a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py index 09f3d80f1..3e1d171ab 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py +++ b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py @@ -64,8 +64,8 @@ def param_function( sum_grad_squared """ return ( - getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( - module, g_inp, g_out, g_out[0], sum_batch=False + self.derivatives.param_mjp( + param_str, module, g_inp, g_out, g_out[0], sum_batch=False ) ** 2 ).sum(self.N_axis) diff --git a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py index 4b445a63b..203b8ebd6 100644 --- a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py +++ b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py @@ -42,7 +42,7 @@ def __init__( super().__init__(derivatives, params=params) def _make_param_method( - self, param: str, sum_batch: bool + self, param_str: str, sum_batch: bool ) -> Callable[ [ModuleExtension, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor ]: @@ -67,8 +67,8 @@ def _param( """ axis: Tuple[int] = (0, 1) if sum_batch else (0,) return ( - getattr(self.derivatives, f"{param}_jac_t_mat_prod")( - module, grad_inp, grad_out, backproped, sum_batch=False + self.derivatives.param_mjp( + param_str, module, grad_inp, grad_out, backproped, sum_batch=False ) ** 2 ).sum(axis=axis) diff --git a/backpack/extensions/secondorder/sqrt_ggn/base.py b/backpack/extensions/secondorder/sqrt_ggn/base.py index d5e20e721..f625e38b1 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/base.py +++ b/backpack/extensions/secondorder/sqrt_ggn/base.py @@ -67,8 +67,8 @@ def param_function( Returns: GGN/Fisher matrix square root. """ - return getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( - module, g_inp, g_out, backproped, sum_batch=False + return self.derivatives.param_mjp( + param_str, module, g_inp, g_out, backproped, sum_batch=False ) return param_function diff --git a/test/benchmark/jvp.py b/test/benchmark/jvp.py index e7115d152..1a1a1464b 100644 --- a/test/benchmark/jvp.py +++ b/test/benchmark/jvp.py @@ -130,7 +130,7 @@ def bp_jtv_weight_func(module, vin): def f(): r = ( derivatives_for[module.__class__]() - .weight_jac_t_mat_prod(module, None, None, vin) + .param_mjp("weight", module, None, None, vin) .contiguous() ) if vin.is_cuda: @@ -160,7 +160,7 @@ def bp_jtv_bias_func(module, vin): def f(): r = ( derivatives_for[module.__class__]() - .bias_jac_t_mat_prod(module, None, None, vin.unsqueeze(2)) + .param_mjp("bias", module, None, None, vin.unsqueeze(2)) .contiguous() ) if vin.is_cuda: diff --git a/test/core/derivatives/batch_norm_settings.py b/test/core/derivatives/batch_norm_settings.py index 6c8f0deeb..7994e1716 100644 --- a/test/core/derivatives/batch_norm_settings.py +++ b/test/core/derivatives/batch_norm_settings.py @@ -31,9 +31,8 @@ "input_fn": lambda: rand(size=(5, 7, 3, 4)), }, { - "module_fn": lambda: BatchNorm3d(num_features=7), - "input_fn": lambda: rand(size=(5, 7, 3, 4, 2)), - "seed": 1, + "module_fn": lambda: BatchNorm3d(num_features=3), + "input_fn": lambda: rand(size=(5, 3, 3, 4, 2)), }, { "module_fn": lambda: initialize_batch_norm_eval(BatchNorm1d(num_features=7)), diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index a4366dedd..c4e9942fa 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -6,7 +6,7 @@ - Jacobian-matrix products with respect to layer parameters - Transposed Jacobian-matrix products with respect to layer parameters """ - +from contextlib import nullcontext from test.automated_test import check_sizes_and_values from test.core.derivatives.batch_norm_settings import BATCH_NORM_SETTINGS from test.core.derivatives.implementation.autograd import AutogradDerivatives @@ -20,11 +20,10 @@ from test.utils.skip_test import ( skip_adaptive_avg_pool3d_cuda, skip_batch_norm_train_mode_with_subsampling, - skip_no_param, skip_permute_with_subsampling, skip_subsampling_conflict, ) -from typing import List, Tuple, Union +from typing import List, Union from warnings import warn from pytest import fixture, mark, raises, skip @@ -47,11 +46,9 @@ LOSS_FAIL_IDS = [problem.make_id() for problem in LOSS_FAIL_PROBLEMS] RNN_PROBLEMS = make_test_problems(RNN_SETTINGS) +RNN_PROBLEMS += make_test_problems(LSTM_SETTINGS) RNN_IDS = [problem.make_id() for problem in RNN_PROBLEMS] -LSTM_PROBLEMS = make_test_problems(LSTM_SETTINGS) -LSTM_IDS = [problem.make_id() for problem in LSTM_PROBLEMS] - PERMUTE_PROBLEMS = make_test_problems(PERMUTE_SETTINGS) PERMUTE_IDS = [problem.make_id() for problem in PERMUTE_PROBLEMS] @@ -62,14 +59,50 @@ SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) +def test_param_mjp( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: List[int] or None, + request, +) -> None: + """Test all parameter derivatives. + + Args: + problem: test problem + sum_batch: whether to sum along batch axis + subsampling: subsampling indices + request: problem request + """ + skip_subsampling_conflict(problem, subsampling) + test_save_memory: bool = "Conv" in request.node.callspec.id + V = 3 + + for param_str, _ in problem.module.named_parameters(): + print(f"testing derivative wrt {param_str}") + for save_memory in [True, False] if test_save_memory else [None]: + if test_save_memory: + print(f"testing with save_memory={save_memory}") + + mat = rand_mat_like_output(V, problem, subsampling=subsampling) + with weight_jac_t_save_memory( + save_memory=save_memory + ) if test_save_memory else nullcontext(): + backpack_res = BackpackDerivatives(problem).param_mjp( + param_str, mat, sum_batch, subsampling=subsampling + ) + autograd_res = AutogradDerivatives(problem).param_mjp( + param_str, mat, sum_batch, subsampling=subsampling + ) + + check_sizes_and_values(autograd_res, backpack_res) + + @mark.parametrize( "problem", - NO_LOSS_PROBLEMS - + RNN_PROBLEMS - + PERMUTE_PROBLEMS - + LSTM_PROBLEMS - + BATCH_NORM_PROBLEMS, - ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS + BATCH_NORM_IDS, + NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + BATCH_NORM_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + BATCH_NORM_IDS, ) def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: """Test the Jacobian-matrix product. @@ -91,12 +124,8 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: @mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize( "problem", - NO_LOSS_PROBLEMS - + RNN_PROBLEMS - + PERMUTE_PROBLEMS - + LSTM_PROBLEMS - + BATCH_NORM_PROBLEMS, - ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + LSTM_IDS + BATCH_NORM_IDS, + NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + BATCH_NORM_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + BATCH_NORM_IDS, ) def test_jac_t_mat_prod( problem: DerivativesTestProblem, @@ -139,166 +168,6 @@ def test_jac_t_mat_prod( IDS_WITH_WEIGHTS.append(problem_id) -@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) -@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) -def test_bias_ih_l0_jac_t_mat_prod( - problem: DerivativesTestProblem, - sum_batch: bool, - subsampling: Union[List[int], None], - V: int = 3, -) -> None: - """Test the transposed Jacobian-matrix product w.r.t. to bias_ih_l0. - - Args: - problem: Problem for derivative test. - sum_batch: Sum results over the batch dimension. - subsampling: Indices of active samples. - V: Number of vectorized transposed Jacobian-vector products. - """ - problem.set_up() - skip_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling) - - autograd_res = AutogradDerivatives(problem).bias_ih_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - backpack_res = BackpackDerivatives(problem).bias_ih_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - - check_sizes_and_values(autograd_res, backpack_res) - problem.tear_down() - - -@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) -@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) -def test_bias_hh_l0_jac_t_mat_prod( - problem: DerivativesTestProblem, - sum_batch: bool, - subsampling: Union[List[int], None], - V: int = 3, -) -> None: - """Test the transposed Jacobian-matrix product w.r.t. to bias_hh_l0. - - Args: - problem: Problem for derivative test. - sum_batch: Sum results over the batch dimension. - subsampling: Indices of active samples. - V: Number of vectorized transposed Jacobian-vector products. - """ - problem.set_up() - skip_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling) - - autograd_res = AutogradDerivatives(problem).bias_hh_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - backpack_res = BackpackDerivatives(problem).bias_hh_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - - check_sizes_and_values(autograd_res, backpack_res) - problem.tear_down() - - -@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) -@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) -def test_weight_ih_l0_jac_t_mat_prod( - problem: DerivativesTestProblem, - sum_batch: bool, - subsampling: Union[List[int], None], - V: int = 3, -) -> None: - """Test the transposed Jacobian-matrix product w.r.t. to weight_ih_l0. - - Args: - problem: Problem for derivative test. - sum_batch: Sum results over the batch dimension. - subsampling: Indices of active samples. - V: Number of vectorized transposed Jacobian-vector products. - """ - problem.set_up() - skip_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling) - - autograd_res = AutogradDerivatives(problem).weight_ih_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - backpack_res = BackpackDerivatives(problem).weight_ih_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - - check_sizes_and_values(autograd_res, backpack_res) - problem.tear_down() - - -@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) -@mark.parametrize("problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS) -def test_weight_hh_l0_jac_t_mat_prod( - problem: DerivativesTestProblem, - sum_batch: bool, - subsampling: Union[List[int], None], - V: int = 3, -) -> None: - """Test the transposed Jacobian-matrix product w.r.t. to weight_hh_l0. - - Args: - problem: Problem for derivative test. - sum_batch: Sum results over the batch dimension. - subsampling: Indices of active samples. - V: Number of vectorized transposed Jacobian-vector products. - """ - problem.set_up() - skip_subsampling_conflict(problem, subsampling) - mat = rand_mat_like_output(V, problem, subsampling=subsampling) - - autograd_res = AutogradDerivatives(problem).weight_hh_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - backpack_res = BackpackDerivatives(problem).weight_hh_l0_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - - check_sizes_and_values(autograd_res, backpack_res) - problem.tear_down() - - -@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) -@mark.parametrize( - "save_memory", - [True, False], - ids=["save_memory=True", "save_memory=False"], -) -def test_weight_jac_t_mat_prod( - problem_weight_jac_t_mat: Tuple[DerivativesTestProblem, List[int], Tensor], - sum_batch: bool, - save_memory: bool, -) -> None: - """Test the transposed Jacobian-matrix product w.r.t. to the weight. - - Args: - problem_weight_jac_t_mat: Instantiated test case, subsampling, and - input for weight_jac_t - sum_batch: Sum out the batch dimension. - save_memory: Use Owkin implementation in convolutions to save memory. - """ - problem, subsampling, mat = problem_weight_jac_t_mat - - with weight_jac_t_save_memory(save_memory): - backpack_res = BackpackDerivatives(problem).weight_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - autograd_res = AutogradDerivatives(problem).weight_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - - check_sizes_and_values(autograd_res, backpack_res, rtol=5e-5) - - def rand_mat_like_output( V: int, problem: DerivativesTestProblem, subsampling: List[int] = None ) -> Tensor: @@ -354,30 +223,6 @@ def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> Non IDS_WITH_BIAS.append(problem_id) -@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]) -def test_bias_jac_t_mat_prod( - problem_bias_jac_t_mat: Tuple[DerivativesTestProblem, List[int], Tensor], - sum_batch: bool, -) -> None: - """Test the transposed Jacobian-matrix product w.r.t. to the bias. - - Args: - problem_bias_jac_t_mat: Instantiated test case, subsampling, and - input for bias_jac_t - sum_batch: Sum out the batch dimension. - """ - problem, subsampling, mat = problem_bias_jac_t_mat - - backpack_res = BackpackDerivatives(problem).bias_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - autograd_res = AutogradDerivatives(problem).bias_jac_t_mat_prod( - mat, sum_batch, subsampling=subsampling - ) - - check_sizes_and_values(autograd_res, backpack_res) - - @mark.parametrize( "problem", PROBLEMS_WITH_BIAS + BATCH_NORM_PROBLEMS, @@ -542,7 +387,9 @@ def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None problem.tear_down() -@fixture(params=PROBLEMS + BATCH_NORM_PROBLEMS, ids=lambda p: p.make_id()) +@fixture( + params=PROBLEMS + BATCH_NORM_PROBLEMS + RNN_PROBLEMS, ids=lambda p: p.make_id() +) def problem(request) -> DerivativesTestProblem: """Set seed, create tested layer and data. Finally clean up. @@ -558,86 +405,6 @@ def problem(request) -> DerivativesTestProblem: case.tear_down() -@fixture -def problem_weight(problem: DerivativesTestProblem) -> DerivativesTestProblem: - """Filter out cases that don't have a weight parameter. - - Args: - problem: Test case with deterministically constructed attributes. - - Yields: - Instantiated cases that have a weight parameter. - """ - skip_no_param(problem, "weight") - yield problem - - -@fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -def problem_weight_jac_t_mat( - request, problem_weight: DerivativesTestProblem -) -> Tuple[DerivativesTestProblem, Union[None, List[int]], Tensor]: - """Create matrix that will be multiplied by the weight Jacobian. - - Skip if there is a conflict where the subsampling indices exceed the number of - samples in the input. - - Args: - request (SubRequest): Request for the fixture from a test/fixture function. - problem_weight: Test case with weight parameter. - - Yields: - problem with weight, subsampling, matrix for weight_jac_t - """ - subsampling: Union[None, List[int]] = request.param - skip_subsampling_conflict(problem_weight, subsampling) - - V = 3 - mat = rand_mat_like_output(V, problem_weight, subsampling=subsampling) - - yield (problem_weight, subsampling, mat) - del mat - - -@fixture -def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: - """Filter out cases that don't have a bias parameter. - - Args: - problem: Test case with deterministically constructed attributes. - - Yields: - Instantiated cases that have a bias parameter. - """ - skip_no_param(problem, "bias") - yield problem - - -@fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -def problem_bias_jac_t_mat( - request, problem_bias: DerivativesTestProblem -) -> Tuple[DerivativesTestProblem, Union[None, List[int]], Tensor]: - """Create matrix that will be multiplied by the bias Jacobian. - - Skip if there is a conflict where the subsampling indices exceed the number of - samples in the input. - - Args: - request (SubRequest): Request for the fixture from a test/fixture function. - problem_bias: Test case with bias parameter. - - Yields: - problem with bias, subsampling, matrix for bias_jac_t - """ - subsampling: Union[None, List[int]] = request.param - skip_subsampling_conflict(problem_bias, subsampling) - - V = 3 - mat = rand_mat_like_output(V, problem_bias, subsampling=subsampling) - - yield (problem_bias, subsampling, mat) - del mat - - @fixture def small_input_problem( problem: DerivativesTestProblem, max_input_numel: int = 100 diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 6281ab4ad..c52b2831b 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -62,41 +62,27 @@ def jac_t_mat_prod( ) -> Tensor: # noqa: D102 return stack([self.jac_t_vec_prod(vec, subsampling=subsampling) for vec in mat]) - def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 - return self.param_jac_t_mat_prod( - "weight", mat, sum_batch, subsampling=subsampling - ) - - def bias_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 - return self.param_jac_t_mat_prod( - "bias", mat, sum_batch, subsampling=subsampling - ) - - def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 - return self.param_jac_t_mat_prod( - "bias_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling - ) - - def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 - return self.param_jac_t_mat_prod( - "bias_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling - ) - - def weight_ih_l0_jac_t_mat_prod( - self, mat, sum_batch, subsampling=None - ): # noqa: D102 - return self.param_jac_t_mat_prod( - "weight_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling - ) - - def weight_hh_l0_jac_t_mat_prod( - self, mat, sum_batch, subsampling=None - ): # noqa: D102 - return self.param_jac_t_mat_prod( - "weight_hh_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + def param_mjp( + self, + param_str: str, + mat: Tensor, + sum_batch: bool, + subsampling: List[int] = None, + ) -> Tensor: # noqa: D102 + return stack( + [ + self._param_vjp( + param_str, + vec, + sum_batch, + axis_batch=get_batch_axis(self.problem.module), + subsampling=subsampling, + ) + for vec in mat + ] ) - def param_jac_t_vec_prod( + def _param_vjp( self, name: str, vec: Tensor, @@ -135,35 +121,6 @@ def param_jac_t_vec_prod( return jac_t_sample_prods - def param_jac_t_mat_prod( - self, - name: str, - mat: Tensor, - sum_batch: bool, - axis_batch: int = 0, - subsampling: List[int] = None, - ) -> Tensor: - """Compute the product of jac_t and the given matrix. - - Args: - name: name of parameter for derivative - mat: matrix which to multiply - sum_batch: whether to sum along batch axis - axis_batch: Batch axis, counted without the first axis. Defaults to 0. - subsampling: Indices of active samples. Default: ``None`` (all). - - Returns: - product of jac_t and mat - """ - return stack( - [ - self.param_jac_t_vec_prod( - name, vec, sum_batch, axis_batch=axis_batch, subsampling=subsampling - ) - for vec in mat - ] - ) - def weight_jac_mat_prod(self, mat) -> Tensor: """Product of jacobian and matrix. diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index c40a2d230..fb7e789e3 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -41,20 +41,16 @@ def jac_t_mat_prod( self.problem.module, None, None, mat, subsampling=subsampling ) - def weight_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 - self.store_forward_io() - return self.problem.derivative.weight_jac_t_mat_prod( - self.problem.module, - None, - None, - mat, - sum_batch=sum_batch, - subsampling=subsampling, - ) - - def bias_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 + def param_mjp( + self, + param_str: str, + mat: Tensor, + sum_batch: bool, + subsampling: List[int] = None, + ) -> Tensor: # noqa: D102 self.store_forward_io() - return self.problem.derivative.bias_jac_t_mat_prod( + return self.problem.derivative.param_mjp( + param_str, self.problem.module, None, None, @@ -75,54 +71,6 @@ def bias_jac_mat_prod(self, mat): # noqa: D102 self.problem.module, None, None, mat ) - def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 - self.store_forward_io() - return self.problem.derivative.bias_ih_l0_jac_t_mat_prod( - self.problem.module, - None, - None, - mat, - sum_batch=sum_batch, - subsampling=subsampling, - ) - - def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 - self.store_forward_io() - return self.problem.derivative.bias_hh_l0_jac_t_mat_prod( - self.problem.module, - None, - None, - mat, - sum_batch=sum_batch, - subsampling=subsampling, - ) - - def weight_ih_l0_jac_t_mat_prod( - self, mat, sum_batch, subsampling=None - ): # noqa: D102 - self.store_forward_io() - return self.problem.derivative.weight_ih_l0_jac_t_mat_prod( - self.problem.module, - None, - None, - mat, - sum_batch=sum_batch, - subsampling=subsampling, - ) - - def weight_hh_l0_jac_t_mat_prod( - self, mat, sum_batch, subsampling=None - ): # noqa: D102 - self.store_forward_io() - return self.problem.derivative.weight_hh_l0_jac_t_mat_prod( - self.problem.module, - None, - None, - mat, - sum_batch=sum_batch, - subsampling=subsampling, - ) - def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102 self.store_forward_io() return self.problem.derivative.ea_jac_t_mat_jac_prod( diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index cc875b409..84edba5b4 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -44,28 +44,17 @@ def jac_t_mat_prod(self, mat: Tensor, subsampling: List[int] = None) -> Tensor: raise NotImplementedError @abstractmethod - def weight_jac_t_mat_prod( - self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + def param_mjp( + self, + param_str: str, + mat: Tensor, + sum_batch: bool, + subsampling: List[int] = None, ) -> Tensor: - """Matrix-Jacobian products w.r.t. the weight. - - Args: - mat: matrix - sum_batch: whether to sum along batch axis - subsampling: Active samples in the output. Default: ``None`` (all). - - Returns: - product - """ - raise NotImplementedError - - @abstractmethod - def bias_jac_t_mat_prod( - self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None - ) -> Tensor: - """Product of jacobian and matrix. + """Matrix-Jacobian products w.r.t. the parameter. Args: + param_str: parameter name mat: matrix sum_batch: whether to sum along batch axis subsampling: Active samples in the output. Default: ``None`` (all). @@ -99,70 +88,6 @@ def bias_jac_mat_prod(self, mat: Tensor) -> Tensor: """ raise NotImplementedError - @abstractmethod - def bias_ih_l0_jac_t_mat_prod( - self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None - ) -> Tensor: - """Product of jacobian and matrix. - - Args: - mat: matrix - sum_batch: whether to sum along batch axis - subsampling: Active samples in the output. Default: ``None`` (all). - - Returns: - product - """ - raise NotImplementedError - - @abstractmethod - def bias_hh_l0_jac_t_mat_prod( - self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None - ) -> Tensor: - """Product of jacobian and matrix. - - Args: - mat: matrix - sum_batch: whether to sum along batch axis - subsampling: Active samples in the output. Default: ``None`` (all). - - Returns: - product - """ - raise NotImplementedError - - @abstractmethod - def weight_ih_l0_jac_t_mat_prod( - self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None - ) -> Tensor: - """Product of jacobian and matrix. - - Args: - mat: matrix - sum_batch: whether to sum along batch axis - subsampling: Active samples in the output. Default: ``None`` (all). - - Returns: - product - """ - raise NotImplementedError - - @abstractmethod - def weight_hh_l0_jac_t_mat_prod( - self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None - ) -> Tensor: - """Product of jacobian and matrix. - - Args: - mat: matrix - sum_batch: whether to sum along batch axis - subsampling: Active samples in the output. Default: ``None`` (all). - - Returns: - product - """ - raise NotImplementedError - @abstractmethod def ea_jac_t_mat_jac_prod(self, mat: Tensor) -> Tensor: """Product of ea jacobian with matrix. diff --git a/test/core/derivatives/lstm_settings.py b/test/core/derivatives/lstm_settings.py index b3b8397f0..348129871 100644 --- a/test/core/derivatives/lstm_settings.py +++ b/test/core/derivatives/lstm_settings.py @@ -13,7 +13,8 @@ "seed" (int): seed for the random number for torch.rand """ -import torch +from torch import rand +from torch.nn import LSTM LSTM_SETTINGS = [] @@ -22,7 +23,11 @@ ############################################################################### LSTM_SETTINGS += [ { - "module_fn": lambda: torch.nn.LSTM(input_size=5, hidden_size=3, num_layers=1), - "input_fn": lambda: torch.rand(size=(10, 8, 5)), + "module_fn": lambda: LSTM(input_size=4, hidden_size=3, num_layers=1), + "input_fn": lambda: rand(size=(5, 3, 4)), + }, + { + "module_fn": lambda: LSTM(input_size=5, hidden_size=3, num_layers=1), + "input_fn": lambda: rand(size=(10, 8, 5)), }, ] diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index f300a4b25..9d02819bf 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -72,15 +72,3 @@ def skip_subsampling_conflict( enough_samples = subsampling is None or N > max(subsampling) if not enough_samples: skip("Not enough samples.") - - -def skip_no_param(problem: DerivativesTestProblem, param_str: str) -> None: - """Skip if test case does not contain the parameter. - - Args: - problem: Test case. - param_str: Parameter name. - """ - has_param = getattr(problem.module, param_str, None) is not None - if not has_param: - skip(f"Test case has no {param_str} parameter.") From c789dcbcd2719229334c7c40616b411806a57f8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Mon, 2 Aug 2021 11:19:56 +0200 Subject: [PATCH 38/54] Use `register_full_backward_hook` for `torch>=1.9.0` (#194) Major changes how information is backpropagated in BackPACK: - Use `register_full_backward_hook` for `torch>=1.9.0`, else `register_backward_hook`: `full_backward_hook` is already introduced in `torch==1.8.0`, but it won't be fired on modules whose input has `requires_grad=False` (e.g. first layer of a net) before `torch==1.9.0`. - Modify which modules are extended: - Container modules built-in to PyTorch: **Old:** Have hooks that store I/O in the forward, and free I/O in the backward pass. **New:** Don't have hooks. - As a consequence of the above: **Old:** Post extension hooks can be executed on built-in PyTorch containers and their submodules. **New:** Post extension hooks only execute on submodules of built-in containers. - Modify backpropagation for second-order quantities: - **Old:** Call `backpropagate` from a module extension even if `module.input0` is end of backpropagation (does not require gradient). **New:** Only backpropagate if `module.input0` requires gradient. - Storage of additional backpropagated information: **Old:** Fetch information from `module.output`, backpropagate through module and append result to `module.input0`. This is not possible anymore with the `full_backward_hook` because of (information attached to python objects is lost). **New:** Keep a separate dictionary that uses the `data_ptr()` of `module.output` and `module.input0` as identifier for `output` and `input`. This property is preserved, i.e. `layer2.input0.data_ptr() == layer1.output.data_ptr()`. - `Flatten` layers that act as no op: **Old:** Second-order module extensions for `Flatten` check if it was a no op and the hook is therefore called at the wrong time. **New:** With `register_full_backward_hook` the execution order gets fixed. With `register_backward_hook` check in module extension whether the execution order is wrong and skip backpropagation only if the `Flatten` module was a no op. No extra logic in the `Flatten` second-order module extensions. Resolves https://github.com/fKunstner/backpack-discuss/issues/94. --- * stash * try saving on grad_inp and grad_out * working: save on .sum() as hash value * working: use .data_ptr as hash * refactoring, docstrings, type hints * docstring * add TODO: discussion * Meaningful Error Message: bp_quantitiy is None * incorporate suggestions * use new hook for torch>=1.9.0 * remove requires_grad=True from tests * fix example_first_order_resnet.py * fix example_trace_estimation.py for cuda * all versions: change reference tensor to input-output * delete use_legacy * fix import * fix import * dirty fix: overwrite existing quantities * docstring backpack init * clear quantities after backward * document context.py * rename post_extension_hook * use cpu() * remove detach() * fix import * introduce FirstOrderBackpropExtension * introduce SecondOrderBackpropExtension * fix import * incorporate suggestions format inline get_post_extension_hook __exit__ typing document firstorder/base.py document utils/hooks.py * fix post extension hook call * incorporate suggestions * use __call__ instead of apply() * incorporate suggestions * use FAIL_WARN * Revert "use FAIL_WARN" This reverts commit 9a9df1b6 * use FAIL_WARN * [REF] Improve names, fix docstring typo * [ADD] option fail_mode in FirsOrderBackpropExtension to enable custom choice * [ADD] option fail_mode in FirsOrderBackpropExtension to enable custom choice * [REF] Improve structure to clarify which modules get hooks * [DEL] delete clear quantities * [FIX] docstring * [ADD] should_backpropagate->bool * [ADD] should_backpropagate->bool * [REF] rename VERSION_HIGHER_THAN->VERSION_AT_LEAST * [TEST] that graph is clear after backward pass * [ADD] reintroduce retain_grad * [REF] Improve consistency: `post_extension_hook` vs `extension_hook` * [ADD] Debug message for extension hook execution * [REF] Introduce variable for flag that marks extended modules * [REF] make fail_mode private * [ADD] reintroduce test_hooks.py * [DEL] remove print * [DOC] Improve docstring * [ADD] introduce test_extension_hook_param_before_savefield_exists * [REF] delete import torch * [ADD] catch different error depending on torch version * [ADD] catch different error depending on torch version * [REF] introduce exception-conversion, unify handling of torch versions * [DOC] add new tests to fully_documented.txt * [REF] Introduce variable for batch size * [REF] Reduce duplication * [REF] Improve type annotation, reformat * [ADD] test settings with flatten * [DOC] fix docstring * [ADD] skip module extension if flatten is no_op * [DOC] clarify that data_ptr() is used * [REF] Check that Flatten without bp_quantity is a no op * [FMT] Fix black * [FIX] Rename `apply` into `__call__` * [DOC] Add noqa to `__call__` * [DEL] Remove backpropagate of FirstOrderModuleExtension It will never be called because no quantities are backpropagated. * [DEL] remove no_op logic in Flatten module extensions * [REF] Explicit import, docstring correction * [DOC] Correct docstring Co-authored-by: Felix Dangel Co-authored-by: Felix Dangel Co-authored-by: Felix Dangel <48687646+f-dangel@users.noreply.github.com> --- .coveragerc | 1 + backpack/__init__.py | 190 ++++++++----- backpack/context.py | 91 ++++-- backpack/core/derivatives/batchnorm_nd.py | 4 +- backpack/core/derivatives/flatten.py | 10 - backpack/core/derivatives/linear.py | 4 +- backpack/core/derivatives/lstm.py | 4 +- backpack/extensions/backprop_extension.py | 106 ++++--- .../extensions/curvmatprod/ggnmp/__init__.py | 4 +- .../extensions/curvmatprod/ggnmp/flatten.py | 6 - .../extensions/curvmatprod/hmp/__init__.py | 4 +- .../extensions/curvmatprod/hmp/flatten.py | 6 - .../extensions/curvmatprod/pchmp/__init__.py | 4 +- .../extensions/curvmatprod/pchmp/flatten.py | 6 - backpack/extensions/firstorder/base.py | 24 +- .../firstorder/batch_grad/__init__.py | 5 +- .../firstorder/batch_grad/batchnorm_nd.py | 4 +- .../firstorder/batch_l2_grad/__init__.py | 5 +- .../firstorder/batch_l2_grad/batchnorm_nd.py | 4 +- .../firstorder/gradient/batchnorm_nd.py | 4 +- .../firstorder/sum_grad_squared/__init__.py | 5 +- .../sum_grad_squared/batchnorm_nd.py | 4 +- .../firstorder/variance/__init__.py | 5 +- .../firstorder/variance/batchnorm_nd.py | 4 +- backpack/extensions/module_extension.py | 267 +++++++++++++----- backpack/extensions/saved_quantities.py | 48 ++++ backpack/extensions/secondorder/base.py | 9 + .../secondorder/diag_ggn/__init__.py | 6 +- .../secondorder/diag_ggn/batchnorm_nd.py | 8 +- .../secondorder/diag_ggn/flatten.py | 6 - .../secondorder/diag_hessian/__init__.py | 6 +- .../secondorder/diag_hessian/flatten.py | 6 - .../extensions/secondorder/hbp/__init__.py | 4 +- .../extensions/secondorder/hbp/flatten.py | 6 - .../secondorder/sqrt_ggn/__init__.py | 4 +- .../secondorder/sqrt_ggn/flatten.py | 39 --- backpack/utils/__init__.py | 28 +- backpack/utils/hooks.py | 9 +- backpack/utils/module_classification.py | 30 ++ .../use_cases/example_first_order_resnet.py | 2 +- .../use_cases/example_trace_estimation.py | 2 +- fully_documented.txt | 9 + test/core/derivatives/problem.py | 7 +- test/core/derivatives/utils.py | 12 - test/extensions/graph_clear_test.py | 61 ++++ test/extensions/implementation/hooks.py | 2 +- test/extensions/secondorder/hbp/test_kfac.py | 5 +- test/extensions/secondorder/hbp/test_kflr.py | 5 +- test/extensions/secondorder/hbp/test_kfra.py | 5 +- .../secondorder/secondorder_settings.py | 8 +- test/extensions/test_hooks.py | 196 ++++++++++--- test/test_simple_resnet.py | 144 ---------- 52 files changed, 870 insertions(+), 568 deletions(-) create mode 100644 backpack/extensions/saved_quantities.py create mode 100644 backpack/extensions/secondorder/base.py create mode 100644 backpack/utils/module_classification.py create mode 100644 test/extensions/graph_clear_test.py delete mode 100644 test/test_simple_resnet.py diff --git a/.coveragerc b/.coveragerc index 28f4ee9fe..e186905f2 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,6 +7,7 @@ exclude_lines = # Don't complain if tests don't hit defensive assertion code: raise NotImplementedError + raise AssertionError # TYPE_CHECKING block is never executed during pytest run if TYPE_CHECKING: diff --git a/backpack/__init__.py b/backpack/__init__.py index 35e339778..2ae09a779 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -1,44 +1,50 @@ """BackPACK.""" import inspect +from types import TracebackType +from typing import Callable, Optional, Tuple, Type, Union import torch +from torch import Tensor +from torch.nn import Module from backpack.extensions.backprop_extension import BackpropExtension from backpack.utils.hooks import no_op from . import extensions from .context import CTX +from .utils import FULL_BACKWARD_HOOK +from .utils.module_classification import is_no_op class backpack: - """Activate BackPACK extensions. + """Activate BackPACK extensions.""" - Enables the BackPACK extensions passed as arguments in the - :code:`backward` calls inside the current :code:`with` block. + def __init__( + self, + *exts: BackpropExtension, + extension_hook: Callable[[Module], None] = None, + debug: bool = False + ): + """Activate BackPACK extensions. - Args: - exts ([BackpropExtension]): Extensions to activate in the backward pass. - extension_hook (function, optional): Function called on each module after - all BackPACK extensions have run. Takes a ``torch.nn.Module`` and returns - ``None``. Default: ``None`` (no operation will be formed). + Enables the BackPACK extensions passed as arguments in the + :code:`backward` calls inside the current :code:`with` block. - Can be used to reduce memory overhead if the goal is to compute + Args: + exts: Extensions to activate in the backward pass. + extension_hook: Function called on each module after + all BackPACK extensions have run. Takes a ``torch.nn.Module`` and returns + ``None``. Default: ``None`` (no operation will be performed). + debug: Print debug messages during the backward pass. Default: ``False``. + + .. note:: + extension_hook can be used to reduce memory overhead if the goal is to compute transformations of BackPACK quantities. Information can be compacted during a backward pass and obsolete tensors be freed manually (``del``). - .. note:: - - If the callable iterates over the ``module.parameters()``, the same - parameter may be seen multiple times across calls. This happens - if the parameters are part of multiple modules. - For example, the parameters of a `torch.nn.Linear` module in - ``model = torch.nn.Sequential(torch.nn.Linear(...))`` are part of - both the ``Linear`` and the ``Sequential``. - debug (bool, optional): Print debug messages during the backward pass. - Default: ``False``. - """ - - def __init__(self, *exts: BackpropExtension, extension_hook=None, debug=False): + Raises: + ValueError: if extensions are not valid + """ for ext in exts: if not isinstance(ext, BackpropExtension): if inspect.isclass(ext) and issubclass(ext, BackpropExtension): @@ -53,11 +59,14 @@ def __init__(self, *exts: BackpropExtension, extension_hook=None, debug=False): + " but received [{}].".format(ext) ) - self.exts = exts - self.debug = debug - self.extension_hook = no_op if extension_hook is None else extension_hook + self.exts: Tuple[BackpropExtension, ...] = exts + self.debug: bool = debug + self.extension_hook: Callable[[Module], None] = ( + no_op if extension_hook is None else extension_hook + ) def __enter__(self): + """Setup backpack environment.""" self.old_CTX = CTX.get_active_exts() self.old_debug = CTX.get_debug() self.old_extension_hook = CTX.get_extension_hook() @@ -65,7 +74,19 @@ def __enter__(self): CTX.set_debug(self.debug) CTX.set_extension_hook(self.extension_hook) - def __exit__(self, type, value, traceback): + def __exit__( + self, + __exc_type: Optional[Type[BaseException]], + __exc_value: Optional[BaseException], + __traceback: Optional[TracebackType], + ): + """Leave backpack environment. + + Args: + __exc_type: exception type + __exc_value: exception value + __traceback: exception traceback + """ CTX.set_active_exts(self.old_CTX) CTX.set_debug(self.old_debug) CTX.set_extension_hook(self.old_extension_hook) @@ -91,33 +112,50 @@ class disable: even if the forward pass is carried out in ``with backpack(...)``. """ - store_io = True + store_io: bool = True def __enter__(self): """Disable input/output storing.""" - self.old_store_io = disable.store_io + self.old_store_io: bool = disable.store_io disable.store_io = False - def __exit__(self, type, value, traceback): - """Set input/output storing to old value.""" + def __exit__( + self, + __exc_type: Optional[Type[BaseException]], + __exc_value: Optional[BaseException], + __traceback: Optional[TracebackType], + ): + """Leave backpack environment. + + Args: + __exc_type: exception type + __exc_value: exception value + __traceback: exception traceback + """ disable.store_io = self.old_store_io @staticmethod - def should_store_io(): - """Return whether input and output should be stored.""" + def should_store_io() -> bool: + """Return whether input and output should be stored during forward pass. + + Returns: + whether input and output should be stored during forward pass + """ return disable.store_io -def hook_store_io(module, input, output): +def hook_store_io( + module: Module, input: Tuple[Tensor], output: Union[Tensor, Tuple[Tensor]] +) -> None: """Saves the input and output as attributes of the module. The list of inputs with index i is saved as module.input[i] The output is reduced to single output tensor and saved as module.output Args: - module (torch.nn.Module): the module on which to save the params - input (list): List of input tensors - output (torch.Tensor or tuple): result of module(input) + module: the module on which to save the inputs/outputs + input: List of input tensors + output: result of module(input) """ if disable.should_store_io() and torch.is_grad_enabled(): for i in range(len(input)): @@ -129,10 +167,13 @@ def hook_store_io(module, input, output): module.output = output -def memory_cleanup(module): +def memory_cleanup(module) -> None: """Remove I/O stored by backpack during the forward pass. Deletes the attributes created by `hook_store_io`. + + Args: + module: current module """ if hasattr(module, "output"): delattr(module, "output") @@ -142,13 +183,27 @@ def memory_cleanup(module): i += 1 -def hook_run_extensions(module, g_inp, g_out): +def hook_run_extensions( + module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] +) -> None: + """The backward hook function. + + It executes all BackPACK operations during the backward pass. + + Args: + module: current module + g_inp: input gradients + g_out: output gradients + """ + debug = CTX.get_debug() for backpack_extension in CTX.get_active_exts(): - if CTX.get_debug(): + if debug: print("[DEBUG] Running extension", backpack_extension, "on", module) - backpack_extension.apply(module, g_inp, g_out) + backpack_extension(module, g_inp, g_out) - run_extension_hook(module) + if debug: + print("[DEBUG] Running extension hook on", module) + CTX.get_extension_hook()(module) if not ( CTX.is_extension_active( @@ -160,31 +215,18 @@ def hook_run_extensions(module, g_inp, g_out): memory_cleanup(module) -def run_extension_hook(module): - """Execute the post extensions hook on a module after all BackPACK extensions. - - See the `post_backward_hook` argument of the `backpack` context manager for details. - """ - try: - CTX.get_extension_hook()(module) - except Exception as e: - message = getattr(e, "message", repr(e)) - raise RuntimeError(f"Post extensions hook failed: {message}") - - -def extend(module: torch.nn.Module, debug=False): - """Extends a ``module`` to make it BackPACK-ready. +def extend(module: Module, debug: bool = False) -> Module: + """Recursively extend a ``module`` to make it BackPACK-ready. - If the ``module`` has children, e.g. for a ``torch.nn.Sequential``, - they will also be extended. + Modules that do not represent an operation in the computation graph (for instance + containers like ``Sequential``) will not explicitly be extended. Args: - module (torch.nn.Module): The module to extend. - debug (bool, optional): Print debug messages during the extension. - Default: ``False``. + module: The module to extend. + debug: Print debug messages during the extension. Default: ``False``. Returns: - torch.nn.Module: Extended module. + Extended module. """ if debug: print("[DEBUG] Extending", module) @@ -192,10 +234,26 @@ def extend(module: torch.nn.Module, debug=False): for child in module.children(): extend(child, debug=debug) - module_was_already_extended = getattr(module, "_backpack_extend", False) - if not module_was_already_extended: - CTX.add_hook_handle(module.register_forward_hook(hook_store_io)) - CTX.add_hook_handle(module.register_backward_hook(hook_run_extensions)) - module._backpack_extend = True + extended_flag = "_backpack_extend" + already_extended = getattr(module, extended_flag, False) + if not (already_extended or is_no_op(module)): + _register_hooks(module) + setattr(module, extended_flag, True) return module + + +def _register_hooks(module: Module) -> None: + """Install forward and backward hooks on a module. + + Args: + module: module that is going to be extended + """ + CTX.add_hook_handle(module.register_forward_hook(hook_store_io)) + + if FULL_BACKWARD_HOOK: + register_backward_hook_fn = module.register_full_backward_hook + else: + register_backward_hook_fn = module.register_backward_hook + + CTX.add_hook_handle(register_backward_hook_fn(hook_run_extensions)) diff --git a/backpack/context.py b/backpack/context.py index 9da73faa9..433ad474f 100644 --- a/backpack/context.py +++ b/backpack/context.py @@ -1,56 +1,99 @@ +"""Context class for BackPACK.""" +from typing import Callable, Iterable, List, Tuple, Type + +from torch.nn import Module +from torch.utils.hooks import RemovableHandle + +from backpack.extensions.backprop_extension import BackpropExtension from backpack.utils.hooks import no_op class CTX: - """ - Global Class holding the configuration of the backward pass - """ + """Global Class holding the configuration of the backward pass.""" - active_exts = tuple() - debug = False - extension_hook = no_op + active_exts: Tuple[BackpropExtension] = tuple() + debug: bool = False + extension_hook: Callable[[Module], None] = no_op + hook_handles: List[RemovableHandle] = [] @staticmethod - def set_active_exts(active_exts): - CTX.active_exts = tuple() - for act_ext in active_exts: - CTX.active_exts += (act_ext,) + def set_active_exts(active_exts: Iterable[BackpropExtension]) -> None: + """Set the active backpack extensions. + + Args: + active_exts: the extensions + """ + CTX.active_exts = tuple(active_exts) @staticmethod - def get_active_exts(): + def get_active_exts() -> Tuple[BackpropExtension]: + """Get the currently active extensions. + + Returns: + active extensions + """ return CTX.active_exts @staticmethod - def add_hook_handle(hook_handle): - if getattr(CTX, "hook_handles", None) is None: - CTX.hook_handles = [] + def add_hook_handle(hook_handle: RemovableHandle) -> None: + """Add the hook handle to internal variable hook_handles. + + Args: + hook_handle: the removable handle + """ CTX.hook_handles.append(hook_handle) @staticmethod - def remove_hooks(): + def remove_hooks() -> None: + """Remove all hooks.""" for handle in CTX.hook_handles: handle.remove() CTX.hook_handles = [] @staticmethod - def is_extension_active(*extension_classes): - for backpack_ext in CTX.get_active_exts(): - if isinstance(backpack_ext, extension_classes): - return True - return False + def is_extension_active(*extension_classes: Type[BackpropExtension]) -> bool: + """Returns whether the specified class is currently active. + + Args: + *extension_classes: classes to test + + Returns: + whether at least one of the specified extensions is active + """ + return any(isinstance(ext, extension_classes) for ext in CTX.get_active_exts()) @staticmethod - def get_debug(): + def get_debug() -> bool: + """Whether debug mode is active. + + Returns: + whether debug mode is active + """ return CTX.debug @staticmethod - def set_debug(debug): + def set_debug(debug: bool) -> None: + """Set debug mode. + + Args: + debug: the mode to set + """ CTX.debug = debug @staticmethod - def get_extension_hook(): + def get_extension_hook() -> Callable[[Module], None]: + """Return the current extension hook to be run after all other extensions. + + Returns: + current extension hook + """ return CTX.extension_hook @staticmethod - def set_extension_hook(extension_hook): + def set_extension_hook(extension_hook: Callable[[Module], None]) -> None: + """Set the current extension hook. + + Args: + extension_hook: the extension hook to run after all other extensions + """ CTX.extension_hook = extension_hook diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index 7643ad22f..5d9a611b3 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -6,7 +6,7 @@ from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION, VERSION_1_9_0 +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 from backpack.utils.subsampling import subsample @@ -150,7 +150,7 @@ def _weight_jac_t_mat_prod( x_hat, _ = self._get_normalized_input_and_var(module) x_hat = subsample(x_hat, subsampling=subsampling) - if TORCH_VERSION >= VERSION_1_9_0: + if TORCH_VERSION_AT_LEAST_1_9_0: equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c" # TODO Remove else-branch after deprecating torch<1.9.0 else: diff --git a/backpack/core/derivatives/flatten.py b/backpack/core/derivatives/flatten.py index a0af28da1..aac7f7992 100644 --- a/backpack/core/derivatives/flatten.py +++ b/backpack/core/derivatives/flatten.py @@ -32,13 +32,3 @@ def _jac_mat_prod( mat: Tensor, ) -> Tensor: return self.reshape_like_output(mat, module) - - def is_no_op(self, module): - """Does flatten add an operation to the computational graph. - - If the input is already flattened, no operation will be added for - the `Flatten` layer. This can lead to an intuitive order of backward - hook execution, see the discussion at https://discuss.pytorch.org/t/ - backward-hooks-changing-order-of-execution-in-nn-sequential/12447/4 . - """ - return tuple(module.input0.shape) == tuple(module.output.shape) diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index aae3da1a5..27743e3b2 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -5,7 +5,7 @@ from torch.nn import Linear from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION, VERSION_1_9_0 +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 from backpack.utils.subsampling import subsample @@ -152,7 +152,7 @@ def _weight_jac_t_mat_prod( """ d_weight = subsample(module.input0, subsampling=subsampling) - if TORCH_VERSION >= VERSION_1_9_0: + if TORCH_VERSION_AT_LEAST_1_9_0: equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi" # TODO Remove else-branch after deprecating torch<1.9.0 else: diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 1d31638e7..6d68dd764 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -5,7 +5,7 @@ from torch.nn import LSTM from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION, VERSION_1_8_0 +from backpack.utils import TORCH_VERSION_AT_LEAST_1_8_0 from backpack.utils.subsampling import get_batch_axis, subsample @@ -49,7 +49,7 @@ def _check_parameters(module: LSTM) -> None: raise NotImplementedError("only dropout = 0 is supported") if module.bidirectional is not False: raise NotImplementedError("only bidirectional = False is supported") - if TORCH_VERSION >= VERSION_1_8_0: + if TORCH_VERSION_AT_LEAST_1_8_0: if module.proj_size != 0: raise NotImplementedError("only proj_size = 0 is supported") diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py index 06c614c3a..0b584ade9 100644 --- a/backpack/extensions/backprop_extension.py +++ b/backpack/extensions/backprop_extension.py @@ -1,20 +1,23 @@ """Implements the backpropagation mechanism.""" +from __future__ import annotations + +import abc import warnings -from typing import Type +from abc import ABC +from typing import Dict, Tuple, Type, Union -import torch.nn -from torch.nn import Sequential +from torch import Tensor +from torch.nn import Module -from backpack.custom_module.reduce_tuple import ReduceTuple from backpack.extensions.module_extension import ModuleExtension -from backpack.utils.hooks import no_op +from backpack.extensions.saved_quantities import SavedQuantities FAIL_ERROR = "ERROR" -FAIL_WARN = "WARN" +FAIL_WARN = "WARNING" FAIL_SILENT = "SILENT" -class BackpropExtension: +class BackpropExtension(ABC): """Base class for the BackPACK extensions. Descendants of this class need to @@ -31,28 +34,36 @@ class BackpropExtension: ``` """ - def __init__(self, savefield, module_exts, fail_mode=FAIL_ERROR): + def __init__( + self, + savefield: str, + module_exts: Dict[Type[Module], ModuleExtension], + fail_mode: str = FAIL_ERROR, + ): """Initializes parameters. Args: - savefield(str): Where to save results - module_exts(dict): Maps module classes to `ModuleExtension` instances - fail_mode(str, optional): Behavior when encountering an unknown layer. + savefield: Where to save results + module_exts: Maps module classes to `ModuleExtension` instances + fail_mode: Behavior when encountering an unknown layer. Can be - "ERROR": raise a NotImplementedError - "WARN": raise a UserWarning - "SILENT": skip the module silently Defaults to FAIL_ERROR = "ERROR" + + Raises: + AssertionError: if fail_mode is not valid """ - self.savefield = savefield - self.__module_extensions = module_exts - self.__fail_mode = fail_mode + if fail_mode not in (FAIL_WARN, FAIL_ERROR, FAIL_SILENT): + raise AssertionError(f"no valid fail mode: {fail_mode}") + self.saved_quantities: SavedQuantities = SavedQuantities() + self.savefield: str = savefield + self.__module_extensions: Dict[Type[Module], ModuleExtension] = module_exts + self._fail_mode: str = fail_mode def set_module_extension( - self, - module: Type[torch.nn.Module], - extension: ModuleExtension, - overwrite: bool = False, + self, module: Type[Module], extension: ModuleExtension, overwrite: bool = False ) -> None: """Adds a module mapping to module_extensions. @@ -74,38 +85,47 @@ def set_module_extension( ) self.__module_extensions[module] = extension - def __get_module_extension(self, module): + def __get_module_extension(self, module: Module) -> Union[ModuleExtension, None]: module_extension = self.__module_extensions.get(module.__class__) if module_extension is None: - - if isinstance(module, (Sequential, ReduceTuple)): - return no_op - - if self.__fail_mode is FAIL_ERROR: + if self._fail_mode is FAIL_ERROR: + # PyTorch converts this Error into a RuntimeError for torch<1.7.0 raise NotImplementedError( - "Extension saving to {} ".format(self.savefield) - + "does not have an extension for " - + "Module {}".format(module.__class__) - ) - elif self.__fail_mode == FAIL_WARN: - warnings.warn( - "Extension saving to {} ".format(self.savefield) - + "does not have an extension for " - + "Module {}".format(module.__class__) + f"Extension saving to {self.savefield} " + "does not have an extension for " + f"Module {module.__class__}" ) - - return no_op - - return module_extension.apply - - def apply(self, module, g_inp, g_out): + elif self._fail_mode == FAIL_WARN: + for _ in module.parameters(): + warnings.warn( + f"Extension saving to {self.savefield} does not have an " + f"extension for Module {module.__class__} " + f"although the module has parameters" + ) + break + + return module_extension + + def __call__( + self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> None: """Applies backpropagation. Args: - module(torch.nn.module): module to perform backpropagation on - g_inp(tuple[torch.Tensor]): input gradient - g_out(tuple[torch.Tensor]): output gradient + module: module to perform backpropagation on + g_inp: input gradient + g_out: output gradient """ module_extension = self.__get_module_extension(module) - module_extension(self, module, g_inp, g_out) + if module_extension is not None: + module_extension(self, module, g_inp, g_out) + + @abc.abstractmethod + def expects_backpropagation_quantities(self) -> bool: + """Whether the extension uses additional backpropagation quantities. + + Returns: + Whether the extension uses additional backpropagation quantities. + """ + return diff --git a/backpack/extensions/curvmatprod/ggnmp/__init__.py b/backpack/extensions/curvmatprod/ggnmp/__init__.py index c80f3c264..0813e367b 100644 --- a/backpack/extensions/curvmatprod/ggnmp/__init__.py +++ b/backpack/extensions/curvmatprod/ggnmp/__init__.py @@ -18,7 +18,7 @@ ZeroPad2d, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from . import ( activations, @@ -33,7 +33,7 @@ ) -class GGNMP(BackpropExtension): +class GGNMP(SecondOrderBackpropExtension): """ Matrix-free Multiplication with the block-diagonal generalized Gauss-Newton/Fisher. diff --git a/backpack/extensions/curvmatprod/ggnmp/flatten.py b/backpack/extensions/curvmatprod/ggnmp/flatten.py index 532c24cb6..d47d9d84d 100644 --- a/backpack/extensions/curvmatprod/ggnmp/flatten.py +++ b/backpack/extensions/curvmatprod/ggnmp/flatten.py @@ -5,9 +5,3 @@ class GGNMPFlatten(GGNMPBase): def __init__(self): super().__init__(derivatives=FlattenDerivatives()) - - def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - if self.derivatives.is_no_op(module): - return backproped - else: - return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/curvmatprod/hmp/__init__.py b/backpack/extensions/curvmatprod/hmp/__init__.py index a358b7f67..f94965b35 100644 --- a/backpack/extensions/curvmatprod/hmp/__init__.py +++ b/backpack/extensions/curvmatprod/hmp/__init__.py @@ -16,7 +16,7 @@ ZeroPad2d, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from . import ( activations, @@ -31,7 +31,7 @@ ) -class HMP(BackpropExtension): +class HMP(SecondOrderBackpropExtension): """Matrix-free multiplication with the block-diagonal Hessian. Stores the multiplication function in :code:`hmp`. diff --git a/backpack/extensions/curvmatprod/hmp/flatten.py b/backpack/extensions/curvmatprod/hmp/flatten.py index ab94fd5e2..300f669cf 100644 --- a/backpack/extensions/curvmatprod/hmp/flatten.py +++ b/backpack/extensions/curvmatprod/hmp/flatten.py @@ -5,9 +5,3 @@ class HMPFlatten(HMPBase): def __init__(self): super().__init__(derivatives=FlattenDerivatives()) - - def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - if self.derivatives.is_no_op(module): - return backproped - else: - return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/curvmatprod/pchmp/__init__.py b/backpack/extensions/curvmatprod/pchmp/__init__.py index 11dd33e03..058f1c762 100644 --- a/backpack/extensions/curvmatprod/pchmp/__init__.py +++ b/backpack/extensions/curvmatprod/pchmp/__init__.py @@ -17,12 +17,12 @@ ZeroPad2d, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from . import activations, conv2d, dropout, flatten, linear, losses, padding, pooling -class PCHMP(BackpropExtension): +class PCHMP(SecondOrderBackpropExtension): """ Matrix-free multiplication with the block-diagonal positive-curvature Hessian (PCH). diff --git a/backpack/extensions/curvmatprod/pchmp/flatten.py b/backpack/extensions/curvmatprod/pchmp/flatten.py index 29403437c..1cbaedce1 100644 --- a/backpack/extensions/curvmatprod/pchmp/flatten.py +++ b/backpack/extensions/curvmatprod/pchmp/flatten.py @@ -5,9 +5,3 @@ class PCHMPFlatten(PCHMPBase): def __init__(self): super().__init__(derivatives=FlattenDerivatives()) - - def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - if self.derivatives.is_no_op(module): - return backproped - else: - return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/firstorder/base.py b/backpack/extensions/firstorder/base.py index e3c62d08b..af761cd18 100644 --- a/backpack/extensions/firstorder/base.py +++ b/backpack/extensions/firstorder/base.py @@ -1,6 +1,26 @@ +"""Base class for first order extensions.""" +from typing import Dict, Type + +from torch.nn import Module + +from backpack.extensions.backprop_extension import FAIL_WARN, BackpropExtension from backpack.extensions.module_extension import ModuleExtension class FirstOrderModuleExtension(ModuleExtension): - def backpropagate(self, ext, module, g_inp, g_out, bpQuantities): - return None + """Base class for first order module extensions.""" + + +class FirstOrderBackpropExtension(BackpropExtension): + """Base backpropagation extension for first order.""" + + def __init__( + self, + savefield: str, + module_exts: Dict[Type[Module], ModuleExtension], + fail_mode: str = FAIL_WARN, + ): # noqa: D107 + super().__init__(savefield, module_exts, fail_mode=fail_mode) + + def expects_backpropagation_quantities(self) -> bool: # noqa: D102 + return False diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 7fff33789..7edf6176e 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -18,7 +18,7 @@ Linear, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.firstorder.base import FirstOrderBackpropExtension from . import ( batchnorm_nd, @@ -33,7 +33,7 @@ ) -class BatchGrad(BackpropExtension): +class BatchGrad(FirstOrderBackpropExtension): """Individual gradients for each sample in a minibatch. Stores the output in ``grad_batch`` as a ``[N x ...]`` tensor, @@ -68,7 +68,6 @@ def __init__(self, subsampling: List[int] = None): """ super().__init__( savefield="grad_batch", - fail_mode="WARNING", module_exts={ Linear: linear.BatchGradLinear(), Conv1d: conv1d.BatchGradConv1d(), diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py index cb0637ea2..ea13772cb 100644 --- a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py +++ b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py @@ -13,6 +13,6 @@ def __init__(self): derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] ) - def apply(self, ext, module, g_inp, g_out): # noqa: D102 + def __call__(self, ext, module, g_inp, g_out): # noqa: D102 batch_norm_raise_error_if_train(module, raise_error=False) - super().apply(ext, module, g_inp, g_out) + super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 4f643fb48..6a5e770b1 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -17,7 +17,7 @@ Linear, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.firstorder.base import FirstOrderBackpropExtension from backpack.extensions.firstorder.batch_l2_grad import ( batchnorm_nd, convnd, @@ -27,7 +27,7 @@ ) -class BatchL2Grad(BackpropExtension): +class BatchL2Grad(FirstOrderBackpropExtension): """The squared L2 norm of individual gradients in the minibatch. Stores the output in ``batch_l2`` as a tensor of size ``[N]``, @@ -52,7 +52,6 @@ def __init__(self): """ super().__init__( savefield="batch_l2", - fail_mode="WARNING", module_exts={ Linear: linear.BatchL2Linear(), Conv1d: convnd.BatchL2Conv1d(), diff --git a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py index 42d6e2426..5d6d09bf6 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py +++ b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py @@ -11,6 +11,6 @@ def __init__(self): """Initialization.""" super().__init__(["weight", "bias"], BatchNormNdDerivatives()) - def apply(self, ext, module, g_inp, g_out): # noqa: D102 + def __call__(self, ext, module, g_inp, g_out): # noqa: D102 batch_norm_raise_error_if_train(module) - super().apply(ext, module, g_inp, g_out) + super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/gradient/batchnorm_nd.py b/backpack/extensions/firstorder/gradient/batchnorm_nd.py index 02dd78412..00322852a 100644 --- a/backpack/extensions/firstorder/gradient/batchnorm_nd.py +++ b/backpack/extensions/firstorder/gradient/batchnorm_nd.py @@ -14,6 +14,6 @@ def __init__(self): derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] ) - def apply(self, ext, module, g_inp, g_out): # noqa: D102 + def __call__(self, ext, module, g_inp, g_out): # noqa: D102 batch_norm_raise_error_if_train(module) - super().apply(ext, module, g_inp, g_out) + super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index a22f5b73f..8dea8e44c 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -16,7 +16,7 @@ Linear, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.firstorder.base import FirstOrderBackpropExtension from . import ( batchnorm_nd, @@ -31,7 +31,7 @@ ) -class SumGradSquared(BackpropExtension): +class SumGradSquared(FirstOrderBackpropExtension): """The sum of individual-gradients-squared, or second moment of the gradient. Stores the output in ``sum_grad_squared``. Same dimension as the gradient. @@ -55,7 +55,6 @@ def __init__(self): """ super().__init__( savefield="sum_grad_squared", - fail_mode="WARNING", module_exts={ Linear: linear.SGSLinear(), Conv1d: conv1d.SGSConv1d(), diff --git a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py index 866c95736..891ca2ee3 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py +++ b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py @@ -11,6 +11,6 @@ def __init__(self): """Initialization.""" super().__init__(BatchNormNdDerivatives(), ["weight", "bias"]) - def apply(self, ext, module, g_inp, g_out): # noqa: D102 + def __call__(self, ext, module, g_inp, g_out): # noqa: D102 batch_norm_raise_error_if_train(module) - super().apply(ext, module, g_inp, g_out) + super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index f3e3085d2..831c07e40 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -16,7 +16,7 @@ Linear, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.firstorder.base import FirstOrderBackpropExtension from . import ( batchnorm_nd, @@ -31,7 +31,7 @@ ) -class Variance(BackpropExtension): +class Variance(FirstOrderBackpropExtension): """Estimates the variance of the gradient using the samples in the minibatch. Stores the output in ``variance``. Same dimension as the gradient. @@ -55,7 +55,6 @@ def __init__(self): """ super().__init__( savefield="variance", - fail_mode="WARNING", module_exts={ Linear: linear.VarianceLinear(), Conv1d: conv1d.VarianceConv1d(), diff --git a/backpack/extensions/firstorder/variance/batchnorm_nd.py b/backpack/extensions/firstorder/variance/batchnorm_nd.py index 6bcaa9386..00d516515 100644 --- a/backpack/extensions/firstorder/variance/batchnorm_nd.py +++ b/backpack/extensions/firstorder/variance/batchnorm_nd.py @@ -12,6 +12,6 @@ def __init__(self): """Initialization.""" super().__init__(["weight", "bias"], GradBatchNormNd(), SGSBatchNormNd()) - def apply(self, ext, module, g_inp, g_out): # noqa: D102 + def __call__(self, ext, module, g_inp, g_out): # noqa: D102 batch_norm_raise_error_if_train(module) - super().apply(ext, module, g_inp, g_out) + super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py index 1b78b78c7..0c5425e3b 100644 --- a/backpack/extensions/module_extension.py +++ b/backpack/extensions/module_extension.py @@ -1,9 +1,21 @@ -import warnings +"""Contains base class for BackPACK module extensions.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Tuple +from warnings import warn + +from torch import Tensor +from torch.nn import Flatten, Module + +from backpack.utils import FULL_BACKWARD_HOOK +from backpack.utils.module_classification import is_loss + +if TYPE_CHECKING: + from backpack import BackpropExtension class ModuleExtension: - """ - Base class for a Module Extension for BackPACK. + """Base class for a Module Extension for BackPACK. Descendants of this class need to - define what parameters of the Module need to be treated (weight, bias) @@ -12,98 +24,209 @@ class ModuleExtension: needs to be propagated through the graph. """ - def __init__(self, params=None): - """ - Parameters - ---------- - params: [str] - List of module parameters that need special treatment. - for each param `p` in the list, instances of the extended module `m` - need to have a field `m.p` and the class extending `ModuleExtension` - need to provide a method with the same signature as the `backprop` - method. - The result of this method will be saved in the savefield of `m.p`. - """ - if params is None: - params = [] + def __init__(self, params: List[str] = None): + """Initialization. - self.__params = params + Args: + params: List of module parameters that need special treatment. + For each param `p` in the list, instances of the extended module `m` + need to have a field `m.p` and the class extending `ModuleExtension` + needs to provide a method with the same signature as the `backpropagate` + method. + The result of this method will be saved in the savefield of `m.p`. - for param in self.__params: - extFunc = getattr(self, param, None) - if extFunc is None: - raise NotImplementedError - - def backpropagate(self, ext, module, g_inp, g_out, bpQuantities): + Raises: + NotImplementedError: if child class doesn't have a method for each parameter """ - Main method to extend to backpropagate additional information through - the graph. - - Parameters - ---------- - ext: BackpropExtension - Instance of the extension currently running - module: torch.nn.Module - Instance of the extended module - g_inp: [Tensor] - Gradient of the loss w.r.t. the inputs - g_out: Tensor - Gradient of the loss w.r.t. the output - bpQuantities: - Quantities backpropagated w.r.t. the output + self.__params: List[str] = [] if params is None else params + + for param in self.__params: + if not hasattr(self, param): + raise NotImplementedError( + f"The module extension {self} is missing an implementation " + f"of how to calculate the quantity for {param}. " + f"This should be realized in a function " + f"{param}(extension, module, g_inp, g_out, bpQuantities) -> Any." + ) + + def backpropagate( + self, + extension: BackpropExtension, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + bpQuantities: Any, + ) -> Any: + """Backpropagation of additional information through the graph. + + Args: + extension: Instance of the extension currently running + module: Instance of the extended module + g_inp: Gradient of the loss w.r.t. the inputs + g_out: Gradient of the loss w.r.t. the output + bpQuantities: Quantities backpropagated w.r.t. the output Returns - ------- - bpQuantities: Quantities backpropagated w.r.t. the input """ - warnings.warn("Backpropagate has not been overwritten") + warn("Backpropagate has not been overwritten") + + def __call__( + self, + extension: BackpropExtension, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: + """Apply all actions required by the extension. - def apply(self, ext, module, g_inp, g_out): - """ Fetch backpropagated quantities from module output, apply backpropagation - rule, and attach the result to module input(s). + rule, and store as backpropagated quantities for the module input(s). + + Args: + extension: current backpropagation extension + module: current module + g_inp: input gradients + g_out: output gradients + + Raises: + AssertionError: if there is no saved quantity although extension expects one, + or if a backpropagated quantity is expected, but there is None and the old + backward hook is used and the module is not a Flatten no op. """ - inp = module.input0 - out = module.output - - bpQuantities = self.__backproped_quantities(ext, out) + delete_old_quantities = not self.__should_retain_backproped_quantities(module) + bp_quantity = self.__get_backproped_quantity( + extension, module.output, delete_old_quantities + ) + if ( + extension.expects_backpropagation_quantities() + and bp_quantity is None + and not is_loss(module) + ): + if not FULL_BACKWARD_HOOK and isinstance(module, Flatten): + # Flatten layers whose input is already flat do not add a node to the + # graph. This leads to unintuitive order of backward hook execution: + # https://discuss.pytorch.org/t/backward-hooks-changing-order-of-execution-in-nn-sequential/12447/4. # noqa: B950 + # Skip everything below if this scenario is encountered. + no_op = module.input0.shape == module.output.shape + if not no_op: + raise AssertionError( + "Expected no op Flatten module. Got " + + f"{module.input0.shape} -> {module.output.shape}" + ) + return + else: + raise AssertionError( + "BackPACK extension expects a backpropagation quantity but it is None. " + f"Module: {module}, Extension: {extension}." + ) for param in self.__params: if self.__param_exists_and_requires_grad(module, param): extFunc = getattr(self, param) - extValue = extFunc(ext, module, g_inp, g_out, bpQuantities) - self.__save(extValue, ext, module, param) - - bpQuantities = self.backpropagate(ext, module, g_inp, g_out, bpQuantities) + extValue = extFunc(extension, module, g_inp, g_out, bp_quantity) + self.__save_value_on_parameter(extValue, extension, module, param) - self.__backprop_quantities(ext, inp, out, bpQuantities) + if self.__should_backpropagate(extension, module): + bp_quantity = self.backpropagate( + extension, module, g_inp, g_out, bp_quantity + ) + self.__save_backproped_quantity(extension, module.input0, bp_quantity) @staticmethod - def __backproped_quantities(ext, out): - """Fetch backpropagated quantities attached to the module output.""" - return getattr(out, ext.savefield, None) + def __should_backpropagate(extension: BackpropExtension, module: Module) -> bool: + """Determines whether the current extension should perform a backpropagation. + + Args: + extension: current extension + module: current module + + Returns: + whether a backpropagation should be performed + """ + input_requires_grad: bool = module.input0.requires_grad + return input_requires_grad and extension.expects_backpropagation_quantities() @staticmethod - def __backprop_quantities(ext, inp, out, bpQuantities): - """Propagate back additional information by attaching it to the module input.""" + def __should_retain_backproped_quantities(module: Module) -> bool: + """Whether the backpropagation quantities should be kept. + + This is old code inherited and not tested. - setattr(inp, ext.savefield, bpQuantities) + Args: + module: current module - is_a_leaf = out.grad_fn is None - retain_grad_is_on = getattr(out, "retains_grad", False) - inp_is_out = id(inp) == id(out) - should_retain_grad = is_a_leaf or retain_grad_is_on or inp_is_out + Returns: + whether backpropagation quantities should be kept + """ + is_a_leaf = module.output.grad_fn is None + retain_grad_is_on = getattr(module.output, "retains_grad", False) + # inp_is_out = id(module.input0) == id(module.output) + should_retain_grad = is_a_leaf or retain_grad_is_on # or inp_is_out + return should_retain_grad + + @staticmethod + def __get_backproped_quantity( + extension: BackpropExtension, + reference_tensor: Tensor, + delete_old: bool, + ) -> Tensor or None: + """Fetch backpropagated quantities attached to the module output. + + The property reference_tensor.data_ptr() is used as a reference. + + Args: + extension: current BackPACK extension + reference_tensor: the output Tensor of the current module + delete_old: whether to delete the old backpropagated quantity + + Returns: + the backpropagation quantity + """ + return extension.saved_quantities.retrieve_quantity( + reference_tensor.data_ptr(), delete_old + ) - if not should_retain_grad: - if hasattr(out, ext.savefield): - delattr(out, ext.savefield) + @staticmethod + def __save_backproped_quantity( + extension: BackpropExtension, reference_tensor: Tensor, bpQuantities: Any + ) -> None: + """Save additional information backpropagated for a tensor. + + Args: + extension: current BackPACK extension + reference_tensor: reference tensor for which additional information + is backpropagated. + bpQuantities: backpropagation quantities that should be saved + """ + extension.saved_quantities.save_quantity( + reference_tensor.data_ptr(), bpQuantities + ) @staticmethod - def __param_exists_and_requires_grad(module, param): - param_exists = getattr(module, param) is not None - return param_exists and getattr(module, param).requires_grad + def __param_exists_and_requires_grad(module: Module, param_str: str) -> bool: + """Whether the module has the parameter and it requires gradient. + + Args: + module: current module + param_str: parameter name + + Returns: + whether the module has the parameter and it requires gradient + """ + param_exists = getattr(module, param_str) is not None + return param_exists and getattr(module, param_str).requires_grad @staticmethod - def __save(value, extension, module, param): - setattr(getattr(module, param), extension.savefield, value) + def __save_value_on_parameter( + value: Any, extension: BackpropExtension, module: Module, param_str: str + ) -> None: + """Saves the value on the parameter of that module. + + Args: + value: The value that should be saved. + extension: The current BackPACK extension. + module: current module + param_str: parameter name + """ + setattr(getattr(module, param_str), extension.savefield, value) diff --git a/backpack/extensions/saved_quantities.py b/backpack/extensions/saved_quantities.py new file mode 100644 index 000000000..78249dbd5 --- /dev/null +++ b/backpack/extensions/saved_quantities.py @@ -0,0 +1,48 @@ +"""Class for saving backpropagation quantities.""" +from typing import Dict, Union + +from torch import Tensor + + +class SavedQuantities: + """Implements interface to save backpropagation quantities.""" + + def __init__(self): + """Initialization.""" + self._saved_quantities: Dict[int, Tensor] = {} + + def save_quantity(self, key: int, quantity: Tensor) -> None: + """Saves the quantity under the specified key. + + Args: + key: data_ptr() of reference tensor (module.input0). + quantity: tensor to save + + Raises: + NotImplementedError: if the key already exists + """ + if key in self._saved_quantities: + # TODO if exists: accumulate quantities (ResNet) + raise NotImplementedError( + "Quantity with given key already exists. Multiple backpropagated " + "quantities like in ResNets are not supported yet." + ) + else: + self._saved_quantities[key] = quantity + + def retrieve_quantity(self, key: int, delete_old: bool) -> Union[Tensor, None]: + """Returns the saved quantity. + + Args: + key: data_ptr() of reference tensor. + For torch>=1.9.0 the reference tensor is grad_output[0]. + For older versions the reference tensor is module.output. + delete_old: whether to delete the old quantity + + Returns: + the saved quantity, None if it does not exist + """ + get_value = ( + self._saved_quantities.pop if delete_old else self._saved_quantities.get + ) + return get_value(key, None) diff --git a/backpack/extensions/secondorder/base.py b/backpack/extensions/secondorder/base.py new file mode 100644 index 000000000..d65fa548f --- /dev/null +++ b/backpack/extensions/secondorder/base.py @@ -0,0 +1,9 @@ +"""Contains base classes for second order extensions.""" +from backpack.extensions.backprop_extension import BackpropExtension + + +class SecondOrderBackpropExtension(BackpropExtension): + """Base backpropagation extension for second order.""" + + def expects_backpropagation_quantities(self) -> bool: # noqa: D102 + return True diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index 45848b5e2..d1cf20f53 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -44,7 +44,7 @@ ) from backpack.custom_module.permute import Permute -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from backpack.extensions.secondorder.hbp import LossHessianStrategy from . import ( @@ -68,7 +68,7 @@ ) -class DiagGGN(BackpropExtension): +class DiagGGN(SecondOrderBackpropExtension): """Base class for diagonal generalized Gauss-Newton/Fisher matrix.""" VALID_LOSS_HESSIAN_STRATEGIES = [ @@ -182,7 +182,7 @@ def get_num_mc_samples(self) -> int: return self._mc_samples -class BatchDiagGGN(BackpropExtension): +class BatchDiagGGN(SecondOrderBackpropExtension): """Base class for batched diagonal generalized Gauss-Newton/Fisher matrix.""" VALID_LOSS_HESSIAN_STRATEGIES = [ diff --git a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py index e48dbc68d..942be14ad 100644 --- a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py +++ b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py @@ -11,9 +11,9 @@ def __init__(self): """Initialization.""" super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=True) - def apply(self, ext, module, g_inp, g_out): # noqa: D102 + def __call__(self, ext, module, g_inp, g_out): # noqa: D102 batch_norm_raise_error_if_train(module) - super().apply(ext, module, g_inp, g_out) + super().__call__(ext, module, g_inp, g_out) class BatchDiagGGNBatchNormNd(DiagGGNBaseModule): @@ -23,6 +23,6 @@ def __init__(self): """Initialization.""" super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=False) - def apply(self, ext, module, g_inp, g_out): # noqa: D102 + def __call__(self, ext, module, g_inp, g_out): # noqa: D102 batch_norm_raise_error_if_train(module) - super().apply(ext, module, g_inp, g_out) + super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/secondorder/diag_ggn/flatten.py b/backpack/extensions/secondorder/diag_ggn/flatten.py index 60c1ca8d4..cf6f63358 100644 --- a/backpack/extensions/secondorder/diag_ggn/flatten.py +++ b/backpack/extensions/secondorder/diag_ggn/flatten.py @@ -5,9 +5,3 @@ class DiagGGNFlatten(DiagGGNBaseModule): def __init__(self): super().__init__(derivatives=FlattenDerivatives()) - - def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - if self.derivatives.is_no_op(module): - return backproped - else: - return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/diag_hessian/__init__.py b/backpack/extensions/secondorder/diag_hessian/__init__.py index 6bb9933d8..ffdf7d639 100644 --- a/backpack/extensions/secondorder/diag_hessian/__init__.py +++ b/backpack/extensions/secondorder/diag_hessian/__init__.py @@ -31,7 +31,7 @@ ZeroPad2d, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from . import ( activations, @@ -50,7 +50,7 @@ ) -class DiagHessian(BackpropExtension): +class DiagHessian(SecondOrderBackpropExtension): """BackPACK extension that computes the Hessian diagonal. Stores the output in :code:`diag_h`, has the same dimensions as the gradient. @@ -96,7 +96,7 @@ def __init__(self): ) -class BatchDiagHessian(BackpropExtension): +class BatchDiagHessian(SecondOrderBackpropExtension): """BackPACK extensions that computes the per-sample (individual) Hessian diagonal. Stores the output in ``diag_h_batch`` as a ``[N x ...]`` tensor, diff --git a/backpack/extensions/secondorder/diag_hessian/flatten.py b/backpack/extensions/secondorder/diag_hessian/flatten.py index d6d28b7c2..d8b01357a 100644 --- a/backpack/extensions/secondorder/diag_hessian/flatten.py +++ b/backpack/extensions/secondorder/diag_hessian/flatten.py @@ -5,9 +5,3 @@ class DiagHFlatten(DiagHBaseModule): def __init__(self): super().__init__(derivatives=FlattenDerivatives()) - - def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - if self.derivatives.is_no_op(module): - return backproped - else: - return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/hbp/__init__.py b/backpack/extensions/secondorder/hbp/__init__.py index 7f469da4d..2e529c4ef 100644 --- a/backpack/extensions/secondorder/hbp/__init__.py +++ b/backpack/extensions/secondorder/hbp/__init__.py @@ -13,8 +13,8 @@ ZeroPad2d, ) -from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.curvature import Curvature +from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from backpack.extensions.secondorder.hbp.hbp_options import ( BackpropStrategy, ExpectationApproximation, @@ -24,7 +24,7 @@ from . import activations, conv2d, dropout, flatten, linear, losses, padding, pooling -class HBP(BackpropExtension): +class HBP(SecondOrderBackpropExtension): def __init__( self, curv_type, diff --git a/backpack/extensions/secondorder/hbp/flatten.py b/backpack/extensions/secondorder/hbp/flatten.py index 990d0b023..c20014e92 100644 --- a/backpack/extensions/secondorder/hbp/flatten.py +++ b/backpack/extensions/secondorder/hbp/flatten.py @@ -5,9 +5,3 @@ class HBPFlatten(HBPBaseModule): def __init__(self): super().__init__(derivatives=FlattenDerivatives()) - - def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - if self.derivatives.is_no_op(module): - return backproped - else: - return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/sqrt_ggn/__init__.py b/backpack/extensions/secondorder/sqrt_ggn/__init__.py index d258b350f..2438b53d0 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/__init__.py +++ b/backpack/extensions/secondorder/sqrt_ggn/__init__.py @@ -28,7 +28,7 @@ ZeroPad2d, ) -from backpack.extensions.backprop_extension import BackpropExtension +from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from backpack.extensions.secondorder.hbp import LossHessianStrategy from backpack.extensions.secondorder.sqrt_ggn import ( activations, @@ -43,7 +43,7 @@ ) -class SqrtGGN(BackpropExtension): +class SqrtGGN(SecondOrderBackpropExtension): """Base class for extensions that compute the GGN/Fisher matrix square root.""" def __init__(self, loss_hessian_strategy: str, savefield: str): diff --git a/backpack/extensions/secondorder/sqrt_ggn/flatten.py b/backpack/extensions/secondorder/sqrt_ggn/flatten.py index bae59b402..2a045c957 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/flatten.py +++ b/backpack/extensions/secondorder/sqrt_ggn/flatten.py @@ -1,17 +1,7 @@ """Contains extensions for the flatten layer used by ``SqrtGGN{Exact, MC}``.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Tuple, Union - -from torch import Tensor -from torch.nn import Module - from backpack.core.derivatives.flatten import FlattenDerivatives from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule -if TYPE_CHECKING: - from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC - class SqrtGGNFlatten(SqrtGGNBaseModule): """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Flatten`` module.""" @@ -19,32 +9,3 @@ class SqrtGGNFlatten(SqrtGGNBaseModule): def __init__(self): """Pass derivatives for ``torch.nn.Flatten`` module.""" super().__init__(FlattenDerivatives()) - - def backpropagate( - self, - ext: Union[SqrtGGNExact, SqrtGGNMC], - module: Module, - grad_inp: Tuple[Tensor], - grad_out: Tuple[Tensor], - backproped: Tensor, - ) -> Tensor: - """Backpropagate only if flatten created a node in the computation graph. - - Otherwise, the backward hook will not be called at the right stage and - no action must be performed. - - Args: - ext: BackPACK extension calling out to the module extension. - module: Module that performed the forward pass. - grad_inp: Gradients w.r.t. the module inputs. - grad_out: Gradients w.r.t. the module outputs. - backproped: Backpropagated symmetric factorization of the loss Hessian - from the child module. - - Returns: - Symmetric loss Hessian factorization, backpropagated through the module. - """ - if self.derivatives.is_no_op(module): - return backproped - else: - return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index 8f5caf781..5e292385d 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -1,10 +1,28 @@ """Contains utility functions.""" +from typing import Type from pkg_resources import get_distribution, packaging TORCH_VERSION = packaging.version.parse(get_distribution("torch").version) -VERSION_1_9_1 = packaging.version.parse("1.9.1") -VERSION_1_9_0 = packaging.version.parse("1.9.0") -VERSION_1_8_0 = packaging.version.parse("1.8.0") -VERSION_1_6_0 = packaging.version.parse("1.6.0") -TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= VERSION_1_9_1 +TORCH_VERSION_AT_LEAST_1_7_0 = TORCH_VERSION >= packaging.version.parse("1.7.0") +TORCH_VERSION_AT_LEAST_1_8_0 = TORCH_VERSION >= packaging.version.parse("1.8.0") +TORCH_VERSION_AT_LEAST_1_9_0 = TORCH_VERSION >= packaging.version.parse("1.9.0") +TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1") +FULL_BACKWARD_HOOK: bool = TORCH_VERSION_AT_LEAST_1_9_0 + + +def exception_inside_backward_pass(error: Type[Exception]) -> Type[Exception]: + """Returns the type of exception that gets raised inside a backward pass by PyTorch. + + For Torch>=1.7.0 the error is identical. + + Args: + error: previous exception type + + Returns: + new exception type + """ + if TORCH_VERSION_AT_LEAST_1_7_0: + return error + else: + return RuntimeError diff --git a/backpack/utils/hooks.py b/backpack/utils/hooks.py index d86381ec6..c4a8aff68 100644 --- a/backpack/utils/hooks.py +++ b/backpack/utils/hooks.py @@ -2,5 +2,10 @@ def no_op(*args, **kwargs): - """Placeholder function that accepts arbitrary input and does nothing.""" - return None + """Placeholder function that accepts arbitrary input and does nothing. + + Args: + *args: anything + **kwargs: anything + """ + pass diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py new file mode 100644 index 000000000..7be68d96a --- /dev/null +++ b/backpack/utils/module_classification.py @@ -0,0 +1,30 @@ +"""Contains util function for classification of modules.""" + +from torch.nn import Module, Sequential +from torch.nn.modules.loss import _Loss + +from backpack.custom_module.reduce_tuple import ReduceTuple + + +def is_loss(module: Module) -> bool: + """Return whether `module` is a `torch` loss function. + + Args: + module: A PyTorch module. + + Returns: + Whether `module` is a loss function. + """ + return isinstance(module, _Loss) + + +def is_no_op(module: Module) -> bool: + """Return whether the module does no operation in graph. + + Args: + module: module + + Returns: + whether module is no operation + """ + return isinstance(module, (Sequential, ReduceTuple)) diff --git a/docs_src/examples/use_cases/example_first_order_resnet.py b/docs_src/examples/use_cases/example_first_order_resnet.py index df3ded93d..0a04b4ce3 100644 --- a/docs_src/examples/use_cases/example_first_order_resnet.py +++ b/docs_src/examples/use_cases/example_first_order_resnet.py @@ -53,7 +53,7 @@ def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10): def forward(self, x): residual = self.shortcut(x) x = self.conv2(F.relu(self.conv1(x))) - x += residual + x = x + residual # don't use: x += residual x = x.view(x.size(0), -1) x = self.linear1(x) return x diff --git a/docs_src/examples/use_cases/example_trace_estimation.py b/docs_src/examples/use_cases/example_trace_estimation.py index 86c24ae78..8c2a57107 100644 --- a/docs_src/examples/use_cases/example_trace_estimation.py +++ b/docs_src/examples/use_cases/example_trace_estimation.py @@ -225,7 +225,7 @@ def hutchinson_trace_autodiff_blockwise(V): plt.semilogx( V_list, - trace_estimates, + [trace_estimate.cpu() for trace_estimate in trace_estimates], linestyle="--", color="orange", label="Hutchinson" if i == 0 else None, diff --git a/fully_documented.txt b/fully_documented.txt index 0c9ed68d8..f4944a8c8 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -1,5 +1,7 @@ setup.py +backpack/__init__.py +backpack/context.py backpack/custom_module/ backpack/core/derivatives/basederivatives.py @@ -14,7 +16,10 @@ backpack/core/derivatives/batchnorm_nd.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py +backpack/extensions/module_extension.py +backpack/extensions/saved_quantities.py backpack/extensions/mat_to_mat_jac_base.py +backpack/extensions/firstorder/base.py backpack/extensions/firstorder/gradient/base.py backpack/extensions/firstorder/gradient/rnn.py backpack/extensions/firstorder/gradient/__init__.py @@ -51,11 +56,15 @@ backpack/utils/linear.py backpack/utils/subsampling.py backpack/utils/errors.py backpack/utils/__init__.py +backpack/utils/module_classification.py +backpack/utils/hooks.py test/extensions/automated_settings.py test/extensions/problem.py test/extensions/utils.py test/extensions/test_backprop_extension.py +test/extensions/test_hooks.py +test/extensions/graph_clear_test.py test/extensions/firstorder/firstorder_settings.py test/extensions/firstorder/variance/ test/extensions/firstorder/batch_grad/batch_grad_settings.py diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index cb6f93120..8c2f8be21 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -1,15 +1,12 @@ """Convert problem settings.""" import copy -from test.core.derivatives.utils import ( - derivative_cls_for, - get_available_devices, - is_loss, -) +from test.core.derivatives.utils import derivative_cls_for, get_available_devices import torch from backpack import extend +from backpack.utils.module_classification import is_loss def make_test_problems(settings): diff --git a/test/core/derivatives/utils.py b/test/core/derivatives/utils.py index c5e14b71d..fe8bbda35 100644 --- a/test/core/derivatives/utils.py +++ b/test/core/derivatives/utils.py @@ -44,18 +44,6 @@ def derivative_cls_for(module_cls: Type[Module]) -> Type[BaseDerivatives]: ) -def is_loss(module: Module) -> bool: - """Return whether `module` is a `torch` loss function. - - Args: - module: A PyTorch module. - - Returns: - Whether `module` is a loss function. - """ - return isinstance(module, torch.nn.modules.loss._Loss) - - def classification_targets(size: Tuple[int, ...], num_classes: int) -> Tensor: """Create random targets for classes 0, ..., `num_classes - 1`. diff --git a/test/extensions/graph_clear_test.py b/test/extensions/graph_clear_test.py new file mode 100644 index 000000000..17b6419ee --- /dev/null +++ b/test/extensions/graph_clear_test.py @@ -0,0 +1,61 @@ +"""Test whether the graph is clear after a backward pass.""" +from typing import Tuple + +from pytest import fixture +from torch import Tensor, rand, rand_like +from torch.nn import Flatten, Linear, Module, MSELoss, ReLU, Sequential + +from backpack import backpack, extend +from backpack.extensions import DiagGGNExact + +PROBLEM_STRING = ["standard", "flatten_no_op", "flatten_with_op"] + + +def test_graph_clear(problem) -> None: + """Test that the graph is clear after a backward pass. + + More specifically, test that there are no saved quantities left over. + + Args: + problem: problem consisting of inputs, and model + """ + inputs, model = problem + extension = DiagGGNExact() + outputs = extend(model)(inputs) + loss = extend(MSELoss())(outputs, rand_like(outputs)) + with backpack(extension): + loss.backward() + + # test that the dictionary is empty + saved_quantities: dict = extension.saved_quantities._saved_quantities + assert type(saved_quantities) is dict + assert not saved_quantities + + +@fixture(params=PROBLEM_STRING, ids=PROBLEM_STRING) +def problem(request) -> Tuple[Tensor, Module]: + """Problem setting. + + Args: + request: pytest request, contains parameters + + Yields: + inputs and model + + Raises: + NotImplementedError: if problem string is unknown + """ + batch_size, in_dim, out_dim = 2, 3, 4 + inputs = rand(batch_size, in_dim) + if request.param == PROBLEM_STRING[0]: + model = Sequential(Linear(in_dim, out_dim), ReLU(), Linear(out_dim, out_dim)) + elif request.param == PROBLEM_STRING[1]: + model = Sequential(Linear(in_dim, out_dim), Flatten(), Linear(out_dim, out_dim)) + elif request.param == PROBLEM_STRING[2]: + inputs = rand(batch_size, in_dim, in_dim) + model = Sequential( + Linear(in_dim, out_dim), Flatten(), Linear(in_dim * out_dim, out_dim) + ) + else: + raise NotImplementedError(f"unknown request.param={request.param}") + yield inputs, model diff --git a/test/extensions/implementation/hooks.py b/test/extensions/implementation/hooks.py index cd3bba1a5..5e6cc2f4c 100644 --- a/test/extensions/implementation/hooks.py +++ b/test/extensions/implementation/hooks.py @@ -1,4 +1,4 @@ -"""Post extension hooks to compact BackPACK quantities during backpropagation.""" +"""Extension hooks to compact BackPACK quantities during backpropagation.""" class ExtensionHookManager: diff --git a/test/extensions/secondorder/hbp/test_kfac.py b/test/extensions/secondorder/hbp/test_kfac.py index ec494a5de..a5c76b6c5 100644 --- a/test/extensions/secondorder/hbp/test_kfac.py +++ b/test/extensions/secondorder/hbp/test_kfac.py @@ -6,7 +6,7 @@ import pytest -from backpack.utils import TORCH_VERSION, VERSION_1_6_0 +from backpack.utils import exception_inside_backward_pass NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -21,8 +21,7 @@ def test_kfac_not_supported(problem): """ problem.set_up() - exception = RuntimeError if TORCH_VERSION == VERSION_1_6_0 else NotImplementedError - with pytest.raises(exception): + with pytest.raises(exception_inside_backward_pass(NotImplementedError)): BackpackExtensions(problem).kfac() problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kflr.py b/test/extensions/secondorder/hbp/test_kflr.py index 0a4e04e46..d5c6819d0 100644 --- a/test/extensions/secondorder/hbp/test_kflr.py +++ b/test/extensions/secondorder/hbp/test_kflr.py @@ -6,7 +6,7 @@ import pytest -from backpack.utils import TORCH_VERSION, VERSION_1_6_0 +from backpack.utils import exception_inside_backward_pass NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -21,8 +21,7 @@ def test_kflr_not_supported(problem): """ problem.set_up() - exception = RuntimeError if TORCH_VERSION == VERSION_1_6_0 else NotImplementedError - with pytest.raises(exception): + with pytest.raises(exception_inside_backward_pass(RuntimeError)): BackpackExtensions(problem).kflr() problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kfra.py b/test/extensions/secondorder/hbp/test_kfra.py index 943fd4e41..033171fb0 100644 --- a/test/extensions/secondorder/hbp/test_kfra.py +++ b/test/extensions/secondorder/hbp/test_kfra.py @@ -6,7 +6,7 @@ import pytest -from backpack.utils import TORCH_VERSION, VERSION_1_6_0 +from backpack.utils import exception_inside_backward_pass NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -21,8 +21,7 @@ def test_kfra_not_supported(problem): """ problem.set_up() - exception = RuntimeError if TORCH_VERSION == VERSION_1_6_0 else NotImplementedError - with pytest.raises(exception): + with pytest.raises(exception_inside_backward_pass(NotImplementedError)): BackpackExtensions(problem).kfra() problem.tear_down() diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index 44a4483b3..f5f4386a9 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -235,9 +235,11 @@ SECONDORDER_SETTINGS += [ { - # Flatten layer does not add a node in the computation graph and thus the - # backward hook will be called at an unexpected stage. This must explicitly - # be addressed in the `backpropagate` function of the flatten module extension. + # Flatten layer does not add a node in the PyTorch computation graph. + # Thus, the backward hook will be called at an unexpected stage. + # The register_full_backward_hook ensures the execution order is correct -> ok. + # The register_backward_hook has above problem and therefore needs to skip execution. + # This is done in the `backward` function or in the `__call__` of ModuleExtension. "input_fn": lambda: rand(3, 5), "module_fn": lambda: Sequential(Linear(5, 4), Flatten(), Linear(4, 2)), "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), diff --git a/test/extensions/test_hooks.py b/test/extensions/test_hooks.py index 119afa8be..325c1d265 100644 --- a/test/extensions/test_hooks.py +++ b/test/extensions/test_hooks.py @@ -3,89 +3,195 @@ These tests aim at demonstrating the pitfalls one may run into when using hooks that iterate over ``module.parameters()``. """ - from test.core.derivatives.utils import classification_targets, get_available_devices +from typing import Tuple -import pytest -import torch +from pytest import fixture, mark, raises +from torch import Tensor, manual_seed, rand +from torch.nn import CrossEntropyLoss, Linear, Module, Sequential -from backpack import backpack, extend, extensions +from backpack import backpack, extend +from backpack.extensions import BatchGrad, DiagGGNExact +from backpack.extensions.backprop_extension import FAIL_ERROR, BackpropExtension +from backpack.utils import exception_inside_backward_pass DEVICES = get_available_devices() DEVICES_ID = [str(dev) for dev in DEVICES] +NESTED_SEQUENTIAL = "NESTED_SEQUENTIAL" +CUSTOM_CONTAINER = "CUSTOM_CONTAINER" +problem_list = [NESTED_SEQUENTIAL, CUSTOM_CONTAINER] + + +@fixture(params=DEVICES, ids=DEVICES_ID) +def device(request): + """Yields the available device for the test. + + Args: + request: pytest request + + Yields: + an available device + """ + yield request.param -def set_up(device): - """Return extended nested sequential with loss from a forward pass.""" - torch.manual_seed(0) + +@fixture(params=problem_list, ids=problem_list) +def problem(device, request) -> Tuple[Module, Tensor, str]: + """Return extended nested sequential with loss from a forward pass. + + Args: + device: available device + request: pytest request + + Yields: + model, loss and problem_string + + Raises: + NotImplementedError: if the problem_string is unknown + """ + problem_string = request.param + manual_seed(0) B = 2 - X = torch.rand(B, 4).to(device) + X = rand(B, 4).to(device) y = classification_targets((B,), 2).to(device) - model = torch.nn.Sequential( - torch.nn.Linear(4, 3, bias=False), - torch.nn.Sequential( - torch.nn.Linear(3, 2, bias=False), - ), - ).to(device) + if problem_string == NESTED_SEQUENTIAL: + model = Sequential( + Linear(4, 3, bias=False), + Sequential( + Linear(3, 2, bias=False), + ), + ) + elif problem_string == CUSTOM_CONTAINER: + + class _MyCustomModule(Module): + def __init__(self): + super().__init__() + self.linear1 = Linear(4, 3, bias=False) + self.linear2 = Linear(3, 2, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + model = _MyCustomModule() + else: + raise NotImplementedError( + f"problem={problem_string} but no test setting for this." + ) + + model = extend(model.to(device)) + lossfunc = extend(CrossEntropyLoss(reduction="mean").to(device)) + loss = lossfunc(model(X), y) + yield model, loss, problem_string + - model = extend(model) - lossfunc = extend(torch.nn.CrossEntropyLoss(reduction="mean")) +@mark.parametrize( + "extension", [BatchGrad(), DiagGGNExact()], ids=["BatchGrad", "DiagGGNExact"] +) +def test_extension_hook_multiple_parameter_visits( + problem, extension: BackpropExtension +): + """Tests whether each parameter is visited exactly once. - loss = lossfunc(model(X), y) + For those cases where parameters are visited more than once (e.g. Custom containers), + it tests that an error is raised. - return model, loss + Furthermore, it is tested whether first order extensions run fine in either case, + and second order extensions raise an error in the case of custom containers. + Args: + problem: test problem, consisting of model, loss, and problem_string + extension: first or second order extension to test -@pytest.mark.parametrize("device", DEVICES, ids=DEVICES_ID) -def test_extension_hook_multiple_parameter_visits(device): - """Extension hooks iterating over parameters may traverse them more than once.""" - model, loss = set_up(device) + Raises: + NotImplementedError: if the problem_string is unknown + """ + model, loss, problem_string = problem params_visited = {id(p): 0 for p in model.parameters()} def count_visits(module): - """Increase counter in ``params_visited`` for all parameters in ``module``.""" + """Increase counter in ``params_visited`` for all parameters in ``module``. + + Args: + module: the module of which the parameter visits are counted + """ for p in module.parameters(): params_visited[id(p)] += 1 - with backpack(extension_hook=count_visits, debug=True): + if problem_string == CUSTOM_CONTAINER and extension._fail_mode == FAIL_ERROR: + with raises(exception_inside_backward_pass(NotImplementedError)): + with backpack(extension, extension_hook=count_visits, debug=True): + loss.backward() + return + with backpack(extension, extension_hook=count_visits, debug=True): loss.backward() - def check(): - """Raise ``AssertionError`` if a parameter has been visited more than once.""" + def check_all_parameters_visited_once(): + """Checks whether all parameters have been visited exactly once. + + Raises: + AssertionError: if a parameter hasn't been visited exactly once + """ for param_id, visits in params_visited.items(): - if visits == 0: - raise ValueError(f"Hook never visited param {param_id}") - elif visits == 1: - pass - else: - raise AssertionError(f"Hook visited param {param_id} {visits} times ") + if visits != 1: + raise AssertionError(f"Hook visited param {param_id} {visits}≠1 times") + + if problem_string == NESTED_SEQUENTIAL: + check_all_parameters_visited_once() + elif problem_string == CUSTOM_CONTAINER: + with raises(AssertionError): + check_all_parameters_visited_once() + else: + raise NotImplementedError(f"unknown problem_string={problem_string}") - with pytest.raises(AssertionError): - check() +def test_extension_hook_param_before_savefield_exists(problem): + """Extension hooks iterating over parameters may get called before BackPACK. -@pytest.mark.parametrize("device", DEVICES, ids=DEVICES_ID) -def test_extension_hook_param_before_savefield_exists(device): - """Extension hooks iterating over parameters may get called before BackPACK.""" - _, loss = set_up(device) + This leads to the case, that the BackPACK quantities might not be calculated yet. + Thus, derived quantities cannot be calculated. + + Sequential containers just work fine. + Custom containers crash. + + Args: + problem: problem consisting of model, loss, and problem_string + + Raises: + NotImplementedError: if problem_string is unknown + """ + _, loss, problem_string = problem params_without_grad_batch = [] def check_grad_batch(module): - """Raise ``AssertionError`` if one parameter misses ``'grad_batch'``.""" + """Check whether the module has a grad_batch attribute. + + Args: + module: the module to check + + Raises: + AssertionError: if a parameter does not have grad_batch attribute. + """ for p in module.parameters(): if not hasattr(p, "grad_batch"): params_without_grad_batch.append(id(p)) raise AssertionError(f"Param {id(p)} has no 'grad_batch' attribute") - # AssertionError is caught inside BackPACK and will raise a RuntimeError - with pytest.raises(RuntimeError): - with backpack( - extensions.BatchGrad(), extension_hook=check_grad_batch, debug=True - ): + if problem_string == NESTED_SEQUENTIAL: + with backpack(BatchGrad(), extension_hook=check_grad_batch, debug=True): loss.backward() - assert len(params_without_grad_batch) > 0 + assert len(params_without_grad_batch) == 0 + elif problem_string == CUSTOM_CONTAINER: + with raises(exception_inside_backward_pass(AssertionError)): + with backpack(BatchGrad(), extension_hook=check_grad_batch, debug=True): + loss.backward() + assert len(params_without_grad_batch) > 0 + else: + raise NotImplementedError(f"unknown problem_string={problem_string}") diff --git a/test/test_simple_resnet.py b/test/test_simple_resnet.py deleted file mode 100644 index 27e3a9380..000000000 --- a/test/test_simple_resnet.py +++ /dev/null @@ -1,144 +0,0 @@ -"""An example to check if BackPACK' first-order extensions are working for ResNets.""" - -from test.core.derivatives.utils import classification_targets - -import torch - -from backpack import backpack, extend, extensions - -from .automated_test import check_sizes, check_values - - -def autograd_individual_gradients(X, y, model, loss_func): - """Individual gradients via for loop with automatic differentiation. - - Args: - X (torch.Tensor): Mini-batch of shape `(N, *)` - y (torch.Tensor: Labels for `X` - model (torch.nn.Module): Model for forward pass - loss_func (torch.nn.Module): Loss function for model prediction - - Returns: - [torch.Tensor]: Individual gradients for samples in the mini-batch - with respect to the model parameters. Arranged in the same order - as `model.parameters()`. - """ - N = X.shape[0] - reduction_factor = _get_reduction_factor(X, y, model, loss_func) - - individual_gradients = [ - torch.zeros(N, *p.shape).to(X.device) for p in model.parameters() - ] - - for n in range(N): - x_n = X[n].unsqueeze(0) - y_n = y[n].unsqueeze(0) - - f_n = model(x_n) - l_n = loss_func(f_n, y_n) / reduction_factor - - g_n = torch.autograd.grad(l_n, model.parameters()) - - for idx, g in enumerate(g_n): - individual_gradients[idx][n] = g - - return individual_gradients - - -def _get_reduction_factor(X, y, model, loss_func): - """Return reduction factor of loss function.""" - N = X.shape[0] - - x_0 = X[0].unsqueeze(0) - y_0 = y[0].unsqueeze(0) - - x_0_repeated = x_0.repeat([N if pos == 0 else 1 for pos, _ in enumerate(X.shape)]) - y_0_repeated = y_0.repeat([N if pos == 0 else 1 for pos, _ in enumerate(y.shape)]) - - individual_loss = loss_func(model(x_0), y_0) - reduced_loss = loss_func(model(x_0_repeated), y_0_repeated) - - return (N * individual_loss / reduced_loss).item() - - -def backpack_individual_gradients(X, y, model, loss_func): - """Individual gradients with BackPACK. - - Args: - X (torch.Tensor): Mini-batch of shape `(N, *)` - y (torch.Tensor: Labels for `X` - model (torch.nn.Module): Model for forward pass - loss_func (torch.nn.Module): Loss function for model prediction - - Returns: - [torch.Tensor]: Individual gradients for samples in the mini-batch - with respect to the model parameters. Arranged in the same order - as `model.parameters()`. - """ - model = extend(model) - loss_func = extend(loss_func) - - loss = loss_func(model(X), y) - - with backpack(extensions.BatchGrad()): - loss.backward() - - individual_gradients = [p.grad_batch for p in model.parameters()] - - return individual_gradients - - -class Identity(torch.nn.Module): - """Identity operation.""" - - def forward(self, input): - return input - - -class Parallel(torch.nn.Sequential): - """Feed input to multiple modules, sum the result. - - |-----| - | -> | f_1 | -> | - | |-----| | - | | - | |-----| | - x ->| -> | f_2 | -> + -> f₁(x) + f₂(x) + ... - | |-----| | - | | - | |-----| | - | -> | ... | -> | - |-----| - - """ - - def forward(self, input): - """Process input with all modules, sum the output.""" - for idx, module in enumerate(self.children()): - if idx == 0: - output = module(input) - else: - output = output + module(input) - - return output - - -def test_individual_gradients_simple_resnet(): - """Individual gradients for a simple ResNet with autodiff and BackPACK.""" - - # batch size, feature dimension - N, D = 2, 5 - # classification - C = 3 - - X = torch.rand(N, D) - y = classification_targets((N,), num_classes=C) - - model = Parallel(Identity(), torch.nn.Linear(D, D, bias=True)) - loss_func = torch.nn.CrossEntropyLoss(reduction="sum") - - result_autograd = autograd_individual_gradients(X, y, model, loss_func) - result_backpack = backpack_individual_gradients(X, y, model, loss_func) - - check_sizes(result_autograd, result_backpack) - check_values(result_autograd, result_backpack) From ba287a27594b89804215faa0d76df13bb5f99bb1 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 4 Aug 2021 15:22:46 +0200 Subject: [PATCH 39/54] [REF] Improve batch axis detection (#208) Sub-sampling relies on slicing tensors along their batch axis. So far we detect the batch axis based on the module. However, BackPACK's custom `Permute` module can alter the axes of its input tensor and therefore modify the batch axis position. To account for this, the batch axis not only has to be determined based on the module, but also on the direction of the tensor, i.e. `input` or `output`. This PR introduces such a tensor-agnostic batch size detection. Progress on #12: - Replace hard-coded batch axes in tests with dynamical ones - Make `jac_t_mat_prod` tests work for `Permute` module - Support arbitrary batch axes in `subsample` --- * [ADD] Make batch_axis depend on IO tensor The `Permute` module can modify the batch axis position in its forward pass. Therefore, input and output tensors have different batch axes. * [FIX] Make `jac_t_mat_prod` tests work for `Permute` * [REF] Replace hard-coded batch axis in test case forward pass * [DEL] Remove equality check of batch_axis_in/out * [REF] Generalize slicing in `subsample` to arbitrary axes --- backpack/core/derivatives/basederivatives.py | 2 +- backpack/core/derivatives/lstm.py | 8 ++- backpack/core/derivatives/rnn.py | 12 ++-- backpack/core/derivatives/shape_check.py | 2 +- backpack/custom_module/permute.py | 25 ++++++++- .../firstorder/batch_grad/batch_grad_base.py | 4 +- backpack/utils/subsampling.py | 50 +++++++++++------ fully_documented.txt | 2 + test/core/derivatives/derivatives_test.py | 4 +- .../derivatives/implementation/autograd.py | 15 ++--- test/core/derivatives/problem.py | 21 ++++--- test/custom_module/__init__.py | 1 + test/custom_module/test_permute.py | 25 +++++++++ .../firstorder/firstorder_settings.py | 8 +-- test/extensions/implementation/autograd.py | 2 +- test/extensions/problem.py | 20 ++++--- .../secondorder/diag_ggn/diag_ggn_settings.py | 4 +- test/extensions/utils.py | 2 +- test/utils/skip_test.py | 18 +----- test/utils/test_subsampling.py | 55 +++++++++++++++++++ 20 files changed, 201 insertions(+), 79 deletions(-) create mode 100644 test/custom_module/__init__.py create mode 100644 test/custom_module/test_permute.py create mode 100644 test/utils/test_subsampling.py diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index dd1205d45..677d20a31 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -289,7 +289,7 @@ def reshape_like_input( """ shape = list(module.input0.shape) if subsampling is not None: - shape[get_batch_axis(module)] = len(subsampling) + shape[get_batch_axis(module, "input0")] = len(subsampling) return cls._reshape_like(mat, shape) diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 6d68dd764..173e9633d 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -86,7 +86,7 @@ def _forward_pass( c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) h: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) - N_axis = get_batch_axis(module) + N_axis = get_batch_axis(module, "input0") input0 = subsample(module.input0, dim=N_axis, subsampling=subsampling) output = subsample(module.output, dim=N_axis, subsampling=subsampling) @@ -347,7 +347,9 @@ def _weight_ih_l0_jac_t_mat_prod( f"vtnh,tni->v{'' if sum_batch else 'n'}hi", IFGO_prod, subsample( - module.input0, dim=get_batch_axis(module), subsampling=subsampling + module.input0, + dim=get_batch_axis(module, "input0"), + subsampling=subsampling, ), ) @@ -377,7 +379,7 @@ def _weight_hh_l0_jac_t_mat_prod( zeros(1, N, H, device=mat.device, dtype=mat.dtype), subsample( module.output, - dim=get_batch_axis(module), + dim=get_batch_axis(module, "input0"), subsampling=subsampling, )[0:-1], ], diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 422abd93c..a049b13c1 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -104,7 +104,9 @@ def _jac_t_mat_prod( "vtnh,hk->vtnk", self._a_jac_t_mat_prod( subsample( - module.output, dim=get_batch_axis(module), subsampling=subsampling + module.output, + dim=get_batch_axis(module, "input0"), + subsampling=subsampling, ), module.weight_hh_l0, mat, @@ -171,7 +173,9 @@ def _bias_ih_l0_jac_t_mat_prod( dim: int = 1 return self._a_jac_t_mat_prod( subsample( - module.output, dim=get_batch_axis(module), subsampling=subsampling + module.output, + dim=get_batch_axis(module, "input0"), + subsampling=subsampling, ), module.weight_hh_l0, mat, @@ -226,7 +230,7 @@ def _weight_ih_l0_jac_t_mat_prod( product """ self._check_parameters(module) - N_axis = get_batch_axis(module) + N_axis = get_batch_axis(module, "input0") return einsum( "vtnh,tnj->" + ("vhj" if sum_batch else "vnhj"), self._a_jac_t_mat_prod( @@ -260,7 +264,7 @@ def _weight_hh_l0_jac_t_mat_prod( product """ self._check_parameters(module) - N_axis = get_batch_axis(module) + N_axis = get_batch_axis(module, "input0") N: int = mat.shape[N_axis + 1] H: int = mat.shape[3] output = subsample(module.output, dim=N_axis, subsampling=subsampling) diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index fcf84634b..4e0755c1a 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -73,7 +73,7 @@ def _check_like(mat, module, name, diff=1, *args, **kwargs): if name in ["output", "input0"] and "subsampling" in kwargs.keys(): compare = subsample( getattr(module, name), - dim=get_batch_axis(module), + dim=get_batch_axis(module, name), subsampling=kwargs["subsampling"], ) else: diff --git a/backpack/custom_module/permute.py b/backpack/custom_module/permute.py index 04b4ea619..11ac40ba6 100644 --- a/backpack/custom_module/permute.py +++ b/backpack/custom_module/permute.py @@ -8,14 +8,17 @@ class Permute(Module): """Module to permute a tensor.""" - def __init__(self, *dims: Any): + def __init__(self, *dims: Any, batch_axis: int = 0): """Initialization. Args: dims: The desired ordering of dimensions. + batch_axis: Which axis assumed to be the batch axis in a forward pass. + Defaults to ``0``. """ super().__init__() self.dims = dims + self.batch_axis = batch_axis def forward(self, input: Tensor) -> Tensor: """Permutes the input tensor. @@ -27,3 +30,23 @@ def forward(self, input: Tensor) -> Tensor: view with new ordering """ return input.permute(self.dims) + + def get_batch_axis(self, io_str: str) -> int: + """Return the batch axis assumed by the module. + + Args: + io_str: Name of the tensor. Must be ``'input0'`` or ``'output'``. + + Returns: + Batch axis + + Raises: + ValueError: For invalid IO names. + """ + if io_str == "input0": + return self.batch_axis + elif io_str == "output": + return self.dims.index(self.batch_axis) + else: + valid_io_strs = ["input0", "output"] + raise ValueError(f"io_str must be in {valid_io_strs}, got {io_str}.") diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index c0ee49729..5a3198fa7 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -85,7 +85,9 @@ def param_function( g_inp, g_out, subsample( - g_out[0], dim=get_batch_axis(module), subsampling=subsampling + g_out[0], + dim=get_batch_axis(module, "output"), + subsampling=subsampling, ), sum_batch=False, subsampling=subsampling, diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py index 68a02c34f..d1d9b739d 100644 --- a/backpack/utils/subsampling.py +++ b/backpack/utils/subsampling.py @@ -2,7 +2,9 @@ from typing import List from torch import Tensor -from torch.nn import LSTM, RNN, Module +from torch.nn import LSTM, RNN, Module, Sequential + +from backpack.custom_module.permute import Permute def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: @@ -16,34 +18,48 @@ def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Te Returns: Tensor of same rank that is sub-sampled along the dimension. - - Raises: - NotImplementedError: If dimension differs from ``0`` or ``1``. """ if subsampling is None: return tensor else: - if dim == 0: - return tensor[subsampling] - elif dim == 1: - return tensor[:, subsampling] - else: - raise NotImplementedError(f"Only supports dim = 0,1. Got {dim}.") + return tensor[(slice(None),) * dim + (subsampling,)] -def get_batch_axis(module: Module) -> int: +def get_batch_axis(module: Module, io_str: str) -> int: """Return the batch axis assumed by the module. + For unknown modules the default axis is determined as ``0``. + Args: module: A module. + io_str: Name of the tensor stored as BackPACK IO. Must be ``'input0'`` or + ``'output'``. + + Note: + This method only inspects single modules and therefore cannot detect whether + the batch axis has been modified by preceding ones. For instance for a ReLU + module, the batch axis will always be detected as ``0``, although the layer + still works if preceded by a ``Permute(0, 1)`` module, but would have batch + axis ``1``. Returns: Batch axis + + Raises: + ValueError: For invalid IO names. """ + valid_io_strs = ["input0", "output"] + if io_str not in valid_io_strs: + raise ValueError(f"io_str must be in {valid_io_strs}, got {io_str}.") + + batch_axis = 0 + if isinstance(module, (RNN, LSTM)): - if module.batch_first: - return 0 - else: - return 1 - else: - return 0 + batch_axis = 0 if module.batch_first else 1 + elif isinstance(module, Permute): + batch_axis = module.get_batch_axis(io_str) + elif isinstance(module, Sequential): + child_idx = {"input0": 0, "output": -1}[io_str] + batch_axis = get_batch_axis(list(module.children())[child_idx], io_str) + + return batch_axis diff --git a/fully_documented.txt b/fully_documented.txt index f4944a8c8..15b013b6a 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -89,3 +89,5 @@ test/core/derivatives/batch_norm_settings.py test/utils/evaluation_mode.py test/utils/skip_test.py test/utils/__init__.py +test/utils/test_subsampling.py +test/custom_module/ diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index c4e9942fa..6cb2f3e69 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -20,7 +20,6 @@ from test.utils.skip_test import ( skip_adaptive_avg_pool3d_cuda, skip_batch_norm_train_mode_with_subsampling, - skip_permute_with_subsampling, skip_subsampling_conflict, ) from typing import List, Union @@ -144,7 +143,6 @@ def test_jac_t_mat_prod( skip_adaptive_avg_pool3d_cuda(request) problem.set_up() - skip_permute_with_subsampling(problem, subsampling) skip_batch_norm_train_mode_with_subsampling(problem, subsampling) skip_subsampling_conflict(problem, subsampling) mat = rand_mat_like_output(V, problem, subsampling=subsampling) @@ -187,7 +185,7 @@ def rand_mat_like_output( subsample_shape = list(problem.output_shape) if subsampling is not None: - N_axis = get_batch_axis(problem.module) + N_axis = get_batch_axis(problem.module, "output") subsample_shape[N_axis] = len(subsampling) return rand(V, *subsample_shape, device=problem.device) diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index c52b2831b..9221624c5 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -43,19 +43,20 @@ def jac_t_vec_prod(self, vec: Tensor, subsampling=None) -> Tensor: # noqa: D102 else: # for each sample, multiply by full input Jacobian, slice out result: # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n] - batch_axis = get_batch_axis(self.problem.module) - output = subsample(output, dim=batch_axis, subsampling=subsampling) - output = output.split(1, dim=batch_axis) - vec = vec.split(1, dim=batch_axis) + batch_axis_out = get_batch_axis(self.problem.module, "output") + output = subsample(output, dim=batch_axis_out, subsampling=subsampling) + output = output.split(1, dim=batch_axis_out) + vec = vec.split(1, dim=batch_axis_out) + batch_axis_in = get_batch_axis(self.problem.module, "input0") vjps: List[Tensor] = [] for sample_idx, out, v in zip(subsampling, output, vec): vjp = transposed_jacobian_vector_product(out, input, v)[0] - vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx]) + vjp = subsample(vjp, dim=batch_axis_in, subsampling=[sample_idx]) vjps.append(vjp) - return cat(vjps, dim=batch_axis) + return cat(vjps, dim=batch_axis_in) def jac_t_mat_prod( self, mat: Tensor, subsampling: List[int] = None @@ -75,7 +76,7 @@ def param_mjp( param_str, vec, sum_batch, - axis_batch=get_batch_axis(self.problem.module), + axis_batch=get_batch_axis(self.problem.module, "output"), subsampling=subsampling, ) for vec in mat diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index 8c2f8be21..43fb0dcd0 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -2,11 +2,14 @@ import copy from test.core.derivatives.utils import derivative_cls_for, get_available_devices +from typing import Dict, Tuple import torch +from torch import Tensor from backpack import extend from backpack.utils.module_classification import is_loss +from backpack.utils.subsampling import get_batch_axis, subsample def make_test_problems(settings): @@ -138,25 +141,27 @@ def make_output_shape(self): def is_loss(self): return is_loss(self.make_module()) - def forward_pass(self, input_requires_grad=False, sample_idx=None): + def forward_pass( + self, input_requires_grad: bool = False, sample_idx: int = None + ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: """Do a forward pass. Return input, output, and parameters.""" - if sample_idx is None: - input = self.input.clone().detach() - else: - input = self.input.clone()[sample_idx, :].unsqueeze(0).detach() + input: Tensor = self.input.clone().detach() + if sample_idx is not None: + batch_axis_in = get_batch_axis(self.module, "input0") + input = subsample(input, dim=batch_axis_in, subsampling=[sample_idx]) if input_requires_grad: input.requires_grad = True if self.is_loss(): assert sample_idx is None - output = self.module(input, self.target) + output: Tensor = self.module(input, self.target) else: - output = self.module(input) + output: Tensor = self.module(input) if isinstance(output, tuple): # is true for RNN,GRU,LSTM which return tuple (output, ...) - output = output[0] + output: Tensor = output[0] return input, output, dict(self.module.named_parameters()) diff --git a/test/custom_module/__init__.py b/test/custom_module/__init__.py new file mode 100644 index 000000000..b998b1a6b --- /dev/null +++ b/test/custom_module/__init__.py @@ -0,0 +1 @@ +"""Contains tests for BackPACK's custom modules.""" diff --git a/test/custom_module/test_permute.py b/test/custom_module/test_permute.py new file mode 100644 index 000000000..f6d185236 --- /dev/null +++ b/test/custom_module/test_permute.py @@ -0,0 +1,25 @@ +"""Contains tests for BackPACK's custom ``Permute`` module.""" + +from pytest import raises + +from backpack.custom_module.permute import Permute + + +def test_get_batch_axis(): + """Test the Permute module's batch axis detection.""" + # invalid argument + with raises(ValueError): + invalid_io_str = "dummy" + Permute().get_batch_axis(invalid_io_str) + + # batch axis unaffected by forward pass + assert Permute(0, 2, 1).get_batch_axis("input0") == 0 + assert Permute(0, 2, 1).get_batch_axis("output") == 0 + + # batch axis first, affected by forward pass + assert Permute(1, 2, 0).get_batch_axis("input0") == 0 + assert Permute(1, 2, 0).get_batch_axis("output") == 2 + + # batch axis second, affected by forward pass + assert Permute(1, 2, 0, batch_axis=1).get_batch_axis("input0") == 1 + assert Permute(1, 2, 0, batch_axis=1).get_batch_axis("output") == 0 diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 55ca07eb3..99d46914b 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -260,10 +260,10 @@ { "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( - Permute(1, 0, 2), + Permute(1, 0, 2, batch_axis=0), RNN(input_size=6, hidden_size=3), ReduceTuple(index=0), - Permute(1, 2, 0), + Permute(1, 2, 0, batch_axis=1), ), "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((8, 5), 3), @@ -271,10 +271,10 @@ { "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( - Permute(1, 0, 2), + Permute(1, 0, 2, batch_axis=0), RNN(input_size=6, hidden_size=3), ReduceTuple(index=0), - Permute(1, 2, 0), + Permute(1, 2, 0, batch_axis=1), Flatten(), ), "loss_function_fn": lambda: MSELoss(), diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index ccee35a0a..bf81604a0 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -17,7 +17,7 @@ class AutogradExtensions(ExtensionsImplementation): def batch_grad( self, subsampling: Union[List[int], None] ) -> List[Tensor]: # noqa: D102 - N = self.problem.input.shape[get_batch_axis(self.problem.model)] + N = self.problem.input.shape[get_batch_axis(self.problem.model, "input0")] samples = list(range(N)) if subsampling is None else subsampling loss_list = zeros(N) diff --git a/test/extensions/problem.py b/test/extensions/problem.py index dcb620458..9edaaa9d0 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -2,12 +2,14 @@ import copy from test.core.derivatives.utils import get_available_devices -from typing import Any, Iterator, List +from typing import Any, Iterator, List, Tuple import torch +from torch import Tensor from torch.nn.parameter import Parameter from backpack import extend +from backpack.utils.subsampling import get_batch_axis, subsample def make_test_problems(settings): @@ -130,26 +132,28 @@ def make_id(self): ).replace(" ", "") ) - def forward_pass(self, sample_idx=None): + def forward_pass(self, sample_idx: int = None) -> Tuple[Tensor, Tensor, Tensor]: """Do a forward pass. Return input, output, and parameters. The forward pass is performed on the selected index. If the index is None, then the forward pass is calculated for the whole batch. Args: - sample_idx (int, optional): Index of the sample to select. - Defaults to None. + sample_idx: Index of the sample to select. Defaults to ``None``. Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - input, output, loss, each with batch axis first + input, output, loss, each with batch axis first """ if sample_idx is None: input = self.input.clone() target = self.target.clone() else: - target = self.target.split(1, dim=0)[sample_idx] - input = self.input.split(1, dim=0)[sample_idx] + batch_axis_in = get_batch_axis(self.model, "input0") + batch_axis_out = get_batch_axis(self.model, "output") + target = subsample( + self.target, dim=batch_axis_out, subsampling=[sample_idx] + ) + input = subsample(self.input, dim=batch_axis_in, subsampling=[sample_idx]) output = self.model(input) diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 596d68bf2..3a75f207f 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -42,10 +42,10 @@ { "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( - Permute(1, 0, 2), + Permute(1, 0, 2, batch_axis=0), RNN(input_size=6, hidden_size=3), ReduceTuple(index=0), - Permute(1, 2, 0), + Permute(1, 2, 0, batch_axis=1), Flatten(), ), "loss_function_fn": lambda: MSELoss(), diff --git a/test/extensions/utils.py b/test/extensions/utils.py index 2705e92a8..83129772a 100644 --- a/test/extensions/utils.py +++ b/test/extensions/utils.py @@ -17,7 +17,7 @@ def skip_if_subsampling_conflict( problem: Test case. subsampling: Indices of active samples. """ - N = problem.input.shape[get_batch_axis(problem.model)] + N = problem.input.shape[get_batch_axis(problem.model, "input0")] enough_samples = subsampling is None or N > max(subsampling) if not enough_samples: skip(f"Not enough samples: N={N}, subsampling={subsampling}") diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index 9d02819bf..b1e10554a 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -6,7 +6,6 @@ from pytest import skip from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d -from backpack.custom_module.permute import Permute from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_1 from backpack.utils.subsampling import get_batch_axis @@ -30,21 +29,6 @@ def skip_adaptive_avg_pool3d_cuda(request) -> None: ) -def skip_permute_with_subsampling( - problem: DerivativesTestProblem, subsampling: Union[List[int], None] -) -> None: - """Skip Permute module when sub-sampling is turned on. - - Permute does not assume a batch axis. - - Args: - problem: Test case. - subsampling: Indices of active samples. - """ - if isinstance(problem.module, Permute) and subsampling is not None: - skip(f"Skipping Permute with sub-sampling: {subsampling}") - - def skip_batch_norm_train_mode_with_subsampling( problem: DerivativesTestProblem, subsampling: Union[List[int], None] ) -> None: @@ -68,7 +52,7 @@ def skip_subsampling_conflict( problem: Test case. subsampling: Indices of active samples. """ - N = problem.input_shape[get_batch_axis(problem.module)] + N = problem.input_shape[get_batch_axis(problem.module, "input0")] enough_samples = subsampling is None or N > max(subsampling) if not enough_samples: skip("Not enough samples.") diff --git a/test/utils/test_subsampling.py b/test/utils/test_subsampling.py new file mode 100644 index 000000000..50309552a --- /dev/null +++ b/test/utils/test_subsampling.py @@ -0,0 +1,55 @@ +"""Contains tests of sub-sampling functionality.""" + +from pytest import raises +from torch import allclose, manual_seed, rand +from torch.nn import Linear, ReLU, Sequential + +from backpack.custom_module.permute import Permute +from backpack.utils.subsampling import get_batch_axis, subsample + + +def test_get_batch_axis(): + """Test batch axis detection.""" + # invalid argument + with raises(ValueError): + invalid_io_str = "dummy" + some_module = Linear(1, 1) + get_batch_axis(some_module, invalid_io_str) + + # Sequential with unaltered batch axis + model = Sequential(Linear(1, 1), ReLU()) + assert get_batch_axis(model, "input0") == 0 + assert get_batch_axis(model, "output") == 0 + + # Sequential with altered batch axis + model = Sequential(Linear(1, 1), Permute(1, 0)) + assert get_batch_axis(model, "input0") == 0 + assert get_batch_axis(model, "output") == 1 + + # Permute + model = Permute(1, 3, 2, 0, batch_axis=0) + assert get_batch_axis(model, "input0") == 0 + assert get_batch_axis(model, "output") == 3 + + model = Sequential(Permute(0, 1), ReLU()) + assert get_batch_axis(model, "input0") == 0 + # expected failure due to local inspection + batch_axis_output = 1 + assert get_batch_axis(model, "output") != batch_axis_output + + +def test_subsample(): + """Test slicing operations for sub-sampling a tensor's batch axis.""" + manual_seed(0) + tensor = rand(3, 4, 5, 6) + + # leave tensor untouched when `subsampling = None` + assert id(subsample(tensor)) == id(tensor) + assert allclose(subsample(tensor), tensor) + + # slice along correct dimension + idx = [2, 0] + assert allclose(subsample(tensor, dim=0, subsampling=idx), tensor[idx]) + assert allclose(subsample(tensor, dim=1, subsampling=idx), tensor[:, idx]) + assert allclose(subsample(tensor, dim=2, subsampling=idx), tensor[:, :, idx]) + assert allclose(subsample(tensor, dim=3, subsampling=idx), tensor[:, :, :, idx]) From 4dcbfd1acafb6eb7753aefbf897b30a97c95cee5 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 5 Aug 2021 18:21:21 +0200 Subject: [PATCH 40/54] [ADD] Support sub-sampling in `SqrtGGN{MC,Exact}` extensions (#209) Adds a `subsampling` argument to the `SqrtGGN{MC,Exact}` extensions that allows to compute their results only on a sub-set of the mini-batch samples. Progress on #12. - Introduce `subsampling` argument in `BackpropExtension` such that the feature can be realized for other extensions in the future - Test suite - Reparameterize `forward_pass` function of derivatives and extension problems with `subsampling` (list) instead of `sample_idx` (int) - Extract computation of loss reduction factor - Normalize GGN-MC to [-1; 1] before comparing for easier tuning of `atol` --- * [ADD] Introduce subsampling argument in SqrtGGN{Exact,MC} * Generalize get_batch_axis to modules with children * [ADD] Introduce `subsampling` in BackpropExtension interface * [ADD] Consider subsampling in `SqrtGGN` module extensions * [ADD] Make skip condition work for derivatives & extension cases * [TEST] Parametrize forward pass with subsampling, extract reduction factor * [TEST] Test correctness of exact GGN with sub-sampling * [TEST] Check GGN-MC from square root for sub-samplings * [FIX] Docstring * [TEST] Normalize GGNMC to before comparing, reduce MC samples * [FIX] Don't normalize if GGN-MC is exactly zero * [REF] Move skip condition for large parameters * [ADD] Reintroduce montecarlo marker --- backpack/extensions/backprop_extension.py | 16 +++- backpack/extensions/firstorder/base.py | 7 +- .../firstorder/batch_grad/__init__.py | 13 +-- backpack/extensions/mat_to_mat_jac_base.py | 13 ++- .../secondorder/sqrt_ggn/__init__.py | 33 +++++-- .../extensions/secondorder/sqrt_ggn/base.py | 8 +- .../extensions/secondorder/sqrt_ggn/losses.py | 11 ++- .../derivatives/implementation/autograd.py | 2 +- test/core/derivatives/problem.py | 19 ++-- test/extensions/implementation/autograd.py | 52 ++++++----- test/extensions/implementation/backpack.py | 28 ++++-- test/extensions/implementation/base.py | 17 +++- test/extensions/problem.py | 65 ++++++++++---- .../secondorder/sqrt_ggn/test_sqrt_ggn.py | 89 +++++++++++-------- test/utils/skip_test.py | 21 ++++- 15 files changed, 264 insertions(+), 130 deletions(-) diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py index 0b584ade9..80ff00820 100644 --- a/backpack/extensions/backprop_extension.py +++ b/backpack/extensions/backprop_extension.py @@ -4,7 +4,7 @@ import abc import warnings from abc import ABC -from typing import Dict, Tuple, Type, Union +from typing import Dict, List, Tuple, Type, Union from torch import Tensor from torch.nn import Module @@ -39,6 +39,7 @@ def __init__( savefield: str, module_exts: Dict[Type[Module], ModuleExtension], fail_mode: str = FAIL_ERROR, + subsampling: List[int] = None, ): """Initializes parameters. @@ -51,6 +52,9 @@ def __init__( - "WARN": raise a UserWarning - "SILENT": skip the module silently Defaults to FAIL_ERROR = "ERROR" + subsampling: Indices of active mini-batch samples. ``None`` means + all samples in the mini-batch will be considered by the extension. + Defaults to ``None``. Raises: AssertionError: if fail_mode is not valid @@ -61,6 +65,7 @@ def __init__( self.savefield: str = savefield self.__module_extensions: Dict[Type[Module], ModuleExtension] = module_exts self._fail_mode: str = fail_mode + self._subsampling = subsampling def set_module_extension( self, module: Type[Module], extension: ModuleExtension, overwrite: bool = False @@ -129,3 +134,12 @@ def expects_backpropagation_quantities(self) -> bool: Whether the extension uses additional backpropagation quantities. """ return + + def get_subsampling(self) -> Union[List[int], None]: + """Return indices of active mini-batch samples. + + Returns: + Indices of samples considered by the extension. ``None`` signifies that + the full mini-batch is used. + """ + return self._subsampling diff --git a/backpack/extensions/firstorder/base.py b/backpack/extensions/firstorder/base.py index af761cd18..b3529df3f 100644 --- a/backpack/extensions/firstorder/base.py +++ b/backpack/extensions/firstorder/base.py @@ -1,5 +1,5 @@ """Base class for first order extensions.""" -from typing import Dict, Type +from typing import Dict, List, Type from torch.nn import Module @@ -19,8 +19,11 @@ def __init__( savefield: str, module_exts: Dict[Type[Module], ModuleExtension], fail_mode: str = FAIL_WARN, + subsampling: List[int] = None, ): # noqa: D107 - super().__init__(savefield, module_exts, fail_mode=fail_mode) + super().__init__( + savefield, module_exts, fail_mode=fail_mode, subsampling=subsampling + ) def expects_backpropagation_quantities(self) -> bool: # noqa: D102 return False diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 7edf6176e..4a4d05562 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -2,7 +2,7 @@ It defines the module extension for each module. """ -from typing import List, Union +from typing import List from torch.nn import ( RNN, @@ -81,14 +81,5 @@ def __init__(self, subsampling: List[int] = None): BatchNorm3d: batchnorm_nd.BatchGradBatchNormNd(), RNN: rnn.BatchGradRNN(), }, + subsampling=subsampling, ) - self._subsampling = subsampling - - def get_subsampling(self) -> Union[List[int], None]: - """Get the indices of samples for which individual gradients are requested. - - Returns: - List of indices containing the active samples in the mini-batch. ``None`` - means all samples will be considered. - """ - return self._subsampling diff --git a/backpack/extensions/mat_to_mat_jac_base.py b/backpack/extensions/mat_to_mat_jac_base.py index 23fd71d9e..937d78b33 100644 --- a/backpack/extensions/mat_to_mat_jac_base.py +++ b/backpack/extensions/mat_to_mat_jac_base.py @@ -5,6 +5,7 @@ from torch.nn import Module from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.module_extension import ModuleExtension @@ -23,7 +24,7 @@ def __init__(self, derivatives: BaseDerivatives, params: List[str] = None): def backpropagate( self, - ext: ModuleExtension, + ext: BackpropExtension, module: Module, grad_inp: Tuple[Tensor], grad_out: Tuple[Tensor], @@ -32,7 +33,7 @@ def backpropagate( """Propagates second order information back. Args: - ext: extension + ext: BackPACK extension module: module through which to perform backpropagation grad_inp: input gradients grad_out: output gradients @@ -41,12 +42,16 @@ def backpropagate( Returns: derivative wrt input """ + subsampling = ext.get_subsampling() + if isinstance(backproped, list): return [ - self.derivatives.jac_t_mat_prod(module, grad_inp, grad_out, M) + self.derivatives.jac_t_mat_prod( + module, grad_inp, grad_out, M, subsampling=subsampling + ) for M in backproped ] else: return self.derivatives.jac_t_mat_prod( - module, grad_inp, grad_out, backproped + module, grad_inp, grad_out, backproped, subsampling=subsampling ) diff --git a/backpack/extensions/secondorder/sqrt_ggn/__init__.py b/backpack/extensions/secondorder/sqrt_ggn/__init__.py index 2438b53d0..4ff71dbbc 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/__init__.py +++ b/backpack/extensions/secondorder/sqrt_ggn/__init__.py @@ -1,5 +1,7 @@ """Defines base class and extensions for computing the GGN/Fisher matrix square root.""" +from typing import List, Union + from torch.nn import ( ELU, SELU, @@ -46,13 +48,19 @@ class SqrtGGN(SecondOrderBackpropExtension): """Base class for extensions that compute the GGN/Fisher matrix square root.""" - def __init__(self, loss_hessian_strategy: str, savefield: str): + def __init__( + self, + loss_hessian_strategy: str, + savefield: str, + subsampling: Union[List[int], None], + ): """Store approximation for backpropagated object and where to save the result. Args: loss_hessian_strategy: Which approximation is used for the backpropagated loss Hessian. Must be ``'exact'`` or ``'sampling'``. savefield: Attribute under which the quantity is saved in a parameter. + subsampling: Indices of active samples. ``None`` uses the full mini-batch. """ self.loss_hessian_strategy = loss_hessian_strategy super().__init__( @@ -85,6 +93,7 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): ELU: activations.SqrtGGNELU(), SELU: activations.SqrtGGNSELU(), }, + subsampling=subsampling, ) def get_loss_hessian_strategy(self) -> str: @@ -103,7 +112,8 @@ class SqrtGGNExact(SqrtGGN): Stores the output in :code:`sqrt_ggn_exact`, has shape ``[C, N, param.shape]``, where ``C`` is the model output dimension (number of classes for classification - problems) and ``N`` is the batch size. + problems) and ``N`` is the batch size. If sub-sampling is enabled, ``N`` is + replaced by the number of active samples, ``len(subsampling)``. For a faster but less precise alternative, see :py:meth:`backpack.extensions.SqrtGGNMC`. @@ -116,9 +126,14 @@ class SqrtGGNExact(SqrtGGN): is the GGN/Fisher's matrix square root, i.e. ``G = V Vᵀ``. """ - def __init__(self): - """Use exact loss Hessian and set savefield to ``sqrt_ggn_exact``.""" - super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact") + def __init__(self, subsampling: List[int] = None): + """Use exact loss Hessian, store results under ``sqrt_ggn_exact``. + + Args: + subsampling: Indices of active samples. Defaults to ``None`` (use all + samples in the mini-batch). + """ + super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact", subsampling) class SqrtGGNMC(SqrtGGN): @@ -129,6 +144,8 @@ class SqrtGGNMC(SqrtGGN): Stores the output in :code:`sqrt_ggn_mc`, has shape ``[M, N, param.shape]``, where ``M`` is the number of Monte-Carlo samples and ``N`` is the batch size. + If sub-sampling is enabled, ``N`` is replaced by the number of active samples, + ``len(subsampling)``. For a more precise but slower alternative, see :py:meth:`backpack.extensions.SqrtGGNExact`. @@ -141,14 +158,16 @@ class SqrtGGNMC(SqrtGGN): is the approximate GGN/Fisher's matrix square root, i.e. ``G ≈ V Vᵀ``. """ - def __init__(self, mc_samples: int = 1): + def __init__(self, mc_samples: int = 1, subsampling: List[int] = None): """Approximate loss Hessian via MC and set savefield to ``sqrt_ggn_mc``. Args: mc_samples: Number of Monte-Carlo samples. Default: ``1``. + subsampling: Indices of active samples. Defaults to ``None`` (use all + samples in the mini-batch). """ self._mc_samples = mc_samples - super().__init__(LossHessianStrategy.SAMPLING, "sqrt_ggn_mc") + super().__init__(LossHessianStrategy.SAMPLING, "sqrt_ggn_mc", subsampling) def get_num_mc_samples(self) -> int: """Return the number of MC samples used to approximate the loss Hessian. diff --git a/backpack/extensions/secondorder/sqrt_ggn/base.py b/backpack/extensions/secondorder/sqrt_ggn/base.py index f625e38b1..425766f8e 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/base.py +++ b/backpack/extensions/secondorder/sqrt_ggn/base.py @@ -68,7 +68,13 @@ def param_function( GGN/Fisher matrix square root. """ return self.derivatives.param_mjp( - param_str, module, g_inp, g_out, backproped, sum_batch=False + param_str, + module, + g_inp, + g_out, + backproped, + sum_batch=False, + subsampling=ext.get_subsampling(), ) return param_function diff --git a/backpack/extensions/secondorder/sqrt_ggn/losses.py b/backpack/extensions/secondorder/sqrt_ggn/losses.py index 30f561396..2294bc794 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/losses.py +++ b/backpack/extensions/secondorder/sqrt_ggn/losses.py @@ -45,13 +45,20 @@ def backpropagate( NotImplementedError: For invalid strategies to represent the loss Hessian. """ loss_hessian_strategy = ext.get_loss_hessian_strategy() + subsampling = ext.get_subsampling() if loss_hessian_strategy == LossHessianStrategy.EXACT: - return self.derivatives.sqrt_hessian(module, grad_inp, grad_out) + return self.derivatives.sqrt_hessian( + module, grad_inp, grad_out, subsampling=subsampling + ) elif loss_hessian_strategy == LossHessianStrategy.SAMPLING: mc_samples = ext.get_num_mc_samples() return self.derivatives.sqrt_hessian_sampled( - module, grad_inp, grad_out, mc_samples=mc_samples + module, + grad_inp, + grad_out, + mc_samples=mc_samples, + subsampling=subsampling, ) else: raise NotImplementedError( diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 9221624c5..a10ea8560 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -159,7 +159,7 @@ def _sample_jac_t_mat_jac_prod(sample_idx, mat): def _sample_jac_t_mat_prod(sample_idx, mat): sample, output, _ = self.problem.forward_pass( - input_requires_grad=True, sample_idx=sample_idx + input_requires_grad=True, subsampling=[sample_idx] ) result = zeros(sample.numel(), mat.size(1), device=sample.device) diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index 43fb0dcd0..d01378e7e 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -2,7 +2,7 @@ import copy from test.core.derivatives.utils import derivative_cls_for, get_available_devices -from typing import Dict, Tuple +from typing import Dict, List, Tuple import torch from torch import Tensor @@ -142,19 +142,20 @@ def is_loss(self): return is_loss(self.make_module()) def forward_pass( - self, input_requires_grad: bool = False, sample_idx: int = None + self, input_requires_grad: bool = False, subsampling: List[int] = None ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: """Do a forward pass. Return input, output, and parameters.""" input: Tensor = self.input.clone().detach() - if sample_idx is not None: + + if subsampling is not None: batch_axis_in = get_batch_axis(self.module, "input0") - input = subsample(input, dim=batch_axis_in, subsampling=[sample_idx]) + input = subsample(input, dim=batch_axis_in, subsampling=subsampling) if input_requires_grad: input.requires_grad = True if self.is_loss(): - assert sample_idx is None + assert subsampling is None output: Tensor = self.module(input, self.target) else: output: Tensor = self.module(input) @@ -187,3 +188,11 @@ def has_weight(self): def has_bias(self): module = self.make_module() return hasattr(module, "bias") and module.bias is not None + + def get_batch_size(self) -> int: + """Return the mini-batch size. + + Returns: + Mini-batch size. + """ + return self.input.shape[get_batch_axis(self.module, "input0")] diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index bf81604a0..ac9f94f11 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -1,4 +1,5 @@ """Autograd implementation of BackPACK's extensions.""" +from math import isclose from test.extensions.implementation.base import ExtensionsImplementation from typing import Iterator, List, Union @@ -8,7 +9,6 @@ from backpack.hessianfree.ggnvp import ggn_vector_product from backpack.hessianfree.rop import R_op from backpack.utils.convert_parameters import vector_to_parameter_list -from backpack.utils.subsampling import get_batch_axis class AutogradExtensions(ExtensionsImplementation): @@ -17,24 +17,20 @@ class AutogradExtensions(ExtensionsImplementation): def batch_grad( self, subsampling: Union[List[int], None] ) -> List[Tensor]: # noqa: D102 - N = self.problem.input.shape[get_batch_axis(self.problem.model, "input0")] + N = self.problem.get_batch_size() samples = list(range(N)) if subsampling is None else subsampling - loss_list = zeros(N) gradients_list = [] for b in range(N): - _, _, loss = self.problem.forward_pass(sample_idx=b) + _, _, loss = self.problem.forward_pass(subsampling=[b]) gradients = autograd.grad(loss, self.problem.trainable_parameters()) gradients_list.append(gradients) - loss_list[b] = loss - - _, _, batch_loss = self.problem.forward_pass() - factor = self.problem.get_reduction_factor(batch_loss, loss_list) batch_grads = [ zeros(len(samples), *p.size()).to(self.problem.device) for p in self.problem.trainable_parameters() ] + factor = self.problem.compute_reduction_factor() for out_idx, sample in enumerate(samples): for param_idx, sample_g in enumerate(gradients_list[sample]): @@ -85,18 +81,15 @@ def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa: D102 return self._diag_ggn_exact_batch() def _diag_ggn_exact_batch(self): - batch_size = self.problem.input.shape[0] - _, _, batch_loss = self.problem.forward_pass() - loss_list = zeros(batch_size, device=self.problem.device) - # batch_diag_ggn has entries [sample_idx][param_idx] batch_diag_ggn = [] - for b in range(batch_size): - _, output, loss = self.problem.forward_pass(sample_idx=b) + for b in range(self.problem.get_batch_size()): + _, output, loss = self.problem.forward_pass(subsampling=[b]) diag_ggn = self._get_diag_ggn(loss, output) batch_diag_ggn.append(diag_ggn) - loss_list[b] = loss - factor = self.problem.get_reduction_factor(batch_loss, loss_list) + + factor = self.problem.compute_reduction_factor() + # params_batch_diag_ggn has entries [param_idx][sample_idx] params_batch_diag_ggn = list(zip(*batch_diag_ggn)) return [stack(param) * factor for param in params_batch_diag_ggn] @@ -129,23 +122,28 @@ def diag_h(self) -> List[Tensor]: # noqa: D102 return self._get_diag_h(loss) def diag_h_batch(self) -> List[Tensor]: # noqa: D102 - batch_size = self.problem.input.shape[0] - _, _, batch_loss = self.problem.forward_pass() - loss_list = zeros(batch_size, device=self.problem.device) - batch_diag_h = [] - for b in range(batch_size): - _, _, loss = self.problem.forward_pass(sample_idx=b) - loss_list[b] = loss + for b in range(self.problem.get_batch_size()): + _, _, loss = self.problem.forward_pass(subsampling=[b]) diag_h = self._get_diag_h(loss) batch_diag_h.append(diag_h) - factor = self.problem.get_reduction_factor(batch_loss, loss_list) + + factor = self.problem.compute_reduction_factor() + params_batch_diag_h = list(zip(*batch_diag_h)) return [stack(param) * factor for param in params_batch_diag_h] - def ggn(self) -> Tensor: # noqa: D102 - _, output, loss = self.problem.forward_pass() - return stack(list(self._ggn_columns(loss, output)), dim=1) + def ggn(self, subsampling: List[int] = None) -> Tensor: # noqa: D102 + _, output, loss = self.problem.forward_pass(subsampling=subsampling) + ggn = stack(list(self._ggn_columns(loss, output)), dim=1) + + # correct normalization constant for 'mean' reduction + if subsampling is not None: + factor = self.problem.compute_reduction_factor() + if not isclose(factor, 1.0): + ggn *= len(subsampling) * factor + + return ggn def _ggn_columns(self, loss: Tensor, output: Tensor) -> Iterator[Tensor]: params = list(self.problem.trainable_parameters()) diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index 2784bfc03..4754be907 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -187,40 +187,52 @@ def diag_h_batch(self) -> List[Tensor]: # noqa:D102 loss.backward() return self.problem.collect_data("diag_h_batch") - def ggn(self) -> Tensor: # noqa:D102 - return self._square_sqrt_ggn(self.sqrt_ggn()) + def ggn(self, subsampling: List[int] = None) -> Tensor: # noqa:D102 + return self._square_sqrt_ggn(self.sqrt_ggn(subsampling=subsampling)) - def sqrt_ggn(self) -> List[Tensor]: + def sqrt_ggn(self, subsampling: List[int] = None) -> List[Tensor]: """Compute the matrix square root of the exact generalized Gauss-Newton. + Args: + subsampling: Indices of active samples. Defaults to ``None`` (use all + samples in the mini-batch). + Returns: Parameter-wise matrix square root of the exact GGN. """ - with backpack(new_ext.SqrtGGNExact()): + with backpack(new_ext.SqrtGGNExact(subsampling=subsampling)): _, _, loss = self.problem.forward_pass() loss.backward() return self.problem.collect_data("sqrt_ggn_exact") - def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: + def sqrt_ggn_mc( + self, mc_samples: int, subsampling: List[int] = None + ) -> List[Tensor]: """Compute the approximate matrix square root of the generalized Gauss-Newton. Args: mc_samples: Number of Monte-Carlo samples. + subsampling: Indices of active samples. Defaults to ``None`` (use all + samples in the mini-batch). Returns: Parameter-wise approximate matrix square root of the exact GGN. """ - with backpack(new_ext.SqrtGGNMC(mc_samples=mc_samples)): + with backpack( + new_ext.SqrtGGNMC(mc_samples=mc_samples, subsampling=subsampling) + ): _, _, loss = self.problem.forward_pass() loss.backward() return self.problem.collect_data("sqrt_ggn_mc") - def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: # noqa:D102 + def ggn_mc( + self, mc_samples: int, chunks: int = 1, subsampling: List[int] = None + ) -> Tensor: # noqa:D102 samples = chunk_sizes(mc_samples, chunks) weights = [samples / mc_samples for samples in samples] return sum( - w * self._square_sqrt_ggn(self.sqrt_ggn_mc(s)) + w * self._square_sqrt_ggn(self.sqrt_ggn_mc(s, subsampling=subsampling)) for w, s in zip(weights, samples) ) diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index 638c780df..8f53af7fe 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -109,27 +109,38 @@ def diag_h_batch(self) -> List[Tensor]: """Per-sample Hessian diagonal. Returns: - list(torch.Tensor): Parameter-wise per-sample Hessian diagonal. + Parameter-wise per-sample Hessian diagonal. """ return @abstractmethod - def ggn(self) -> Tensor: + def ggn(self, subsampling: List[int] = None) -> Tensor: """Exact generalized Gauss-Newton/Fisher matrix. + Note: + For losses with ``'mean'`` reduction, the GGN is ``¹/N ∑ₙ Jₙᵀ Hₙ Jₙ``. If + sub-sampling is enabled, the sum will only run over active samples. The + normalization will not be ``1/len(subsampling)``, but remain ``1/N``. + + Args: + subsampling: Indices of active samples. Default: ``None`` (all). + Returns: Matrix representation of the exact GGN. """ return @abstractmethod - def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: + def ggn_mc( + self, mc_samples: int, chunks: int = 1, subsampling: List[int] = None + ) -> Tensor: """Compute the MC-approximation of the GGN in chunks of MC samples. Args: mc_samples: Number of Monte-Carlo samples. chunks: Number of sequential portions to split the computation. Default: ``1`` (no sequential split). + subsampling: Indices of active samples. Default: ``None`` (all). Returns: Matrix representation of the Monte-Carlo approximated GGN. diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 9edaaa9d0..77232c0f2 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -132,31 +132,30 @@ def make_id(self): ).replace(" ", "") ) - def forward_pass(self, sample_idx: int = None) -> Tuple[Tensor, Tensor, Tensor]: + def forward_pass( + self, subsampling: List[int] = None + ) -> Tuple[Tensor, Tensor, Tensor]: """Do a forward pass. Return input, output, and parameters. - The forward pass is performed on the selected index. - If the index is None, then the forward pass is calculated for the whole batch. + If sub-sampling is None, the forward pass is calculated on the whole batch. Args: - sample_idx: Index of the sample to select. Defaults to ``None``. + subsampling: Indices of selected samples. Default: ``None`` (all samples). Returns: - input, output, loss, each with batch axis first + input, output, and loss of the forward pass """ - if sample_idx is None: - input = self.input.clone() - target = self.target.clone() - else: + input = self.input.clone() + target = self.target.clone() + + if subsampling is not None: batch_axis_in = get_batch_axis(self.model, "input0") + input = subsample(self.input, dim=batch_axis_in, subsampling=subsampling) + batch_axis_out = get_batch_axis(self.model, "output") - target = subsample( - self.target, dim=batch_axis_out, subsampling=[sample_idx] - ) - input = subsample(self.input, dim=batch_axis_in, subsampling=[sample_idx]) + target = subsample(self.target, dim=batch_axis_out, subsampling=subsampling) output = self.model(input) - loss = self.loss_function(output, target) return input, output, loss @@ -166,15 +165,16 @@ def extend(self): self.model = extend(self.model) self.loss_function = extend(self.loss_function) - def get_reduction_factor(self, loss, unreduced_loss): + @staticmethod + def __get_reduction_factor(loss: Tensor, unreduced_loss: Tensor) -> float: """Return the factor used to reduce the individual losses. Args: - loss (torch.Tensor): the loss after reduction - unreduced_loss (torch.Tensor): the raw loss before reduction + loss: Reduced loss. + unreduced_loss: Unreduced loss. Returns: - float: factor + Reduction factor. Raises: RuntimeError: if either mean or sum cannot be determined @@ -233,3 +233,32 @@ def collect_data(self, savefield: str) -> List[Any]: ) return data + + def get_batch_size(self) -> int: + """Return the mini-batch size. + + Returns: + Mini-batch size. + """ + return self.input.shape[get_batch_axis(self.model, "input0")] + + def compute_reduction_factor(self) -> float: + """Compute loss function's reduction factor for aggregating per-sample losses. + + For instance, if ``reduction='mean'`` is used, then the reduction factor + is ``1 / N`` where ``N`` is the batch size. With ``reduction='sum'``, it + is ``1``. + + Returns: + Reduction factor + """ + _, _, loss = self.forward_pass() + + batch_size = self.get_batch_size() + loss_list = torch.zeros(batch_size, device=self.device) + + for n in range(batch_size): + _, _, loss_n = self.forward_pass(subsampling=[n]) + loss_list[n] = loss_n + + return self.__get_reduction_factor(loss, loss_list) diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py index 75b1dbee4..7034044a3 100644 --- a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -1,18 +1,24 @@ """Tests BackPACK's ``SqrtGGNExact`` and ``SqrtGGNMC`` extension.""" +from math import isclose from test.automated_test import check_sizes_and_values from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import ExtensionsTestProblem, make_test_problems from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS +from test.utils.skip_test import skip_large_parameters, skip_subsampling_conflict +from typing import List, Union -from pytest import fixture, mark, skip +from pytest import fixture, mark PROBLEMS = make_test_problems(SECONDORDER_SETTINGS) +SUBSAMPLINGS = [None, [0, 0], [2, 0]] +SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] + @fixture(params=PROBLEMS, ids=lambda p: p.make_id()) -def instantiated_problem(request) -> ExtensionsTestProblem: +def problem(request) -> ExtensionsTestProblem: """Set seed, create tested model, loss, data. Finally clean up. Args: @@ -27,40 +33,29 @@ def instantiated_problem(request) -> ExtensionsTestProblem: case.tear_down() -@fixture -def small_problem( - instantiated_problem: ExtensionsTestProblem, max_num_params=1000 -) -> ExtensionsTestProblem: - """Skip architectures with too many parameters whose GGN is expensive to evaluate. - - Args: - instantiated_problem: Test case with instantiated model, data, etc. - max_num_params: Maximum number of model parameters to run the case. - Default: ``1000``. - - Yields: - Instantiated test case whose model's are small enough. - """ - num_params = sum(p.numel() for p in instantiated_problem.trainable_parameters()) - if num_params <= max_num_params: - yield instantiated_problem - else: - skip(f"Model has too many parameters: {num_params} > {max_num_params}") - - -def test_ggn_exact(small_problem: ExtensionsTestProblem): +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +def test_ggn_exact( + problem: ExtensionsTestProblem, subsampling: Union[List[int], None] +) -> None: """Compare exact GGN from BackPACK's matrix square root with autograd. Args: - small_problem: Test case with small network whose GGN can be evaluated. + problem: Test case with small network whose GGN can be evaluated. + subsampling: Indices of active samples. ``None`` uses the full mini-batch. """ - autograd_res = AutogradExtensions(small_problem).ggn() - backpack_res = BackpackExtensions(small_problem).ggn() + skip_large_parameters(problem) + skip_subsampling_conflict(problem, subsampling) + + autograd_res = AutogradExtensions(problem).ggn(subsampling=subsampling) + backpack_res = BackpackExtensions(problem).ggn(subsampling=subsampling) check_sizes_and_values(autograd_res, backpack_res) -def test_sqrt_ggn_mc_integration(small_problem: ExtensionsTestProblem): +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +def test_sqrt_ggn_mc_integration( + problem: ExtensionsTestProblem, subsampling: Union[List[int], None] +) -> None: """Check if MC-approximated GGN matrix square root code executes. Note: @@ -70,21 +65,41 @@ def test_sqrt_ggn_mc_integration(small_problem: ExtensionsTestProblem): frequently. Args: - small_problem: Test case with small network whose GGN can be evaluated. + problem: Test case with small network whose GGN can be evaluated. + subsampling: Indices of active samples. ``None`` uses the full mini-batch. """ - BackpackExtensions(small_problem).sqrt_ggn_mc(mc_samples=1) + skip_large_parameters(problem) + skip_subsampling_conflict(problem, subsampling) + + BackpackExtensions(problem).sqrt_ggn_mc(mc_samples=1, subsampling=subsampling) @mark.montecarlo -def test_ggn_mc(small_problem: ExtensionsTestProblem): - """Compare MC-approximated GGN from BackpACK's with exact version from autograd. +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +def test_ggn_mc( + problem: ExtensionsTestProblem, subsampling: Union[List[int], None] +) -> None: + """Compare MC-approximated GGN from BackPACK with exact version from autograd. Args: - small_problem: Test case with small network whose GGN can be evaluated. + problem: Test case with small network whose GGN can be evaluated. + subsampling: Indices of active samples. ``None`` uses the full mini-batch. """ - autograd_res = AutogradExtensions(small_problem).ggn() - atol, rtol = 5e-2, 1e-2 - mc_samples, chunks = 300000, 30 - backpack_res = BackpackExtensions(small_problem).ggn_mc(mc_samples, chunks=chunks) + skip_large_parameters(problem) + skip_subsampling_conflict(problem, subsampling) + + autograd_res = AutogradExtensions(problem).ggn(subsampling=subsampling) + atol, rtol = 5e-3, 5e-3 + mc_samples, chunks = 150000, 15 + backpack_res = BackpackExtensions(problem).ggn_mc( + mc_samples, chunks=chunks, subsampling=subsampling + ) + + # compare normalized entries ∈ [-1; 1] (easier to tune atol) + max_val = max(autograd_res.abs().max(), backpack_res.abs().max()) + # NOTE: The GGN can be exactly zero; e.g. if a ReLU after all parameters zeroes + # its input, its Jacobian is thus zero and will cancel the backpropagated GGN + if not isclose(max_val, 0): + autograd_res, backpack_res = autograd_res / max_val, backpack_res / max_val check_sizes_and_values(autograd_res, backpack_res, atol=atol, rtol=rtol) diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index b1e10554a..34257e946 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -1,13 +1,13 @@ """Skip specific tests.""" from test.core.derivatives.problem import DerivativesTestProblem +from test.extensions.problem import ExtensionsTestProblem from typing import List, Union from pytest import skip from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_1 -from backpack.utils.subsampling import get_batch_axis def skip_adaptive_avg_pool3d_cuda(request) -> None: @@ -44,7 +44,8 @@ def skip_batch_norm_train_mode_with_subsampling( def skip_subsampling_conflict( - problem: DerivativesTestProblem, subsampling: Union[List[int], None] + problem: Union[DerivativesTestProblem, ExtensionsTestProblem], + subsampling: Union[List[int], None], ) -> None: """Skip if some samples in subsampling are not contained in input. @@ -52,7 +53,21 @@ def skip_subsampling_conflict( problem: Test case. subsampling: Indices of active samples. """ - N = problem.input_shape[get_batch_axis(problem.module, "input0")] + N = problem.get_batch_size() enough_samples = subsampling is None or N > max(subsampling) if not enough_samples: skip("Not enough samples.") + + +def skip_large_parameters( + problem: ExtensionsTestProblem, max_num_params: int = 1000 +) -> None: + """Skip architectures with too many parameters. + + Args: + problem: Test case. + max_num_params: Maximum number of model parameters. Default: ``1000``. + """ + num_params = sum(p.numel() for p in problem.trainable_parameters()) + if num_params > max_num_params: + skip(f"Model has too many parameters: {num_params} > {max_num_params}") From afc94f1d9a1666e56f0cefea24c2288f928f6a56 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 5 Aug 2021 19:22:39 +0200 Subject: [PATCH 41/54] [DOC] Sub-sampling use case (#210) Adds an example to the documentation that shows how to restrict BackPACK extensions to a subset of mini-batch samples. Progress on #12. --- * [ADD] Introduce subsampling argument in SqrtGGN{Exact,MC} * Generalize get_batch_axis to modules with children * [ADD] Introduce `subsampling` in BackpropExtension interface * [ADD] Consider subsampling in `SqrtGGN` module extensions * [ADD] Make skip condition work for derivatives & extension cases * [TEST] Parametrize forward pass with subsampling, extract reduction factor * [TEST] Test correctness of exact GGN with sub-sampling * [TEST] Check GGN-MC from square root for sub-samplings * [FIX] Docstring * [TEST] Normalize GGNMC to before comparing, reduce MC samples * [FIX] Don't normalize if GGN-MC is exactly zero * [REF] Move skip condition for large parameters * [ADD] Reintroduce montecarlo marker * [DOC] Add example for mini-batch sub-sampling Co-authored-by: Felix Dangel --- .../examples/use_cases/example_subsampling.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 docs_src/examples/use_cases/example_subsampling.py diff --git a/docs_src/examples/use_cases/example_subsampling.py b/docs_src/examples/use_cases/example_subsampling.py new file mode 100644 index 000000000..b89d01a3d --- /dev/null +++ b/docs_src/examples/use_cases/example_subsampling.py @@ -0,0 +1,86 @@ +"""Mini-batch sub-sampling +========================== + +By default, BackPACK's extensions consider all samples in a mini-batch. Some extensions +support limiting the computations to a subset of samples. This example shows how to +restrict the computations to such a subset of samples. + +This may be interesting for applications where parts of the samples are used for +different purposes, e.g. computing curvature and gradient information on different +subsets. Limiting the computations to fewer samples also reduces costs. + +.. note:: + Not all extensions support sub-sampling yet. Please create a feature request in the + repository if the extension you need is not supported. +""" + +# %% +# Let's start by loading some dummy data and extending the model + +from torch import allclose, cuda, device, manual_seed +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend +from backpack.extensions import BatchGrad +from backpack.utils.examples import load_one_batch_mnist + +# make deterministic +manual_seed(0) + +dev = device("cuda" if cuda.is_available() else "cpu") + +# data +X, y = load_one_batch_mnist(batch_size=128) +X, y = X.to(dev), y.to(dev) + +# model +model = Sequential(Flatten(), Linear(784, 10)).to(dev) +lossfunc = CrossEntropyLoss().to(dev) + +model = extend(model) +lossfunc = extend(lossfunc) + +# %% +# Individual gradients for a mini-batch subset +# -------------------------------------------- +# +# Let's say we only want to compute individual gradients for samples 0, 1, +# 13, and 42. Naively, we could perform the computation for all samples, then +# slice out the samples we care about. + +# selected samples +subsampling = [0, 1, 13, 42] + +loss = lossfunc(model(X), y) + +with backpack(BatchGrad()): + loss.backward() + +# naive approach: compute for all, slice out relevant +naive = [p.grad_batch[subsampling] for p in model.parameters()] + +# %% +# This is not efficient, as individual gradients are computed for all samples, +# most of them being discarded after. We can do better by specifying the active +# samples directly with the ``subsampling`` argument of +# :py:class:`BatchGrad `. + +loss = lossfunc(model(X), y) + +# efficient approach: specify active samples in backward pass +with backpack(BatchGrad(subsampling=subsampling)): + loss.backward() + +efficient = [p.grad_batch for p in model.parameters()] + +# %% +# Let's verify that both ways yield the same result: + +match = all( + allclose(g_naive, g_efficient) for g_naive, g_efficient in zip(naive, efficient) +) + +print(f"Naive and efficient sub-sampled individual gradients match? {match}") + +if not match: + raise ValueError("Naive and efficient sub-sampled individual gradient don't match.") From f35427c9c679afcaf908b04d6fd749aeb5c9c870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 2 Sep 2021 15:12:33 +0200 Subject: [PATCH 42/54] [ADD] Basic support for residual networks (#202) Enables basic support for architectures with branched computation graphs. Main features: - Converter functionality to replace nodes with BackPACK- compatible modules and enable second-order extensions, exposed through the `extend` function - Example how to use BackPACK with ResNets - Custom modules to build ResNets as containers, without overwriting the forward pass Internals: - Implement aggregation of backpropagated quantities - Test `DiagGGNExact` on `torchvision`'s `resnet18` and `wide_resnet_50_2` - Converter test cases that check the forward remains the same, and `DiagGGNExact` works --- * get changes from branch 'branching' * move branching.py to custom_modules/ * docstrings and type hints * fix import * format * all DiagGGN extensions * generalize Identity to ScaleModule * example from Katharina * format * delete unnecessary mapping * resnets.py working for GGN and KFAC * experiments with torch.fx * experiments with torch.fx * graph conversion: branching * graph conversion: branching * graph conversion: branching * change interface Merge() * change backprop for branch points * converter for Katharina's resnet * simplify graph_utils.py * scaffold for resnet conversion * make resnet18 work, simplify utils * backward pass resnet18 * delete branching mark functions * delete branching mark functions, rename Merge -> Sum * generalize Parallel to arbitrary merge modules * remove MergeModuleExtension * move graph_utils.py to custom_modules/ * improve graph_utils.py * fix usage accumulate_backproped quantities * add option to extend(): use_converter * add TODO * fix: using same input multiple times recognized as branch point * fix: test_branching.py: adjust to updated interface * DiagGGN AdaptiveAvgPool * BatchNorm: first order extensions * BatchNorm: DiagGGN extension * Merge: fix diag_ggn_settings.py * print * transform in-place to normal * reduce data size * align naming * reduce num_classes * format * new transformation: remove duplicate modules * refactor * torch version check * simplify extend, allow GraphModule * increase num_classes, evaluation mode * presentation code * resnet18: first order tests * ResNets: DiagGGN tests * format * clean up, create tests * improve print out * example for using resnets * resnet: reduce problem size, use cuda * resnet example: use alternative small example * resnet example: use alternative small example * fix problem.py * fix merge * try fix github error * align naming * simplify except * reduce tests * all tests again * format * test changed for CI * requirements: tabulate * [ADD] use resnet18 with very small input * [DEL] remove small custom example, keep resnet18 * [DEL] remove _maybe_warn_no_batch_summation * [FIX] merge * [FIX] merge * [FIX] convert resnet18 for first order * [FIX] error type * [REF] explicit imports, type hints * [DOC] docstring * [REF] format, type hint, simplify * [DOC] add new files to fully_documented.txt * [TEST] reduce problem size, documentation * [ADD] example_resnet_all_in_one.py * [DEL] delete redundant resnet examples * [REF] Improve docs, type hints, shorten forward pass * [DOC] Slight tweak * [REF] Change tracer name, add type hints * [REF] Introduce variables for target operations * [DOC] Explain necessity for different accumulation rules * [REF] Remove unused imports * [REF] Reduce nesting in `__should_backpropagate` * [REF] Improve readability * [FIX] Forward pass of Parallel module * [FIX] fix flatten->Flatten (Module) transformation * [DEL] Do not (yet) support custom modules in HBP * [REF] Minor rewrite of ResNet2 * [DOC] Clean up doc * [REF] Make ResNets smaller, improve doc, squeeze out some lines * [DOC] Pass through Section 1 of ResNet use case * [DEL] fully_documented.txt: remove hbp/custom_module.py * [REF] variable name * [ADD] converter: debug option * [FIX] converter: debug option * [REF] remove max_depth, introduce _get_free_name * [TEST] test model_converted ^= model_original * [REF] clarify variable name * [DOC] docstring _change_node_to_module * [REF] rename __should_backpropagate to __get_inputs_for_backpropagation * [TEST] set seed for resnet_torchvision.py * [TEST] improve docstring * [REF] rename resnets_examples.py to resnet_cases.py * [TEST] remove .eval() from ResNet1 * [ADD] enforce types in branching.py * [ADD] mul to ScaleModule: allow arbitrary order * [ADD] enforce types in ScaleModule * [TEST] test derivatives: SumModule, ScaleModule * [TEST] rename resnet/ to converter/, test converter with converter_cases.py * [FIX] flake8 list comprehension * [ADD] Utils for the GGN diagonal with autograd * [DOC] Pass through Section 2 of ResNet use case * [DOC, CFG] Pass over ResNet example Sec. 3, move tabulate dependency * [REF] Remove duplicate GGN diagonal utility function * [REF] Introduce `CONVERTER_AVAILABLE' variable * [DOC] Fully document utilities for examples * [FMT] Replace `.format` with `f`-string * [DOC] Polish docstring * [DOC] Docstring polish * [REF] Replace `AssertionError` with `ValueError` * [ADD] reintroduce example_first_order_resnet.py and link to new location * [DOC] docstring _make_layer * [FIX] ensure last element is excluded in linspace * [TEST][ADD] test WideResNet * [TEST] test some diagonal elements of DiagGGN * [DOCS] White space cosmetics, switch sentences * [DOCS] Replace bullets with enumerated items * [DOCS] Polish * [DOCS] Minor edits * [REF] Less indented transformations * [REF] Reduce duplication of f-strings * [FIX] fix link to new example * [REF] Less indentation in duplicate transformation * [FIX] raising an exception for parameter cycles * [FIX] Remove unused import * [DEL] revert/remove test derivatives for SumModule * [DEL] remove unnecessary print * [REF] Rename variables, reduce atol * [REF] sharpen test tolerance * [REF] descriptive error messages * [REF] Improve error messages in converter * [REF] Make Branch class private, remove redundant test * [DOCS] Reduce atol by one order for GPU * [FIX] remove sum_module_settings.py from fully_documented.txt * [ADD] Converter: check BackPACK compatibility * [DEL] Don't docstring-lint sum_module_settings * [DOC] add docstrings to transformations * [REF] Shorten warning * [DOC] Update limitations section Co-authored-by: Felix Dangel --- backpack/__init__.py | 44 ++- backpack/core/derivatives/batchnorm_nd.py | 16 - backpack/core/derivatives/scale_module.py | 25 ++ backpack/core/derivatives/sum_module.py | 21 ++ backpack/custom_module/branching.py | 123 +++++++ backpack/custom_module/graph_utils.py | 337 ++++++++++++++++++ backpack/custom_module/scale_module.py | 32 ++ backpack/extensions/backprop_extension.py | 23 +- backpack/extensions/module_extension.py | 31 +- backpack/extensions/saved_quantities.py | 26 +- .../secondorder/diag_ggn/__init__.py | 23 ++ .../secondorder/diag_ggn/custom_module.py | 20 ++ backpack/utils/__init__.py | 1 + backpack/utils/examples.py | 124 +++++-- backpack/utils/module_classification.py | 11 +- .../use_cases/example_first_order_resnet.py | 99 +---- .../use_cases/example_resnet_all_in_one.py | 325 +++++++++++++++++ docs_src/rtd/good-to-know.rst | 11 +- fully_documented.txt | 6 + setup.cfg | 1 + test/converter/__init__.py | 1 + test/converter/converter_cases.py | 161 +++++++++ test/converter/resnet_cases.py | 146 ++++++++ test/converter/test_branching.py | 145 ++++++++ test/converter/test_converter.py | 79 ++++ test/core/derivatives/__init__.py | 9 + test/core/derivatives/derivatives_test.py | 12 +- .../core/derivatives/scale_module_settings.py | 29 ++ .../firstorder/firstorder_settings.py | 20 ++ test/extensions/problem.py | 15 +- .../secondorder/diag_ggn/diag_ggn_settings.py | 110 +++++- test/utils/skip_test.py | 8 +- 32 files changed, 1847 insertions(+), 187 deletions(-) create mode 100644 backpack/core/derivatives/scale_module.py create mode 100644 backpack/core/derivatives/sum_module.py create mode 100644 backpack/custom_module/branching.py create mode 100644 backpack/custom_module/graph_utils.py create mode 100644 backpack/custom_module/scale_module.py create mode 100644 backpack/extensions/secondorder/diag_ggn/custom_module.py create mode 100644 docs_src/examples/use_cases/example_resnet_all_in_one.py create mode 100644 test/converter/__init__.py create mode 100644 test/converter/converter_cases.py create mode 100644 test/converter/resnet_cases.py create mode 100644 test/converter/test_branching.py create mode 100644 test/converter/test_converter.py create mode 100644 test/core/derivatives/scale_module_settings.py diff --git a/backpack/__init__.py b/backpack/__init__.py index 2ae09a779..068a973af 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -1,19 +1,22 @@ """BackPACK.""" -import inspect +from inspect import isclass from types import TracebackType from typing import Callable, Optional, Tuple, Type, Union -import torch -from torch import Tensor +from torch import Tensor, is_grad_enabled from torch.nn import Module +from backpack import extensions +from backpack.context import CTX from backpack.extensions.backprop_extension import BackpropExtension +from backpack.utils import CONVERTER_AVAILABLE, FULL_BACKWARD_HOOK from backpack.utils.hooks import no_op +from backpack.utils.module_classification import is_no_op -from . import extensions -from .context import CTX -from .utils import FULL_BACKWARD_HOOK -from .utils.module_classification import is_no_op +if CONVERTER_AVAILABLE: + from torch.fx import GraphModule + + from backpack.custom_module.graph_utils import convert_module_to_backpack class backpack: @@ -23,7 +26,7 @@ def __init__( self, *exts: BackpropExtension, extension_hook: Callable[[Module], None] = None, - debug: bool = False + debug: bool = False, ): """Activate BackPACK extensions. @@ -47,16 +50,16 @@ def __init__( """ for ext in exts: if not isinstance(ext, BackpropExtension): - if inspect.isclass(ext) and issubclass(ext, BackpropExtension): + if isclass(ext) and issubclass(ext, BackpropExtension): raise ValueError( - "backpack expect instances of BackpropExtension," - + " but received a class instead [{}].".format(ext) + "backpack expects instances of BackpropExtension," + + f" but received a class instead [{ext}]." + " Instantiate it before passing it to backpack." ) else: raise ValueError( "backpack expects instances of BackpropExtension," - + " but received [{}].".format(ext) + + f" but received [{ext}]." ) self.exts: Tuple[BackpropExtension, ...] = exts @@ -157,7 +160,7 @@ def hook_store_io( input: List of input tensors output: result of module(input) """ - if disable.should_store_io() and torch.is_grad_enabled(): + if disable.should_store_io() and is_grad_enabled(): for i in range(len(input)): setattr(module, "input{}".format(i), input[i]) if isinstance(output, tuple): @@ -167,7 +170,7 @@ def hook_store_io( module.output = output -def memory_cleanup(module) -> None: +def memory_cleanup(module: Module) -> None: """Remove I/O stored by backpack during the forward pass. Deletes the attributes created by `hook_store_io`. @@ -215,7 +218,7 @@ def hook_run_extensions( memory_cleanup(module) -def extend(module: Module, debug: bool = False) -> Module: +def extend(module: Module, debug: bool = False, use_converter: bool = False) -> Module: """Recursively extend a ``module`` to make it BackPACK-ready. Modules that do not represent an operation in the computation graph (for instance @@ -224,13 +227,24 @@ def extend(module: Module, debug: bool = False) -> Module: Args: module: The module to extend. debug: Print debug messages during the extension. Default: ``False``. + use_converter: Try converting the module to a BackPACK-compatible network. Returns: Extended module. + + Raises: + RuntimeError: if trying to use converter without torch>=1.9.0 """ if debug: print("[DEBUG] Extending", module) + if use_converter: + if not CONVERTER_AVAILABLE: + raise RuntimeError("use_converter=True is only available for torch>=1.9.0.") + + module: GraphModule = convert_module_to_backpack(module, debug) + return extend(module) + for child in module.children(): extend(child, debug=debug) diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index 5d9a611b3..93c8cf99c 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -1,6 +1,5 @@ """Contains derivatives for BatchNorm.""" from typing import List, Tuple, Union -from warnings import warn from torch import Size, Tensor, einsum from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d @@ -146,7 +145,6 @@ def _weight_jac_t_mat_prod( sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: - self._maybe_warn_no_batch_summation(sum_batch) x_hat, _ = self._get_normalized_input_and_var(module) x_hat = subsample(x_hat, subsampling=subsampling) @@ -184,7 +182,6 @@ def _bias_jac_t_mat_prod( sum_batch: bool = True, subsampling: List[int] = None, ) -> Tensor: - self._maybe_warn_no_batch_summation(sum_batch) axis_sum: Tuple[int] = self._get_free_axes(module, with_batch_axis=sum_batch) return mat.sum(dim=axis_sum) if axis_sum else mat @@ -323,16 +320,3 @@ def _get_free_axes( for n in range(self._get_n_axis(module)): free_axes.append(index_batch + n + 2) return tuple(free_axes) - - @staticmethod - def _maybe_warn_no_batch_summation(sum_batch: bool) -> None: - """Warn that Jacobians w.r.t. single components are not per-sample gradients. - - Args: - sum_batch: Whether to sum out the batch dimension. - """ - if not sum_batch: - warn( - "BatchNorm batch summation disabled." - "This may not compute meaningful quantities" - ) diff --git a/backpack/core/derivatives/scale_module.py b/backpack/core/derivatives/scale_module.py new file mode 100644 index 000000000..d83d33c04 --- /dev/null +++ b/backpack/core/derivatives/scale_module.py @@ -0,0 +1,25 @@ +"""Derivatives of ScaleModule (implies ActiveIdentity and Identity).""" +from typing import List, Tuple, Union + +from torch import Tensor +from torch.nn import Identity + +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.custom_module.scale_module import ScaleModule + + +class ScaleModuleDerivatives(BaseDerivatives): + """Derivatives of ScaleModule (implies ActiveIdentity and Identity).""" + + def _jac_t_mat_prod( + self, + module: Union[ScaleModule, Identity], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: + if isinstance(module, Identity): + return mat + else: + return mat * module.weight diff --git a/backpack/core/derivatives/sum_module.py b/backpack/core/derivatives/sum_module.py new file mode 100644 index 000000000..e8383fe6b --- /dev/null +++ b/backpack/core/derivatives/sum_module.py @@ -0,0 +1,21 @@ +"""Contains derivatives for SumModule.""" +from typing import List, Tuple + +from torch import Tensor + +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.custom_module.branching import SumModule + + +class SumModuleDerivatives(BaseDerivatives): + """Contains derivatives for SumModule.""" + + def _jac_t_mat_prod( + self, + module: SumModule, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: + return mat diff --git a/backpack/custom_module/branching.py b/backpack/custom_module/branching.py new file mode 100644 index 000000000..fc0d5149a --- /dev/null +++ b/backpack/custom_module/branching.py @@ -0,0 +1,123 @@ +"""Emulating branching with modules.""" +from typing import Any, OrderedDict, Tuple, Union + +from torch import Tensor +from torch.nn import Module + +from backpack.custom_module.scale_module import ScaleModule + + +class ActiveIdentity(ScaleModule): + """Like ``torch.nn.Identity``, but creates a new node in the computation graph.""" + + def __init__(self): + """Initialization with weight=1.0.""" + super().__init__(weight=1.0) + + +class _Branch(Module): + """Module used by BackPACK to handle branching in the computation graph. + + ↗ module1 → output1 + input → module2 → output2 + ↘ ... → ... + """ + + def __init__(self, *args: Union[OrderedDict[str, Module], Module]): + """Use interface of ``torch.nn.Sequential``. Modules are parallel sequence. + + Args: + args: either an ordered dictionary of modules or a tuple of modules + """ + super().__init__() + + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def forward(self, input: Tensor) -> Tuple[Any, ...]: + """Feed input through each child module. + + Args: + input: input tensor + + Returns: + tuple of output tensor + """ + return tuple(module(input) for module in self.children()) + + +class SumModule(Module): + """Module used by BackPACK to handle branch merges in the computation graph. + + module 1 ↘ + module 2 → SumModule (sum) + ... ↗ + """ + + def forward(self, *input: Tensor) -> Tensor: + """Sum up all inputs (a tuple of tensors). + + Args: + input: tuple of input tensors + + Returns: + sum of all inputs + + Raises: + AssertionError: if input is no tuple of matching tensors + """ + if not isinstance(input, tuple): + raise AssertionError(f"Expecting tuple as input. Got {input.__class__}") + elif not all(isinstance(inp, Tensor) for inp in input): + raise AssertionError( + f"Expecting tuple of tensors, but received ({[inp.__class__ for inp in input]})" + ) + elif not all(input[0].shape == input[i].shape for i in range(1, len(input))): + raise AssertionError(f"Shapes don't match: {[inp.shape for inp in input]}.") + else: + return sum(input) + + +class Parallel(Module): + """Feed the same input through a parallel sequence of modules. Sum the results. + + Used by BackPACK to emulate branched computations. + + ↗ module 1 ↘ + Branch → module 2 → SumModule (sum) + ↘ ... ↗ + """ + + def __init__( + self, + *args: Union[OrderedDict[str, Module], Module], + merge_module: Module = None, + ): + """Like ``torch.nn.Sequential``, but defines a parallel module sequence. + + Use interface of ``torch.nn.Sequential``. + + Args: + args: either ordered dictionary of modules or tuple of modules + merge_module: The module used for merging. Defaults to ``None``, which + means ``SumModule()`` is used. + """ + super().__init__() + + self.branch = _Branch(*args) + self.merge = SumModule() if merge_module is None else merge_module + + def forward(self, input: Tensor) -> Tensor: + """Forward pass. Concatenation of Branch and SumModule. + + Args: + input: module input + + Returns: + Merged results from forward pass of each branch + """ + return self.merge(*self.branch(input)) diff --git a/backpack/custom_module/graph_utils.py b/backpack/custom_module/graph_utils.py new file mode 100644 index 000000000..7d8a947df --- /dev/null +++ b/backpack/custom_module/graph_utils.py @@ -0,0 +1,337 @@ +"""Transformation tools to make graph BackPACK compatible.""" +from copy import deepcopy +from warnings import warn + +from torch.fx import Graph, GraphModule, Node, Tracer +from torch.nn import Flatten, Module + +from backpack.custom_module.branching import ActiveIdentity, SumModule, _Branch +from backpack.custom_module.scale_module import ScaleModule +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 + + +class BackpackTracer(Tracer): + """Tracer that recognizes BackPACK's custom modules as 'leaf modules'.""" + + def is_leaf_module( + self, m: Module, module_qualified_name: str + ) -> bool: # noqa: D102 + if isinstance(m, (ScaleModule, SumModule, _Branch, ActiveIdentity)): + return True + else: + return super().is_leaf_module(m, module_qualified_name) + + +def convert_module_to_backpack(module: Module, debug: bool) -> GraphModule: + """Convert all modules to BackPACK-compatible modules. + + Transformations: + - mul -> ScaleModule + - add -> AddModule + - flatten -> nn.Flatten + - inplace to normal + - remove duplicates + - delete unused modules + - check BackPACK compatible + + Args: + module: module to convert + debug: if True prints to command line + + Returns: + BackPACK-compatible module + + Raises: + NotImplementedError: if not torch >= 1.9.0 + """ + if TORCH_VERSION_AT_LEAST_1_9_0 is False: + raise NotImplementedError( + "Conversion is only possible for torch >= 1.9.0. This is because these " + "functions use functionality such as torch.nn.Module.get_submodule" + ) + if debug: + print("\nMake module BackPACK-compatible...") + module_new = _transform_mul_to_scale_module(module, debug) + module_new = _transform_flatten_to_module(module_new, debug) + module_new = _transform_add_to_sum_module(module_new, debug) + _transform_inplace_to_normal(module_new, debug) + module_new = _transform_remove_duplicates(module_new, debug) + if debug: + print("\tDelete unused modules.") + module_new.delete_all_unused_submodules() + _check_backpack_compatible(module_new, debug) + if debug: + print("Finished transformation.\n") + return module_new + + +def _check_backpack_compatible(module: Module, debug: bool) -> None: + """Checks whether the computation graph of the given module is BackPACK compatible. + + More specifically, it checks whether all nodes are either input/output + or a call to a module. Subsequent checks if the module is extendable in BackPACK + have to be done by running the extension. + + Args: + module: module to check + debug: whether to print debug messages + """ + if debug: + print("\tChecking BackPACK compatibility.") + graph: Graph = BackpackTracer().trace(module) + for node in graph.nodes: + if node.op not in ["call_module", "placeholder", "output"]: + warn( + f"Encountered node that may break second-order extensions: op={node.op}" + f", target={node.target}. If you encounter this problem, please open an" + " issue at https://github.com/f-dangel/backpack/issues." + ) + + +def _transform_mul_to_scale_module(module: Module, debug: bool) -> GraphModule: + """Transforms multiplications of tensor with float to ScaleModule. + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + + Raises: + RuntimeError: if a multiplication is found but node.args are not (float, Node) + """ + target = "" + if debug: + print(f"\tBegin transformation: {target} -> ScaleModule") + + graph: Graph = BackpackTracer().trace(module) + nodes = [ + n for n in graph.nodes if n.op == "call_function" and str(n.target) == target + ] + + for node in nodes: + if len(node.args) != 2: + raise RuntimeError(f"Expecting 2 arguments, got {len(node.args)}.") + + idx_weight = 0 if isinstance(node.args[0], float) else 1 + idx_tensor = 1 - idx_weight + + weight = node.args[idx_weight] + tensor = node.args[idx_tensor] + + if not (isinstance(weight, float) and isinstance(tensor, Node)): + raise RuntimeError( + f"Expecting types [float, Node], got {[type(weight), type(tensor)]}." + ) + + _change_node_to_module( + node, "scale_module", module, ScaleModule(weight), (tensor,) + ) + + graph.lint() + + if debug: + print(f"\tMultiplications transformed: {len(nodes)}") + + return GraphModule(module, graph) + + +def _transform_add_to_sum_module(module: Module, debug: bool) -> GraphModule: + """Transforms summations of tensors to SumModule (useful in ResNets). + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + """ + target = "" + if debug: + print(f"\tBegin transformation: {target} -> SumModule") + + graph: Graph = BackpackTracer().trace(module) + nodes = [ + n for n in graph.nodes if n.op == "call_function" and str(n.target) == target + ] + + for node in nodes: + _change_node_to_module(node, "sum_module", module, SumModule(), node.args) + + graph.lint() + + if debug: + print(f"\tSummations transformed: {len(nodes)}") + + return GraphModule(module, graph) + + +def _transform_flatten_to_module(module: Module, debug: bool) -> GraphModule: + """Transforms PyTorch's flatten method to the nn.Flatten module. + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + """ + target = " Flatten") + + graph: Graph = BackpackTracer().trace(module) + nodes = [ + n for n in graph.nodes if n.op == "call_function" and target in str(n.target) + ] + + for node in nodes: + start_dim = node.args[1] if len(node.args) > 1 else 1 + end_dim = node.args[2] if len(node.args) > 2 else -1 + _change_node_to_module( + node, "flatten", module, Flatten(start_dim, end_dim), (node.args[0],) + ) + + graph.lint() + + if debug: + print(f"\tFlatten transformed: {len(nodes)}") + + return GraphModule(module, graph) + + +def _transform_inplace_to_normal( + module: Module, debug: bool, initialize_recursion: bool = True +) -> None: + """Searches for in-place operations and changes them to standard operations. + + Args: + module: container module to transform + debug: whether to print debug messages + initialize_recursion: whether this is the initial call to this function. + """ + if initialize_recursion: + if debug: + print("\tBegin transformation: in-place -> standard") + _transform_inplace_to_normal.counter = 0 + if hasattr(module, "inplace") and module.inplace: + module.inplace = False + _transform_inplace_to_normal.counter += 1 + for child_module in module.children(): + _transform_inplace_to_normal(child_module, debug, initialize_recursion=False) + + if initialize_recursion: + if debug: + print(f"\tIn-place changed: {_transform_inplace_to_normal.counter}") + del _transform_inplace_to_normal.counter + + +def _transform_remove_duplicates(module: GraphModule, debug: bool) -> GraphModule: + """Removes duplicate modules by creating a copy of the module. + + This is necessary because BackPACK saves input/output which is overwritten + if the module is called multiple times. + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + + Raises: + NotImplementedError: if a duplicate module has parameters + """ + if debug: + print("\tBegin transformation: remove duplicates") + + graph: Graph = BackpackTracer().trace(module) + + targets = [n.target for n in graph.nodes] + duplicates = {t for t in targets if targets.count(t) > 1} + nodes = [n for n in graph.nodes if n.target in duplicates] + + for node in nodes: + target = node.target + original_module = module.get_submodule(target) + + for _ in original_module.parameters(): + raise NotImplementedError( + f"Cycle with parameters detected: module {original_module} with target" + f" {target} has parameters and is used {targets.count(target)} times." + ) + + new_module = deepcopy(original_module) + new_target = _get_free_name(module, target) + module.add_submodule(new_target, new_module) + node.target = new_target + + graph.lint() + + if debug: + print(f"\tDuplicates removed: {len(nodes)}") + + return GraphModule(module, graph) + + +def _change_node_to_module( + node: Node, + name: str, + base_module: Module, + new_module: Module, + args: tuple, +) -> None: + """Helper function to change an existing node to a module. + + The new module is registered in the base_module as a submodule. + The attribute name is based on name{int}. + The attributes of the node are changed so they point onto the new module. + + Args: + node: existing node + name: proposed name, real name is name{int} + base_module: the module that should get new_module as a child + new_module: the new module to register on the node and base_module + args: arguments of the new node + """ + new_name = _get_free_name(base_module, name) + node.op = "call_module" + node.target = new_name + node.args = args + setattr(base_module, new_name, new_module) + + +def _get_free_name(module: Module, initial_name: str) -> str: + """Find a free name in the modules naming space. + + Args: + module: the parent module + initial_name: a name suggestion + + Returns: + a string with the pattern {initial_name}{int} where module has no such attribute + + Raises: + RuntimeError: if the module already has an attribute with the intended name + """ + + def _has_target(target: str) -> bool: + try: + module.get_submodule(target) + return True + except AttributeError: + return False + + counter = 0 + while _has_target(f"{initial_name}{counter}"): + counter += 1 + name = f"{initial_name}{counter}" + + if hasattr(module, name): + raise RuntimeError( + f"Unable to find a free name for registering a new module." + f"module={module} already has an attribute named {name}." + ) + + return name diff --git a/backpack/custom_module/scale_module.py b/backpack/custom_module/scale_module.py new file mode 100644 index 000000000..2ee03e1a3 --- /dev/null +++ b/backpack/custom_module/scale_module.py @@ -0,0 +1,32 @@ +"""Contains ScaleModule.""" +from torch import Tensor +from torch.nn import Module + + +class ScaleModule(Module): + """Scale Module scales the input by a constant.""" + + def __init__(self, weight: float = 1.0): + """Store scalar weight. + + Args: + weight: Initial value for weight. Defaults to 1.0. + + Raises: + ValueError: if weight is no float + """ + super().__init__() + if not isinstance(weight, float): + raise ValueError("Weight must be float.") + self.weight: float = weight + + def forward(self, input: Tensor) -> Tensor: + """Defines forward pass. + + Args: + input: input + + Returns: + product of input and weight + """ + return input * self.weight diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py index 80ff00820..84e12d254 100644 --- a/backpack/extensions/backprop_extension.py +++ b/backpack/extensions/backprop_extension.py @@ -4,7 +4,7 @@ import abc import warnings from abc import ABC -from typing import Dict, List, Tuple, Type, Union +from typing import Any, Dict, List, Tuple, Type, Union from torch import Tensor from torch.nn import Module @@ -143,3 +143,24 @@ def get_subsampling(self) -> Union[List[int], None]: the full mini-batch is used. """ return self._subsampling + + def accumulate_backpropagated_quantities(self, existing: Any, other: Any) -> Any: + """Specify how to accumulate info that is backpropagated to the same node. + + Must be implemented by second-order extensions to function on computation + graphs with branching. + + For instance, + - ``DiagGGN`` extensions must sum their backpropagated tensor quantities. + - ``curvmatprod`` extensions must chain functions to sums of functions. + + Args: + existing: Backpropagated quantity + other: Other backpropagated quantity + + Raises: + NotImplementedError: if not overwritten + """ + raise NotImplementedError( + f"{self}: No accumulation rule for backpropagated info specified" + ) diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py index 0c5425e3b..3fd9070f0 100644 --- a/backpack/extensions/module_extension.py +++ b/backpack/extensions/module_extension.py @@ -127,25 +127,38 @@ def __call__( extValue = extFunc(extension, module, g_inp, g_out, bp_quantity) self.__save_value_on_parameter(extValue, extension, module, param) - if self.__should_backpropagate(extension, module): + module_inputs = self.__get_inputs_for_backpropagation(extension, module) + if module_inputs: bp_quantity = self.backpropagate( extension, module, g_inp, g_out, bp_quantity ) - self.__save_backproped_quantity(extension, module.input0, bp_quantity) + for module_inp in module_inputs: + self.__save_backproped_quantity(extension, module_inp, bp_quantity) @staticmethod - def __should_backpropagate(extension: BackpropExtension, module: Module) -> bool: - """Determines whether the current extension should perform a backpropagation. + def __get_inputs_for_backpropagation( + extension: BackpropExtension, module: Module + ) -> Tuple[Tensor]: + """Returns the inputs on which a backpropagation should be performed. Args: extension: current extension module: current module Returns: - whether a backpropagation should be performed + the inputs which need a backpropagation quantity """ - input_requires_grad: bool = module.input0.requires_grad - return input_requires_grad and extension.expects_backpropagation_quantities() + module_inputs: Tuple[Tensor, ...] = () + + if extension.expects_backpropagation_quantities(): + i = 0 + while hasattr(module, f"input{i}"): + input = getattr(module, f"input{i}") + if input.requires_grad: + module_inputs += (input,) + i += 1 + + return module_inputs @staticmethod def __should_retain_backproped_quantities(module: Module) -> bool: @@ -200,7 +213,9 @@ def __save_backproped_quantity( bpQuantities: backpropagation quantities that should be saved """ extension.saved_quantities.save_quantity( - reference_tensor.data_ptr(), bpQuantities + reference_tensor.data_ptr(), + bpQuantities, + extension.accumulate_backpropagated_quantities, ) @staticmethod diff --git a/backpack/extensions/saved_quantities.py b/backpack/extensions/saved_quantities.py index 78249dbd5..3eb38ad8e 100644 --- a/backpack/extensions/saved_quantities.py +++ b/backpack/extensions/saved_quantities.py @@ -1,5 +1,5 @@ """Class for saving backpropagation quantities.""" -from typing import Dict, Union +from typing import Any, Callable, Dict, Union from torch import Tensor @@ -11,24 +11,28 @@ def __init__(self): """Initialization.""" self._saved_quantities: Dict[int, Tensor] = {} - def save_quantity(self, key: int, quantity: Tensor) -> None: + def save_quantity( + self, + key: int, + quantity: Tensor, + accumulation_function: Callable[[Any, Any], Any], + ) -> None: """Saves the quantity under the specified key. + Accumulate quantities which already have an entry. + Args: key: data_ptr() of reference tensor (module.input0). quantity: tensor to save - - Raises: - NotImplementedError: if the key already exists + accumulation_function: function defining how to accumulate quantity """ if key in self._saved_quantities: - # TODO if exists: accumulate quantities (ResNet) - raise NotImplementedError( - "Quantity with given key already exists. Multiple backpropagated " - "quantities like in ResNets are not supported yet." - ) + existing = self.retrieve_quantity(key, delete_old=True) + save_value = accumulation_function(existing, quantity) else: - self._saved_quantities[key] = quantity + save_value = quantity + + self._saved_quantities[key] = save_value def retrieve_quantity(self, key: int, delete_old: bool) -> Union[Tensor, None]: """Returns the saved quantity. diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index d1cf20f53..1fa7742f1 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -8,6 +8,7 @@ BatchDiagGGNExact(BatchDiagGGN) BatchDiagGGNMC(BatchDiagGGN) """ +from torch import Tensor from torch.nn import ( ELU, RNN, @@ -30,6 +31,7 @@ CrossEntropyLoss, Dropout, Flatten, + Identity, LeakyReLU, Linear, LogSigmoid, @@ -43,7 +45,9 @@ ZeroPad2d, ) +from backpack.custom_module.branching import ActiveIdentity, SumModule from backpack.custom_module.permute import Permute +from backpack.custom_module.scale_module import ScaleModule from backpack.extensions.secondorder.base import SecondOrderBackpropExtension from backpack.extensions.secondorder.hbp import LossHessianStrategy @@ -57,6 +61,7 @@ convtranspose1d, convtranspose2d, convtranspose3d, + custom_module, dropout, flatten, linear, @@ -122,6 +127,10 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): LogSigmoid: activations.DiagGGNLogSigmoid(), ELU: activations.DiagGGNELU(), SELU: activations.DiagGGNSELU(), + Identity: custom_module.DiagGGNScaleModule(), + ActiveIdentity: custom_module.DiagGGNScaleModule(), + ScaleModule: custom_module.DiagGGNScaleModule(), + SumModule: custom_module.DiagGGNSumModule(), RNN: rnn.DiagGGNRNN(), Permute: permute.DiagGGNPermute(), AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1), @@ -133,6 +142,11 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): }, ) + def accumulate_backpropagated_quantities( + self, existing: Tensor, other: Tensor + ) -> Tensor: # noqa: D102 + return existing + other + class DiagGGNExact(DiagGGN): """Diagonal of the Generalized Gauss-Newton/Fisher. @@ -235,6 +249,10 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): LogSigmoid: activations.DiagGGNLogSigmoid(), ELU: activations.DiagGGNELU(), SELU: activations.DiagGGNSELU(), + Identity: custom_module.DiagGGNScaleModule(), + ActiveIdentity: custom_module.DiagGGNScaleModule(), + ScaleModule: custom_module.DiagGGNScaleModule(), + SumModule: custom_module.DiagGGNSumModule(), RNN: rnn.BatchDiagGGNRNN(), Permute: permute.DiagGGNPermute(), AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1), @@ -246,6 +264,11 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): }, ) + def accumulate_backpropagated_quantities( + self, existing: Tensor, other: Tensor + ) -> Tensor: # noqa: D102 + return existing + other + class BatchDiagGGNExact(BatchDiagGGN): """Individual diagonal of the Generalized Gauss-Newton/Fisher. diff --git a/backpack/extensions/secondorder/diag_ggn/custom_module.py b/backpack/extensions/secondorder/diag_ggn/custom_module.py new file mode 100644 index 000000000..293ed4281 --- /dev/null +++ b/backpack/extensions/secondorder/diag_ggn/custom_module.py @@ -0,0 +1,20 @@ +"""DiagGGN extensions for backpack's custom modules.""" +from backpack.core.derivatives.scale_module import ScaleModuleDerivatives +from backpack.core.derivatives.sum_module import SumModuleDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class DiagGGNScaleModule(DiagGGNBaseModule): + """DiagGGN extension for ScaleModule.""" + + def __init__(self): + """Initialization.""" + super().__init__(derivatives=ScaleModuleDerivatives()) + + +class DiagGGNSumModule(DiagGGNBaseModule): + """DiagGGN extension for SumModule.""" + + def __init__(self): + """Initialization.""" + super().__init__(derivatives=SumModuleDerivatives()) diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index 5e292385d..e51232283 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -9,6 +9,7 @@ TORCH_VERSION_AT_LEAST_1_9_0 = TORCH_VERSION >= packaging.version.parse("1.9.0") TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1") FULL_BACKWARD_HOOK: bool = TORCH_VERSION_AT_LEAST_1_9_0 +CONVERTER_AVAILABLE: bool = TORCH_VERSION_AT_LEAST_1_9_0 def exception_inside_backward_pass(error: Type[Exception]) -> Type[Exception]: diff --git a/backpack/utils/examples.py b/backpack/utils/examples.py index eb9392613..788241140 100644 --- a/backpack/utils/examples.py +++ b/backpack/utils/examples.py @@ -1,35 +1,117 @@ """Utility functions for examples.""" -import torch -import torchvision +from typing import Iterator, List, Tuple +from torch import Tensor, stack, zeros +from torch.nn import Module +from torch.nn.utils.convert_parameters import parameters_to_vector +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import MNIST +from torchvision.transforms import Compose, Normalize, ToTensor -def load_mnist_dataset(): - """Download and normalize MNIST training data.""" - mnist_dataset = torchvision.datasets.MNIST( +from backpack.hessianfree.ggnvp import ggn_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list + + +def load_mnist_dataset() -> Dataset: + """Download and normalize MNIST training data. + + Returns: + Normalized MNIST dataset + """ + return MNIST( root="./data", train=True, - transform=torchvision.transforms.Compose( - [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.1307,), (0.3081,)), - ] - ), + transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]), download=True, ) - return mnist_dataset -def get_mnist_dataloader(batch_size=64, shuffle=True): - """Returns a dataloader for MNIST""" - return torch.utils.data.dataloader.DataLoader( - load_mnist_dataset(), - batch_size=batch_size, - shuffle=shuffle, - ) +def get_mnist_dataloader(batch_size: int = 64, shuffle: bool = True) -> DataLoader: + """Returns a dataloader for MNIST. + + Args: + batch_size: Mini-batch size. Default: ``64``. + shuffle: Randomly shuffle the data. Default: ``True``. + + Returns: + MNIST dataloader + """ + return DataLoader(load_mnist_dataset(), batch_size=batch_size, shuffle=shuffle) -def load_one_batch_mnist(batch_size=64, shuffle=True): - """Return a single batch (inputs, labels) of MNIST data.""" +def load_one_batch_mnist( + batch_size: int = 64, shuffle: bool = True +) -> Tuple[Tensor, Tensor]: + """Return a single mini-batch (inputs, labels) from MNIST. + + Args: + batch_size: Mini-batch size. Default: ``64``. + shuffle: Randomly shuffle the data. Default: ``True``. + + Returns: + A single batch (inputs, labels) from MNIST. + """ dataloader = get_mnist_dataloader(batch_size, shuffle) X, y = next(iter(dataloader)) + return X, y + + +def autograd_diag_ggn_exact( + X: Tensor, y: Tensor, model: Module, loss_function: Module, idx: List[int] = None +) -> Tensor: + """Compute the generalized Gauss-Newton diagonal with ``torch.autograd``. + + Args: + X: Input to the model. + y: Labels. + model: The neural network. + loss_function: Loss function module. + idx: Indices for which the diagonal entries are computed. Default value ``None`` + computes the full diagonal. + + Returns: + Exact GGN diagonal (flattened and concatenated). + """ + diag_elements = [ + col[col_idx] + for col_idx, col in _autograd_ggn_exact_columns( + X, y, model, loss_function, idx=idx + ) + ] + + return stack(diag_elements) + + +def _autograd_ggn_exact_columns( + X: Tensor, y: Tensor, model: Module, loss_function: Module, idx: List[int] = None +) -> Iterator[Tuple[int, Tensor]]: + """Yield exact generalized Gauss-Newton's columns computed with ``torch.autograd``. + + Args: + X: Input to the model. + y: Labels. + model: The neural network. + loss_function: Loss function module. + idx: Indices of columns that are computed. Default value ``None`` computes all + columns. + + Yields: + Tuple of column index and respective GGN column (flattened and concatenated). + """ + trainable_parameters = [p for p in model.parameters() if p.requires_grad] + D = sum(p.numel() for p in trainable_parameters) + + outputs = model(X) + loss = loss_function(outputs, y) + + idx = idx if idx is not None else list(range(D)) + + for d in idx: + e_d = zeros(D, device=loss.device, dtype=loss.dtype) + e_d[d] = 1.0 + e_d_list = vector_to_parameter_list(e_d, trainable_parameters) + + ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list) + + yield d, parameters_to_vector(ggn_d_list) diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py index 7be68d96a..c28169eab 100644 --- a/backpack/utils/module_classification.py +++ b/backpack/utils/module_classification.py @@ -1,9 +1,13 @@ """Contains util function for classification of modules.""" - from torch.nn import Module, Sequential from torch.nn.modules.loss import _Loss +from backpack.custom_module.branching import Parallel, _Branch from backpack.custom_module.reduce_tuple import ReduceTuple +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 + +if TORCH_VERSION_AT_LEAST_1_9_0: + from torch.fx import GraphModule def is_loss(module: Module) -> bool: @@ -27,4 +31,7 @@ def is_no_op(module: Module) -> bool: Returns: whether module is no operation """ - return isinstance(module, (Sequential, ReduceTuple)) + no_op_modules = (Sequential, _Branch, Parallel, ReduceTuple) + if TORCH_VERSION_AT_LEAST_1_9_0: + no_op_modules += (GraphModule,) + return isinstance(module, no_op_modules) diff --git a/docs_src/examples/use_cases/example_first_order_resnet.py b/docs_src/examples/use_cases/example_first_order_resnet.py index 0a04b4ce3..a1a1075bc 100644 --- a/docs_src/examples/use_cases/example_first_order_resnet.py +++ b/docs_src/examples/use_cases/example_first_order_resnet.py @@ -1,101 +1,8 @@ r"""First order extensions with a ResNet ======================================== - """ # %% -# Let's get the imports, configuration and some helper functions out of the way first. - -import torch -import torch.nn.functional as F - -from backpack import backpack, extend -from backpack.extensions import BatchGrad -from backpack.utils.examples import load_one_batch_mnist - -BATCH_SIZE = 3 -torch.manual_seed(0) -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -def get_accuracy(output, targets): - """Helper function to print the accuracy""" - predictions = output.argmax(dim=1, keepdim=True).view_as(targets) - return predictions.eq(targets).float().mean().item() - - -x, y = load_one_batch_mnist(batch_size=BATCH_SIZE) -x, y = x.to(DEVICE), y.to(DEVICE) - - -# %% -# We can build a ResNet by extending :py:class:`torch.nn.Module`. -# As long as the layers with parameters -# (:py:class:`torch.nn.Conv2d` and :py:class:`torch.nn.Linear`) are -# ``nn`` modules, BackPACK can extend them, -# and this is all that is needed for first order extensions. -# We can rewrite the forward to implement the residual connection, -# and :py:func:`extend() ` the resulting model. - - -class MyFirstResNet(torch.nn.Module): - def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10): - super().__init__() - - self.conv1 = torch.nn.Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1) - self.conv2 = torch.nn.Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1) - self.linear1 = torch.nn.Linear(input_dim[0] * input_dim[1] * C_hid, output_dim) - if C_in == C_hid: - self.shortcut = torch.nn.Identity() - else: - self.shortcut = torch.nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1) - - def forward(self, x): - residual = self.shortcut(x) - x = self.conv2(F.relu(self.conv1(x))) - x = x + residual # don't use: x += residual - x = x.view(x.size(0), -1) - x = self.linear1(x) - return x - - -model = extend(MyFirstResNet()).to(DEVICE) - -# %% -# Using :py:class:`BatchGrad ` in a -# :py:class:`with backpack(...) ` block, -# we can access the individual gradients for each sample. -# -# The loss does not need to be extended in this case either, as it does not -# have model parameters and BackPACK does not need to know about it for -# first order extensions. This also means you can use any custom loss function. - -model.zero_grad() -loss = F.cross_entropy(model(x), y, reduction="sum") -with backpack(BatchGrad()): - loss.backward() - -print("{:<20} {:<30} {:<30}".format("Param", "grad", "grad (batch)")) -print("-" * 80) -for name, p in model.named_parameters(): - print( - "{:<20}: {:<30} {:<30}".format(name, str(p.grad.shape), str(p.grad_batch.shape)) - ) - -# %% -# To check that everything works, let's compute one individual gradient with -# PyTorch (using a single sample in a forward and backward pass) -# and compare it with the one computed by BackPACK. - -sample_to_check = 1 -x_to_check = x[sample_to_check, :].unsqueeze(0) -y_to_check = y[sample_to_check].unsqueeze(0) - -model.zero_grad() -loss = F.cross_entropy(model(x_to_check), y_to_check) -loss.backward() - -print("Do the individual gradients match?") -for name, p in model.named_parameters(): - match = torch.allclose(p.grad_batch[sample_to_check, :], p.grad, atol=1e-7) - print("{:<20} {}".format(name, match)) +# This tutorial has moved. Click +# `here `_ +# to continue to its new location. diff --git a/docs_src/examples/use_cases/example_resnet_all_in_one.py b/docs_src/examples/use_cases/example_resnet_all_in_one.py new file mode 100644 index 000000000..c06483de3 --- /dev/null +++ b/docs_src/examples/use_cases/example_resnet_all_in_one.py @@ -0,0 +1,325 @@ +"""Residual networks +==================== +""" +# %% +# There are three different approaches to using BackPACK with ResNets. +# +# 1. :ref:`Custom ResNet`: (Only works for first-order extensions) Write your own model +# by defining its forward pass. Trainable parameters must be in modules known to +# BackPACK (e.g. :class:`torch.nn.Conv2d`, :class:`torch.nn.Linear`). +# +# 2. :ref:`Custom ResNet with BackPACK custom modules`: (Works for first- and second- +# order extensions) Build your ResNet with custom modules provided by BackPACK +# without overwriting the forward pass. This approach is useful if you want to +# understand how BackPACK handles ResNets, or if you think building a container +# module that implicitly defines the forward pass is more elegant than coding up +# a forward pass. +# +# 3. :ref:`Any ResNet with BackPACK's converter`: (Works for first- and second-order +# extensions) Convert your model into a BackPACK-compatible architecture. +# +# .. note:: +# ResNets are still an experimental feature. Always double-check your +# results, as done in this example! Open an issue if you encounter a bug to help +# us improve the support. +# +# Not all extensions support ResNets (yet). Please create a feature request in the +# repository if the extension you need is not supported. + +# %% +# Let's get the imports out of the way. + +from torch import ( + allclose, + cat, + cuda, + device, + int32, + linspace, + manual_seed, + rand, + rand_like, +) +from torch.nn import ( + Conv2d, + CrossEntropyLoss, + Flatten, + Identity, + Linear, + Module, + MSELoss, + ReLU, + Sequential, +) +from torch.nn.functional import cross_entropy, relu +from torchvision.models import resnet18 + +from backpack import backpack, extend +from backpack.custom_module.branching import ActiveIdentity, Parallel, SumModule +from backpack.custom_module.graph_utils import BackpackTracer +from backpack.extensions import BatchGrad, DiagGGNExact +from backpack.utils.examples import autograd_diag_ggn_exact, load_one_batch_mnist + +manual_seed(0) +DEVICE = device("cuda:0" if cuda.is_available() else "cpu") +x, y = load_one_batch_mnist(batch_size=32) +x, y = x.to(DEVICE), y.to(DEVICE) + + +# %% +# Custom ResNet +# ------------- +# We can build a ResNet by extending :py:class:`torch.nn.Module`. +# As long as the layers with parameters (:py:class:`torch.nn.Conv2d` +# and :py:class:`torch.nn.Linear`) are ``nn`` modules, BackPACK can extend them, +# and this is all that is needed for first-order extensions. +# We can rewrite the :code:`forward` to implement the residual connection, +# and :py:func:`extend() ` the resulting model. +# +# .. note:: +# Using in-place operations is not compatible with PyTorch's +# :meth:`torch.nn.Module.register_full_backward_hook`. Therefore, +# always use :code:`x = x + residual` instead of :code:`x += residual`. + + +class MyFirstResNet(Module): + def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10): + """Instantiate submodules that are used in the forward pass.""" + super().__init__() + + self.conv1 = Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1) + self.conv2 = Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1) + self.linear1 = Linear(input_dim[0] * input_dim[1] * C_hid, output_dim) + if C_in == C_hid: + self.shortcut = Identity() + else: + self.shortcut = Conv2d(C_in, C_hid, kernel_size=1, stride=1) + + def forward(self, x): + """Manual implementation of the forward pass.""" + residual = self.shortcut(x) + x = self.conv2(relu(self.conv1(x))) + x = x + residual # don't use: x += residual + x = x.flatten(start_dim=1) + x = self.linear1(x) + return x + + +model = extend(MyFirstResNet()).to(DEVICE) + +# %% +# The loss does not need to be extended in this case either, as it does not +# have model parameters and BackPACK does not need to know about it for +# first-order extensions. This also means you can use any custom loss function. +# +# Using :py:class:`BatchGrad ` in a +# :py:class:`with backpack(...) ` block, +# we can access the individual gradients for each sample. + +loss = cross_entropy(model(x), y, reduction="sum") + +with backpack(BatchGrad()): + loss.backward() + +for name, parameter in model.named_parameters(): + print(f"{name:>20}'s grad_batch shape: {parameter.grad_batch.shape}") + +# %% +# To check that everything works, let's compute one individual gradient with +# PyTorch (using a single sample in a forward and backward pass) +# and compare it with the one computed by BackPACK. + +sample_to_check = 1 +x_to_check = x[[sample_to_check]] +y_to_check = y[[sample_to_check]] + +model.zero_grad() +loss = cross_entropy(model(x_to_check), y_to_check) +loss.backward() + +print("Do the individual gradients match?") +for name, parameter in model.named_parameters(): + match = allclose(parameter.grad_batch[sample_to_check], parameter.grad, atol=1e-6) + print(f"{name:>20}: {match}") + if not match: + raise AssertionError("Individual gradients don't match!") + +# %% +# Custom ResNet with BackPACK custom modules +# ------------- +# Second-order extensions only work if every node in the computation graph is an +# ``nn`` module that can be extended by BackPACK. The above ResNet class +# :py:class:`MyFirstResNet` does not satisfy these conditions, because +# it implements the skip connection via :py:func:`torch.add` while overwriting the +# :py:meth:`forward() ` method. +# +# To build ResNets without overwriting the forward pass, BackPACK offers custom modules: +# +# 1. :py:class:`Parallel` is similar to +# :py:class:`torch.nn.Sequential`, but implements a container for a parallel sequence +# of modules (followed by an aggregation module), rather than a sequential one. +# +# 2. :py:class:`SumModule` is the module that takes the +# role of :py:func:`torch.add` in the previous example. It sums up multiple inputs. +# We will use it to merge the skip connection. +# +# 3. :py:class:`ActiveIdentity` acts like +# PyTorch's identity, but fixes the backward hook execution order by inserting a new +# node into the graph during a forward pass (for details see +# `this discussion `_). +# The problem is fixed for ``torch >= 1.9.0``, where it's safe to use +# :py:class:`torch.nn.Identity`. If you are on ``torch < 1.9.0``, you +# have to use :py:class:`ActiveIdentity`. +# +# With the above modules, we can build a simple ResNet as a container that implicitly +# defines the forward pass: + +C_in = 1 +C_hid = 2 +input_dim = (28, 28) +output_dim = 10 + +model = Sequential( + Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1), + ReLU(), + Parallel( # skip connection with ReLU-activated convolution + ActiveIdentity(), + Sequential( + Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1), + ReLU(), + ), + merge_module=SumModule(), + ), + Flatten(), + Linear(input_dim[0] * input_dim[1] * C_hid, output_dim), +) + +model = extend(model.to(DEVICE)) +loss_function = extend(CrossEntropyLoss(reduction="mean")).to(DEVICE) + + +# %% +# This ResNets supports BackPACK's second-order extensions: + +loss = loss_function(model(x), y) + +with backpack(DiagGGNExact()): + loss.backward() + +for name, parameter in model.named_parameters(): + print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}") + +diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()]) + +# %% +# Comparison with :py:mod:`torch.autograd`: +# +# .. note:: +# +# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation +# can be slow, depending on the number of parameters. To reduce run time, we only +# compare some elements of the diagonal. + +num_params = sum(p.numel() for p in model.parameters()) +num_to_compare = 10 +idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32) + +diag_ggn_exact_to_compare = autograd_diag_ggn_exact( + x, y, model, loss_function, idx=idx_to_compare +) + +print("Do the exact GGN diagonals match?") +for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): + match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6) + print(f"Diagonal entry {idx:>6}: {match}") + if not match: + raise AssertionError("Exact GGN diagonals don't match!") + +# %% +# Any ResNet with BackPACK's converter +# ------------- +# If you are not building a ResNet through custom modules but for instance want to +# use a prominent ResNet from :py:mod:`torchvision.models`, BackPACK offers a converter. +# It analyzes the model and tries to turn it into a compatible architecture. The result +# is a :py:class:`torch.fx.GraphModule` that exclusively consists of modules. +# +# Here, we demo the converter on :py:class:`resnet18 `. +# +# .. note:: +# +# :py:class:`resnet18 ` has to be in evaluation mode, +# because it contains batch normalization layers that are not supported in train +# mode by the second-order extension used in this example. +# +# Let's create the model, and convert it in the call to :py:func:`extend `: + +loss_function = extend(MSELoss().to(DEVICE)) +model = resnet18(num_classes=5).to(DEVICE).eval() + +# use BackPACK's converter to extend the model (turned off by default) +model = extend(model, use_converter=True) + +# %% +# To get an understanding what happened, we can inspect the model's graph with the +# following helper function: + + +def print_table(module: Module) -> None: + """Prints a table of the module. + + Args: + module: module to analyze + """ + graph = BackpackTracer().trace(module) + graph.print_tabular() + + +print_table(model) + +# %% +# Admittedly, the converted :py:class:`resnet18 `'s graph +# is quite large. Note however that it fully consists of modules (indicated by +# ``call_module`` in the first table column) such that BackPACK's hooks can +# successfully backpropagate additional information for its second-order extensions +# (first-order extensions work, too). +# +# Let's verify that second-order extensions are working: + +x = rand(4, 3, 7, 7, device=DEVICE) # (128, 3, 224, 224) +output = model(x) +y = rand_like(output) + +loss = loss_function(output, y) + +with backpack(DiagGGNExact()): + loss.backward() + +for name, parameter in model.named_parameters(): + print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}") + +diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()]) + +# %% +# Comparison with :py:mod:`torch.autograd`: +# +# .. note:: +# +# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation +# can be slow, depending on the number of parameters. To reduce run time, we only +# compare some elements of the diagonal. + +num_params = sum(p.numel() for p in model.parameters()) +num_to_compare = 10 +idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32) + +diag_ggn_exact_to_compare = autograd_diag_ggn_exact( + x, y, model, loss_function, idx=idx_to_compare +) + +print("Do the exact GGN diagonals match?") +for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): + match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6) + print(f"Diagonal entry {idx:>8}: {match}") + if not match: + raise AssertionError("Exact GGN diagonals don't match!") diff --git a/docs_src/rtd/good-to-know.rst b/docs_src/rtd/good-to-know.rst index 23911ea9e..8f7a777ad 100644 --- a/docs_src/rtd/good-to-know.rst +++ b/docs_src/rtd/good-to-know.rst @@ -74,16 +74,13 @@ for ResNets (see :ref:`this example `). Not (yet) supported models ---------------------------------- -The second-order extensions for BackPACK don't support (yet) residual networks, -and no extension support recurrent architectures. -We're working on how to handle those, as well as adding more -:ref:`layers `. +The second-order extensions for BackPACK partially support residual and +recurrent networks. We're working on how to handle those, as well as +adding more :ref:`layers `. Along those lines, some things that will (most likely) not work with BackPACK, but that we're trying to build support for: -- Reusing the same parameters or module multiple time in the computation graph. +- Reusing the same parameters multiple times in the computation graph. - For second order extensions, this also holds for any module, - whether or not they have parameters. This sadly mean that BackPACK can't compute the individual gradients or second-order information of a L2-regularized loss, for example. diff --git a/fully_documented.txt b/fully_documented.txt index 15b013b6a..f6018851c 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -13,6 +13,8 @@ backpack/core/derivatives/lstm.py backpack/core/derivatives/linear.py backpack/core/derivatives/adaptive_avg_pool_nd.py backpack/core/derivatives/batchnorm_nd.py +backpack/core/derivatives/scale_module.py +backpack/core/derivatives/sum_module.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py @@ -44,6 +46,7 @@ backpack/extensions/secondorder/diag_ggn/rnn.py backpack/extensions/secondorder/diag_ggn/permute.py backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py +backpack/extensions/secondorder/diag_ggn/custom_module.py backpack/extensions/secondorder/diag_hessian/__init__.py backpack/extensions/secondorder/diag_hessian/conv1d.py backpack/extensions/secondorder/diag_hessian/conv2d.py @@ -58,6 +61,7 @@ backpack/utils/errors.py backpack/utils/__init__.py backpack/utils/module_classification.py backpack/utils/hooks.py +backpack/utils/examples.py test/extensions/automated_settings.py test/extensions/problem.py @@ -86,8 +90,10 @@ test/core/derivatives/permute_settings.py test/core/derivatives/lstm_settings.py test/core/derivatives/pooling_adaptive_settings.py test/core/derivatives/batch_norm_settings.py +test/core/derivatives/scale_module_settings.py test/utils/evaluation_mode.py test/utils/skip_test.py test/utils/__init__.py +test/converter/ test/utils/test_subsampling.py test/custom_module/ diff --git a/setup.cfg b/setup.cfg index 32ae54570..82d6f1fe1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -77,6 +77,7 @@ docs = matplotlib sphinx-gallery memory_profiler + tabulate ############################################################################### # Development tool configurations # diff --git a/test/converter/__init__.py b/test/converter/__init__.py new file mode 100644 index 000000000..0699f9ffe --- /dev/null +++ b/test/converter/__init__.py @@ -0,0 +1 @@ +"""Contains tests for the converter and ResNets.""" diff --git a/test/converter/converter_cases.py b/test/converter/converter_cases.py new file mode 100644 index 000000000..67dc9f832 --- /dev/null +++ b/test/converter/converter_cases.py @@ -0,0 +1,161 @@ +"""Test cases for the converter. + +Network with resnet18 +Network with inplace activation +Network with parameter-free module used in multiple places +Network with flatten operation +Network with multiply operation +Network with add operation +""" +import abc +from typing import List, Type + +from torch import Tensor, flatten, rand +from torch.nn import Linear, Module, ReLU +from torchvision.models import resnet18, wide_resnet50_2 + + +class ConverterModule(Module, abc.ABC): + """Interface class for test modules for converter.""" + + @abc.abstractmethod + def input_fn(self) -> Tensor: + """Generate a fitting input for the module. + + Returns: + an input + """ + return + + +CONVERTER_MODULES: List[Type[ConverterModule]] = [] + + +class _ResNet18(ConverterModule): + def __init__(self): + super().__init__() + self.resnet18 = resnet18(num_classes=4).eval() + + def forward(self, x): + return self.resnet18(x) + + def input_fn(self) -> Tensor: + return rand(2, 3, 7, 7) + + +class _WideResNet50(ConverterModule): + def __init__(self): + super().__init__() + self.wide_resnet50_2 = wide_resnet50_2(num_classes=4).eval() + + def forward(self, x): + return self.wide_resnet50_2(x) + + def input_fn(self) -> Tensor: + return rand(2, 3, 7, 7) + + +class _InplaceActivation(ConverterModule): + def __init__(self): + self.in_dim = 3 + out_dim = 2 + super().__init__() + self.linear = Linear(self.in_dim, out_dim) + self.relu = ReLU(inplace=True) + self.linear2 = Linear(out_dim, out_dim) + + def forward(self, x): + x = self.linear(x) + x = self.relu(x) + x = self.linear2(x) + return x + + def input_fn(self) -> Tensor: + return rand(3, self.in_dim) + + +class _MultipleUsages(ConverterModule): + def __init__(self): + super().__init__() + self.in_dim = 3 + out_dim = 2 + self.linear = Linear(self.in_dim, out_dim) + self.relu = ReLU() + self.linear2 = Linear(out_dim, out_dim) + + def forward(self, x): + x = self.relu(x) + x = self.linear(x) + x = self.relu(x) + x = self.linear2(x) + x = self.relu(x) + return x + + def input_fn(self) -> Tensor: + return rand(3, self.in_dim) + + +class _FlattenNetwork(ConverterModule): + def __init__(self): + super().__init__() + self.in_dim = 4 + out_dim = 3 + self.linear = Linear(self.in_dim, out_dim) + + def forward(self, x): + x = self.linear(x) + x = flatten(x, 1) + return x + + def input_fn(self) -> Tensor: + return rand(3, 2, 2, self.in_dim) + + +class _Multiply(ConverterModule): + def __init__(self): + super().__init__() + self.in_dim = 4 + out_dim = 3 + self.linear = Linear(self.in_dim, out_dim) + + def forward(self, x): + x = x * 2.5 + x = self.linear(x) + x = 0.5 * x + return x + + def input_fn(self) -> Tensor: + return rand(2, self.in_dim) + + +class _Add(ConverterModule): + def __init__(self): + super().__init__() + self.in_dim = 3 + out_dim = 2 + self.linear = Linear(self.in_dim, self.in_dim) + self.linear1 = Linear(self.in_dim, out_dim) + self.linear2 = Linear(self.in_dim, out_dim) + self.relu = ReLU() + + def forward(self, x): + x = self.linear(x) + x1 = self.linear1(x) + x2 = self.linear2(x) + x = x1 + x2 + x = self.relu(x) + return x + + def input_fn(self) -> Tensor: + return rand(3, self.in_dim) + + +CONVERTER_MODULES += [ + _ResNet18, + _WideResNet50, + _InplaceActivation, + _MultipleUsages, + _FlattenNetwork, + _Multiply, + _Add, +] diff --git a/test/converter/resnet_cases.py b/test/converter/resnet_cases.py new file mode 100644 index 000000000..4d5e00de7 --- /dev/null +++ b/test/converter/resnet_cases.py @@ -0,0 +1,146 @@ +"""Contains example ResNets to be used in tests.""" +from torch import flatten, tensor +from torch.nn import ( + AdaptiveAvgPool2d, + BatchNorm2d, + Conv2d, + Linear, + MaxPool2d, + Module, + MSELoss, + ReLU, + Sequential, + Tanh, +) +from torchvision.models.resnet import BasicBlock, conv1x1 + + +class ResNet1(Module): + """Small ResNet.""" + + def __init__(self, in_dim: int = 2, out_dim: int = 10): + """Initialization. + + Args: + in_dim: input dimensions + out_dim: output dimensions + """ + super().__init__() + self.net = Sequential( + Linear(in_dim, out_dim), + Tanh(), + Linear(out_dim, out_dim), + Tanh(), + Linear(out_dim, in_dim), + ) + + def forward(self, input): + """Forward pass. One Euler step. + + Args: + input: input tensor + + Returns: + result + """ + x = self.net(input) + return input + x * 0.1 + + input_test = tensor([[1.0, 2.0]]) + target_test = tensor([[1.0, 1.0]]) + loss_test = MSELoss() + + +class ResNet2(Module): + """Replicates resnet18 but a lot smaller.""" + + num_classes: int = 3 + batch_size: int = 2 + picture_width: int = 7 + inplanes = 2 + + input_test = (batch_size, 3, picture_width, picture_width) + target_test = (batch_size, num_classes) + loss_test = MSELoss() + + def __init__(self): + """Initialization.""" + super().__init__() + self.inplanes = ResNet2.inplanes + + self.conv1 = Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = BatchNorm2d(self.inplanes) + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(BasicBlock, ResNet2.inplanes, 2) + self.layer2 = self._make_layer(BasicBlock, 2 * ResNet2.inplanes, 2, stride=2) + self.layer3 = self._make_layer(BasicBlock, 4 * ResNet2.inplanes, 2, stride=2) + self.avgpool = AdaptiveAvgPool2d((1, 1)) + self.fc = Linear(4 * ResNet2.inplanes, self.num_classes) + + def forward(self, x): + """Forward pass. + + Args: + x: input tensor + + Returns: + result + """ + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x = flatten(x, 1) + x = self.fc(x) + + return x + + def _make_layer(self, block, planes, blocks, stride=1): + """Creates a concatenation of blocks in the ResNet. + + This function is similar to the one in torchvision/resnets. + https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html + + Args: + block: basic block to use (with one skip connection) + planes: number of parallel planes + blocks: number of sequential blocks + stride: factor between input and output planes + + Returns: + a sequence of blocks + """ + norm_layer = BatchNorm2d + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [ + block(self.inplanes, planes, stride, downsample, 1, 64, 1, norm_layer) + ] + self.inplanes = planes * block.expansion + layers += [ + block( + self.inplanes, + planes, + groups=1, + base_width=64, + dilation=1, + norm_layer=norm_layer, + ) + for _ in range(1, blocks) + ] + + return Sequential(*layers) diff --git a/test/converter/test_branching.py b/test/converter/test_branching.py new file mode 100644 index 000000000..4fd968878 --- /dev/null +++ b/test/converter/test_branching.py @@ -0,0 +1,145 @@ +"""This test demonstrates the custom modules in BackPACK for branching/ResNets. + +It is important for torch < 1.9.0 (no full backward hook), as the test fails without +ActiveIdentity (wrong execution order in backward hook). +For torch>=1.9.0 (full backward hook), all tests pass. + +Additionally, for torch>=1.9.0 there is a convenient option use_converter=True in extend(). + +TODO Delete after supporting torch >= 1.9.0 (full backward hook and converter). +""" +from contextlib import nullcontext +from test.automated_test import check_sizes_and_values +from test.core.derivatives.utils import classification_targets +from typing import Callable, Tuple + +from pytest import mark, raises +from torch import Tensor, cat, manual_seed, rand +from torch.nn import ( + CrossEntropyLoss, + Identity, + Linear, + Module, + ReLU, + Sequential, + Sigmoid, +) + +from backpack import backpack, extend, extensions +from backpack.custom_module.branching import ActiveIdentity, Parallel +from backpack.utils import FULL_BACKWARD_HOOK, exception_inside_backward_pass +from backpack.utils.examples import autograd_diag_ggn_exact + + +def setup( + apply_extend: bool = False, active_identity: bool = True +) -> Tuple[Tensor, Tensor, Module, Module]: + """Set seed. Generate and return inputs, labels, model and loss function. + + A simple ResNet using the ``Parallel`` convenience module around the ``Branch`` and + ``SumModule`` modules to handle branching. + + Args: + active_identity: Whether the identity function should create a new node + in the computation graph. + apply_extend: Whether model and loss function should be extended. + + Returns: + X, y, model, loss_function + """ + manual_seed(0) + + N = 7 + + in_features = 10 + hidden_features = 5 + out_features = 3 + + X = rand((N, in_features)) + y = classification_targets((N,), out_features) + + identity = ActiveIdentity() if active_identity else Identity() + + model = Sequential( + Linear(in_features, hidden_features), + ReLU(), + # skip connection + Parallel( + identity, + Linear(hidden_features, hidden_features), + ), + # end of skip connection + Sigmoid(), + Linear(hidden_features, out_features), + ) + loss_function = CrossEntropyLoss(reduction="mean") + + if apply_extend: + model = extend(model, debug=True) + loss_function = extend(loss_function, debug=True) + + return X, y, model, loss_function + + +def backpack_diag_ggn_exact( + X: Tensor, y: Tensor, model: Module, loss_function: Module +) -> Tensor: + """Compute the generalized Gauss-Newton diagonal via BackPACK. + + Args: + X: data + y: target + model: model + loss_function: loss function + + Returns: + diag_ggn_exact + """ + outputs = model(X) + loss = loss_function(outputs, y) + + with backpack(extensions.DiagGGNExact(), debug=True): + loss.backward() + + return cat([p.diag_ggn_exact.flatten() for p in model.parameters()]) + + +SETUPS = [setup] +SETUPS_IDS = ["simple-resnet"] + + +@mark.parametrize("setup_fn", SETUPS, ids=SETUPS_IDS) +def test_diag_ggn_exact_active_identity(setup_fn: Callable) -> None: + """Compare BackPACK's diagonal GGN of a ResNet with autograd. + + Args: + setup_fn: setup function + """ + X, y, model, loss_function = setup_fn() + autograd_result = autograd_diag_ggn_exact(X, y, model, loss_function) + + X, y, model, loss_function = setup_fn(apply_extend=True) + backpack_result = backpack_diag_ggn_exact(X, y, model, loss_function) + + check_sizes_and_values(autograd_result, backpack_result) + + +@mark.parametrize("setup_fn", SETUPS, ids=SETUPS_IDS) +def test_diag_ggn_exact_nn_identity_fails(setup_fn: Callable) -> None: + """``torch.nn.Identity`` does not create a node and messes up backward hooks. + + However, it works fine if using full backward hook. (torch >= 1.9.0) + + Args: + setup_fn: setup function + """ + X, y, model, loss_function = setup_fn(active_identity=False) + autograd_result = autograd_diag_ggn_exact(X, y, model, loss_function) + + X, y, model, loss_function = setup_fn(apply_extend=True, active_identity=False) + with nullcontext() if FULL_BACKWARD_HOOK else raises( + exception_inside_backward_pass(AssertionError) + ): + backpack_result = backpack_diag_ggn_exact(X, y, model, loss_function) + + check_sizes_and_values(autograd_result, backpack_result) diff --git a/test/converter/test_converter.py b/test/converter/test_converter.py new file mode 100644 index 000000000..66d390474 --- /dev/null +++ b/test/converter/test_converter.py @@ -0,0 +1,79 @@ +"""Tests converter. + +- whether converted network is equivalent to original network +- whether DiagGGN runs without errors on new network +""" +from test.converter.converter_cases import CONVERTER_MODULES, ConverterModule +from test.utils.skip_test import skip_pytorch_below_1_9_0 +from typing import Tuple + +from pytest import fixture +from torch import Tensor, allclose, cat, int32, linspace, manual_seed, rand_like +from torch.nn import Module, MSELoss + +from backpack import backpack, extend +from backpack.extensions import DiagGGNExact +from backpack.utils.examples import autograd_diag_ggn_exact + + +@fixture( + params=CONVERTER_MODULES, + ids=[str(model_class) for model_class in CONVERTER_MODULES], +) +def model_and_input(request) -> Tuple[Module, Tensor]: + """Yield ResNet model and an input to it. + + Args: + request: pytest request + + Yields: + model and input + """ + manual_seed(0) + skip_pytorch_below_1_9_0() + model: ConverterModule = request.param() + inputs: Tensor = model.input_fn() + inputs.requires_grad = True + yield model, inputs + del model, inputs + + +def test_network_diag_ggn(model_and_input): + """Test whether the given module can compute diag_ggn. + + This test is placed here, because some models are too big to run with PyTorch. + Thus, a full diag_ggn comparison with PyTorch is impossible. + This test just checks whether it runs on BackPACK without errors. + Additionally, it checks whether the forward pass is identical to the original model. + Finally, a small number of elements of DiagGGN are compared. + + Args: + model_and_input: module to test + """ + model_original, x = model_and_input + output_compare = model_original(x) + y = rand_like(output_compare) + + num_params = sum(p.numel() for p in model_original.parameters()) + num_to_compare = 10 + idx_to_compare = linspace(0, num_params - 1, num_to_compare, dtype=int32) + diag_ggn_exact_to_compare = autograd_diag_ggn_exact( + x, y, model_original, MSELoss(), idx=idx_to_compare + ) + + model_extended = extend(model_original, use_converter=True, debug=True) + output = model_extended(x) + + assert allclose(output, output_compare) + + loss = extend(MSELoss())(output, y) + + with backpack(DiagGGNExact()): + loss.backward() + + diag_ggn_exact_vector = cat( + [p.diag_ggn_exact.flatten() for p in model_extended.parameters()] + ) + + for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): + assert allclose(element, diag_ggn_exact_vector[idx]) diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index 28fdd6b9b..552c4a1de 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -21,6 +21,7 @@ ConvTranspose3d, CrossEntropyLoss, Dropout, + Identity, LeakyReLU, Linear, LogSigmoid, @@ -63,11 +64,15 @@ from backpack.core.derivatives.permute import PermuteDerivatives from backpack.core.derivatives.relu import ReLUDerivatives from backpack.core.derivatives.rnn import RNNDerivatives +from backpack.core.derivatives.scale_module import ScaleModuleDerivatives from backpack.core.derivatives.selu import SELUDerivatives from backpack.core.derivatives.sigmoid import SigmoidDerivatives +from backpack.core.derivatives.sum_module import SumModuleDerivatives from backpack.core.derivatives.tanh import TanhDerivatives from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives +from backpack.custom_module.branching import ActiveIdentity, SumModule from backpack.custom_module.permute import Permute +from backpack.custom_module.scale_module import ScaleModule derivatives_for = { Linear: LinearDerivatives, @@ -103,4 +108,8 @@ BatchNorm1d: BatchNormNdDerivatives, BatchNorm2d: BatchNormNdDerivatives, BatchNorm3d: BatchNormNdDerivatives, + ScaleModule: ScaleModuleDerivatives, + ActiveIdentity: ScaleModuleDerivatives, + Identity: ScaleModuleDerivatives, + SumModule: SumModuleDerivatives, } diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 6cb2f3e69..ae477b0d0 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -16,6 +16,7 @@ from test.core.derivatives.permute_settings import PERMUTE_SETTINGS from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS +from test.core.derivatives.scale_module_settings import SCALE_MODULE_SETTINGS from test.core.derivatives.settings import SETTINGS from test.utils.skip_test import ( skip_adaptive_avg_pool3d_cuda, @@ -54,6 +55,9 @@ BATCH_NORM_PROBLEMS = make_test_problems(BATCH_NORM_SETTINGS) BATCH_NORM_IDS = [problem.make_id() for problem in BATCH_NORM_PROBLEMS] +SCALE_MODULE_PROBLEMS = make_test_problems(SCALE_MODULE_SETTINGS) +SCALE_MODULE_IDS = [problem.make_id() for problem in SCALE_MODULE_PROBLEMS] + SUBSAMPLINGS = [None, [0, 0], [2, 0]] SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] @@ -123,8 +127,12 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: @mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize( "problem", - NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + BATCH_NORM_PROBLEMS, - ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + BATCH_NORM_IDS, + NO_LOSS_PROBLEMS + + RNN_PROBLEMS + + PERMUTE_PROBLEMS + + BATCH_NORM_PROBLEMS + + SCALE_MODULE_PROBLEMS, + ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + BATCH_NORM_IDS + SCALE_MODULE_IDS, ) def test_jac_t_mat_prod( problem: DerivativesTestProblem, diff --git a/test/core/derivatives/scale_module_settings.py b/test/core/derivatives/scale_module_settings.py new file mode 100644 index 000000000..50de81e53 --- /dev/null +++ b/test/core/derivatives/scale_module_settings.py @@ -0,0 +1,29 @@ +"""Test settings for ScaleModule derivatives.""" +from torch import rand +from torch.nn import Identity + +from backpack.custom_module.branching import ActiveIdentity +from backpack.custom_module.scale_module import ScaleModule + +SCALE_MODULE_SETTINGS = [ + { + "module_fn": lambda: ScaleModule(), + "input_fn": lambda: rand(3, 4, 2), + }, + { + "module_fn": lambda: ScaleModule(0.3), + "input_fn": lambda: rand(3, 2), + }, + { + "module_fn": lambda: ScaleModule(5.7), + "input_fn": lambda: rand(2, 3), + }, + { + "module_fn": lambda: ActiveIdentity(), + "input_fn": lambda: rand(3, 2, 4), + }, + { + "module_fn": lambda: Identity(), + "input_fn": lambda: rand(3, 1, 2), + }, +] diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 99d46914b..3eec969f2 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -42,9 +42,14 @@ Sequential, Sigmoid, ) +from torchvision.models import resnet18 from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple +from backpack.utils import CONVERTER_AVAILABLE + +if CONVERTER_AVAILABLE: + from backpack import convert_module_to_backpack FIRSTORDER_SETTINGS = [] @@ -281,3 +286,18 @@ "target_fn": lambda: regression_targets((8, 3 * 5)), }, ] +############################################################################### +# test setting: torchvision resnet # +############################################################################### +if CONVERTER_AVAILABLE: + FIRSTORDER_SETTINGS += [ + { + "input_fn": lambda: rand(2, 3, 7, 7), + "module_fn": lambda: convert_module_to_backpack( + resnet18(num_classes=4).eval(), True + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((2, 4)), + "id_prefix": "resnet18", + }, + ] diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 77232c0f2..24125e681 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -182,12 +182,15 @@ def __get_reduction_factor(loss: Tensor, unreduced_loss: Tensor) -> float: mean_loss = unreduced_loss.flatten().mean() sum_loss = unreduced_loss.flatten().sum() if torch.allclose(mean_loss, sum_loss): - raise RuntimeError( - "Cannot determine reduction factor. ", - "Results from 'mean' and 'sum' reduction are identical. ", - f"'mean': {mean_loss}, 'sum': {sum_loss}", - ) - if torch.allclose(loss, mean_loss): + if unreduced_loss.numel() == 1 and torch.allclose(loss, sum_loss): + factor = 1.0 + else: + raise RuntimeError( + "Cannot determine reduction factor. ", + "Results from 'mean' and 'sum' reduction are identical. ", + f"'mean': {mean_loss}, 'sum': {sum_loss}", + ) + elif torch.allclose(loss, mean_loss): factor = 1.0 / unreduced_loss.numel() elif torch.allclose(loss, sum_loss): factor = 1.0 diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 3a75f207f..7eb0801ac 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -9,7 +9,8 @@ Shared settings are taken from `test.extensions.secondorder.secondorder_settings`. Additional local cases can be defined here through ``LOCAL_SETTINGS``. """ -from test.core.derivatives.utils import regression_targets +from test.converter.resnet_cases import ResNet1, ResNet2 +from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS from test.utils.evaluation_mode import initialize_training_false_recursive @@ -22,15 +23,25 @@ BatchNorm1d, BatchNorm2d, BatchNorm3d, + Conv2d, + CrossEntropyLoss, Flatten, Linear, + MaxPool2d, MSELoss, ReLU, Sequential, + Sigmoid, ) +from backpack.custom_module import branching +from backpack.custom_module.branching import ActiveIdentity, Parallel from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple +from backpack.utils import CONVERTER_AVAILABLE + +if CONVERTER_AVAILABLE: + from backpack import convert_module_to_backpack SHARED_SETTINGS = SECONDORDER_SETTINGS LOCAL_SETTINGS = [] @@ -125,4 +136,101 @@ "target_fn": lambda: regression_targets((3, 4 * 1 * 3 * 3)), }, ] +############################################################################### +# Branched models # +############################################################################### +LOCAL_SETTINGS += [ + { + "input_fn": lambda: rand(3, 10), + "module_fn": lambda: Sequential( + Linear(10, 5), + ReLU(), + # skip connection + Parallel( + ActiveIdentity(), + Linear(5, 5), + ), + # end of skip connection + Sigmoid(), + Linear(5, 4), + ), + "loss_function_fn": lambda: CrossEntropyLoss(), + "target_fn": lambda: classification_targets((3,), 4), + "id_prefix": "branching-linear", + }, + { + "input_fn": lambda: rand(4, 2, 6, 6), + "module_fn": lambda: Sequential( + Conv2d(2, 3, kernel_size=3, stride=1, padding=1), + ReLU(), + # skip connection + Parallel( + ActiveIdentity(), + Sequential( + Conv2d(3, 5, kernel_size=3, stride=1, padding=1), + ReLU(), + Conv2d(5, 3, kernel_size=3, stride=1, padding=1), + ), + ), + # end of skip connection + MaxPool2d(kernel_size=3, stride=2), + Flatten(), + Linear(12, 5), + ), + "loss_function_fn": lambda: CrossEntropyLoss(), + "target_fn": lambda: classification_targets((4,), 5), + "id_prefix": "branching-convolution", + }, + { + "input_fn": lambda: rand(4, 3, 6, 6), + "module_fn": lambda: Sequential( + Conv2d(3, 2, kernel_size=3, stride=1, padding=1), + ReLU(), + # skip connection + Parallel( + ActiveIdentity(), + Sequential( + Conv2d(2, 4, kernel_size=3, stride=1, padding=1), + Sigmoid(), + Conv2d(4, 2, kernel_size=3, stride=1, padding=1), + branching.Parallel( + branching.ActiveIdentity(), + Sequential( + Conv2d(2, 4, kernel_size=3, stride=1, padding=1), + ReLU(), + Conv2d(4, 2, kernel_size=3, stride=1, padding=1), + ), + ), + ), + ), + # end of skip connection + MaxPool2d(kernel_size=3, stride=2), + Flatten(), + Linear(8, 5), + ), + "loss_function_fn": lambda: CrossEntropyLoss(), + "target_fn": lambda: classification_targets((4,), 5), + "id_prefix": "nested-branching-convolution", + }, +] +############################################################################### +# Branched models - converter # +############################################################################### +if CONVERTER_AVAILABLE: + LOCAL_SETTINGS += [ + { + "input_fn": lambda: ResNet1.input_test, + "module_fn": lambda: convert_module_to_backpack(ResNet1(), True), + "loss_function_fn": lambda: ResNet1.loss_test, + "target_fn": lambda: ResNet1.target_test, + "id_prefix": "ResNet1", + }, + { + "input_fn": lambda: rand(ResNet2.input_test), + "module_fn": lambda: convert_module_to_backpack(ResNet2().eval(), True), + "loss_function_fn": lambda: ResNet2.loss_test, + "target_fn": lambda: rand(ResNet2.target_test), + "id_prefix": "ResNet2", + }, + ] DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index 34257e946..47374d779 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -7,7 +7,7 @@ from pytest import skip from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_1 +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0, TORCH_VERSION_AT_LEAST_1_9_1 def skip_adaptive_avg_pool3d_cuda(request) -> None: @@ -59,6 +59,12 @@ def skip_subsampling_conflict( skip("Not enough samples.") +def skip_pytorch_below_1_9_0() -> None: + """Skip test if pytorch version is below 1.9.0.""" + if not TORCH_VERSION_AT_LEAST_1_9_0: + skip("Test needs PyTorch>=1.9.0") + + def skip_large_parameters( problem: ExtensionsTestProblem, max_num_params: int = 1000 ) -> None: From e871cfa72a19b6c40b636a4f568b5853ebaa23a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 2 Sep 2021 15:15:39 +0200 Subject: [PATCH 43/54] [ADD] Interface to check module extension hyperparameter support (#206) * [REFACTOR] unify interface for checking parameters in module extension * [DOC] Small polish Co-authored-by: Felix Dangel --- .../firstorder/batch_grad/batchnorm_nd.py | 15 ++++++++++-- .../firstorder/batch_l2_grad/batchnorm_nd.py | 15 ++++++++++-- .../firstorder/gradient/batchnorm_nd.py | 15 ++++++++++-- .../sum_grad_squared/batchnorm_nd.py | 15 ++++++++++-- .../firstorder/variance/batchnorm_nd.py | 15 ++++++++++-- backpack/extensions/module_extension.py | 20 ++++++++++++++++ .../secondorder/diag_ggn/batchnorm_nd.py | 24 +++++++++++++++---- 7 files changed, 105 insertions(+), 14 deletions(-) diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py index ea13772cb..83759b0ae 100644 --- a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py +++ b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py @@ -1,5 +1,11 @@ """Contains grad_batch extension for BatchNorm.""" +from typing import Tuple, Union + +from torch import Tensor +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase from backpack.utils.errors import batch_norm_raise_error_if_train @@ -13,6 +19,11 @@ def __init__(self): derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] ) - def __call__(self, ext, module, g_inp, g_out): # noqa: D102 + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 batch_norm_raise_error_if_train(module, raise_error=False) - super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py index 5d6d09bf6..9e1941804 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py +++ b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py @@ -1,5 +1,11 @@ """Contains batch_l2 extension for BatchNorm.""" +from typing import Tuple, Union + +from torch import Tensor +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base from backpack.utils.errors import batch_norm_raise_error_if_train @@ -11,6 +17,11 @@ def __init__(self): """Initialization.""" super().__init__(["weight", "bias"], BatchNormNdDerivatives()) - def __call__(self, ext, module, g_inp, g_out): # noqa: D102 + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 batch_norm_raise_error_if_train(module) - super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/gradient/batchnorm_nd.py b/backpack/extensions/firstorder/gradient/batchnorm_nd.py index 00322852a..5bacc2ad6 100644 --- a/backpack/extensions/firstorder/gradient/batchnorm_nd.py +++ b/backpack/extensions/firstorder/gradient/batchnorm_nd.py @@ -1,5 +1,11 @@ """Gradient extension for BatchNorm.""" +from typing import Tuple, Union + +from torch import Tensor +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.backprop_extension import BackpropExtension from backpack.utils.errors import batch_norm_raise_error_if_train from .base import GradBaseModule @@ -14,6 +20,11 @@ def __init__(self): derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] ) - def __call__(self, ext, module, g_inp, g_out): # noqa: D102 + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 batch_norm_raise_error_if_train(module) - super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py index 891ca2ee3..9ad99de07 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py +++ b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py @@ -1,5 +1,11 @@ """SGS extension for BatchNorm.""" +from typing import Tuple, Union + +from torch import Tensor +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase from backpack.utils.errors import batch_norm_raise_error_if_train @@ -11,6 +17,11 @@ def __init__(self): """Initialization.""" super().__init__(BatchNormNdDerivatives(), ["weight", "bias"]) - def __call__(self, ext, module, g_inp, g_out): # noqa: D102 + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 batch_norm_raise_error_if_train(module) - super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/firstorder/variance/batchnorm_nd.py b/backpack/extensions/firstorder/variance/batchnorm_nd.py index 00d516515..d2b8512e5 100644 --- a/backpack/extensions/firstorder/variance/batchnorm_nd.py +++ b/backpack/extensions/firstorder/variance/batchnorm_nd.py @@ -1,4 +1,10 @@ """Variance extension for BatchNorm.""" +from typing import Tuple, Union + +from torch import Tensor +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + +from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.firstorder.gradient.batchnorm_nd import GradBatchNormNd from backpack.extensions.firstorder.sum_grad_squared.batchnorm_nd import SGSBatchNormNd from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule @@ -12,6 +18,11 @@ def __init__(self): """Initialization.""" super().__init__(["weight", "bias"], GradBatchNormNd(), SGSBatchNormNd()) - def __call__(self, ext, module, g_inp, g_out): # noqa: D102 + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 batch_norm_raise_error_if_train(module) - super().__call__(ext, module, g_inp, g_out) diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py index 3fd9070f0..d530788a7 100644 --- a/backpack/extensions/module_extension.py +++ b/backpack/extensions/module_extension.py @@ -94,6 +94,7 @@ def __call__( or if a backpropagated quantity is expected, but there is None and the old backward hook is used and the module is not a Flatten no op. """ + self.check_hyperparameters_module_extension(extension, module, g_inp, g_out) delete_old_quantities = not self.__should_retain_backproped_quantities(module) bp_quantity = self.__get_backproped_quantity( extension, module.output, delete_old_quantities @@ -245,3 +246,22 @@ def __save_value_on_parameter( param_str: parameter name """ setattr(getattr(module, param_str), extension.savefield, value) + + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: + """Check whether the current module is supported by the extension. + + Child classes can override this method. + + Args: + ext: current extension + module: module + g_inp: input gradients + g_out: output gradients + """ + pass diff --git a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py index 942be14ad..c0aa7c29b 100644 --- a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py +++ b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py @@ -1,5 +1,11 @@ """DiagGGN extension for BatchNorm.""" +from typing import Tuple, Union + +from torch import Tensor +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d + from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule from backpack.utils.errors import batch_norm_raise_error_if_train @@ -11,9 +17,14 @@ def __init__(self): """Initialization.""" super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=True) - def __call__(self, ext, module, g_inp, g_out): # noqa: D102 + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 batch_norm_raise_error_if_train(module) - super().__call__(ext, module, g_inp, g_out) class BatchDiagGGNBatchNormNd(DiagGGNBaseModule): @@ -23,6 +34,11 @@ def __init__(self): """Initialization.""" super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=False) - def __call__(self, ext, module, g_inp, g_out): # noqa: D102 + def check_hyperparameters_module_extension( + self, + ext: BackpropExtension, + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + ) -> None: # noqa: D102 batch_norm_raise_error_if_train(module) - super().__call__(ext, module, g_inp, g_out) From b0929db84f8394963681e14903bce3461cb55c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Mon, 6 Sep 2021 15:54:16 +0200 Subject: [PATCH 44/54] [ADD] LSTM: 1st-order + `DiagGGN{MC,Exact}`, `batch_first` (#215) * [ADD] LSTM: batch_first, extensions, tests * [REF] Remove dead code, clarify internal shape convention * [REF] Less linebreaks in einsum arguments * [REF] Reduce magic numbers * [REF] Use `get_batch_axis` from sub-sampling * [REF] simplify batch axis decision * [DOC] Polish * [REF] Use `torch.Tensor.transpose` * [FIX] Syntax error Co-authored-by: Felix Dangel --- backpack/core/derivatives/lstm.py | 153 +++++++++--------- .../firstorder/batch_grad/__init__.py | 2 + .../firstorder/batch_grad/batch_grad_base.py | 8 +- .../extensions/firstorder/batch_grad/rnn.py | 12 ++ .../firstorder/batch_l2_grad/__init__.py | 2 + .../firstorder/batch_l2_grad/rnn.py | 12 ++ .../extensions/firstorder/gradient/rnn.py | 12 ++ .../firstorder/sum_grad_squared/__init__.py | 2 + .../firstorder/sum_grad_squared/rnn.py | 12 ++ .../firstorder/variance/__init__.py | 2 + .../extensions/firstorder/variance/rnn.py | 19 ++- .../firstorder/variance/variance_base.py | 7 +- .../secondorder/diag_ggn/__init__.py | 3 + .../extensions/secondorder/diag_ggn/rnn.py | 25 +++ test/core/derivatives/lstm_settings.py | 8 +- .../firstorder/firstorder_settings.py | 11 ++ .../secondorder/diag_ggn/diag_ggn_settings.py | 11 ++ 17 files changed, 207 insertions(+), 94 deletions(-) diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 173e9633d..5734d719f 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -6,7 +6,7 @@ from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils import TORCH_VERSION_AT_LEAST_1_8_0 -from backpack.utils.subsampling import get_batch_axis, subsample +from backpack.utils.subsampling import subsample class LSTMDerivatives(BaseParameterDerivatives): @@ -27,6 +27,10 @@ class LSTMDerivatives(BaseParameterDerivatives): ifgo[t] = tanh(ifgo_tilde[t]) for g c[t] = f[t] c[t-1] + i[t] g[t] h[t] = o[t] tanh(c[t]) + + Note: + For ``batch_first=True``, most of the internal tensors (e.g. those from + the manual forward pass) are kept with time axis first. """ @staticmethod @@ -43,8 +47,6 @@ def _check_parameters(module: LSTM) -> None: raise NotImplementedError("only num_layers = 1 is supported") if module.bias is not True: raise NotImplementedError("only bias = True is supported") - if module.batch_first is not False: - raise NotImplementedError("only batch_first = False is supported") if module.dropout != 0: raise NotImplementedError("only dropout = 0 is supported") if module.bidirectional is not False: @@ -56,7 +58,7 @@ def _check_parameters(module: LSTM) -> None: @staticmethod def _forward_pass( module: LSTM, mat: Tensor, subsampling: List[int] = None - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor]: """This performs an additional forward pass and returns the hidden variables. This is important because the PyTorch implementation does not grant access to @@ -70,10 +72,12 @@ def _forward_pass( subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: - ifgo, c, c_tanh, h + ifgo, c, c_tanh (all in format ``[T, N, ...]``) """ - T: int = mat.shape[1] - N: int = mat.shape[2] + free_axis = 1 + N_axis, T_axis = LSTMDerivatives.get_batch_and_time_axes(module) + T: int = mat.shape[T_axis + free_axis] + N: int = mat.shape[N_axis + free_axis] H: int = module.hidden_size H0: int = 0 * H H1: int = 1 * H @@ -84,12 +88,15 @@ def _forward_pass( ifgo: Tensor = zeros(T, N, 4 * H, device=mat.device, dtype=mat.dtype) c: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) - h: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) - N_axis = get_batch_axis(module, "input0") input0 = subsample(module.input0, dim=N_axis, subsampling=subsampling) output = subsample(module.output, dim=N_axis, subsampling=subsampling) + # use [T, N, ...] format + if module.batch_first: + input0 = input0.transpose(N_axis, T_axis) + output = output.transpose(N_axis, T_axis) + for t in range(T): ifgo[t] = ( einsum("hi,ni->nh", module.weight_ih_l0, input0[t]) @@ -106,17 +113,18 @@ def _forward_pass( if t != 0: c[t] += ifgo[t, :, H1:H2] * c[t - 1] c_tanh[t] = tanh(c[t]) - h[t] = ifgo[t, :, H3:H4] * c_tanh[t] - return ifgo, c, c_tanh, h + return ifgo, c, c_tanh @classmethod def _ifgo_jac_t_mat_prod( cls, module: LSTM, mat: Tensor, subsampling: List[int] = None ) -> Tensor: + free_axis = 1 + N_axis, T_axis = cls.get_batch_and_time_axes(module) V: int = mat.shape[0] - T: int = mat.shape[1] - N: int = mat.shape[2] + T: int = mat.shape[T_axis + free_axis] + N: int = mat.shape[N_axis + free_axis] H: int = module.hidden_size H0: int = 0 * H H1: int = 1 * H @@ -124,7 +132,7 @@ def _ifgo_jac_t_mat_prod( H3: int = 3 * H H4: int = 4 * H - ifgo, c, c_tanh, _ = cls._forward_pass(module, mat, subsampling=subsampling) + ifgo, c, c_tanh = cls._forward_pass(module, mat, subsampling=subsampling) # backward pass H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) @@ -133,28 +141,20 @@ def _ifgo_jac_t_mat_prod( IFGO_prod: Tensor = zeros(V, T, N, 4 * H, device=mat.device, dtype=mat.dtype) for t in reversed(range(T)): # jac_t_mat_prod until node h - H_prod_t[:] = mat[:, t] + H_prod_t[:] = mat[(slice(None),) * (T_axis + 1) + (t,)] if t != (T - 1): H_prod_t += einsum( - "vnh,hg->vng", - IFGO_prod[:, t + 1], - module.weight_hh_l0, + "vnh,hg->vng", IFGO_prod[:, t + 1], module.weight_hh_l0 ) # C_prod_t = jac_t_mat_prod until node c if t != (T - 1): C_prod_old[:] = C_prod_t C_prod_t[:] = einsum( - "vnh,nh->vnh", - H_prod_t, - ifgo[t, :, H3:H4] * (1 - c_tanh[t] ** 2), + "vnh,nh->vnh", H_prod_t, ifgo[t, :, H3:H4] * (1 - c_tanh[t] ** 2) ) if t != (T - 1): - C_prod_t += einsum( - "vnh,nh->vnh", - C_prod_old, - ifgo[t + 1, :, H1:H2], - ) + C_prod_t += einsum("vnh,nh->vnh", C_prod_old, ifgo[t + 1, :, H1:H2]) IFGO_prod[:, t, :, H3:H4] = einsum( "vnh,nh->vnh", @@ -190,9 +190,11 @@ def _jac_mat_prod( mat: Tensor, sum_batch: bool = True, ) -> Tensor: + free_axis = 1 + N_axis, T_axis = self.get_batch_and_time_axes(module) V: int = mat.shape[0] - T: int = mat.shape[1] - N: int = mat.shape[2] + T: int = mat.shape[T_axis + free_axis] + N: int = mat.shape[N_axis + free_axis] H: int = module.hidden_size H0: int = 0 * H H1: int = 1 * H @@ -200,7 +202,7 @@ def _jac_mat_prod( H3: int = 3 * H H4: int = 4 * H - ifgo, c, c_tanh, h = self._forward_pass(module, mat) + ifgo, c, c_tanh = self._forward_pass(module, mat) H_prod: Tensor = zeros(V, T, N, H, device=mat.device, dtype=mat.dtype) C_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) C_prod_old: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) @@ -211,13 +213,11 @@ def _jac_mat_prod( IFGO_prod_t[:] = einsum( "hi,vni->vnh", module.weight_ih_l0, - mat[:, t], + mat[(slice(None),) * (T_axis + free_axis) + (t,)], ) if t != 0: IFGO_prod_t[:] += einsum( - "hg,vng->vnh", - module.weight_hh_l0, - H_prod[:, t - 1], + "hg,vng->vnh", module.weight_hh_l0, H_prod[:, t - 1] ) IFGO_prod_t[:, :, H0:H2] = einsum( "vnh,nh->vnh", @@ -238,42 +238,24 @@ def _jac_mat_prod( # product until node c if t >= 1: C_prod_old[:] = C_prod_t - C_prod_t[:] = ( - einsum( - "vnh,nh->vnh", - IFGO_prod_t[:, :, H0:H1], - ifgo[t, :, H2:H3], - ) - + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H2:H3], ifgo[t, :, H0:H1]) - ) + C_prod_t[:] = einsum( + "vnh,nh->vnh", IFGO_prod_t[:, :, H0:H1], ifgo[t, :, H2:H3] + ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H2:H3], ifgo[t, :, H0:H1]) if t >= 1: C_prod_t += einsum( - "vnh,nh->vnh", - C_prod_old, - ifgo[t, :, H1:H2], - ) + einsum( - "vnh,nh->vnh", - IFGO_prod_t[:, :, H1:H2], - c[t - 1], - ) + "vnh,nh->vnh", C_prod_old, ifgo[t, :, H1:H2] + ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H1:H2], c[t - 1]) # product until node c_tanh - C_tanh_prod_t[:] = einsum( - "vnh,nh->vnh", - C_prod_t, - 1 - c_tanh[t] ** 2, - ) + C_tanh_prod_t[:] = einsum("vnh,nh->vnh", C_prod_t, 1 - c_tanh[t] ** 2) # product until node h H_prod[:, t] = einsum( - "vnh,nh->vnh", - IFGO_prod_t[:, :, H3:H4], - c_tanh[t], - ) + einsum( - "vnh,nh->vnh", - C_tanh_prod_t, - ifgo[t, :, H3:H4], - ) + "vnh,nh->vnh", IFGO_prod_t[:, :, H3:H4], c_tanh[t] + ) + einsum("vnh,nh->vnh", C_tanh_prod_t, ifgo[t, :, H3:H4]) + + if module.batch_first: + H_prod = H_prod.transpose(T_axis + free_axis, N_axis + free_axis) return H_prod @@ -291,10 +273,11 @@ def _jac_t_mat_prod( module, mat, subsampling=subsampling ) + N_axis, _ = self.get_batch_and_time_axes(module) + batch_time_str = "nt" if N_axis == 0 else "tn" + X_prod: Tensor = einsum( - "vtnh,hi->vtni", - IFGO_prod, - module.weight_ih_l0, + f"vtnh,hi->v{batch_time_str}i", IFGO_prod, module.weight_ih_l0 ) return X_prod @@ -343,14 +326,13 @@ def _weight_ih_l0_jac_t_mat_prod( module, mat, subsampling=subsampling ) + N_axis, _ = self.get_batch_and_time_axes(module) + batch_time_str = "nt" if N_axis == 0 else "tn" + return einsum( - f"vtnh,tni->v{'' if sum_batch else 'n'}hi", + f"vtnh,{batch_time_str}i->v{'' if sum_batch else 'n'}hi", IFGO_prod, - subsample( - module.input0, - dim=get_batch_axis(module, "input0"), - subsampling=subsampling, - ), + subsample(module.input0, dim=N_axis, subsampling=subsampling), ) def _weight_hh_l0_jac_t_mat_prod( @@ -364,25 +346,42 @@ def _weight_hh_l0_jac_t_mat_prod( ) -> Tensor: self._check_parameters(module) - N: int = mat.shape[2] + free_axis = 1 + N_axis, T_axis = self.get_batch_and_time_axes(module) + + N: int = mat.shape[N_axis + free_axis] H: int = module.hidden_size IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( module, mat, subsampling=subsampling ) + subsampled_output = subsample( + module.output, dim=N_axis, subsampling=subsampling + ) + if N_axis == 0: + subsampled_output = subsampled_output.transpose(N_axis, T_axis) + return einsum( f"vtnh,tng->v{'' if sum_batch else 'n'}hg", IFGO_prod, cat( [ zeros(1, N, H, device=mat.device, dtype=mat.dtype), - subsample( - module.output, - dim=get_batch_axis(module, "input0"), - subsampling=subsampling, - )[0:-1], + subsampled_output[0:-1], ], dim=0, ), ) + + @staticmethod + def get_batch_and_time_axes(module: LSTM) -> Tuple[int, int]: + """Return axes interpreted by the module as batch and time axes of the input. + + Args: + module: LSTM module. + + Returns: + Batch axis and time axis. + """ + return (0, 1) if module.batch_first else (1, 0) diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 4a4d05562..10f82e447 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -5,6 +5,7 @@ from typing import List from torch.nn import ( + LSTM, RNN, BatchNorm1d, BatchNorm2d, @@ -80,6 +81,7 @@ def __init__(self, subsampling: List[int] = None): BatchNorm2d: batchnorm_nd.BatchGradBatchNormNd(), BatchNorm3d: batchnorm_nd.BatchGradBatchNormNd(), RNN: rnn.BatchGradRNN(), + LSTM: rnn.BatchGradLSTM(), }, subsampling=subsampling, ) diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index 5a3198fa7..32c41e524 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -79,16 +79,14 @@ def param_function( Scaled individual gradients """ subsampling = ext.get_subsampling() + N_axis = get_batch_axis(module, "output") + return self._derivatives.param_mjp( param_str, module, g_inp, g_out, - subsample( - g_out[0], - dim=get_batch_axis(module, "output"), - subsampling=subsampling, - ), + subsample(g_out[0], dim=N_axis, subsampling=subsampling), sum_batch=False, subsampling=subsampling, ) diff --git a/backpack/extensions/firstorder/batch_grad/rnn.py b/backpack/extensions/firstorder/batch_grad/rnn.py index 50afdf516..9b92f2642 100644 --- a/backpack/extensions/firstorder/batch_grad/rnn.py +++ b/backpack/extensions/firstorder/batch_grad/rnn.py @@ -1,4 +1,5 @@ """Contains BatchGradRNN.""" +from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase @@ -12,3 +13,14 @@ def __init__(self): derivatives=RNNDerivatives(), params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], ) + + +class BatchGradLSTM(BatchGradBase): + """Extension for LSTM calculating grad_batch.""" + + def __init__(self): + """Initialization.""" + super().__init__( + derivatives=LSTMDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + ) diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 6a5e770b1..3b97bbcdf 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -4,6 +4,7 @@ Within it, define the extension for each module. """ from torch.nn import ( + LSTM, RNN, BatchNorm1d, BatchNorm2d, @@ -61,6 +62,7 @@ def __init__(self): ConvTranspose2d: convtransposend.BatchL2ConvTranspose2d(), ConvTranspose3d: convtransposend.BatchL2ConvTranspose3d(), RNN: rnn.BatchL2RNN(), + LSTM: rnn.BatchL2LSTM(), BatchNorm1d: batchnorm_nd.BatchL2BatchNorm(), BatchNorm2d: batchnorm_nd.BatchL2BatchNorm(), BatchNorm3d: batchnorm_nd.BatchL2BatchNorm(), diff --git a/backpack/extensions/firstorder/batch_l2_grad/rnn.py b/backpack/extensions/firstorder/batch_l2_grad/rnn.py index efdbc4320..dbb1a1644 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/rnn.py +++ b/backpack/extensions/firstorder/batch_l2_grad/rnn.py @@ -1,4 +1,5 @@ """Contains BatchL2RNN.""" +from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base @@ -12,3 +13,14 @@ def __init__(self): ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], derivatives=RNNDerivatives(), ) + + +class BatchL2LSTM(BatchL2Base): + """Extension for LSTM, calculating batch_l2.""" + + def __init__(self): + """Initialization.""" + super().__init__( + ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + derivatives=LSTMDerivatives(), + ) diff --git a/backpack/extensions/firstorder/gradient/rnn.py b/backpack/extensions/firstorder/gradient/rnn.py index c99130efc..7ba76e626 100644 --- a/backpack/extensions/firstorder/gradient/rnn.py +++ b/backpack/extensions/firstorder/gradient/rnn.py @@ -1,4 +1,5 @@ """Contains GradRNN.""" +from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.gradient.base import GradBaseModule @@ -12,3 +13,14 @@ def __init__(self): derivatives=RNNDerivatives(), params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], ) + + +class GradLSTM(GradBaseModule): + """Extension for LSTM, calculating gradient.""" + + def __init__(self): + """Initialization.""" + super().__init__( + derivatives=LSTMDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + ) diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index 8dea8e44c..a7acf3c0c 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -3,6 +3,7 @@ Defines module extension for each module. """ from torch.nn import ( + LSTM, RNN, BatchNorm1d, BatchNorm2d, @@ -64,6 +65,7 @@ def __init__(self): ConvTranspose2d: convtranspose2d.SGSConvTranspose2d(), ConvTranspose3d: convtranspose3d.SGSConvTranspose3d(), RNN: rnn.SGSRNN(), + LSTM: rnn.SGSLSTM(), BatchNorm1d: batchnorm_nd.SGSBatchNormNd(), BatchNorm2d: batchnorm_nd.SGSBatchNormNd(), BatchNorm3d: batchnorm_nd.SGSBatchNormNd(), diff --git a/backpack/extensions/firstorder/sum_grad_squared/rnn.py b/backpack/extensions/firstorder/sum_grad_squared/rnn.py index 7388dd746..129229144 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/rnn.py +++ b/backpack/extensions/firstorder/sum_grad_squared/rnn.py @@ -1,4 +1,5 @@ """Contains SGSRNN module.""" +from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase @@ -12,3 +13,14 @@ def __init__(self): derivatives=RNNDerivatives(), params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], ) + + +class SGSLSTM(SGSBase): + """Extension for LSTM, calculating sum_gradient_squared.""" + + def __init__(self): + """Initialization.""" + super().__init__( + derivatives=LSTMDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + ) diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index 831c07e40..d0d4aa302 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -3,6 +3,7 @@ Defines module extension for each module. """ from torch.nn import ( + LSTM, RNN, BatchNorm1d, BatchNorm2d, @@ -64,6 +65,7 @@ def __init__(self): ConvTranspose2d: convtranspose2d.VarianceConvTranspose2d(), ConvTranspose3d: convtranspose3d.VarianceConvTranspose3d(), RNN: rnn.VarianceRNN(), + LSTM: rnn.VarianceLSTM(), BatchNorm1d: batchnorm_nd.VarianceBatchNormNd(), BatchNorm2d: batchnorm_nd.VarianceBatchNormNd(), BatchNorm3d: batchnorm_nd.VarianceBatchNormNd(), diff --git a/backpack/extensions/firstorder/variance/rnn.py b/backpack/extensions/firstorder/variance/rnn.py index 93342b32c..62baa1258 100644 --- a/backpack/extensions/firstorder/variance/rnn.py +++ b/backpack/extensions/firstorder/variance/rnn.py @@ -1,6 +1,7 @@ """Contains VarianceRNN.""" -from backpack.extensions.firstorder.gradient.rnn import GradRNN -from backpack.extensions.firstorder.sum_grad_squared.rnn import SGSRNN + +from backpack.extensions.firstorder.gradient.rnn import GradLSTM, GradRNN +from backpack.extensions.firstorder.sum_grad_squared.rnn import SGSLSTM, SGSRNN from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule @@ -15,6 +16,14 @@ def __init__(self): sgs_extension=SGSRNN(), ) - @staticmethod - def _get_axis_batch() -> int: - return 1 + +class VarianceLSTM(VarianceBaseModule): + """Extension for LSTM, calculating variance.""" + + def __init__(self): + """Initialization.""" + super().__init__( + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + grad_extension=GradLSTM(), + sgs_extension=SGSLSTM(), + ) diff --git a/backpack/extensions/firstorder/variance/variance_base.py b/backpack/extensions/firstorder/variance/variance_base.py index 47e383d46..a1c7a9c0d 100644 --- a/backpack/extensions/firstorder/variance/variance_base.py +++ b/backpack/extensions/firstorder/variance/variance_base.py @@ -7,6 +7,7 @@ from torch.nn import Module from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.utils.subsampling import get_batch_axis if TYPE_CHECKING: from backpack.extensions import Variance @@ -45,10 +46,6 @@ def _variance_from(grad: Tensor, sgs: Tensor, N: int) -> Tensor: avg_gsquared = sgs / N return avg_gsquared - avgg_squared - @staticmethod - def _get_axis_batch() -> int: - return 0 - def _make_param_function( self, param: str ) -> Callable[[Variance, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]: @@ -83,7 +80,7 @@ def param_function( return self._variance_from( getattr(self.grad_ext, param)(ext, module, g_inp, g_out, bpQuantities), getattr(self.sgs_ext, param)(ext, module, g_inp, g_out, bpQuantities), - g_out[0].shape[self._get_axis_batch()], + g_out[0].shape[get_batch_axis(module, "output")], ) return param_function diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index 1fa7742f1..9bb965b3f 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -11,6 +11,7 @@ from torch import Tensor from torch.nn import ( ELU, + LSTM, RNN, SELU, AdaptiveAvgPool1d, @@ -132,6 +133,7 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): ScaleModule: custom_module.DiagGGNScaleModule(), SumModule: custom_module.DiagGGNSumModule(), RNN: rnn.DiagGGNRNN(), + LSTM: rnn.DiagGGNLSTM(), Permute: permute.DiagGGNPermute(), AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1), AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(2), @@ -254,6 +256,7 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): ScaleModule: custom_module.DiagGGNScaleModule(), SumModule: custom_module.DiagGGNSumModule(), RNN: rnn.BatchDiagGGNRNN(), + LSTM: rnn.BatchDiagGGNLSTM(), Permute: permute.DiagGGNPermute(), AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1), AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(2), diff --git a/backpack/extensions/secondorder/diag_ggn/rnn.py b/backpack/extensions/secondorder/diag_ggn/rnn.py index 64bb4e6b7..7c926c945 100644 --- a/backpack/extensions/secondorder/diag_ggn/rnn.py +++ b/backpack/extensions/secondorder/diag_ggn/rnn.py @@ -1,4 +1,5 @@ """Module implementing GGN for RNN.""" +from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule @@ -25,3 +26,27 @@ def __init__(self): params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], sum_batch=False, ) + + +class DiagGGNLSTM(DiagGGNBaseModule): + """Calculating GGN diagonal of LSTM.""" + + def __init__(self): + """Initialize.""" + super().__init__( + derivatives=LSTMDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + sum_batch=True, + ) + + +class BatchDiagGGNLSTM(DiagGGNBaseModule): + """Calculating per-sample diagonal of GGN.""" + + def __init__(self): + """Initialize.""" + super().__init__( + derivatives=LSTMDerivatives(), + params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], + sum_batch=False, + ) diff --git a/test/core/derivatives/lstm_settings.py b/test/core/derivatives/lstm_settings.py index 348129871..94fc9131f 100644 --- a/test/core/derivatives/lstm_settings.py +++ b/test/core/derivatives/lstm_settings.py @@ -23,11 +23,15 @@ ############################################################################### LSTM_SETTINGS += [ { - "module_fn": lambda: LSTM(input_size=4, hidden_size=3, num_layers=1), + "module_fn": lambda: LSTM(input_size=4, hidden_size=3), "input_fn": lambda: rand(size=(5, 3, 4)), }, { - "module_fn": lambda: LSTM(input_size=5, hidden_size=3, num_layers=1), + "module_fn": lambda: LSTM(input_size=5, hidden_size=3), "input_fn": lambda: rand(size=(10, 8, 5)), }, + { + "module_fn": lambda: LSTM(input_size=4, hidden_size=3, batch_first=True), + "input_fn": lambda: rand(size=(3, 5, 4)), + }, ] diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 3eec969f2..3e4e377b3 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -24,6 +24,7 @@ from torch import device, rand from torch.nn import ( + LSTM, RNN, BatchNorm1d, BatchNorm2d, @@ -285,6 +286,16 @@ "loss_function_fn": lambda: MSELoss(), "target_fn": lambda: regression_targets((8, 3 * 5)), }, + { + "input_fn": lambda: rand(4, 5, 3), + "module_fn": lambda: Sequential( + LSTM(3, 4, batch_first=True), + ReduceTuple(index=0), + Flatten(), + ), + "loss_function_fn": lambda: CrossEntropyLoss(), + "target_fn": lambda: classification_targets((4,), 20), + }, ] ############################################################################### # test setting: torchvision resnet # diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 7eb0801ac..e138d35b6 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -16,6 +16,7 @@ from torch import rand from torch.nn import ( + LSTM, RNN, AdaptiveAvgPool1d, AdaptiveAvgPool2d, @@ -62,6 +63,16 @@ "loss_function_fn": lambda: MSELoss(), "target_fn": lambda: regression_targets((8, 3 * 5)), }, + { + "input_fn": lambda: rand(4, 3, 5), + "module_fn": lambda: Sequential( + LSTM(input_size=5, hidden_size=4, batch_first=True), + ReduceTuple(index=0), + Flatten(), + ), + "loss_function_fn": lambda: CrossEntropyLoss(), + "target_fn": lambda: classification_targets((4,), 4 * 3), + }, ] ################################################################## # AdaptiveAvgPool settings # From 2ad5c3ba80b787c584eb7b17707b9f86968054e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 22 Sep 2021 09:52:15 +0200 Subject: [PATCH 45/54] [ADD] CrossEntropyLoss: Support additional axes (#211) Before, only inputs of shape `[N, C]` were accepted. Now, inputs of shape `[N, C, ...]` can be treated as well. --- * [REF] remove empty lines * [ADD] CrossEntropyLoss: support additional axes * [DEL] remove TODO * [FIX] support subsampling * [FIX] support subsampling * [ADD] hessian_mat_prod support additional axes, cleanup * [DOC] add to fully_documented.txt * [REF] Remove duplicate `classification_targets` * [FIX] Input to `classification_targets` * [REF] Adapt normalization constant computation * [REF] Reduce magic numbers in `rearrange_input` * [REF] Rename functions to (un-)group additional axes * [FIX] _make_hessian_mat_prod * [DEL] Improve doc, disable multi-dimensional CrossEntropyLoss tests * [DEL] delete unnecessary imports * [ADD] sqrt_hessian: support additional axes * [REF] CrossEntropyLoss: handling of additional axes in _sqrt_hessian() to _expand_sqrt_h * [TEST] test_make_hessian_mat_prod * [ADD] make_hessian_mat_prod: handle additional axes * [ADD] autograd: sum_hessian for multiple axes * [ADD] CrossEntropyLoss: sum_hessian for additional axes, dirty implementation * [REF] CrossEntropyLoss: cleanup sum_hessian * [REF] replace numpy.prod with numel * [REF] simplify expansion of sqrt_hessian to multiple dimensions * [REF] Remove id_prefix * [REF] Improve docstring, minor tweaks * [REF] Implement `sum_hessian_blocks` with flat tensors * [REF] Simplify `hessian_mat_prod` * [FMT] Add linebreaks * [REF] Reduce nesting in `_sum_hessian` * [FIX] flake8 (exception chaining) * [FIX] Number of elements in additional axes * [ADD] Shape checks * [DEL] Shape checks * [REF] Pass `out_shape` more consistently * [TEST] CrossEntropyLoss (with RNN): DiagGGN * [TEST] CrossEntropyLoss: second order extensions * [REF] Use `Tensor.sqrt` instead of `torch.sqrt` Co-authored-by: Felix Dangel Co-authored-by: Felix Dangel --- backpack/core/derivatives/crossentropyloss.py | 176 +++++++++++++++--- .../extensions/secondorder/diag_ggn/losses.py | 2 - fully_documented.txt | 1 + test/core/derivatives/derivatives_test.py | 16 ++ .../derivatives/implementation/autograd.py | 36 ++-- .../derivatives/implementation/backpack.py | 53 +++--- test/core/derivatives/implementation/base.py | 12 ++ test/core/derivatives/loss_settings.py | 18 +- test/core/derivatives/utils.py | 8 +- .../secondorder/diag_ggn/diag_ggn_settings.py | 12 ++ .../secondorder/secondorder_settings.py | 13 ++ test/test_second_order_warnings.py | 9 +- 12 files changed, 270 insertions(+), 86 deletions(-) diff --git a/backpack/core/derivatives/crossentropyloss.py b/backpack/core/derivatives/crossentropyloss.py index 1a596b39f..68690df9a 100644 --- a/backpack/core/derivatives/crossentropyloss.py +++ b/backpack/core/derivatives/crossentropyloss.py @@ -1,9 +1,9 @@ """Partial derivatives for cross-entropy loss.""" from math import sqrt -from typing import List, Tuple +from typing import Callable, Dict, List, Tuple -from torch import Tensor, diag, diag_embed, einsum, multinomial, ones_like, softmax -from torch import sqrt as torchsqrt +from einops import rearrange +from torch import Tensor, diag, diag_embed, einsum, eye, multinomial, ones_like, softmax from torch.nn import CrossEntropyLoss from torch.nn.functional import one_hot @@ -24,20 +24,23 @@ def _sqrt_hessian( g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, - ) -> Tensor: # noqa: D102 + ) -> Tensor: self._check_2nd_order_parameters(module) probs = self._get_probs(module, subsampling=subsampling) - tau = torchsqrt(probs) + probs, *rearrange_info = self._merge_batch_and_additional(probs) + + tau = probs.sqrt() V_dim, C_dim = 0, 2 Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim) Id_tautau = Id - einsum("nv,nc->vnc", tau, tau) sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau) if module.reduction == "mean": - N = module.input0.shape[0] - sqrt_H /= sqrt(N) + sqrt_H /= sqrt(self._get_mean_normalization(module.input0)) + sqrt_H = self._ungroup_batch_and_additional(sqrt_H, *rearrange_info) + sqrt_H = self._expand_sqrt_h(sqrt_H) return sqrt_H def _sqrt_hessian_sampled( @@ -47,13 +50,15 @@ def _sqrt_hessian_sampled( g_out: Tuple[Tensor], mc_samples: int = 1, subsampling: List[int] = None, - ) -> Tensor: # noqa: D102 + ) -> Tensor: self._check_2nd_order_parameters(module) M = mc_samples C = module.input0.shape[1] probs = self._get_probs(module, subsampling=subsampling) + probs, *rearrange_info = self._merge_batch_and_additional(probs) + V_dim = 0 probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1) @@ -64,49 +69,68 @@ def _sqrt_hessian_sampled( sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M) if module.reduction == "mean": - N = module.input0.shape[0] - sqrt_mc_h /= sqrt(N) + sqrt_mc_h /= sqrt(self._get_mean_normalization(module.input0)) + sqrt_mc_h = self._ungroup_batch_and_additional(sqrt_mc_h, *rearrange_info) return sqrt_mc_h - def _sum_hessian(self, module, g_inp, g_out): + def _sum_hessian( + self, module: CrossEntropyLoss, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Tensor: self._check_2nd_order_parameters(module) probs = self._get_probs(module) - sum_H = diag(probs.sum(0)) - einsum("bi,bj->ij", (probs, probs)) + + if probs.dim() == 2: + diagonal = diag(probs.sum(0)) + sum_H = diagonal - einsum("nc,nd->cd", probs, probs) + else: + out_shape = (*probs.shape[1:], *probs.shape[1:]) + additional = probs.shape[2:].numel() + + diagonal = diag(probs.sum(0).flatten()).reshape(out_shape) + + probs = probs.flatten(2) + kron_delta = eye(additional, device=probs.device, dtype=probs.dtype) + + sum_H = diagonal - einsum( + "ncx,ndy,xy->cxdy", probs, probs, kron_delta + ).reshape(out_shape) if module.reduction == "mean": - N = module.input0.shape[0] - sum_H /= N + sum_H /= self._get_mean_normalization(module.input0) return sum_H - def _make_hessian_mat_prod(self, module, g_inp, g_out): - """Multiplication of the input Hessian with a matrix.""" + def _make_hessian_mat_prod( + self, module: CrossEntropyLoss, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] + ) -> Callable[[Tensor], Tensor]: self._check_2nd_order_parameters(module) probs = self._get_probs(module) def hessian_mat_prod(mat): - Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum( - "bi,bj,cbj->cbi", (probs, probs, mat) + Hmat = einsum("...,v...->v...", probs, mat) - einsum( + "nc...,nd...,vnd...->vnc...", probs, probs, mat ) if module.reduction == "mean": - N = module.input0.shape[0] - Hmat /= N + Hmat /= self._get_mean_normalization(module.input0) return Hmat return hessian_mat_prod - def hessian_is_psd(self): - """Return whether cross-entropy loss Hessian is positive semi-definite.""" + def hessian_is_psd(self) -> bool: + """Return whether cross-entropy loss Hessian is positive semi-definite. + + Returns: + True + """ return True - def _get_probs( - self, module: CrossEntropyLoss, subsampling: List[int] = None - ) -> Tensor: + @staticmethod + def _get_probs(module: CrossEntropyLoss, subsampling: List[int] = None) -> Tensor: """Compute the softmax probabilities from the module input. Args: @@ -120,11 +144,11 @@ def _get_probs( input0 = subsample(module.input0, subsampling=subsampling) return softmax(input0, dim=1) - def _check_2nd_order_parameters(self, module): + def _check_2nd_order_parameters(self, module: CrossEntropyLoss) -> None: """Verify that the parameters are supported by 2nd-order quantities. - Attributes: - module (torch.nn.CrossEntropyLoss): Extended CrossEntropyLoss module + Args: + module: Extended CrossEntropyLoss module Raises: NotImplementedError: If module's setting is not implemented. @@ -145,3 +169,99 @@ def _check_2nd_order_parameters(self, module): implemented_weight, module.weight ) ) + + @staticmethod + def _merge_batch_and_additional( + probs: Tensor, + ) -> Tuple[Tensor, str, Dict[str, int]]: + """Rearranges the input if it has additional axes. + + Treat additional axes like batch axis, i.e. group ``n c d1 d2 -> (n d1 d2) c``. + + Args: + probs: the tensor to rearrange + + Returns: + a tuple containing + - probs: the rearranged tensor + - str_d_dims: a string representation of the additional dimensions + - d_info: a dictionary encoding the size of the additional dimensions + """ + leading = 2 + additional = probs.dim() - leading + + str_d_dims: str = "".join(f"d{i} " for i in range(additional)) + d_info: Dict[str, int] = { + f"d{i}": probs.shape[leading + i] for i in range(additional) + } + + probs = rearrange(probs, f"n c {str_d_dims} -> (n {str_d_dims}) c") + + return probs, str_d_dims, d_info + + @staticmethod + def _ungroup_batch_and_additional( + tensor: Tensor, str_d_dims, d_info, free_axis: int = 1 + ) -> Tensor: + """Rearranges output if it has additional axes. + + Used with group_batch_and_additional. + + Undoes treating additional axes like batch axis and assumes an number of + additional free axes (``v``) were added, i.e. un-groups + ``v (n d1 d2) c -> v n c d1 d2``. + + Args: + tensor: the tensor to rearrange + str_d_dims: a string representation of the additional dimensions + d_info: a dictionary encoding the size of the additional dimensions + free_axis: Number of free leading axes. Default: ``1``. + + Returns: + the rearranged tensor + + Raises: + NotImplementedError: If ``free_axis != 1``. + """ + if free_axis != 1: + raise NotImplementedError(f"Only supports free_axis=1. Got {free_axis}.") + + return rearrange( + tensor, f"v (n {str_d_dims}) c -> v n c {str_d_dims}", **d_info + ) + + @staticmethod + def _expand_sqrt_h(sqrt_h: Tensor) -> Tensor: + """Expands the square root hessian if CrossEntropyLoss has additional axes. + + In the case of e.g. two additional axes (A and B), the input is [N,C,A,B]. + In CrossEntropyLoss the additional axes are treated independently. + Therefore, the intermediate result has shape [C,N,C,A,B]. + In subsequent calculations the additional axes are not independent anymore. + The required shape for sqrt_h_full is then [C*A*B,N,C,A,B]. + Due to the independence, sqrt_h lives on the diagonal of sqrt_h_full. + + Args: + sqrt_h: intermediate result, shape [C,N,C,A,B] + + Returns: + sqrt_h_full, shape [C*A*B,N,C,A,B], sqrt_h on diagonal. + """ + if sqrt_h.dim() > 3: + return diag_embed(sqrt_h.flatten(3), offset=0, dim1=1, dim2=4).reshape( + -1, *sqrt_h.shape[1:] + ) + else: + return sqrt_h + + @staticmethod + def _get_mean_normalization(input: Tensor) -> int: + """Get normalization constant used with reduction='mean'. + + Args: + input: Input to the cross-entropy module. + + Returns: + Divisor for mean reduction. + """ + return input.numel() // input.shape[1] diff --git a/backpack/extensions/secondorder/diag_ggn/losses.py b/backpack/extensions/secondorder/diag_ggn/losses.py index 377adb52b..6679a9b3e 100644 --- a/backpack/extensions/secondorder/diag_ggn/losses.py +++ b/backpack/extensions/secondorder/diag_ggn/losses.py @@ -9,7 +9,6 @@ class DiagGGNLoss(DiagGGNBaseModule): def backpropagate(self, ext, module, grad_inp, grad_out, backproped): hess_func = self.make_loss_hessian_func(ext) - return hess_func(module, grad_inp, grad_out) def make_loss_hessian_func(self, ext): @@ -21,7 +20,6 @@ def make_loss_hessian_func(self, ext): elif loss_hessian_strategy == LossHessianStrategy.SAMPLING: mc_samples = ext.get_num_mc_samples() return partial(self.derivatives.sqrt_hessian_sampled, mc_samples=mc_samples) - else: raise ValueError( "Unknown hessian strategy {}".format(loss_hessian_strategy) diff --git a/fully_documented.txt b/fully_documented.txt index f6018851c..f61b99cbd 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -13,6 +13,7 @@ backpack/core/derivatives/lstm.py backpack/core/derivatives/linear.py backpack/core/derivatives/adaptive_avg_pool_nd.py backpack/core/derivatives/batchnorm_nd.py +backpack/core/derivatives/crossentropyloss.py backpack/core/derivatives/scale_module.py backpack/core/derivatives/sum_module.py diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index ae477b0d0..6707c65e2 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -469,3 +469,19 @@ def test_hessian_is_zero(no_loss_problem: DerivativesTestProblem) -> None: ) else: assert backpack_res == autograd_res + + +@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS) +def test_make_hessian_mat_prod(problem: DerivativesTestProblem) -> None: + """Test hessian_mat_prod. + + Args: + problem: test problem + """ + problem.set_up() + mat = rand(4, *problem.input_shape, device=problem.device) + + autograd_res = AutogradDerivatives(problem).hessian_mat_prod(mat) + backpack_res = BackpackDerivatives(problem).hessian_mat_prod(mat) + + check_sizes_and_values(backpack_res, autograd_res) diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index a10ea8560..593e1b5f7 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -273,7 +273,8 @@ def input_hessian(self, subsampling: List[int] = None) -> Tensor: subsampling: Indices of active samples. ``None`` uses all samples. Returns: - hessian + Hessian of shape ``[N, *, N, *]`` where ``N`` denotes the + number of sub-samples, and ``*`` is the input feature shape. """ input, output, _ = self.problem.forward_pass(input_requires_grad=True) hessian = self._hessian(output, input) @@ -338,36 +339,33 @@ def _sum_hessian_blocks(self, hessian: Tensor) -> Tensor: Assert second derivative w.r.t. different samples is zero. Args: - hessian: . + hessian: Hessian of the output w.r.t. the input. Has shape ``[N, *, N, *]`` + where ``N`` is the number of active samples and ``*`` is the input's + feature shape. Returns: - sum of hessians - - Raises: - ValueError: if input is not 2d + Sum of Hessians w.r.t. to individual samples. Has shape ``[*, *]``. """ input = self.problem.input - num_axes = len(input.shape) - - if num_axes != 2: - raise ValueError("Only 2D inputs are currently supported.") - N = input.shape[0] - num_features = input.numel() // N + shape_feature = input.shape[1:] + D = shape_feature.numel() - sum_hessian = zeros(num_features, num_features, device=input.device) + hessian = hessian.reshape(N, D, N, D) + sum_hessian = zeros(D, D, device=input.device, dtype=input.dtype) - hessian_different_samples = zeros( - num_features, num_features, device=input.device - ) + hessian_different_samples = zeros(D, D, device=input.device, dtype=input.dtype) for n_1 in range(N): for n_2 in range(N): block = hessian[n_1, :, n_2, :] - if n_1 == n_2: sum_hessian += block - else: assert allclose(block, hessian_different_samples) - return sum_hessian + return sum_hessian.reshape(*shape_feature, *shape_feature) + + def hessian_mat_prod(self, mat: Tensor) -> Tensor: # noqa: D102 + input, output, _ = self.problem.forward_pass(input_requires_grad=True) + + return stack([hessian_vector_product(output, [input], [vec])[0] for vec in mat]) diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index fb7e789e3..092d368c1 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -94,7 +94,10 @@ def input_hessian_via_sqrt_hessian( subsampling: Indices of active samples. ``None`` uses all samples. Returns: - Hessian with respect to the input. + Hessian with respect to the input. Has shape + ``[N, A, B, ..., N, A, B, ...]`` where ``N`` is the batch size or number + of active samples when sub-sampling is used, and ``[A, B, ...]`` are the + input's feature dimensions. """ self.store_forward_io() @@ -127,23 +130,23 @@ def input_hessian_via_sqrt_hessian( def hessian_is_zero(self) -> bool: # noqa: D102 return self.problem.derivative.hessian_is_zero(self.problem.module) - def _sample_hessians_from_sqrt(self, sqrt): + def _sample_hessians_from_sqrt(self, sqrt: Tensor) -> Tensor: """Convert individual matrix square root into individual full matrix. Args: sqrt: individual square root of hessian Returns: - individual full matrix - - Raises: - ValueError: if input is not 2d + Individual Hessians of shape ``[N, A, B, ..., A, B, ...]`` where + ``input.shape[1:] = [A, B, ...]`` are the input feature dimensions + and ``N`` is the batch size. """ - # TODO improve readability - if sqrt.dim() == 3: - return einsum("vni,vnj->nij", sqrt, sqrt) - else: - raise ValueError("Only 2D inputs are currently supported.") + N, input_dims = sqrt.shape[1], sqrt.shape[2:] + + sqrt_flat = sqrt.flatten(start_dim=2) + sample_hessians = einsum("vni,vnj->nij", sqrt_flat, sqrt_flat) + + return sample_hessians.reshape(N, *input_dims, *input_dims) def _embed_sample_hessians( self, individual_hessians: Tensor, input: Tensor @@ -152,21 +155,25 @@ def _embed_sample_hessians( Args: individual_hessians: Hessians w.r.t. individual samples in the input. - input: Inputs for the individual Hessians. + input: Inputs for the for samples whose individual Hessians are passed. + Has shape ``[N, A, B, ..., A, B, ...]`` where ``N`` is the number of + active samples and ``[A, B, ...]`` are the feature dimensions. Returns: Hessian that contains the individual Hessians as diagonal blocks. - - Raises: - ValueError: if input is not 2d + Has shape ``[N, A, B, ..., N, A, B, ...]``. """ - hessian_shape = (*input.shape, *input.shape) - hessian = zeros(hessian_shape, device=input.device, dtype=input.dtype) + N, D = input.shape[0], input.shape[1:].numel() + hessian = zeros(N, D, N, D, device=input.device, dtype=input.dtype) + + for n in range(N): + hessian[n, :, n, :] = individual_hessians[n].reshape(D, D) - for idx in range(input.shape[0]): - if input.dim() == 2: - hessian[idx, :, idx, :] = individual_hessians[idx] - else: - raise ValueError("Only 2D inputs are currently supported.") + return hessian.reshape(*input.shape, *input.shape) - return hessian + def hessian_mat_prod(self, mat: Tensor) -> Tensor: # noqa: D102 + self.store_forward_io() + hmp = self.problem.derivative.make_hessian_mat_prod( + self.problem.module, None, None + ) + return hmp(mat) diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index 84edba5b4..1bf91c387 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -117,3 +117,15 @@ def hessian_is_zero(self) -> bool: `True`, if Hessian is zero, else `False`. """ raise NotImplementedError + + @abstractmethod + def hessian_mat_prod(self, mat: Tensor) -> Tensor: + """Product of hessian with matrix mat. + + Args: + mat: matrix to multiply + + Returns: + product + """ + raise NotImplementedError diff --git a/test/core/derivatives/loss_settings.py b/test/core/derivatives/loss_settings.py index 7c5a57855..391420cae 100644 --- a/test/core/derivatives/loss_settings.py +++ b/test/core/derivatives/loss_settings.py @@ -3,7 +3,7 @@ Required entries: "module_fn" (callable): Contains a model constructed from `torch.nn` layers "input_fn" (callable): Used for specifying input function - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model @@ -28,7 +28,7 @@ example = { "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), "input_fn": lambda: torch.rand(size=(2, 4)), - "target_fn": lambda: classification_targets(size=(2,), num_classes=2), + "target_fn": lambda: classification_targets(size=(2,), num_classes=4), "device": [torch.device("cpu")], # optional "seed": 0, # optional "id_prefix": "loss-example", # optional @@ -37,15 +37,25 @@ LOSS_SETTINGS += [ + { + "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "input_fn": lambda: torch.rand(size=(2, 4, 3)), + "target_fn": lambda: classification_targets(size=(2, 3), num_classes=4), + }, + { + "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "input_fn": lambda: torch.rand(size=(3, 4, 3, 2)), + "target_fn": lambda: classification_targets(size=(3, 3, 2), num_classes=4), + }, { "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), "input_fn": lambda: torch.rand(size=(2, 4)), - "target_fn": lambda: classification_targets(size=(2,), num_classes=2), + "target_fn": lambda: classification_targets(size=(2,), num_classes=4), }, { "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "input_fn": lambda: torch.rand(size=(8, 4)), - "target_fn": lambda: classification_targets(size=(8,), num_classes=2), + "target_fn": lambda: classification_targets(size=(8,), num_classes=4), }, { "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="none"), diff --git a/test/core/derivatives/utils.py b/test/core/derivatives/utils.py index fe8bbda35..262096b6d 100644 --- a/test/core/derivatives/utils.py +++ b/test/core/derivatives/utils.py @@ -37,11 +37,11 @@ def derivative_cls_for(module_cls: Type[Module]) -> Type[BaseDerivatives]: """ try: return derivatives_for[module_cls] - except KeyError: + except KeyError as e: raise KeyError( - "No derivative available for {}".format(module_cls) - + "Known mappings:\n{}".format(derivatives_for) - ) + f"No derivative available for {module_cls}." + + f"Known mappings:\n{derivatives_for}" + ) from e def classification_targets(size: Tuple[int, ...], num_classes: int) -> Tensor: diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index e138d35b6..a4ba4eba0 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -73,6 +73,18 @@ "loss_function_fn": lambda: CrossEntropyLoss(), "target_fn": lambda: classification_targets((4,), 4 * 3), }, + { + "input_fn": lambda: rand(5, 8, 6), # TODO (8, 5, 6) + "module_fn": lambda: Sequential( + RNN(input_size=6, hidden_size=3), # TODO batch_first=True + ReduceTuple(index=0), + Permute(1, 0, 2, batch_axis=1), # TODO remove + Linear(3, 3), + Permute(0, 2, 1, batch_axis=0), + ), + "loss_function_fn": lambda: CrossEntropyLoss(), + "target_fn": lambda: classification_targets((8, 5), 3), + }, ] ################################################################## # AdaptiveAvgPool settings # diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index f5f4386a9..fcb46832d 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -301,3 +301,16 @@ ] SECONDORDER_SETTINGS += LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS + +############################################################################### +# test setting: CrossEntropyLoss # +############################################################################### +SECONDORDER_SETTINGS += [ + { + "input_fn": lambda: rand(3, 4, 2, 3, 5), + "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2)), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3, 2, 3, 2), 4), + "id_prefix": "multi-d-CrossEntropyLoss", + }, +] diff --git a/test/test_second_order_warnings.py b/test/test_second_order_warnings.py index 5b7370552..39591e6ba 100644 --- a/test/test_second_order_warnings.py +++ b/test/test_second_order_warnings.py @@ -5,6 +5,8 @@ - using unsupported parameters of the loss """ +from test.core.derivatives.utils import classification_targets + import pytest import torch from torch.nn import CrossEntropyLoss, MSELoss @@ -29,15 +31,10 @@ ] -def classification_targets(N, num_classes): - """Create random targets for classes 0, ..., `num_classes - 1`.""" - return torch.randint(size=(N,), low=0, high=num_classes) - - def dummy_cross_entropy(N=5): y_pred = torch.rand((N, 2)) y_pred.requires_grad = True - y = classification_targets(N, 2) + y = classification_targets((N,), 2) loss_module = extend(CrossEntropyLoss()) return loss_module(y_pred, y) From a33c508fc84091a470237f2b417ab566816b7412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Fri, 24 Sep 2021 09:48:49 +0200 Subject: [PATCH 46/54] [ADD] `Embedding` derivatives & extensions (#216) * [ADD] Embedding derivatives and tests * [ADD] Embedding: first order extensions and tests * [FIX] Embedding * [Add] Embedding: DiagGGN extension and test * [DOC] add embedding_settings.py to fully_documented.txt * [FIX] fix merge * [REF] make jac_t_mat_prod error message more similar to PyTorch error * [ADD] docstring explanation batch axis * [REF] change condition for requires_grad * [FIX] flake8 (exception chaining) * [REF] Minor rephrasings * [REF] replace torch.flatten by Tensor.flatten * [DEL] remove todo * [DEL] simplify einsum * [REF] reintroduce if-else branch * [ADD] Embedding: SqrtGGN and tests * [DOC] Fix docstring * [FIX] Change seed to make MC-sampling tests pass Co-authored-by: Felix Dangel Co-authored-by: Felix Dangel --- backpack/core/derivatives/embedding.py | 74 +++++++++++++++++++ .../firstorder/batch_grad/__init__.py | 3 + .../firstorder/batch_grad/embedding.py | 11 +++ .../firstorder/batch_l2_grad/__init__.py | 3 + .../firstorder/batch_l2_grad/embedding.py | 11 +++ .../firstorder/gradient/embedding.py | 11 +++ .../firstorder/sum_grad_squared/__init__.py | 3 + .../firstorder/sum_grad_squared/embedding.py | 11 +++ .../firstorder/variance/__init__.py | 3 + .../firstorder/variance/embedding.py | 16 ++++ .../secondorder/diag_ggn/__init__.py | 4 + .../secondorder/diag_ggn/embedding.py | 23 ++++++ .../secondorder/sqrt_ggn/__init__.py | 3 + .../secondorder/sqrt_ggn/embedding.py | 11 +++ fully_documented.txt | 7 ++ test/core/derivatives/__init__.py | 3 + test/core/derivatives/derivatives_test.py | 7 +- test/core/derivatives/embedding_settings.py | 14 ++++ test/core/derivatives/problem.py | 4 +- test/core/derivatives/utils.py | 2 +- .../firstorder/firstorder_settings.py | 27 ++++++- .../secondorder/diag_ggn/diag_ggn_settings.py | 31 +++++++- .../secondorder/sqrt_ggn/sqrt_ggn_settings.py | 33 +++++++++ .../secondorder/sqrt_ggn/test_sqrt_ggn.py | 4 +- 24 files changed, 311 insertions(+), 8 deletions(-) create mode 100644 backpack/core/derivatives/embedding.py create mode 100644 backpack/extensions/firstorder/batch_grad/embedding.py create mode 100644 backpack/extensions/firstorder/batch_l2_grad/embedding.py create mode 100644 backpack/extensions/firstorder/gradient/embedding.py create mode 100644 backpack/extensions/firstorder/sum_grad_squared/embedding.py create mode 100644 backpack/extensions/firstorder/variance/embedding.py create mode 100644 backpack/extensions/secondorder/diag_ggn/embedding.py create mode 100644 backpack/extensions/secondorder/sqrt_ggn/embedding.py create mode 100644 test/core/derivatives/embedding_settings.py create mode 100644 test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py diff --git a/backpack/core/derivatives/embedding.py b/backpack/core/derivatives/embedding.py new file mode 100644 index 000000000..dc2529358 --- /dev/null +++ b/backpack/core/derivatives/embedding.py @@ -0,0 +1,74 @@ +"""Derivatives for Embedding.""" +from typing import List, Tuple + +from torch import Tensor, einsum, zeros +from torch.nn import Embedding + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 +from backpack.utils.subsampling import subsample + + +class EmbeddingDerivatives(BaseParameterDerivatives): + """Derivatives for Embedding. + + Note: + These derivatives assume the batch axis to be at position 0. + + Index convention: + v - free axis + n - batch axis + s - num_embeddings + h - embedding_dim + """ + + def _jac_t_mat_prod( + self, + module: Embedding, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: List[int] = None, + ) -> Tensor: + raise NotImplementedError( + "Derivative w.r.t. input not defined: Input to Embedding has type long." + " But only float and complex dtypes can require gradients in PyTorch." + ) + + def _weight_jac_t_mat_prod( + self, + module: Embedding, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + subsampling: List[int] = None, + ) -> Tensor: + self._check_parameters(module) + + input0 = subsample(module.input0, subsampling=subsampling) + delta = zeros(module.num_embeddings, *input0.shape, device=mat.device) + for s in range(module.num_embeddings): + delta[s] = input0 == s + if TORCH_VERSION_AT_LEAST_1_9_0: + equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh" + elif delta.dim() == 2: + equation = f"sn,vnh->v{'' if sum_batch else 'n'}sh" + else: + equation = f"snx,vnxh->v{'' if sum_batch else 'n'}sh" + delta = delta.flatten(start_dim=2, end_dim=-1) + mat = mat.flatten(start_dim=2, end_dim=-2) + return einsum(equation, delta, mat) + + def _check_parameters(self, module: Embedding) -> None: + if module.padding_idx is not None: + raise NotImplementedError("Only padding_idx=None supported.") + elif module.max_norm is not None: + raise NotImplementedError("Only max_norm=None supported.") + elif module.scale_grad_by_freq: + raise NotImplementedError("Only scale_grad_by_freq=False supported.") + elif module.sparse: + raise NotImplementedError("Only sparse=False supported.") + + def hessian_is_zero(self, module: Embedding) -> bool: # noqa: D102 + return False diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 10f82e447..d67bc1d4d 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -16,6 +16,7 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Embedding, Linear, ) @@ -29,6 +30,7 @@ conv_transpose1d, conv_transpose2d, conv_transpose3d, + embedding, linear, rnn, ) @@ -82,6 +84,7 @@ def __init__(self, subsampling: List[int] = None): BatchNorm3d: batchnorm_nd.BatchGradBatchNormNd(), RNN: rnn.BatchGradRNN(), LSTM: rnn.BatchGradLSTM(), + Embedding: embedding.BatchGradEmbedding(), }, subsampling=subsampling, ) diff --git a/backpack/extensions/firstorder/batch_grad/embedding.py b/backpack/extensions/firstorder/batch_grad/embedding.py new file mode 100644 index 000000000..35b41f7b0 --- /dev/null +++ b/backpack/extensions/firstorder/batch_grad/embedding.py @@ -0,0 +1,11 @@ +"""BatchGrad extension for Embedding.""" +from backpack.core.derivatives.embedding import EmbeddingDerivatives +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase + + +class BatchGradEmbedding(BatchGradBase): + """BatchGrad extension for Embedding.""" + + def __init__(self): + """Initialization.""" + super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 3b97bbcdf..8be80f08e 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -15,6 +15,7 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Embedding, Linear, ) @@ -23,6 +24,7 @@ batchnorm_nd, convnd, convtransposend, + embedding, linear, rnn, ) @@ -66,5 +68,6 @@ def __init__(self): BatchNorm1d: batchnorm_nd.BatchL2BatchNorm(), BatchNorm2d: batchnorm_nd.BatchL2BatchNorm(), BatchNorm3d: batchnorm_nd.BatchL2BatchNorm(), + Embedding: embedding.BatchL2Embedding(), }, ) diff --git a/backpack/extensions/firstorder/batch_l2_grad/embedding.py b/backpack/extensions/firstorder/batch_l2_grad/embedding.py new file mode 100644 index 000000000..eca2b10cb --- /dev/null +++ b/backpack/extensions/firstorder/batch_l2_grad/embedding.py @@ -0,0 +1,11 @@ +"""BatchL2 extension for Embedding.""" +from backpack.core.derivatives.embedding import EmbeddingDerivatives +from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base + + +class BatchL2Embedding(BatchL2Base): + """BatchL2 extension for Embedding.""" + + def __init__(self): + """Initialization.""" + super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) diff --git a/backpack/extensions/firstorder/gradient/embedding.py b/backpack/extensions/firstorder/gradient/embedding.py new file mode 100644 index 000000000..c394ae509 --- /dev/null +++ b/backpack/extensions/firstorder/gradient/embedding.py @@ -0,0 +1,11 @@ +"""Gradient extension for Embedding.""" +from backpack.core.derivatives.embedding import EmbeddingDerivatives +from backpack.extensions.firstorder.gradient.base import GradBaseModule + + +class GradEmbedding(GradBaseModule): + """Gradient extension for Embedding.""" + + def __init__(self): + """Initialization.""" + super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index a7acf3c0c..76891cff6 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -14,6 +14,7 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Embedding, Linear, ) @@ -27,6 +28,7 @@ convtranspose1d, convtranspose2d, convtranspose3d, + embedding, linear, rnn, ) @@ -69,5 +71,6 @@ def __init__(self): BatchNorm1d: batchnorm_nd.SGSBatchNormNd(), BatchNorm2d: batchnorm_nd.SGSBatchNormNd(), BatchNorm3d: batchnorm_nd.SGSBatchNormNd(), + Embedding: embedding.SGSEmbedding(), }, ) diff --git a/backpack/extensions/firstorder/sum_grad_squared/embedding.py b/backpack/extensions/firstorder/sum_grad_squared/embedding.py new file mode 100644 index 000000000..62f34e86b --- /dev/null +++ b/backpack/extensions/firstorder/sum_grad_squared/embedding.py @@ -0,0 +1,11 @@ +"""SGS extension for Embedding.""" +from backpack.core.derivatives.embedding import EmbeddingDerivatives +from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase + + +class SGSEmbedding(SGSBase): + """SGS extension for Embedding.""" + + def __init__(self): + """Initialization.""" + super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index d0d4aa302..eeb90902f 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -14,6 +14,7 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Embedding, Linear, ) @@ -27,6 +28,7 @@ convtranspose1d, convtranspose2d, convtranspose3d, + embedding, linear, rnn, ) @@ -69,5 +71,6 @@ def __init__(self): BatchNorm1d: batchnorm_nd.VarianceBatchNormNd(), BatchNorm2d: batchnorm_nd.VarianceBatchNormNd(), BatchNorm3d: batchnorm_nd.VarianceBatchNormNd(), + Embedding: embedding.VarianceEmbedding(), }, ) diff --git a/backpack/extensions/firstorder/variance/embedding.py b/backpack/extensions/firstorder/variance/embedding.py new file mode 100644 index 000000000..1b38472a6 --- /dev/null +++ b/backpack/extensions/firstorder/variance/embedding.py @@ -0,0 +1,16 @@ +"""Variance extension for Embedding.""" +from backpack.extensions.firstorder.gradient.embedding import GradEmbedding +from backpack.extensions.firstorder.sum_grad_squared.embedding import SGSEmbedding +from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule + + +class VarianceEmbedding(VarianceBaseModule): + """Variance extension for Embedding.""" + + def __init__(self): + """Initialization.""" + super().__init__( + grad_extension=GradEmbedding(), + sgs_extension=SGSEmbedding(), + params=["weight"], + ) diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index 9bb965b3f..df0e02759 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -31,6 +31,7 @@ ConvTranspose3d, CrossEntropyLoss, Dropout, + Embedding, Flatten, Identity, LeakyReLU, @@ -64,6 +65,7 @@ convtranspose3d, custom_module, dropout, + embedding, flatten, linear, losses, @@ -141,6 +143,7 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): BatchNorm1d: batchnorm_nd.DiagGGNBatchNormNd(), BatchNorm2d: batchnorm_nd.DiagGGNBatchNormNd(), BatchNorm3d: batchnorm_nd.DiagGGNBatchNormNd(), + Embedding: embedding.DiagGGNEmbedding(), }, ) @@ -264,6 +267,7 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): BatchNorm1d: batchnorm_nd.BatchDiagGGNBatchNormNd(), BatchNorm2d: batchnorm_nd.BatchDiagGGNBatchNormNd(), BatchNorm3d: batchnorm_nd.BatchDiagGGNBatchNormNd(), + Embedding: embedding.BatchDiagGGNEmbedding(), }, ) diff --git a/backpack/extensions/secondorder/diag_ggn/embedding.py b/backpack/extensions/secondorder/diag_ggn/embedding.py new file mode 100644 index 000000000..1021b089b --- /dev/null +++ b/backpack/extensions/secondorder/diag_ggn/embedding.py @@ -0,0 +1,23 @@ +"""DiagGGN extension for Embedding.""" +from backpack.core.derivatives.embedding import EmbeddingDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class DiagGGNEmbedding(DiagGGNBaseModule): + """DiagGGN extension of Embedding.""" + + def __init__(self): + """Initialize.""" + super().__init__( + derivatives=EmbeddingDerivatives(), params=["weight"], sum_batch=True + ) + + +class BatchDiagGGNEmbedding(DiagGGNBaseModule): + """DiagGGN extension of Embedding.""" + + def __init__(self): + """Initialize.""" + super().__init__( + derivatives=EmbeddingDerivatives(), params=["weight"], sum_batch=False + ) diff --git a/backpack/extensions/secondorder/sqrt_ggn/__init__.py b/backpack/extensions/secondorder/sqrt_ggn/__init__.py index 4ff71dbbc..ebd31dae4 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/__init__.py +++ b/backpack/extensions/secondorder/sqrt_ggn/__init__.py @@ -16,6 +16,7 @@ ConvTranspose3d, CrossEntropyLoss, Dropout, + Embedding, Flatten, LeakyReLU, Linear, @@ -37,6 +38,7 @@ convnd, convtransposend, dropout, + embedding, flatten, linear, losses, @@ -92,6 +94,7 @@ def __init__( LogSigmoid: activations.SqrtGGNLogSigmoid(), ELU: activations.SqrtGGNELU(), SELU: activations.SqrtGGNSELU(), + Embedding: embedding.SqrtGGNEmbedding(), }, subsampling=subsampling, ) diff --git a/backpack/extensions/secondorder/sqrt_ggn/embedding.py b/backpack/extensions/secondorder/sqrt_ggn/embedding.py new file mode 100644 index 000000000..070ad217c --- /dev/null +++ b/backpack/extensions/secondorder/sqrt_ggn/embedding.py @@ -0,0 +1,11 @@ +"""Contains extension for the embedding layer used by ``SqrtGGN{Exact, MC}``.""" +from backpack.core.derivatives.embedding import EmbeddingDerivatives +from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule + + +class SqrtGGNEmbedding(SqrtGGNBaseModule): + """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Embedding`` module.""" + + def __init__(self): + """Pass derivatives for ``torch.nn.Embedding`` module.""" + super().__init__(EmbeddingDerivatives(), params=["weight"]) diff --git a/fully_documented.txt b/fully_documented.txt index f61b99cbd..adf917fef 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -13,6 +13,7 @@ backpack/core/derivatives/lstm.py backpack/core/derivatives/linear.py backpack/core/derivatives/adaptive_avg_pool_nd.py backpack/core/derivatives/batchnorm_nd.py +backpack/core/derivatives/embedding.py backpack/core/derivatives/crossentropyloss.py backpack/core/derivatives/scale_module.py backpack/core/derivatives/sum_module.py @@ -27,18 +28,22 @@ backpack/extensions/firstorder/gradient/base.py backpack/extensions/firstorder/gradient/rnn.py backpack/extensions/firstorder/gradient/__init__.py backpack/extensions/firstorder/gradient/batchnorm_nd.py +backpack/extensions/firstorder/gradient/embedding.py backpack/extensions/firstorder/batch_grad/batch_grad_base.py backpack/extensions/firstorder/batch_grad/rnn.py backpack/extensions/firstorder/batch_grad/__init__.py backpack/extensions/firstorder/batch_grad/batchnorm_nd.py +backpack/extensions/firstorder/batch_grad/embedding.py backpack/extensions/firstorder/variance/variance_base.py backpack/extensions/firstorder/variance/rnn.py backpack/extensions/firstorder/variance/__init__.py backpack/extensions/firstorder/variance/batchnorm_nd.py +backpack/extensions/firstorder/variance/embedding.py backpack/extensions/firstorder/sum_grad_squared/sgs_base.py backpack/extensions/firstorder/sum_grad_squared/rnn.py backpack/extensions/firstorder/sum_grad_squared/__init__.py backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py +backpack/extensions/firstorder/sum_grad_squared/embedding.py backpack/extensions/firstorder/batch_l2_grad/ backpack/extensions/secondorder/__init__.py backpack/extensions/secondorder/diag_ggn/__init__.py @@ -47,6 +52,7 @@ backpack/extensions/secondorder/diag_ggn/rnn.py backpack/extensions/secondorder/diag_ggn/permute.py backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py +backpack/extensions/secondorder/diag_ggn/embedding.py backpack/extensions/secondorder/diag_ggn/custom_module.py backpack/extensions/secondorder/diag_hessian/__init__.py backpack/extensions/secondorder/diag_hessian/conv1d.py @@ -91,6 +97,7 @@ test/core/derivatives/permute_settings.py test/core/derivatives/lstm_settings.py test/core/derivatives/pooling_adaptive_settings.py test/core/derivatives/batch_norm_settings.py +test/core/derivatives/embedding_settings.py test/core/derivatives/scale_module_settings.py test/utils/evaluation_mode.py test/utils/skip_test.py diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index 552c4a1de..7d3f35df0 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -21,6 +21,7 @@ ConvTranspose3d, CrossEntropyLoss, Dropout, + Embedding, Identity, LeakyReLU, Linear, @@ -53,6 +54,7 @@ from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives from backpack.core.derivatives.dropout import DropoutDerivatives from backpack.core.derivatives.elu import ELUDerivatives +from backpack.core.derivatives.embedding import EmbeddingDerivatives from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives from backpack.core.derivatives.linear import LinearDerivatives from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives @@ -108,6 +110,7 @@ BatchNorm1d: BatchNormNdDerivatives, BatchNorm2d: BatchNormNdDerivatives, BatchNorm3d: BatchNormNdDerivatives, + Embedding: EmbeddingDerivatives, ScaleModule: ScaleModuleDerivatives, ActiveIdentity: ScaleModuleDerivatives, Identity: ScaleModuleDerivatives, diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 6707c65e2..8283bea17 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -9,6 +9,7 @@ from contextlib import nullcontext from test.automated_test import check_sizes_and_values from test.core.derivatives.batch_norm_settings import BATCH_NORM_SETTINGS +from test.core.derivatives.embedding_settings import EMBEDDING_SETTINGS from test.core.derivatives.implementation.autograd import AutogradDerivatives from test.core.derivatives.implementation.backpack import BackpackDerivatives from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS @@ -55,6 +56,9 @@ BATCH_NORM_PROBLEMS = make_test_problems(BATCH_NORM_SETTINGS) BATCH_NORM_IDS = [problem.make_id() for problem in BATCH_NORM_PROBLEMS] +EMBEDDING_PROBLEMS = make_test_problems(EMBEDDING_SETTINGS) +EMBEDDING_IDS = [problem.make_id() for problem in EMBEDDING_PROBLEMS] + SCALE_MODULE_PROBLEMS = make_test_problems(SCALE_MODULE_SETTINGS) SCALE_MODULE_IDS = [problem.make_id() for problem in SCALE_MODULE_PROBLEMS] @@ -394,7 +398,8 @@ def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None @fixture( - params=PROBLEMS + BATCH_NORM_PROBLEMS + RNN_PROBLEMS, ids=lambda p: p.make_id() + params=PROBLEMS + BATCH_NORM_PROBLEMS + RNN_PROBLEMS + EMBEDDING_PROBLEMS, + ids=lambda p: p.make_id(), ) def problem(request) -> DerivativesTestProblem: """Set seed, create tested layer and data. Finally clean up. diff --git a/test/core/derivatives/embedding_settings.py b/test/core/derivatives/embedding_settings.py new file mode 100644 index 000000000..e6f8b1486 --- /dev/null +++ b/test/core/derivatives/embedding_settings.py @@ -0,0 +1,14 @@ +"""Settings for testing derivatives of Embedding.""" +from torch import randint +from torch.nn import Embedding + +EMBEDDING_SETTINGS = [ + { + "module_fn": lambda: Embedding(3, 5), + "input_fn": lambda: randint(0, 3, (4,)), + }, + { + "module_fn": lambda: Embedding(5, 7), + "input_fn": lambda: randint(0, 5, (8, 3, 3)), + }, +] diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index d01378e7e..4f6306d9a 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -5,7 +5,7 @@ from typing import Dict, List, Tuple import torch -from torch import Tensor +from torch import Tensor, long from backpack import extend from backpack.utils.module_classification import is_loss @@ -151,7 +151,7 @@ def forward_pass( batch_axis_in = get_batch_axis(self.module, "input0") input = subsample(input, dim=batch_axis_in, subsampling=subsampling) - if input_requires_grad: + if input_requires_grad and input.dtype is not long: input.requires_grad = True if self.is_loss(): diff --git a/test/core/derivatives/utils.py b/test/core/derivatives/utils.py index 262096b6d..21bfe1834 100644 --- a/test/core/derivatives/utils.py +++ b/test/core/derivatives/utils.py @@ -39,7 +39,7 @@ def derivative_cls_for(module_cls: Type[Module]) -> Type[BaseDerivatives]: return derivatives_for[module_cls] except KeyError as e: raise KeyError( - f"No derivative available for {module_cls}." + f"No derivative available for {module_cls}. " + f"Known mappings:\n{derivatives_for}" ) from e diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 3e4e377b3..c897befaa 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -22,7 +22,7 @@ from test.extensions.automated_settings import make_simple_cnn_setting from test.utils.evaluation_mode import initialize_training_false_recursive -from torch import device, rand +from torch import device, rand, randint from torch.nn import ( LSTM, RNN, @@ -36,6 +36,7 @@ ConvTranspose2d, ConvTranspose3d, CrossEntropyLoss, + Embedding, Flatten, Linear, MSELoss, @@ -297,6 +298,30 @@ "target_fn": lambda: classification_targets((4,), 20), }, ] +############################################################################### +# test setting: Embedding # +############################################################################### +FIRSTORDER_SETTINGS += [ + { + "input_fn": lambda: randint(0, 5, (6,)), + "module_fn": lambda: Sequential( + Embedding(5, 3), + Linear(3, 4), + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((6,), 4), + }, + { + "input_fn": lambda: randint(0, 3, (4, 2, 2)), + "module_fn": lambda: Sequential( + Embedding(3, 5), + Flatten(), + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((4,), 2 * 5), + }, +] + ############################################################################### # test setting: torchvision resnet # ############################################################################### diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index a4ba4eba0..59316c062 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -14,7 +14,7 @@ from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS from test.utils.evaluation_mode import initialize_training_false_recursive -from torch import rand +from torch import rand, randint from torch.nn import ( LSTM, RNN, @@ -26,6 +26,7 @@ BatchNorm3d, Conv2d, CrossEntropyLoss, + Embedding, Flatten, Linear, MaxPool2d, @@ -159,6 +160,32 @@ "target_fn": lambda: regression_targets((3, 4 * 1 * 3 * 3)), }, ] +############################################################################### +# Embedding # +############################################################################### +LOCAL_SETTINGS += [ + { + "input_fn": lambda: randint(0, 5, (6,)), + "module_fn": lambda: Sequential( + Embedding(5, 3), + Linear(3, 4), + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((6,), 4), + }, + { + "input_fn": lambda: randint(0, 3, (3, 2, 2)), + "module_fn": lambda: Sequential( + Embedding(3, 2), + Flatten(), + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 2 * 2), + "seed": 2, + }, +] + + ############################################################################### # Branched models # ############################################################################### @@ -236,6 +263,7 @@ "id_prefix": "nested-branching-convolution", }, ] + ############################################################################### # Branched models - converter # ############################################################################### @@ -256,4 +284,5 @@ "id_prefix": "ResNet2", }, ] + DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py b/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py new file mode 100644 index 000000000..f6fdc34cd --- /dev/null +++ b/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py @@ -0,0 +1,33 @@ +"""Contains test settings for testing SqrtGGN extension.""" +from test.core.derivatives.utils import classification_targets +from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS + +from torch import randint +from torch.nn import CrossEntropyLoss, Embedding, Flatten, Linear, Sequential + +SQRT_GGN_SETTINGS = SECONDORDER_SETTINGS + +############################################################################### +# Embedding # +############################################################################### +SQRT_GGN_SETTINGS += [ + { + "input_fn": lambda: randint(0, 5, (6,)), + "module_fn": lambda: Sequential( + Embedding(5, 3), + Linear(3, 4), + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((6,), 4), + }, + { + "input_fn": lambda: randint(0, 3, (3, 2, 2)), + "module_fn": lambda: Sequential( + Embedding(3, 2), + Flatten(), + ), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 2 * 2), + "seed": 1, + }, +] diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py index 7034044a3..92c44f152 100644 --- a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -5,13 +5,13 @@ from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import ExtensionsTestProblem, make_test_problems -from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS +from test.extensions.secondorder.sqrt_ggn.sqrt_ggn_settings import SQRT_GGN_SETTINGS from test.utils.skip_test import skip_large_parameters, skip_subsampling_conflict from typing import List, Union from pytest import fixture, mark -PROBLEMS = make_test_problems(SECONDORDER_SETTINGS) +PROBLEMS = make_test_problems(SQRT_GGN_SETTINGS) SUBSAMPLINGS = [None, [0, 0], [2, 0]] SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] From ad860422597e6665823cdde53746cc98f14e679a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Tue, 5 Oct 2021 14:26:44 +0200 Subject: [PATCH 47/54] [REF] Assume AdaptiveAvgPoolBug3d solved for `torch>=2.0` (#222) --- backpack/core/derivatives/adaptive_avg_pool_nd.py | 8 ++------ backpack/utils/__init__.py | 3 +++ test/utils/skip_test.py | 8 +++----- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/backpack/core/derivatives/adaptive_avg_pool_nd.py b/backpack/core/derivatives/adaptive_avg_pool_nd.py index 2c0a4ed0a..f6f4d3ad1 100644 --- a/backpack/core/derivatives/adaptive_avg_pool_nd.py +++ b/backpack/core/derivatives/adaptive_avg_pool_nd.py @@ -6,7 +6,7 @@ from torch.nn import AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_1 +from backpack.utils import ADAPTIVE_AVG_POOL_BUG class AdaptiveAvgPoolNDDerivatives(AvgPoolNDDerivatives): @@ -28,11 +28,7 @@ def check_parameters( Raises: NotImplementedError: if the given shapes do not match """ - if ( - TORCH_VERSION_AT_LEAST_1_9_1 is False - and module.input0.is_cuda - and (self.N == 3) - ): + if ADAPTIVE_AVG_POOL_BUG and module.input0.is_cuda and (self.N == 3): warn( "Be careful when computing gradients of AdaptiveAvgPool3d. " "There is a bug using autograd.grad on cuda with AdaptiveAvgPool3d. " diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index e51232283..d05762d93 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -8,8 +8,11 @@ TORCH_VERSION_AT_LEAST_1_8_0 = TORCH_VERSION >= packaging.version.parse("1.8.0") TORCH_VERSION_AT_LEAST_1_9_0 = TORCH_VERSION >= packaging.version.parse("1.9.0") TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1") +TORCH_VERSION_AT_LEAST_2_0_0 = TORCH_VERSION >= packaging.version.parse("2.0.0") + FULL_BACKWARD_HOOK: bool = TORCH_VERSION_AT_LEAST_1_9_0 CONVERTER_AVAILABLE: bool = TORCH_VERSION_AT_LEAST_1_9_0 +ADAPTIVE_AVG_POOL_BUG: bool = not TORCH_VERSION_AT_LEAST_2_0_0 def exception_inside_backward_pass(error: Type[Exception]) -> Type[Exception]: diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index 47374d779..b7afc30e2 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -7,7 +7,7 @@ from pytest import skip from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0, TORCH_VERSION_AT_LEAST_1_9_1 +from backpack.utils import ADAPTIVE_AVG_POOL_BUG, TORCH_VERSION_AT_LEAST_1_9_0 def skip_adaptive_avg_pool3d_cuda(request) -> None: @@ -16,16 +16,14 @@ def skip_adaptive_avg_pool3d_cuda(request) -> None: Args: request: problem request """ - if TORCH_VERSION_AT_LEAST_1_9_1: - pass - else: + if ADAPTIVE_AVG_POOL_BUG: if all( string in request.node.callspec.id for string in ["AdaptiveAvgPool3d", "cuda"] ): skip( "Skip test because AdaptiveAvgPool3d does not work on cuda. " - "Is fixed in torch 1.9.1." + "Should be fixed in torch 2.0." ) From 46628306bb204fa67e2de55719ca2704d3016f25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 6 Oct 2021 17:03:08 +0200 Subject: [PATCH 48/54] [ADD] `retain_graph` option in `backpack` context (#217) * [TEST] test retain_graph * [ADD] add option retain_graph in backpack() * [TEST] adjust test to new interface of retain_graph * [REF] rename test_autograd.py to test_retain_graph.py * [ADD] example_indiv_hvp_retain_graph.py * [REF] docstring * [REF] change (not a and not b) to: not (a or b) * [DOC] fully_documented.txt: add test_retain_graph.py * [REF] simplify _clear_input_output * [DOC] docstring _clear_input_output * [REF] rename _clear_input_output to _check_no_io * [REF] fix seed * [REF] reuse inputs and labels * [REF] Use `any` to check for IO in module * [DOC] Wrong IHVP if parameter occurs multiple times in grad Evidence that the method proposed in #213 does not work whenever a parameter is used in the first `backward` pass to compute the gradient. The IHVP is missing terms from the parameter's operation that is used in the backward pass, because `BatchGrad` only installs the Jacobians of modules from the forward pass. * [ADD] test_for_loop_replace from github issue #220 * [DEL] remove example_indiv_hvp_retain_graph.py * [REF] improve test * [TEST] set seed * [DOC] Docstring polish Co-authored-by: Felix Dangel --- backpack/__init__.py | 19 ++++++-- backpack/context.py | 19 ++++++++ fully_documented.txt | 1 + test/test_retain_graph.py | 99 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 4 deletions(-) create mode 100644 test/test_retain_graph.py diff --git a/backpack/__init__.py b/backpack/__init__.py index 068a973af..877fd1b6a 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -27,6 +27,7 @@ def __init__( *exts: BackpropExtension, extension_hook: Callable[[Module], None] = None, debug: bool = False, + retain_graph: bool = False, ): """Activate BackPACK extensions. @@ -39,6 +40,9 @@ def __init__( all BackPACK extensions have run. Takes a ``torch.nn.Module`` and returns ``None``. Default: ``None`` (no operation will be performed). debug: Print debug messages during the backward pass. Default: ``False``. + retain_graph: Determines whether BackPack IO should be kept for additional + backward passes. Should have same value as the argument ``retain_graph`` + in ``backward()``. Default: ``False``. .. note:: extension_hook can be used to reduce memory overhead if the goal is to compute @@ -67,15 +71,18 @@ def __init__( self.extension_hook: Callable[[Module], None] = ( no_op if extension_hook is None else extension_hook ) + self.retain_graph = retain_graph def __enter__(self): """Setup backpack environment.""" self.old_CTX = CTX.get_active_exts() self.old_debug = CTX.get_debug() self.old_extension_hook = CTX.get_extension_hook() + self.old_retain_graph = CTX.get_retain_graph() CTX.set_active_exts(self.exts) CTX.set_debug(self.debug) CTX.set_extension_hook(self.extension_hook) + CTX.set_retain_graph(self.retain_graph) def __exit__( self, @@ -93,6 +100,7 @@ def __exit__( CTX.set_active_exts(self.old_CTX) CTX.set_debug(self.old_debug) CTX.set_extension_hook(self.old_extension_hook) + CTX.set_retain_graph(self.old_retain_graph) class disable: @@ -209,10 +217,13 @@ def hook_run_extensions( CTX.get_extension_hook()(module) if not ( - CTX.is_extension_active( - extensions.curvmatprod.HMP, - extensions.curvmatprod.GGNMP, - extensions.curvmatprod.PCHMP, + CTX.get_retain_graph() + or ( + CTX.is_extension_active( + extensions.curvmatprod.HMP, + extensions.curvmatprod.GGNMP, + extensions.curvmatprod.PCHMP, + ) ) ): memory_cleanup(module) diff --git a/backpack/context.py b/backpack/context.py index 433ad474f..39d19ebf7 100644 --- a/backpack/context.py +++ b/backpack/context.py @@ -15,6 +15,7 @@ class CTX: debug: bool = False extension_hook: Callable[[Module], None] = no_op hook_handles: List[RemovableHandle] = [] + retain_graph: bool = False @staticmethod def set_active_exts(active_exts: Iterable[BackpropExtension]) -> None: @@ -97,3 +98,21 @@ def set_extension_hook(extension_hook: Callable[[Module], None]) -> None: extension_hook: the extension hook to run after all other extensions """ CTX.extension_hook = extension_hook + + @staticmethod + def set_retain_graph(retain_graph: bool) -> None: + """Set retain_graph. + + Args: + retain_graph: new value for retain_graph + """ + CTX.retain_graph = retain_graph + + @staticmethod + def get_retain_graph() -> bool: + """Get retain_graph. + + Returns: + retain_graph + """ + return CTX.retain_graph diff --git a/fully_documented.txt b/fully_documented.txt index adf917fef..16bd4b4ed 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -105,3 +105,4 @@ test/utils/__init__.py test/converter/ test/utils/test_subsampling.py test/custom_module/ +test/test_retain_graph.py diff --git a/test/test_retain_graph.py b/test/test_retain_graph.py new file mode 100644 index 000000000..055f18d14 --- /dev/null +++ b/test/test_retain_graph.py @@ -0,0 +1,99 @@ +"""Test autograd functionality like retain_graph.""" +from test.automated_test import check_sizes_and_values + +from pytest import raises +from torch import autograd, manual_seed, ones_like, rand, randint, randn, zeros +from torch.nn import CrossEntropyLoss, Linear, Module, Sequential + +from backpack import backpack, extend +from backpack.extensions import BatchGrad + + +def test_retain_graph(): + """Tests whether retain_graph works as expected. + + Does several forward and backward passes. + In between, it is tested whether BackPACK quantities are present or not. + """ + manual_seed(0) + model = extend(Sequential(Linear(4, 6), Linear(6, 5))) + loss_fn = extend(CrossEntropyLoss()) + + # after a forward pass graph is not clear + inputs = rand(8, 4) + labels = randint(5, (8,)) + loss = loss_fn(model(inputs), labels) + with raises(AssertionError): + _check_no_io(model) + + # after a normal backward pass graph should be clear + loss.backward() + _check_no_io(model) + + # after a backward pass with retain_graph=True graph is not clear + loss = loss_fn(model(inputs), labels) + with backpack(retain_graph=True): + loss.backward(retain_graph=True) + with raises(AssertionError): + _check_no_io(model) + + # doing several backward passes with retain_graph=True + for _ in range(3): + with backpack(retain_graph=True): + loss.backward(retain_graph=True) + with raises(AssertionError): + _check_no_io(model) + + # finally doing a normal backward pass that verifies graph is clear again + with backpack(BatchGrad()): + loss.backward() + _check_no_io(model) + + +def _check_no_io(module: Module) -> None: + """Checks whether the module is clear of any BackPACK inputs and outputs. + + Args: + module: The module to test + + Raises: + AssertionError: if the module or any child module has BackPACK inputs or outputs. + """ + for child_module in module.children(): + _check_no_io(child_module) + + io_strs = ["input0", "output"] + if any(hasattr(module, io) for io in io_strs): + raise AssertionError(f"IO should be clear, but {module} has one of {io_strs}.") + + +def test_for_loop_replace() -> None: + """Application of retain_graph: replace an outer for-loop. + + This test is based on issue #220 opened by Romain3Ch216. + It computes per-component individual gradients of a tensor-valued output + with a for loop over components, rather than over samples and components. + """ + manual_seed(0) + B = 5 + M = 3 + h = 2 + + x = randn(B, h) + fc = extend(Linear(h, M)) + A = fc(x) + + grad_autograd = zeros(B, M, *fc.weight.shape) + for b in range(B): + for m in range(M): + with backpack(retain_graph=True): + grads = autograd.grad(A[b, m], fc.weight, retain_graph=True) + grad_autograd[b, m] = grads[0] + + grad_backpack = zeros(B, M, *fc.weight.shape) + for i in range(M): + with backpack(BatchGrad(), retain_graph=True): + A[:, i].backward(ones_like(A[:, i]), retain_graph=True) + grad_backpack[:, i] = fc.weight.grad_batch + + check_sizes_and_values(grad_backpack, grad_autograd) From a3d6116787878ebf096280db722739108d0842bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 7 Oct 2021 10:56:17 +0200 Subject: [PATCH 49/54] [ADD] Converter functionality for RNNs and some torch methods (#221) Also adds support for `batch_first`. Progress on #16. Resolves fKunstner/backpack-discuss#116. --- * [REF] remove empty lines * [ADD] CrossEntropyLoss: support additional axes * [DEL] remove TODO * [FIX] support subsampling * [FIX] support subsampling * [ADD] hessian_mat_prod support additional axes, cleanup * [DOC] add to fully_documented.txt * [ADD] Embedding derivatives and tests * [ADD] Embedding: first order extensions and tests * [FIX] Embedding * [Add] Embedding: DiagGGN extension and test * [DOC] add embedding_settings.py to fully_documented.txt * [REF] Remove duplicate `classification_targets` * [FIX] Input to `classification_targets` * [REF] Adapt normalization constant computation * [REF] Reduce magic numbers in `rearrange_input` * [REF] Rename functions to (un-)group additional axes * [FIX] _make_hessian_mat_prod * [DEL] Improve doc, disable multi-dimensional CrossEntropyLoss tests * [FIX] fix merge * [DEL] delete unnecessary imports * [ADD] sqrt_hessian: support additional axes * [REF] CrossEntropyLoss: handling of additional axes in _sqrt_hessian() to _expand_sqrt_h * [TEST] test_make_hessian_mat_prod * [ADD] make_hessian_mat_prod: handle additional axes * [ADD] autograd: sum_hessian for multiple axes * [ADD] CrossEntropyLoss: sum_hessian for additional axes, dirty implementation * [REF] CrossEntropyLoss: cleanup sum_hessian * [REF] make jac_t_mat_prod error message more similar to PyTorch error * [REF] replace numpy.prod with numel * [REF] simplify expansion of sqrt_hessian to multiple dimensions * [REF] Remove id_prefix * [REF] Improve docstring, minor tweaks * [REF] Implement `sum_hessian_blocks` with flat tensors * [REF] Simplify `hessian_mat_prod` * [FMT] Add linebreaks * [REF] Reduce nesting in `_sum_hessian` * [FIX] flake8 (exception chaining) * [FIX] Number of elements in additional axes * [ADD] docstring explanation batch axis * [REF] change condition for requires_grad * [FIX] flake8 (exception chaining) * [REF] Minor rephrasings * [ADD] Shape checks * [DEL] Shape checks * [REF] Pass `out_shape` more consistently * [TEST] CrossEntropyLoss (with RNN): DiagGGN * [TEST] CrossEntropyLoss: second order extensions * [REF] replace torch.flatten by Tensor.flatten * [DEL] remove todo * [DEL] simplify einsum * [REF] Use `Tensor.sqrt` instead of `torch.sqrt` * [REF] reintroduce if-else branch * [ADD] RNN converter and test * [REF] RNN converter and test: adjust tests * [ADD] Converter: Flatten method * [ADD] Converter: Mul and Add method * [ADD] Converter: transpose * [ADD] Update Tolstoi and adjust test * [ADD] Converter: RNN * [ADD] RNN: batch_first * [REF] cleanup, RNN version of Tolstoi * [FIX] Dropout evaluation mode, converter_cases.py parameter order * [DOC] document converter might change parameter order * [REF] incorporate suggestions * [REF] graph_utils.py: incorporate suggestions * [DOC] permute.py: clarify docstring * [REF] converter_cases.py: incorporate suggestions * [REF] graph_utils.py: incorporate suggestions * [REF] incorporate suggestions * [DOC] docstrings * [REF] define batch_size Co-authored-by: Felix Dangel Co-authored-by: Felix Dangel --- backpack/__init__.py | 2 + backpack/core/derivatives/dropout.py | 24 +- backpack/core/derivatives/rnn.py | 112 ++++--- backpack/custom_module/graph_utils.py | 281 ++++++++++++++++-- backpack/custom_module/permute.py | 23 +- fully_documented.txt | 1 + test/converter/converter_cases.py | 166 ++++++++++- test/converter/test_converter.py | 47 ++- test/core/derivatives/rnn_settings.py | 6 + .../secondorder/diag_ggn/diag_ggn_settings.py | 5 +- 10 files changed, 573 insertions(+), 94 deletions(-) diff --git a/backpack/__init__.py b/backpack/__init__.py index 877fd1b6a..86a6f8cb4 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -239,6 +239,8 @@ def extend(module: Module, debug: bool = False, use_converter: bool = False) -> module: The module to extend. debug: Print debug messages during the extension. Default: ``False``. use_converter: Try converting the module to a BackPACK-compatible network. + The converter might alter the model, e.g. order of parameters. + Default: ``False``. Returns: Extended module. diff --git a/backpack/core/derivatives/dropout.py b/backpack/core/derivatives/dropout.py index 5ca3d4a02..b32aef49a 100644 --- a/backpack/core/derivatives/dropout.py +++ b/backpack/core/derivatives/dropout.py @@ -1,7 +1,7 @@ """Partial derivatives for the dropout layer.""" from typing import List, Tuple -from torch import Tensor, eq +from torch import Tensor, eq, ones_like from torch.nn import Dropout from backpack.core.derivatives.elementwise import ElementwiseDerivatives @@ -9,8 +9,17 @@ class DropoutDerivatives(ElementwiseDerivatives): + """Derivatives for the Dropout module.""" + def hessian_is_zero(self, module: Dropout) -> bool: - """``Dropout''(x) = 0``.""" + """``Dropout''(x) = 0``. + + Args: + module: dropout module + + Returns: + whether hessian is zero + """ return True def df( @@ -19,8 +28,11 @@ def df( g_inp: Tuple[Tensor], g_out: Tuple[Tensor], subsampling: List[int] = None, - ) -> Tensor: + ) -> Tensor: # noqa: D102 output = subsample(module.output, subsampling=subsampling) - scaling = 1 / (1 - module.p) - mask = 1 - eq(output, 0.0).to(output.dtype) - return mask * scaling + if module.training: + scaling = 1 / (1 - module.p) + mask = 1 - eq(output, 0.0).to(output.dtype) + return mask * scaling + else: + return ones_like(output) diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index a049b13c1..b0d8e7011 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -6,7 +6,7 @@ from torch.nn import RNN from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils.subsampling import get_batch_axis, subsample +from backpack.utils.subsampling import subsample class RNNDerivatives(BaseParameterDerivatives): @@ -40,8 +40,6 @@ def _check_parameters(module: RNN) -> None: raise NotImplementedError("only nonlinearity = tanh is supported") if module.bias is not True: raise NotImplementedError("only bias = True is supported") - if module.batch_first is not False: - raise NotImplementedError("only batch_first = False is supported") if not module.dropout == 0: raise NotImplementedError("only dropout = 0 is supported") if module.bidirectional is not False: @@ -50,38 +48,48 @@ def _check_parameters(module: RNN) -> None: def hessian_is_zero(self, module: RNN) -> bool: # noqa: D102 return False - @staticmethod + @classmethod def _a_jac_t_mat_prod( - output: Tensor, + cls, + module: RNN, weight_hh_l0: Tensor, mat: Tensor, + subsampling: List[int] = None, ) -> Tensor: """Calculates jacobian vector product wrt a. Args: - output: the values of the hidden layer + module: RNN module weight_hh_l0: weight matrix hidden-to-hidden mat: matrix to multiply + subsampling: subsampling Returns: jacobian vector product wrt a """ + free_axis = 1 + N_axis, T_axis = cls.get_batch_and_time_axes(module) V: int = mat.shape[0] - N: int = mat.shape[2] - T: int = mat.shape[1] + N: int = mat.shape[N_axis + free_axis] + T: int = mat.shape[T_axis + free_axis] H: int = mat.shape[3] + output = subsample(module.output, dim=N_axis, subsampling=subsampling) + # use [T, N, ...] format + if module.batch_first: + output = output.transpose(N_axis, T_axis) a_jac_t_mat_prod: Tensor = zeros(V, T, N, H, device=mat.device) for t in reversed(range(T)): + mat_t = mat[(slice(None),) * (T_axis + free_axis) + (t,)] if t == (T - 1): a_jac_t_mat_prod[:, t, ...] = einsum( "vnh,nh->vnh", - mat[:, t, ...], + mat_t, 1 - output[t, ...] ** 2, ) else: a_jac_t_mat_prod[:, t, ...] = einsum( "vnh,nh->vnh", - mat[:, t, ...] + mat_t + einsum( "vng,gh->vnh", a_jac_t_mat_prod[:, t + 1, ...], @@ -101,15 +109,12 @@ def _jac_t_mat_prod( ) -> Tensor: self._check_parameters(module) return torch.einsum( - "vtnh,hk->vtnk", + f"vtnh,hk->v{'nt' if module.batch_first else 'tn'}k", self._a_jac_t_mat_prod( - subsample( - module.output, - dim=get_batch_axis(module, "input0"), - subsampling=subsampling, - ), + module, module.weight_hh_l0, mat, + subsampling, ), module.weight_ih_l0, ) @@ -118,31 +123,47 @@ def _jac_mat_prod( self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: self._check_parameters(module) + free_axis = 1 + N_axis, T_axis = self.get_batch_and_time_axes(module) V: int = mat.shape[0] - N: int = mat.shape[2] - T: int = mat.shape[1] + N: int = mat.shape[N_axis + free_axis] + T: int = mat.shape[T_axis + free_axis] H: int = module.hidden_size + # use [T, N, ...] format + if module.batch_first: + output = module.output.transpose(N_axis, T_axis) + else: + output = module.output _jac_mat_prod: Tensor = torch.zeros(V, T, N, H, device=mat.device) for t in range(T): + mat_t = mat[(slice(None),) * (T_axis + free_axis) + (t,)] if t == 0: _jac_mat_prod[:, t, ...] = einsum( "nh,hi,vni->vnh", - 1 - module.output[t, ...] ** 2, + 1 - output[t, ...] ** 2, module.weight_ih_l0, - mat[:, t, ...], + mat_t, ) else: _jac_mat_prod[:, t, ...] = einsum( "nh,vnh->vnh", - 1 - module.output[t, ...] ** 2, - einsum("hi,vni->vnh", module.weight_ih_l0, mat[:, t, ...]) + 1 - output[t, ...] ** 2, + einsum( + "hi,vni->vnh", + module.weight_ih_l0, + mat_t, + ) + einsum( "hk,vnk->vnh", module.weight_hh_l0, _jac_mat_prod[:, t - 1, ...], ), ) - return _jac_mat_prod + return ( + _jac_mat_prod.transpose(N_axis + free_axis, T_axis + free_axis) + if module.batch_first + else _jac_mat_prod + ) def _bias_ih_l0_jac_t_mat_prod( self, @@ -172,13 +193,10 @@ def _bias_ih_l0_jac_t_mat_prod( else: dim: int = 1 return self._a_jac_t_mat_prod( - subsample( - module.output, - dim=get_batch_axis(module, "input0"), - subsampling=subsampling, - ), + module, module.weight_hh_l0, mat, + subsampling, ).sum(dim=dim) def _bias_hh_l0_jac_t_mat_prod( @@ -230,13 +248,14 @@ def _weight_ih_l0_jac_t_mat_prod( product """ self._check_parameters(module) - N_axis = get_batch_axis(module, "input0") + N_axis, _ = self.get_batch_and_time_axes(module) return einsum( - "vtnh,tnj->" + ("vhj" if sum_batch else "vnhj"), + f"vtnh,{'nt' if module.batch_first else 'tn'}j->v{'' if sum_batch else 'n'}hj", self._a_jac_t_mat_prod( - subsample(module.output, dim=N_axis, subsampling=subsampling), + module, module.weight_hh_l0, mat, + subsampling, ), subsample(module.input0, dim=N_axis, subsampling=subsampling), ) @@ -264,15 +283,32 @@ def _weight_hh_l0_jac_t_mat_prod( product """ self._check_parameters(module) - N_axis = get_batch_axis(module, "input0") + N_axis, T_axis = self.get_batch_and_time_axes(module) N: int = mat.shape[N_axis + 1] H: int = mat.shape[3] output = subsample(module.output, dim=N_axis, subsampling=subsampling) + shape_single_step = (N, 1, H) if module.batch_first else (1, N, H) + output_shifted = cat( + [ + zeros(shape_single_step, device=mat.device, dtype=mat.dtype), + output[(slice(None),) * T_axis + (slice(0, -1),)], + ], + dim=T_axis, + ) return einsum( - "vtnh,tnk->" + ("vhk" if sum_batch else "vnhk"), - self._a_jac_t_mat_prod(output, module.weight_hh_l0, mat), - cat( - [zeros(1, N, H, device=mat.device, dtype=mat.dtype), output[0:-1]], - dim=0, - ), + f"vtnh,{'nt' if module.batch_first else 'tn'}k->v{'' if sum_batch else 'n'}hk", + self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat, subsampling), + output_shifted, ) + + @staticmethod + def get_batch_and_time_axes(module: RNN) -> Tuple[int, int]: + """Return axes interpreted by the module as batch and time axes of the input. + + Args: + module: RNN module. + + Returns: + Batch axis and time axis. + """ + return (0, 1) if module.batch_first else (1, 0) diff --git a/backpack/custom_module/graph_utils.py b/backpack/custom_module/graph_utils.py index 7d8a947df..e48a451ee 100644 --- a/backpack/custom_module/graph_utils.py +++ b/backpack/custom_module/graph_utils.py @@ -1,11 +1,14 @@ """Transformation tools to make graph BackPACK compatible.""" from copy import deepcopy +from typing import Tuple, Union from warnings import warn from torch.fx import Graph, GraphModule, Node, Tracer -from torch.nn import Flatten, Module +from torch.nn import LSTM, RNN, Dropout, Flatten, Module, Sequential from backpack.custom_module.branching import ActiveIdentity, SumModule, _Branch +from backpack.custom_module.permute import Permute +from backpack.custom_module.reduce_tuple import ReduceTuple from backpack.custom_module.scale_module import ScaleModule from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 @@ -16,7 +19,9 @@ class BackpackTracer(Tracer): def is_leaf_module( self, m: Module, module_qualified_name: str ) -> bool: # noqa: D102 - if isinstance(m, (ScaleModule, SumModule, _Branch, ActiveIdentity)): + if isinstance( + m, (ScaleModule, SumModule, _Branch, ActiveIdentity, ReduceTuple, Permute) + ): return True else: return super().is_leaf_module(m, module_qualified_name) @@ -29,7 +34,11 @@ def convert_module_to_backpack(module: Module, debug: bool) -> GraphModule: - mul -> ScaleModule - add -> AddModule - flatten -> nn.Flatten - - inplace to normal + - getitem -> ReduceTuple + - permute -> Permute + - transpose -> Transpose + - LSTM: split multiple layers + - inplace -> normal - remove duplicates - delete unused modules - check BackPACK compatible @@ -54,6 +63,10 @@ def convert_module_to_backpack(module: Module, debug: bool) -> GraphModule: module_new = _transform_mul_to_scale_module(module, debug) module_new = _transform_flatten_to_module(module_new, debug) module_new = _transform_add_to_sum_module(module_new, debug) + module_new = _transform_get_item_to_module(module_new, debug) + module_new = _transform_permute_to_module(module_new, debug) + module_new = _transform_transpose_to_module(module_new, debug) + module_new = _transform_lstm_rnn(module_new, debug) _transform_inplace_to_normal(module_new, debug) module_new = _transform_remove_duplicates(module_new, debug) if debug: @@ -101,16 +114,24 @@ def _transform_mul_to_scale_module(module: Module, debug: bool) -> GraphModule: Raises: RuntimeError: if a multiplication is found but node.args are not (float, Node) """ - target = "" + target_function = "" + target_method = "multiply" if debug: - print(f"\tBegin transformation: {target} -> ScaleModule") + print(f"\tBegin transformation: {target_function} -> ScaleModule") graph: Graph = BackpackTracer().trace(module) - nodes = [ - n for n in graph.nodes if n.op == "call_function" and str(n.target) == target + nodes_function = [ + n + for n in graph.nodes + if n.op == "call_function" and str(n.target) == target_function + ] + nodes_method = [ + n + for n in graph.nodes + if n.op == "call_method" and str(n.target) == target_method ] - for node in nodes: + for node in nodes_function: if len(node.args) != 2: raise RuntimeError(f"Expecting 2 arguments, got {len(node.args)}.") @@ -128,11 +149,15 @@ def _transform_mul_to_scale_module(module: Module, debug: bool) -> GraphModule: _change_node_to_module( node, "scale_module", module, ScaleModule(weight), (tensor,) ) + for node in nodes_method: + _change_node_to_module( + node, "scale_module", module, ScaleModule(node.args[1]), (node.args[0],) + ) graph.lint() if debug: - print(f"\tMultiplications transformed: {len(nodes)}") + print(f"\tMultiplications transformed: {len(nodes_function)+len(nodes_method)}") return GraphModule(module, graph) @@ -147,13 +172,17 @@ def _transform_add_to_sum_module(module: Module, debug: bool) -> GraphModule: Returns: equivalent transformed module """ - target = "" + target_function = "" + target_method = "add" if debug: - print(f"\tBegin transformation: {target} -> SumModule") + print(f"\tBegin transformation: {target_function} -> SumModule") graph: Graph = BackpackTracer().trace(module) nodes = [ - n for n in graph.nodes if n.op == "call_function" and str(n.target) == target + n + for n in graph.nodes + if (n.op == "call_function" and str(n.target) == target_function) + or (n.op == "call_method" and str(n.target) == target_method) ] for node in nodes: @@ -177,17 +206,21 @@ def _transform_flatten_to_module(module: Module, debug: bool) -> GraphModule: Returns: equivalent transformed module """ - target = " Flatten") + print(f"\tBegin transformation: {target_function} -> Flatten") graph: Graph = BackpackTracer().trace(module) nodes = [ - n for n in graph.nodes if n.op == "call_function" and target in str(n.target) + n + for n in graph.nodes + if (n.op == "call_function" and target_function in str(n.target)) + or (n.op == "call_method" and target_method == str(n.target)) ] for node in nodes: - start_dim = node.args[1] if len(node.args) > 1 else 1 + start_dim = node.args[1] if len(node.args) > 1 else 0 end_dim = node.args[2] if len(node.args) > 2 else -1 _change_node_to_module( node, "flatten", module, Flatten(start_dim, end_dim), (node.args[0],) @@ -196,11 +229,225 @@ def _transform_flatten_to_module(module: Module, debug: bool) -> GraphModule: graph.lint() if debug: - print(f"\tFlatten transformed: {len(nodes)}") + print(f"\tFlatten functions transformed: {len(nodes)}") + + return GraphModule(module, graph) + + +def _transform_get_item_to_module(module: Module, debug: bool) -> GraphModule: + """Transforms the built-in getitem function to ReduceTuple module. + + This function is usually used to reduce the tuple output of RNNs. + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + """ + target = "" + if debug: + print(f"\tBegin transformation: {target} -> ReduceTuple") + graph: Graph = BackpackTracer().trace(module) + + nodes = [ + n for n in graph.nodes if n.op == "call_function" and target == str(n.target) + ] + for node in nodes: + _change_node_to_module( + node, + "reduce_tuple", + module, + ReduceTuple(index=node.args[1]), + (node.args[0],), + ) + + graph.lint() + if debug: + print(f"\tReduceTuple transformed: {len(nodes)}") + return GraphModule(module, graph) + + +def _transform_permute_to_module(module: Module, debug: bool) -> GraphModule: + """Transforms permute function or method to Permute module. + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + """ + target1 = "permute" + target2 = " Permute") + graph: Graph = BackpackTracer().trace(module) + + nodes = [ + n + for n in graph.nodes + if (n.op == "call_function" and target2 in str(n.target)) + or (n.op == "call_method" and target1 == str(n.target)) + ] + + for node in nodes: + _change_node_to_module( + node, + "permute", + module, + Permute(*node.args[1]) if len(node.args) == 2 else Permute(*node.args[1:]), + (node.args[0],), + ) + + graph.lint() + if debug: + print(f"\tPermute transformed: {len(nodes)}") + return GraphModule(module, graph) + + +def _transform_transpose_to_module(module: Module, debug: bool) -> GraphModule: + """Transforms transpose function or method to Permute module. + + The Permute module is initialized with transpose parameters and computes + the permutation on its first forward pass. + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + """ + target_function = " Permute") + graph: Graph = BackpackTracer().trace(module) + + nodes = [ + n + for n in graph.nodes + if (n.op == "call_function" and target_function in str(n.target)) + or (n.op == "call_method" and target_method == str(n.target)) + ] + + for node in nodes: + _change_node_to_module( + node, + "permute", + module, + Permute(*node.args[1:], init_transpose=True), + (node.args[0],), + ) + + graph.lint() + if debug: + print(f"\tPermute transformed: {len(nodes)}") + return GraphModule(module, graph) + + +def _transform_lstm_rnn(module: Module, debug: bool) -> GraphModule: + """Transforms multi-layer RNN/LSTM to Sequential of single-layer RNN/LSTM. + + Converts multi-layer RNN/LSTM to Sequential with single-layer RNN/LSTM. + If dropout probability is nonzero, creates intermediate dropout layers. + Finally, copies training mode. + + Args: + module: container module to transform + debug: whether to print debug messages + + Returns: + equivalent transformed module + + Raises: + NotImplementedError: if initial hidden state is used in forward pass + """ + if debug: + print("\tBegin transformation: LSTM, RNN") + graph: Graph = BackpackTracer().trace(module) + nodes = [ + n + for n in graph.nodes + if n.op == "call_module" + and isinstance(module.get_submodule(n.target), (RNN, LSTM)) + and module.get_submodule(n.target).num_layers > 1 + ] + for node in nodes: + if len(node.args) > 1: + raise NotImplementedError( + "For conversion, LSTM/RNN input must not have hidden states." + ) + lstm_module_replace = _make_rnn_backpack(module.get_submodule(node.target)) + module.add_module(node.target, lstm_module_replace) + + graph.lint() + if debug: + print(f"\tRNNs, LSTMs transformed: {len(nodes)}") return GraphModule(module, graph) +def _rnn_hyperparams(module: Union[RNN, LSTM]) -> Tuple[int, int, float, bool]: + """Determines the hyperparameters for RNN/LSTM conversion. + + Args: + module: module to convert + + Returns: + input_size, hidden_size, dropout, batch_first + + Raises: + NotImplementedError: if any hyperparameter has a forbidden value + """ + if module.bias is not True: + raise NotImplementedError("only bias = True is supported") + if module.bidirectional is not False: + raise NotImplementedError("only bidirectional = False is supported") + if isinstance(module, RNN): + if module.nonlinearity != "tanh": + raise NotImplementedError("only nonlinearity = 'tanh' is supported") + elif isinstance(module, LSTM): + if module.proj_size != 0: + raise NotImplementedError("only proj_size = 0 is supported") + return module.input_size, module.hidden_size, module.dropout, module.batch_first + + +def _make_rnn_backpack(module: Union[RNN, LSTM]) -> Module: + """Creates an equivalent module to the multi-layer RNN/LSTM. + + Converts multi-layer RNN/LSTM to Sequential with single-layer RNN/LSTM. + If dropout probability is nonzero, creates intermediate dropout layers. + Finally, copies training mode. + + Args: + module: RNN/LSTM module to convert + + Returns: + equivalent Sequential module + """ + input_size, hidden_size, dropout, batch_first = _rnn_hyperparams(module) + rnn_class = type(module) + rnn_module_replace = Sequential() + for layer in range(module.num_layers): + rnn_layer = rnn_class( + input_size if layer == 0 else hidden_size, + hidden_size, + batch_first=batch_first, + ) + for param_str in ["weight_ih_l", "weight_hh_l", "bias_ih_l", "bias_hh_l"]: + setattr(rnn_layer, f"{param_str}0", getattr(module, f"{param_str}{layer}")) + rnn_module_replace.add_module(f"lstm_{layer}", rnn_layer) + if layer != (module.num_layers - 1): + rnn_module_replace.add_module(f"reduce_tuple_{layer}", ReduceTuple()) + if dropout != 0: + rnn_module_replace.add_module(f"dropout_{layer}", Dropout(dropout)) + rnn_module_replace.train(module.training) + return rnn_module_replace + + def _transform_inplace_to_normal( module: Module, debug: bool, initialize_recursion: bool = True ) -> None: diff --git a/backpack/custom_module/permute.py b/backpack/custom_module/permute.py index 11ac40ba6..d71361a14 100644 --- a/backpack/custom_module/permute.py +++ b/backpack/custom_module/permute.py @@ -8,16 +8,23 @@ class Permute(Module): """Module to permute a tensor.""" - def __init__(self, *dims: Any, batch_axis: int = 0): + def __init__(self, *dims: Any, init_transpose: bool = False, batch_axis: int = 0): """Initialization. + This module supports two variants: permutation and transposition. + If transposition should be used, a tuple (axis1, axis2) should be provided and + init_transpose must be True. + Internally, this is converted to a permutation in the first forward pass. + Args: dims: The desired ordering of dimensions. + init_transpose: If transpose parameters are provided. Default: False. batch_axis: Which axis assumed to be the batch axis in a forward pass. Defaults to ``0``. """ super().__init__() self.dims = dims + self.init_transpose = init_transpose self.batch_axis = batch_axis def forward(self, input: Tensor) -> Tensor: @@ -29,8 +36,22 @@ def forward(self, input: Tensor) -> Tensor: Returns: view with new ordering """ + self._convert_transpose_to_permute(input) return input.permute(self.dims) + def _convert_transpose_to_permute(self, input: Tensor): + """Converts the parameters of transpose to a permutation. + + Args: + input: input tensor. Used to determine number of dimensions. + """ + if self.init_transpose: + permutation = list(range(input.dim())) + permutation[self.dims[0]] = self.dims[1] + permutation[self.dims[1]] = self.dims[0] + self.dims = tuple(permutation) + self.init_transpose = False + def get_batch_axis(self, io_str: str) -> int: """Return the batch axis assumed by the module. diff --git a/fully_documented.txt b/fully_documented.txt index 16bd4b4ed..001a9fc13 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -17,6 +17,7 @@ backpack/core/derivatives/embedding.py backpack/core/derivatives/crossentropyloss.py backpack/core/derivatives/scale_module.py backpack/core/derivatives/sum_module.py +backpack/core/derivatives/dropout.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py diff --git a/test/converter/converter_cases.py b/test/converter/converter_cases.py index 67dc9f832..508c7fc2d 100644 --- a/test/converter/converter_cases.py +++ b/test/converter/converter_cases.py @@ -10,10 +10,25 @@ import abc from typing import List, Type -from torch import Tensor, flatten, rand -from torch.nn import Linear, Module, ReLU +from torch import Tensor, flatten, rand, randint, transpose, zeros_like +from torch.nn import ( + LSTM, + RNN, + CrossEntropyLoss, + Dropout, + Embedding, + Linear, + Module, + MSELoss, + ReLU, +) from torchvision.models import resnet18, wide_resnet50_2 +from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 + +if TORCH_VERSION_AT_LEAST_1_9_0: + from torch import permute + class ConverterModule(Module, abc.ABC): """Interface class for test modules for converter.""" @@ -27,6 +42,14 @@ def input_fn(self) -> Tensor: """ return + def loss_fn(self) -> Module: + """The loss function. + + Returns: + loss function + """ + return MSELoss() + CONVERTER_MODULES: List[Type[ConverterModule]] = [] @@ -57,9 +80,10 @@ def input_fn(self) -> Tensor: class _InplaceActivation(ConverterModule): def __init__(self): + super().__init__() + self.batch_size = 3 self.in_dim = 3 out_dim = 2 - super().__init__() self.linear = Linear(self.in_dim, out_dim) self.relu = ReLU(inplace=True) self.linear2 = Linear(out_dim, out_dim) @@ -71,12 +95,13 @@ def forward(self, x): return x def input_fn(self) -> Tensor: - return rand(3, self.in_dim) + return rand(self.batch_size, self.in_dim) class _MultipleUsages(ConverterModule): def __init__(self): super().__init__() + self.batch_size = 3 self.in_dim = 3 out_dim = 2 self.linear = Linear(self.in_dim, out_dim) @@ -92,45 +117,52 @@ def forward(self, x): return x def input_fn(self) -> Tensor: - return rand(3, self.in_dim) + return rand(self.batch_size, self.in_dim) class _FlattenNetwork(ConverterModule): def __init__(self): super().__init__() - self.in_dim = 4 + self.batch_size = 3 + self.in_dim = (2, 2, 4) out_dim = 3 - self.linear = Linear(self.in_dim, out_dim) + self.linear = Linear(self.in_dim[2], out_dim) + self.linear2 = Linear(self.in_dim[1] * out_dim, out_dim) def forward(self, x): x = self.linear(x) - x = flatten(x, 1) + x = flatten(x, 2) # built-in function flatten + x = self.linear2(x) + x = x.flatten(1) # method flatten return x def input_fn(self) -> Tensor: - return rand(3, 2, 2, self.in_dim) + return rand(self.batch_size, *self.in_dim) class _Multiply(ConverterModule): def __init__(self): super().__init__() + self.batch_size = 2 self.in_dim = 4 out_dim = 3 self.linear = Linear(self.in_dim, out_dim) def forward(self, x): - x = x * 2.5 + x = x * 2.5 # built-in method multiply (Tensor-float) x = self.linear(x) - x = 0.5 * x + x = 0.5 * x # built-in method multiply (float-Tensor) + x = x.multiply(3.1415) # method multiply return x def input_fn(self) -> Tensor: - return rand(2, self.in_dim) + return rand(self.batch_size, self.in_dim) class _Add(ConverterModule): def __init__(self): super().__init__() + self.batch_size = 3 self.in_dim = 3 out_dim = 2 self.linear = Linear(self.in_dim, self.in_dim) @@ -142,12 +174,114 @@ def forward(self, x): x = self.linear(x) x1 = self.linear1(x) x2 = self.linear2(x) - x = x1 + x2 + x = x1 + x2 # built-in method add x = self.relu(x) + x = x.add(x2) # method add + return x + + def input_fn(self) -> Tensor: + return rand(self.batch_size, self.in_dim) + + +class _Permute(ConverterModule): + def __init__(self): + super().__init__() + self.batch_size = 3 + self.in_dim = (5, 3) + out_dim = 2 + self.linear = Linear(self.in_dim[-1], out_dim) + self.linear2 = Linear(self.in_dim[-2], out_dim) + + def forward(self, x): + x = self.linear(x) + x = x.permute(0, 2, 1) # method permute + x = self.linear2(x) + x = permute(x, (0, 2, 1)) # function permute + return x + + def input_fn(self) -> Tensor: + return rand(self.batch_size, *self.in_dim) + + def loss_fn(self) -> Module: + return CrossEntropyLoss() + + +class _Transpose(ConverterModule): + def __init__(self): + super().__init__() + self.batch_size = 3 + self.in_dim = (5, 3) + out_dim = 2 + out_dim2 = 3 + self.linear = Linear(self.in_dim[-1], out_dim) + self.linear2 = Linear(self.in_dim[-2], out_dim2) + + def forward(self, x): + x = self.linear(x) + x = x.transpose(1, 2) # method transpose + x = self.linear2(x) + x = transpose(x, 1, 2) # function transpose return x def input_fn(self) -> Tensor: - return rand(3, self.in_dim) + return rand(self.batch_size, *self.in_dim) + + def loss_fn(self) -> Module: + return CrossEntropyLoss() + + +class _TolstoiCharRNN(ConverterModule): + def __init__(self): + super(_TolstoiCharRNN, self).__init__() + self.batch_size = 8 + self.hidden_dim = 64 + self.num_layers = 2 + self.seq_len = 15 + self.vocab_size = 25 + + self.embedding = Embedding( + num_embeddings=self.vocab_size, embedding_dim=self.hidden_dim + ) + self.dropout = Dropout(p=0.2) + self.lstm = LSTM( + input_size=self.hidden_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + dropout=0.36, + batch_first=True, + ) + self.lstm.bias_ih_l0.data = zeros_like(self.lstm.bias_ih_l0) + self.lstm.bias_ih_l1.data = zeros_like(self.lstm.bias_ih_l1) + self.lstm.bias_ih_l0.requires_grad = False + self.lstm.bias_ih_l1.requires_grad = False + self.dense = Linear(in_features=self.hidden_dim, out_features=self.vocab_size) + + def forward(self, x): + x = self.embedding(x) + x = self.dropout(x) + x, new_state = self.lstm(x) + x = self.dropout(x) + output = self.dense(x) + output = output.permute(0, 2, 1) + return output + + def input_fn(self) -> Tensor: + return randint(0, self.vocab_size, (self.batch_size, self.seq_len)) + + def loss_fn(self) -> Module: + return CrossEntropyLoss() + + +class _TolstoiRNNVersion(_TolstoiCharRNN): + def __init__(self): + super(_TolstoiRNNVersion, self).__init__() + self.lstm = RNN( + input_size=self.hidden_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + dropout=0.36, + batch_first=True, + ) CONVERTER_MODULES += [ @@ -158,4 +292,8 @@ def input_fn(self) -> Tensor: _FlattenNetwork, _Multiply, _Add, + _Permute, + _Transpose, + _TolstoiCharRNN, + _TolstoiRNNVersion, ] diff --git a/test/converter/test_converter.py b/test/converter/test_converter.py index 66d390474..735a880b9 100644 --- a/test/converter/test_converter.py +++ b/test/converter/test_converter.py @@ -4,12 +4,13 @@ - whether DiagGGN runs without errors on new network """ from test.converter.converter_cases import CONVERTER_MODULES, ConverterModule +from test.core.derivatives.utils import classification_targets, regression_targets from test.utils.skip_test import skip_pytorch_below_1_9_0 from typing import Tuple from pytest import fixture -from torch import Tensor, allclose, cat, int32, linspace, manual_seed, rand_like -from torch.nn import Module, MSELoss +from torch import Tensor, allclose, cat, int32, linspace, manual_seed +from torch.nn import CrossEntropyLoss, Module, MSELoss from backpack import backpack, extend from backpack.extensions import DiagGGNExact @@ -20,22 +21,22 @@ params=CONVERTER_MODULES, ids=[str(model_class) for model_class in CONVERTER_MODULES], ) -def model_and_input(request) -> Tuple[Module, Tensor]: +def model_and_input(request) -> Tuple[Module, Tensor, Module]: """Yield ResNet model and an input to it. Args: request: pytest request Yields: - model and input + model and input and loss function """ manual_seed(0) skip_pytorch_below_1_9_0() model: ConverterModule = request.param() inputs: Tensor = model.input_fn() - inputs.requires_grad = True - yield model, inputs - del model, inputs + loss_fn: Module = model.loss_fn() + yield model, inputs, loss_fn + del model, inputs, loss_fn def test_network_diag_ggn(model_and_input): @@ -49,16 +50,28 @@ def test_network_diag_ggn(model_and_input): Args: model_and_input: module to test + + Raises: + NotImplementedError: if loss_fn is not MSELoss or CrossEntropyLoss """ - model_original, x = model_and_input + model_original, x, loss_fn = model_and_input + model_original = model_original.eval() output_compare = model_original(x) - y = rand_like(output_compare) - - num_params = sum(p.numel() for p in model_original.parameters()) + if isinstance(loss_fn, MSELoss): + y = regression_targets(output_compare.shape) + elif isinstance(loss_fn, CrossEntropyLoss): + y = classification_targets( + (output_compare.shape[0], *output_compare.shape[2:]), + output_compare.shape[1], + ) + else: + raise NotImplementedError(f"test cannot handle loss_fn = {type(loss_fn)}") + + num_params = sum(p.numel() for p in model_original.parameters() if p.requires_grad) num_to_compare = 10 idx_to_compare = linspace(0, num_params - 1, num_to_compare, dtype=int32) diag_ggn_exact_to_compare = autograd_diag_ggn_exact( - x, y, model_original, MSELoss(), idx=idx_to_compare + x, y, model_original, loss_fn, idx=idx_to_compare ) model_extended = extend(model_original, use_converter=True, debug=True) @@ -66,14 +79,18 @@ def test_network_diag_ggn(model_and_input): assert allclose(output, output_compare) - loss = extend(MSELoss())(output, y) + loss = extend(loss_fn)(output, y) with backpack(DiagGGNExact()): loss.backward() diag_ggn_exact_vector = cat( - [p.diag_ggn_exact.flatten() for p in model_extended.parameters()] + [ + p.diag_ggn_exact.flatten() + for p in model_extended.parameters() + if p.requires_grad + ] ) for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): - assert allclose(element, diag_ggn_exact_vector[idx]) + assert allclose(element, diag_ggn_exact_vector[idx], atol=1e-5) diff --git a/test/core/derivatives/rnn_settings.py b/test/core/derivatives/rnn_settings.py index 8b4bd88ec..933cf52af 100644 --- a/test/core/derivatives/rnn_settings.py +++ b/test/core/derivatives/rnn_settings.py @@ -24,4 +24,10 @@ "module_fn": lambda: torch.nn.RNN(input_size=4, hidden_size=2), "input_fn": lambda: torch.rand(size=(1, 1, 4)), }, + { + "module_fn": lambda: torch.nn.RNN( + input_size=4, hidden_size=3, batch_first=True + ), + "input_fn": lambda: torch.rand(size=(3, 5, 4)), + }, ] diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 59316c062..808921c43 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -75,11 +75,10 @@ "target_fn": lambda: classification_targets((4,), 4 * 3), }, { - "input_fn": lambda: rand(5, 8, 6), # TODO (8, 5, 6) + "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( - RNN(input_size=6, hidden_size=3), # TODO batch_first=True + RNN(input_size=6, hidden_size=3, batch_first=True), ReduceTuple(index=0), - Permute(1, 0, 2, batch_axis=1), # TODO remove Linear(3, 3), Permute(0, 2, 1, batch_axis=0), ), From ff1fbd5bb1fd34b357287f5d9392e974f195829b Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Thu, 7 Oct 2021 14:41:41 +0200 Subject: [PATCH 50/54] [DOC] Improve documentation (#224) Update documentation for new release. --- * [DOC] Improve context manager documentation * [DOC] Improve limitations section * [DOC] Update supported layers, mention converter --- backpack/__init__.py | 2 +- docs_src/rtd/good-to-know.rst | 43 +++++------ docs_src/rtd/main-api.rst | 3 +- docs_src/rtd/supported-layers.rst | 118 +++++++++++++++++------------- 4 files changed, 91 insertions(+), 75 deletions(-) diff --git a/backpack/__init__.py b/backpack/__init__.py index 86a6f8cb4..8e4ab64ef 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -20,7 +20,7 @@ class backpack: - """Activate BackPACK extensions.""" + """Context manager to activate BackPACK extensions.""" def __init__( self, diff --git a/docs_src/rtd/good-to-know.rst b/docs_src/rtd/good-to-know.rst index 8f7a777ad..18131e80f 100644 --- a/docs_src/rtd/good-to-know.rst +++ b/docs_src/rtd/good-to-know.rst @@ -19,7 +19,7 @@ and a backward pass over each sample individually. This is slow, but can be used to check that the values returned by BackPACK match what you expect them to be. -While we test many a use-case and try to write solid code, unexpected +While we test many use-cases and try to write solid code, unexpected behavior (such as some listed on this page) or bugs are not impossible. We recommend that you check that the outputs match your expectations, especially if you're using non-default values on slightly more unusual parameters @@ -55,32 +55,29 @@ are not affected by :py:meth:`zero_grad() `. The :ref:`intro example ` shows how to make a model using a :py:class:`torch.nn.Sequential` module -and how to :py:func:`extend() ` the model and the loss function, -but this setup is only really necessary for -:ref:`second order quantities `. +and how to :py:func:`extend() ` the model and the loss function. +But extending everything is only really necessary for +:ref:`Second order quantities `. For those, BackPACK needs to know about the structure of the whole network -to propagate additional information. - -:ref:`First order extensions ` are more flexible, -and the only :py:class:`torch.nn.Module` that need to be extended -are modules with parameters, to extract more information, -as the gradients are already propagated by PyTorch. -For every operations that is not parametrized, you can use standard operations -from the :std:doc:`torch.nn.functional ` module or standard -tensor operations. This makes it possible to use first order extensions -for ResNets (see :ref:`this example `). +to backpropagate additional information. +Because :ref:`First order extensions ` don't +backpropagate additional information, they are more flexible and only require +every parameterized :py:class:`torch.nn.Module` be extended. For any +unparameterized operation, you can use standard operations from the +:std:doc:`torch.nn.functional ` module or standard tensor +operations. Not (yet) supported models ---------------------------------- -The second-order extensions for BackPACK partially support residual and -recurrent networks. We're working on how to handle those, as well as -adding more :ref:`layers `. -Along those lines, some things that will (most likely) not work with BackPACK, -but that we're trying to build support for: - -- Reusing the same parameters multiple times in the computation graph. +We're working on handling more complex computation graphs, as well as adding +more :ref:`layers `. Along those lines, some things that will +(most likely) **not** work with BackPACK are: - This sadly mean that BackPACK can't compute the individual gradients or - second-order information of a L2-regularized loss, for example. +- **Reusing the same parameters multiple times in the computation graph:** This + sadly means that BackPACK can't compute the individual gradients or + second-order information of an L2-regularized loss or architectures that use + parameters multiple times. +- **Some exotic hyperparameters are not fully supported:** Feature requests on + the repository are welcome. diff --git a/docs_src/rtd/main-api.rst b/docs_src/rtd/main-api.rst index 67a4d0c52..dce6792e2 100644 --- a/docs_src/rtd/main-api.rst +++ b/docs_src/rtd/main-api.rst @@ -66,5 +66,6 @@ and the :ref:`Supported models`. ----- .. autofunction:: backpack.extend -.. autofunction:: backpack.backpack +.. autoclass:: backpack.backpack + :members: __init__ .. autofunction:: backpack.disable diff --git a/docs_src/rtd/supported-layers.rst b/docs_src/rtd/supported-layers.rst index 149bb7dc2..41887141c 100644 --- a/docs_src/rtd/supported-layers.rst +++ b/docs_src/rtd/supported-layers.rst @@ -14,12 +14,13 @@ For example, torch.nn.Linear(64, 10) ) -This page lists the layers currently supported by BackPACK. - +**If you overwrite any** :code:`forward()` **function** (for example in ResNets +and RNNs), the additional backward pass to compute second-order quantities will +break. You can ask BackPACK to inspect the graph and try converting it +into a compatible structure by setting :code:`use_converter=True` in +:py:func:`extend `. -**Do not rewrite the** :code:`forward()` **function of the** :code:`Sequential` **or the inner modules!** -If the forward is not standard, the additional backward pass to compute second-order quantities will not match the actual function. -First-order extensions that extract information might work outside of this framework, but it is not tested. +This page lists the layers currently supported by BackPACK. .. raw:: html @@ -38,55 +39,72 @@ parameters of the following layers; - :py:class:`torch.nn.ConvTranspose1d`, :py:class:`torch.nn.ConvTranspose2d`, :py:class:`torch.nn.ConvTranspose3d` +- :py:class:`torch.nn.BatchNorm1d` (evaluation mode), + :py:class:`torch.nn.BatchNorm2d` (evaluation mode), + :py:class:`torch.nn.BatchNorm3d` (evaluation mode) +- :py:class:`torch.nn.Embedding` +- :py:class:`torch.nn.RNN`, :py:class:`torch.nn.LSTM` -First-order extensions should support any module as long as they do not have parameters, -but some layers lead to the concept of "individual gradient for a sample in a minibatch" -to be ill-defined, as they introduce dependencies across examples -(like :py:class:`torch.nn.BatchNorm`). +Some layers (like :code:`torch.nn.BatchNormNd` in training mode) mix samples and +lead to ill-defined first-order quantities. ----- For second-order extensions -------------------------------------- -BackPACK needs to know how to propagate second-order information. -This is implemented for: - -+-------------------------------+---------------------------------------+ -| **Parametrized layers** | :py:class:`torch.nn.Conv1d`, | -| | :py:class:`torch.nn.Conv2d`, | -| | :py:class:`torch.nn.Conv3d` | -| +---------------------------------------+ -| | :py:class:`torch.nn.ConvTranspose1d`, | -| | :py:class:`torch.nn.ConvTranspose2d`, | -| | :py:class:`torch.nn.ConvTranspose3d` | -| +---------------------------------------+ -| | :py:class:`torch.nn.Linear` | -+-------------------------------+---------------------------------------+ -| **Loss functions** | :py:class:`torch.nn.MSELoss` | -| +---------------------------------------+ -| | :py:class:`torch.nn.CrossEntropyLoss` | -+-------------------------------+---------------------------------------+ -| **Layers without parameters** | :py:class:`torch.nn.MaxPool1d`, | -| | :py:class:`torch.nn.MaxPool2d`, | -| | :py:class:`torch.nn.MaxPool3d` | -| +---------------------------------------+ -| | :py:class:`torch.nn.AvgPool1d`, | -| | :py:class:`torch.nn.AvgPool2d`, | -| | :py:class:`torch.nn.AvgPool3d` | -| +---------------------------------------+ -| | :py:class:`torch.nn.ZeroPad2d`, | -| +---------------------------------------+ -| | :py:class:`torch.nn.Dropout` | -| +---------------------------------------+ -| | :py:class:`torch.nn.ReLU`, | -| | :py:class:`torch.nn.Sigmoid`, | -| | :py:class:`torch.nn.Tanh`, | -| | :py:class:`torch.nn.LeakyReLU`, | -| | :py:class:`torch.nn.LogSigmoid`, | -| | :py:class:`torch.nn.ELU`, | -| | :py:class:`torch.nn.SELU` | -+-------------------------------+---------------------------------------+ - -Some exotic hyperparameters are not fully supported, but feature requests -on the repository are welcome. +BackPACK needs to know how to backpropagate additional information for +second-order quantities. This is implemented for: + ++-------------------------------+-----------------------------------------------+ +| **Parametrized layers** | :py:class:`torch.nn.Conv1d`, | +| | :py:class:`torch.nn.Conv2d`, | +| | :py:class:`torch.nn.Conv3d` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.ConvTranspose1d`, | +| | :py:class:`torch.nn.ConvTranspose2d`, | +| | :py:class:`torch.nn.ConvTranspose3d` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.Linear` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.BatchNorm1d`, | +| | :py:class:`torch.nn.BatchNorm2d`, | +| | :py:class:`torch.nn.BatchNorm3d` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.Embedding` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.RNN`, | +| | :py:class:`torch.nn.LSTM` | ++-------------------------------+-----------------------------------------------+ +| **Loss functions** | :py:class:`torch.nn.MSELoss` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.CrossEntropyLoss` | ++-------------------------------+-----------------------------------------------+ +| **Layers without parameters** | :py:class:`torch.nn.MaxPool1d`, | +| | :py:class:`torch.nn.MaxPool2d`, | +| | :py:class:`torch.nn.MaxPool3d` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.AvgPool1d`, | +| | :py:class:`torch.nn.AvgPool2d`, | +| | :py:class:`torch.nn.AvgPool3d` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.AdaptiveAvgPool1d`, | +| | :py:class:`torch.nn.AdaptiveAvgPool2d`, | +| | :py:class:`torch.nn.AdaptiveAvgPool3d` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.ZeroPad2d`, | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.Dropout` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.ReLU`, | +| | :py:class:`torch.nn.Sigmoid`, | +| | :py:class:`torch.nn.Tanh`, | +| | :py:class:`torch.nn.LeakyReLU`, | +| | :py:class:`torch.nn.LogSigmoid`, | +| | :py:class:`torch.nn.ELU`, | +| | :py:class:`torch.nn.SELU` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.Identity` | +| +-----------------------------------------------+ +| | :py:class:`torch.nn.Flatten` | ++-------------------------------+-----------------------------------------------+ From 0864de42d4ffaaf25279047a77408cbc189d3a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Fri, 8 Oct 2021 14:20:00 +0200 Subject: [PATCH 51/54] [REF] Bump torch version to 1.9.0 (#226) - Use `torch>=1.9.0` - Remove `backward_hook`, fully rely on `full_backward_hook` - Remove `ActiveIdentity` and tests - Use ellipsis in `einsum` equations - Remove exceptions for unavailable imports when using `torch<1.9.0` --- * [REF] torch >= 1.7.0 * [REF] torch >= 1.8.0 * [REF] torch >= 1.9.0 * [FIX] test semantics * [FIX] remove python 3.10 from tests * [DEL] remove ActiveIdentity, move test to diag_ggn_settings.py * [DEL] remove _no_inplace * [DEL] remove flatten-no-op test * [DOC] simplify docstring * [DEL] linear.py: remove _has_additional_dims() * [FIX] example_resnet_all_in_one.py: remove ActiveIdentity * [DEL] Redundant branching test Co-authored-by: Felix Dangel --- .github/workflows/test.yaml | 7 +- backpack/__init__.py | 22 +-- backpack/core/derivatives/batchnorm_nd.py | 12 +- backpack/core/derivatives/elementwise.py | 31 ---- backpack/core/derivatives/embedding.py | 10 +- backpack/core/derivatives/linear.py | 29 +--- backpack/core/derivatives/lstm.py | 6 +- backpack/core/derivatives/scale_module.py | 4 +- backpack/custom_module/branching.py | 10 -- backpack/custom_module/graph_utils.py | 15 +- backpack/extensions/module_extension.py | 24 +-- backpack/extensions/saved_quantities.py | 4 +- .../secondorder/diag_ggn/__init__.py | 4 +- backpack/utils/__init__.py | 24 --- backpack/utils/module_classification.py | 9 +- .../use_cases/example_resnet_all_in_one.py | 13 +- setup.cfg | 2 +- test/converter/converter_cases.py | 7 +- test/converter/test_branching.py | 145 ------------------ test/converter/test_converter.py | 2 - test/core/derivatives/__init__.py | 3 +- .../core/derivatives/scale_module_settings.py | 5 - .../firstorder/firstorder_settings.py | 28 ++-- .../secondorder/diag_ggn/diag_ggn_settings.py | 49 +++--- test/extensions/secondorder/hbp/test_kfac.py | 4 +- test/extensions/secondorder/hbp/test_kflr.py | 4 +- test/extensions/secondorder/hbp/test_kfra.py | 4 +- .../secondorder/secondorder_settings.py | 15 -- test/extensions/test_hooks.py | 5 +- test/utils/skip_test.py | 8 +- 30 files changed, 68 insertions(+), 437 deletions(-) delete mode 100644 test/converter/test_branching.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3a0d4871a..ee7391bff 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,12 +20,7 @@ jobs: strategy: matrix: python-version: [3.7, 3.8, 3.9] - pytorch-version: [1.6.0, 1.7.1, 1.8.0, 1.9.0] - exclude: - - pytorch-version: 1.6.0 - python-version: 3.9 - - pytorch-version: 1.7.1 - python-version: 3.9 + pytorch-version: [1.9.0, 1.9.1] steps: - uses: actions/checkout@v1 - uses: actions/setup-python@v1 diff --git a/backpack/__init__.py b/backpack/__init__.py index 8e4ab64ef..85b5aa578 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -4,20 +4,16 @@ from typing import Callable, Optional, Tuple, Type, Union from torch import Tensor, is_grad_enabled +from torch.fx import GraphModule from torch.nn import Module from backpack import extensions from backpack.context import CTX +from backpack.custom_module.graph_utils import convert_module_to_backpack from backpack.extensions.backprop_extension import BackpropExtension -from backpack.utils import CONVERTER_AVAILABLE, FULL_BACKWARD_HOOK from backpack.utils.hooks import no_op from backpack.utils.module_classification import is_no_op -if CONVERTER_AVAILABLE: - from torch.fx import GraphModule - - from backpack.custom_module.graph_utils import convert_module_to_backpack - class backpack: """Context manager to activate BackPACK extensions.""" @@ -244,17 +240,11 @@ def extend(module: Module, debug: bool = False, use_converter: bool = False) -> Returns: Extended module. - - Raises: - RuntimeError: if trying to use converter without torch>=1.9.0 """ if debug: print("[DEBUG] Extending", module) if use_converter: - if not CONVERTER_AVAILABLE: - raise RuntimeError("use_converter=True is only available for torch>=1.9.0.") - module: GraphModule = convert_module_to_backpack(module, debug) return extend(module) @@ -277,10 +267,4 @@ def _register_hooks(module: Module) -> None: module: module that is going to be extended """ CTX.add_hook_handle(module.register_forward_hook(hook_store_io)) - - if FULL_BACKWARD_HOOK: - register_backward_hook_fn = module.register_full_backward_hook - else: - register_backward_hook_fn = module.register_backward_hook - - CTX.add_hook_handle(register_backward_hook_fn(hook_run_extensions)) + CTX.add_hook_handle(module.register_full_backward_hook(hook_run_extensions)) diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index 93c8cf99c..7fe15255a 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -5,7 +5,6 @@ from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 from backpack.utils.subsampling import subsample @@ -148,16 +147,7 @@ def _weight_jac_t_mat_prod( x_hat, _ = self._get_normalized_input_and_var(module) x_hat = subsample(x_hat, subsampling=subsampling) - if TORCH_VERSION_AT_LEAST_1_9_0: - equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c" - # TODO Remove else-branch after deprecating torch<1.9.0 - else: - N: int = self._get_n_axis(module) - spatial_dims = "xyz"[:N] - equation = ( - f"vnc{spatial_dims},nc{spatial_dims}->v{'' if sum_batch else 'n'}c" - ) - + equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c" return einsum(equation, mat, x_hat) def _bias_jac_mat_prod( diff --git a/backpack/core/derivatives/elementwise.py b/backpack/core/derivatives/elementwise.py index e7ec3de29..33c2b4511 100644 --- a/backpack/core/derivatives/elementwise.py +++ b/backpack/core/derivatives/elementwise.py @@ -69,8 +69,6 @@ def hessian_diagonal(self, module, g_inp, g_out): g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs. g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs. """ - self._no_inplace(module) - return self.d2f(module, g_inp, g_out) * g_out[0] def hessian_is_diagonal(self, module): @@ -89,19 +87,13 @@ def _jac_t_mat_prod( mat: Tensor, subsampling: List[int] = None, ) -> Tensor: - self._no_inplace(module) - df_elementwise = self.df(module, g_inp, g_out, subsampling=subsampling) return einsum("...,v...->v...", df_elementwise, mat) def _jac_mat_prod(self, module, g_inp, g_out, mat): - self._no_inplace(module) - return self.jac_t_mat_prod(module, g_inp, g_out, mat) def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): - self._no_inplace(module) - N = module.input0.size(0) df_flat = self.df(module, g_inp, g_out).reshape(N, -1) return einsum("ni,nj,ij->ij", df_flat, df_flat, mat) / N @@ -109,26 +101,3 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): def _residual_mat_prod(self, module, g_inp, g_out, mat): residual = self.d2f(module, g_inp, g_out) * g_out[0] return einsum("...,v...->v...", residual, mat) - - # TODO Deprecate after supporting torch >= 1.8.0 and full_backward_hook - @staticmethod - def _no_inplace(module: Module): - """Do not support inplace modification. - - Jacobians/Hessians might be computed using the modified input instead - of the original. - - Args: - module: Elementwise activation module. - - Raises: - NotImplementedError: If `module` has inplace option enabled. - - Todo: - - Write tests to investigate what happens with `inplace=True`. - """ - has_inplace_option = hasattr(module, "inplace") - - if has_inplace_option: - if module.inplace is True: - raise NotImplementedError("Inplace not supported in {}.".format(module)) diff --git a/backpack/core/derivatives/embedding.py b/backpack/core/derivatives/embedding.py index dc2529358..acb191be8 100644 --- a/backpack/core/derivatives/embedding.py +++ b/backpack/core/derivatives/embedding.py @@ -5,7 +5,6 @@ from torch.nn import Embedding from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 from backpack.utils.subsampling import subsample @@ -50,14 +49,7 @@ def _weight_jac_t_mat_prod( delta = zeros(module.num_embeddings, *input0.shape, device=mat.device) for s in range(module.num_embeddings): delta[s] = input0 == s - if TORCH_VERSION_AT_LEAST_1_9_0: - equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh" - elif delta.dim() == 2: - equation = f"sn,vnh->v{'' if sum_batch else 'n'}sh" - else: - equation = f"snx,vnxh->v{'' if sum_batch else 'n'}sh" - delta = delta.flatten(start_dim=2, end_dim=-1) - mat = mat.flatten(start_dim=2, end_dim=-2) + equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh" return einsum(equation, delta, mat) def _check_parameters(self, module: Embedding) -> None: diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index 27743e3b2..a3156f927 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -5,7 +5,6 @@ from torch.nn import Linear from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 from backpack.utils.subsampling import subsample @@ -152,17 +151,7 @@ def _weight_jac_t_mat_prod( """ d_weight = subsample(module.input0, subsampling=subsampling) - if TORCH_VERSION_AT_LEAST_1_9_0: - equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi" - # TODO Remove else-branch after deprecating torch<1.9.0 - else: - if self._has_additional_dims(module): - d_weight = d_weight.flatten(start_dim=1, end_dim=-2) - mat = mat.flatten(start_dim=2, end_dim=-2) - equation = f"vnao,nai->v{'' if sum_batch else 'n'}oi" - else: - equation = f"vno,ni->v{'' if sum_batch else 'n'}oi" - + equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi" return einsum(equation, mat, d_weight) def _bias_jac_mat_prod( @@ -223,22 +212,6 @@ def _bias_jac_t_mat_prod( equation = f"vn...o->v{'' if sum_batch else 'n'}o" return einsum(equation, mat) - # TODO Remove after deprecating torch<1.9.0 - @classmethod - def _has_additional_dims(cls, module: Linear) -> bool: - """Return whether the input to a linear layer has additional (>1) dimensions. - - The input to a linear layer may have shape ``[N, *, out_features]``. - It has additional dimensions if ``*`` is non-empty. - - Args: - module: Linear layer. - - Returns: - Whether the input has hidden dimensions. - """ - return len(cls._get_additional_dims(module)) != 0 - @staticmethod def _get_additional_dims(module: Linear) -> Size: """Return the shape of additional dimensions in the input to a linear layer. diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 5734d719f..5e4b003c8 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -5,7 +5,6 @@ from torch.nn import LSTM from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION_AT_LEAST_1_8_0 from backpack.utils.subsampling import subsample @@ -51,9 +50,8 @@ def _check_parameters(module: LSTM) -> None: raise NotImplementedError("only dropout = 0 is supported") if module.bidirectional is not False: raise NotImplementedError("only bidirectional = False is supported") - if TORCH_VERSION_AT_LEAST_1_8_0: - if module.proj_size != 0: - raise NotImplementedError("only proj_size = 0 is supported") + if module.proj_size != 0: + raise NotImplementedError("only proj_size = 0 is supported") @staticmethod def _forward_pass( diff --git a/backpack/core/derivatives/scale_module.py b/backpack/core/derivatives/scale_module.py index d83d33c04..9965a204c 100644 --- a/backpack/core/derivatives/scale_module.py +++ b/backpack/core/derivatives/scale_module.py @@ -1,4 +1,4 @@ -"""Derivatives of ScaleModule (implies ActiveIdentity and Identity).""" +"""Derivatives of ScaleModule (implies Identity).""" from typing import List, Tuple, Union from torch import Tensor @@ -9,7 +9,7 @@ class ScaleModuleDerivatives(BaseDerivatives): - """Derivatives of ScaleModule (implies ActiveIdentity and Identity).""" + """Derivatives of ScaleModule (implies Identity).""" def _jac_t_mat_prod( self, diff --git a/backpack/custom_module/branching.py b/backpack/custom_module/branching.py index fc0d5149a..109b88c87 100644 --- a/backpack/custom_module/branching.py +++ b/backpack/custom_module/branching.py @@ -4,16 +4,6 @@ from torch import Tensor from torch.nn import Module -from backpack.custom_module.scale_module import ScaleModule - - -class ActiveIdentity(ScaleModule): - """Like ``torch.nn.Identity``, but creates a new node in the computation graph.""" - - def __init__(self): - """Initialization with weight=1.0.""" - super().__init__(weight=1.0) - class _Branch(Module): """Module used by BackPACK to handle branching in the computation graph. diff --git a/backpack/custom_module/graph_utils.py b/backpack/custom_module/graph_utils.py index e48a451ee..62bc0ea03 100644 --- a/backpack/custom_module/graph_utils.py +++ b/backpack/custom_module/graph_utils.py @@ -6,11 +6,10 @@ from torch.fx import Graph, GraphModule, Node, Tracer from torch.nn import LSTM, RNN, Dropout, Flatten, Module, Sequential -from backpack.custom_module.branching import ActiveIdentity, SumModule, _Branch +from backpack.custom_module.branching import SumModule, _Branch from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple from backpack.custom_module.scale_module import ScaleModule -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 class BackpackTracer(Tracer): @@ -19,9 +18,7 @@ class BackpackTracer(Tracer): def is_leaf_module( self, m: Module, module_qualified_name: str ) -> bool: # noqa: D102 - if isinstance( - m, (ScaleModule, SumModule, _Branch, ActiveIdentity, ReduceTuple, Permute) - ): + if isinstance(m, (ScaleModule, SumModule, _Branch, ReduceTuple, Permute)): return True else: return super().is_leaf_module(m, module_qualified_name) @@ -49,15 +46,7 @@ def convert_module_to_backpack(module: Module, debug: bool) -> GraphModule: Returns: BackPACK-compatible module - - Raises: - NotImplementedError: if not torch >= 1.9.0 """ - if TORCH_VERSION_AT_LEAST_1_9_0 is False: - raise NotImplementedError( - "Conversion is only possible for torch >= 1.9.0. This is because these " - "functions use functionality such as torch.nn.Module.get_submodule" - ) if debug: print("\nMake module BackPACK-compatible...") module_new = _transform_mul_to_scale_module(module, debug) diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py index d530788a7..96e4fdfd1 100644 --- a/backpack/extensions/module_extension.py +++ b/backpack/extensions/module_extension.py @@ -5,9 +5,8 @@ from warnings import warn from torch import Tensor -from torch.nn import Flatten, Module +from torch.nn import Module -from backpack.utils import FULL_BACKWARD_HOOK from backpack.utils.module_classification import is_loss if TYPE_CHECKING: @@ -104,23 +103,10 @@ def __call__( and bp_quantity is None and not is_loss(module) ): - if not FULL_BACKWARD_HOOK and isinstance(module, Flatten): - # Flatten layers whose input is already flat do not add a node to the - # graph. This leads to unintuitive order of backward hook execution: - # https://discuss.pytorch.org/t/backward-hooks-changing-order-of-execution-in-nn-sequential/12447/4. # noqa: B950 - # Skip everything below if this scenario is encountered. - no_op = module.input0.shape == module.output.shape - if not no_op: - raise AssertionError( - "Expected no op Flatten module. Got " - + f"{module.input0.shape} -> {module.output.shape}" - ) - return - else: - raise AssertionError( - "BackPACK extension expects a backpropagation quantity but it is None. " - f"Module: {module}, Extension: {extension}." - ) + raise AssertionError( + "BackPACK extension expects a backpropagation quantity but it is None. " + f"Module: {module}, Extension: {extension}." + ) for param in self.__params: if self.__param_exists_and_requires_grad(module, param): diff --git a/backpack/extensions/saved_quantities.py b/backpack/extensions/saved_quantities.py index 3eb38ad8e..c5006fc2a 100644 --- a/backpack/extensions/saved_quantities.py +++ b/backpack/extensions/saved_quantities.py @@ -38,9 +38,7 @@ def retrieve_quantity(self, key: int, delete_old: bool) -> Union[Tensor, None]: """Returns the saved quantity. Args: - key: data_ptr() of reference tensor. - For torch>=1.9.0 the reference tensor is grad_output[0]. - For older versions the reference tensor is module.output. + key: data_ptr() of module.output. delete_old: whether to delete the old quantity Returns: diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index df0e02759..e18d7eea3 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -47,7 +47,7 @@ ZeroPad2d, ) -from backpack.custom_module.branching import ActiveIdentity, SumModule +from backpack.custom_module.branching import SumModule from backpack.custom_module.permute import Permute from backpack.custom_module.scale_module import ScaleModule from backpack.extensions.secondorder.base import SecondOrderBackpropExtension @@ -131,7 +131,6 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): ELU: activations.DiagGGNELU(), SELU: activations.DiagGGNSELU(), Identity: custom_module.DiagGGNScaleModule(), - ActiveIdentity: custom_module.DiagGGNScaleModule(), ScaleModule: custom_module.DiagGGNScaleModule(), SumModule: custom_module.DiagGGNSumModule(), RNN: rnn.DiagGGNRNN(), @@ -255,7 +254,6 @@ def __init__(self, loss_hessian_strategy: str, savefield: str): ELU: activations.DiagGGNELU(), SELU: activations.DiagGGNSELU(), Identity: custom_module.DiagGGNScaleModule(), - ActiveIdentity: custom_module.DiagGGNScaleModule(), ScaleModule: custom_module.DiagGGNScaleModule(), SumModule: custom_module.DiagGGNSumModule(), RNN: rnn.BatchDiagGGNRNN(), diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index d05762d93..d5fb6701b 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -1,32 +1,8 @@ """Contains utility functions.""" -from typing import Type - from pkg_resources import get_distribution, packaging TORCH_VERSION = packaging.version.parse(get_distribution("torch").version) -TORCH_VERSION_AT_LEAST_1_7_0 = TORCH_VERSION >= packaging.version.parse("1.7.0") -TORCH_VERSION_AT_LEAST_1_8_0 = TORCH_VERSION >= packaging.version.parse("1.8.0") -TORCH_VERSION_AT_LEAST_1_9_0 = TORCH_VERSION >= packaging.version.parse("1.9.0") TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1") TORCH_VERSION_AT_LEAST_2_0_0 = TORCH_VERSION >= packaging.version.parse("2.0.0") -FULL_BACKWARD_HOOK: bool = TORCH_VERSION_AT_LEAST_1_9_0 -CONVERTER_AVAILABLE: bool = TORCH_VERSION_AT_LEAST_1_9_0 ADAPTIVE_AVG_POOL_BUG: bool = not TORCH_VERSION_AT_LEAST_2_0_0 - - -def exception_inside_backward_pass(error: Type[Exception]) -> Type[Exception]: - """Returns the type of exception that gets raised inside a backward pass by PyTorch. - - For Torch>=1.7.0 the error is identical. - - Args: - error: previous exception type - - Returns: - new exception type - """ - if TORCH_VERSION_AT_LEAST_1_7_0: - return error - else: - return RuntimeError diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py index c28169eab..b8d9b5b5f 100644 --- a/backpack/utils/module_classification.py +++ b/backpack/utils/module_classification.py @@ -1,13 +1,10 @@ """Contains util function for classification of modules.""" +from torch.fx import GraphModule from torch.nn import Module, Sequential from torch.nn.modules.loss import _Loss from backpack.custom_module.branching import Parallel, _Branch from backpack.custom_module.reduce_tuple import ReduceTuple -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 - -if TORCH_VERSION_AT_LEAST_1_9_0: - from torch.fx import GraphModule def is_loss(module: Module) -> bool: @@ -31,7 +28,5 @@ def is_no_op(module: Module) -> bool: Returns: whether module is no operation """ - no_op_modules = (Sequential, _Branch, Parallel, ReduceTuple) - if TORCH_VERSION_AT_LEAST_1_9_0: - no_op_modules += (GraphModule,) + no_op_modules = (Sequential, _Branch, Parallel, ReduceTuple, GraphModule) return isinstance(module, no_op_modules) diff --git a/docs_src/examples/use_cases/example_resnet_all_in_one.py b/docs_src/examples/use_cases/example_resnet_all_in_one.py index c06483de3..0b3398e3b 100644 --- a/docs_src/examples/use_cases/example_resnet_all_in_one.py +++ b/docs_src/examples/use_cases/example_resnet_all_in_one.py @@ -55,7 +55,7 @@ from torchvision.models import resnet18 from backpack import backpack, extend -from backpack.custom_module.branching import ActiveIdentity, Parallel, SumModule +from backpack.custom_module.branching import Parallel, SumModule from backpack.custom_module.graph_utils import BackpackTracer from backpack.extensions import BatchGrad, DiagGGNExact from backpack.utils.examples import autograd_diag_ggn_exact, load_one_batch_mnist @@ -163,15 +163,6 @@ def forward(self, x): # role of :py:func:`torch.add` in the previous example. It sums up multiple inputs. # We will use it to merge the skip connection. # -# 3. :py:class:`ActiveIdentity` acts like -# PyTorch's identity, but fixes the backward hook execution order by inserting a new -# node into the graph during a forward pass (for details see -# `this discussion `_). -# The problem is fixed for ``torch >= 1.9.0``, where it's safe to use -# :py:class:`torch.nn.Identity`. If you are on ``torch < 1.9.0``, you -# have to use :py:class:`ActiveIdentity`. -# # With the above modules, we can build a simple ResNet as a container that implicitly # defines the forward pass: @@ -184,7 +175,7 @@ def forward(self, x): Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1), ReLU(), Parallel( # skip connection with ReLU-activated convolution - ActiveIdentity(), + Identity(), Sequential( Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1), ReLU(), diff --git a/setup.cfg b/setup.cfg index 82d6f1fe1..b74c2d594 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,7 @@ setup_requires = setuptools_scm # Dependencies of the project (semicolon/line-separated): install_requires = - torch >= 1.6.0, < 2.0.0 + torch >= 1.9.0, < 2.0.0 torchvision >= 0.7.0, < 1.0.0 einops >= 0.3.0, < 1.0.0 # Require a specific Python version, e.g. Python 2.7 or >= 3.4 diff --git a/test/converter/converter_cases.py b/test/converter/converter_cases.py index 508c7fc2d..0715be0b5 100644 --- a/test/converter/converter_cases.py +++ b/test/converter/converter_cases.py @@ -10,7 +10,7 @@ import abc from typing import List, Type -from torch import Tensor, flatten, rand, randint, transpose, zeros_like +from torch import Tensor, flatten, permute, rand, randint, transpose, zeros_like from torch.nn import ( LSTM, RNN, @@ -24,11 +24,6 @@ ) from torchvision.models import resnet18, wide_resnet50_2 -from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0 - -if TORCH_VERSION_AT_LEAST_1_9_0: - from torch import permute - class ConverterModule(Module, abc.ABC): """Interface class for test modules for converter.""" diff --git a/test/converter/test_branching.py b/test/converter/test_branching.py deleted file mode 100644 index 4fd968878..000000000 --- a/test/converter/test_branching.py +++ /dev/null @@ -1,145 +0,0 @@ -"""This test demonstrates the custom modules in BackPACK for branching/ResNets. - -It is important for torch < 1.9.0 (no full backward hook), as the test fails without -ActiveIdentity (wrong execution order in backward hook). -For torch>=1.9.0 (full backward hook), all tests pass. - -Additionally, for torch>=1.9.0 there is a convenient option use_converter=True in extend(). - -TODO Delete after supporting torch >= 1.9.0 (full backward hook and converter). -""" -from contextlib import nullcontext -from test.automated_test import check_sizes_and_values -from test.core.derivatives.utils import classification_targets -from typing import Callable, Tuple - -from pytest import mark, raises -from torch import Tensor, cat, manual_seed, rand -from torch.nn import ( - CrossEntropyLoss, - Identity, - Linear, - Module, - ReLU, - Sequential, - Sigmoid, -) - -from backpack import backpack, extend, extensions -from backpack.custom_module.branching import ActiveIdentity, Parallel -from backpack.utils import FULL_BACKWARD_HOOK, exception_inside_backward_pass -from backpack.utils.examples import autograd_diag_ggn_exact - - -def setup( - apply_extend: bool = False, active_identity: bool = True -) -> Tuple[Tensor, Tensor, Module, Module]: - """Set seed. Generate and return inputs, labels, model and loss function. - - A simple ResNet using the ``Parallel`` convenience module around the ``Branch`` and - ``SumModule`` modules to handle branching. - - Args: - active_identity: Whether the identity function should create a new node - in the computation graph. - apply_extend: Whether model and loss function should be extended. - - Returns: - X, y, model, loss_function - """ - manual_seed(0) - - N = 7 - - in_features = 10 - hidden_features = 5 - out_features = 3 - - X = rand((N, in_features)) - y = classification_targets((N,), out_features) - - identity = ActiveIdentity() if active_identity else Identity() - - model = Sequential( - Linear(in_features, hidden_features), - ReLU(), - # skip connection - Parallel( - identity, - Linear(hidden_features, hidden_features), - ), - # end of skip connection - Sigmoid(), - Linear(hidden_features, out_features), - ) - loss_function = CrossEntropyLoss(reduction="mean") - - if apply_extend: - model = extend(model, debug=True) - loss_function = extend(loss_function, debug=True) - - return X, y, model, loss_function - - -def backpack_diag_ggn_exact( - X: Tensor, y: Tensor, model: Module, loss_function: Module -) -> Tensor: - """Compute the generalized Gauss-Newton diagonal via BackPACK. - - Args: - X: data - y: target - model: model - loss_function: loss function - - Returns: - diag_ggn_exact - """ - outputs = model(X) - loss = loss_function(outputs, y) - - with backpack(extensions.DiagGGNExact(), debug=True): - loss.backward() - - return cat([p.diag_ggn_exact.flatten() for p in model.parameters()]) - - -SETUPS = [setup] -SETUPS_IDS = ["simple-resnet"] - - -@mark.parametrize("setup_fn", SETUPS, ids=SETUPS_IDS) -def test_diag_ggn_exact_active_identity(setup_fn: Callable) -> None: - """Compare BackPACK's diagonal GGN of a ResNet with autograd. - - Args: - setup_fn: setup function - """ - X, y, model, loss_function = setup_fn() - autograd_result = autograd_diag_ggn_exact(X, y, model, loss_function) - - X, y, model, loss_function = setup_fn(apply_extend=True) - backpack_result = backpack_diag_ggn_exact(X, y, model, loss_function) - - check_sizes_and_values(autograd_result, backpack_result) - - -@mark.parametrize("setup_fn", SETUPS, ids=SETUPS_IDS) -def test_diag_ggn_exact_nn_identity_fails(setup_fn: Callable) -> None: - """``torch.nn.Identity`` does not create a node and messes up backward hooks. - - However, it works fine if using full backward hook. (torch >= 1.9.0) - - Args: - setup_fn: setup function - """ - X, y, model, loss_function = setup_fn(active_identity=False) - autograd_result = autograd_diag_ggn_exact(X, y, model, loss_function) - - X, y, model, loss_function = setup_fn(apply_extend=True, active_identity=False) - with nullcontext() if FULL_BACKWARD_HOOK else raises( - exception_inside_backward_pass(AssertionError) - ): - backpack_result = backpack_diag_ggn_exact(X, y, model, loss_function) - - check_sizes_and_values(autograd_result, backpack_result) diff --git a/test/converter/test_converter.py b/test/converter/test_converter.py index 735a880b9..860637910 100644 --- a/test/converter/test_converter.py +++ b/test/converter/test_converter.py @@ -5,7 +5,6 @@ """ from test.converter.converter_cases import CONVERTER_MODULES, ConverterModule from test.core.derivatives.utils import classification_targets, regression_targets -from test.utils.skip_test import skip_pytorch_below_1_9_0 from typing import Tuple from pytest import fixture @@ -31,7 +30,6 @@ def model_and_input(request) -> Tuple[Module, Tensor, Module]: model and input and loss function """ manual_seed(0) - skip_pytorch_below_1_9_0() model: ConverterModule = request.param() inputs: Tensor = model.input_fn() loss_fn: Module = model.loss_fn() diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index 7d3f35df0..f1aca1af8 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -72,7 +72,7 @@ from backpack.core.derivatives.sum_module import SumModuleDerivatives from backpack.core.derivatives.tanh import TanhDerivatives from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives -from backpack.custom_module.branching import ActiveIdentity, SumModule +from backpack.custom_module.branching import SumModule from backpack.custom_module.permute import Permute from backpack.custom_module.scale_module import ScaleModule @@ -112,7 +112,6 @@ BatchNorm3d: BatchNormNdDerivatives, Embedding: EmbeddingDerivatives, ScaleModule: ScaleModuleDerivatives, - ActiveIdentity: ScaleModuleDerivatives, Identity: ScaleModuleDerivatives, SumModule: SumModuleDerivatives, } diff --git a/test/core/derivatives/scale_module_settings.py b/test/core/derivatives/scale_module_settings.py index 50de81e53..3bc089ef9 100644 --- a/test/core/derivatives/scale_module_settings.py +++ b/test/core/derivatives/scale_module_settings.py @@ -2,7 +2,6 @@ from torch import rand from torch.nn import Identity -from backpack.custom_module.branching import ActiveIdentity from backpack.custom_module.scale_module import ScaleModule SCALE_MODULE_SETTINGS = [ @@ -18,10 +17,6 @@ "module_fn": lambda: ScaleModule(5.7), "input_fn": lambda: rand(2, 3), }, - { - "module_fn": lambda: ActiveIdentity(), - "input_fn": lambda: rand(3, 2, 4), - }, { "module_fn": lambda: Identity(), "input_fn": lambda: rand(3, 1, 2), diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index c897befaa..51be5085c 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -46,12 +46,9 @@ ) from torchvision.models import resnet18 +from backpack import convert_module_to_backpack from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple -from backpack.utils import CONVERTER_AVAILABLE - -if CONVERTER_AVAILABLE: - from backpack import convert_module_to_backpack FIRSTORDER_SETTINGS = [] @@ -325,15 +322,14 @@ ############################################################################### # test setting: torchvision resnet # ############################################################################### -if CONVERTER_AVAILABLE: - FIRSTORDER_SETTINGS += [ - { - "input_fn": lambda: rand(2, 3, 7, 7), - "module_fn": lambda: convert_module_to_backpack( - resnet18(num_classes=4).eval(), True - ), - "loss_function_fn": lambda: MSELoss(), - "target_fn": lambda: regression_targets((2, 4)), - "id_prefix": "resnet18", - }, - ] +FIRSTORDER_SETTINGS += [ + { + "input_fn": lambda: rand(2, 3, 7, 7), + "module_fn": lambda: convert_module_to_backpack( + resnet18(num_classes=4).eval(), True + ), + "loss_function_fn": lambda: MSELoss(), + "target_fn": lambda: regression_targets((2, 4)), + "id_prefix": "resnet18", + }, +] diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 808921c43..02038be0c 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -28,6 +28,7 @@ CrossEntropyLoss, Embedding, Flatten, + Identity, Linear, MaxPool2d, MSELoss, @@ -36,14 +37,11 @@ Sigmoid, ) +from backpack import convert_module_to_backpack from backpack.custom_module import branching -from backpack.custom_module.branching import ActiveIdentity, Parallel +from backpack.custom_module.branching import Parallel from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple -from backpack.utils import CONVERTER_AVAILABLE - -if CONVERTER_AVAILABLE: - from backpack import convert_module_to_backpack SHARED_SETTINGS = SECONDORDER_SETTINGS LOCAL_SETTINGS = [] @@ -196,7 +194,7 @@ ReLU(), # skip connection Parallel( - ActiveIdentity(), + Identity(), Linear(5, 5), ), # end of skip connection @@ -214,7 +212,7 @@ ReLU(), # skip connection Parallel( - ActiveIdentity(), + Identity(), Sequential( Conv2d(3, 5, kernel_size=3, stride=1, padding=1), ReLU(), @@ -237,13 +235,13 @@ ReLU(), # skip connection Parallel( - ActiveIdentity(), + Identity(), Sequential( Conv2d(2, 4, kernel_size=3, stride=1, padding=1), Sigmoid(), Conv2d(4, 2, kernel_size=3, stride=1, padding=1), branching.Parallel( - branching.ActiveIdentity(), + Identity(), Sequential( Conv2d(2, 4, kernel_size=3, stride=1, padding=1), ReLU(), @@ -266,22 +264,21 @@ ############################################################################### # Branched models - converter # ############################################################################### -if CONVERTER_AVAILABLE: - LOCAL_SETTINGS += [ - { - "input_fn": lambda: ResNet1.input_test, - "module_fn": lambda: convert_module_to_backpack(ResNet1(), True), - "loss_function_fn": lambda: ResNet1.loss_test, - "target_fn": lambda: ResNet1.target_test, - "id_prefix": "ResNet1", - }, - { - "input_fn": lambda: rand(ResNet2.input_test), - "module_fn": lambda: convert_module_to_backpack(ResNet2().eval(), True), - "loss_function_fn": lambda: ResNet2.loss_test, - "target_fn": lambda: rand(ResNet2.target_test), - "id_prefix": "ResNet2", - }, - ] +LOCAL_SETTINGS += [ + { + "input_fn": lambda: ResNet1.input_test, + "module_fn": lambda: convert_module_to_backpack(ResNet1(), True), + "loss_function_fn": lambda: ResNet1.loss_test, + "target_fn": lambda: ResNet1.target_test, + "id_prefix": "ResNet1", + }, + { + "input_fn": lambda: rand(ResNet2.input_test), + "module_fn": lambda: convert_module_to_backpack(ResNet2().eval(), True), + "loss_function_fn": lambda: ResNet2.loss_test, + "target_fn": lambda: rand(ResNet2.target_test), + "id_prefix": "ResNet2", + }, +] DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/secondorder/hbp/test_kfac.py b/test/extensions/secondorder/hbp/test_kfac.py index a5c76b6c5..d6205e290 100644 --- a/test/extensions/secondorder/hbp/test_kfac.py +++ b/test/extensions/secondorder/hbp/test_kfac.py @@ -6,8 +6,6 @@ import pytest -from backpack.utils import exception_inside_backward_pass - NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -21,7 +19,7 @@ def test_kfac_not_supported(problem): """ problem.set_up() - with pytest.raises(exception_inside_backward_pass(NotImplementedError)): + with pytest.raises(NotImplementedError): BackpackExtensions(problem).kfac() problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kflr.py b/test/extensions/secondorder/hbp/test_kflr.py index d5c6819d0..3bb4a900b 100644 --- a/test/extensions/secondorder/hbp/test_kflr.py +++ b/test/extensions/secondorder/hbp/test_kflr.py @@ -6,8 +6,6 @@ import pytest -from backpack.utils import exception_inside_backward_pass - NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -21,7 +19,7 @@ def test_kflr_not_supported(problem): """ problem.set_up() - with pytest.raises(exception_inside_backward_pass(RuntimeError)): + with pytest.raises(RuntimeError): BackpackExtensions(problem).kflr() problem.tear_down() diff --git a/test/extensions/secondorder/hbp/test_kfra.py b/test/extensions/secondorder/hbp/test_kfra.py index 033171fb0..387438308 100644 --- a/test/extensions/secondorder/hbp/test_kfra.py +++ b/test/extensions/secondorder/hbp/test_kfra.py @@ -6,8 +6,6 @@ import pytest -from backpack.utils import exception_inside_backward_pass - NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS) NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS] @@ -21,7 +19,7 @@ def test_kfra_not_supported(problem): """ problem.set_up() - with pytest.raises(exception_inside_backward_pass(NotImplementedError)): + with pytest.raises(NotImplementedError): BackpackExtensions(problem).kfra() problem.tear_down() diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index fcb46832d..562be8a1e 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -233,21 +233,6 @@ SECONDORDER_SETTINGS += GROUP_CONV_SETTINGS -SECONDORDER_SETTINGS += [ - { - # Flatten layer does not add a node in the PyTorch computation graph. - # Thus, the backward hook will be called at an unexpected stage. - # The register_full_backward_hook ensures the execution order is correct -> ok. - # The register_backward_hook has above problem and therefore needs to skip execution. - # This is done in the `backward` function or in the `__call__` of ModuleExtension. - "input_fn": lambda: rand(3, 5), - "module_fn": lambda: Sequential(Linear(5, 4), Flatten(), Linear(4, 2)), - "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), - "target_fn": lambda: classification_targets((3,), 2), - "id_prefix": "flatten-no-op", - }, -] - # linear with additional dimension LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS = [ # regression diff --git a/test/extensions/test_hooks.py b/test/extensions/test_hooks.py index 325c1d265..bc6058f87 100644 --- a/test/extensions/test_hooks.py +++ b/test/extensions/test_hooks.py @@ -13,7 +13,6 @@ from backpack import backpack, extend from backpack.extensions import BatchGrad, DiagGGNExact from backpack.extensions.backprop_extension import FAIL_ERROR, BackpropExtension -from backpack.utils import exception_inside_backward_pass DEVICES = get_available_devices() DEVICES_ID = [str(dev) for dev in DEVICES] @@ -124,7 +123,7 @@ def count_visits(module): params_visited[id(p)] += 1 if problem_string == CUSTOM_CONTAINER and extension._fail_mode == FAIL_ERROR: - with raises(exception_inside_backward_pass(NotImplementedError)): + with raises(NotImplementedError): with backpack(extension, extension_hook=count_visits, debug=True): loss.backward() return @@ -189,7 +188,7 @@ def check_grad_batch(module): assert len(params_without_grad_batch) == 0 elif problem_string == CUSTOM_CONTAINER: - with raises(exception_inside_backward_pass(AssertionError)): + with raises(AssertionError): with backpack(BatchGrad(), extension_hook=check_grad_batch, debug=True): loss.backward() assert len(params_without_grad_batch) > 0 diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index b7afc30e2..4f282662f 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -7,7 +7,7 @@ from pytest import skip from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d -from backpack.utils import ADAPTIVE_AVG_POOL_BUG, TORCH_VERSION_AT_LEAST_1_9_0 +from backpack.utils import ADAPTIVE_AVG_POOL_BUG def skip_adaptive_avg_pool3d_cuda(request) -> None: @@ -57,12 +57,6 @@ def skip_subsampling_conflict( skip("Not enough samples.") -def skip_pytorch_below_1_9_0() -> None: - """Skip test if pytorch version is below 1.9.0.""" - if not TORCH_VERSION_AT_LEAST_1_9_0: - skip("Test needs PyTorch>=1.9.0") - - def skip_large_parameters( problem: ExtensionsTestProblem, max_num_params: int = 1000 ) -> None: From b6ac25ebeedb312755c7e076bd14e02665ac9403 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Mon, 11 Oct 2021 14:22:36 +0200 Subject: [PATCH 52/54] [DOC] Add RNN example (#229) Adds an example that demonstrates how to use RNNs through custom modules or the converter. --- * [ADD] RNN example * [DOC] Polish RNN example, add forward pass comparison * [FIX] flake8 Co-authored-by: Felix Dangel --- docs_src/examples/use_cases/example_rnn.py | 306 +++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 docs_src/examples/use_cases/example_rnn.py diff --git a/docs_src/examples/use_cases/example_rnn.py b/docs_src/examples/use_cases/example_rnn.py new file mode 100644 index 000000000..283b1a900 --- /dev/null +++ b/docs_src/examples/use_cases/example_rnn.py @@ -0,0 +1,306 @@ +"""Recurrent networks +==================== +""" +# %% +# There are two different approaches to using BackPACK with RNNs. +# +# 1. :ref:`Custom RNN with BackPACK custom modules`: +# Build your RNN with custom modules provided by BackPACK +# without overwriting the forward pass. This approach is useful if you want to +# understand how BackPACK handles RNNs, or if you think building a container +# module that implicitly defines the forward pass is more elegant than coding up +# a forward pass. +# +# 2. :ref:`RNN with BackPACK's converter`: +# Automatically convert your model into a BackPACK-compatible architecture. +# +# .. note:: +# RNNs are still an experimental feature. Always double-check your +# results, as done in this example! Open an issue if you encounter a bug to help +# us improve the support. +# +# Not all extensions support RNNs (yet). Please create a feature request in the +# repository if the extension you need is not supported. + +# %% +# Let's get the imports out of the way. +from torch import ( + allclose, + cat, + device, + int32, + linspace, + manual_seed, + nn, + randint, + zeros_like, +) + +from backpack import backpack, extend +from backpack.custom_module.graph_utils import BackpackTracer +from backpack.custom_module.permute import Permute +from backpack.custom_module.reduce_tuple import ReduceTuple +from backpack.extensions import BatchGrad, DiagGGNExact +from backpack.utils.examples import autograd_diag_ggn_exact + +manual_seed(0) +DEVICE = device("cpu") # Verification via autograd only works on CPU + + +# %% +# For this demo, we will use the Tolstoi Char RNN from +# `DeepOBS `_. +# This network is trained on Leo Tolstoi's War and Peace +# and learns to predict the next character. +class TolstoiCharRNN(nn.Module): + def __init__(self): + super().__init__() + self.batch_size = 8 + self.hidden_dim = 64 + self.num_layers = 2 + self.seq_len = 15 + self.vocab_size = 25 + + self.embedding = nn.Embedding( + num_embeddings=self.vocab_size, embedding_dim=self.hidden_dim + ) + self.dropout = nn.Dropout(p=0.2) + self.lstm = nn.LSTM( + input_size=self.hidden_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + dropout=0.36, + batch_first=True, + ) + # deactivate redundant bias + self.lstm.bias_ih_l0.data = zeros_like(self.lstm.bias_ih_l0) + self.lstm.bias_ih_l1.data = zeros_like(self.lstm.bias_ih_l1) + self.lstm.bias_ih_l0.requires_grad = False + self.lstm.bias_ih_l1.requires_grad = False + self.dense = nn.Linear( + in_features=self.hidden_dim, out_features=self.vocab_size + ) + + def forward(self, x): + x = self.embedding(x) + x = self.dropout(x) + x, _ = self.lstm(x) # last return values are hidden states + x = self.dropout(x) + output = self.dense(x) + output = output.permute(0, 2, 1) # [N, T, C] → [N, C, T] + return output + + def input_target_fn(self): + input = randint(0, self.vocab_size, (self.batch_size, self.seq_len)) + # target is the input shifted by 1 in time axis + target = cat( + [ + randint(0, self.vocab_size, (self.batch_size, 1)), + input[:, :-1], + ], + dim=1, + ) + return input.to(DEVICE), target.to(DEVICE) + + def loss_fn(self) -> nn.Module: + return nn.CrossEntropyLoss().to(DEVICE) + + +manual_seed(1) +tolstoi_char_rnn = TolstoiCharRNN().to(DEVICE).eval() +loss_function = extend(tolstoi_char_rnn.loss_fn()) +x, y = tolstoi_char_rnn.input_target_fn() +# %% +# Note that instead of the real data set, we will feed synthetic data to the network for +# simplicity. We also use the network in evaluation mode. This disables the +# :py:class:`Dropout ` layers and allows double-checking our results +# via :py:mod:`torch.autograd`. +# +# Custom RNN with BackPACK custom modules +# ------------- +# Second-order extensions only work if every node in the computation graph is an +# ``nn`` module that can be extended by BackPACK. The above RNN +# :py:class:`TolstoiCharRNN` does not satisfy these conditions, because +# it has a multi-layer :py:class:`torch.nn.LSTM` and implicitly uses the +# :py:func:`getitem` (for unpacking) and :py:meth:`permute() ` +# functions in the :py:meth:`forward() ` method. +# +# To build RNN without overwriting the forward pass, BackPACK offers custom modules: +# +# 1. :py:class:`ReduceTuple ` +# +# 2. :py:class:`Permute ` +# +# With the above modules, we can build a simple RNN as a container that implicitly +# defines the forward pass: +manual_seed(1) # same seed as used to initialize `tolstoi_char_rnn` +tolstoi_char_rnn_custom = nn.Sequential( + nn.Embedding(tolstoi_char_rnn.vocab_size, tolstoi_char_rnn.hidden_dim), + nn.Dropout(p=0.2), + nn.LSTM(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.hidden_dim, batch_first=True), + ReduceTuple(index=0), + nn.Dropout(p=0.36), + nn.LSTM(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.hidden_dim, batch_first=True), + ReduceTuple(index=0), + nn.Dropout(p=0.2), + nn.Linear(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.vocab_size), + Permute(0, 2, 1), +) +tolstoi_char_rnn_custom.eval().to(DEVICE) + +# %% +# Let's check that both models have the same forward pass. +for name, p in tolstoi_char_rnn_custom.named_parameters(): + if "bias_ih_l" in name: + # deactivate redundant bias + p.data = zeros_like(p.data) + p.requires_grad = False + +match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x)) +print(f"Forward pass of custom model matches TolstoiCharRNN? {match}") + +if not match: + raise AssertionError("Forward passes don't match.") + +# %% +# We can :py:func:`extend ` our model and the loss function to +# compute BackPACK extensions. + +tolstoi_char_rnn_custom = extend(tolstoi_char_rnn_custom) +loss = loss_function(tolstoi_char_rnn_custom(x), y) + +with backpack(BatchGrad(), DiagGGNExact()): + loss.backward() + +for name, param in tolstoi_char_rnn_custom.named_parameters(): + if param.requires_grad: + print( + name, + param.shape, + param.grad_batch.shape, + param.diag_ggn_exact.shape, + ) + +# %% +# Comparison of the GGN diagonal extension with :py:mod:`torch.autograd`: +# +# .. note:: +# +# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation +# can be slow, depending on the number of parameters. To reduce run time, we only +# compare some elements of the diagonal. +trainable_params = [p for p in tolstoi_char_rnn_custom.parameters() if p.requires_grad] + +diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in trainable_params]) + +num_params = sum(p.numel() for p in trainable_params) +num_to_compare = 10 +idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32) + +diag_ggn_exact_to_compare = autograd_diag_ggn_exact( + x, y, tolstoi_char_rnn_custom, loss_function, idx=idx_to_compare +) + +print("Do the exact GGN diagonals match?") +for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): + match = allclose(element, diag_ggn_exact_vector[idx]) + print( + f"Diagonal entry {idx:>8}: {match}:" + + f"\t{element:.5e}, {diag_ggn_exact_vector[idx]:.5e}" + ) + if not match: + raise AssertionError("Exact GGN diagonals don't match!") + +# %% +# RNN with BackPACK's converter +# ------------- +# If you are not building an RNN through custom modules but for instance want to +# directly use the Tolstoi Char RNN, BackPACK offers a converter. +# It analyzes the model and tries to turn it into a compatible architecture. The result +# is a :py:class:`torch.fx.GraphModule` that exclusively consists of modules. +# +# Here, we demonstrate the converter on the above Tolstoi Char RNN. Let's convert it +# while :py:func:`extend `-ing the model: + +# use BackPACK's converter to extend the model (turned off by default) +tolstoi_char_rnn = extend(tolstoi_char_rnn, use_converter=True) + +# %% +# To get an understanding what happened, we can inspect the model's graph with the +# following helper function: + + +def print_table(module: nn.Module) -> None: + """Prints a table of the module. + + Args: + module: module to analyze + """ + graph = BackpackTracer().trace(module) + graph.print_tabular() + + +print_table(tolstoi_char_rnn) + +# %% +# Note that the computation graph fully consists of modules (indicated by +# ``call_module`` in the first table column) such that BackPACK's hooks can +# successfully backpropagate additional information for its second-order extensions +# (first-order extensions work, too). +# +# First, let's compare the forward pass with the custom module from the previous +# section to make sure the converter worked fine: + +match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x)) +print(f"Forward pass of extended TolstoiCharRNN matches custom model? {match}") + +if not match: + raise AssertionError("Forward passes don't match.") + +# %% +# +# Now let's verify that second-order extensions (GGN diagonal) are working: +output = tolstoi_char_rnn(x) +loss = loss_function(output, y) + +with backpack(DiagGGNExact()): + loss.backward() + +for name, parameter in tolstoi_char_rnn.named_parameters(): + if parameter.requires_grad: + print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}") + +diag_ggn_exact_vector = cat( + [ + p.diag_ggn_exact.flatten() + for p in tolstoi_char_rnn.parameters() + if p.requires_grad + ] +) + +# %% +# Finally, we compare BackPACK's GGN diagonal with :py:mod:`torch.autograd`: +# +# .. note:: +# +# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation +# can be slow, depending on the number of parameters. To reduce run time, we only +# compare some elements of the diagonal. + +num_params = sum(p.numel() for p in tolstoi_char_rnn.parameters() if p.requires_grad) +num_to_compare = 10 +idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32) + +diag_ggn_exact_to_compare = autograd_diag_ggn_exact( + x, y, tolstoi_char_rnn, loss_function, idx=idx_to_compare +) + +print("Do the exact GGN diagonals match?") +for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): + match = allclose(element, diag_ggn_exact_vector[idx]) + print( + f"Diagonal entry {idx:>8}: {match}:" + + f"\t{element:.5e}, {diag_ggn_exact_vector[idx]:.5e}" + ) + if not match: + raise AssertionError("Exact GGN diagonals don't match!") From 161ed22f82981c8e88671e2d996bc806d794b171 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Mon, 11 Oct 2021 19:14:12 +0200 Subject: [PATCH 53/54] [ADD] Enforce batch axis first (#227) Make BackPACK assume the first axis of inputs/outputs to be the batch axis (including RNN modules that support `[T, N, H]` format). This simplifies the code for mini-batch sub-sampling and `RNN`/`LSTM`. --- * [REF] permute.py batch first * [REF] enforce batch_first=True for RNN/LSTM, delete get_batch_axis * [DEL] RNN/LSTM: remove code from batch_first=False * [REF] RNN: avoid transpose by swapping axes * [REF] LSTM: avoid transpose by swapping axes * [REF] reformat * [REF] batch axis first * [DEL] remove redundant tests * [REF] Shorter lines, add dtype * [REF] Extract single step * [REF] Shorten unpacking * [FMT] Shorten line * [REF] Remove torch import Co-authored-by: Felix Dangel --- backpack/core/derivatives/basederivatives.py | 3 +- backpack/core/derivatives/lstm.py | 175 ++++++------------ backpack/core/derivatives/rnn.py | 116 ++++-------- backpack/core/derivatives/shape_check.py | 6 +- backpack/custom_module/permute.py | 33 ++-- .../firstorder/batch_grad/batch_grad_base.py | 6 +- .../firstorder/variance/variance_base.py | 3 +- backpack/utils/subsampling.py | 43 ----- fully_documented.txt | 1 + test/core/derivatives/derivatives_test.py | 4 +- .../derivatives/implementation/autograd.py | 18 +- test/core/derivatives/lstm_settings.py | 8 - test/core/derivatives/permute_settings.py | 4 +- test/core/derivatives/problem.py | 8 +- test/core/derivatives/rnn_settings.py | 8 - test/custom_module/__init__.py | 1 - test/custom_module/test_permute.py | 25 --- .../firstorder/firstorder_settings.py | 10 +- test/extensions/problem.py | 12 +- .../secondorder/diag_ggn/diag_ggn_settings.py | 7 +- test/extensions/utils.py | 4 +- test/test_batch_first.py | 25 +++ test/utils/test_subsampling.py | 35 +--- 23 files changed, 166 insertions(+), 389 deletions(-) delete mode 100644 test/custom_module/__init__.py delete mode 100644 test/custom_module/test_permute.py create mode 100644 test/test_batch_first.py diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 677d20a31..94c152884 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -7,7 +7,6 @@ from torch.nn import Module from backpack.core.derivatives import shape_check -from backpack.utils.subsampling import get_batch_axis class BaseDerivatives(ABC): @@ -289,7 +288,7 @@ def reshape_like_input( """ shape = list(module.input0.shape) if subsampling is not None: - shape[get_batch_axis(module, "input0")] = len(subsampling) + shape[0] = len(subsampling) return cls._reshape_like(mat, shape) diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 5e4b003c8..def7e80da 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -27,9 +27,8 @@ class LSTMDerivatives(BaseParameterDerivatives): c[t] = f[t] c[t-1] + i[t] g[t] h[t] = o[t] tanh(c[t]) - Note: - For ``batch_first=True``, most of the internal tensors (e.g. those from - the manual forward pass) are kept with time axis first. + In general, we assume that it is batch axis first + and the order of axis is (V, N, T, H). """ @staticmethod @@ -42,6 +41,8 @@ def _check_parameters(module: LSTM) -> None: Raises: NotImplementedError: If any parameter of module does not match expectation """ + if not module.batch_first: + raise NotImplementedError("Batch axis must be first.") if module.num_layers != 1: raise NotImplementedError("only num_layers = 1 is supported") if module.bias is not True: @@ -70,12 +71,9 @@ def _forward_pass( subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: - ifgo, c, c_tanh (all in format ``[T, N, ...]``) + ifgo, c, c_tanh (all in format ``[N, T, ...]``) """ - free_axis = 1 - N_axis, T_axis = LSTMDerivatives.get_batch_and_time_axes(module) - T: int = mat.shape[T_axis + free_axis] - N: int = mat.shape[N_axis + free_axis] + _, N, T, _ = mat.shape H: int = module.hidden_size H0: int = 0 * H H1: int = 1 * H @@ -83,34 +81,29 @@ def _forward_pass( H3: int = 3 * H H4: int = 4 * H # forward pass and save i, f, g, o, c, c_tanh-> ifgo, c, c_tanh - ifgo: Tensor = zeros(T, N, 4 * H, device=mat.device, dtype=mat.dtype) - c: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) - c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) + ifgo: Tensor = zeros(N, T, 4 * H, device=mat.device, dtype=mat.dtype) + c: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype) + c_tanh: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype) - input0 = subsample(module.input0, dim=N_axis, subsampling=subsampling) - output = subsample(module.output, dim=N_axis, subsampling=subsampling) - - # use [T, N, ...] format - if module.batch_first: - input0 = input0.transpose(N_axis, T_axis) - output = output.transpose(N_axis, T_axis) + input0 = subsample(module.input0, dim=0, subsampling=subsampling) + output = subsample(module.output, dim=0, subsampling=subsampling) for t in range(T): - ifgo[t] = ( - einsum("hi,ni->nh", module.weight_ih_l0, input0[t]) + ifgo[:, t] = ( + einsum("hi,ni->nh", module.weight_ih_l0, input0[:, t]) + module.bias_ih_l0 + module.bias_hh_l0 ) if t != 0: - ifgo[t] += einsum("hg,ng->nh", module.weight_hh_l0, output[t - 1]) - ifgo[t, :, H0:H1] = sigmoid(ifgo[t, :, H0:H1]) - ifgo[t, :, H1:H2] = sigmoid(ifgo[t, :, H1:H2]) - ifgo[t, :, H2:H3] = tanh(ifgo[t, :, H2:H3]) - ifgo[t, :, H3:H4] = sigmoid(ifgo[t, :, H3:H4]) - c[t] = ifgo[t, :, H0:H1] * ifgo[t, :, H2:H3] + ifgo[:, t] += einsum("hg,ng->nh", module.weight_hh_l0, output[:, t - 1]) + ifgo[:, t, H0:H1] = sigmoid(ifgo[:, t, H0:H1]) + ifgo[:, t, H1:H2] = sigmoid(ifgo[:, t, H1:H2]) + ifgo[:, t, H2:H3] = tanh(ifgo[:, t, H2:H3]) + ifgo[:, t, H3:H4] = sigmoid(ifgo[:, t, H3:H4]) + c[:, t] = ifgo[:, t, H0:H1] * ifgo[:, t, H2:H3] if t != 0: - c[t] += ifgo[t, :, H1:H2] * c[t - 1] - c_tanh[t] = tanh(c[t]) + c[:, t] += ifgo[:, t, H1:H2] * c[:, t - 1] + c_tanh[:, t] = tanh(c[:, t]) return ifgo, c, c_tanh @@ -118,12 +111,7 @@ def _forward_pass( def _ifgo_jac_t_mat_prod( cls, module: LSTM, mat: Tensor, subsampling: List[int] = None ) -> Tensor: - free_axis = 1 - N_axis, T_axis = cls.get_batch_and_time_axes(module) - V: int = mat.shape[0] - T: int = mat.shape[T_axis + free_axis] - N: int = mat.shape[N_axis + free_axis] - H: int = module.hidden_size + V, N, T, H = mat.shape H0: int = 0 * H H1: int = 1 * H H2: int = 2 * H @@ -136,44 +124,44 @@ def _ifgo_jac_t_mat_prod( H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) C_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) C_prod_old: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) - IFGO_prod: Tensor = zeros(V, T, N, 4 * H, device=mat.device, dtype=mat.dtype) + IFGO_prod: Tensor = zeros(V, N, T, 4 * H, device=mat.device, dtype=mat.dtype) for t in reversed(range(T)): # jac_t_mat_prod until node h - H_prod_t[:] = mat[(slice(None),) * (T_axis + 1) + (t,)] + H_prod_t[:] = mat[:, :, t] if t != (T - 1): H_prod_t += einsum( - "vnh,hg->vng", IFGO_prod[:, t + 1], module.weight_hh_l0 + "vnh,hg->vng", IFGO_prod[:, :, t + 1], module.weight_hh_l0 ) # C_prod_t = jac_t_mat_prod until node c if t != (T - 1): C_prod_old[:] = C_prod_t C_prod_t[:] = einsum( - "vnh,nh->vnh", H_prod_t, ifgo[t, :, H3:H4] * (1 - c_tanh[t] ** 2) + "vnh,nh->vnh", H_prod_t, ifgo[:, t, H3:H4] * (1 - c_tanh[:, t] ** 2) ) if t != (T - 1): - C_prod_t += einsum("vnh,nh->vnh", C_prod_old, ifgo[t + 1, :, H1:H2]) + C_prod_t += einsum("vnh,nh->vnh", C_prod_old, ifgo[:, t + 1, H1:H2]) - IFGO_prod[:, t, :, H3:H4] = einsum( + IFGO_prod[:, :, t, H3:H4] = einsum( "vnh,nh->vnh", H_prod_t, - c_tanh[t] * (ifgo[t, :, H3:H4] * (1 - ifgo[t, :, H3:H4])), + c_tanh[:, t] * (ifgo[:, t, H3:H4] * (1 - ifgo[:, t, H3:H4])), ) - IFGO_prod[:, t, :, H0:H1] = einsum( + IFGO_prod[:, :, t, H0:H1] = einsum( "vnh,nh->vnh", C_prod_t, - ifgo[t, :, H2:H3] * (ifgo[t, :, H0:H1] * (1 - ifgo[t, :, H0:H1])), + ifgo[:, t, H2:H3] * (ifgo[:, t, H0:H1] * (1 - ifgo[:, t, H0:H1])), ) if t >= 1: - IFGO_prod[:, t, :, H1:H2] = einsum( + IFGO_prod[:, :, t, H1:H2] = einsum( "vnh,nh->vnh", C_prod_t, - c[t - 1] * (ifgo[t, :, H1:H2] * (1 - ifgo[t, :, H1:H2])), + c[:, t - 1] * (ifgo[:, t, H1:H2] * (1 - ifgo[:, t, H1:H2])), ) - IFGO_prod[:, t, :, H2:H3] = einsum( + IFGO_prod[:, :, t, H2:H3] = einsum( "vnh,nh->vnh", C_prod_t, - ifgo[t, :, H0:H1] * (1 - ifgo[t, :, H2:H3] ** 2), + ifgo[:, t, H0:H1] * (1 - ifgo[:, t, H2:H3] ** 2), ) return IFGO_prod @@ -188,11 +176,7 @@ def _jac_mat_prod( mat: Tensor, sum_batch: bool = True, ) -> Tensor: - free_axis = 1 - N_axis, T_axis = self.get_batch_and_time_axes(module) - V: int = mat.shape[0] - T: int = mat.shape[T_axis + free_axis] - N: int = mat.shape[N_axis + free_axis] + V, N, T, _ = mat.shape H: int = module.hidden_size H0: int = 0 * H H1: int = 1 * H @@ -201,7 +185,7 @@ def _jac_mat_prod( H4: int = 4 * H ifgo, c, c_tanh = self._forward_pass(module, mat) - H_prod: Tensor = zeros(V, T, N, H, device=mat.device, dtype=mat.dtype) + H_prod: Tensor = zeros(V, N, T, H, device=mat.device, dtype=mat.dtype) C_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) C_prod_old: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) C_tanh_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) @@ -211,49 +195,46 @@ def _jac_mat_prod( IFGO_prod_t[:] = einsum( "hi,vni->vnh", module.weight_ih_l0, - mat[(slice(None),) * (T_axis + free_axis) + (t,)], + mat[:, :, t], ) if t != 0: IFGO_prod_t[:] += einsum( - "hg,vng->vnh", module.weight_hh_l0, H_prod[:, t - 1] + "hg,vng->vnh", module.weight_hh_l0, H_prod[:, :, t - 1] ) IFGO_prod_t[:, :, H0:H2] = einsum( "vnh,nh->vnh", IFGO_prod_t[:, :, H0:H2], - ifgo[t, :, H0:H2] * (1 - ifgo[t, :, H0:H2]), + ifgo[:, t, H0:H2] * (1 - ifgo[:, t, H0:H2]), ) IFGO_prod_t[:, :, H3:H4] = einsum( "vnh,nh->vnh", IFGO_prod_t[:, :, H3:H4], - ifgo[t, :, H3:H4] * (1 - ifgo[t, :, H3:H4]), + ifgo[:, t, H3:H4] * (1 - ifgo[:, t, H3:H4]), ) IFGO_prod_t[:, :, H2:H3] = einsum( "vnh,nh->vnh", IFGO_prod_t[:, :, H2:H3], - 1 - ifgo[t, :, H2:H3] ** 2, + 1 - ifgo[:, t, H2:H3] ** 2, ) # product until node c if t >= 1: C_prod_old[:] = C_prod_t C_prod_t[:] = einsum( - "vnh,nh->vnh", IFGO_prod_t[:, :, H0:H1], ifgo[t, :, H2:H3] - ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H2:H3], ifgo[t, :, H0:H1]) + "vnh,nh->vnh", IFGO_prod_t[:, :, H0:H1], ifgo[:, t, H2:H3] + ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H2:H3], ifgo[:, t, H0:H1]) if t >= 1: C_prod_t += einsum( - "vnh,nh->vnh", C_prod_old, ifgo[t, :, H1:H2] - ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H1:H2], c[t - 1]) + "vnh,nh->vnh", C_prod_old, ifgo[:, t, H1:H2] + ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H1:H2], c[:, t - 1]) # product until node c_tanh - C_tanh_prod_t[:] = einsum("vnh,nh->vnh", C_prod_t, 1 - c_tanh[t] ** 2) + C_tanh_prod_t[:] = einsum("vnh,nh->vnh", C_prod_t, 1 - c_tanh[:, t] ** 2) # product until node h - H_prod[:, t] = einsum( - "vnh,nh->vnh", IFGO_prod_t[:, :, H3:H4], c_tanh[t] - ) + einsum("vnh,nh->vnh", C_tanh_prod_t, ifgo[t, :, H3:H4]) - - if module.batch_first: - H_prod = H_prod.transpose(T_axis + free_axis, N_axis + free_axis) + H_prod[:, :, t] = einsum( + "vnh,nh->vnh", IFGO_prod_t[:, :, H3:H4], c_tanh[:, t] + ) + einsum("vnh,nh->vnh", C_tanh_prod_t, ifgo[:, t, H3:H4]) return H_prod @@ -270,13 +251,7 @@ def _jac_t_mat_prod( IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( module, mat, subsampling=subsampling ) - - N_axis, _ = self.get_batch_and_time_axes(module) - batch_time_str = "nt" if N_axis == 0 else "tn" - - X_prod: Tensor = einsum( - f"vtnh,hi->v{batch_time_str}i", IFGO_prod, module.weight_ih_l0 - ) + X_prod: Tensor = einsum("vnth,hi->vnti", IFGO_prod, module.weight_ih_l0) return X_prod def _bias_ih_l0_jac_t_mat_prod( @@ -294,7 +269,7 @@ def _bias_ih_l0_jac_t_mat_prod( module, mat, subsampling=subsampling ) - return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod) + return einsum(f"vnth->v{'' if sum_batch else 'n'}h", IFGO_prod) def _bias_hh_l0_jac_t_mat_prod( self, @@ -323,14 +298,10 @@ def _weight_ih_l0_jac_t_mat_prod( IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( module, mat, subsampling=subsampling ) - - N_axis, _ = self.get_batch_and_time_axes(module) - batch_time_str = "nt" if N_axis == 0 else "tn" - return einsum( - f"vtnh,{batch_time_str}i->v{'' if sum_batch else 'n'}hi", + f"vnth,nti->v{'' if sum_batch else 'n'}hi", IFGO_prod, - subsample(module.input0, dim=N_axis, subsampling=subsampling), + subsample(module.input0, dim=0, subsampling=subsampling), ) def _weight_hh_l0_jac_t_mat_prod( @@ -343,43 +314,15 @@ def _weight_hh_l0_jac_t_mat_prod( subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) - - free_axis = 1 - N_axis, T_axis = self.get_batch_and_time_axes(module) - - N: int = mat.shape[N_axis + free_axis] - H: int = module.hidden_size - + _, N, _, H = mat.shape IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( module, mat, subsampling=subsampling ) - subsampled_output = subsample( - module.output, dim=N_axis, subsampling=subsampling - ) - if N_axis == 0: - subsampled_output = subsampled_output.transpose(N_axis, T_axis) - + subsampled_output = subsample(module.output, dim=0, subsampling=subsampling) + single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype) return einsum( - f"vtnh,tng->v{'' if sum_batch else 'n'}hg", + f"vnth,ntg->v{'' if sum_batch else 'n'}hg", IFGO_prod, - cat( - [ - zeros(1, N, H, device=mat.device, dtype=mat.dtype), - subsampled_output[0:-1], - ], - dim=0, - ), + cat([single_step, subsampled_output[:, :-1]], dim=1), ) - - @staticmethod - def get_batch_and_time_axes(module: LSTM) -> Tuple[int, int]: - """Return axes interpreted by the module as batch and time axes of the input. - - Args: - module: LSTM module. - - Returns: - Batch axis and time axis. - """ - return (0, 1) if module.batch_first else (1, 0) diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index b0d8e7011..792eda640 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -1,7 +1,6 @@ """Partial derivatives for the torch.nn.RNN layer.""" from typing import List, Tuple -import torch from torch import Tensor, cat, einsum, zeros from torch.nn import RNN @@ -15,6 +14,8 @@ class RNNDerivatives(BaseParameterDerivatives): a_t = W_ih x_t + b_ih + W_hh h_{t-1} + b_hh h_t = tanh(a_t) + We assume that it is always batch axis first. + Index conventions: ------------------ * t: Sequence dimension @@ -34,6 +35,8 @@ def _check_parameters(module: RNN) -> None: Raises: NotImplementedError: If any parameter of module does not match expectation """ + if not module.batch_first: + raise NotImplementedError("Batch axis must be first.") if module.num_layers > 1: raise NotImplementedError("only num_layers = 1 is supported") if not module.nonlinearity == "tanh": @@ -67,35 +70,24 @@ def _a_jac_t_mat_prod( Returns: jacobian vector product wrt a """ - free_axis = 1 - N_axis, T_axis = cls.get_batch_and_time_axes(module) - V: int = mat.shape[0] - N: int = mat.shape[N_axis + free_axis] - T: int = mat.shape[T_axis + free_axis] - H: int = mat.shape[3] - output = subsample(module.output, dim=N_axis, subsampling=subsampling) - # use [T, N, ...] format - if module.batch_first: - output = output.transpose(N_axis, T_axis) - a_jac_t_mat_prod: Tensor = zeros(V, T, N, H, device=mat.device) + V, N, T, H = mat.shape + output = subsample(module.output, dim=0, subsampling=subsampling) + a_jac_t_mat_prod: Tensor = zeros(V, N, T, H, device=mat.device, dtype=mat.dtype) for t in reversed(range(T)): - mat_t = mat[(slice(None),) * (T_axis + free_axis) + (t,)] if t == (T - 1): - a_jac_t_mat_prod[:, t, ...] = einsum( - "vnh,nh->vnh", - mat_t, - 1 - output[t, ...] ** 2, + a_jac_t_mat_prod[:, :, t] = einsum( + "vnh,nh->vnh", mat[:, :, t], 1 - output[:, t] ** 2 ) else: - a_jac_t_mat_prod[:, t, ...] = einsum( + a_jac_t_mat_prod[:, :, t] = einsum( "vnh,nh->vnh", - mat_t + mat[:, :, t] + einsum( "vng,gh->vnh", - a_jac_t_mat_prod[:, t + 1, ...], + a_jac_t_mat_prod[:, :, t + 1], weight_hh_l0, ), - 1 - output[t, ...] ** 2, + 1 - output[:, t] ** 2, ) return a_jac_t_mat_prod @@ -108,8 +100,8 @@ def _jac_t_mat_prod( subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) - return torch.einsum( - f"vtnh,hk->v{'nt' if module.batch_first else 'tn'}k", + return einsum( + f"vnth,hk->v{'nt' if module.batch_first else 'tn'}k", self._a_jac_t_mat_prod( module, module.weight_hh_l0, @@ -123,47 +115,33 @@ def _jac_mat_prod( self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor ) -> Tensor: self._check_parameters(module) - free_axis = 1 - N_axis, T_axis = self.get_batch_and_time_axes(module) - V: int = mat.shape[0] - N: int = mat.shape[N_axis + free_axis] - T: int = mat.shape[T_axis + free_axis] H: int = module.hidden_size - # use [T, N, ...] format - if module.batch_first: - output = module.output.transpose(N_axis, T_axis) - else: - output = module.output - _jac_mat_prod: Tensor = torch.zeros(V, T, N, H, device=mat.device) + V, N, T, _ = mat.shape + _jac_mat_prod: Tensor = zeros(V, N, T, H, device=mat.device, dtype=mat.dtype) for t in range(T): - mat_t = mat[(slice(None),) * (T_axis + free_axis) + (t,)] if t == 0: - _jac_mat_prod[:, t, ...] = einsum( + _jac_mat_prod[:, :, t] = einsum( "nh,hi,vni->vnh", - 1 - output[t, ...] ** 2, + 1 - module.output[:, t] ** 2, module.weight_ih_l0, - mat_t, + mat[:, :, t], ) else: - _jac_mat_prod[:, t, ...] = einsum( + _jac_mat_prod[:, :, t] = einsum( "nh,vnh->vnh", - 1 - output[t, ...] ** 2, + 1 - module.output[:, t] ** 2, einsum( "hi,vni->vnh", module.weight_ih_l0, - mat_t, + mat[:, :, t], ) + einsum( "hk,vnk->vnh", module.weight_hh_l0, - _jac_mat_prod[:, t - 1, ...], + _jac_mat_prod[:, :, t - 1], ), ) - return ( - _jac_mat_prod.transpose(N_axis + free_axis, T_axis + free_axis) - if module.batch_first - else _jac_mat_prod - ) + return _jac_mat_prod def _bias_ih_l0_jac_t_mat_prod( self, @@ -191,7 +169,7 @@ def _bias_ih_l0_jac_t_mat_prod( if sum_batch: dim: List[int] = [1, 2] else: - dim: int = 1 + dim: int = 2 return self._a_jac_t_mat_prod( module, module.weight_hh_l0, @@ -248,16 +226,10 @@ def _weight_ih_l0_jac_t_mat_prod( product """ self._check_parameters(module) - N_axis, _ = self.get_batch_and_time_axes(module) return einsum( - f"vtnh,{'nt' if module.batch_first else 'tn'}j->v{'' if sum_batch else 'n'}hj", - self._a_jac_t_mat_prod( - module, - module.weight_hh_l0, - mat, - subsampling, - ), - subsample(module.input0, dim=N_axis, subsampling=subsampling), + f"vnth,ntj->v{'' if sum_batch else 'n'}hj", + self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat, subsampling), + subsample(module.input0, dim=0, subsampling=subsampling), ) def _weight_hh_l0_jac_t_mat_prod( @@ -283,32 +255,12 @@ def _weight_hh_l0_jac_t_mat_prod( product """ self._check_parameters(module) - N_axis, T_axis = self.get_batch_and_time_axes(module) - N: int = mat.shape[N_axis + 1] - H: int = mat.shape[3] - output = subsample(module.output, dim=N_axis, subsampling=subsampling) - shape_single_step = (N, 1, H) if module.batch_first else (1, N, H) - output_shifted = cat( - [ - zeros(shape_single_step, device=mat.device, dtype=mat.dtype), - output[(slice(None),) * T_axis + (slice(0, -1),)], - ], - dim=T_axis, - ) + _, N, _, H = mat.shape + output = subsample(module.output, dim=0, subsampling=subsampling) + single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype) + output_shifted = cat([single_step, output[:, :-1]], dim=1) return einsum( - f"vtnh,{'nt' if module.batch_first else 'tn'}k->v{'' if sum_batch else 'n'}hk", + f"vnth,ntk->v{'' if sum_batch else 'n'}hk", self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat, subsampling), output_shifted, ) - - @staticmethod - def get_batch_and_time_axes(module: RNN) -> Tuple[int, int]: - """Return axes interpreted by the module as batch and time axes of the input. - - Args: - module: RNN module. - - Returns: - Batch axis and time axis. - """ - return (0, 1) if module.batch_first else (1, 0) diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index 4e0755c1a..d141f8fc8 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -8,7 +8,7 @@ from torch import Tensor from torch.nn import Module -from backpack.utils.subsampling import get_batch_axis, subsample +from backpack.utils.subsampling import subsample ############################################################################### @@ -72,9 +72,7 @@ def check_same_V_dim(mat1, mat2): def _check_like(mat, module, name, diff=1, *args, **kwargs): if name in ["output", "input0"] and "subsampling" in kwargs.keys(): compare = subsample( - getattr(module, name), - dim=get_batch_axis(module, name), - subsampling=kwargs["subsampling"], + getattr(module, name), dim=0, subsampling=kwargs["subsampling"] ) else: compare = getattr(module, name) diff --git a/backpack/custom_module/permute.py b/backpack/custom_module/permute.py index d71361a14..3213ccd54 100644 --- a/backpack/custom_module/permute.py +++ b/backpack/custom_module/permute.py @@ -8,7 +8,7 @@ class Permute(Module): """Module to permute a tensor.""" - def __init__(self, *dims: Any, init_transpose: bool = False, batch_axis: int = 0): + def __init__(self, *dims: Any, init_transpose: bool = False): """Initialization. This module supports two variants: permutation and transposition. @@ -19,13 +19,11 @@ def __init__(self, *dims: Any, init_transpose: bool = False, batch_axis: int = 0 Args: dims: The desired ordering of dimensions. init_transpose: If transpose parameters are provided. Default: False. - batch_axis: Which axis assumed to be the batch axis in a forward pass. - Defaults to ``0``. """ super().__init__() self.dims = dims self.init_transpose = init_transpose - self.batch_axis = batch_axis + self._enforce_batch_axis_first() def forward(self, input: Tensor) -> Tensor: """Permutes the input tensor. @@ -52,22 +50,13 @@ def _convert_transpose_to_permute(self, input: Tensor): self.dims = tuple(permutation) self.init_transpose = False - def get_batch_axis(self, io_str: str) -> int: - """Return the batch axis assumed by the module. - - Args: - io_str: Name of the tensor. Must be ``'input0'`` or ``'output'``. - - Returns: - Batch axis - - Raises: - ValueError: For invalid IO names. - """ - if io_str == "input0": - return self.batch_axis - elif io_str == "output": - return self.dims.index(self.batch_axis) + def _enforce_batch_axis_first(self) -> None: + batch_first = False + if self.init_transpose: + if 0 not in self.dims: + batch_first = True else: - valid_io_strs = ["input0", "output"] - raise ValueError(f"io_str must be in {valid_io_strs}, got {io_str}.") + if self.dims[0] == 0: + batch_first = True + if not batch_first: + raise ValueError("Permute: Batch axis must be left unchanged!") diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index 32c41e524..bd8e75a0d 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -8,7 +8,7 @@ from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.extensions.firstorder.base import FirstOrderModuleExtension -from backpack.utils.subsampling import get_batch_axis, subsample +from backpack.utils.subsampling import subsample if TYPE_CHECKING: from backpack.extensions.firstorder import BatchGrad @@ -79,14 +79,14 @@ def param_function( Scaled individual gradients """ subsampling = ext.get_subsampling() - N_axis = get_batch_axis(module, "output") + batch_axis = 0 return self._derivatives.param_mjp( param_str, module, g_inp, g_out, - subsample(g_out[0], dim=N_axis, subsampling=subsampling), + subsample(g_out[0], dim=batch_axis, subsampling=subsampling), sum_batch=False, subsampling=subsampling, ) diff --git a/backpack/extensions/firstorder/variance/variance_base.py b/backpack/extensions/firstorder/variance/variance_base.py index a1c7a9c0d..b91aac935 100644 --- a/backpack/extensions/firstorder/variance/variance_base.py +++ b/backpack/extensions/firstorder/variance/variance_base.py @@ -7,7 +7,6 @@ from torch.nn import Module from backpack.extensions.firstorder.base import FirstOrderModuleExtension -from backpack.utils.subsampling import get_batch_axis if TYPE_CHECKING: from backpack.extensions import Variance @@ -80,7 +79,7 @@ def param_function( return self._variance_from( getattr(self.grad_ext, param)(ext, module, g_inp, g_out, bpQuantities), getattr(self.sgs_ext, param)(ext, module, g_inp, g_out, bpQuantities), - g_out[0].shape[get_batch_axis(module, "output")], + g_out[0].shape[0], ) return param_function diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py index d1d9b739d..62d399f4c 100644 --- a/backpack/utils/subsampling.py +++ b/backpack/utils/subsampling.py @@ -2,9 +2,6 @@ from typing import List from torch import Tensor -from torch.nn import LSTM, RNN, Module, Sequential - -from backpack.custom_module.permute import Permute def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: @@ -23,43 +20,3 @@ def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Te return tensor else: return tensor[(slice(None),) * dim + (subsampling,)] - - -def get_batch_axis(module: Module, io_str: str) -> int: - """Return the batch axis assumed by the module. - - For unknown modules the default axis is determined as ``0``. - - Args: - module: A module. - io_str: Name of the tensor stored as BackPACK IO. Must be ``'input0'`` or - ``'output'``. - - Note: - This method only inspects single modules and therefore cannot detect whether - the batch axis has been modified by preceding ones. For instance for a ReLU - module, the batch axis will always be detected as ``0``, although the layer - still works if preceded by a ``Permute(0, 1)`` module, but would have batch - axis ``1``. - - Returns: - Batch axis - - Raises: - ValueError: For invalid IO names. - """ - valid_io_strs = ["input0", "output"] - if io_str not in valid_io_strs: - raise ValueError(f"io_str must be in {valid_io_strs}, got {io_str}.") - - batch_axis = 0 - - if isinstance(module, (RNN, LSTM)): - batch_axis = 0 if module.batch_first else 1 - elif isinstance(module, Permute): - batch_axis = module.get_batch_axis(io_str) - elif isinstance(module, Sequential): - child_idx = {"input0": 0, "output": -1}[io_str] - batch_axis = get_batch_axis(list(module.children())[child_idx], io_str) - - return batch_axis diff --git a/fully_documented.txt b/fully_documented.txt index 001a9fc13..de3c98730 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -107,3 +107,4 @@ test/converter/ test/utils/test_subsampling.py test/custom_module/ test/test_retain_graph.py +test/test_batch_first.py diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 8283bea17..5e751ffe1 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -31,7 +31,6 @@ from torch import Tensor, rand from backpack.core.derivatives.convnd import weight_jac_t_save_memory -from backpack.utils.subsampling import get_batch_axis PROBLEMS = make_test_problems(SETTINGS) IDS = [problem.make_id() for problem in PROBLEMS] @@ -197,8 +196,7 @@ def rand_mat_like_output( subsample_shape = list(problem.output_shape) if subsampling is not None: - N_axis = get_batch_axis(problem.module, "output") - subsample_shape[N_axis] = len(subsampling) + subsample_shape[0] = len(subsampling) return rand(V, *subsample_shape, device=problem.device) diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 593e1b5f7..921e46132 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -7,7 +7,7 @@ from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.lop import transposed_jacobian_vector_product from backpack.hessianfree.rop import jacobian_vector_product -from backpack.utils.subsampling import get_batch_axis, subsample +from backpack.utils.subsampling import subsample class AutogradDerivatives(DerivativesImplementation): @@ -43,20 +43,18 @@ def jac_t_vec_prod(self, vec: Tensor, subsampling=None) -> Tensor: # noqa: D102 else: # for each sample, multiply by full input Jacobian, slice out result: # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n] - batch_axis_out = get_batch_axis(self.problem.module, "output") - output = subsample(output, dim=batch_axis_out, subsampling=subsampling) - output = output.split(1, dim=batch_axis_out) - vec = vec.split(1, dim=batch_axis_out) + batch_axis = 0 + output = subsample(output, dim=batch_axis, subsampling=subsampling) + output = output.split(1, dim=batch_axis) + vec = vec.split(1, dim=batch_axis) - batch_axis_in = get_batch_axis(self.problem.module, "input0") vjps: List[Tensor] = [] - for sample_idx, out, v in zip(subsampling, output, vec): vjp = transposed_jacobian_vector_product(out, input, v)[0] - vjp = subsample(vjp, dim=batch_axis_in, subsampling=[sample_idx]) + vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx]) vjps.append(vjp) - return cat(vjps, dim=batch_axis_in) + return cat(vjps, dim=batch_axis) def jac_t_mat_prod( self, mat: Tensor, subsampling: List[int] = None @@ -76,7 +74,7 @@ def param_mjp( param_str, vec, sum_batch, - axis_batch=get_batch_axis(self.problem.module, "output"), + axis_batch=0, subsampling=subsampling, ) for vec in mat diff --git a/test/core/derivatives/lstm_settings.py b/test/core/derivatives/lstm_settings.py index 94fc9131f..77181d9dd 100644 --- a/test/core/derivatives/lstm_settings.py +++ b/test/core/derivatives/lstm_settings.py @@ -22,14 +22,6 @@ # test settings # ############################################################################### LSTM_SETTINGS += [ - { - "module_fn": lambda: LSTM(input_size=4, hidden_size=3), - "input_fn": lambda: rand(size=(5, 3, 4)), - }, - { - "module_fn": lambda: LSTM(input_size=5, hidden_size=3), - "input_fn": lambda: rand(size=(10, 8, 5)), - }, { "module_fn": lambda: LSTM(input_size=4, hidden_size=3, batch_first=True), "input_fn": lambda: rand(size=(3, 5, 4)), diff --git a/test/core/derivatives/permute_settings.py b/test/core/derivatives/permute_settings.py index 8f499a629..e6ffd360d 100644 --- a/test/core/derivatives/permute_settings.py +++ b/test/core/derivatives/permute_settings.py @@ -23,11 +23,11 @@ "input_fn": lambda: torch.rand(size=(1, 2, 3)), }, { - "module_fn": lambda: Permute(2, 0, 1), + "module_fn": lambda: Permute(0, 2, 1), "input_fn": lambda: torch.rand(size=(4, 3, 2)), }, { - "module_fn": lambda: Permute(3, 1, 0, 2), + "module_fn": lambda: Permute(0, 3, 1, 2), "input_fn": lambda: torch.rand(size=(5, 4, 3, 2)), }, ] diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index 4f6306d9a..fa67a5634 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -9,7 +9,7 @@ from backpack import extend from backpack.utils.module_classification import is_loss -from backpack.utils.subsampling import get_batch_axis, subsample +from backpack.utils.subsampling import subsample def make_test_problems(settings): @@ -148,8 +148,8 @@ def forward_pass( input: Tensor = self.input.clone().detach() if subsampling is not None: - batch_axis_in = get_batch_axis(self.module, "input0") - input = subsample(input, dim=batch_axis_in, subsampling=subsampling) + batch_axis = 0 + input = subsample(input, dim=batch_axis, subsampling=subsampling) if input_requires_grad and input.dtype is not long: input.requires_grad = True @@ -195,4 +195,4 @@ def get_batch_size(self) -> int: Returns: Mini-batch size. """ - return self.input.shape[get_batch_axis(self.module, "input0")] + return self.input.shape[0] diff --git a/test/core/derivatives/rnn_settings.py b/test/core/derivatives/rnn_settings.py index 933cf52af..6641cbb55 100644 --- a/test/core/derivatives/rnn_settings.py +++ b/test/core/derivatives/rnn_settings.py @@ -16,14 +16,6 @@ import torch RNN_SETTINGS = [ - { - "module_fn": lambda: torch.nn.RNN(input_size=4, hidden_size=2), - "input_fn": lambda: torch.rand(size=(5, 7, 4)), - }, - { - "module_fn": lambda: torch.nn.RNN(input_size=4, hidden_size=2), - "input_fn": lambda: torch.rand(size=(1, 1, 4)), - }, { "module_fn": lambda: torch.nn.RNN( input_size=4, hidden_size=3, batch_first=True diff --git a/test/custom_module/__init__.py b/test/custom_module/__init__.py deleted file mode 100644 index b998b1a6b..000000000 --- a/test/custom_module/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Contains tests for BackPACK's custom modules.""" diff --git a/test/custom_module/test_permute.py b/test/custom_module/test_permute.py deleted file mode 100644 index f6d185236..000000000 --- a/test/custom_module/test_permute.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Contains tests for BackPACK's custom ``Permute`` module.""" - -from pytest import raises - -from backpack.custom_module.permute import Permute - - -def test_get_batch_axis(): - """Test the Permute module's batch axis detection.""" - # invalid argument - with raises(ValueError): - invalid_io_str = "dummy" - Permute().get_batch_axis(invalid_io_str) - - # batch axis unaffected by forward pass - assert Permute(0, 2, 1).get_batch_axis("input0") == 0 - assert Permute(0, 2, 1).get_batch_axis("output") == 0 - - # batch axis first, affected by forward pass - assert Permute(1, 2, 0).get_batch_axis("input0") == 0 - assert Permute(1, 2, 0).get_batch_axis("output") == 2 - - # batch axis second, affected by forward pass - assert Permute(1, 2, 0, batch_axis=1).get_batch_axis("input0") == 1 - assert Permute(1, 2, 0, batch_axis=1).get_batch_axis("output") == 0 diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 51be5085c..4fc1f1a8e 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -264,10 +264,9 @@ { "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( - Permute(1, 0, 2, batch_axis=0), - RNN(input_size=6, hidden_size=3), + RNN(input_size=6, hidden_size=3, batch_first=True), ReduceTuple(index=0), - Permute(1, 2, 0, batch_axis=1), + Permute(0, 2, 1), ), "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((8, 5), 3), @@ -275,10 +274,9 @@ { "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( - Permute(1, 0, 2, batch_axis=0), - RNN(input_size=6, hidden_size=3), + RNN(input_size=6, hidden_size=3, batch_first=True), ReduceTuple(index=0), - Permute(1, 2, 0, batch_axis=1), + Permute(0, 2, 1), Flatten(), ), "loss_function_fn": lambda: MSELoss(), diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 24125e681..b2f96b92c 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter from backpack import extend -from backpack.utils.subsampling import get_batch_axis, subsample +from backpack.utils.subsampling import subsample def make_test_problems(settings): @@ -149,11 +149,9 @@ def forward_pass( target = self.target.clone() if subsampling is not None: - batch_axis_in = get_batch_axis(self.model, "input0") - input = subsample(self.input, dim=batch_axis_in, subsampling=subsampling) - - batch_axis_out = get_batch_axis(self.model, "output") - target = subsample(self.target, dim=batch_axis_out, subsampling=subsampling) + batch_axis = 0 + input = subsample(self.input, dim=batch_axis, subsampling=subsampling) + target = subsample(self.target, dim=batch_axis, subsampling=subsampling) output = self.model(input) loss = self.loss_function(output, target) @@ -243,7 +241,7 @@ def get_batch_size(self) -> int: Returns: Mini-batch size. """ - return self.input.shape[get_batch_axis(self.model, "input0")] + return self.input.shape[0] def compute_reduction_factor(self) -> float: """Compute loss function's reduction factor for aggregating per-sample losses. diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 02038be0c..6934f754e 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -53,10 +53,9 @@ { "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( - Permute(1, 0, 2, batch_axis=0), - RNN(input_size=6, hidden_size=3), + RNN(input_size=6, hidden_size=3, batch_first=True), ReduceTuple(index=0), - Permute(1, 2, 0, batch_axis=1), + Permute(0, 2, 1), Flatten(), ), "loss_function_fn": lambda: MSELoss(), @@ -78,7 +77,7 @@ RNN(input_size=6, hidden_size=3, batch_first=True), ReduceTuple(index=0), Linear(3, 3), - Permute(0, 2, 1, batch_axis=0), + Permute(0, 2, 1), ), "loss_function_fn": lambda: CrossEntropyLoss(), "target_fn": lambda: classification_targets((8, 5), 3), diff --git a/test/extensions/utils.py b/test/extensions/utils.py index 83129772a..b517a3c8d 100644 --- a/test/extensions/utils.py +++ b/test/extensions/utils.py @@ -5,8 +5,6 @@ from pytest import skip -from backpack.utils.subsampling import get_batch_axis - def skip_if_subsampling_conflict( problem: ExtensionsTestProblem, subsampling: Union[List[int], None] @@ -17,7 +15,7 @@ def skip_if_subsampling_conflict( problem: Test case. subsampling: Indices of active samples. """ - N = problem.input.shape[get_batch_axis(problem.model, "input0")] + N = problem.input.shape[0] enough_samples = subsampling is None or N > max(subsampling) if not enough_samples: skip(f"Not enough samples: N={N}, subsampling={subsampling}") diff --git a/test/test_batch_first.py b/test/test_batch_first.py new file mode 100644 index 000000000..ec962dab4 --- /dev/null +++ b/test/test_batch_first.py @@ -0,0 +1,25 @@ +"""Tests whether batch axis is always first.""" +from pytest import raises + +from backpack.custom_module.permute import Permute + + +def test_permute_batch_axis() -> None: + """Verify that an Error is raised in the correct settings.""" + Permute(0, 1, 2) + Permute(0, 2, 1) + Permute(0, 2, 3, 1) + with raises(ValueError): + Permute(1, 0, 2) + with raises(ValueError): + Permute(2, 0, 1) + + Permute(1, 2, init_transpose=True) + Permute(3, 1, init_transpose=True) + Permute(2, 1, init_transpose=True) + with raises(ValueError): + Permute(0, 1, init_transpose=True) + with raises(ValueError): + Permute(1, 0, init_transpose=True) + with raises(ValueError): + Permute(2, 0, init_transpose=True) diff --git a/test/utils/test_subsampling.py b/test/utils/test_subsampling.py index 50309552a..750c06b1f 100644 --- a/test/utils/test_subsampling.py +++ b/test/utils/test_subsampling.py @@ -1,41 +1,8 @@ """Contains tests of sub-sampling functionality.""" -from pytest import raises from torch import allclose, manual_seed, rand -from torch.nn import Linear, ReLU, Sequential -from backpack.custom_module.permute import Permute -from backpack.utils.subsampling import get_batch_axis, subsample - - -def test_get_batch_axis(): - """Test batch axis detection.""" - # invalid argument - with raises(ValueError): - invalid_io_str = "dummy" - some_module = Linear(1, 1) - get_batch_axis(some_module, invalid_io_str) - - # Sequential with unaltered batch axis - model = Sequential(Linear(1, 1), ReLU()) - assert get_batch_axis(model, "input0") == 0 - assert get_batch_axis(model, "output") == 0 - - # Sequential with altered batch axis - model = Sequential(Linear(1, 1), Permute(1, 0)) - assert get_batch_axis(model, "input0") == 0 - assert get_batch_axis(model, "output") == 1 - - # Permute - model = Permute(1, 3, 2, 0, batch_axis=0) - assert get_batch_axis(model, "input0") == 0 - assert get_batch_axis(model, "output") == 3 - - model = Sequential(Permute(0, 1), ReLU()) - assert get_batch_axis(model, "input0") == 0 - # expected failure due to local inspection - batch_axis_output = 1 - assert get_batch_axis(model, "output") != batch_axis_output +from backpack.utils.subsampling import subsample def test_subsample(): From 9701baf5966f6b427c37a1e1acc0067e8c3042b6 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Mon, 11 Oct 2021 19:57:24 +0200 Subject: [PATCH 54/54] [MNT] Update changelog for `1.4.0` release (#228) --- changelog.md | 105 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/changelog.md b/changelog.md index 6a00293f4..4bf090ad8 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,108 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.4.0] - 2021-11-12 + +This release ships many new features. Some rely on recent PyTorch functionality. +We now require `torch>=1.9.0`. + +**Highlights:** + +- *ResNets & RNNs:* Thanks to [@schaefertim](https://github.com/schaefertim) for + bringing basic support for RNNs + ([#16](https://github.com/f-dangel/backpack/issues/16), + [tutorial](https://docs.backpack.pt/en/1.4.0/use_cases/example_rnn.html#sphx-glr-use-cases-example-rnn-py)) and + ResNets ([#14](https://github.com/f-dangel/backpack/issues/14), + [tutorial](https://docs.backpack.pt/en/1.4.0/use_cases/example_resnet_all_in_one.html#sphx-glr-use-cases-example-resnet-all-in-one-py)) +- *`SqrtGGN{Exact,MC}` extension:* Symmetric factorization of the generalized + Gauss-Newton/Fisher (see + [arXiv:2106.02624](https://arxiv.org/abs/2106.02624)) +- *Sub-sampling:* Allows for restricting BackPACK extensions to a sub-set of + samples in the mini-batch + ([#12](https://github.com/f-dangel/backpack/issues/12), + [tutorial](https://docs.backpack.pt/en/1.4.0/use_cases/example_subsampling.html#sphx-glr-use-cases-example-subsampling-py)) + +### Added/New +- Converter functionality for basic support of ResNets and RNNs + [[PR1](https://github.com/f-dangel/backpack/pull/202), + [PR2](https://github.com/f-dangel/backpack/pull/221), + [PR3](https://github.com/f-dangel/backpack/pull/229)] +- New extensions: + - `SqrtGGNExact`: Symmetric factorization of the exact GGN/Fisher + [[PR](https://github.com/f-dangel/backpack/pull/180)] + - `SqrtGGNMC`: Symmetric factorization of the MC-approximated GGN/Fisher + [[PR](https://github.com/f-dangel/backpack/pull/182)] +- Module support: + - `Linear`: Support additional (more than 2) input dimensions + [[PR1](https://github.com/f-dangel/backpack/pull/185), + [PR2](https://github.com/f-dangel/backpack/pull/186)] + - `BatchNormNd`: Distinguish evaluation and training mode, support + first-order extensions and `DiagGGN{Exact,MC}` + [[#160](https://github.com/f-dangel/backpack/issues/160), + [PR1](https://github.com/f-dangel/backpack/pull/179), + [PR2](https://github.com/f-dangel/backpack/pull/201)] + - `AdaptiveAvgPoolND`: Support first-order extensions and + `DiagGGN{Exact,MC}` + [[PR](https://github.com/f-dangel/backpack/pull/201)] + - `RNN`: Support first-order extensions and `DiagGGN{MC,Exact}` + [[PR1](https://github.com/f-dangel/backpack/pull/159) + [PR2](https://github.com/f-dangel/backpack/pull/158) + [PR3](https://github.com/f-dangel/backpack/pull/156)] + - `LSTM`: Support first-order extensions and `DiagGGN{MC,Exact}` + [[PR](https://github.com/f-dangel/backpack/pull/215)] + - `CrossEntropyLoss`: Support additional (more than 2) input dimensions. + [[PR](https://github.com/f-dangel/backpack/pull/211)] + - `Embedding`: Support first-order extensions and `DiagGGN{MC,Exact}` + [[PR](https://github.com/f-dangel/backpack/pull/216)] +- Mini-batch sub-sampling + - `BatchGrad` + [[PR1](https://github.com/f-dangel/backpack/pull/200), + [PR2](https://github.com/f-dangel/backpack/pull/210)] + - `SqrtGGN{Exact,MC}` + [[PR](https://github.com/f-dangel/backpack/pull/200)] +- `retain_graph` option for `backpack` context + [[PR](https://github.com/f-dangel/backpack/pull/217)] +- Assume batch axis always first + [[PR](https://github.com/f-dangel/backpack/pull/227)] + +### Fixed/Removed +- Deprecate `python3.6`, require at least `python3.7` + [[PR](https://github.com/f-dangel/backpack/pull/190)] + +### Internal +- Use `full_backward_hook` for `torch>=1.9.0` + [[PR](https://github.com/f-dangel/backpack/pull/194)] +- Core + - Implement derivatives for `LSTM` + [[PR](https://github.com/f-dangel/backpack/pull/169)] + - Implement derivatives for `AdaptiveAvgPoolNd` + [[PR](https://github.com/f-dangel/backpack/pull/165)] + - Sub-sampling + - `weight_jac_t_mat_prod` + [[PR](https://github.com/f-dangel/backpack/pull/195)] + - `bias_jac_t_mat_prod` + [[PR](https://github.com/f-dangel/backpack/pull/196)] + - `*_jac_t_mat_prod` of `RNN` and `LSTM` parameters + [[PR](https://github.com/f-dangel/backpack/pull/197)] + - `jac_t_mat_prod` + [[PR](https://github.com/f-dangel/backpack/pull/205)] + - Hessian square root decomposition (exact and MC) + [[PR](https://github.com/f-dangel/backpack/pull/207)] + - Refactor: Share code for `*_jac_t_mat_prod` + [[PR](https://github.com/f-dangel/backpack/pull/203)] +- Extensions + - Refactor `BatchL2Grad`, introducing a base class + [[PR](https://github.com/f-dangel/backpack/pull/175)] + - Automate parameter functions for `BatchGrad` and `Grad` + [[PR](https://github.com/f-dangel/backpack/pull/150)] + - Introduce interface to check module hyperparameters + [[PR](https://github.com/f-dangel/backpack/pull/206)] +- Tests + - Check if module Hessian is zero + [[PR](https://github.com/f-dangel/backpack/pull/183)] + - Reduce run time + [[PR](https://github.com/f-dangel/backpack/pull/199)] + ## [1.3.0] - 2021-06-16 Thanks to [@sbharadwajj](https://github.com/sbharadwajj) @@ -234,7 +336,8 @@ co-authoring many PRs shipped in this release. Initial release -[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.3.0...HEAD +[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.4.0...HEAD +[1.4.0]: https://github.com/f-dangel/backpack/compare/1.4.0...1.3.0 [1.3.0]: https://github.com/f-dangel/backpack/compare/1.3.0...1.2.0 [1.2.0]: https://github.com/f-dangel/backpack/compare/1.2.0...1.1.1 [1.1.1]: https://github.com/f-dangel/backpack/compare/1.1.0...1.1.1