diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 000000000..e186905f2
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,13 @@
+# https://coverage.readthedocs.io/en/v4.5.x/config.html#config
+[report]
+# Regexes for lines to exclude from consideration
+exclude_lines =
+ # Have to re-enable the standard pragma
+ pragma: no cover
+
+ # Don't complain if tests don't hit defensive assertion code:
+ raise NotImplementedError
+ raise AssertionError
+
+ # TYPE_CHECKING block is never executed during pytest run
+ if TYPE_CHECKING:
diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index 1ae4e4217..ee7391bff 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -9,27 +9,28 @@ on:
- development
- master
- release
-
jobs:
tests:
- name: "Python ${{ matrix.python-version }}"
+ name: "py${{ matrix.python-version }} torch${{ matrix.pytorch-version}}"
runs-on: ubuntu-latest
env:
- USING_COVERAGE: '3.6,3.8'
+ USING_COVERAGE: '3.7,3.9'
strategy:
matrix:
- python-version: ["3.6", "3.7", "3.8"]
+ python-version: [3.7, 3.8, 3.9]
+ pytorch-version: [1.9.0, 1.9.1]
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
with:
- python-version: "${{ matrix.python-version }}"
+ python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
make install-test
+ pip install torch==${{ matrix.pytorch-version }} torchvision
- name: Run test
if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref)
run: |
diff --git a/README-dev.md b/README-dev.md
index e853e7bc8..913b035e5 100644
--- a/README-dev.md
+++ b/README-dev.md
@@ -1,7 +1,7 @@
# BackPACK developer manual
-## General standards
-- Python version: support 3.6+, use 3.7 for development
+## General standards
+- Python version: support 3.7+, use 3.7 for development
- `git` [branching model](https://nvie.com/posts/a-successful-git-branching-model/)
- Docstring style: [Google](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
- Test runner: [`pytest`](https://docs.pytest.org/en/latest/)
diff --git a/README.md b/README.md
index f3f179d15..14cdeea8c 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
[![Travis](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack)
[![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack)
-[![Python 3.6+](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/downloads/release/python-360/)
+[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)
BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient.
diff --git a/backpack/__init__.py b/backpack/__init__.py
index 2a1ba9234..85b5aa578 100644
--- a/backpack/__init__.py
+++ b/backpack/__init__.py
@@ -1,74 +1,102 @@
"""BackPACK."""
-import inspect
+from inspect import isclass
+from types import TracebackType
+from typing import Callable, Optional, Tuple, Type, Union
-import torch
+from torch import Tensor, is_grad_enabled
+from torch.fx import GraphModule
+from torch.nn import Module
+from backpack import extensions
+from backpack.context import CTX
+from backpack.custom_module.graph_utils import convert_module_to_backpack
from backpack.extensions.backprop_extension import BackpropExtension
from backpack.utils.hooks import no_op
-
-from . import extensions
-from .context import CTX
+from backpack.utils.module_classification import is_no_op
class backpack:
- """Activate BackPACK extensions.
-
- Enables the BackPACK extensions passed as arguments in the
- :code:`backward` calls inside the current :code:`with` block.
-
- Args:
- exts ([BackpropExtension]): Extensions to activate in the backward pass.
- extension_hook (function, optional): Function called on each module after
- all BackPACK extensions have run. Takes a ``torch.nn.Module`` and returns
- ``None``. Default: ``None`` (no operation will be formed).
-
- Can be used to reduce memory overhead if the goal is to compute
+ """Context manager to activate BackPACK extensions."""
+
+ def __init__(
+ self,
+ *exts: BackpropExtension,
+ extension_hook: Callable[[Module], None] = None,
+ debug: bool = False,
+ retain_graph: bool = False,
+ ):
+ """Activate BackPACK extensions.
+
+ Enables the BackPACK extensions passed as arguments in the
+ :code:`backward` calls inside the current :code:`with` block.
+
+ Args:
+ exts: Extensions to activate in the backward pass.
+ extension_hook: Function called on each module after
+ all BackPACK extensions have run. Takes a ``torch.nn.Module`` and returns
+ ``None``. Default: ``None`` (no operation will be performed).
+ debug: Print debug messages during the backward pass. Default: ``False``.
+ retain_graph: Determines whether BackPack IO should be kept for additional
+ backward passes. Should have same value as the argument ``retain_graph``
+ in ``backward()``. Default: ``False``.
+
+ .. note::
+ extension_hook can be used to reduce memory overhead if the goal is to compute
transformations of BackPACK quantities. Information can be compacted
during a backward pass and obsolete tensors be freed manually (``del``).
- .. note::
-
- If the callable iterates over the ``module.parameters()``, the same
- parameter may be seen multiple times across calls. This happens
- if the parameters are part of multiple modules.
- For example, the parameters of a `torch.nn.Linear` module in
- ``model = torch.nn.Sequential(torch.nn.Linear(...))`` are part of
- both the ``Linear`` and the ``Sequential``.
- debug (bool, optional): Print debug messages during the backward pass.
- Default: ``False``.
- """
-
- def __init__(self, *exts: BackpropExtension, extension_hook=None, debug=False):
+ Raises:
+ ValueError: if extensions are not valid
+ """
for ext in exts:
if not isinstance(ext, BackpropExtension):
- if inspect.isclass(ext) and issubclass(ext, BackpropExtension):
+ if isclass(ext) and issubclass(ext, BackpropExtension):
raise ValueError(
- "backpack expect instances of BackpropExtension,"
- + " but received a class instead [{}].".format(ext)
+ "backpack expects instances of BackpropExtension,"
+ + f" but received a class instead [{ext}]."
+ " Instantiate it before passing it to backpack."
)
else:
raise ValueError(
"backpack expects instances of BackpropExtension,"
- + " but received [{}].".format(ext)
+ + f" but received [{ext}]."
)
- self.exts = exts
- self.debug = debug
- self.extension_hook = no_op if extension_hook is None else extension_hook
+ self.exts: Tuple[BackpropExtension, ...] = exts
+ self.debug: bool = debug
+ self.extension_hook: Callable[[Module], None] = (
+ no_op if extension_hook is None else extension_hook
+ )
+ self.retain_graph = retain_graph
def __enter__(self):
+ """Setup backpack environment."""
self.old_CTX = CTX.get_active_exts()
self.old_debug = CTX.get_debug()
self.old_extension_hook = CTX.get_extension_hook()
+ self.old_retain_graph = CTX.get_retain_graph()
CTX.set_active_exts(self.exts)
CTX.set_debug(self.debug)
CTX.set_extension_hook(self.extension_hook)
+ CTX.set_retain_graph(self.retain_graph)
+
+ def __exit__(
+ self,
+ __exc_type: Optional[Type[BaseException]],
+ __exc_value: Optional[BaseException],
+ __traceback: Optional[TracebackType],
+ ):
+ """Leave backpack environment.
- def __exit__(self, type, value, traceback):
+ Args:
+ __exc_type: exception type
+ __exc_value: exception value
+ __traceback: exception traceback
+ """
CTX.set_active_exts(self.old_CTX)
CTX.set_debug(self.old_debug)
CTX.set_extension_hook(self.old_extension_hook)
+ CTX.set_retain_graph(self.old_retain_graph)
class disable:
@@ -91,41 +119,68 @@ class disable:
even if the forward pass is carried out in ``with backpack(...)``.
"""
- store_io = True
+ store_io: bool = True
def __enter__(self):
"""Disable input/output storing."""
- self.old_store_io = disable.store_io
+ self.old_store_io: bool = disable.store_io
disable.store_io = False
- def __exit__(self, type, value, traceback):
- """Set input/output storing to old value."""
+ def __exit__(
+ self,
+ __exc_type: Optional[Type[BaseException]],
+ __exc_value: Optional[BaseException],
+ __traceback: Optional[TracebackType],
+ ):
+ """Leave backpack environment.
+
+ Args:
+ __exc_type: exception type
+ __exc_value: exception value
+ __traceback: exception traceback
+ """
disable.store_io = self.old_store_io
@staticmethod
- def should_store_io():
- """Return whether input and output should be stored."""
+ def should_store_io() -> bool:
+ """Return whether input and output should be stored during forward pass.
+
+ Returns:
+ whether input and output should be stored during forward pass
+ """
return disable.store_io
-def hook_store_io(module, input, output):
+def hook_store_io(
+ module: Module, input: Tuple[Tensor], output: Union[Tensor, Tuple[Tensor]]
+) -> None:
"""Saves the input and output as attributes of the module.
+ The list of inputs with index i is saved as module.input[i]
+ The output is reduced to single output tensor and saved as module.output
+
Args:
- module: module
+ module: the module on which to save the inputs/outputs
input: List of input tensors
- output: output tensor
+ output: result of module(input)
"""
- if disable.should_store_io() and torch.is_grad_enabled():
+ if disable.should_store_io() and is_grad_enabled():
for i in range(len(input)):
setattr(module, "input{}".format(i), input[i])
- module.output = output
+ if isinstance(output, tuple):
+ # is true for RNN,GRU,LSTM which return tuple (output, ...)
+ module.output = output[0]
+ else:
+ module.output = output
-def memory_cleanup(module):
+def memory_cleanup(module: Module) -> None:
"""Remove I/O stored by backpack during the forward pass.
Deletes the attributes created by `hook_store_io`.
+
+ Args:
+ module: current module
"""
if hasattr(module, "output"):
delattr(module, "output")
@@ -135,60 +190,81 @@ def memory_cleanup(module):
i += 1
-def hook_run_extensions(module, g_inp, g_out):
+def hook_run_extensions(
+ module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+) -> None:
+ """The backward hook function.
+
+ It executes all BackPACK operations during the backward pass.
+
+ Args:
+ module: current module
+ g_inp: input gradients
+ g_out: output gradients
+ """
+ debug = CTX.get_debug()
for backpack_extension in CTX.get_active_exts():
- if CTX.get_debug():
+ if debug:
print("[DEBUG] Running extension", backpack_extension, "on", module)
- backpack_extension.apply(module, g_inp, g_out)
+ backpack_extension(module, g_inp, g_out)
- run_extension_hook(module)
+ if debug:
+ print("[DEBUG] Running extension hook on", module)
+ CTX.get_extension_hook()(module)
if not (
- CTX.is_extension_active(
- extensions.curvmatprod.HMP,
- extensions.curvmatprod.GGNMP,
- extensions.curvmatprod.PCHMP,
+ CTX.get_retain_graph()
+ or (
+ CTX.is_extension_active(
+ extensions.curvmatprod.HMP,
+ extensions.curvmatprod.GGNMP,
+ extensions.curvmatprod.PCHMP,
+ )
)
):
memory_cleanup(module)
-def run_extension_hook(module):
- """Execute the post extensions hook on a module after all BackPACK extensions.
-
- See the `post_backward_hook` argument of the `backpack` context manager for details.
- """
- try:
- CTX.get_extension_hook()(module)
- except Exception as e:
- message = getattr(e, "message", repr(e))
- raise RuntimeError(f"Post extensions hook failed: {message}")
+def extend(module: Module, debug: bool = False, use_converter: bool = False) -> Module:
+ """Recursively extend a ``module`` to make it BackPACK-ready.
-
-def extend(module: torch.nn.Module, debug=False):
- """Extends a ``module`` to make it BackPACK-ready.
-
- If the ``module`` has children, e.g. for a ``torch.nn.Sequential``,
- they will also be extended.
+ Modules that do not represent an operation in the computation graph (for instance
+ containers like ``Sequential``) will not explicitly be extended.
Args:
- module (torch.nn.Module): The module to extend.
- debug (bool, optional): Print debug messages during the extension.
+ module: The module to extend.
+ debug: Print debug messages during the extension. Default: ``False``.
+ use_converter: Try converting the module to a BackPACK-compatible network.
+ The converter might alter the model, e.g. order of parameters.
Default: ``False``.
Returns:
- torch.nn.Module: Extended module.
+ Extended module.
"""
if debug:
print("[DEBUG] Extending", module)
+ if use_converter:
+ module: GraphModule = convert_module_to_backpack(module, debug)
+ return extend(module)
+
for child in module.children():
extend(child, debug=debug)
- module_was_already_extended = getattr(module, "_backpack_extend", False)
- if not module_was_already_extended:
- CTX.add_hook_handle(module.register_forward_hook(hook_store_io))
- CTX.add_hook_handle(module.register_backward_hook(hook_run_extensions))
- module._backpack_extend = True
+ extended_flag = "_backpack_extend"
+ already_extended = getattr(module, extended_flag, False)
+ if not (already_extended or is_no_op(module)):
+ _register_hooks(module)
+ setattr(module, extended_flag, True)
return module
+
+
+def _register_hooks(module: Module) -> None:
+ """Install forward and backward hooks on a module.
+
+ Args:
+ module: module that is going to be extended
+ """
+ CTX.add_hook_handle(module.register_forward_hook(hook_store_io))
+ CTX.add_hook_handle(module.register_full_backward_hook(hook_run_extensions))
diff --git a/backpack/context.py b/backpack/context.py
index 9da73faa9..39d19ebf7 100644
--- a/backpack/context.py
+++ b/backpack/context.py
@@ -1,56 +1,118 @@
+"""Context class for BackPACK."""
+from typing import Callable, Iterable, List, Tuple, Type
+
+from torch.nn import Module
+from torch.utils.hooks import RemovableHandle
+
+from backpack.extensions.backprop_extension import BackpropExtension
from backpack.utils.hooks import no_op
class CTX:
- """
- Global Class holding the configuration of the backward pass
- """
+ """Global Class holding the configuration of the backward pass."""
- active_exts = tuple()
- debug = False
- extension_hook = no_op
+ active_exts: Tuple[BackpropExtension] = tuple()
+ debug: bool = False
+ extension_hook: Callable[[Module], None] = no_op
+ hook_handles: List[RemovableHandle] = []
+ retain_graph: bool = False
@staticmethod
- def set_active_exts(active_exts):
- CTX.active_exts = tuple()
- for act_ext in active_exts:
- CTX.active_exts += (act_ext,)
+ def set_active_exts(active_exts: Iterable[BackpropExtension]) -> None:
+ """Set the active backpack extensions.
+
+ Args:
+ active_exts: the extensions
+ """
+ CTX.active_exts = tuple(active_exts)
@staticmethod
- def get_active_exts():
+ def get_active_exts() -> Tuple[BackpropExtension]:
+ """Get the currently active extensions.
+
+ Returns:
+ active extensions
+ """
return CTX.active_exts
@staticmethod
- def add_hook_handle(hook_handle):
- if getattr(CTX, "hook_handles", None) is None:
- CTX.hook_handles = []
+ def add_hook_handle(hook_handle: RemovableHandle) -> None:
+ """Add the hook handle to internal variable hook_handles.
+
+ Args:
+ hook_handle: the removable handle
+ """
CTX.hook_handles.append(hook_handle)
@staticmethod
- def remove_hooks():
+ def remove_hooks() -> None:
+ """Remove all hooks."""
for handle in CTX.hook_handles:
handle.remove()
CTX.hook_handles = []
@staticmethod
- def is_extension_active(*extension_classes):
- for backpack_ext in CTX.get_active_exts():
- if isinstance(backpack_ext, extension_classes):
- return True
- return False
+ def is_extension_active(*extension_classes: Type[BackpropExtension]) -> bool:
+ """Returns whether the specified class is currently active.
+
+ Args:
+ *extension_classes: classes to test
+
+ Returns:
+ whether at least one of the specified extensions is active
+ """
+ return any(isinstance(ext, extension_classes) for ext in CTX.get_active_exts())
@staticmethod
- def get_debug():
+ def get_debug() -> bool:
+ """Whether debug mode is active.
+
+ Returns:
+ whether debug mode is active
+ """
return CTX.debug
@staticmethod
- def set_debug(debug):
+ def set_debug(debug: bool) -> None:
+ """Set debug mode.
+
+ Args:
+ debug: the mode to set
+ """
CTX.debug = debug
@staticmethod
- def get_extension_hook():
+ def get_extension_hook() -> Callable[[Module], None]:
+ """Return the current extension hook to be run after all other extensions.
+
+ Returns:
+ current extension hook
+ """
return CTX.extension_hook
@staticmethod
- def set_extension_hook(extension_hook):
+ def set_extension_hook(extension_hook: Callable[[Module], None]) -> None:
+ """Set the current extension hook.
+
+ Args:
+ extension_hook: the extension hook to run after all other extensions
+ """
CTX.extension_hook = extension_hook
+
+ @staticmethod
+ def set_retain_graph(retain_graph: bool) -> None:
+ """Set retain_graph.
+
+ Args:
+ retain_graph: new value for retain_graph
+ """
+ CTX.retain_graph = retain_graph
+
+ @staticmethod
+ def get_retain_graph() -> bool:
+ """Get retain_graph.
+
+ Returns:
+ retain_graph
+ """
+ return CTX.retain_graph
diff --git a/backpack/core/derivatives/__init__.py b/backpack/core/derivatives/__init__.py
index d9388ccec..059d55349 100644
--- a/backpack/core/derivatives/__init__.py
+++ b/backpack/core/derivatives/__init__.py
@@ -1,78 +1 @@
-from torch.nn import (
- ELU,
- SELU,
- AvgPool1d,
- AvgPool2d,
- AvgPool3d,
- Conv1d,
- Conv2d,
- Conv3d,
- ConvTranspose1d,
- ConvTranspose2d,
- ConvTranspose3d,
- CrossEntropyLoss,
- Dropout,
- LeakyReLU,
- Linear,
- LogSigmoid,
- MaxPool1d,
- MaxPool2d,
- MaxPool3d,
- MSELoss,
- ReLU,
- Sigmoid,
- Tanh,
- ZeroPad2d,
-)
-
-from .avgpool1d import AvgPool1DDerivatives
-from .avgpool2d import AvgPool2DDerivatives
-from .avgpool3d import AvgPool3DDerivatives
-from .conv1d import Conv1DDerivatives
-from .conv2d import Conv2DDerivatives
-from .conv3d import Conv3DDerivatives
-from .conv_transpose1d import ConvTranspose1DDerivatives
-from .conv_transpose2d import ConvTranspose2DDerivatives
-from .conv_transpose3d import ConvTranspose3DDerivatives
-from .crossentropyloss import CrossEntropyLossDerivatives
-from .dropout import DropoutDerivatives
-from .elu import ELUDerivatives
-from .leakyrelu import LeakyReLUDerivatives
-from .linear import LinearDerivatives
-from .logsigmoid import LogSigmoidDerivatives
-from .maxpool1d import MaxPool1DDerivatives
-from .maxpool2d import MaxPool2DDerivatives
-from .maxpool3d import MaxPool3DDerivatives
-from .mseloss import MSELossDerivatives
-from .relu import ReLUDerivatives
-from .selu import SELUDerivatives
-from .sigmoid import SigmoidDerivatives
-from .tanh import TanhDerivatives
-from .zeropad2d import ZeroPad2dDerivatives
-
-derivatives_for = {
- Linear: LinearDerivatives,
- Conv1d: Conv1DDerivatives,
- Conv2d: Conv2DDerivatives,
- Conv3d: Conv3DDerivatives,
- AvgPool1d: AvgPool1DDerivatives,
- AvgPool2d: AvgPool2DDerivatives,
- AvgPool3d: AvgPool3DDerivatives,
- MaxPool1d: MaxPool1DDerivatives,
- MaxPool2d: MaxPool2DDerivatives,
- MaxPool3d: MaxPool3DDerivatives,
- ZeroPad2d: ZeroPad2dDerivatives,
- Dropout: DropoutDerivatives,
- ReLU: ReLUDerivatives,
- Tanh: TanhDerivatives,
- Sigmoid: SigmoidDerivatives,
- ConvTranspose1d: ConvTranspose1DDerivatives,
- ConvTranspose2d: ConvTranspose2DDerivatives,
- ConvTranspose3d: ConvTranspose3DDerivatives,
- LeakyReLU: LeakyReLUDerivatives,
- LogSigmoid: LogSigmoidDerivatives,
- ELU: ELUDerivatives,
- SELU: SELUDerivatives,
- CrossEntropyLoss: CrossEntropyLossDerivatives,
- MSELoss: MSELossDerivatives,
-}
+"""Contains derivatives of all supported modules."""
diff --git a/backpack/core/derivatives/adaptive_avg_pool_nd.py b/backpack/core/derivatives/adaptive_avg_pool_nd.py
new file mode 100644
index 000000000..f6f4d3ad1
--- /dev/null
+++ b/backpack/core/derivatives/adaptive_avg_pool_nd.py
@@ -0,0 +1,106 @@
+"""Implements the derivatives for AdaptiveAvgPool."""
+from typing import List, Tuple, Union
+from warnings import warn
+
+from torch import Size
+from torch.nn import AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
+
+from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives
+from backpack.utils import ADAPTIVE_AVG_POOL_BUG
+
+
+class AdaptiveAvgPoolNDDerivatives(AvgPoolNDDerivatives):
+ """Implements the derivatives for AdaptiveAvgPool."""
+
+ def check_parameters(
+ self, module: Union[AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d]
+ ) -> None:
+ """Checks if the parameters are supported.
+
+ Specifically checks if input shape is multiple of output shape.
+ In this case, there are parameters for AvgPoolND that are equivalent.
+
+ https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993 # noqa: B950
+
+ Args:
+ module: module to check
+
+ Raises:
+ NotImplementedError: if the given shapes do not match
+ """
+ if ADAPTIVE_AVG_POOL_BUG and module.input0.is_cuda and (self.N == 3):
+ warn(
+ "Be careful when computing gradients of AdaptiveAvgPool3d. "
+ "There is a bug using autograd.grad on cuda with AdaptiveAvgPool3d. "
+ "https://discuss.pytorch.org/t/bug-report-autograd-grad-adaptiveavgpool3d-cuda/124614 " # noqa: B950
+ "BackPACK derivatives are correct."
+ )
+
+ shape_input: Size = module.input0.shape
+ shape_output: Size = module.output.shape
+
+ # check length of input shape
+ if not len(shape_input) == (self.N + 2):
+ raise NotImplementedError(
+ f"input must be (batch_size, C, ...) with ... {self.N} dimensions"
+ )
+
+ # check if input shape is multiple of output shape
+ if any(shape_input[2 + n] % shape_output[2 + n] != 0 for n in range(self.N)):
+ raise NotImplementedError(
+ f"No equivalent AvgPool (unadaptive): Input shape ({shape_input}) "
+ f"must be multiple of output shape ({shape_output})."
+ )
+
+ def get_avg_pool_parameters(
+ self, module: Union[AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d]
+ ) -> Tuple[List[int], List[int], List[int]]:
+ """Return parameters for an equivalent AvgPool.
+
+ Assumes that check_parameters has been run before.
+ Therefore, does not check parameters.
+
+ Args:
+ module: module to compute on
+
+ Returns:
+ stride, kernel_size, padding as lists of length self.N
+ """
+ shape_input: Size = module.input0.shape
+ shape_target: Size = module.output.shape
+
+ # calculate equivalent AvgPoolND parameters
+ stride: List[int] = []
+ kernel_size: List[int] = []
+ for n in range(self.N):
+ in_dim: int = shape_input[2 + n]
+ out_dim: int = shape_target[2 + n]
+ stride.append(in_dim // out_dim)
+ kernel_size.append(in_dim - (out_dim - 1) * stride[n])
+ padding: List[int] = [0 for _ in range(self.N)]
+
+ return stride, kernel_size, padding
+
+
+class AdaptiveAvgPool1dDerivatives(AdaptiveAvgPoolNDDerivatives):
+ """Derivatives for AdaptiveAvgPool1d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(N=1)
+
+
+class AdaptiveAvgPool2dDerivatives(AdaptiveAvgPoolNDDerivatives):
+ """Derivatives for AdaptiveAvgPool2d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(N=2)
+
+
+class AdaptiveAvgPool3dDerivatives(AdaptiveAvgPoolNDDerivatives):
+ """Derivatives for AdaptiveAvgPool3d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(N=3)
diff --git a/backpack/core/derivatives/avgpoolnd.py b/backpack/core/derivatives/avgpoolnd.py
index e40d1f0af..b51d21480 100644
--- a/backpack/core/derivatives/avgpoolnd.py
+++ b/backpack/core/derivatives/avgpoolnd.py
@@ -3,57 +3,51 @@
Average pooling can be expressed as convolution over grouped channels with a constant
kernel.
"""
+from typing import Any, List, Tuple
-import torch.nn
from einops import rearrange
-from torch.nn import (
- Conv1d,
- Conv2d,
- Conv3d,
- ConvTranspose1d,
- ConvTranspose2d,
- ConvTranspose3d,
-)
+from torch import Tensor, ones_like
+from torch.nn import Module
from backpack.core.derivatives.basederivatives import BaseDerivatives
+from backpack.utils.conv import get_conv_module
+from backpack.utils.conv_transpose import get_conv_transpose_module
class AvgPoolNDDerivatives(BaseDerivatives):
- def __init__(self, N):
+ def __init__(self, N: int):
+ self.conv = get_conv_module(N)
+ self.convt = get_conv_transpose_module(N)
self.N = N
- if self.N == 1:
- self.conv = Conv1d
- self.convt = ConvTranspose1d
- elif self.N == 2:
- self.conv = Conv2d
- self.convt = ConvTranspose2d
- elif self.N == 3:
- self.conv = Conv3d
- self.convt = ConvTranspose3d
-
- def hessian_is_zero(self):
+
+ def check_parameters(self, module: Module) -> None:
+ assert module.count_include_pad, (
+ "Might not work for exotic hyperparameters of AvgPool2d, "
+ + "like count_include_pad=False"
+ )
+
+ def get_avg_pool_parameters(self, module) -> Tuple[Any, Any, Any]:
+ """Return the parameters of the module.
+
+ Args:
+ module: module
+
+ Returns:
+ stride, kernel_size, padding
+ """
+ return module.stride, module.kernel_size, module.padding
+
+ def hessian_is_zero(self, module):
return True
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
"""Use fact that average pooling can be implemented as conv."""
- if self.N == 1:
- _, C, L_in = module.input0.size()
- _, _, L_out = module.output.size()
- in_features = C * L_in
- out_features = C * L_out
- shape_out = (1, L_out)
- elif self.N == 2:
- _, C, H_in, W_in = module.input0.size()
- _, _, H_out, W_out = module.output.size()
- in_features = C * H_in * W_in
- out_features = C * H_out * W_out
- shape_out = (1, H_out, W_out)
- elif self.N == 3:
- _, C, D_in, H_in, W_in = module.input0.size()
- _, _, D_out, H_out, W_out = module.output.size()
- in_features = C * D_in * H_in * W_in
- out_features = C * D_out * H_out * W_out
- shape_out = (1, D_out, H_out, W_out)
+ self.check_parameters(module)
+
+ C = module.input0.shape[1]
+ shape_out = (1,) + tuple(module.output.shape[2:])
+ in_features = module.input0.shape[1:].numel()
+ out_features = module.output.shape[1:].numel()
mat = mat.reshape(out_features * C, *shape_out)
jac_t_mat = self.__apply_jacobian_t_of(module, mat).reshape(
@@ -66,14 +60,8 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
return jac_t_mat_t_jac.t()
- def check_exotic_parameters(self, module):
- assert module.count_include_pad, (
- "Might not work for exotic hyperparameters of AvgPool2d, "
- + "like count_include_pad=False"
- )
-
def _jac_mat_prod(self, module, g_inp, g_out, mat):
- self.check_exotic_parameters(module)
+ self.check_parameters(module)
mat_as_pool = self.__make_single_channel(mat, module)
jmp_as_pool = self.__apply_jacobian_of(module, mat_as_pool)
@@ -89,79 +77,61 @@ class and channel dimension."""
return result.unsqueeze(C_axis)
def __apply_jacobian_of(self, module, mat):
+ stride, kernel_size, padding = self.get_avg_pool_parameters(module)
convnd = self.conv(
in_channels=1,
out_channels=1,
- kernel_size=module.kernel_size,
- stride=module.stride,
- padding=module.padding,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
bias=False,
).to(module.input0.device)
convnd.weight.requires_grad = False
- avg_kernel = torch.ones_like(convnd.weight) / convnd.weight.numel()
+ avg_kernel = ones_like(convnd.weight) / convnd.weight.numel()
convnd.weight.data = avg_kernel
return convnd(mat)
def __check_jmp_out_as_pool(self, mat, jmp_as_pool, module):
- V = mat.size(0)
- if self.N == 1:
- N, C_out, L_out = module.output.shape
- assert jmp_as_pool.shape == (V * N * C_out, 1, L_out)
- elif self.N == 2:
- N, C_out, H_out, W_out = module.output.shape
- assert jmp_as_pool.shape == (V * N * C_out, 1, H_out, W_out)
- elif self.N == 3:
- N, C_out, D_out, H_out, W_out = module.output.shape
- assert jmp_as_pool.shape == (V * N * C_out, 1, D_out, H_out, W_out)
-
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
- self.check_exotic_parameters(module)
+ V = mat.shape[0]
+ N, C_out = module.output.shape[:2]
+
+ assert jmp_as_pool.shape == (V * N * C_out, 1) + module.output.shape[2:]
+
+ def _jac_t_mat_prod(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self.check_parameters(module)
mat_as_pool = self.__make_single_channel(mat, module)
jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool)
- self.__check_jmp_in_as_pool(mat, jmp_as_pool, module)
- return self.reshape_like_input(jmp_as_pool, module)
+ return self.reshape_like_input(jmp_as_pool, module, subsampling=subsampling)
def __apply_jacobian_t_of(self, module, mat):
+ stride, kernel_size, padding = self.get_avg_pool_parameters(module)
C_for_conv_t = 1
convnd_t = self.convt(
in_channels=C_for_conv_t,
out_channels=C_for_conv_t,
- kernel_size=module.kernel_size,
- stride=module.stride,
- padding=module.padding,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
bias=False,
).to(module.input0.device)
convnd_t.weight.requires_grad = False
- avg_kernel = torch.ones_like(convnd_t.weight) / convnd_t.weight.numel()
+ avg_kernel = ones_like(convnd_t.weight) / convnd_t.weight.numel()
convnd_t.weight.data = avg_kernel
V_N_C_in = mat.size(0)
- if self.N == 1:
- _, _, L_in = module.input0.size()
- output_size = (V_N_C_in, C_for_conv_t, L_in)
- elif self.N == 2:
- _, _, H_in, W_in = module.input0.size()
- output_size = (V_N_C_in, C_for_conv_t, H_in, W_in)
- elif self.N == 3:
- _, _, D_in, H_in, W_in = module.input0.size()
- output_size = (V_N_C_in, C_for_conv_t, D_in, H_in, W_in)
+ output_size = (V_N_C_in, C_for_conv_t) + tuple(module.input0.shape[2:])
return convnd_t(mat, output_size=output_size)
-
- def __check_jmp_in_as_pool(self, mat, jmp_as_pool, module):
- V = mat.size(0)
- if self.N == 1:
- N, C_in, L_in = module.input0.size()
- assert jmp_as_pool.shape == (V * N * C_in, 1, L_in)
- elif self.N == 2:
- N, C_in, H_in, W_in = module.input0.size()
- assert jmp_as_pool.shape == (V * N * C_in, 1, H_in, W_in)
- elif self.N == 3:
- N, C_in, D_in, H_in, W_in = module.input0.size()
- assert jmp_as_pool.shape == (V * N * C_in, 1, D_in, H_in, W_in)
diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py
index 1911d9adf..94c152884 100644
--- a/backpack/core/derivatives/basederivatives.py
+++ b/backpack/core/derivatives/basederivatives.py
@@ -1,10 +1,15 @@
"""Base classes for more flexible Jacobians and second-order information."""
import warnings
+from abc import ABC
+from typing import Callable, List, Tuple
+
+from torch import Tensor
+from torch.nn import Module
from backpack.core.derivatives import shape_check
-class BaseDerivatives:
+class BaseDerivatives(ABC):
"""First- and second-order partial derivatives of unparameterized module.
Note:
@@ -38,7 +43,9 @@ class BaseDerivatives:
@shape_check.jac_mat_prod_accept_vectors
@shape_check.jac_mat_prod_check_shapes
- def jac_mat_prod(self, module, g_inp, g_out, mat):
+ def jac_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
"""Apply Jacobian of the output w.r.t. input to a matrix.
It is assumed that the module input has shape `[N, *]`, while the output is
@@ -49,14 +56,14 @@ def jac_mat_prod(self, module, g_inp, g_out, mat):
`result[v, n, •] = ∑ₖ ∑_* J[n, •, k, *] mat[v, n, *]`.
Args:
- module (torch.nn.Module): Extended module.
- g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
- g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
- mat (torch.Tensor): Matrix the Jacobian will be applied to. Must have
+ module: Extended module.
+ g_inp: Gradients of the module w.r.t. its inputs.
+ g_out: Gradients of the module w.r.t. its outputs.
+ mat: Matrix the Jacobian will be applied to. Must have
shape `[V, N, *]`.
Returns:
- torch.Tensor: Jacobian-matrix product. Has shape [V, N, *].
+ Jacobian-matrix product. Has shape [V, N, *].
Note:
- The Jacobian can be applied without knowledge about backpropagated
@@ -65,40 +72,59 @@ def jac_mat_prod(self, module, g_inp, g_out, mat):
"""
return self._jac_mat_prod(module, g_inp, g_out, mat)
- def _jac_mat_prod(self, module, g_inp, g_out, mat):
- """Internal implementation of the input-output Jacobian."""
+ def _jac_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
raise NotImplementedError
@shape_check.jac_t_mat_prod_accept_vectors
@shape_check.jac_t_mat_prod_check_shapes
- def jac_t_mat_prod(self, module, g_inp, g_out, mat):
+ def jac_t_mat_prod(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""Apply transposed input-ouput Jacobian of module output to a matrix.
Implicit application of Jᵀ:
result[v, ̃n, ̃c, ̃w, ...]
= ∑_{n, c, w} Jᵀ[̃n, ̃c, ̃w, ..., n, c, w, ...] mat[v, n, c, w, ...].
- Parameters:
- -----------
- mat: torch.Tensor
- Matrix the transposed Jacobian will be applied to.
- Must have shape [V, N, C_out, H_out, ...].
+ Args:
+ module: module which derivative is calculated
+ g_inp: input gradients
+ g_out: output gradients
+ mat: Matrix the transposed Jacobian will be applied to.
+ Must have shape ``[V, *module.output.shape]``; but if used with
+ sub-sampling, the batch dimension is replaced by ``len(subsampling)``.
+ subsampling: Indices of samples along the output's batch dimension that
+ should be considered. Defaults to ``None`` (use all samples).
Returns:
- --------
- result: torch.Tensor
Transposed Jacobian-matrix product.
- Has shape [V, N, C_in, H_in, ...].
+ Has shape ``[V, *module.input0.shape]``; but if used with sub-sampling,
+ the batch dimension is replaced by ``len(subsampling)``.
"""
- return self._jac_t_mat_prod(module, g_inp, g_out, mat)
-
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
- """Internal implementation of transposed Jacobian."""
+ return self._jac_t_mat_prod(module, g_inp, g_out, mat, subsampling=subsampling)
+
+ def _jac_t_mat_prod(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
raise NotImplementedError
# TODO Add shape check
# TODO Use new convention
- def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
+ def ea_jac_t_mat_jac_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
"""Expectation approximation of outer product with input-output Jacobian.
Used for backpropagation in KFRA.
@@ -109,76 +135,178 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
= 1/n ∑ₙₖₗ (𝜕output[n,k] / 𝜕input[n,i]) mat[k,l] (𝜕output[n,j] / 𝜕input[n,l])
Args:
- module (torch.nn.Module): Extended module.
- g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
- g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
- mat (torch.Tensor): Matrix of shape `[D_out, D_out]`.
+ module: Extended module.
+ g_inp: Gradients of the module w.r.t. its inputs.
+ g_out: Gradients of the module w.r.t. its outputs.
+ mat: Matrix of shape `[D_out, D_out]`.
+ # noqa: DAR202
Returns:
- torch.Tensor: Matrix of shape `[D_in, D_in]`.
+ Matrix of shape `[D_in, D_in]`.
Note:
- This operation can be applied without knowledge about backpropagated
derivatives. Both `g_inp` and `g_out` are usually not required and
can be set to `None`.
+
+ Raises:
+ NotImplementedError: if not overwritten
"""
raise NotImplementedError
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module: Module) -> bool:
+ """Returns whether Hessian is zero.
+
+ I.e. whether ``∂²output[i] / ∂input[j] ∂input[k] = 0 ∀ i,j,k``.
+
+ Args:
+ module: current module to evaluate
+
+ # noqa: DAR202
+ Returns:
+ whether Hessian is zero
+
+ Raises:
+ NotImplementedError: if not overwritten
+ """
raise NotImplementedError
- def hessian_is_diagonal(self):
- """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`."""
+ def hessian_is_diagonal(self, module: Module) -> bool:
+ """Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`.
+
+ The Hessian diagonal is only defined for layers that preserve the size
+ of their input.
+
+ Must be implemented by descendants that don't implement ``hessian_is_zero``.
+
+ Args:
+ module: current module to evaluate
+
+ # noqa: DAR202
+ Returns:
+ whether Hessian is diagonal
+
+ Raises:
+ NotImplementedError: if not overwritten
+ """
raise NotImplementedError
- def hessian_diagonal(self):
- """Return `∂²output[i] / ∂input[i]²`.
+ # FIXME Currently returns `∂²output[i] / ∂input[i]² * g_out[0][i]`,
+ # which s the residual matrix diagonal, rather than the Hessian diagonal
+ def hessian_diagonal(
+ self, module: Module, g_in: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> Tensor:
+ """Return the Hessian diagonal `∂²output[i] / ∂input[i]²`.
Only required if `hessian_is_diagonal` returns `True`.
+ The Hessian diagonal is only defined for layers that preserve the size
+ of their input.
+
+ Args:
+ module: Module whose output-input Hessian diagonal is computed.
+ g_in: Gradients w.r.t. the module input.
+ g_out: Gradients w.r.t. the module output.
+
+ # noqa: DAR202
+ Returns:
+ Hessian diagonal. Has same shape as module input.
+
+ Raises:
+ NotImplementedError: if not overwritten
"""
raise NotImplementedError
- def hessian_is_psd(self):
- """Is `∂²output[i] / ∂input[j] ∂input[k]` positive semidefinite (PSD)."""
+ def hessian_is_psd(self) -> bool:
+ """Is `∂²output[i] / ∂input[j] ∂input[k]` positive semidefinite (PSD).
+
+ # noqa: DAR202
+ Returns:
+ whether hessian is positive semi definite
+
+ Raises:
+ NotImplementedError: if not overwritten
+ """
raise NotImplementedError
@shape_check.residual_mat_prod_accept_vectors
@shape_check.residual_mat_prod_check_shapes
- def residual_mat_prod(self, module, g_inp, g_out, mat):
+ def residual_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
"""Multiply with the residual term.
Performs mat → [∑_{k} Hz_k(x) 𝛿z_k] mat.
+ Args:
+ module: module
+ g_inp: input gradients
+ g_out: output gradients
+ mat: matrix to multiply
+
+ Returns:
+ product
+
Note:
- -----
This function only has to be implemented if the residual is not
zero and not diagonal (for instance, `BatchNorm`).
"""
return self._residual_mat_prod(module, g_inp, g_out, mat)
- def _residual_mat_prod(self, module, g_inp, g_out, mat):
+ def _residual_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
raise NotImplementedError
@staticmethod
- def _reshape_like(mat, like):
+ def _reshape_like(mat: Tensor, shape: Tuple[int]) -> Tensor:
"""Reshape as like with trailing and additional 0th dimension.
If like is [N, C, H, ...], returns shape [-1, N, C, H, ...]
+
+ Args:
+ mat: Matrix to reshape.
+ shape: Trailing target shape.
+
+ Returns:
+ reshaped matrix
"""
- V = -1
- shape = (V, *like.shape)
- return mat.reshape(shape)
+ return mat.reshape(-1, *shape)
@classmethod
- def reshape_like_input(cls, mat, module):
- return cls._reshape_like(mat, module.input0)
+ def reshape_like_input(
+ cls, mat: Tensor, module: Module, subsampling: List[int] = None
+ ) -> Tensor:
+ """Reshapes matrix according to input.
+
+ Args:
+ mat: matrix to reshape
+ module: module which input shape is used
+ subsampling: Indices of active samples. ``None`` means use all samples.
+
+ Returns:
+ reshaped matrix
+ """
+ shape = list(module.input0.shape)
+ if subsampling is not None:
+ shape[0] = len(subsampling)
+
+ return cls._reshape_like(mat, shape)
@classmethod
- def reshape_like_output(cls, mat, module):
- return cls._reshape_like(mat, module.output)
+ def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor:
+ """Reshapes matrix like output.
+ Args:
+ mat: matrix to reshape
+ module: module which output is used
+
+ Returns:
+ reshaped matrix
+ """
+ return cls._reshape_like(mat, module.output.shape)
-class BaseParameterDerivatives(BaseDerivatives):
+
+class BaseParameterDerivatives(BaseDerivatives, ABC):
"""First- and second order partial derivatives of a module with parameters.
Assumptions (true for `nn.Linear`, `nn.Conv(Transpose)Nd`, `nn.BatchNormNd`):
@@ -193,169 +321,295 @@ class BaseParameterDerivatives(BaseDerivatives):
For most layers, these shapes correspond to shapes of the module input or output.
"""
- @shape_check.bias_jac_mat_prod_accept_vectors
- @shape_check.bias_jac_mat_prod_check_shapes
- def bias_jac_mat_prod(self, module, g_inp, g_out, mat):
- """Apply Jacobian of the output w.r.t. bias to a matrix.
+ @shape_check.param_mjp_accept_vectors
+ def param_mjp(
+ self,
+ param_str: str,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Compute matrix-Jacobian products (MJPs) of the module w.r.t. a parameter.
+
+ Handles both vector and matrix inputs. Preserves input format in output.
+
+ Internally calls out to ``_{param_str}_jac_t_mat_prod`` function that must be
+ implemented by descendants. It follows the same signature, but does not have
+ the ``param_str`` argument.
- Parameters:
- -----------
- mat: torch.Tensor
- Matrix the Jacobian will be applied to.
- Must have shape [V, C_b, ...].
+ Args:
+ param_str: Attribute name under which the parameter is stored in the module.
+ module: Module whose Jacobian will be applied. Must provide access to IO.
+ g_inp: Gradients w.r.t. module input.
+ g_out: Gradients w.r.t. module output.
+ mat: Matrix the Jacobian will be applied to. Has shape
+ ``[V, *module.output.shape]`` (matrix case) or same shape as
+ ``module.output`` (vector case). If used with subsampling, has dimension
+ len(subsampling) instead of batch size along the batch axis.
+ sum_batch: Sum out the MJP's batch axis. Default: ``True``.
+ subsampling: Indices of samples along the output's batch dimension that
+ should be considered. Defaults to ``None`` (use all samples).
Returns:
- --------
- result: torch.Tensor
- Jacobian-matrix product.
- Has shape [V, N, C_out, H_out, ...].
+ Matrix-Jacobian products. Has shape ``[V, *param_shape]`` when batch
+ summation is enabled (same shape as parameter in the vector case). Without
+ batch summation, the result has shape ``[V, N, *param_shape]`` (vector case
+ has shape ``[N, *param_shape]``). If used with subsampling, the batch size N
+ is replaced by len(subsampling).
+
+ Raises:
+ NotImplementedError: if required method is not implemented by derivatives class
"""
- return self._bias_jac_mat_prod(module, g_inp, g_out, mat)
+ # input check
+ shape_check.shape_like_output(mat, module, subsampling=subsampling)
+
+ method_name = f"_{param_str}_jac_t_mat_prod"
+ mjp = getattr(self, method_name, None)
+ if mjp is None:
+ raise NotImplementedError(
+ f"Computation requires implementation of {method_name}, but {self} "
+ f"(defining derivatives of {module}) does not implement it."
+ )
+ mjp_out = mjp(
+ module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
+ )
- def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
- """Internal implementation of the bias Jacobian."""
- raise NotImplementedError
+ # output check
+ shape_check.check_like_with_sum_batch(
+ mjp_out, module, param_str, sum_batch=sum_batch
+ )
+ shape_check.check_same_V_dim(mjp_out, mat)
- @shape_check.bias_jac_t_mat_prod_accept_vectors
- @shape_check.bias_jac_t_mat_prod_check_shapes
- def bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- """Apply transposed Jacobian of the output w.r.t. bias to a matrix.
+ return mjp_out
- Parameters:
- -----------
- mat: torch.Tensor
- Matrix the transposed Jacobian will be applied to.
- Must have shape [V, N, C_out, H_out, ...].
- sum_batch: bool
- Whether to sum over the batch dimension on the fly.
+ @shape_check.bias_jac_mat_prod_accept_vectors
+ @shape_check.bias_jac_mat_prod_check_shapes
+ def bias_jac_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
+ """Apply Jacobian of the output w.r.t. bias to a matrix.
+
+ Args:
+ module: module to perform derivatives on
+ g_inp: input gradients
+ g_out: output gradients
+ mat: Matrix the Jacobian will be applied to.
+ Must have shape [V, C_b, ...].
Returns:
- --------
- result: torch.Tensor
- Jacobian-matrix product.
- Has shape [V, N, C_b, ...] if `sum_batch == False`.
- Has shape [V, C_b, ...] if `sum_batch == True`.
+ Jacobian-matrix product. Has shape [V, N, C_out, H_out, ...].
"""
- return self._bias_jac_t_mat_prod(module, g_inp, g_out, mat, sum_batch=sum_batch)
+ return self._bias_jac_mat_prod(module, g_inp, g_out, mat)
- def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- """Internal implementation of the transposed bias Jacobian."""
+ def _bias_jac_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
raise NotImplementedError
@shape_check.weight_jac_mat_prod_accept_vectors
@shape_check.weight_jac_mat_prod_check_shapes
- def weight_jac_mat_prod(self, module, g_inp, g_out, mat):
+ def weight_jac_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
"""Apply Jacobian of the output w.r.t. weight to a matrix.
- Parameters:
- -----------
- mat: torch.Tensor
- Matrix the Jacobian will be applied to.
- Must have shape [V, C_w, H_w, ...].
+ Args:
+ module: module to perform derivatives on
+ g_inp: input gradients
+ g_out: output gradients
+ mat: Matrix the Jacobian will be applied to.
+ Must have shape [V, C_w, H_w, ...].
Returns:
- --------
- result: torch.Tensor
Jacobian-matrix product.
Has shape [V, N, C_out, H_out, ...].
"""
return self._weight_jac_mat_prod(module, g_inp, g_out, mat)
- def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
- """Internal implementation of weight Jacobian."""
+ def _weight_jac_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
raise NotImplementedError
- @shape_check.weight_jac_t_mat_prod_accept_vectors
- @shape_check.weight_jac_t_mat_prod_check_shapes
- def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- """Apply transposed Jacobian of the output w.r.t. weight to a matrix.
- Parameters:
- -----------
- mat: torch.Tensor
- Matrix the transposed Jacobian will be applied to.
- Must have shape [V, N, C_out, H_out, ...].
- sum_batch: bool
- Whether to sum over the batch dimension on the fly.
+class BaseLossDerivatives(BaseDerivatives, ABC):
+ """Second- order partial derivatives of loss functions."""
+
+ # TODO Add shape check
+ def sqrt_hessian(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Symmetric factorization ('sqrt') of the loss Hessian.
+
+ The Hessian factorization is returned in format ``Hs = [D, N, D]``, where
+ ``Hs[:, n, :]`` is the Hessian factorization for the ``n``th sample, i.e.
+ ``Hs[:, n, :]ᵀ Hs[:, n, :]`` is the Hessian w.r.t. to the ``n``th sample.
+
+ Args:
+ module: Loss layer whose factorized Hessian will be computed.
+ g_inp: Gradients w.r.t. module input.
+ g_out: Gradients w.r.t. module output.
+ subsampling: Indices of data samples to be considered. Default of ``None``
+ uses all data in the mini-batch.
Returns:
- --------
- result: torch.Tensor
- Jacobian-matrix product.
- Has shape [V, N, C_w, H_w, ...] if `sum_batch == False`.
- Has shape [V, C_w, H_w, ...] if `sum_batch == True`.
+ Symmetric factorization of the loss Hessian for each sample. If the input
+ to the loss has shape ``[N, D]``, this is a tensor of shape ``[D, N, D]``;
+ if used with sub-sampling, ``N`` is replaced by ``len(subsampling)``.
+ For fixed ``n``, squaring the matrix implied by the slice ``[:, n, :]``
+ results in the loss Hessian w.r.t. to sample ``n``.
"""
- return self._weight_jac_t_mat_prod(
- module, g_inp, g_out, mat, sum_batch=sum_batch
- )
-
- def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- """Internal implementation of transposed weight Jacobian."""
+ self._check_2nd_order_make_sense(module, g_out)
+ return self._sqrt_hessian(module, g_inp, g_out, subsampling=subsampling)
+
+ def _sqrt_hessian(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
raise NotImplementedError
-
-class BaseLossDerivatives(BaseDerivatives):
- """Second- order partial derivatives of loss functions."""
-
# TODO Add shape check
- def sqrt_hessian(self, module, g_inp, g_out):
- """Symmetric factorization ('sqrt') of the loss Hessian."""
- self.check_2nd_order_make_sense(module, g_inp, g_out)
- return self._sqrt_hessian(module, g_inp, g_out)
+ def sqrt_hessian_sampled(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mc_samples: int = 1,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """A Monte-Carlo sampled symmetric factorization of the loss Hessian.
+
+ The Hessian factorization is returned in format ``Hs = [M, N, D]``, where
+ ``Hs[:, n, :]`` approximates the Hessian factorization for the ``n``th sample,
+ i.e. ``Hs[:, n, :]ᵀ Hs[:, n, :]ᵀ`` approximates the Hessian w.r.t. to sample
+ ``n``.
- def _sqrt_hessian(self, module, g_inp, g_out):
- raise NotImplementedError
+ Args:
+ module: Loss layer whose factorized Hessian will be computed.
+ g_inp: Gradients w.r.t. module input.
+ g_out: Gradients w.r.t. module output.
+ mc_samples: Number of samples used for MC approximation.
+ subsampling: Indices of data samples to be considered. Default of ``None``
+ uses all data in the mini-batch.
- # TODO Add shape check
- def sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
- """Monte-Carlo sampled symmetric factorization of the loss Hessian."""
- self.check_2nd_order_make_sense(module, g_inp, g_out)
- return self._sqrt_hessian_sampled(module, g_inp, g_out, mc_samples=mc_samples)
+ Returns:
+ Symmetric factorization of the loss Hessian for each sample. If the input
+ to the loss has shape ``[N, D]``, this is a tensor of shape ``[M, N, D]``
+ when using ``M`` MC samples; if used with sub-sampling, ``N`` is replaced
+ by ``len(subsampling)``. For fixed ``n``, squaring the matrix implied by the
+ slice ``[:, n, :]`` approximates the loss Hessian w.r.t. to sample ``n``.
+ """
+ self._check_2nd_order_make_sense(module, g_out)
+ return self._sqrt_hessian_sampled(
+ module, g_inp, g_out, mc_samples=mc_samples, subsampling=subsampling
+ )
- def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
+ def _sqrt_hessian_sampled(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mc_samples: int = 1,
+ subsampling=None,
+ ) -> Tensor:
raise NotImplementedError
@shape_check.make_hessian_mat_prod_accept_vectors
@shape_check.make_hessian_mat_prod_check_shapes
- def make_hessian_mat_prod(self, module, g_inp, g_out):
+ def make_hessian_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> Callable[[Tensor], Tensor]:
"""Multiplication of the input Hessian with a matrix.
Return a function that maps mat to H * mat.
+
+ Args:
+ module: module to perform derivatives on
+ g_inp: input gradients
+ g_out: output gradients
+
+ Returns:
+ function that maps mat to H * mat
"""
- self.check_2nd_order_make_sense(module, g_inp, g_out)
+ self._check_2nd_order_make_sense(module, g_out)
return self._make_hessian_mat_prod(module, g_inp, g_out)
- def _make_hessian_mat_prod(self, module, g_inp, g_out):
+ def _make_hessian_mat_prod(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> Callable[[Tensor], Tensor]:
raise NotImplementedError
# TODO Add shape check
- def sum_hessian(self, module, g_inp, g_out):
- """Loss Hessians, summed over the batch dimension."""
- self.check_2nd_order_make_sense(module, g_inp, g_out)
+ def sum_hessian(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> Tensor:
+ """Loss Hessians, summed over the batch dimension.
+
+ Args:
+ module: module to perform derivatives on
+ g_inp: input gradients
+ g_out: output gradients
+
+ Returns:
+ sum of hessians
+ """
+ self._check_2nd_order_make_sense(module, g_out)
return self._sum_hessian(module, g_inp, g_out)
- def _sum_hessian(self, module, g_inp, g_out):
+ def _sum_hessian(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> Tensor:
raise NotImplementedError
- def check_2nd_order_make_sense(self, module, g_inp, g_out):
+ def _check_2nd_order_make_sense(self, module: Module, g_out: Tuple[Tensor]) -> None:
"""Verify conditions for 2nd-order extensions to be working.
2nd-order extensions are only guaranteed to work if the `loss`,
on which `backward()` is called, is a scalar that has not been
modified further after passing through the loss function module.
+
+ Args:
+ module: module to perform derivatives on
+ g_out: output gradients
"""
self._check_output_is_scalar(module)
self._check_loss_has_not_been_modified(module, g_out)
- def _check_output_is_scalar(self, module):
- """Raise an exception is the module output is not a scalar."""
+ @classmethod
+ def _check_output_is_scalar(cls, module: Module) -> None:
+ """Raise an exception is the module output is not a scalar.
+
+ Args:
+ module: module to perform derivatives on
+
+ Raises:
+ ValueError: if output is not scalar
+ """
if module.output.numel() != 1:
raise ValueError(
"Output must be scalar. Got {}".format(module.output.shape)
)
- def _check_loss_has_not_been_modified(self, module, g_out):
- """Raise a warning if the module output seems to have been changed."""
+ @classmethod
+ def _check_loss_has_not_been_modified(
+ cls, module: Module, g_out: Tuple[Tensor]
+ ) -> None:
+ """Raise a warning if the module output seems to have been changed.
+
+ Args:
+ module: module to perform derivatives on
+ g_out: output gradients
+ """
grad_out_is_identity = g_out is None or (g_out[0] == 1.0).all().item()
if not grad_out_is_identity:
warnings.warn(
diff --git a/backpack/core/derivatives/batchnorm1d.py b/backpack/core/derivatives/batchnorm1d.py
deleted file mode 100644
index cbdf8cc82..000000000
--- a/backpack/core/derivatives/batchnorm1d.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from warnings import warn
-
-from torch import einsum
-
-from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
-
-
-class BatchNorm1dDerivatives(BaseParameterDerivatives):
- def hessian_is_zero(self):
- return False
-
- def hessian_is_diagonal(self):
- return False
-
- def _jac_mat_prod(self, module, g_inp, g_out, mat):
- return self._jac_t_mat_prod(module, g_inp, g_out, mat)
-
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
- """
- Note:
- -----
- The Jacobian is *not independent* among the batch dimension, i.e.
- D z_i = D z_i(x_1, ..., x_B).
-
- This structure breaks the computation of the GGN diagonal,
- for curvature-matrix products it should still work.
-
- References:
- -----------
- https://kevinzakka.github.io/2016/09/14/batch_normalization/
- https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html
- """
- assert module.affine is True
-
- N = module.input0.size(0)
- x_hat, var = self.get_normalized_input_and_var(module)
- ivar = 1.0 / (var + module.eps).sqrt()
-
- dx_hat = einsum("vni,i->vni", (mat, module.weight))
-
- jac_t_mat = N * dx_hat
- jac_t_mat -= dx_hat.sum(1).unsqueeze(1).expand_as(jac_t_mat)
- jac_t_mat -= einsum("ni,vsi,si->vni", (x_hat, dx_hat, x_hat))
- jac_t_mat = einsum("vni,i->vni", (jac_t_mat, ivar / N))
-
- return jac_t_mat
-
- def get_normalized_input_and_var(self, module):
- input = module.input0
- mean = input.mean(dim=0)
- var = input.var(dim=0, unbiased=False)
- return (input - mean) / (var + module.eps).sqrt(), var
-
- def _residual_mat_prod(self, module, g_inp, g_out, mat):
- """Multiply with BatchNorm1d residual-matrix.
-
- Paul Fischer (GitHub: @paulkogni) contributed this code during a research
- project in winter 2019.
-
- Details are described in
-
- - `TODO: Add tech report title`
- _
- by Paul Fischer, 2020.
- """
- N = module.input0.size(0)
- x_hat, var = self.get_normalized_input_and_var(module)
- gamma = module.weight
- eps = module.eps
-
- factor = gamma / (N * (var + eps))
-
- sum_127 = einsum("nc,vnc->vc", (x_hat, mat))
- sum_24 = einsum("nc->c", g_out[0])
- sum_3 = einsum("nc,vnc->vc", (g_out[0], mat))
- sum_46 = einsum("vnc->vc", mat)
- sum_567 = einsum("nc,nc->c", (x_hat, g_out[0]))
-
- r_mat = -einsum("nc,vc->vnc", (g_out[0], sum_127))
- r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_24, sum_127)).unsqueeze(1).expand(
- -1, N, -1
- )
- r_mat -= einsum("nc,vc->vnc", (x_hat, sum_3))
- r_mat += (1.0 / N) * einsum("nc,c,vc->vnc", (x_hat, sum_24, sum_46))
-
- r_mat -= einsum("vnc,c->vnc", (mat, sum_567))
- r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_567, sum_46)).unsqueeze(1).expand(
- -1, N, -1
- )
- r_mat += (3.0 / N) * einsum("nc,vc,c->vnc", (x_hat, sum_127, sum_567))
-
- return einsum("c,vnc->vnc", (factor, r_mat))
-
- def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
- x_hat, _ = self.get_normalized_input_and_var(module)
- return einsum("ni,vi->vni", (x_hat, mat))
-
- def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch):
- if not sum_batch:
- warn(
- "BatchNorm batch summation disabled."
- "This may not compute meaningful quantities"
- )
- x_hat, _ = self.get_normalized_input_and_var(module)
- equation = "vni,ni->v{}i".format("" if sum_batch is True else "n")
- operands = [mat, x_hat]
- return einsum(equation, operands)
-
- def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
- N = module.input0.size(0)
- return mat.unsqueeze(1).repeat(1, N, 1)
-
- def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- if not sum_batch:
- warn(
- "BatchNorm batch summation disabled."
- "This may not compute meaningful quantities"
- )
- return mat
- else:
- N_axis = 1
- return mat.sum(N_axis)
diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py
new file mode 100644
index 000000000..7fe15255a
--- /dev/null
+++ b/backpack/core/derivatives/batchnorm_nd.py
@@ -0,0 +1,312 @@
+"""Contains derivatives for BatchNorm."""
+from typing import List, Tuple, Union
+
+from torch import Size, Tensor, einsum
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
+from backpack.utils.subsampling import subsample
+
+
+class BatchNormNdDerivatives(BaseParameterDerivatives):
+ """Derivatives for BatchNorm1d, 2d and 3d.
+
+ If training=False: saved statistics are used.
+ If training=True: statistics of current batch are used.
+
+ Index convention:
+ n: batch axis
+ c: category axis
+ {empty}/l/hw/dhw: dimension axis for 0/1/2/3-dimensions (alternatively using xyz)
+ ...: usually for the remaining dimension axis (same as dhw)
+
+ Links to PyTorch docs:
+ https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
+ https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
+ https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html
+
+ As a starting point for derivative computation, see these references:
+ https://kevinzakka.github.io/2016/09/14/batch_normalization/
+ https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html
+ """
+
+ def _check_parameters(
+ self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]
+ ) -> None:
+ if module.affine is False:
+ raise NotImplementedError("Only implemented for affine=True")
+ if module.track_running_stats is False:
+ raise NotImplementedError("Only implemented for track_running_stats=True")
+
+ def hessian_is_zero(
+ self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]
+ ) -> bool:
+ """Whether hessian is zero.
+
+ Args:
+ module: current module to evaluate
+
+ Returns:
+ whether hessian is zero
+ """
+ return not module.training
+
+ def hessian_is_diagonal(
+ self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]
+ ) -> bool:
+ """Whether hessian is diagonal.
+
+ Args:
+ module: current module to evaluate
+
+ Returns:
+ whether hessian is diagonal
+
+ Raises:
+ NotImplementedError: if module is in evaluation mode
+ """
+ if module.training:
+ return False
+ else:
+ raise NotImplementedError(
+ "hessian_is_diagonal is not tested for BatchNorm. "
+ "Create an issue if you need it."
+ )
+
+ def _jac_mat_prod(
+ self,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ ) -> Tensor:
+ return self._jac_t_mat_prod(module, g_inp, g_out, mat)
+
+ def _jac_t_mat_prod(
+ self,
+ module: BatchNorm1d,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self._check_parameters(module)
+ N: int = self._get_n_axis(module)
+ if module.training:
+
+ if subsampling is not None:
+ raise NotImplementedError(
+ "BatchNorm VJP sub-sampling is not defined in train mode."
+ )
+
+ denominator: int = self._get_denominator(module)
+ x_hat, var = self._get_normalized_input_and_var(module)
+ ivar = 1.0 / (var + module.eps).sqrt()
+
+ dx_hat: Tensor = einsum("vnc...,c->vnc...", mat, module.weight)
+ jac_t_mat = denominator * dx_hat
+ jac_t_mat -= dx_hat.sum(
+ self._get_free_axes(module),
+ keepdim=True,
+ ).expand_as(jac_t_mat)
+ spatial_dims = "xyz"[:N]
+ jac_t_mat -= einsum(
+ f"nc...,vmc{spatial_dims},mc{spatial_dims}->vnc...",
+ x_hat,
+ dx_hat,
+ x_hat,
+ )
+ jac_t_mat = einsum("vnc...,c->vnc...", jac_t_mat, ivar / denominator)
+ return jac_t_mat
+ else:
+ return einsum(
+ "c,vnc...->vnc...",
+ ((module.running_var + module.eps) ** (-0.5)) * module.weight,
+ mat,
+ )
+
+ def _weight_jac_mat_prod(
+ self,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ ) -> Tensor:
+ x_hat, _ = self._get_normalized_input_and_var(module)
+ return einsum("nc...,vc->vnc...", x_hat, mat)
+
+ def _weight_jac_t_mat_prod(
+ self,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ x_hat, _ = self._get_normalized_input_and_var(module)
+ x_hat = subsample(x_hat, subsampling=subsampling)
+
+ equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c"
+ return einsum(equation, mat, x_hat)
+
+ def _bias_jac_mat_prod(
+ self,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ ) -> Tensor:
+ out = self._unsqueeze_free_axis(module, mat, 1)
+ dim_expand: List[int] = [-1, module.input0.shape[0], -1]
+ for n in range(self._get_n_axis(module)):
+ dim_expand.append(module.input0.shape[2 + n])
+ return out.expand(*dim_expand)
+
+ def _bias_jac_t_mat_prod(
+ self,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ axis_sum: Tuple[int] = self._get_free_axes(module, with_batch_axis=sum_batch)
+ return mat.sum(dim=axis_sum) if axis_sum else mat
+
+ def _residual_mat_prod(
+ self,
+ module: BatchNorm1d,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ ) -> Tensor:
+ """Multiply with BatchNorm1d residual-matrix.
+
+ Paul Fischer (GitHub: @paulkogni) contributed this code during a research
+ project in winter 2019.
+
+ Details are described in
+
+ `HESSIAN BACKPROPAGATION FOR BATCHNORM`
+
+ by Paul Fischer, 2020.
+
+ Args:
+ module: module
+ g_inp: input gradients
+ g_out: output gradients
+ mat: matrix to multiply
+
+ Returns:
+ product
+
+ Raises:
+ NotImplementedError: if used with a not supported mode or input
+ """ # noqa: B950
+ self._check_parameters(module)
+ if module.training is False:
+ raise NotImplementedError("residual_mat_prod works only for training mode.")
+ if module.input0.dim() != 2:
+ raise NotImplementedError(
+ "residual_mat_prod is implemented only for 0 dimensions. "
+ "If you need more dimension make a feature request."
+ )
+
+ N = module.input0.size(0)
+ x_hat, var = self._get_normalized_input_and_var(module)
+ gamma = module.weight
+ eps = module.eps
+
+ factor = gamma / (N * (var + eps))
+
+ sum_127 = einsum("nc,vnc->vc", x_hat, mat)
+ sum_24 = einsum("nc->c", g_out[0])
+ sum_3 = einsum("nc,vnc->vc", g_out[0], mat)
+ sum_46 = einsum("vnc->vc", mat)
+ sum_567 = einsum("nc,nc->c", x_hat, g_out[0])
+
+ r_mat = -einsum("nc,vc->vnc", g_out[0], sum_127)
+ r_mat += (1.0 / N) * einsum("c,vc->vc", sum_24, sum_127).unsqueeze(1).expand(
+ -1, N, -1
+ )
+ r_mat -= einsum("nc,vc->vnc", x_hat, sum_3)
+ r_mat += (1.0 / N) * einsum("nc,c,vc->vnc", x_hat, sum_24, sum_46)
+
+ r_mat -= einsum("vnc,c->vnc", mat, sum_567)
+ r_mat += (1.0 / N) * einsum("c,vc->vc", sum_567, sum_46).unsqueeze(1).expand(
+ -1, N, -1
+ )
+ r_mat += (3.0 / N) * einsum("nc,vc,c->vnc", x_hat, sum_127, sum_567)
+
+ return einsum("c,vnc->vnc", factor, r_mat)
+
+ ###############################################################
+ # HELPER FUNCTIONS ###
+ ###############################################################
+ def _get_normalized_input_and_var(
+ self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]
+ ) -> Tuple[Tensor, Tensor]:
+ input: Tensor = module.input0
+ if module.training:
+ dim: Tuple[int] = self._get_free_axes(module, index_batch=0)
+ mean: Tensor = input.mean(dim=dim)
+ var: Tensor = input.var(dim=dim, unbiased=False)
+ else:
+ mean: Tensor = module.running_mean
+ var: Tensor = module.running_var
+ mean: Tensor = self._unsqueeze_free_axis(module, mean, 0)
+ var_expanded: Tensor = self._unsqueeze_free_axis(module, var, 0)
+ return (input - mean) / (var_expanded + module.eps).sqrt(), var
+
+ def _get_denominator(
+ self, module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]
+ ) -> int:
+ shape_input: Size = module.input0.shape
+ free_axes: Tuple[int] = self._get_free_axes(module, index_batch=0)
+ denominator: int = 1
+ for index in free_axes:
+ denominator *= shape_input[index]
+ return denominator
+
+ @staticmethod
+ def _get_n_axis(module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]) -> int:
+ return module.input0.dim() - 2
+
+ def _unsqueeze_free_axis(
+ self,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ tensor: Tensor,
+ index_batch: int,
+ ) -> Tensor:
+ """Unsqueezes the free dimensions.
+
+ This function is useful to avoid broadcasting.
+ Also useful when applying .expand(self._get_free_axes()) afterwards.
+
+ Args:
+ module: extended module
+ tensor: the tensor to operate on
+ index_batch: the batch axes index
+
+ Returns:
+ tensor with the free dimensions unsqueezed.
+ """
+ out = tensor.unsqueeze(index_batch)
+ for _ in range(self._get_n_axis(module)):
+ out = out.unsqueeze(-1)
+ return out
+
+ def _get_free_axes(
+ self,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ with_batch_axis: bool = True,
+ index_batch: int = 1,
+ ) -> Tuple[int]:
+ free_axes: List[int] = []
+ if with_batch_axis:
+ free_axes.append(index_batch)
+ for n in range(self._get_n_axis(module)):
+ free_axes.append(index_batch + n + 2)
+ return tuple(free_axes)
diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py
index 9cf9f223d..f0046337f 100644
--- a/backpack/core/derivatives/conv_transposend.py
+++ b/backpack/core/derivatives/conv_transposend.py
@@ -1,58 +1,42 @@
"""Partial derivatives for ``torch.nn.ConvTranspose{1,2,3}d``."""
+from typing import List, Tuple, Union
+
from einops import rearrange
from numpy import prod
-from torch import einsum
-from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
-from torch.nn.functional import (
- conv1d,
- conv2d,
- conv3d,
- conv_transpose1d,
- conv_transpose2d,
- conv_transpose3d,
-)
+from torch import Tensor, einsum
+from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module
from torch.nn.grad import _grad_input_padding
from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
-from backpack.utils.conv_transpose import unfold_by_conv_transpose
+from backpack.utils.conv import get_conv_function
+from backpack.utils.conv_transpose import (
+ get_conv_transpose_function,
+ unfold_by_conv_transpose,
+)
+from backpack.utils.subsampling import subsample
class ConvTransposeNDDerivatives(BaseParameterDerivatives):
"""Base class for partial derivatives of transpose convolution."""
- def __init__(self, N):
- """Store convolution dimension and operations.
+ def __init__(self, N: int):
+ """Store transpose convolution dimension and operations.
Args:
- N (int): Convolution dimension. Must be ``1``, ``2``, or ``3``.
-
- Raises:
- ValueError: If convolution dimension is unsupported.
+ N: Transpose convolution dimension.
"""
- if N == 1:
- self.module = ConvTranspose1d
- self.conv_func = conv1d
- self.conv_transpose_func = conv_transpose1d
- elif N == 2:
- self.module = ConvTranspose2d
- self.conv_func = conv2d
- self.conv_transpose_func = conv_transpose2d
- elif N == 3:
- self.module = ConvTranspose3d
- self.conv_func = conv3d
- self.conv_transpose_func = conv_transpose3d
- else:
- raise ValueError(f"ConvTranspose{N}d not supported.")
+ self.conv_func = get_conv_function(N)
+ self.conv_transpose_func = get_conv_transpose_function(N)
self.conv_dims = N
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
return True
- def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- axes = list(range(3, len(module.output.shape) + 1))
- if sum_batch:
- axes = [1] + axes
- return mat.sum(axes)
+ def _bias_jac_t_mat_prod(
+ self, module, g_inp, g_out, mat, sum_batch=True, subsampling=None
+ ):
+ equation = f"vnc...->v{'' if sum_batch else 'n'}c"
+ return einsum(equation, mat)
def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
# Expand batch dimension
@@ -84,18 +68,26 @@ def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
return self.reshape_like_output(jac_mat, module)
- def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
+ def _weight_jac_t_mat_prod(
+ self,
+ module: Union[ConvTranspose1d, ConvTranspose2d, ConvTranspose3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
V = mat.shape[0]
G = module.groups
C_in = module.input0.shape[1]
- N = module.output.shape[0]
+ N = module.output.shape[0] if subsampling is None else len(subsampling)
C_out = module.output.shape[1]
mat_reshape = mat.reshape(V, N, G, C_out // G, *module.output.shape[2:])
- u = unfold_by_conv_transpose(module.input0, module).reshape(
- N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:]
- )
+ u = unfold_by_conv_transpose(
+ subsample(module.input0, subsampling=subsampling), module
+ ).reshape(N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:])
dims_kern = "xyz"[: self.conv_dims]
dims_data = "abc"[: self.conv_dims]
@@ -139,7 +131,7 @@ def __jac(self, module, mat):
dilation=module.dilation,
)
- jac_t_mat = conv_transpose1d(
+ jac_t_mat = self.conv_transpose_func(
input=mat,
weight=module.weight,
bias=None,
@@ -149,14 +141,22 @@ def __jac(self, module, mat):
groups=module.groups,
dilation=module.dilation,
)
+
return jac_t_mat
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
+ def _jac_t_mat_prod(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
jmp_as_conv = self.__jac_t(module, mat_as_conv)
- return self.reshape_like_input(jmp_as_conv, module)
+ return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling)
- def __jac_t(self, module, mat):
+ def __jac_t(self, module: Module, mat: Tensor) -> Tensor:
jac_t = self.conv_func(
mat,
module.weight,
diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py
index e0c5bf8cf..167a5ad75 100644
--- a/backpack/core/derivatives/convnd.py
+++ b/backpack/core/derivatives/convnd.py
@@ -1,21 +1,16 @@
-import warnings
+from typing import List, Tuple, Union
+from warnings import warn
from einops import rearrange, reduce
from numpy import prod
-from torch import einsum
-from torch.nn import Conv1d, Conv2d, Conv3d
-from torch.nn.functional import (
- conv1d,
- conv2d,
- conv3d,
- conv_transpose1d,
- conv_transpose2d,
- conv_transpose3d,
-)
+from torch import Tensor, einsum
+from torch.nn import Conv1d, Conv2d, Conv3d, Module
from torch.nn.grad import _grad_input_padding
from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
-from backpack.utils import conv as convUtils
+from backpack.utils.conv import get_conv_function, unfold_by_conv
+from backpack.utils.conv_transpose import get_conv_transpose_function
+from backpack.utils.subsampling import subsample
class weight_jac_t_save_memory:
@@ -38,27 +33,15 @@ def __exit__(self, type, value, traceback):
class ConvNDDerivatives(BaseParameterDerivatives):
def __init__(self, N):
- if N == 1:
- self.module = Conv1d
- self.conv_func = conv1d
- self.conv_transpose_func = conv_transpose1d
- elif N == 2:
- self.module = Conv2d
- self.conv_func = conv2d
- self.conv_transpose_func = conv_transpose2d
- elif N == 3:
- self.module = Conv3d
- self.conv_func = conv3d
- self.conv_transpose_func = conv_transpose3d
- else:
- raise ValueError("{}-dimensional Conv. is not implemented.".format(N))
+ self.conv_func = get_conv_function(N)
+ self.conv_transpose_func = get_conv_transpose_function(N)
self.conv_dims = N
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
return True
def get_unfolded_input(self, module):
- return convUtils.unfold_by_conv(module.input0, module)
+ return unfold_by_conv(module.input0, module)
def _jac_mat_prod(self, module, g_inp, g_out, mat):
mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
@@ -72,10 +55,17 @@ def _jac_mat_prod(self, module, g_inp, g_out, mat):
)
return self.reshape_like_output(jmp_as_conv, module)
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
+ def _jac_t_mat_prod(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
jmp_as_conv = self.__jac_t(module, mat_as_conv)
- return self.reshape_like_input(jmp_as_conv, module)
+ return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling)
def __jac_t(self, module, mat):
input_size = list(module.input0.size())
@@ -114,11 +104,11 @@ def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
return jac_mat.expand(*expand_shape)
- def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- axes = list(range(3, len(module.output.shape) + 1))
- if sum_batch:
- axes = [1] + axes
- return mat.sum(axes)
+ def _bias_jac_t_mat_prod(
+ self, module, g_inp, g_out, mat, sum_batch=True, subsampling=None
+ ):
+ equation = f"vnc...->v{'' if sum_batch else 'n'}c"
+ return einsum(equation, mat)
def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
# separate output channel groups
@@ -131,29 +121,41 @@ def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
return self.reshape_like_output(jac_mat, module)
- def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
+ def _weight_jac_t_mat_prod(
+ self,
+ module: Union[Conv1d, Conv2d, Conv3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
save_memory = weight_jac_t_save_memory._SAVE_MEMORY
if save_memory and self.conv_dims in [1, 2]:
- return self.__higher_conv_weight_jac_t(module, mat, sum_batch)
-
+ weight_jac_t_func = self.__higher_conv_weight_jac_t
else:
-
if save_memory and self.conv_dims == 3:
- warnings.warn(
- UserWarning(
- "Conv3d: Cannot save memory as there is no Conv4d."
- + " Fallback to more memory-intense method."
- )
+ warn(
+ "Conv3d: Cannot save memory as there is no Conv4d."
+ + " Fallback to more memory-intense method."
)
+ weight_jac_t_func = self.__same_conv_weight_jac_t
- return self.__same_conv_weight_jac_t(module, mat, sum_batch)
+ return weight_jac_t_func(module, mat, sum_batch, subsampling=subsampling)
- def __same_conv_weight_jac_t(self, module, mat, sum_batch):
+ def __same_conv_weight_jac_t(
+ self,
+ module: Union[Conv1d, Conv2d, Conv3d],
+ mat: Tensor,
+ sum_batch: bool,
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""Uses convolution of same order."""
G = module.groups
V = mat.shape[0]
- N, C_out = module.output.shape[0], module.output.shape[1]
+ C_out = module.output.shape[1]
+ N = module.output.shape[0] if subsampling is None else len(subsampling)
C_in = module.input0.shape[1]
C_in_axis = 1
N_axis = 0
@@ -165,7 +167,9 @@ def __same_conv_weight_jac_t(self, module, mat, sum_batch):
mat = rearrange(mat, "a b ... -> (a b) ...")
mat = mat.unsqueeze(C_in_axis)
- input = rearrange(module.input0, "n c ... -> (n c) ...")
+ input = rearrange(
+ subsample(module.input0, subsampling=subsampling), "n c ... -> (n c) ..."
+ )
input = input.unsqueeze(N_axis)
repeat_pattern = [1, V] + [1 for _ in range(self.conv_dims)]
input = input.repeat(*repeat_pattern)
@@ -191,7 +195,13 @@ def __same_conv_weight_jac_t(self, module, mat, sum_batch):
else:
return rearrange(grad_weight, "(v n g i o) ... -> v n (g o) i ...", **dim)
- def __higher_conv_weight_jac_t(self, module, mat, sum_batch):
+ def __higher_conv_weight_jac_t(
+ self,
+ module: Union[Conv1d, Conv2d, Conv3d],
+ mat: Tensor,
+ sum_batch: bool,
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""Requires higher-order convolution.
The algorithm is proposed in:
@@ -201,30 +211,22 @@ def __higher_conv_weight_jac_t(self, module, mat, sum_batch):
"""
G = module.groups
V = mat.shape[0]
- N, C_out = module.output.shape[0], module.output.shape[1]
+ C_out = module.output.shape[1]
+ N = module.output.shape[0] if subsampling is None else len(subsampling)
C_in = module.input0.shape[1]
- if self.conv_dims == 1:
- _, _, L_in = module.input0.size()
- higher_conv_func = conv2d
- K_L_axis = 2
- K_L = module.kernel_size[0]
- spatial_dim = (C_in // G, L_in)
- spatial_dim_axis = (1, V, 1, 1)
- spatial_dim_new = (C_in // G, K_L)
- else:
- _, _, H_in, W_in = module.input0.size()
- higher_conv_func = conv3d
- K_H_axis, K_W_axis = 2, 3
- K_H, K_W = module.kernel_size
- spatial_dim = (C_in // G, H_in, W_in)
- spatial_dim_axis = (1, V, 1, 1, 1)
- spatial_dim_new = (C_in // G, K_H, K_W)
+ higher_conv_func = get_conv_function(self.conv_dims + 1)
+
+ spatial_dim = (C_in // G,) + module.input0.shape[2:]
+ spatial_dim_axis = (1, V) + tuple([1] * (self.conv_dims + 1))
+ spatial_dim_new = (C_in // G,) + module.weight.shape[2:]
# Reshape to extract groups from the convolutional layer
# Channels are seen as an extra spatial dimension with kernel size 1
- input_conv = module.input0.reshape(1, N * G, *spatial_dim).repeat(
- *spatial_dim_axis
+ input_conv = (
+ subsample(module.input0, subsampling=subsampling)
+ .reshape(1, N * G, *spatial_dim)
+ .repeat(*spatial_dim_axis)
)
# Compute convolution between input and output; the batchsize is seen
# as channels, taking advantage of the `groups` argument
@@ -245,10 +247,8 @@ def __higher_conv_weight_jac_t(self, module, mat, sum_batch):
# Because of rounding shapes when using non-default stride or dilation,
# convolution result must be truncated to convolution kernel size
- if self.conv_dims == 1:
- conv = conv.narrow(K_L_axis, 0, K_L)
- else:
- conv = conv.narrow(K_H_axis, 0, K_H).narrow(K_W_axis, 0, K_W)
+ for axis in range(2, 2 + self.conv_dims):
+ conv = conv.narrow(axis, 0, module.weight.shape[axis])
new_shape = [V, N, C_out, *spatial_dim_new]
weight_grad = conv.reshape(*new_shape)
diff --git a/backpack/core/derivatives/crossentropyloss.py b/backpack/core/derivatives/crossentropyloss.py
index 8ead9f404..68690df9a 100644
--- a/backpack/core/derivatives/crossentropyloss.py
+++ b/backpack/core/derivatives/crossentropyloss.py
@@ -1,11 +1,14 @@
"""Partial derivatives for cross-entropy loss."""
from math import sqrt
+from typing import Callable, Dict, List, Tuple
-from torch import diag, diag_embed, einsum, multinomial, ones_like, softmax
-from torch import sqrt as torchsqrt
+from einops import rearrange
+from torch import Tensor, diag, diag_embed, einsum, eye, multinomial, ones_like, softmax
+from torch.nn import CrossEntropyLoss
from torch.nn.functional import one_hot
from backpack.core.derivatives.basederivatives import BaseLossDerivatives
+from backpack.utils.subsampling import subsample
class CrossEntropyLossDerivatives(BaseLossDerivatives):
@@ -15,29 +18,47 @@ class CrossEntropyLossDerivatives(BaseLossDerivatives):
and negative log-likelihood.
"""
- def _sqrt_hessian(self, module, g_inp, g_out):
+ def _sqrt_hessian(
+ self,
+ module: CrossEntropyLoss,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
self._check_2nd_order_parameters(module)
- probs = self._get_probs(module)
- tau = torchsqrt(probs)
+ probs = self._get_probs(module, subsampling=subsampling)
+ probs, *rearrange_info = self._merge_batch_and_additional(probs)
+
+ tau = probs.sqrt()
V_dim, C_dim = 0, 2
Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
Id_tautau = Id - einsum("nv,nc->vnc", tau, tau)
sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau)
if module.reduction == "mean":
- N = module.input0.shape[0]
- sqrt_H /= sqrt(N)
+ sqrt_H /= sqrt(self._get_mean_normalization(module.input0))
+ sqrt_H = self._ungroup_batch_and_additional(sqrt_H, *rearrange_info)
+ sqrt_H = self._expand_sqrt_h(sqrt_H)
return sqrt_H
- def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
+ def _sqrt_hessian_sampled(
+ self,
+ module: CrossEntropyLoss,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mc_samples: int = 1,
+ subsampling: List[int] = None,
+ ) -> Tensor:
self._check_2nd_order_parameters(module)
M = mc_samples
C = module.input0.shape[1]
- probs = self._get_probs(module)
+ probs = self._get_probs(module, subsampling=subsampling)
+ probs, *rearrange_info = self._merge_batch_and_additional(probs)
+
V_dim = 0
probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)
@@ -48,54 +69,86 @@ def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)
if module.reduction == "mean":
- N = module.input0.shape[0]
- sqrt_mc_h /= sqrt(N)
+ sqrt_mc_h /= sqrt(self._get_mean_normalization(module.input0))
+ sqrt_mc_h = self._ungroup_batch_and_additional(sqrt_mc_h, *rearrange_info)
return sqrt_mc_h
- def _sum_hessian(self, module, g_inp, g_out):
+ def _sum_hessian(
+ self, module: CrossEntropyLoss, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> Tensor:
self._check_2nd_order_parameters(module)
probs = self._get_probs(module)
- sum_H = diag(probs.sum(0)) - einsum("bi,bj->ij", (probs, probs))
+
+ if probs.dim() == 2:
+ diagonal = diag(probs.sum(0))
+ sum_H = diagonal - einsum("nc,nd->cd", probs, probs)
+ else:
+ out_shape = (*probs.shape[1:], *probs.shape[1:])
+ additional = probs.shape[2:].numel()
+
+ diagonal = diag(probs.sum(0).flatten()).reshape(out_shape)
+
+ probs = probs.flatten(2)
+ kron_delta = eye(additional, device=probs.device, dtype=probs.dtype)
+
+ sum_H = diagonal - einsum(
+ "ncx,ndy,xy->cxdy", probs, probs, kron_delta
+ ).reshape(out_shape)
if module.reduction == "mean":
- N = module.input0.shape[0]
- sum_H /= N
+ sum_H /= self._get_mean_normalization(module.input0)
return sum_H
- def _make_hessian_mat_prod(self, module, g_inp, g_out):
- """Multiplication of the input Hessian with a matrix."""
+ def _make_hessian_mat_prod(
+ self, module: CrossEntropyLoss, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> Callable[[Tensor], Tensor]:
self._check_2nd_order_parameters(module)
probs = self._get_probs(module)
def hessian_mat_prod(mat):
- Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum(
- "bi,bj,cbj->cbi", (probs, probs, mat)
+ Hmat = einsum("...,v...->v...", probs, mat) - einsum(
+ "nc...,nd...,vnd...->vnc...", probs, probs, mat
)
if module.reduction == "mean":
- N = module.input0.shape[0]
- Hmat /= N
+ Hmat /= self._get_mean_normalization(module.input0)
return Hmat
return hessian_mat_prod
- def hessian_is_psd(self):
- """Return whether cross-entropy loss Hessian is positive semi-definite."""
+ def hessian_is_psd(self) -> bool:
+ """Return whether cross-entropy loss Hessian is positive semi-definite.
+
+ Returns:
+ True
+ """
return True
- def _get_probs(self, module):
- return softmax(module.input0, dim=1)
+ @staticmethod
+ def _get_probs(module: CrossEntropyLoss, subsampling: List[int] = None) -> Tensor:
+ """Compute the softmax probabilities from the module input.
- def _check_2nd_order_parameters(self, module):
+ Args:
+ module: cross-entropy loss with I/O.
+ subsampling: Indices of samples to be considered. Default of ``None`` uses
+ the full mini-batch.
+
+ Returns:
+ Softmax probabilites
+ """
+ input0 = subsample(module.input0, subsampling=subsampling)
+ return softmax(input0, dim=1)
+
+ def _check_2nd_order_parameters(self, module: CrossEntropyLoss) -> None:
"""Verify that the parameters are supported by 2nd-order quantities.
- Attributes:
- module (torch.nn.CrossEntropyLoss): Extended CrossEntropyLoss module
+ Args:
+ module: Extended CrossEntropyLoss module
Raises:
NotImplementedError: If module's setting is not implemented.
@@ -116,3 +169,99 @@ def _check_2nd_order_parameters(self, module):
implemented_weight, module.weight
)
)
+
+ @staticmethod
+ def _merge_batch_and_additional(
+ probs: Tensor,
+ ) -> Tuple[Tensor, str, Dict[str, int]]:
+ """Rearranges the input if it has additional axes.
+
+ Treat additional axes like batch axis, i.e. group ``n c d1 d2 -> (n d1 d2) c``.
+
+ Args:
+ probs: the tensor to rearrange
+
+ Returns:
+ a tuple containing
+ - probs: the rearranged tensor
+ - str_d_dims: a string representation of the additional dimensions
+ - d_info: a dictionary encoding the size of the additional dimensions
+ """
+ leading = 2
+ additional = probs.dim() - leading
+
+ str_d_dims: str = "".join(f"d{i} " for i in range(additional))
+ d_info: Dict[str, int] = {
+ f"d{i}": probs.shape[leading + i] for i in range(additional)
+ }
+
+ probs = rearrange(probs, f"n c {str_d_dims} -> (n {str_d_dims}) c")
+
+ return probs, str_d_dims, d_info
+
+ @staticmethod
+ def _ungroup_batch_and_additional(
+ tensor: Tensor, str_d_dims, d_info, free_axis: int = 1
+ ) -> Tensor:
+ """Rearranges output if it has additional axes.
+
+ Used with group_batch_and_additional.
+
+ Undoes treating additional axes like batch axis and assumes an number of
+ additional free axes (``v``) were added, i.e. un-groups
+ ``v (n d1 d2) c -> v n c d1 d2``.
+
+ Args:
+ tensor: the tensor to rearrange
+ str_d_dims: a string representation of the additional dimensions
+ d_info: a dictionary encoding the size of the additional dimensions
+ free_axis: Number of free leading axes. Default: ``1``.
+
+ Returns:
+ the rearranged tensor
+
+ Raises:
+ NotImplementedError: If ``free_axis != 1``.
+ """
+ if free_axis != 1:
+ raise NotImplementedError(f"Only supports free_axis=1. Got {free_axis}.")
+
+ return rearrange(
+ tensor, f"v (n {str_d_dims}) c -> v n c {str_d_dims}", **d_info
+ )
+
+ @staticmethod
+ def _expand_sqrt_h(sqrt_h: Tensor) -> Tensor:
+ """Expands the square root hessian if CrossEntropyLoss has additional axes.
+
+ In the case of e.g. two additional axes (A and B), the input is [N,C,A,B].
+ In CrossEntropyLoss the additional axes are treated independently.
+ Therefore, the intermediate result has shape [C,N,C,A,B].
+ In subsequent calculations the additional axes are not independent anymore.
+ The required shape for sqrt_h_full is then [C*A*B,N,C,A,B].
+ Due to the independence, sqrt_h lives on the diagonal of sqrt_h_full.
+
+ Args:
+ sqrt_h: intermediate result, shape [C,N,C,A,B]
+
+ Returns:
+ sqrt_h_full, shape [C*A*B,N,C,A,B], sqrt_h on diagonal.
+ """
+ if sqrt_h.dim() > 3:
+ return diag_embed(sqrt_h.flatten(3), offset=0, dim1=1, dim2=4).reshape(
+ -1, *sqrt_h.shape[1:]
+ )
+ else:
+ return sqrt_h
+
+ @staticmethod
+ def _get_mean_normalization(input: Tensor) -> int:
+ """Get normalization constant used with reduction='mean'.
+
+ Args:
+ input: Input to the cross-entropy module.
+
+ Returns:
+ Divisor for mean reduction.
+ """
+ return input.numel() // input.shape[1]
diff --git a/backpack/core/derivatives/dropout.py b/backpack/core/derivatives/dropout.py
index 964deaa77..b32aef49a 100644
--- a/backpack/core/derivatives/dropout.py
+++ b/backpack/core/derivatives/dropout.py
@@ -1,13 +1,38 @@
-from torch import eq
+"""Partial derivatives for the dropout layer."""
+from typing import List, Tuple
+
+from torch import Tensor, eq, ones_like
+from torch.nn import Dropout
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class DropoutDerivatives(ElementwiseDerivatives):
- def hessian_is_zero(self):
+ """Derivatives for the Dropout module."""
+
+ def hessian_is_zero(self, module: Dropout) -> bool:
+ """``Dropout''(x) = 0``.
+
+ Args:
+ module: dropout module
+
+ Returns:
+ whether hessian is zero
+ """
return True
- def df(self, module, g_inp, g_out):
- scaling = 1 / (1 - module.p)
- mask = 1 - eq(module.output, 0.0).float()
- return mask * scaling
+ def df(
+ self,
+ module: Dropout,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor: # noqa: D102
+ output = subsample(module.output, subsampling=subsampling)
+ if module.training:
+ scaling = 1 / (1 - module.p)
+ mask = 1 - eq(output, 0.0).to(output.dtype)
+ return mask * scaling
+ else:
+ return ones_like(output)
diff --git a/backpack/core/derivatives/elementwise.py b/backpack/core/derivatives/elementwise.py
index d6f10bce5..33c2b4511 100644
--- a/backpack/core/derivatives/elementwise.py
+++ b/backpack/core/derivatives/elementwise.py
@@ -1,6 +1,9 @@
"""Base class for more flexible Jacobians/Hessians of activation functions."""
-from torch import einsum
+from typing import List, Tuple
+
+from torch import Tensor, einsum
+from torch.nn import Module
from backpack.core.derivatives.basederivatives import BaseDerivatives
@@ -20,18 +23,24 @@ class ElementwiseDerivatives(BaseDerivatives):
- If the activation is piece-wise linear: `hessian_is_zero`, else `d2f`.
"""
- def df(self, module, g_inp, g_out):
+ def df(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ):
"""Elementwise first derivative.
Args:
- module (torch.nn.Module): PyTorch activation function module.
- g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
- g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
+ module: PyTorch activation module.
+ g_inp: Gradients of the module w.r.t. its inputs.
+ g_out: Gradients of the module w.r.t. its outputs.
+ subsampling: Indices of active samples. ``None`` means all samples.
Returns:
- (torch.Tensor): Tensor containing the derivatives `f'(input[i]) ∀ i`.
+ Tensor containing the derivatives `f'(input[i]) ∀ i`.
"""
-
raise NotImplementedError("First derivatives not implemented")
def d2f(self, module, g_inp, g_out):
@@ -40,14 +49,13 @@ def d2f(self, module, g_inp, g_out):
Only needs to be implemented for non piece-wise linear functions.
Args:
- module (torch.nn.Module): PyTorch activation function module.
+ module (torch.nn.Module): PyTorch activation module.
g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
Returns:
(torch.Tensor): Tensor containing the derivatives `f''(input[i]) ∀ i`.
"""
-
raise NotImplementedError("Second derivatives not implemented")
def hessian_diagonal(self, module, g_inp, g_out):
@@ -57,15 +65,13 @@ def hessian_diagonal(self, module, g_inp, g_out):
- Only required if `hessian_is_diagonal` returns `True`.
Args:
- module (torch.nn.Module): PyTorch activation function module.
+ module (torch.nn.Module): PyTorch activation module.
g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
"""
- self._no_inplace(module)
-
return self.d2f(module, g_inp, g_out) * g_out[0]
- def hessian_is_diagonal(self):
+ def hessian_is_diagonal(self, module):
"""Elementwise activation function Hessians are diagonal.
Returns:
@@ -73,46 +79,25 @@ def hessian_is_diagonal(self):
"""
return True
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
- self._no_inplace(module)
-
- df_elementwise = self.df(module, g_inp, g_out)
- return einsum("...,v...->v...", (df_elementwise, mat))
+ def _jac_t_mat_prod(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ df_elementwise = self.df(module, g_inp, g_out, subsampling=subsampling)
+ return einsum("...,v...->v...", df_elementwise, mat)
def _jac_mat_prod(self, module, g_inp, g_out, mat):
- self._no_inplace(module)
-
return self.jac_t_mat_prod(module, g_inp, g_out, mat)
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
- self._no_inplace(module)
-
N = module.input0.size(0)
df_flat = self.df(module, g_inp, g_out).reshape(N, -1)
- return einsum("ni,nj,ij->ij", (df_flat, df_flat, mat)) / N
+ return einsum("ni,nj,ij->ij", df_flat, df_flat, mat) / N
def _residual_mat_prod(self, module, g_inp, g_out, mat):
residual = self.d2f(module, g_inp, g_out) * g_out[0]
- return einsum("...,v...->v...", (residual, mat))
-
- @staticmethod
- def _no_inplace(module):
- """Do not support inplace modification.
-
- Jacobians/Hessians might be computed using the modified input instead
- of the original.
-
- Args:
- module (torch.nn.Module): Elementwise activation module.
-
- Raises:
- NotImplementedError: If `module` has inplace option enabled.
-
- Todo:
- - Write tests to investigate what happens with `inplace=True`.
- """
- has_inplace_option = hasattr(module, "inplace")
-
- if has_inplace_option:
- if module.inplace is True:
- raise NotImplementedError("Inplace not supported in {}.".format(module))
+ return einsum("...,v...->v...", residual, mat)
diff --git a/backpack/core/derivatives/elu.py b/backpack/core/derivatives/elu.py
index 74092e883..5d3778223 100644
--- a/backpack/core/derivatives/elu.py
+++ b/backpack/core/derivatives/elu.py
@@ -1,22 +1,33 @@
"""Partial derivatives for the ELU activation function."""
-from torch import exp, le, ones_like, zeros_like
+from typing import List, Tuple
+
+from torch import Tensor, exp, le, ones_like, zeros_like
+from torch.nn import ELU
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class ELUDerivatives(ElementwiseDerivatives):
"""Implement first- and second-order partial derivatives of ELU."""
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module: ELU) -> bool:
"""`ELU''(x) ≠ 0`."""
return False
- def df(self, module, g_inp, g_out):
+ def df(
+ self,
+ module: ELU,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ):
"""First ELU derivative: `ELU'(x) = alpha * e^x if x <= 0 else 1`."""
- non_pos = le(module.input0, 0)
+ input0 = subsample(module.input0, subsampling=subsampling)
+ non_pos = le(input0, 0)
- result = ones_like(module.input0)
- result[non_pos] = module.alpha * exp(module.input0[non_pos])
+ result = ones_like(input0)
+ result[non_pos] = module.alpha * exp(input0[non_pos])
return result
diff --git a/backpack/core/derivatives/embedding.py b/backpack/core/derivatives/embedding.py
new file mode 100644
index 000000000..acb191be8
--- /dev/null
+++ b/backpack/core/derivatives/embedding.py
@@ -0,0 +1,66 @@
+"""Derivatives for Embedding."""
+from typing import List, Tuple
+
+from torch import Tensor, einsum, zeros
+from torch.nn import Embedding
+
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
+from backpack.utils.subsampling import subsample
+
+
+class EmbeddingDerivatives(BaseParameterDerivatives):
+ """Derivatives for Embedding.
+
+ Note:
+ These derivatives assume the batch axis to be at position 0.
+
+ Index convention:
+ v - free axis
+ n - batch axis
+ s - num_embeddings
+ h - embedding_dim
+ """
+
+ def _jac_t_mat_prod(
+ self,
+ module: Embedding,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ raise NotImplementedError(
+ "Derivative w.r.t. input not defined: Input to Embedding has type long."
+ " But only float and complex dtypes can require gradients in PyTorch."
+ )
+
+ def _weight_jac_t_mat_prod(
+ self,
+ module: Embedding,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self._check_parameters(module)
+
+ input0 = subsample(module.input0, subsampling=subsampling)
+ delta = zeros(module.num_embeddings, *input0.shape, device=mat.device)
+ for s in range(module.num_embeddings):
+ delta[s] = input0 == s
+ equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh"
+ return einsum(equation, delta, mat)
+
+ def _check_parameters(self, module: Embedding) -> None:
+ if module.padding_idx is not None:
+ raise NotImplementedError("Only padding_idx=None supported.")
+ elif module.max_norm is not None:
+ raise NotImplementedError("Only max_norm=None supported.")
+ elif module.scale_grad_by_freq:
+ raise NotImplementedError("Only scale_grad_by_freq=False supported.")
+ elif module.sparse:
+ raise NotImplementedError("Only sparse=False supported.")
+
+ def hessian_is_zero(self, module: Embedding) -> bool: # noqa: D102
+ return False
diff --git a/backpack/core/derivatives/flatten.py b/backpack/core/derivatives/flatten.py
index 8366c770e..aac7f7992 100644
--- a/backpack/core/derivatives/flatten.py
+++ b/backpack/core/derivatives/flatten.py
@@ -1,25 +1,34 @@
+"""Partial derivatives of the flatten layer."""
+from typing import List, Tuple
+
+from torch import Tensor
+from torch.nn import Flatten
+
from backpack.core.derivatives.basederivatives import BaseDerivatives
class FlattenDerivatives(BaseDerivatives):
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
return True
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
return mat
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
- return self.reshape_like_input(mat, module)
+ def _jac_t_mat_prod(
+ self,
+ module: Flatten,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ return self.reshape_like_input(mat, module, subsampling=subsampling)
- def _jac_mat_prod(self, module, g_inp, g_out, mat):
+ def _jac_mat_prod(
+ self,
+ module: Flatten,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ ) -> Tensor:
return self.reshape_like_output(mat, module)
-
- def is_no_op(self, module):
- """Does flatten add an operation to the computational graph.
-
- If the input is already flattened, no operation will be added for
- the `Flatten` layer. This can lead to an intuitive order of backward
- hook execution, see the discussion at https://discuss.pytorch.org/t/
- backward-hooks-changing-order-of-execution-in-nn-sequential/12447/4 .
- """
- return tuple(module.input0.shape) == tuple(module.output.shape)
diff --git a/backpack/core/derivatives/leakyrelu.py b/backpack/core/derivatives/leakyrelu.py
index 60a650c93..7cb0dfa1e 100644
--- a/backpack/core/derivatives/leakyrelu.py
+++ b/backpack/core/derivatives/leakyrelu.py
@@ -1,16 +1,27 @@
-from torch import gt
+"""Partial derivatives for the leaky ReLU layer."""
+from typing import List, Tuple
+
+from torch import Tensor, gt
+from torch.nn import LeakyReLU
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class LeakyReLUDerivatives(ElementwiseDerivatives):
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module: LeakyReLU) -> bool:
"""`LeakyReLU''(x) = 0`."""
return True
- def df(self, module, g_inp, g_out):
- """First LeakyReLU derivative:
- `LeakyReLU'(x) = negative_slope if x < 0 else 1`."""
- df_leakyrelu = gt(module.input0, 0).float()
+ def df(
+ self,
+ module: LeakyReLU,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """``LeakyReLU'(x) = negative_slope if x < 0 else 1``."""
+ input0 = subsample(module.input0, subsampling=subsampling)
+ df_leakyrelu = gt(input0, 0).to(input0.dtype)
df_leakyrelu[df_leakyrelu == 0] = module.negative_slope
return df_leakyrelu
diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py
index 0e1f5f32b..a3156f927 100644
--- a/backpack/core/derivatives/linear.py
+++ b/backpack/core/derivatives/linear.py
@@ -1,6 +1,11 @@
-from torch import einsum
+"""Contains partial derivatives for the ``torch.nn.Linear`` layer."""
+from typing import List, Tuple
+
+from torch import Size, Tensor, einsum
+from torch.nn import Linear
from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
+from backpack.utils.subsampling import subsample
class LinearDerivatives(BaseParameterDerivatives):
@@ -14,43 +19,208 @@ class LinearDerivatives(BaseParameterDerivatives):
* i: Input dimension
"""
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module: Linear) -> bool:
+ """Linear layer output is linear w.r.t. to its input.
+
+ Args:
+ module: current module
+
+ Returns:
+ True
+ """
return True
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
- """Apply transposed Jacobian of the output w.r.t. the input."""
- d_input = module.weight.data
- return einsum("oi,vno->vni", (d_input, mat))
-
- def _jac_mat_prod(self, module, g_inp, g_out, mat):
- """Apply Jacobian of the output w.r.t. the input."""
- d_input = module.weight.data
- return einsum("oi,vni->vno", (d_input, mat))
-
- def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
- jac = module.weight.data
- return einsum("ik,ij,jl->kl", (jac, mat, jac))
-
- def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
- """Apply Jacobian of the output w.r.t. the weight."""
- d_weight = module.input0
- return einsum("ni,voi->vno", (d_weight, mat))
-
- def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- """Apply transposed Jacobian of the output w.r.t. the weight."""
- d_weight = module.input0
- contract = "vno,ni->voi" if sum_batch else "vno,ni->vnoi"
- return einsum(contract, (mat, d_weight))
-
- def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
- """Apply Jacobian of the output w.r.t. the bias."""
- N = module.input0.size(0)
- return mat.unsqueeze(1).expand(-1, N, -1)
-
- def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
- """Apply transposed Jacobian of the output w.r.t. the bias."""
- if sum_batch:
- N_axis = 1
- return mat.sum(N_axis)
- else:
- return mat
+ def _jac_t_mat_prod(
+ self,
+ module: Linear,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Batch-apply transposed Jacobian of the output w.r.t. the input.
+
+ Args:
+ module: Linear layer.
+ g_inp: Gradients w.r.t. module input. Not required by the implementation.
+ g_out: Gradients w.r.t. module output. Not required by the implementation.
+ mat: Batch of ``V`` vectors of same shape as the layer output
+ (``[N, *, out_features]``) to which the transposed output-input Jacobian
+ is applied. Has shape ``[V, N, *, out_features]``; but if used with
+ sub-sampling, ``N`` is replaced by ``len(subsampling)``.
+ subsampling: Indices of active samples. ``None`` means all samples.
+
+ Returns:
+ Batched transposed Jacobian vector products. Has shape
+ ``[V, N, *, in_features]``. If used with sub-sampling, ``N`` is replaced
+ by ``len(subsampling)``.
+ """
+ return einsum("vn...o,oi->vn...i", mat, module.weight)
+
+ def _jac_mat_prod(
+ self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
+ """Batch-apply Jacobian of the output w.r.t. the input.
+
+ Args:
+ module: Linear layer.
+ g_inp: Gradients w.r.t. module input. Not required by the implementation.
+ g_out: Gradients w.r.t. module output. Not required by the implementation.
+ mat: Batch of ``V`` vectors of same shape as the layer input
+ (``[N, *, in_features]``) to which the output-input Jacobian is applied.
+ Has shape ``[V, N, *, in_features]``.
+
+ Returns:
+ Batched Jacobian vector products. Has shape ``[V, N, *, out_features]``.
+ """
+ return einsum("oi,vn...i->vn...o", module.weight, mat)
+
+ def ea_jac_t_mat_jac_prod(
+ self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
+ """Expectation approximation of outer product with input-output Jacobian.
+
+ Used for KFRA backpropagation: ``mat ← E(Jₙᵀ mat Jₙ) = 1/N ∑ₙ Jₙᵀ mat Jₙ``.
+
+ Args:
+ module: Linear layer.
+ g_inp: Gradients w.r.t. module input. Not required by the implementation.
+ g_out: Gradients w.r.t. module output. Not required by the implementation.
+ mat: Matrix of shape
+ ``[module.output.numel() // N, module.output.numel() // N]``.
+
+ Returns:
+ Matrix of shape
+ ``[module.input0.numel() // N, module.input0.numel() // N]``.
+ """
+ add_features = self._get_additional_dims(module).numel()
+ in_features, out_features = module.in_features, module.out_features
+
+ result = mat.reshape(add_features, out_features, add_features, out_features)
+ result = einsum("ik,xiyj,jl->xkyl", module.weight, result, module.weight)
+
+ return result.reshape(in_features * add_features, in_features * add_features)
+
+ def _weight_jac_mat_prod(
+ self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
+ """Batch-apply Jacobian of the output w.r.t. the weight.
+
+ Args:
+ module: Linear layer.
+ g_inp: Gradients w.r.t. module input. Not required by the implementation.
+ g_out: Gradients w.r.t. module output. Not required by the implementation.
+ mat: Batch of ``V`` vectors of shape ``module.weight.shape`` to which the
+ transposed output-input Jacobian is applied. Has shape
+ ``[V, *module.weight.shape]``.
+
+ Returns:
+ Batched Jacobian vector products. Has shape
+ ``[V, N, *module.output.shape]``.
+ """
+ return einsum("n...i,voi->vn...o", module.input0, mat)
+
+ def _weight_jac_t_mat_prod(
+ self,
+ module: Linear,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: int = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Batch-apply transposed Jacobian of the output w.r.t. the weight.
+
+ Args:
+ module: Linear layer.
+ g_inp: Gradients w.r.t. module input. Not required by the implementation.
+ g_out: Gradients w.r.t. module output. Not required by the implementation.
+ mat: Batch of ``V`` vectors of same shape as the layer output
+ (``[N, *, out_features]``) to which the transposed output-input Jacobian
+ is applied. Has shape ``[V, N, *, out_features]`` if subsampling is not
+ used, otherwise ``N`` must be ``len(subsampling)`` instead.
+ sum_batch: Sum the result's batch axis. Default: ``True``.
+ subsampling: Indices of samples along the output's batch dimension that
+ should be considered. Defaults to ``None`` (use all samples).
+
+ Returns:
+ Batched transposed Jacobian vector products. Has shape
+ ``[V, N, *module.weight.shape]`` when ``sum_batch`` is ``False``. With
+ ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. If sub-
+ sampling is used, ``N`` must be ``len(subsampling)`` instead.
+ """
+ d_weight = subsample(module.input0, subsampling=subsampling)
+
+ equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi"
+ return einsum(equation, mat, d_weight)
+
+ def _bias_jac_mat_prod(
+ self, module: Linear, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
+ """Batch-apply Jacobian of the output w.r.t. the bias.
+
+ Args:
+ module: Linear layer.
+ g_inp: Gradients w.r.t. module input. Not required by the implementation.
+ g_out: Gradients w.r.t. module output. Not required by the implementation.
+ mat: Batch of ``V`` vectors of shape ``module.bias.shape`` to which the
+ transposed output-input Jacobian is applied. Has shape
+ ``[V, *module.bias.shape]``.
+
+ Returns:
+ Batched Jacobian vector products. Has shape
+ ``[V, N, *module.output.shape]``.
+ """
+ N = module.input0.shape[0]
+ additional_dims = list(self._get_additional_dims(module))
+
+ for _ in range(len(additional_dims) + 1):
+ mat = mat.unsqueeze(1)
+
+ expand = [-1, N] + additional_dims + [-1]
+
+ return mat.expand(*expand)
+
+ def _bias_jac_t_mat_prod(
+ self,
+ module: Linear,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: int = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Batch-apply transposed Jacobian of the output w.r.t. the bias.
+
+ Args:
+ module: Linear layer.
+ g_inp: Gradients w.r.t. module input. Not required by the implementation.
+ g_out: Gradients w.r.t. module output. Not required by the implementation.
+ mat: Batch of ``V`` vectors of same shape as the layer output
+ (``[N, *, out_features]``) to which the transposed output-input Jacobian
+ is applied. Has shape ``[V, N, *, out_features]``.
+ sum_batch: Sum the result's batch axis. Default: ``True``.
+ subsampling: Indices of samples along the output's batch dimension that
+ should be considered. Defaults to ``None`` (use all samples).
+
+ Returns:
+ Batched transposed Jacobian vector products. Has shape
+ ``[V, N, *module.bias.shape]`` when ``sum_batch`` is ``False``. With
+ ``sum_batch=True``, has shape ``[V, *module.bias.shape]``. If sub-
+ sampling is used, ``N`` is replaced by ``len(subsampling)``.
+ """
+ equation = f"vn...o->v{'' if sum_batch else 'n'}o"
+ return einsum(equation, mat)
+
+ @staticmethod
+ def _get_additional_dims(module: Linear) -> Size:
+ """Return the shape of additional dimensions in the input to a linear layer.
+
+ Args:
+ module: A linear layer.
+
+ Returns:
+ Shape of the additional dimensions. Corresponds to ``*`` in the
+ input shape ``[N, *, out_features]``.
+ """
+ return module.input0.shape[1:-1]
diff --git a/backpack/core/derivatives/logsigmoid.py b/backpack/core/derivatives/logsigmoid.py
index f9e016c11..917784010 100644
--- a/backpack/core/derivatives/logsigmoid.py
+++ b/backpack/core/derivatives/logsigmoid.py
@@ -1,16 +1,28 @@
-from torch import exp
+"""Contains partial derivatives for the ``torch.nn.LogSigmoid`` layer."""
+from typing import List, Tuple
+
+from torch import Tensor, exp
+from torch.nn import LogSigmoid
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class LogSigmoidDerivatives(ElementwiseDerivatives):
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
"""`logsigmoid''(x) ≠ 0`."""
return False
- def df(self, module, g_inp, g_out):
+ def df(
+ self,
+ module: LogSigmoid,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""First Logsigmoid derivative: `logsigmoid'(x) = 1 / (e^x + 1) `."""
- return 1 / (exp(module.input0) + 1)
+ input0 = subsample(module.input0, subsampling=subsampling)
+ return 1 / (exp(input0) + 1)
def d2f(self, module, g_inp, g_out):
"""Second Logsigmoid derivative: `logsigmoid''(x) = - e^x / (e^x + 1)^2`."""
diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py
new file mode 100644
index 000000000..def7e80da
--- /dev/null
+++ b/backpack/core/derivatives/lstm.py
@@ -0,0 +1,328 @@
+"""Partial derivatives for nn.LSTM."""
+from typing import List, Tuple
+
+from torch import Tensor, cat, einsum, sigmoid, tanh, zeros
+from torch.nn import LSTM
+
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
+from backpack.utils.subsampling import subsample
+
+
+class LSTMDerivatives(BaseParameterDerivatives):
+ """Partial derivatives for nn.LSTM layer.
+
+ Index conventions:
+ ------------------
+ * t: Sequence dimension
+ * v: Free dimension
+ * n: Batch dimension
+ * h: Output dimension
+ * i: Input dimension
+
+ LSTM forward pass (definition of variables):
+ see https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
+ ifgo_tilde[t] = W_ih x[t] + b_ii + W_hh h[t-1] + b_hh
+ ifgo[t] = sigma(ifgo_tilde[t]) for i, f, o
+ ifgo[t] = tanh(ifgo_tilde[t]) for g
+ c[t] = f[t] c[t-1] + i[t] g[t]
+ h[t] = o[t] tanh(c[t])
+
+ In general, we assume that it is batch axis first
+ and the order of axis is (V, N, T, H).
+ """
+
+ @staticmethod
+ def _check_parameters(module: LSTM) -> None:
+ """Check the parameters of module.
+
+ Args:
+ module: module which to check
+
+ Raises:
+ NotImplementedError: If any parameter of module does not match expectation
+ """
+ if not module.batch_first:
+ raise NotImplementedError("Batch axis must be first.")
+ if module.num_layers != 1:
+ raise NotImplementedError("only num_layers = 1 is supported")
+ if module.bias is not True:
+ raise NotImplementedError("only bias = True is supported")
+ if module.dropout != 0:
+ raise NotImplementedError("only dropout = 0 is supported")
+ if module.bidirectional is not False:
+ raise NotImplementedError("only bidirectional = False is supported")
+ if module.proj_size != 0:
+ raise NotImplementedError("only proj_size = 0 is supported")
+
+ @staticmethod
+ def _forward_pass(
+ module: LSTM, mat: Tensor, subsampling: List[int] = None
+ ) -> Tuple[Tensor, Tensor, Tensor]:
+ """This performs an additional forward pass and returns the hidden variables.
+
+ This is important because the PyTorch implementation does not grant access to
+ some of the hidden variables. Those are computed and returned.
+
+ See also forward pass in class docstring.
+
+ Args:
+ module: module
+ mat: matrix, used to extract device and shapes.
+ subsampling: Indices of active samples. Defaults to ``None`` (all samples).
+
+ Returns:
+ ifgo, c, c_tanh (all in format ``[N, T, ...]``)
+ """
+ _, N, T, _ = mat.shape
+ H: int = module.hidden_size
+ H0: int = 0 * H
+ H1: int = 1 * H
+ H2: int = 2 * H
+ H3: int = 3 * H
+ H4: int = 4 * H
+ # forward pass and save i, f, g, o, c, c_tanh-> ifgo, c, c_tanh
+ ifgo: Tensor = zeros(N, T, 4 * H, device=mat.device, dtype=mat.dtype)
+ c: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype)
+ c_tanh: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype)
+
+ input0 = subsample(module.input0, dim=0, subsampling=subsampling)
+ output = subsample(module.output, dim=0, subsampling=subsampling)
+
+ for t in range(T):
+ ifgo[:, t] = (
+ einsum("hi,ni->nh", module.weight_ih_l0, input0[:, t])
+ + module.bias_ih_l0
+ + module.bias_hh_l0
+ )
+ if t != 0:
+ ifgo[:, t] += einsum("hg,ng->nh", module.weight_hh_l0, output[:, t - 1])
+ ifgo[:, t, H0:H1] = sigmoid(ifgo[:, t, H0:H1])
+ ifgo[:, t, H1:H2] = sigmoid(ifgo[:, t, H1:H2])
+ ifgo[:, t, H2:H3] = tanh(ifgo[:, t, H2:H3])
+ ifgo[:, t, H3:H4] = sigmoid(ifgo[:, t, H3:H4])
+ c[:, t] = ifgo[:, t, H0:H1] * ifgo[:, t, H2:H3]
+ if t != 0:
+ c[:, t] += ifgo[:, t, H1:H2] * c[:, t - 1]
+ c_tanh[:, t] = tanh(c[:, t])
+
+ return ifgo, c, c_tanh
+
+ @classmethod
+ def _ifgo_jac_t_mat_prod(
+ cls, module: LSTM, mat: Tensor, subsampling: List[int] = None
+ ) -> Tensor:
+ V, N, T, H = mat.shape
+ H0: int = 0 * H
+ H1: int = 1 * H
+ H2: int = 2 * H
+ H3: int = 3 * H
+ H4: int = 4 * H
+
+ ifgo, c, c_tanh = cls._forward_pass(module, mat, subsampling=subsampling)
+
+ # backward pass
+ H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
+ C_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
+ C_prod_old: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
+ IFGO_prod: Tensor = zeros(V, N, T, 4 * H, device=mat.device, dtype=mat.dtype)
+ for t in reversed(range(T)):
+ # jac_t_mat_prod until node h
+ H_prod_t[:] = mat[:, :, t]
+ if t != (T - 1):
+ H_prod_t += einsum(
+ "vnh,hg->vng", IFGO_prod[:, :, t + 1], module.weight_hh_l0
+ )
+
+ # C_prod_t = jac_t_mat_prod until node c
+ if t != (T - 1):
+ C_prod_old[:] = C_prod_t
+ C_prod_t[:] = einsum(
+ "vnh,nh->vnh", H_prod_t, ifgo[:, t, H3:H4] * (1 - c_tanh[:, t] ** 2)
+ )
+ if t != (T - 1):
+ C_prod_t += einsum("vnh,nh->vnh", C_prod_old, ifgo[:, t + 1, H1:H2])
+
+ IFGO_prod[:, :, t, H3:H4] = einsum(
+ "vnh,nh->vnh",
+ H_prod_t,
+ c_tanh[:, t] * (ifgo[:, t, H3:H4] * (1 - ifgo[:, t, H3:H4])),
+ )
+ IFGO_prod[:, :, t, H0:H1] = einsum(
+ "vnh,nh->vnh",
+ C_prod_t,
+ ifgo[:, t, H2:H3] * (ifgo[:, t, H0:H1] * (1 - ifgo[:, t, H0:H1])),
+ )
+ if t >= 1:
+ IFGO_prod[:, :, t, H1:H2] = einsum(
+ "vnh,nh->vnh",
+ C_prod_t,
+ c[:, t - 1] * (ifgo[:, t, H1:H2] * (1 - ifgo[:, t, H1:H2])),
+ )
+ IFGO_prod[:, :, t, H2:H3] = einsum(
+ "vnh,nh->vnh",
+ C_prod_t,
+ ifgo[:, t, H0:H1] * (1 - ifgo[:, t, H2:H3] ** 2),
+ )
+ return IFGO_prod
+
+ def hessian_is_zero(self, module: LSTM) -> bool: # noqa: D102
+ return False
+
+ def _jac_mat_prod(
+ self,
+ module: LSTM,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ ) -> Tensor:
+ V, N, T, _ = mat.shape
+ H: int = module.hidden_size
+ H0: int = 0 * H
+ H1: int = 1 * H
+ H2: int = 2 * H
+ H3: int = 3 * H
+ H4: int = 4 * H
+
+ ifgo, c, c_tanh = self._forward_pass(module, mat)
+ H_prod: Tensor = zeros(V, N, T, H, device=mat.device, dtype=mat.dtype)
+ C_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
+ C_prod_old: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
+ C_tanh_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
+ IFGO_prod_t: Tensor = zeros(V, N, 4 * H, device=mat.device, dtype=mat.dtype)
+ for t in range(T):
+ # product until nodes ifgo
+ IFGO_prod_t[:] = einsum(
+ "hi,vni->vnh",
+ module.weight_ih_l0,
+ mat[:, :, t],
+ )
+ if t != 0:
+ IFGO_prod_t[:] += einsum(
+ "hg,vng->vnh", module.weight_hh_l0, H_prod[:, :, t - 1]
+ )
+ IFGO_prod_t[:, :, H0:H2] = einsum(
+ "vnh,nh->vnh",
+ IFGO_prod_t[:, :, H0:H2],
+ ifgo[:, t, H0:H2] * (1 - ifgo[:, t, H0:H2]),
+ )
+ IFGO_prod_t[:, :, H3:H4] = einsum(
+ "vnh,nh->vnh",
+ IFGO_prod_t[:, :, H3:H4],
+ ifgo[:, t, H3:H4] * (1 - ifgo[:, t, H3:H4]),
+ )
+ IFGO_prod_t[:, :, H2:H3] = einsum(
+ "vnh,nh->vnh",
+ IFGO_prod_t[:, :, H2:H3],
+ 1 - ifgo[:, t, H2:H3] ** 2,
+ )
+
+ # product until node c
+ if t >= 1:
+ C_prod_old[:] = C_prod_t
+ C_prod_t[:] = einsum(
+ "vnh,nh->vnh", IFGO_prod_t[:, :, H0:H1], ifgo[:, t, H2:H3]
+ ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H2:H3], ifgo[:, t, H0:H1])
+ if t >= 1:
+ C_prod_t += einsum(
+ "vnh,nh->vnh", C_prod_old, ifgo[:, t, H1:H2]
+ ) + einsum("vnh,nh->vnh", IFGO_prod_t[:, :, H1:H2], c[:, t - 1])
+
+ # product until node c_tanh
+ C_tanh_prod_t[:] = einsum("vnh,nh->vnh", C_prod_t, 1 - c_tanh[:, t] ** 2)
+
+ # product until node h
+ H_prod[:, :, t] = einsum(
+ "vnh,nh->vnh", IFGO_prod_t[:, :, H3:H4], c_tanh[:, t]
+ ) + einsum("vnh,nh->vnh", C_tanh_prod_t, ifgo[:, t, H3:H4])
+
+ return H_prod
+
+ def _jac_t_mat_prod(
+ self,
+ module: LSTM,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self._check_parameters(module)
+
+ IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
+ module, mat, subsampling=subsampling
+ )
+ X_prod: Tensor = einsum("vnth,hi->vnti", IFGO_prod, module.weight_ih_l0)
+ return X_prod
+
+ def _bias_ih_l0_jac_t_mat_prod(
+ self,
+ module: LSTM,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self._check_parameters(module)
+
+ IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
+ module, mat, subsampling=subsampling
+ )
+
+ return einsum(f"vnth->v{'' if sum_batch else 'n'}h", IFGO_prod)
+
+ def _bias_hh_l0_jac_t_mat_prod(
+ self,
+ module: LSTM,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ return self._bias_ih_l0_jac_t_mat_prod(
+ module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
+ )
+
+ def _weight_ih_l0_jac_t_mat_prod(
+ self,
+ module: LSTM,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self._check_parameters(module)
+
+ IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
+ module, mat, subsampling=subsampling
+ )
+ return einsum(
+ f"vnth,nti->v{'' if sum_batch else 'n'}hi",
+ IFGO_prod,
+ subsample(module.input0, dim=0, subsampling=subsampling),
+ )
+
+ def _weight_hh_l0_jac_t_mat_prod(
+ self,
+ module: LSTM,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self._check_parameters(module)
+ _, N, _, H = mat.shape
+ IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
+ module, mat, subsampling=subsampling
+ )
+
+ subsampled_output = subsample(module.output, dim=0, subsampling=subsampling)
+ single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype)
+ return einsum(
+ f"vnth,ntg->v{'' if sum_batch else 'n'}hg",
+ IFGO_prod,
+ cat([single_step, subsampled_output[:, :-1]], dim=1),
+ )
diff --git a/backpack/core/derivatives/maxpoolnd.py b/backpack/core/derivatives/maxpoolnd.py
index b3960459b..dc1c6af59 100644
--- a/backpack/core/derivatives/maxpoolnd.py
+++ b/backpack/core/derivatives/maxpoolnd.py
@@ -1,28 +1,31 @@
+from typing import List, Tuple, Union
+
from einops import rearrange
-from torch import zeros
+from torch import Tensor, zeros
+from torch.nn import MaxPool1d, MaxPool2d, MaxPool3d
from torch.nn.functional import max_pool1d, max_pool2d, max_pool3d
from backpack.core.derivatives.basederivatives import BaseDerivatives
+from backpack.utils.subsampling import subsample
class MaxPoolNDDerivatives(BaseDerivatives):
- def __init__(self, N):
+ def __init__(self, N: int):
self.N = N
- if self.N == 1:
- self.maxpool = max_pool1d
- elif self.N == 2:
- self.maxpool = max_pool2d
- elif self.N == 3:
- self.maxpool = max_pool3d
- else:
- raise ValueError(
- "{}-dimensional Maxpool. is not implemented.".format(self.N)
- )
+ self.maxpool = {
+ 1: max_pool1d,
+ 2: max_pool2d,
+ 3: max_pool3d,
+ }[N]
# TODO: Do not recompute but get from forward pass of module
- def get_pooling_idx(self, module):
+ def get_pooling_idx(
+ self,
+ module: Union[MaxPool1d, MaxPool2d, MaxPool3d],
+ subsampling: List[int] = None,
+ ) -> Tensor:
_, pool_idx = self.maxpool(
- module.input0,
+ subsample(module.input0, subsampling=subsampling),
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
@@ -32,7 +35,7 @@ def get_pooling_idx(self, module):
)
return pool_idx
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
return True
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
@@ -48,22 +51,9 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
"""
device = mat.device
- if self.N == 1:
- N, C, L_in = module.input0.size()
- _, _, L_out = module.output.size()
- in_pixels = L_in
- out_pixels = L_out
- elif self.N == 2:
- N, C, H_in, W_in = module.input0.size()
- _, _, H_out, W_out = module.output.size()
- in_pixels = H_in * W_in
- out_pixels = H_out * W_out
- elif self.N == 3:
- N, C, D_in, H_in, W_in = module.input0.size()
- _, _, D_out, H_out, W_out = module.output.size()
- in_pixels = D_in * H_in * W_in
- out_pixels = D_out * H_out * W_out
-
+ N, C = module.input0.shape[:2]
+ in_pixels = module.input0.shape[2:].numel()
+ out_pixels = module.output.shape[2:].numel()
in_features = C * in_pixels
pool_idx = self.get_pooling_idx(module).view(N, C, out_pixels)
@@ -102,46 +92,56 @@ def __apply_jacobian_of(self, module, mat):
pool_idx = self.__pool_idx_for_jac(module, V)
return mat.gather(N_axis, pool_idx)
- def __pool_idx_for_jac(self, module, V):
+ def __pool_idx_for_jac(
+ self,
+ module: Union[MaxPool1d, MaxPool2d, MaxPool3d],
+ V: int,
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""Manipulated pooling indices ready-to-use in jac(t)."""
- pool_idx = self.get_pooling_idx(module)
+ pool_idx = self.get_pooling_idx(module, subsampling=subsampling)
pool_idx = rearrange(pool_idx, "n c ... -> n c (...)")
- V_axis = 0
-
- return pool_idx.unsqueeze(V_axis).expand(V, -1, -1, -1)
+ return pool_idx.unsqueeze(0).expand(V, -1, -1, -1)
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
+ def _jac_t_mat_prod(
+ self,
+ module: Union[MaxPool1d, MaxPool2d, MaxPool3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
mat_as_pool = rearrange(mat, "v n c ... -> v n c (...)")
- jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool)
- return self.reshape_like_input(jmp_as_pool, module)
-
- def __apply_jacobian_t_of(self, module, mat):
+ jmp_as_pool = self.__apply_jacobian_t_of(
+ module, mat_as_pool, subsampling=subsampling
+ )
+ return self.reshape_like_input(jmp_as_pool, module, subsampling=subsampling)
+
+ def __apply_jacobian_t_of(
+ self,
+ module: Union[MaxPool1d, MaxPool2d, MaxPool3d],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
V = mat.shape[0]
- result = self.__zero_for_jac_t(module, V, mat.device)
- pool_idx = self.__pool_idx_for_jac(module, V)
+ result = self.__zero_for_jac_t(module, V, subsampling=subsampling)
+ pool_idx = self.__pool_idx_for_jac(module, V, subsampling=subsampling)
N_axis = 3
result.scatter_add_(N_axis, pool_idx, mat)
return result
- def __zero_for_jac_t(self, module, V, device):
- if self.N == 1:
- N, C_out, _ = module.output.shape
- _, _, L_in = module.input0.size()
-
- shape = (V, N, C_out, L_in)
-
- elif self.N == 2:
- N, C_out, _, _ = module.output.shape
- _, _, H_in, W_in = module.input0.size()
-
- shape = (V, N, C_out, H_in * W_in)
-
- elif self.N == 3:
- N, C_out, _, _, _ = module.output.shape
- _, _, D_in, H_in, W_in = module.input0.size()
+ def __zero_for_jac_t(
+ self,
+ module: Union[MaxPool1d, MaxPool2d, MaxPool3d],
+ V: int,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ N, C_out = module.output.shape[:2]
+ in_pixels = module.input0.shape[2:].numel()
+ N = N if subsampling is None else len(subsampling)
- shape = (V, N, C_out, D_in * H_in * W_in)
+ shape = (V, N, C_out, in_pixels)
- return zeros(shape, device=device)
+ return zeros(shape, device=module.output.device, dtype=module.output.dtype)
diff --git a/backpack/core/derivatives/mseloss.py b/backpack/core/derivatives/mseloss.py
index e18c12b97..f09750052 100644
--- a/backpack/core/derivatives/mseloss.py
+++ b/backpack/core/derivatives/mseloss.py
@@ -1,8 +1,10 @@
"""Derivatives of the MSE Loss."""
from math import sqrt
+from typing import List, Tuple
-from torch import einsum, eye, normal
+from torch import Tensor, eye, normal, ones
+from torch.nn import MSELoss
from backpack.core.derivatives.basederivatives import BaseLossDerivatives
@@ -16,50 +18,52 @@ class MSELossDerivatives(BaseLossDerivatives):
`∑ᵢ₌₁ⁿ ‖X[i,∶] − Y[i,∶]‖²`. If `reduce=mean`, the result is divided by `nd`.
"""
- def _sqrt_hessian(self, module, g_inp, g_out):
- """Square-root of the hessian of the MSE for each minibatch elements.
-
- Returns the Hessian in format `Hs = [D, N, D]`, where
- `Hs[:, n, :]` is the Hessian for the `n`th element.
-
- Attributes:
- module: (torch.nn.MSELoss) module
- g_inp: Gradient of loss w.r.t. input
- g_out: Gradient of loss w.r.t. output
-
- Returns:
- Batch of hessians, in format [D, N, D]
- """
+ def _sqrt_hessian(
+ self,
+ module: MSELoss,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor: # noqa: D102
self.check_input_dims(module)
- N, D = module.input0.shape
- sqrt_H = sqrt(2) * eye(D, device=module.input0.device) # [D, D]
- sqrt_H = sqrt_H.unsqueeze(0).repeat(N, 1, 1) # [N, D, D]
- sqrt_H = einsum("nab->anb", sqrt_H) # [D, N, D]
+ input0: Tensor = module.input0
+ N, D = input0.shape
+ N_active = N if subsampling is None else len(subsampling)
+ scale = sqrt(2)
if module.reduction == "mean":
- sqrt_H /= sqrt(module.input0.numel())
+ scale /= sqrt(input0.numel())
- return sqrt_H
+ sqrt_H_diag = scale * ones(D, device=input0.device, dtype=input0.dtype)
+ sqrt_H = sqrt_H_diag.diag().unsqueeze(1).expand(-1, N_active, -1)
- def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
- """A Monte-Carlo estimate of the square-root of the Hessian.
+ return sqrt_H
- Attributes:
- module: (torch.nn.MSELoss) module.
- g_inp: Gradient of loss w.r.t. input.
- g_out: Gradient of loss w.r.t. output.
- mc_samples: (int, optional) Number of MC samples to use. Default: 1.
+ def _sqrt_hessian_sampled(
+ self,
+ module: MSELoss,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mc_samples: int = 1,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self.check_input_dims(module)
- Returns:
- tensor:
- """
- N, D = module.input0.shape
- samples = normal(0, 1, size=[mc_samples, N, D], device=module.input0.device)
+ input0: Tensor = module.input0
+ N, D = input0.shape
+ N_active = N if subsampling is None else len(subsampling)
+ samples = normal(
+ 0,
+ 1,
+ size=[mc_samples, N_active, D],
+ device=input0.device,
+ dtype=input0.dtype,
+ )
samples *= sqrt(2) / sqrt(mc_samples)
if module.reduction == "mean":
- samples /= sqrt(module.input0.numel())
+ samples /= sqrt(input0.numel())
return samples
diff --git a/backpack/core/derivatives/permute.py b/backpack/core/derivatives/permute.py
new file mode 100644
index 000000000..396803876
--- /dev/null
+++ b/backpack/core/derivatives/permute.py
@@ -0,0 +1,28 @@
+"""Module containing derivatives of Permute."""
+from typing import List, Tuple
+
+from torch import Tensor, argsort
+
+from backpack.core.derivatives.basederivatives import BaseDerivatives
+from backpack.custom_module.permute import Permute
+
+
+class PermuteDerivatives(BaseDerivatives):
+ """Derivatives of Permute."""
+
+ def _jac_t_mat_prod(
+ self,
+ module: Permute,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ return mat.permute(
+ [0] + [element + 1 for element in argsort(Tensor(module.dims))]
+ )
+
+ def _jac_mat_prod(
+ self, module: Permute, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
+ return mat.permute([0] + [element + 1 for element in module.dims])
diff --git a/backpack/core/derivatives/relu.py b/backpack/core/derivatives/relu.py
index eae9d5ebf..18dab75fa 100644
--- a/backpack/core/derivatives/relu.py
+++ b/backpack/core/derivatives/relu.py
@@ -1,13 +1,25 @@
-from torch import gt
+"""Partial derivatives for the ReLU activation function."""
+from typing import List, Tuple
+
+from torch import Tensor, gt
+from torch.nn import ReLU
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class ReLUDerivatives(ElementwiseDerivatives):
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
"""`ReLU''(x) = 0`."""
return True
- def df(self, module, g_inp, g_out):
+ def df(
+ self,
+ module: ReLU,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""First ReLU derivative: `ReLU'(x) = 0 if x < 0 else 1`."""
- return gt(module.input0, 0).float()
+ input0 = subsample(module.input0, subsampling=subsampling)
+ return gt(input0, 0).to(input0.dtype)
diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py
new file mode 100644
index 000000000..792eda640
--- /dev/null
+++ b/backpack/core/derivatives/rnn.py
@@ -0,0 +1,266 @@
+"""Partial derivatives for the torch.nn.RNN layer."""
+from typing import List, Tuple
+
+from torch import Tensor, cat, einsum, zeros
+from torch.nn import RNN
+
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
+from backpack.utils.subsampling import subsample
+
+
+class RNNDerivatives(BaseParameterDerivatives):
+ """Partial derivatives for the torch.nn.RNN layer.
+
+ a_t = W_ih x_t + b_ih + W_hh h_{t-1} + b_hh
+ h_t = tanh(a_t)
+
+ We assume that it is always batch axis first.
+
+ Index conventions:
+ ------------------
+ * t: Sequence dimension
+ * v: Free dimension
+ * n: Batch dimension
+ * h: Output dimension
+ * i: Input dimension
+ """
+
+ @staticmethod
+ def _check_parameters(module: RNN) -> None:
+ """Check the parameters of module.
+
+ Args:
+ module: module which to check
+
+ Raises:
+ NotImplementedError: If any parameter of module does not match expectation
+ """
+ if not module.batch_first:
+ raise NotImplementedError("Batch axis must be first.")
+ if module.num_layers > 1:
+ raise NotImplementedError("only num_layers = 1 is supported")
+ if not module.nonlinearity == "tanh":
+ raise NotImplementedError("only nonlinearity = tanh is supported")
+ if module.bias is not True:
+ raise NotImplementedError("only bias = True is supported")
+ if not module.dropout == 0:
+ raise NotImplementedError("only dropout = 0 is supported")
+ if module.bidirectional is not False:
+ raise NotImplementedError("only bidirectional = False is supported")
+
+ def hessian_is_zero(self, module: RNN) -> bool: # noqa: D102
+ return False
+
+ @classmethod
+ def _a_jac_t_mat_prod(
+ cls,
+ module: RNN,
+ weight_hh_l0: Tensor,
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Calculates jacobian vector product wrt a.
+
+ Args:
+ module: RNN module
+ weight_hh_l0: weight matrix hidden-to-hidden
+ mat: matrix to multiply
+ subsampling: subsampling
+
+ Returns:
+ jacobian vector product wrt a
+ """
+ V, N, T, H = mat.shape
+ output = subsample(module.output, dim=0, subsampling=subsampling)
+ a_jac_t_mat_prod: Tensor = zeros(V, N, T, H, device=mat.device, dtype=mat.dtype)
+ for t in reversed(range(T)):
+ if t == (T - 1):
+ a_jac_t_mat_prod[:, :, t] = einsum(
+ "vnh,nh->vnh", mat[:, :, t], 1 - output[:, t] ** 2
+ )
+ else:
+ a_jac_t_mat_prod[:, :, t] = einsum(
+ "vnh,nh->vnh",
+ mat[:, :, t]
+ + einsum(
+ "vng,gh->vnh",
+ a_jac_t_mat_prod[:, :, t + 1],
+ weight_hh_l0,
+ ),
+ 1 - output[:, t] ** 2,
+ )
+ return a_jac_t_mat_prod
+
+ def _jac_t_mat_prod(
+ self,
+ module: RNN,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ self._check_parameters(module)
+ return einsum(
+ f"vnth,hk->v{'nt' if module.batch_first else 'tn'}k",
+ self._a_jac_t_mat_prod(
+ module,
+ module.weight_hh_l0,
+ mat,
+ subsampling,
+ ),
+ module.weight_ih_l0,
+ )
+
+ def _jac_mat_prod(
+ self, module: RNN, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
+ ) -> Tensor:
+ self._check_parameters(module)
+ H: int = module.hidden_size
+ V, N, T, _ = mat.shape
+ _jac_mat_prod: Tensor = zeros(V, N, T, H, device=mat.device, dtype=mat.dtype)
+ for t in range(T):
+ if t == 0:
+ _jac_mat_prod[:, :, t] = einsum(
+ "nh,hi,vni->vnh",
+ 1 - module.output[:, t] ** 2,
+ module.weight_ih_l0,
+ mat[:, :, t],
+ )
+ else:
+ _jac_mat_prod[:, :, t] = einsum(
+ "nh,vnh->vnh",
+ 1 - module.output[:, t] ** 2,
+ einsum(
+ "hi,vni->vnh",
+ module.weight_ih_l0,
+ mat[:, :, t],
+ )
+ + einsum(
+ "hk,vnk->vnh",
+ module.weight_hh_l0,
+ _jac_mat_prod[:, :, t - 1],
+ ),
+ )
+ return _jac_mat_prod
+
+ def _bias_ih_l0_jac_t_mat_prod(
+ self,
+ module: RNN,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Apply transposed Jacobian of the output w.r.t. bias_ih_l0.
+
+ Args:
+ module: extended module
+ g_inp: input gradient
+ g_out: output gradient
+ mat: matrix to multiply
+ sum_batch: Whether to sum along batch axis. Defaults to True.
+ subsampling: Indices of active samples. Defaults to ``None`` (all samples).
+
+ Returns:
+ product
+ """
+ self._check_parameters(module)
+ if sum_batch:
+ dim: List[int] = [1, 2]
+ else:
+ dim: int = 2
+ return self._a_jac_t_mat_prod(
+ module,
+ module.weight_hh_l0,
+ mat,
+ subsampling,
+ ).sum(dim=dim)
+
+ def _bias_hh_l0_jac_t_mat_prod(
+ self,
+ module: RNN,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Apply transposed Jacobian of the output w.r.t. bias_hh_l0.
+
+ Args:
+ module: extended module
+ g_inp: input gradient
+ g_out: output gradient
+ mat: matrix to multiply
+ sum_batch: Whether to sum along batch axis. Defaults to True.
+ subsampling: Indices of active samples. Defaults to ``None`` (all samples).
+
+ Returns:
+ product
+ """
+ return self._bias_ih_l0_jac_t_mat_prod(
+ module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
+ )
+
+ def _weight_ih_l0_jac_t_mat_prod(
+ self,
+ module: RNN,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Apply transposed Jacobian of the output w.r.t. weight_ih_l0.
+
+ Args:
+ module: extended module
+ g_inp: input gradient
+ g_out: output gradient
+ mat: matrix to multiply
+ sum_batch: Whether to sum along batch axis. Defaults to True.
+ subsampling: Indices of active samples. Defaults to ``None`` (all samples).
+
+ Returns:
+ product
+ """
+ self._check_parameters(module)
+ return einsum(
+ f"vnth,ntj->v{'' if sum_batch else 'n'}hj",
+ self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat, subsampling),
+ subsample(module.input0, dim=0, subsampling=subsampling),
+ )
+
+ def _weight_hh_l0_jac_t_mat_prod(
+ self,
+ module: RNN,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Apply transposed Jacobian of the output w.r.t. weight_hh_l0.
+
+ Args:
+ module: extended module
+ g_inp: input gradient
+ g_out: output gradient
+ mat: matrix to multiply
+ sum_batch: Whether to sum along batch axis. Defaults to True.
+ subsampling: Indices of active samples. Defaults to ``None`` (all samples).
+
+ Returns:
+ product
+ """
+ self._check_parameters(module)
+ _, N, _, H = mat.shape
+ output = subsample(module.output, dim=0, subsampling=subsampling)
+ single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype)
+ output_shifted = cat([single_step, output[:, :-1]], dim=1)
+ return einsum(
+ f"vnth,ntk->v{'' if sum_batch else 'n'}hk",
+ self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat, subsampling),
+ output_shifted,
+ )
diff --git a/backpack/core/derivatives/scale_module.py b/backpack/core/derivatives/scale_module.py
new file mode 100644
index 000000000..9965a204c
--- /dev/null
+++ b/backpack/core/derivatives/scale_module.py
@@ -0,0 +1,25 @@
+"""Derivatives of ScaleModule (implies Identity)."""
+from typing import List, Tuple, Union
+
+from torch import Tensor
+from torch.nn import Identity
+
+from backpack.core.derivatives.basederivatives import BaseDerivatives
+from backpack.custom_module.scale_module import ScaleModule
+
+
+class ScaleModuleDerivatives(BaseDerivatives):
+ """Derivatives of ScaleModule (implies Identity)."""
+
+ def _jac_t_mat_prod(
+ self,
+ module: Union[ScaleModule, Identity],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ if isinstance(module, Identity):
+ return mat
+ else:
+ return mat * module.weight
diff --git a/backpack/core/derivatives/selu.py b/backpack/core/derivatives/selu.py
index 33c4a9ceb..b6e1c6852 100644
--- a/backpack/core/derivatives/selu.py
+++ b/backpack/core/derivatives/selu.py
@@ -1,7 +1,11 @@
"""Partial derivatives for the SELU activation function."""
-from torch import exp, le, ones_like, zeros_like
+from typing import List, Tuple
+
+from torch import Tensor, exp, le, ones_like, zeros_like
+from torch.nn import SELU
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class SELUDerivatives(ElementwiseDerivatives):
@@ -10,16 +14,23 @@ class SELUDerivatives(ElementwiseDerivatives):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
"""`SELU''(x) != 0`."""
return False
- def df(self, module, g_inp, g_out):
+ def df(
+ self,
+ module: SELU,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""First SELU derivative: `SELU'(x) = scale if x > 0 else scale*alpha*e^x`."""
- non_pos = le(module.input0, 0)
+ input0 = subsample(module.input0, subsampling=subsampling)
+ non_pos = le(input0, 0)
- result = self.scale * ones_like(module.input0)
- result[non_pos] = self.scale * self.alpha * exp(module.input0[non_pos])
+ result = self.scale * ones_like(input0)
+ result[non_pos] = self.scale * self.alpha * exp(input0[non_pos])
return result
diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py
index 32b593a89..d141f8fc8 100644
--- a/backpack/core/derivatives/shape_check.py
+++ b/backpack/core/derivatives/shape_check.py
@@ -1,18 +1,24 @@
-"""
-Helpers to support application of Jacobians to vectors
+"""Helpers to support application of Jacobians to vectors.
+
Helpers to check input and output sizes of Jacobian-matrix products.
"""
import functools
+from typing import Any, Callable
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.utils.subsampling import subsample
###############################################################################
# Utility functions #
###############################################################################
-def add_V_dim(mat):
+def _add_V_dim(mat):
return mat.unsqueeze(0)
-def remove_V_dim(mat):
+def _remove_V_dim(mat):
if mat.shape[0] != 1:
raise RuntimeError(
"Cannot unsqueeze dimension 0. ", "Got tensor of shape {}".format(mat.shape)
@@ -20,8 +26,17 @@ def remove_V_dim(mat):
return mat.squeeze(0)
-def check_shape(mat, like, diff=1):
- """Compare dimension diff,diff+1, ... with dimension 0,1,..."""
+def check_shape(mat: Tensor, like: Tensor, diff: int = 1) -> None:
+ """Compare dimension diff,diff+1, ... with dimension 0,1,...
+
+ Args:
+ mat: matrix
+ like: comparison matrix
+ diff: difference in dimensions. Defaults to 1.
+
+ Raises:
+ RuntimeError: if shape does not fit
+ """
mat_shape = [int(dim) for dim in mat.shape]
like_shape = [int(dim) for dim in like.shape]
@@ -40,80 +55,113 @@ def check_shape(mat, like, diff=1):
def check_same_V_dim(mat1, mat2):
+ """Check whether V dim (first dim) matches.
+
+ Args:
+ mat1: first tensor
+ mat2: second tensor
+
+ Raises:
+ RuntimeError: if V dim (first dim) doesn't match
+ """
V1, V2 = mat1.shape[0], mat2.shape[0]
if V1 != V2:
raise RuntimeError("Number of vectors changed. Got {} and {}".format(V1, V2))
-def check_like(mat, module, name, diff=1, *args, **kwargs):
- return check_shape(mat, getattr(module, name), diff=diff)
+def _check_like(mat, module, name, diff=1, *args, **kwargs):
+ if name in ["output", "input0"] and "subsampling" in kwargs.keys():
+ compare = subsample(
+ getattr(module, name), dim=0, subsampling=kwargs["subsampling"]
+ )
+ else:
+ compare = getattr(module, name)
+
+ return check_shape(mat, compare, diff=diff)
def check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs):
+ """Checks shape, considers sum_batch.
+
+ Args:
+ mat: matrix to multiply
+ module: module
+ name: parameter to operate on: module.name
+ sum_batch: whether to consider with or without sum
+ *args: ignored
+ **kwargs: ignored
+ """
diff = 1 if sum_batch else 2
- return check_shape(mat, getattr(module, name), diff=diff)
+ check_shape(mat, getattr(module, name), diff=diff)
-def same_dim_as(mat, module, name, *args, **kwargs):
+def _same_dim_as(mat, module, name, *args, **kwargs):
return len(mat.shape) == len(getattr(module, name).shape)
###############################################################################
# Decorators for handling vectors as matrix special case #
###############################################################################
-def mat_prod_accept_vectors(mat_prod, vec_criterion):
+def _mat_prod_accept_vectors(
+ mat_prod: Callable[..., Tensor],
+ vec_criterion: Callable[[Tensor, Module, Any, Any], bool],
+) -> Callable[..., Tensor]:
"""Add support for vectors to matrix products.
vec_criterion(mat, module) returns if mat is a vector.
+
+ Args:
+ mat_prod: Function that processes multiple vectors in format of a matrix.
+ vec_criterion: Function that returns true if an input is a single vector
+ that must be formatted into a matrix first before processing.
+
+ Returns:
+ Wrapped ``mat_prod`` function that processes multiple vectors in format of
+ a matrix, and supports vector-shaped inputs which are internally converted
+ to the correct format.
+ Preserves format of input:
+ If the input format is a vector, the output format is a vector.
+ If the input format is a matrix, the output format is a matrix.
"""
@functools.wraps(mat_prod)
- def wrapped_mat_prod_accept_vectors(
+ def _wrapped_mat_prod_accept_vectors(
self, module, g_inp, g_out, mat, *args, **kwargs
):
is_vec = vec_criterion(mat, module, *args, **kwargs)
- mat_in = mat if not is_vec else add_V_dim(mat)
+ mat_in = mat if not is_vec else _add_V_dim(mat)
mat_out = mat_prod(self, module, g_inp, g_out, mat_in, *args, **kwargs)
- mat_out = mat_out if not is_vec else remove_V_dim(mat_out)
+ mat_out = mat_out if not is_vec else _remove_V_dim(mat_out)
return mat_out
- return wrapped_mat_prod_accept_vectors
+ return _wrapped_mat_prod_accept_vectors
# vec criteria
-same_dim_as_output = functools.partial(same_dim_as, name="output")
-same_dim_as_input = functools.partial(same_dim_as, name="input0")
-same_dim_as_weight = functools.partial(same_dim_as, name="weight")
-same_dim_as_bias = functools.partial(same_dim_as, name="bias")
+same_dim_as_output = functools.partial(_same_dim_as, name="output")
+same_dim_as_input = functools.partial(_same_dim_as, name="input0")
+same_dim_as_weight = functools.partial(_same_dim_as, name="weight")
+same_dim_as_bias = functools.partial(_same_dim_as, name="bias")
# decorators for handling vectors
jac_t_mat_prod_accept_vectors = functools.partial(
- mat_prod_accept_vectors,
- vec_criterion=same_dim_as_output,
-)
-
-weight_jac_t_mat_prod_accept_vectors = functools.partial(
- mat_prod_accept_vectors,
- vec_criterion=same_dim_as_output,
-)
-bias_jac_t_mat_prod_accept_vectors = functools.partial(
- mat_prod_accept_vectors,
+ _mat_prod_accept_vectors,
vec_criterion=same_dim_as_output,
)
jac_mat_prod_accept_vectors = functools.partial(
- mat_prod_accept_vectors,
+ _mat_prod_accept_vectors,
vec_criterion=same_dim_as_input,
)
weight_jac_mat_prod_accept_vectors = functools.partial(
- mat_prod_accept_vectors,
+ _mat_prod_accept_vectors,
vec_criterion=same_dim_as_weight,
)
bias_jac_mat_prod_accept_vectors = functools.partial(
- mat_prod_accept_vectors,
+ _mat_prod_accept_vectors,
vec_criterion=same_dim_as_bias,
)
@@ -121,8 +169,20 @@ def wrapped_mat_prod_accept_vectors(
###############################################################################
# Decorators for checking inputs and outputs of mat_prod routines #
###############################################################################
-def mat_prod_check_shapes(mat_prod, in_check, out_check):
- """Check that input and output have correct shapes."""
+def mat_prod_check_shapes(
+ mat_prod: Callable, in_check: Callable, out_check: Callable
+) -> Callable[..., Tensor]:
+ """Check that input and output have correct shapes.
+
+ Args:
+ mat_prod: Function that applies a derivative operator to multiple vectors
+ handed in as a matrix.
+ in_check: Function that checks the input to mat_prod
+ out_check: Function that checks the output to mat_prod
+
+ Returns:
+ Wrapped mat_prod function with input and output checks
+ """
@functools.wraps(mat_prod)
def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwargs):
@@ -137,16 +197,10 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar
# input/output checker
-shape_like_output = functools.partial(check_like, name="output")
-shape_like_input = functools.partial(check_like, name="input0")
-shape_like_weight = functools.partial(check_like, name="weight")
-shape_like_bias = functools.partial(check_like, name="bias")
-shape_like_weight_with_sum_batch = functools.partial(
- check_like_with_sum_batch, name="weight"
-)
-shape_like_bias_with_sum_batch = functools.partial(
- check_like_with_sum_batch, name="bias"
-)
+shape_like_output = functools.partial(_check_like, name="output")
+shape_like_input = functools.partial(_check_like, name="input0")
+shape_like_weight = functools.partial(_check_like, name="weight")
+shape_like_bias = functools.partial(_check_like, name="bias")
# decorators for shape checking
jac_mat_prod_check_shapes = functools.partial(
@@ -165,18 +219,6 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar
mat_prod_check_shapes, in_check=shape_like_output, out_check=shape_like_input
)
-
-weight_jac_t_mat_prod_check_shapes = functools.partial(
- mat_prod_check_shapes,
- in_check=shape_like_output,
- out_check=shape_like_weight_with_sum_batch,
-)
-bias_jac_t_mat_prod_check_shapes = functools.partial(
- mat_prod_check_shapes,
- in_check=shape_like_output,
- out_check=shape_like_bias_with_sum_batch,
-)
-
###############################################################################
# Wrapper for second-order extensions #
###############################################################################
@@ -185,44 +227,102 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar
)
residual_mat_prod_accept_vectors = functools.partial(
- mat_prod_accept_vectors,
+ _mat_prod_accept_vectors,
vec_criterion=same_dim_as_input,
)
# TODO Refactor using partials
-def make_hessian_mat_prod_accept_vectors(make_hessian_mat_prod):
+def make_hessian_mat_prod_accept_vectors(
+ make_hessian_mat_prod: Callable,
+) -> Callable[..., Callable[..., Tensor]]:
+ """Accept vectors for hessian_mat_prod.
+
+ Args:
+ make_hessian_mat_prod: Function that creates multiplication routine
+ of a matrix with the module Hessian
+
+ Returns:
+ Wrapped hessian_mat_prod which converts vector-format inputs to a matrix
+ before processing. Preserves format of input.
+ """
+
@functools.wraps(make_hessian_mat_prod)
- def wrapped_make_hessian_mat_prod(self, module, g_inp, g_out):
+ def _wrapped_make_hessian_mat_prod(self, module, g_inp, g_out):
hessian_mat_prod = make_hessian_mat_prod(self, module, g_inp, g_out)
- def new_hessian_mat_prod(mat):
- is_vec = same_dim_as(mat, module, "input0")
- mat_in = mat if not is_vec else add_V_dim(mat)
+ def _new_hessian_mat_prod(mat):
+ is_vec = _same_dim_as(mat, module, "input0")
+ mat_in = mat if not is_vec else _add_V_dim(mat)
mat_out = hessian_mat_prod(mat_in)
- mat_out = mat_out if not is_vec else remove_V_dim(mat_out)
+ mat_out = mat_out if not is_vec else _remove_V_dim(mat_out)
return mat_out
- return new_hessian_mat_prod
+ return _new_hessian_mat_prod
+
+ return _wrapped_make_hessian_mat_prod
+
- return wrapped_make_hessian_mat_prod
+def make_hessian_mat_prod_check_shapes(
+ make_hessian_mat_prod: Callable[..., Callable[..., Tensor]],
+) -> Callable[..., Callable[..., Tensor]]:
+ """Wrap hessian_mat_prod with shape checks for input and output.
+ Args:
+ make_hessian_mat_prod: function that creates multiplication routine of
+ a matrix with the module Hessian.
+
+ Returns:
+ wrapped hessian_mat_prod with shape checks for input and output
+ """
-def make_hessian_mat_prod_check_shapes(make_hessian_mat_prod):
@functools.wraps(make_hessian_mat_prod)
- def wrapped_make_hessian_mat_prod(self, module, g_inp, g_out):
+ def _wrapped_make_hessian_mat_prod(self, module, g_inp, g_out):
hessian_mat_prod = make_hessian_mat_prod(self, module, g_inp, g_out)
- def new_hessian_mat_prod(mat):
- check_like(mat, module, "input0")
+ def _new_hessian_mat_prod(mat):
+ _check_like(mat, module, "input0")
result = hessian_mat_prod(mat)
- check_like(result, module, "input0")
+ _check_like(result, module, "input0")
return result
- return new_hessian_mat_prod
+ return _new_hessian_mat_prod
+
+ return _wrapped_make_hessian_mat_prod
+
+
+def param_mjp_accept_vectors(mat_prod: Callable[..., Tensor]) -> Callable[..., Tensor]:
+ """Add support for vectors to matrix products.
+
+ vec_criterion(mat, module) returns if mat is a vector.
+
+ Args:
+ mat_prod: Function that processes multiple vectors in format of a matrix.
+
+ Returns:
+ Wrapped ``mat_prod`` function that processes multiple vectors in format of
+ a matrix, and supports vector-shaped inputs which are internally converted
+ to the correct format.
+ Preserves format of input:
+ If the input format is a vector, the output format is a vector.
+ If the input format is a matrix, the output format is a matrix.
+ """
+
+ @functools.wraps(mat_prod)
+ def _wrapped_mat_prod_accept_vectors(
+ self, param_str, module, g_inp, g_out, mat, *args, **kwargs
+ ):
+ is_vec = same_dim_as_output(mat, module)
+ mat_in = mat if not is_vec else _add_V_dim(mat)
+ mat_out = mat_prod(
+ self, param_str, module, g_inp, g_out, mat_in, *args, **kwargs
+ )
+ mat_out = mat_out if not is_vec else _remove_V_dim(mat_out)
+
+ return mat_out
- return wrapped_make_hessian_mat_prod
+ return _wrapped_mat_prod_accept_vectors
diff --git a/backpack/core/derivatives/sigmoid.py b/backpack/core/derivatives/sigmoid.py
index 5b45a114b..f03573e57 100644
--- a/backpack/core/derivatives/sigmoid.py
+++ b/backpack/core/derivatives/sigmoid.py
@@ -1,14 +1,28 @@
+"""Partial derivatives for the Sigmoid activation function."""
+from typing import List, Tuple
+
+from torch import Tensor
+from torch.nn import Sigmoid
+
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class SigmoidDerivatives(ElementwiseDerivatives):
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
"""`σ''(x) ≠ 0`."""
return False
- def df(self, module, g_inp, g_out):
+ def df(
+ self,
+ module: Sigmoid,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
"""First sigmoid derivative: `σ'(x) = σ(x) (1 - σ(x))`."""
- return module.output * (1.0 - module.output)
+ output = subsample(module.output, subsampling=subsampling)
+ return output * (1.0 - output)
def d2f(self, module, g_inp, g_out):
"""Second sigmoid derivative: `σ''(x) = σ(x) (1 - σ(x)) (1 - 2 σ(x))`."""
diff --git a/backpack/core/derivatives/sum_module.py b/backpack/core/derivatives/sum_module.py
new file mode 100644
index 000000000..e8383fe6b
--- /dev/null
+++ b/backpack/core/derivatives/sum_module.py
@@ -0,0 +1,21 @@
+"""Contains derivatives for SumModule."""
+from typing import List, Tuple
+
+from torch import Tensor
+
+from backpack.core.derivatives.basederivatives import BaseDerivatives
+from backpack.custom_module.branching import SumModule
+
+
+class SumModuleDerivatives(BaseDerivatives):
+ """Contains derivatives for SumModule."""
+
+ def _jac_t_mat_prod(
+ self,
+ module: SumModule,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ return mat
diff --git a/backpack/core/derivatives/tanh.py b/backpack/core/derivatives/tanh.py
index 525cb3aa2..1fd6c9c14 100644
--- a/backpack/core/derivatives/tanh.py
+++ b/backpack/core/derivatives/tanh.py
@@ -1,12 +1,26 @@
+"""Partial derivatives for the Tanh activation function."""
+from typing import List, Tuple
+
+from torch import Tensor
+from torch.nn import Tanh
+
from backpack.core.derivatives.elementwise import ElementwiseDerivatives
+from backpack.utils.subsampling import subsample
class TanhDerivatives(ElementwiseDerivatives):
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
return False
- def df(self, module, g_inp, g_out):
- return 1.0 - module.output ** 2
+ def df(
+ self,
+ module: Tanh,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ output = subsample(module.output, subsampling=subsampling)
+ return 1.0 - output ** 2
def d2f(self, module, g_inp, g_out):
return -2.0 * module.output * (1.0 - module.output ** 2)
diff --git a/backpack/core/derivatives/zeropad2d.py b/backpack/core/derivatives/zeropad2d.py
index 197566461..07af6c95e 100644
--- a/backpack/core/derivatives/zeropad2d.py
+++ b/backpack/core/derivatives/zeropad2d.py
@@ -1,11 +1,15 @@
+"""Partial derivatives for the ZeroPad2d function."""
+from typing import List, Tuple
+
from einops import rearrange
-from torch.nn import functional
+from torch import Tensor
+from torch.nn import ZeroPad2d, functional
from backpack.core.derivatives.basederivatives import BaseDerivatives
class ZeroPad2dDerivatives(BaseDerivatives):
- def hessian_is_zero(self):
+ def hessian_is_zero(self, module):
return True
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
@@ -27,7 +31,14 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
return result.view(in_features, in_features)
- def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
+ def _jac_t_mat_prod(
+ self,
+ module: ZeroPad2d,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: List[int] = None,
+ ) -> Tensor:
(W_top, W_bottom), (H_bottom, H_top) = self.__unpad_indices(module)
return mat[:, :, :, W_top:W_bottom, H_bottom:H_top]
diff --git a/backpack/custom_module/__init__.py b/backpack/custom_module/__init__.py
new file mode 100644
index 000000000..1b3c52f2b
--- /dev/null
+++ b/backpack/custom_module/__init__.py
@@ -0,0 +1,4 @@
+"""This package adds torch.nn.Module type modules.
+
+These are used as utilities.
+"""
diff --git a/backpack/custom_module/branching.py b/backpack/custom_module/branching.py
new file mode 100644
index 000000000..109b88c87
--- /dev/null
+++ b/backpack/custom_module/branching.py
@@ -0,0 +1,113 @@
+"""Emulating branching with modules."""
+from typing import Any, OrderedDict, Tuple, Union
+
+from torch import Tensor
+from torch.nn import Module
+
+
+class _Branch(Module):
+ """Module used by BackPACK to handle branching in the computation graph.
+
+ ↗ module1 → output1
+ input → module2 → output2
+ ↘ ... → ...
+ """
+
+ def __init__(self, *args: Union[OrderedDict[str, Module], Module]):
+ """Use interface of ``torch.nn.Sequential``. Modules are parallel sequence.
+
+ Args:
+ args: either an ordered dictionary of modules or a tuple of modules
+ """
+ super().__init__()
+
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
+ for key, module in args[0].items():
+ self.add_module(key, module)
+ else:
+ for idx, module in enumerate(args):
+ self.add_module(str(idx), module)
+
+ def forward(self, input: Tensor) -> Tuple[Any, ...]:
+ """Feed input through each child module.
+
+ Args:
+ input: input tensor
+
+ Returns:
+ tuple of output tensor
+ """
+ return tuple(module(input) for module in self.children())
+
+
+class SumModule(Module):
+ """Module used by BackPACK to handle branch merges in the computation graph.
+
+ module 1 ↘
+ module 2 → SumModule (sum)
+ ... ↗
+ """
+
+ def forward(self, *input: Tensor) -> Tensor:
+ """Sum up all inputs (a tuple of tensors).
+
+ Args:
+ input: tuple of input tensors
+
+ Returns:
+ sum of all inputs
+
+ Raises:
+ AssertionError: if input is no tuple of matching tensors
+ """
+ if not isinstance(input, tuple):
+ raise AssertionError(f"Expecting tuple as input. Got {input.__class__}")
+ elif not all(isinstance(inp, Tensor) for inp in input):
+ raise AssertionError(
+ f"Expecting tuple of tensors, but received ({[inp.__class__ for inp in input]})"
+ )
+ elif not all(input[0].shape == input[i].shape for i in range(1, len(input))):
+ raise AssertionError(f"Shapes don't match: {[inp.shape for inp in input]}.")
+ else:
+ return sum(input)
+
+
+class Parallel(Module):
+ """Feed the same input through a parallel sequence of modules. Sum the results.
+
+ Used by BackPACK to emulate branched computations.
+
+ ↗ module 1 ↘
+ Branch → module 2 → SumModule (sum)
+ ↘ ... ↗
+ """
+
+ def __init__(
+ self,
+ *args: Union[OrderedDict[str, Module], Module],
+ merge_module: Module = None,
+ ):
+ """Like ``torch.nn.Sequential``, but defines a parallel module sequence.
+
+ Use interface of ``torch.nn.Sequential``.
+
+ Args:
+ args: either ordered dictionary of modules or tuple of modules
+ merge_module: The module used for merging. Defaults to ``None``, which
+ means ``SumModule()`` is used.
+ """
+ super().__init__()
+
+ self.branch = _Branch(*args)
+ self.merge = SumModule() if merge_module is None else merge_module
+
+ def forward(self, input: Tensor) -> Tensor:
+ """Forward pass. Concatenation of Branch and SumModule.
+
+ Args:
+ input: module input
+
+ Returns:
+ Merged results from forward pass of each branch
+ """
+ return self.merge(*self.branch(input))
diff --git a/backpack/custom_module/graph_utils.py b/backpack/custom_module/graph_utils.py
new file mode 100644
index 000000000..62bc0ea03
--- /dev/null
+++ b/backpack/custom_module/graph_utils.py
@@ -0,0 +1,573 @@
+"""Transformation tools to make graph BackPACK compatible."""
+from copy import deepcopy
+from typing import Tuple, Union
+from warnings import warn
+
+from torch.fx import Graph, GraphModule, Node, Tracer
+from torch.nn import LSTM, RNN, Dropout, Flatten, Module, Sequential
+
+from backpack.custom_module.branching import SumModule, _Branch
+from backpack.custom_module.permute import Permute
+from backpack.custom_module.reduce_tuple import ReduceTuple
+from backpack.custom_module.scale_module import ScaleModule
+
+
+class BackpackTracer(Tracer):
+ """Tracer that recognizes BackPACK's custom modules as 'leaf modules'."""
+
+ def is_leaf_module(
+ self, m: Module, module_qualified_name: str
+ ) -> bool: # noqa: D102
+ if isinstance(m, (ScaleModule, SumModule, _Branch, ReduceTuple, Permute)):
+ return True
+ else:
+ return super().is_leaf_module(m, module_qualified_name)
+
+
+def convert_module_to_backpack(module: Module, debug: bool) -> GraphModule:
+ """Convert all modules to BackPACK-compatible modules.
+
+ Transformations:
+ - mul -> ScaleModule
+ - add -> AddModule
+ - flatten -> nn.Flatten
+ - getitem -> ReduceTuple
+ - permute -> Permute
+ - transpose -> Transpose
+ - LSTM: split multiple layers
+ - inplace -> normal
+ - remove duplicates
+ - delete unused modules
+ - check BackPACK compatible
+
+ Args:
+ module: module to convert
+ debug: if True prints to command line
+
+ Returns:
+ BackPACK-compatible module
+ """
+ if debug:
+ print("\nMake module BackPACK-compatible...")
+ module_new = _transform_mul_to_scale_module(module, debug)
+ module_new = _transform_flatten_to_module(module_new, debug)
+ module_new = _transform_add_to_sum_module(module_new, debug)
+ module_new = _transform_get_item_to_module(module_new, debug)
+ module_new = _transform_permute_to_module(module_new, debug)
+ module_new = _transform_transpose_to_module(module_new, debug)
+ module_new = _transform_lstm_rnn(module_new, debug)
+ _transform_inplace_to_normal(module_new, debug)
+ module_new = _transform_remove_duplicates(module_new, debug)
+ if debug:
+ print("\tDelete unused modules.")
+ module_new.delete_all_unused_submodules()
+ _check_backpack_compatible(module_new, debug)
+ if debug:
+ print("Finished transformation.\n")
+ return module_new
+
+
+def _check_backpack_compatible(module: Module, debug: bool) -> None:
+ """Checks whether the computation graph of the given module is BackPACK compatible.
+
+ More specifically, it checks whether all nodes are either input/output
+ or a call to a module. Subsequent checks if the module is extendable in BackPACK
+ have to be done by running the extension.
+
+ Args:
+ module: module to check
+ debug: whether to print debug messages
+ """
+ if debug:
+ print("\tChecking BackPACK compatibility.")
+ graph: Graph = BackpackTracer().trace(module)
+ for node in graph.nodes:
+ if node.op not in ["call_module", "placeholder", "output"]:
+ warn(
+ f"Encountered node that may break second-order extensions: op={node.op}"
+ f", target={node.target}. If you encounter this problem, please open an"
+ " issue at https://github.com/f-dangel/backpack/issues."
+ )
+
+
+def _transform_mul_to_scale_module(module: Module, debug: bool) -> GraphModule:
+ """Transforms multiplications of tensor with float to ScaleModule.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+
+ Raises:
+ RuntimeError: if a multiplication is found but node.args are not (float, Node)
+ """
+ target_function = ""
+ target_method = "multiply"
+ if debug:
+ print(f"\tBegin transformation: {target_function} -> ScaleModule")
+
+ graph: Graph = BackpackTracer().trace(module)
+ nodes_function = [
+ n
+ for n in graph.nodes
+ if n.op == "call_function" and str(n.target) == target_function
+ ]
+ nodes_method = [
+ n
+ for n in graph.nodes
+ if n.op == "call_method" and str(n.target) == target_method
+ ]
+
+ for node in nodes_function:
+ if len(node.args) != 2:
+ raise RuntimeError(f"Expecting 2 arguments, got {len(node.args)}.")
+
+ idx_weight = 0 if isinstance(node.args[0], float) else 1
+ idx_tensor = 1 - idx_weight
+
+ weight = node.args[idx_weight]
+ tensor = node.args[idx_tensor]
+
+ if not (isinstance(weight, float) and isinstance(tensor, Node)):
+ raise RuntimeError(
+ f"Expecting types [float, Node], got {[type(weight), type(tensor)]}."
+ )
+
+ _change_node_to_module(
+ node, "scale_module", module, ScaleModule(weight), (tensor,)
+ )
+ for node in nodes_method:
+ _change_node_to_module(
+ node, "scale_module", module, ScaleModule(node.args[1]), (node.args[0],)
+ )
+
+ graph.lint()
+
+ if debug:
+ print(f"\tMultiplications transformed: {len(nodes_function)+len(nodes_method)}")
+
+ return GraphModule(module, graph)
+
+
+def _transform_add_to_sum_module(module: Module, debug: bool) -> GraphModule:
+ """Transforms summations of tensors to SumModule (useful in ResNets).
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+ """
+ target_function = ""
+ target_method = "add"
+ if debug:
+ print(f"\tBegin transformation: {target_function} -> SumModule")
+
+ graph: Graph = BackpackTracer().trace(module)
+ nodes = [
+ n
+ for n in graph.nodes
+ if (n.op == "call_function" and str(n.target) == target_function)
+ or (n.op == "call_method" and str(n.target) == target_method)
+ ]
+
+ for node in nodes:
+ _change_node_to_module(node, "sum_module", module, SumModule(), node.args)
+
+ graph.lint()
+
+ if debug:
+ print(f"\tSummations transformed: {len(nodes)}")
+
+ return GraphModule(module, graph)
+
+
+def _transform_flatten_to_module(module: Module, debug: bool) -> GraphModule:
+ """Transforms PyTorch's flatten method to the nn.Flatten module.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+ """
+ target_function = " Flatten")
+
+ graph: Graph = BackpackTracer().trace(module)
+ nodes = [
+ n
+ for n in graph.nodes
+ if (n.op == "call_function" and target_function in str(n.target))
+ or (n.op == "call_method" and target_method == str(n.target))
+ ]
+
+ for node in nodes:
+ start_dim = node.args[1] if len(node.args) > 1 else 0
+ end_dim = node.args[2] if len(node.args) > 2 else -1
+ _change_node_to_module(
+ node, "flatten", module, Flatten(start_dim, end_dim), (node.args[0],)
+ )
+
+ graph.lint()
+
+ if debug:
+ print(f"\tFlatten functions transformed: {len(nodes)}")
+
+ return GraphModule(module, graph)
+
+
+def _transform_get_item_to_module(module: Module, debug: bool) -> GraphModule:
+ """Transforms the built-in getitem function to ReduceTuple module.
+
+ This function is usually used to reduce the tuple output of RNNs.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+ """
+ target = ""
+ if debug:
+ print(f"\tBegin transformation: {target} -> ReduceTuple")
+ graph: Graph = BackpackTracer().trace(module)
+
+ nodes = [
+ n for n in graph.nodes if n.op == "call_function" and target == str(n.target)
+ ]
+ for node in nodes:
+ _change_node_to_module(
+ node,
+ "reduce_tuple",
+ module,
+ ReduceTuple(index=node.args[1]),
+ (node.args[0],),
+ )
+
+ graph.lint()
+ if debug:
+ print(f"\tReduceTuple transformed: {len(nodes)}")
+ return GraphModule(module, graph)
+
+
+def _transform_permute_to_module(module: Module, debug: bool) -> GraphModule:
+ """Transforms permute function or method to Permute module.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+ """
+ target1 = "permute"
+ target2 = " Permute")
+ graph: Graph = BackpackTracer().trace(module)
+
+ nodes = [
+ n
+ for n in graph.nodes
+ if (n.op == "call_function" and target2 in str(n.target))
+ or (n.op == "call_method" and target1 == str(n.target))
+ ]
+
+ for node in nodes:
+ _change_node_to_module(
+ node,
+ "permute",
+ module,
+ Permute(*node.args[1]) if len(node.args) == 2 else Permute(*node.args[1:]),
+ (node.args[0],),
+ )
+
+ graph.lint()
+ if debug:
+ print(f"\tPermute transformed: {len(nodes)}")
+ return GraphModule(module, graph)
+
+
+def _transform_transpose_to_module(module: Module, debug: bool) -> GraphModule:
+ """Transforms transpose function or method to Permute module.
+
+ The Permute module is initialized with transpose parameters and computes
+ the permutation on its first forward pass.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+ """
+ target_function = " Permute")
+ graph: Graph = BackpackTracer().trace(module)
+
+ nodes = [
+ n
+ for n in graph.nodes
+ if (n.op == "call_function" and target_function in str(n.target))
+ or (n.op == "call_method" and target_method == str(n.target))
+ ]
+
+ for node in nodes:
+ _change_node_to_module(
+ node,
+ "permute",
+ module,
+ Permute(*node.args[1:], init_transpose=True),
+ (node.args[0],),
+ )
+
+ graph.lint()
+ if debug:
+ print(f"\tPermute transformed: {len(nodes)}")
+ return GraphModule(module, graph)
+
+
+def _transform_lstm_rnn(module: Module, debug: bool) -> GraphModule:
+ """Transforms multi-layer RNN/LSTM to Sequential of single-layer RNN/LSTM.
+
+ Converts multi-layer RNN/LSTM to Sequential with single-layer RNN/LSTM.
+ If dropout probability is nonzero, creates intermediate dropout layers.
+ Finally, copies training mode.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+
+ Raises:
+ NotImplementedError: if initial hidden state is used in forward pass
+ """
+ if debug:
+ print("\tBegin transformation: LSTM, RNN")
+ graph: Graph = BackpackTracer().trace(module)
+
+ nodes = [
+ n
+ for n in graph.nodes
+ if n.op == "call_module"
+ and isinstance(module.get_submodule(n.target), (RNN, LSTM))
+ and module.get_submodule(n.target).num_layers > 1
+ ]
+ for node in nodes:
+ if len(node.args) > 1:
+ raise NotImplementedError(
+ "For conversion, LSTM/RNN input must not have hidden states."
+ )
+ lstm_module_replace = _make_rnn_backpack(module.get_submodule(node.target))
+ module.add_module(node.target, lstm_module_replace)
+
+ graph.lint()
+ if debug:
+ print(f"\tRNNs, LSTMs transformed: {len(nodes)}")
+ return GraphModule(module, graph)
+
+
+def _rnn_hyperparams(module: Union[RNN, LSTM]) -> Tuple[int, int, float, bool]:
+ """Determines the hyperparameters for RNN/LSTM conversion.
+
+ Args:
+ module: module to convert
+
+ Returns:
+ input_size, hidden_size, dropout, batch_first
+
+ Raises:
+ NotImplementedError: if any hyperparameter has a forbidden value
+ """
+ if module.bias is not True:
+ raise NotImplementedError("only bias = True is supported")
+ if module.bidirectional is not False:
+ raise NotImplementedError("only bidirectional = False is supported")
+ if isinstance(module, RNN):
+ if module.nonlinearity != "tanh":
+ raise NotImplementedError("only nonlinearity = 'tanh' is supported")
+ elif isinstance(module, LSTM):
+ if module.proj_size != 0:
+ raise NotImplementedError("only proj_size = 0 is supported")
+ return module.input_size, module.hidden_size, module.dropout, module.batch_first
+
+
+def _make_rnn_backpack(module: Union[RNN, LSTM]) -> Module:
+ """Creates an equivalent module to the multi-layer RNN/LSTM.
+
+ Converts multi-layer RNN/LSTM to Sequential with single-layer RNN/LSTM.
+ If dropout probability is nonzero, creates intermediate dropout layers.
+ Finally, copies training mode.
+
+ Args:
+ module: RNN/LSTM module to convert
+
+ Returns:
+ equivalent Sequential module
+ """
+ input_size, hidden_size, dropout, batch_first = _rnn_hyperparams(module)
+ rnn_class = type(module)
+ rnn_module_replace = Sequential()
+ for layer in range(module.num_layers):
+ rnn_layer = rnn_class(
+ input_size if layer == 0 else hidden_size,
+ hidden_size,
+ batch_first=batch_first,
+ )
+ for param_str in ["weight_ih_l", "weight_hh_l", "bias_ih_l", "bias_hh_l"]:
+ setattr(rnn_layer, f"{param_str}0", getattr(module, f"{param_str}{layer}"))
+ rnn_module_replace.add_module(f"lstm_{layer}", rnn_layer)
+ if layer != (module.num_layers - 1):
+ rnn_module_replace.add_module(f"reduce_tuple_{layer}", ReduceTuple())
+ if dropout != 0:
+ rnn_module_replace.add_module(f"dropout_{layer}", Dropout(dropout))
+ rnn_module_replace.train(module.training)
+ return rnn_module_replace
+
+
+def _transform_inplace_to_normal(
+ module: Module, debug: bool, initialize_recursion: bool = True
+) -> None:
+ """Searches for in-place operations and changes them to standard operations.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+ initialize_recursion: whether this is the initial call to this function.
+ """
+ if initialize_recursion:
+ if debug:
+ print("\tBegin transformation: in-place -> standard")
+ _transform_inplace_to_normal.counter = 0
+ if hasattr(module, "inplace") and module.inplace:
+ module.inplace = False
+ _transform_inplace_to_normal.counter += 1
+ for child_module in module.children():
+ _transform_inplace_to_normal(child_module, debug, initialize_recursion=False)
+
+ if initialize_recursion:
+ if debug:
+ print(f"\tIn-place changed: {_transform_inplace_to_normal.counter}")
+ del _transform_inplace_to_normal.counter
+
+
+def _transform_remove_duplicates(module: GraphModule, debug: bool) -> GraphModule:
+ """Removes duplicate modules by creating a copy of the module.
+
+ This is necessary because BackPACK saves input/output which is overwritten
+ if the module is called multiple times.
+
+ Args:
+ module: container module to transform
+ debug: whether to print debug messages
+
+ Returns:
+ equivalent transformed module
+
+ Raises:
+ NotImplementedError: if a duplicate module has parameters
+ """
+ if debug:
+ print("\tBegin transformation: remove duplicates")
+
+ graph: Graph = BackpackTracer().trace(module)
+
+ targets = [n.target for n in graph.nodes]
+ duplicates = {t for t in targets if targets.count(t) > 1}
+ nodes = [n for n in graph.nodes if n.target in duplicates]
+
+ for node in nodes:
+ target = node.target
+ original_module = module.get_submodule(target)
+
+ for _ in original_module.parameters():
+ raise NotImplementedError(
+ f"Cycle with parameters detected: module {original_module} with target"
+ f" {target} has parameters and is used {targets.count(target)} times."
+ )
+
+ new_module = deepcopy(original_module)
+ new_target = _get_free_name(module, target)
+ module.add_submodule(new_target, new_module)
+ node.target = new_target
+
+ graph.lint()
+
+ if debug:
+ print(f"\tDuplicates removed: {len(nodes)}")
+
+ return GraphModule(module, graph)
+
+
+def _change_node_to_module(
+ node: Node,
+ name: str,
+ base_module: Module,
+ new_module: Module,
+ args: tuple,
+) -> None:
+ """Helper function to change an existing node to a module.
+
+ The new module is registered in the base_module as a submodule.
+ The attribute name is based on name{int}.
+ The attributes of the node are changed so they point onto the new module.
+
+ Args:
+ node: existing node
+ name: proposed name, real name is name{int}
+ base_module: the module that should get new_module as a child
+ new_module: the new module to register on the node and base_module
+ args: arguments of the new node
+ """
+ new_name = _get_free_name(base_module, name)
+ node.op = "call_module"
+ node.target = new_name
+ node.args = args
+ setattr(base_module, new_name, new_module)
+
+
+def _get_free_name(module: Module, initial_name: str) -> str:
+ """Find a free name in the modules naming space.
+
+ Args:
+ module: the parent module
+ initial_name: a name suggestion
+
+ Returns:
+ a string with the pattern {initial_name}{int} where module has no such attribute
+
+ Raises:
+ RuntimeError: if the module already has an attribute with the intended name
+ """
+
+ def _has_target(target: str) -> bool:
+ try:
+ module.get_submodule(target)
+ return True
+ except AttributeError:
+ return False
+
+ counter = 0
+ while _has_target(f"{initial_name}{counter}"):
+ counter += 1
+ name = f"{initial_name}{counter}"
+
+ if hasattr(module, name):
+ raise RuntimeError(
+ f"Unable to find a free name for registering a new module."
+ f"module={module} already has an attribute named {name}."
+ )
+
+ return name
diff --git a/backpack/custom_module/permute.py b/backpack/custom_module/permute.py
new file mode 100644
index 000000000..3213ccd54
--- /dev/null
+++ b/backpack/custom_module/permute.py
@@ -0,0 +1,62 @@
+"""Module containing Permute module."""
+from typing import Any
+
+from torch import Tensor
+from torch.nn import Module
+
+
+class Permute(Module):
+ """Module to permute a tensor."""
+
+ def __init__(self, *dims: Any, init_transpose: bool = False):
+ """Initialization.
+
+ This module supports two variants: permutation and transposition.
+ If transposition should be used, a tuple (axis1, axis2) should be provided and
+ init_transpose must be True.
+ Internally, this is converted to a permutation in the first forward pass.
+
+ Args:
+ dims: The desired ordering of dimensions.
+ init_transpose: If transpose parameters are provided. Default: False.
+ """
+ super().__init__()
+ self.dims = dims
+ self.init_transpose = init_transpose
+ self._enforce_batch_axis_first()
+
+ def forward(self, input: Tensor) -> Tensor:
+ """Permutes the input tensor.
+
+ Args:
+ input: input tensor
+
+ Returns:
+ view with new ordering
+ """
+ self._convert_transpose_to_permute(input)
+ return input.permute(self.dims)
+
+ def _convert_transpose_to_permute(self, input: Tensor):
+ """Converts the parameters of transpose to a permutation.
+
+ Args:
+ input: input tensor. Used to determine number of dimensions.
+ """
+ if self.init_transpose:
+ permutation = list(range(input.dim()))
+ permutation[self.dims[0]] = self.dims[1]
+ permutation[self.dims[1]] = self.dims[0]
+ self.dims = tuple(permutation)
+ self.init_transpose = False
+
+ def _enforce_batch_axis_first(self) -> None:
+ batch_first = False
+ if self.init_transpose:
+ if 0 not in self.dims:
+ batch_first = True
+ else:
+ if self.dims[0] == 0:
+ batch_first = True
+ if not batch_first:
+ raise ValueError("Permute: Batch axis must be left unchanged!")
diff --git a/backpack/custom_module/reduce_tuple.py b/backpack/custom_module/reduce_tuple.py
new file mode 100644
index 000000000..02fa9f5cc
--- /dev/null
+++ b/backpack/custom_module/reduce_tuple.py
@@ -0,0 +1,29 @@
+"""Module containing ReduceTuple module."""
+from typing import Union
+
+from torch import Tensor
+from torch.nn import Module
+
+
+class ReduceTuple(Module):
+ """Module reducing tuple input."""
+
+ def __init__(self, index: int = 0):
+ """Initialization.
+
+ Args:
+ index: which element to choose
+ """
+ super().__init__()
+ self.index = index
+
+ def forward(self, input: tuple) -> Union[tuple, Tensor]:
+ """Reduces the tuple.
+
+ Args:
+ input: the tuple of data
+
+ Returns:
+ the selected element
+ """
+ return input[self.index]
diff --git a/backpack/custom_module/scale_module.py b/backpack/custom_module/scale_module.py
new file mode 100644
index 000000000..2ee03e1a3
--- /dev/null
+++ b/backpack/custom_module/scale_module.py
@@ -0,0 +1,32 @@
+"""Contains ScaleModule."""
+from torch import Tensor
+from torch.nn import Module
+
+
+class ScaleModule(Module):
+ """Scale Module scales the input by a constant."""
+
+ def __init__(self, weight: float = 1.0):
+ """Store scalar weight.
+
+ Args:
+ weight: Initial value for weight. Defaults to 1.0.
+
+ Raises:
+ ValueError: if weight is no float
+ """
+ super().__init__()
+ if not isinstance(weight, float):
+ raise ValueError("Weight must be float.")
+ self.weight: float = weight
+
+ def forward(self, input: Tensor) -> Tensor:
+ """Defines forward pass.
+
+ Args:
+ input: input
+
+ Returns:
+ product of input and weight
+ """
+ return input * self.weight
diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py
index a84a64f71..df0bd558d 100644
--- a/backpack/extensions/__init__.py
+++ b/backpack/extensions/__init__.py
@@ -13,6 +13,8 @@
DiagGGNExact,
DiagGGNMC,
DiagHessian,
+ SqrtGGNExact,
+ SqrtGGNMC,
)
__all__ = [
@@ -33,4 +35,6 @@
"BatchDiagGGNMC",
"DiagHessian",
"BatchDiagHessian",
+ "SqrtGGNExact",
+ "SqrtGGNMC",
]
diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py
index e32d12185..84e12d254 100644
--- a/backpack/extensions/backprop_extension.py
+++ b/backpack/extensions/backprop_extension.py
@@ -1,19 +1,23 @@
"""Implements the backpropagation mechanism."""
+from __future__ import annotations
+
+import abc
import warnings
-from typing import Type
+from abc import ABC
+from typing import Any, Dict, List, Tuple, Type, Union
-import torch.nn
-from torch.nn import Sequential
+from torch import Tensor
+from torch.nn import Module
from backpack.extensions.module_extension import ModuleExtension
-from backpack.utils.hooks import no_op
+from backpack.extensions.saved_quantities import SavedQuantities
FAIL_ERROR = "ERROR"
-FAIL_WARN = "WARN"
+FAIL_WARN = "WARNING"
FAIL_SILENT = "SILENT"
-class BackpropExtension:
+class BackpropExtension(ABC):
"""Base class for the BackPACK extensions.
Descendants of this class need to
@@ -30,28 +34,41 @@ class BackpropExtension:
```
"""
- def __init__(self, savefield, module_exts, fail_mode=FAIL_ERROR):
+ def __init__(
+ self,
+ savefield: str,
+ module_exts: Dict[Type[Module], ModuleExtension],
+ fail_mode: str = FAIL_ERROR,
+ subsampling: List[int] = None,
+ ):
"""Initializes parameters.
Args:
- savefield(str): Where to save results
- module_exts(dict): Maps module classes to `ModuleExtension` instances
- fail_mode(str, optional): Behavior when encountering an unknown layer.
+ savefield: Where to save results
+ module_exts: Maps module classes to `ModuleExtension` instances
+ fail_mode: Behavior when encountering an unknown layer.
Can be
- "ERROR": raise a NotImplementedError
- "WARN": raise a UserWarning
- "SILENT": skip the module silently
Defaults to FAIL_ERROR = "ERROR"
+ subsampling: Indices of active mini-batch samples. ``None`` means
+ all samples in the mini-batch will be considered by the extension.
+ Defaults to ``None``.
+
+ Raises:
+ AssertionError: if fail_mode is not valid
"""
- self.savefield = savefield
- self.__module_extensions = module_exts
- self.__fail_mode = fail_mode
+ if fail_mode not in (FAIL_WARN, FAIL_ERROR, FAIL_SILENT):
+ raise AssertionError(f"no valid fail mode: {fail_mode}")
+ self.saved_quantities: SavedQuantities = SavedQuantities()
+ self.savefield: str = savefield
+ self.__module_extensions: Dict[Type[Module], ModuleExtension] = module_exts
+ self._fail_mode: str = fail_mode
+ self._subsampling = subsampling
def set_module_extension(
- self,
- module: Type[torch.nn.Module],
- extension: ModuleExtension,
- overwrite: bool = False,
+ self, module: Type[Module], extension: ModuleExtension, overwrite: bool = False
) -> None:
"""Adds a module mapping to module_extensions.
@@ -73,38 +90,77 @@ def set_module_extension(
)
self.__module_extensions[module] = extension
- def __get_module_extension(self, module):
+ def __get_module_extension(self, module: Module) -> Union[ModuleExtension, None]:
module_extension = self.__module_extensions.get(module.__class__)
if module_extension is None:
-
- if isinstance(module, Sequential):
- return no_op
-
- if self.__fail_mode is FAIL_ERROR:
+ if self._fail_mode is FAIL_ERROR:
+ # PyTorch converts this Error into a RuntimeError for torch<1.7.0
raise NotImplementedError(
- "Extension saving to {} ".format(self.savefield)
- + "does not have an extension for "
- + "Module {}".format(module.__class__)
- )
- elif self.__fail_mode == FAIL_WARN:
- warnings.warn(
- "Extension saving to {} ".format(self.savefield)
- + "does not have an extension for "
- + "Module {}".format(module.__class__)
+ f"Extension saving to {self.savefield} "
+ "does not have an extension for "
+ f"Module {module.__class__}"
)
+ elif self._fail_mode == FAIL_WARN:
+ for _ in module.parameters():
+ warnings.warn(
+ f"Extension saving to {self.savefield} does not have an "
+ f"extension for Module {module.__class__} "
+ f"although the module has parameters"
+ )
+ break
+
+ return module_extension
+
+ def __call__(
+ self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]
+ ) -> None:
+ """Applies backpropagation.
+
+ Args:
+ module: module to perform backpropagation on
+ g_inp: input gradient
+ g_out: output gradient
+ """
+ module_extension = self.__get_module_extension(module)
+ if module_extension is not None:
+ module_extension(self, module, g_inp, g_out)
+
+ @abc.abstractmethod
+ def expects_backpropagation_quantities(self) -> bool:
+ """Whether the extension uses additional backpropagation quantities.
- return no_op
+ Returns:
+ Whether the extension uses additional backpropagation quantities.
+ """
+ return
- return module_extension.apply
+ def get_subsampling(self) -> Union[List[int], None]:
+ """Return indices of active mini-batch samples.
- def apply(self, module, g_inp, g_out):
- """Applies backpropagation.
+ Returns:
+ Indices of samples considered by the extension. ``None`` signifies that
+ the full mini-batch is used.
+ """
+ return self._subsampling
+
+ def accumulate_backpropagated_quantities(self, existing: Any, other: Any) -> Any:
+ """Specify how to accumulate info that is backpropagated to the same node.
+
+ Must be implemented by second-order extensions to function on computation
+ graphs with branching.
+
+ For instance,
+ - ``DiagGGN`` extensions must sum their backpropagated tensor quantities.
+ - ``curvmatprod`` extensions must chain functions to sums of functions.
Args:
- module(torch.nn.module): module to perform backpropagation on
- g_inp(tuple[torch.Tensor]): input gradient
- g_out(tuple[torch.Tensor]): output gradient
+ existing: Backpropagated quantity
+ other: Other backpropagated quantity
+
+ Raises:
+ NotImplementedError: if not overwritten
"""
- module_extension = self.__get_module_extension(module)
- module_extension(self, module, g_inp, g_out)
+ raise NotImplementedError(
+ f"{self}: No accumulation rule for backpropagated info specified"
+ )
diff --git a/backpack/extensions/curvmatprod/ggnmp/__init__.py b/backpack/extensions/curvmatprod/ggnmp/__init__.py
index c80f3c264..0813e367b 100644
--- a/backpack/extensions/curvmatprod/ggnmp/__init__.py
+++ b/backpack/extensions/curvmatprod/ggnmp/__init__.py
@@ -18,7 +18,7 @@
ZeroPad2d,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
from . import (
activations,
@@ -33,7 +33,7 @@
)
-class GGNMP(BackpropExtension):
+class GGNMP(SecondOrderBackpropExtension):
"""
Matrix-free Multiplication with the block-diagonal generalized Gauss-Newton/Fisher.
diff --git a/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py b/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py
index 03c05ffa5..831117e99 100644
--- a/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py
+++ b/backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py
@@ -1,11 +1,11 @@
-from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase
class GGNMPBatchNorm1d(GGNMPBase):
def __init__(self):
super().__init__(
- derivatives=BatchNorm1dDerivatives(), params=["weight", "bias"]
+ derivatives=BatchNormNdDerivatives(), params=["weight", "bias"]
)
def weight(self, ext, module, g_inp, g_out, backproped):
@@ -14,9 +14,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_ggnmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -28,7 +26,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_ggnmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/ggnmp/conv2d.py b/backpack/extensions/curvmatprod/ggnmp/conv2d.py
index afd9785b6..825d88038 100644
--- a/backpack/extensions/curvmatprod/ggnmp/conv2d.py
+++ b/backpack/extensions/curvmatprod/ggnmp/conv2d.py
@@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_ggnmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_ggnmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/ggnmp/flatten.py b/backpack/extensions/curvmatprod/ggnmp/flatten.py
index 532c24cb6..d47d9d84d 100644
--- a/backpack/extensions/curvmatprod/ggnmp/flatten.py
+++ b/backpack/extensions/curvmatprod/ggnmp/flatten.py
@@ -5,9 +5,3 @@
class GGNMPFlatten(GGNMPBase):
def __init__(self):
super().__init__(derivatives=FlattenDerivatives())
-
- def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
- if self.derivatives.is_no_op(module):
- return backproped
- else:
- return super().backpropagate(ext, module, grad_inp, grad_out, backproped)
diff --git a/backpack/extensions/curvmatprod/ggnmp/linear.py b/backpack/extensions/curvmatprod/ggnmp/linear.py
index 21d92b685..dfb35d3d8 100644
--- a/backpack/extensions/curvmatprod/ggnmp/linear.py
+++ b/backpack/extensions/curvmatprod/ggnmp/linear.py
@@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_ggnmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_ggnmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/hmp/__init__.py b/backpack/extensions/curvmatprod/hmp/__init__.py
index a358b7f67..f94965b35 100644
--- a/backpack/extensions/curvmatprod/hmp/__init__.py
+++ b/backpack/extensions/curvmatprod/hmp/__init__.py
@@ -16,7 +16,7 @@
ZeroPad2d,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
from . import (
activations,
@@ -31,7 +31,7 @@
)
-class HMP(BackpropExtension):
+class HMP(SecondOrderBackpropExtension):
"""Matrix-free multiplication with the block-diagonal Hessian.
Stores the multiplication function in :code:`hmp`.
diff --git a/backpack/extensions/curvmatprod/hmp/batchnorm1d.py b/backpack/extensions/curvmatprod/hmp/batchnorm1d.py
index d441a82c4..aa9f70f3d 100644
--- a/backpack/extensions/curvmatprod/hmp/batchnorm1d.py
+++ b/backpack/extensions/curvmatprod/hmp/batchnorm1d.py
@@ -1,11 +1,11 @@
-from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase
class HMPBatchNorm1d(HMPBase):
def __init__(self):
super().__init__(
- derivatives=BatchNorm1dDerivatives(), params=["weight", "bias"]
+ derivatives=BatchNormNdDerivatives(), params=["weight", "bias"]
)
def weight(self, ext, module, g_inp, g_out, backproped):
@@ -14,9 +14,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_hmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -28,7 +26,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_hmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/hmp/conv2d.py b/backpack/extensions/curvmatprod/hmp/conv2d.py
index 69c74cfa8..7430c042a 100644
--- a/backpack/extensions/curvmatprod/hmp/conv2d.py
+++ b/backpack/extensions/curvmatprod/hmp/conv2d.py
@@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_hmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_hmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/hmp/flatten.py b/backpack/extensions/curvmatprod/hmp/flatten.py
index ab94fd5e2..300f669cf 100644
--- a/backpack/extensions/curvmatprod/hmp/flatten.py
+++ b/backpack/extensions/curvmatprod/hmp/flatten.py
@@ -5,9 +5,3 @@
class HMPFlatten(HMPBase):
def __init__(self):
super().__init__(derivatives=FlattenDerivatives())
-
- def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
- if self.derivatives.is_no_op(module):
- return backproped
- else:
- return super().backpropagate(ext, module, grad_inp, grad_out, backproped)
diff --git a/backpack/extensions/curvmatprod/hmp/hmpbase.py b/backpack/extensions/curvmatprod/hmp/hmpbase.py
index 013582c06..459be8fa4 100644
--- a/backpack/extensions/curvmatprod/hmp/hmpbase.py
+++ b/backpack/extensions/curvmatprod/hmp/hmpbase.py
@@ -29,7 +29,7 @@ def h_in_mat_prod(mat):
result = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, result)
# Multiply with the residual term: mat → [∑ᵢ Hzᵢ(x) δzᵢ] mat.
- if not self.derivatives.hessian_is_zero():
+ if not self.derivatives.hessian_is_zero(module):
result += self.derivatives.residual_mat_prod(module, g_inp, g_out, mat)
return result
diff --git a/backpack/extensions/curvmatprod/hmp/linear.py b/backpack/extensions/curvmatprod/hmp/linear.py
index 11428f18f..7917dfa1a 100644
--- a/backpack/extensions/curvmatprod/hmp/linear.py
+++ b/backpack/extensions/curvmatprod/hmp/linear.py
@@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_hmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_hmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/pchmp/__init__.py b/backpack/extensions/curvmatprod/pchmp/__init__.py
index 11dd33e03..058f1c762 100644
--- a/backpack/extensions/curvmatprod/pchmp/__init__.py
+++ b/backpack/extensions/curvmatprod/pchmp/__init__.py
@@ -17,12 +17,12 @@
ZeroPad2d,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
from . import activations, conv2d, dropout, flatten, linear, losses, padding, pooling
-class PCHMP(BackpropExtension):
+class PCHMP(SecondOrderBackpropExtension):
"""
Matrix-free multiplication with the block-diagonal positive-curvature Hessian (PCH).
diff --git a/backpack/extensions/curvmatprod/pchmp/conv2d.py b/backpack/extensions/curvmatprod/pchmp/conv2d.py
index 620fb2a6d..213601e6c 100644
--- a/backpack/extensions/curvmatprod/pchmp/conv2d.py
+++ b/backpack/extensions/curvmatprod/pchmp/conv2d.py
@@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_pchmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_pchmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/pchmp/flatten.py b/backpack/extensions/curvmatprod/pchmp/flatten.py
index 29403437c..1cbaedce1 100644
--- a/backpack/extensions/curvmatprod/pchmp/flatten.py
+++ b/backpack/extensions/curvmatprod/pchmp/flatten.py
@@ -5,9 +5,3 @@
class PCHMPFlatten(PCHMPBase):
def __init__(self):
super().__init__(derivatives=FlattenDerivatives())
-
- def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
- if self.derivatives.is_no_op(module):
- return backproped
- else:
- return super().backpropagate(ext, module, grad_inp, grad_out, backproped)
diff --git a/backpack/extensions/curvmatprod/pchmp/linear.py b/backpack/extensions/curvmatprod/pchmp/linear.py
index 37dbc49d7..d38539622 100644
--- a/backpack/extensions/curvmatprod/pchmp/linear.py
+++ b/backpack/extensions/curvmatprod/pchmp/linear.py
@@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_pchmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, result
- )
+ result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)
return result
@@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_pchmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
- result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
+ result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)
return result
diff --git a/backpack/extensions/curvmatprod/pchmp/pchmpbase.py b/backpack/extensions/curvmatprod/pchmp/pchmpbase.py
index 4c6cd1a07..8438c3750 100644
--- a/backpack/extensions/curvmatprod/pchmp/pchmpbase.py
+++ b/backpack/extensions/curvmatprod/pchmp/pchmpbase.py
@@ -20,9 +20,9 @@ def backpropagate(self, ext, module, g_inp, g_out, backproped):
Given mat → ℋz(x) mat, backpropagate mat → ℋx mat.
"""
- diagonal_or_zero_residual = (
- self.derivatives.hessian_is_zero() or self.derivatives.hessian_is_diagonal()
- )
+ diagonal_or_zero_residual = self.derivatives.hessian_is_zero(
+ module
+ ) or self.derivatives.hessian_is_diagonal(module)
if not diagonal_or_zero_residual:
raise ValueError("Only linear or element-wise operations supported.")
@@ -45,7 +45,7 @@ def h_in_mat_prod(mat):
result = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, result)
# Multiply with the residual term: mat → [∑ᵢ Hzᵢ(x) δzᵢ] mat.
- if not self.derivatives.hessian_is_zero():
+ if not self.derivatives.hessian_is_zero(module):
result += self.modified_residual_mat_prod(
ext, module, g_inp, g_out, mat, modify
)
diff --git a/backpack/extensions/firstorder/base.py b/backpack/extensions/firstorder/base.py
index e3c62d08b..b3529df3f 100644
--- a/backpack/extensions/firstorder/base.py
+++ b/backpack/extensions/firstorder/base.py
@@ -1,6 +1,29 @@
+"""Base class for first order extensions."""
+from typing import Dict, List, Type
+
+from torch.nn import Module
+
+from backpack.extensions.backprop_extension import FAIL_WARN, BackpropExtension
from backpack.extensions.module_extension import ModuleExtension
class FirstOrderModuleExtension(ModuleExtension):
- def backpropagate(self, ext, module, g_inp, g_out, bpQuantities):
- return None
+ """Base class for first order module extensions."""
+
+
+class FirstOrderBackpropExtension(BackpropExtension):
+ """Base backpropagation extension for first order."""
+
+ def __init__(
+ self,
+ savefield: str,
+ module_exts: Dict[Type[Module], ModuleExtension],
+ fail_mode: str = FAIL_WARN,
+ subsampling: List[int] = None,
+ ): # noqa: D107
+ super().__init__(
+ savefield, module_exts, fail_mode=fail_mode, subsampling=subsampling
+ )
+
+ def expects_backpropagation_quantities(self) -> bool: # noqa: D102
+ return False
diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py
index 4d72865ef..d67bc1d4d 100644
--- a/backpack/extensions/firstorder/batch_grad/__init__.py
+++ b/backpack/extensions/firstorder/batch_grad/__init__.py
@@ -1,34 +1,50 @@
+"""Contains the backpropagation extension for grad_batch: BatchGrad.
+
+It defines the module extension for each module.
+"""
+from typing import List
+
from torch.nn import (
+ LSTM,
+ RNN,
BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
+ Embedding,
Linear,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.firstorder.base import FirstOrderBackpropExtension
from . import (
- batchnorm1d,
+ batchnorm_nd,
conv1d,
conv2d,
conv3d,
conv_transpose1d,
conv_transpose2d,
conv_transpose3d,
+ embedding,
linear,
+ rnn,
)
-class BatchGrad(BackpropExtension):
+class BatchGrad(FirstOrderBackpropExtension):
"""Individual gradients for each sample in a minibatch.
Stores the output in ``grad_batch`` as a ``[N x ...]`` tensor,
where ``N`` batch size and ``...`` is the shape of the gradient.
+ If ``subsampling`` is specified, ``N`` is replaced by the number of active
+ samples.
+
.. note::
Beware of scaling issue
@@ -42,13 +58,19 @@ class BatchGrad(BackpropExtension):
The concept of individual gradients is only meaningful if the
objective is a sum of independent functions (no batchnorm).
-
"""
- def __init__(self):
+ def __init__(self, subsampling: List[int] = None):
+ """Initialization.
+
+ Defines extension for each module.
+
+ Args:
+ subsampling: Indices of samples in the mini-batch for which individual
+ gradients will be computed. Defaults to ``None`` (use all samples).
+ """
super().__init__(
savefield="grad_batch",
- fail_mode="WARNING",
module_exts={
Linear: linear.BatchGradLinear(),
Conv1d: conv1d.BatchGradConv1d(),
@@ -57,6 +79,12 @@ def __init__(self):
ConvTranspose1d: conv_transpose1d.BatchGradConvTranspose1d(),
ConvTranspose2d: conv_transpose2d.BatchGradConvTranspose2d(),
ConvTranspose3d: conv_transpose3d.BatchGradConvTranspose3d(),
- BatchNorm1d: batchnorm1d.BatchGradBatchNorm1d(),
+ BatchNorm1d: batchnorm_nd.BatchGradBatchNormNd(),
+ BatchNorm2d: batchnorm_nd.BatchGradBatchNormNd(),
+ BatchNorm3d: batchnorm_nd.BatchGradBatchNormNd(),
+ RNN: rnn.BatchGradRNN(),
+ LSTM: rnn.BatchGradLSTM(),
+ Embedding: embedding.BatchGradEmbedding(),
},
+ subsampling=subsampling,
)
diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py
index 1e25de41e..bd8e75a0d 100644
--- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py
+++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py
@@ -1,17 +1,94 @@
+"""Calculates the batch_grad derivative."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, List, Tuple
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.extensions.firstorder.base import FirstOrderModuleExtension
+from backpack.utils.subsampling import subsample
+
+if TYPE_CHECKING:
+ from backpack.extensions.firstorder import BatchGrad
class BatchGradBase(FirstOrderModuleExtension):
- def __init__(self, derivatives, params=None):
- self.derivatives = derivatives
+ """Calculates the batch_grad derivative.
+
+ Passes the calls for the parameters to the derivatives class.
+ Implements functions with method names from params.
+
+ If child class wants to overwrite these methods
+ - for example to support an additional external module -
+ it can do so using the interface for parameter "param1"::
+
+ param1(ext, module, g_inp, g_out, bpQuantities):
+ return batch_grads
+
+ In this case, the method is not overwritten by this class.
+ """
+
+ def __init__(
+ self, derivatives: BaseParameterDerivatives, params: List[str]
+ ) -> None:
+ """Initializes all methods.
+
+ If the param method has already been defined, it is left unchanged.
+
+ Args:
+ derivatives: Derivatives object used to apply parameter Jacobians.
+ params: List of parameter names.
+ """
+ self._derivatives = derivatives
+ for param_str in params:
+ if not hasattr(self, param_str):
+ setattr(self, param_str, self._make_param_function(param_str))
super().__init__(params=params)
- def bias(self, ext, module, g_inp, g_out, bpQuantities):
- return self.derivatives.bias_jac_t_mat_prod(
- module, g_inp, g_out, g_out[0], sum_batch=False
- )
+ def _make_param_function(
+ self, param_str: str
+ ) -> Callable[[BatchGrad, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]:
+ """Creates a function that calculates batch_grad w.r.t. param.
+
+ Args:
+ param_str: Parameter name.
+
+ Returns:
+ Function that calculates batch_grad wrt param
+ """
+
+ def param_function(
+ ext: BatchGrad,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ bpQuantities: None,
+ ) -> Tensor:
+ """Calculates batch_grad with the help of derivatives object.
+
+ Args:
+ ext: extension that is used
+ module: module that performed forward pass
+ g_inp: input gradient tensors
+ g_out: output gradient tensors
+ bpQuantities: additional quantities for second order
+
+ Returns:
+ Scaled individual gradients
+ """
+ subsampling = ext.get_subsampling()
+ batch_axis = 0
+
+ return self._derivatives.param_mjp(
+ param_str,
+ module,
+ g_inp,
+ g_out,
+ subsample(g_out[0], dim=batch_axis, subsampling=subsampling),
+ sum_batch=False,
+ subsampling=subsampling,
+ )
- def weight(self, ext, module, g_inp, g_out, bpQuantities):
- return self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, g_out[0], sum_batch=False
- )
+ return param_function
diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py b/backpack/extensions/firstorder/batch_grad/batchnorm1d.py
deleted file mode 100644
index 74d8737d9..000000000
--- a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives
-from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase
-
-
-class BatchGradBatchNorm1d(BatchGradBase):
- def __init__(self):
- super().__init__(
- derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"]
- )
diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py
new file mode 100644
index 000000000..83759b0ae
--- /dev/null
+++ b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py
@@ -0,0 +1,29 @@
+"""Contains grad_batch extension for BatchNorm."""
+from typing import Tuple, Union
+
+from torch import Tensor
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
+from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase
+from backpack.utils.errors import batch_norm_raise_error_if_train
+
+
+class BatchGradBatchNormNd(BatchGradBase):
+ """BatchGrad extension for BatchNorm."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=BatchNormNdDerivatives(), params=["bias", "weight"]
+ )
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None: # noqa: D102
+ batch_norm_raise_error_if_train(module, raise_error=False)
diff --git a/backpack/extensions/firstorder/batch_grad/embedding.py b/backpack/extensions/firstorder/batch_grad/embedding.py
new file mode 100644
index 000000000..35b41f7b0
--- /dev/null
+++ b/backpack/extensions/firstorder/batch_grad/embedding.py
@@ -0,0 +1,11 @@
+"""BatchGrad extension for Embedding."""
+from backpack.core.derivatives.embedding import EmbeddingDerivatives
+from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase
+
+
+class BatchGradEmbedding(BatchGradBase):
+ """BatchGrad extension for Embedding."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"])
diff --git a/backpack/extensions/firstorder/batch_grad/rnn.py b/backpack/extensions/firstorder/batch_grad/rnn.py
new file mode 100644
index 000000000..9b92f2642
--- /dev/null
+++ b/backpack/extensions/firstorder/batch_grad/rnn.py
@@ -0,0 +1,26 @@
+"""Contains BatchGradRNN."""
+from backpack.core.derivatives.lstm import LSTMDerivatives
+from backpack.core.derivatives.rnn import RNNDerivatives
+from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase
+
+
+class BatchGradRNN(BatchGradBase):
+ """Extension for RNN calculating grad_batch."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=RNNDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ )
+
+
+class BatchGradLSTM(BatchGradBase):
+ """Extension for LSTM calculating grad_batch."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=LSTMDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ )
diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py
index 90ae2a775..8be80f08e 100644
--- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py
+++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py
@@ -1,27 +1,36 @@
+"""Contains BatchL2Grad.
+
+Defines the backpropagation extension.
+Within it, define the extension for each module.
+"""
from torch.nn import (
+ LSTM,
+ RNN,
+ BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
+ Embedding,
Linear,
)
-from backpack.extensions.backprop_extension import BackpropExtension
-
-from . import (
- conv1d,
- conv2d,
- conv3d,
- convtranspose1d,
- convtranspose2d,
- convtranspose3d,
+from backpack.extensions.firstorder.base import FirstOrderBackpropExtension
+from backpack.extensions.firstorder.batch_l2_grad import (
+ batchnorm_nd,
+ convnd,
+ convtransposend,
+ embedding,
linear,
+ rnn,
)
-class BatchL2Grad(BackpropExtension):
+class BatchL2Grad(FirstOrderBackpropExtension):
"""The squared L2 norm of individual gradients in the minibatch.
Stores the output in ``batch_l2`` as a tensor of size ``[N]``,
@@ -40,16 +49,25 @@ class BatchL2Grad(BackpropExtension):
"""
def __init__(self):
+ """Initialization.
+
+ Define the extensions for each module.
+ """
super().__init__(
savefield="batch_l2",
- fail_mode="WARNING",
module_exts={
Linear: linear.BatchL2Linear(),
- Conv1d: conv1d.BatchL2Conv1d(),
- Conv2d: conv2d.BatchL2Conv2d(),
- Conv3d: conv3d.BatchL2Conv3d(),
- ConvTranspose1d: convtranspose1d.BatchL2ConvTranspose1d(),
- ConvTranspose2d: convtranspose2d.BatchL2ConvTranspose2d(),
- ConvTranspose3d: convtranspose3d.BatchL2ConvTranspose3d(),
+ Conv1d: convnd.BatchL2Conv1d(),
+ Conv2d: convnd.BatchL2Conv2d(),
+ Conv3d: convnd.BatchL2Conv3d(),
+ ConvTranspose1d: convtransposend.BatchL2ConvTranspose1d(),
+ ConvTranspose2d: convtransposend.BatchL2ConvTranspose2d(),
+ ConvTranspose3d: convtransposend.BatchL2ConvTranspose3d(),
+ RNN: rnn.BatchL2RNN(),
+ LSTM: rnn.BatchL2LSTM(),
+ BatchNorm1d: batchnorm_nd.BatchL2BatchNorm(),
+ BatchNorm2d: batchnorm_nd.BatchL2BatchNorm(),
+ BatchNorm3d: batchnorm_nd.BatchL2BatchNorm(),
+ Embedding: embedding.BatchL2Embedding(),
},
)
diff --git a/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py
new file mode 100644
index 000000000..f7b4f79dd
--- /dev/null
+++ b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py
@@ -0,0 +1,75 @@
+"""Contains Base class for batch_l2_grad."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, List, Tuple
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
+from backpack.extensions.firstorder.base import FirstOrderModuleExtension
+
+if TYPE_CHECKING:
+ from backpack.extensions import BatchL2Grad
+
+
+class BatchL2Base(FirstOrderModuleExtension):
+ """BaseExtension for batch_l2."""
+
+ def __init__(self, params: List[str], derivatives: BaseParameterDerivatives = None):
+ """Initialization.
+
+ If derivatives object is provided initializes methods that compute batch_l2.
+ If there is an existent method in a child class it is not overwritten.
+
+ Args:
+ params: parameter names
+ derivatives: derivatives object. Defaults to None.
+ """
+ if derivatives is not None:
+ self.derivatives: BaseParameterDerivatives = derivatives
+ for param_str in params:
+ if not hasattr(self, param_str):
+ setattr(self, param_str, self._make_param_function(param_str))
+ super().__init__(params=params)
+
+ def _make_param_function(
+ self, param_str: str
+ ) -> Callable[[BatchL2Grad, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]:
+ """Creates a function that calculates batch_l2.
+
+ Args:
+ param_str: name of parameter
+
+ Returns:
+ function that calculates batch_l2
+ """
+
+ def param_function(
+ ext: BatchL2Grad,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ bpQuantities: None,
+ ) -> Tensor:
+ """Calculates batch_l2 with the help of derivatives object.
+
+ Args:
+ ext: extension that is used
+ module: module that performed forward pass
+ g_inp: input gradient tensors
+ g_out: output gradient tensors
+ bpQuantities: additional quantities for second order
+
+ Returns:
+ batch_l2
+ """
+ param_dims: List[int] = list(range(1, 1 + getattr(module, param_str).dim()))
+ return (
+ self.derivatives.param_mjp(
+ param_str, module, g_inp, g_out, g_out[0], sum_batch=False
+ )
+ ** 2
+ ).sum(param_dims)
+
+ return param_function
diff --git a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py
new file mode 100644
index 000000000..9e1941804
--- /dev/null
+++ b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py
@@ -0,0 +1,27 @@
+"""Contains batch_l2 extension for BatchNorm."""
+from typing import Tuple, Union
+
+from torch import Tensor
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
+from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base
+from backpack.utils.errors import batch_norm_raise_error_if_train
+
+
+class BatchL2BatchNorm(BatchL2Base):
+ """batch_l2 extension for BatchNorm."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["weight", "bias"], BatchNormNdDerivatives())
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None: # noqa: D102
+ batch_norm_raise_error_if_train(module)
diff --git a/backpack/extensions/firstorder/batch_l2_grad/conv1d.py b/backpack/extensions/firstorder/batch_l2_grad/conv1d.py
deleted file mode 100644
index 64eb36066..000000000
--- a/backpack/extensions/firstorder/batch_l2_grad/conv1d.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from backpack.extensions.firstorder.batch_l2_grad.convnd import BatchL2ConvND
-
-
-class BatchL2Conv1d(BatchL2ConvND):
- def __init__(self):
- super().__init__(N=1, params=["bias", "weight"])
diff --git a/backpack/extensions/firstorder/batch_l2_grad/conv2d.py b/backpack/extensions/firstorder/batch_l2_grad/conv2d.py
deleted file mode 100644
index 327c90598..000000000
--- a/backpack/extensions/firstorder/batch_l2_grad/conv2d.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from backpack.extensions.firstorder.batch_l2_grad.convnd import BatchL2ConvND
-
-
-class BatchL2Conv2d(BatchL2ConvND):
- def __init__(self):
- super().__init__(N=2, params=["bias", "weight"])
diff --git a/backpack/extensions/firstorder/batch_l2_grad/conv3d.py b/backpack/extensions/firstorder/batch_l2_grad/conv3d.py
deleted file mode 100644
index 369f6bb8a..000000000
--- a/backpack/extensions/firstorder/batch_l2_grad/conv3d.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from backpack.extensions.firstorder.batch_l2_grad.convnd import BatchL2ConvND
-
-
-class BatchL2Conv3d(BatchL2ConvND):
- def __init__(self):
- super().__init__(N=3, params=["bias", "weight"])
diff --git a/backpack/extensions/firstorder/batch_l2_grad/convnd.py b/backpack/extensions/firstorder/batch_l2_grad/convnd.py
index 55c542a03..991eb96e2 100644
--- a/backpack/extensions/firstorder/batch_l2_grad/convnd.py
+++ b/backpack/extensions/firstorder/batch_l2_grad/convnd.py
@@ -1,23 +1,54 @@
+"""batch_l2 extension for Conv."""
from torch import einsum
-from backpack.extensions.firstorder.base import FirstOrderModuleExtension
+from backpack.core.derivatives.conv1d import Conv1DDerivatives
+from backpack.core.derivatives.conv2d import Conv2DDerivatives
+from backpack.core.derivatives.conv3d import Conv3DDerivatives
+from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base
from backpack.utils import conv as convUtils
-class BatchL2ConvND(FirstOrderModuleExtension):
- def __init__(self, N, params=None):
- super().__init__(params=params)
- self.N = N
+class BatchL2ConvND(BatchL2Base):
+ """batch_l2 extension for Conv."""
- # TODO Use bias Jacobian to compute `bias_gradient`
- def bias(self, ext, module, g_inp, g_out, backproped):
- spatial_dims = list(range(2, g_out[0].dim()))
- channel_dim = 1
+ def weight(self, ext, module, g_inp, g_out, backproped):
+ """batch_l2 for weight.
- return g_out[0].sum(spatial_dims).pow_(2).sum(channel_dim)
+ Args:
+ ext: extension
+ module: module
+ g_inp: input gradients
+ g_out: output gradients
+ backproped: backpropagation quantities
- def weight(self, ext, module, g_inp, g_out, backproped):
+ Returns:
+ batch_l2 for weight
+ """
X, dE_dY = convUtils.get_weight_gradient_factors(
- module.input0, g_out[0], module, self.N
+ module.input0, g_out[0], module
)
- return einsum("nmi,nki,nmj,nkj->n", (dE_dY, X, dE_dY, X))
+ return einsum("nmi,nki,nmj,nkj->n", dE_dY, X, dE_dY, X)
+
+
+class BatchL2Conv1d(BatchL2ConvND):
+ """batch_l2 extension for Conv1d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["bias", "weight"], derivatives=Conv1DDerivatives())
+
+
+class BatchL2Conv2d(BatchL2ConvND):
+ """batch_l2 extension for Conv2d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["bias", "weight"], derivatives=Conv2DDerivatives())
+
+
+class BatchL2Conv3d(BatchL2ConvND):
+ """batch_l2 extension for Conv3d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["bias", "weight"], derivatives=Conv3DDerivatives())
diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtranspose1d.py b/backpack/extensions/firstorder/batch_l2_grad/convtranspose1d.py
deleted file mode 100644
index aad345c58..000000000
--- a/backpack/extensions/firstorder/batch_l2_grad/convtranspose1d.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from backpack.extensions.firstorder.batch_l2_grad.convtransposend import (
- BatchL2ConvTransposeND,
-)
-
-
-class BatchL2ConvTranspose1d(BatchL2ConvTransposeND):
- def __init__(self):
- super().__init__(N=1, params=["bias", "weight"])
diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtranspose2d.py b/backpack/extensions/firstorder/batch_l2_grad/convtranspose2d.py
deleted file mode 100644
index 0d916fbed..000000000
--- a/backpack/extensions/firstorder/batch_l2_grad/convtranspose2d.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from backpack.extensions.firstorder.batch_l2_grad.convtransposend import (
- BatchL2ConvTransposeND,
-)
-
-
-class BatchL2ConvTranspose2d(BatchL2ConvTransposeND):
- def __init__(self):
- super().__init__(N=2, params=["bias", "weight"])
diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtranspose3d.py b/backpack/extensions/firstorder/batch_l2_grad/convtranspose3d.py
deleted file mode 100644
index 8a1f5e257..000000000
--- a/backpack/extensions/firstorder/batch_l2_grad/convtranspose3d.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from backpack.extensions.firstorder.batch_l2_grad.convtransposend import (
- BatchL2ConvTransposeND,
-)
-
-
-class BatchL2ConvTranspose3d(BatchL2ConvTransposeND):
- def __init__(self):
- super().__init__(N=3, params=["bias", "weight"])
diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py
index 9ceaa7881..3c54be1f5 100644
--- a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py
+++ b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py
@@ -1,23 +1,54 @@
+"""batch_l2 extension for ConvTranspose."""
from torch import einsum
-from backpack.extensions.firstorder.base import FirstOrderModuleExtension
+from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives
+from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives
+from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives
+from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base
from backpack.utils import conv_transpose as convTransposeUtils
-class BatchL2ConvTransposeND(FirstOrderModuleExtension):
- def __init__(self, N, params=None):
- super().__init__(params=params)
- self.N = N
+class BatchL2ConvTransposeND(BatchL2Base):
+ """batch_l2 extension for ConvTranspose."""
- # TODO Use bias Jacobian to compute `bias_gradient`
- def bias(self, ext, module, g_inp, g_out, backproped):
- spatial_dims = list(range(2, g_out[0].dim()))
- channel_dim = 1
+ def weight(self, ext, module, g_inp, g_out, backproped):
+ """batch_l2 for weight.
- return g_out[0].sum(spatial_dims).pow_(2).sum(channel_dim)
+ Args:
+ ext: extension
+ module: module
+ g_inp: input gradients
+ g_out: output gradients
+ backproped: backpropagation quantities
- def weight(self, ext, module, g_inp, g_out, backproped):
+ Returns:
+ batch_l2 for weight
+ """
X, dE_dY = convTransposeUtils.get_weight_gradient_factors(
- module.input0, g_out[0], module, self.N
+ module.input0, g_out[0], module
)
- return einsum("nmi,nki,nmj,nkj->n", (dE_dY, X, dE_dY, X))
+ return einsum("nmi,nki,nmj,nkj->n", dE_dY, X, dE_dY, X)
+
+
+class BatchL2ConvTranspose1d(BatchL2ConvTransposeND):
+ """batch_l2 extension for ConvTranspose1d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["bias", "weight"], derivatives=ConvTranspose1DDerivatives())
+
+
+class BatchL2ConvTranspose2d(BatchL2ConvTransposeND):
+ """batch_l2 extension for ConvTranspose2d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["bias", "weight"], derivatives=ConvTranspose2DDerivatives())
+
+
+class BatchL2ConvTranspose3d(BatchL2ConvTransposeND):
+ """batch_l2 extension for ConvTranspose3d."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["bias", "weight"], derivatives=ConvTranspose3DDerivatives())
diff --git a/backpack/extensions/firstorder/batch_l2_grad/embedding.py b/backpack/extensions/firstorder/batch_l2_grad/embedding.py
new file mode 100644
index 000000000..eca2b10cb
--- /dev/null
+++ b/backpack/extensions/firstorder/batch_l2_grad/embedding.py
@@ -0,0 +1,11 @@
+"""BatchL2 extension for Embedding."""
+from backpack.core.derivatives.embedding import EmbeddingDerivatives
+from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base
+
+
+class BatchL2Embedding(BatchL2Base):
+ """BatchL2 extension for Embedding."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"])
diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py
index 6a9b3b73c..1c5c10d7c 100644
--- a/backpack/extensions/firstorder/batch_l2_grad/linear.py
+++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py
@@ -1,15 +1,52 @@
-from torch import einsum
+"""Contains batch_l2 extension for Linear."""
+from __future__ import annotations
-from backpack.extensions.firstorder.base import FirstOrderModuleExtension
+from typing import TYPE_CHECKING, Tuple
+from torch import Tensor, einsum
+from torch.nn import Linear
+
+from backpack.core.derivatives.linear import LinearDerivatives
+from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base
+
+if TYPE_CHECKING:
+ from backpack.extensions import BatchL2Grad
+
+
+class BatchL2Linear(BatchL2Base):
+ """batch_l2 extension for Linear."""
-class BatchL2Linear(FirstOrderModuleExtension):
def __init__(self):
- super().__init__(params=["bias", "weight"])
+ """Initialization."""
+ super().__init__(["bias", "weight"], derivatives=LinearDerivatives())
+
+ def weight(
+ self,
+ ext: BatchL2Grad,
+ module: Linear,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ backproped: None,
+ ) -> Tensor:
+ """batch_l2 for weight.
+
+ Args:
+ ext: extension
+ module: module
+ g_inp: input gradients
+ g_out: output gradients
+ backproped: backpropagation quantities
- def bias(self, ext, module, g_inp, g_out, backproped):
- C_axis = 1
- return (g_out[0] ** 2).sum(C_axis)
+ Returns:
+ batch_l2 for weight
+ """
+ has_additional_axes = g_out[0].dim() > 2
- def weight(self, ext, module, g_inp, g_out, backproped):
- return einsum("ni,nj->n", (g_out[0] ** 2, module.input0 ** 2))
+ if has_additional_axes:
+ # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class
+ # implementation: https://github.com/fKunstner/backpack-discuss/issues/111
+ dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2)
+ X = module.input0.flatten(start_dim=1, end_dim=-2)
+ return einsum("nmi,nmj,nki,nkj->n", dE_dY, X, dE_dY, X)
+ else:
+ return einsum("ni,nj->n", g_out[0] ** 2, module.input0 ** 2)
diff --git a/backpack/extensions/firstorder/batch_l2_grad/rnn.py b/backpack/extensions/firstorder/batch_l2_grad/rnn.py
new file mode 100644
index 000000000..dbb1a1644
--- /dev/null
+++ b/backpack/extensions/firstorder/batch_l2_grad/rnn.py
@@ -0,0 +1,26 @@
+"""Contains BatchL2RNN."""
+from backpack.core.derivatives.lstm import LSTMDerivatives
+from backpack.core.derivatives.rnn import RNNDerivatives
+from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base
+
+
+class BatchL2RNN(BatchL2Base):
+ """Extension for RNN, calculating batch_l2."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ derivatives=RNNDerivatives(),
+ )
+
+
+class BatchL2LSTM(BatchL2Base):
+ """Extension for LSTM, calculating batch_l2."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ derivatives=LSTMDerivatives(),
+ )
diff --git a/backpack/extensions/firstorder/gradient/__init__.py b/backpack/extensions/firstorder/gradient/__init__.py
index 7a5228834..89c7cff43 100644
--- a/backpack/extensions/firstorder/gradient/__init__.py
+++ b/backpack/extensions/firstorder/gradient/__init__.py
@@ -1 +1,5 @@
+"""This package contains the gradient extension.
+
+It calculates the same result as torch backward().
+"""
# TODO: Rewrite variance to not need this extension
diff --git a/backpack/extensions/firstorder/gradient/base.py b/backpack/extensions/firstorder/gradient/base.py
index 3ebaf8932..b9f198855 100644
--- a/backpack/extensions/firstorder/gradient/base.py
+++ b/backpack/extensions/firstorder/gradient/base.py
@@ -1,17 +1,65 @@
+"""Calculates the gradient."""
from backpack.extensions.firstorder.base import FirstOrderModuleExtension
class GradBaseModule(FirstOrderModuleExtension):
- def __init__(self, derivatives, params=None):
+ """Calculates the gradient.
+
+ Passes the calls for the parameters to the derivatives class.
+ Implements functions with method names from params.
+
+ If child class wants to overwrite these methods
+ - for example to support an additional external module -
+ it can do so using the interface for parameter "param1"::
+
+ param1(ext, module, g_inp, g_out, bpQuantities):
+ return batch_grads
+
+ In this case, the method is not overwritten by this class.
+ """
+
+ def __init__(self, derivatives, params):
+ """Initializes all methods.
+
+ If the param method has already been defined, it is left unchanged.
+
+ Args:
+ derivatives(backpack.core.derivatives.basederivatives.BaseParameterDerivatives): # noqa: B950
+ Derivatives object assigned to self.derivatives.
+ params (list[str]): list of strings with parameter names.
+ For each, a method is assigned.
+ """
self.derivatives = derivatives
+ for param_str in params:
+ if not hasattr(self, param_str):
+ setattr(self, param_str, self._make_param_function(param_str))
super().__init__(params=params)
- def bias(self, ext, module, g_inp, g_out, bpQuantities):
- return self.derivatives.bias_jac_t_mat_prod(
- module, g_inp, g_out, g_out[0], sum_batch=True
- )
+ def _make_param_function(self, param_str):
+ """Creates a function that calculates gradient wrt param.
+
+ Args:
+ param_str: name of parameter
+
+ Returns:
+ function: function that calculates gradient wrt param
+ """
+
+ def param_function(ext, module, g_inp, g_out, bpQuantities):
+ """Calculates gradient with the help of derivatives object.
+
+ Args:
+ ext(backpack.extensions.BatchGrad): extension that is used
+ module(torch.nn.Module): module that performed forward pass
+ g_inp(tuple[torch.Tensor]): input gradient tensors
+ g_out(tuple[torch.Tensor]): output gradient tensors
+ bpQuantities(None): additional quantities for second order
+
+ Returns:
+ torch.Tensor: gradient of the batch, similar to autograd
+ """
+ return self.derivatives.param_mjp(
+ param_str, module, g_inp, g_out, g_out[0], sum_batch=True
+ )
- def weight(self, ext, module, g_inp, g_out, bpQuantities):
- return self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, g_out[0], sum_batch=True
- )
+ return param_function
diff --git a/backpack/extensions/firstorder/gradient/batchnorm1d.py b/backpack/extensions/firstorder/gradient/batchnorm1d.py
deleted file mode 100644
index 5e0f3b6fd..000000000
--- a/backpack/extensions/firstorder/gradient/batchnorm1d.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives
-
-from .base import GradBaseModule
-
-
-class GradBatchNorm1d(GradBaseModule):
- def __init__(self):
- super().__init__(
- derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"]
- )
diff --git a/backpack/extensions/firstorder/gradient/batchnorm_nd.py b/backpack/extensions/firstorder/gradient/batchnorm_nd.py
new file mode 100644
index 000000000..5bacc2ad6
--- /dev/null
+++ b/backpack/extensions/firstorder/gradient/batchnorm_nd.py
@@ -0,0 +1,30 @@
+"""Gradient extension for BatchNorm."""
+from typing import Tuple, Union
+
+from torch import Tensor
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
+from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.utils.errors import batch_norm_raise_error_if_train
+
+from .base import GradBaseModule
+
+
+class GradBatchNormNd(GradBaseModule):
+ """Gradient extension for BatchNorm."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=BatchNormNdDerivatives(), params=["bias", "weight"]
+ )
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None: # noqa: D102
+ batch_norm_raise_error_if_train(module)
diff --git a/backpack/extensions/firstorder/gradient/embedding.py b/backpack/extensions/firstorder/gradient/embedding.py
new file mode 100644
index 000000000..c394ae509
--- /dev/null
+++ b/backpack/extensions/firstorder/gradient/embedding.py
@@ -0,0 +1,11 @@
+"""Gradient extension for Embedding."""
+from backpack.core.derivatives.embedding import EmbeddingDerivatives
+from backpack.extensions.firstorder.gradient.base import GradBaseModule
+
+
+class GradEmbedding(GradBaseModule):
+ """Gradient extension for Embedding."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"])
diff --git a/backpack/extensions/firstorder/gradient/rnn.py b/backpack/extensions/firstorder/gradient/rnn.py
new file mode 100644
index 000000000..7ba76e626
--- /dev/null
+++ b/backpack/extensions/firstorder/gradient/rnn.py
@@ -0,0 +1,26 @@
+"""Contains GradRNN."""
+from backpack.core.derivatives.lstm import LSTMDerivatives
+from backpack.core.derivatives.rnn import RNNDerivatives
+from backpack.extensions.firstorder.gradient.base import GradBaseModule
+
+
+class GradRNN(GradBaseModule):
+ """Extension for RNN, calculating gradient."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=RNNDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ )
+
+
+class GradLSTM(GradBaseModule):
+ """Extension for LSTM, calculating gradient."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=LSTMDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ )
diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py
index fffe54a91..76891cff6 100644
--- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py
+++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py
@@ -1,27 +1,40 @@
+"""Contains backpropagation extension for sum_grad_squared: SumGradSquared.
+
+Defines module extension for each module.
+"""
from torch.nn import (
+ LSTM,
+ RNN,
+ BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
+ Embedding,
Linear,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.firstorder.base import FirstOrderBackpropExtension
from . import (
+ batchnorm_nd,
conv1d,
conv2d,
conv3d,
convtranspose1d,
convtranspose2d,
convtranspose3d,
+ embedding,
linear,
+ rnn,
)
-class SumGradSquared(BackpropExtension):
+class SumGradSquared(FirstOrderBackpropExtension):
"""The sum of individual-gradients-squared, or second moment of the gradient.
Stores the output in ``sum_grad_squared``. Same dimension as the gradient.
@@ -39,9 +52,12 @@ class SumGradSquared(BackpropExtension):
"""
def __init__(self):
+ """Initialization.
+
+ Defines module extension for each module.
+ """
super().__init__(
savefield="sum_grad_squared",
- fail_mode="WARNING",
module_exts={
Linear: linear.SGSLinear(),
Conv1d: conv1d.SGSConv1d(),
@@ -50,5 +66,11 @@ def __init__(self):
ConvTranspose1d: convtranspose1d.SGSConvTranspose1d(),
ConvTranspose2d: convtranspose2d.SGSConvTranspose2d(),
ConvTranspose3d: convtranspose3d.SGSConvTranspose3d(),
+ RNN: rnn.SGSRNN(),
+ LSTM: rnn.SGSLSTM(),
+ BatchNorm1d: batchnorm_nd.SGSBatchNormNd(),
+ BatchNorm2d: batchnorm_nd.SGSBatchNormNd(),
+ BatchNorm3d: batchnorm_nd.SGSBatchNormNd(),
+ Embedding: embedding.SGSEmbedding(),
},
)
diff --git a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py
new file mode 100644
index 000000000..9ad99de07
--- /dev/null
+++ b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py
@@ -0,0 +1,27 @@
+"""SGS extension for BatchNorm."""
+from typing import Tuple, Union
+
+from torch import Tensor
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
+from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase
+from backpack.utils.errors import batch_norm_raise_error_if_train
+
+
+class SGSBatchNormNd(SGSBase):
+ """SGS extension for BatchNorm."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(BatchNormNdDerivatives(), ["weight", "bias"])
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None: # noqa: D102
+ batch_norm_raise_error_if_train(module)
diff --git a/backpack/extensions/firstorder/sum_grad_squared/embedding.py b/backpack/extensions/firstorder/sum_grad_squared/embedding.py
new file mode 100644
index 000000000..62f34e86b
--- /dev/null
+++ b/backpack/extensions/firstorder/sum_grad_squared/embedding.py
@@ -0,0 +1,11 @@
+"""SGS extension for Embedding."""
+from backpack.core.derivatives.embedding import EmbeddingDerivatives
+from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase
+
+
+class SGSEmbedding(SGSBase):
+ """SGS extension for Embedding."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"])
diff --git a/backpack/extensions/firstorder/sum_grad_squared/linear.py b/backpack/extensions/firstorder/sum_grad_squared/linear.py
index 4cf75db1d..99bb73c40 100644
--- a/backpack/extensions/firstorder/sum_grad_squared/linear.py
+++ b/backpack/extensions/firstorder/sum_grad_squared/linear.py
@@ -18,4 +18,13 @@ def weight(self, ext, module, g_inp, g_out, backproped):
For details, see page 12 (paragraph about "second moment") of the
paper (https://arxiv.org/pdf/1912.10985.pdf).
"""
- return einsum("ni,nj->ij", (g_out[0] ** 2, module.input0 ** 2))
+ has_additional_axes = g_out[0].dim() > 2
+
+ if has_additional_axes:
+ # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class
+ # implementation: https://github.com/fKunstner/backpack-discuss/issues/111
+ dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2)
+ X = module.input0.flatten(start_dim=1, end_dim=-2)
+ return einsum("nmi,nmj,nki,nkj->ij", dE_dY, X, dE_dY, X)
+ else:
+ return einsum("ni,nj->ij", g_out[0] ** 2, module.input0 ** 2)
diff --git a/backpack/extensions/firstorder/sum_grad_squared/rnn.py b/backpack/extensions/firstorder/sum_grad_squared/rnn.py
new file mode 100644
index 000000000..129229144
--- /dev/null
+++ b/backpack/extensions/firstorder/sum_grad_squared/rnn.py
@@ -0,0 +1,26 @@
+"""Contains SGSRNN module."""
+from backpack.core.derivatives.lstm import LSTMDerivatives
+from backpack.core.derivatives.rnn import RNNDerivatives
+from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase
+
+
+class SGSRNN(SGSBase):
+ """Extension for RNN, calculating sum_gradient_squared."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=RNNDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ )
+
+
+class SGSLSTM(SGSBase):
+ """Extension for LSTM, calculating sum_gradient_squared."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ derivatives=LSTMDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ )
diff --git a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py
index e654e8016..3e1d171ab 100644
--- a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py
+++ b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py
@@ -1,20 +1,73 @@
+"""Contains SGSBase, the base module for sum_grad_squared extension."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, List, Tuple
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.extensions.firstorder.base import FirstOrderModuleExtension
+if TYPE_CHECKING:
+ from backpack.extensions import SumGradSquared
+
class SGSBase(FirstOrderModuleExtension):
- def __init__(self, derivatives, params=None):
- self.derivatives = derivatives
- self.N_axis = 0
+ """Base class for extensions calculating sum_grad_squared."""
+
+ def __init__(self, derivatives: BaseParameterDerivatives, params: List[str] = None):
+ """Initialization.
+
+ For each parameter a function is initialized that is named like the parameter
+
+ Args:
+ derivatives: calculates the derivatives wrt parameters
+ params: list of parameter names
+ """
+ self.derivatives: BaseParameterDerivatives = derivatives
+ self.N_axis: int = 0
+ for param_str in params:
+ if not hasattr(self, param_str):
+ setattr(self, param_str, self._make_param_function(param_str))
super().__init__(params=params)
- def bias(self, ext, module, g_inp, g_out, bpQuantities):
- grad_batch = self.derivatives.bias_jac_t_mat_prod(
- module, g_inp, g_out, g_out[0], sum_batch=False
- )
- return (grad_batch ** 2).sum(self.N_axis)
-
- def weight(self, ext, module, g_inp, g_out, bpQuantities):
- grad_batch = self.derivatives.weight_jac_t_mat_prod(
- module, g_inp, g_out, g_out[0], sum_batch=False
- )
- return (grad_batch ** 2).sum(self.N_axis)
+ def _make_param_function(
+ self, param_str: str
+ ) -> Callable[[SumGradSquared, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]:
+ """Creates a function that calculates sum_grad_squared.
+
+ Args:
+ param_str: name of parameter
+
+ Returns:
+ function that calculates sum_grad_squared
+ """
+
+ def param_function(
+ ext: SumGradSquared,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ bpQuantities: None,
+ ) -> Tensor:
+ """Calculates sum_grad_squared with the help of derivatives object.
+
+ Args:
+ ext: extension that is used
+ module: module that performed forward pass
+ g_inp: input gradient tensors
+ g_out: output gradient tensors
+ bpQuantities: additional quantities for second order
+
+ Returns:
+ sum_grad_squared
+ """
+ return (
+ self.derivatives.param_mjp(
+ param_str, module, g_inp, g_out, g_out[0], sum_batch=False
+ )
+ ** 2
+ ).sum(self.N_axis)
+
+ return param_function
diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py
index b1f162228..eeb90902f 100644
--- a/backpack/extensions/firstorder/variance/__init__.py
+++ b/backpack/extensions/firstorder/variance/__init__.py
@@ -1,27 +1,40 @@
+"""Defines backpropagation extension for variance: Variance.
+
+Defines module extension for each module.
+"""
from torch.nn import (
+ LSTM,
+ RNN,
+ BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
+ Embedding,
Linear,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.firstorder.base import FirstOrderBackpropExtension
from . import (
+ batchnorm_nd,
conv1d,
conv2d,
conv3d,
convtranspose1d,
convtranspose2d,
convtranspose3d,
+ embedding,
linear,
+ rnn,
)
-class Variance(BackpropExtension):
+class Variance(FirstOrderBackpropExtension):
"""Estimates the variance of the gradient using the samples in the minibatch.
Stores the output in ``variance``. Same dimension as the gradient.
@@ -39,9 +52,12 @@ class Variance(BackpropExtension):
"""
def __init__(self):
+ """Initialization.
+
+ Defines module extension for each module.
+ """
super().__init__(
savefield="variance",
- fail_mode="WARNING",
module_exts={
Linear: linear.VarianceLinear(),
Conv1d: conv1d.VarianceConv1d(),
@@ -50,5 +66,11 @@ def __init__(self):
ConvTranspose1d: convtranspose1d.VarianceConvTranspose1d(),
ConvTranspose2d: convtranspose2d.VarianceConvTranspose2d(),
ConvTranspose3d: convtranspose3d.VarianceConvTranspose3d(),
+ RNN: rnn.VarianceRNN(),
+ LSTM: rnn.VarianceLSTM(),
+ BatchNorm1d: batchnorm_nd.VarianceBatchNormNd(),
+ BatchNorm2d: batchnorm_nd.VarianceBatchNormNd(),
+ BatchNorm3d: batchnorm_nd.VarianceBatchNormNd(),
+ Embedding: embedding.VarianceEmbedding(),
},
)
diff --git a/backpack/extensions/firstorder/variance/batchnorm_nd.py b/backpack/extensions/firstorder/variance/batchnorm_nd.py
new file mode 100644
index 000000000..d2b8512e5
--- /dev/null
+++ b/backpack/extensions/firstorder/variance/batchnorm_nd.py
@@ -0,0 +1,28 @@
+"""Variance extension for BatchNorm."""
+from typing import Tuple, Union
+
+from torch import Tensor
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.firstorder.gradient.batchnorm_nd import GradBatchNormNd
+from backpack.extensions.firstorder.sum_grad_squared.batchnorm_nd import SGSBatchNormNd
+from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule
+from backpack.utils.errors import batch_norm_raise_error_if_train
+
+
+class VarianceBatchNormNd(VarianceBaseModule):
+ """Variance extension for BatchNorm."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(["weight", "bias"], GradBatchNormNd(), SGSBatchNormNd())
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None: # noqa: D102
+ batch_norm_raise_error_if_train(module)
diff --git a/backpack/extensions/firstorder/variance/embedding.py b/backpack/extensions/firstorder/variance/embedding.py
new file mode 100644
index 000000000..1b38472a6
--- /dev/null
+++ b/backpack/extensions/firstorder/variance/embedding.py
@@ -0,0 +1,16 @@
+"""Variance extension for Embedding."""
+from backpack.extensions.firstorder.gradient.embedding import GradEmbedding
+from backpack.extensions.firstorder.sum_grad_squared.embedding import SGSEmbedding
+from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule
+
+
+class VarianceEmbedding(VarianceBaseModule):
+ """Variance extension for Embedding."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ grad_extension=GradEmbedding(),
+ sgs_extension=SGSEmbedding(),
+ params=["weight"],
+ )
diff --git a/backpack/extensions/firstorder/variance/rnn.py b/backpack/extensions/firstorder/variance/rnn.py
new file mode 100644
index 000000000..62baa1258
--- /dev/null
+++ b/backpack/extensions/firstorder/variance/rnn.py
@@ -0,0 +1,29 @@
+"""Contains VarianceRNN."""
+
+from backpack.extensions.firstorder.gradient.rnn import GradLSTM, GradRNN
+from backpack.extensions.firstorder.sum_grad_squared.rnn import SGSLSTM, SGSRNN
+from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule
+
+
+class VarianceRNN(VarianceBaseModule):
+ """Extension for RNN, calculating variance."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ grad_extension=GradRNN(),
+ sgs_extension=SGSRNN(),
+ )
+
+
+class VarianceLSTM(VarianceBaseModule):
+ """Extension for LSTM, calculating variance."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ grad_extension=GradLSTM(),
+ sgs_extension=SGSLSTM(),
+ )
diff --git a/backpack/extensions/firstorder/variance/variance_base.py b/backpack/extensions/firstorder/variance/variance_base.py
index 64d8c17e2..b91aac935 100644
--- a/backpack/extensions/firstorder/variance/variance_base.py
+++ b/backpack/extensions/firstorder/variance/variance_base.py
@@ -1,30 +1,85 @@
+"""Contains VarianceBaseModule."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, List, Tuple
+
+from torch import Tensor
+from torch.nn import Module
+
from backpack.extensions.firstorder.base import FirstOrderModuleExtension
+if TYPE_CHECKING:
+ from backpack.extensions import Variance
+ from backpack.extensions.firstorder.gradient.base import GradBaseModule
+ from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase
+
class VarianceBaseModule(FirstOrderModuleExtension):
- def __init__(self, params, grad_extension, sgs_extension):
+ """Base class for extensions calculating variance."""
+
+ def __init__(
+ self,
+ params: List[str],
+ grad_extension: GradBaseModule,
+ sgs_extension: SGSBase,
+ ):
+ """Initialization.
+
+ Creates a function named after each parameter.
+
+ Args:
+ params: list of parameter names
+ grad_extension: the extension calculating grad.
+ sgs_extension: the extension calculating squared_grad_sum.
+ """
+ self.grad_ext: GradBaseModule = grad_extension
+ self.sgs_ext: SGSBase = sgs_extension
+ for param_str in params:
+ if not hasattr(self, param_str):
+ setattr(self, param_str, self._make_param_function(param_str))
super().__init__(params=params)
- self.grad_ext = grad_extension
- self.sgs_ext = sgs_extension
@staticmethod
- def variance_from(grad, sgs, N):
+ def _variance_from(grad: Tensor, sgs: Tensor, N: int) -> Tensor:
avgg_squared = (grad / N) ** 2
avg_gsquared = sgs / N
return avg_gsquared - avgg_squared
- def bias(self, ext, module, g_inp, g_out, backproped):
- N = g_out[0].shape[0]
- return self.variance_from(
- self.grad_ext.bias(ext, module, g_inp, g_out, backproped),
- self.sgs_ext.bias(ext, module, g_inp, g_out, backproped),
- N,
- )
-
- def weight(self, ext, module, g_inp, g_out, backproped):
- N = g_out[0].shape[0]
- return self.variance_from(
- self.grad_ext.weight(ext, module, g_inp, g_out, backproped),
- self.sgs_ext.weight(ext, module, g_inp, g_out, backproped),
- N,
- )
+ def _make_param_function(
+ self, param: str
+ ) -> Callable[[Variance, Module, Tuple[Tensor], Tuple[Tensor], None], Tensor]:
+ """Creates a function that calculates variance of grad_batch.
+
+ Args:
+ param: name of parameter
+
+ Returns:
+ function that calculates variance of grad_batch
+ """
+
+ def param_function(
+ ext: Variance,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ bpQuantities: None,
+ ) -> Tensor:
+ """Calculates variance with the help of derivatives object.
+
+ Args:
+ ext: extension that is used
+ module: module that performed forward pass
+ g_inp: input gradient tensors
+ g_out: output gradient tensors
+ bpQuantities: additional quantities for second order
+
+ Returns:
+ variance of the batch
+ """
+ return self._variance_from(
+ getattr(self.grad_ext, param)(ext, module, g_inp, g_out, bpQuantities),
+ getattr(self.sgs_ext, param)(ext, module, g_inp, g_out, bpQuantities),
+ g_out[0].shape[0],
+ )
+
+ return param_function
diff --git a/backpack/extensions/mat_to_mat_jac_base.py b/backpack/extensions/mat_to_mat_jac_base.py
index ca9d214e0..937d78b33 100644
--- a/backpack/extensions/mat_to_mat_jac_base.py
+++ b/backpack/extensions/mat_to_mat_jac_base.py
@@ -1,24 +1,57 @@
-from .module_extension import ModuleExtension
+"""Contains base class for second order extensions."""
+from typing import List, Tuple, Union
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.core.derivatives.basederivatives import BaseDerivatives
+from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.module_extension import ModuleExtension
class MatToJacMat(ModuleExtension):
- """
- Base class for backpropagating matrices by multiplying with Jacobians.
- """
+ """Base class for backpropagation of matrices by multiplying with Jacobians."""
+
+ def __init__(self, derivatives: BaseDerivatives, params: List[str] = None):
+ """Initialization.
- def __init__(self, derivatives, params=None):
+ Args:
+ derivatives: class containing derivatives
+ params: list of parameter names
+ """
super().__init__(params)
self.derivatives = derivatives
- def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
+ def backpropagate(
+ self,
+ ext: BackpropExtension,
+ module: Module,
+ grad_inp: Tuple[Tensor],
+ grad_out: Tuple[Tensor],
+ backproped: Union[List[Tensor], Tensor],
+ ) -> Union[List[Tensor], Tensor]:
+ """Propagates second order information back.
+
+ Args:
+ ext: BackPACK extension
+ module: module through which to perform backpropagation
+ grad_inp: input gradients
+ grad_out: output gradients
+ backproped: backpropagation information
+
+ Returns:
+ derivative wrt input
+ """
+ subsampling = ext.get_subsampling()
if isinstance(backproped, list):
- M_list = [
- self.derivatives.jac_t_mat_prod(module, grad_inp, grad_out, M)
+ return [
+ self.derivatives.jac_t_mat_prod(
+ module, grad_inp, grad_out, M, subsampling=subsampling
+ )
for M in backproped
]
- return list(M_list)
else:
return self.derivatives.jac_t_mat_prod(
- module, grad_inp, grad_out, backproped
+ module, grad_inp, grad_out, backproped, subsampling=subsampling
)
diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py
index 1ed811a34..96e4fdfd1 100644
--- a/backpack/extensions/module_extension.py
+++ b/backpack/extensions/module_extension.py
@@ -1,9 +1,20 @@
-import warnings
+"""Contains base class for BackPACK module extensions."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, List, Tuple
+from warnings import warn
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.utils.module_classification import is_loss
+
+if TYPE_CHECKING:
+ from backpack import BackpropExtension
class ModuleExtension:
- """
- Base class for a Module Extension for BackPACK.
+ """Base class for a Module Extension for BackPACK.
Descendants of this class need to
- define what parameters of the Module need to be treated (weight, bias)
@@ -12,92 +23,231 @@ class ModuleExtension:
needs to be propagated through the graph.
"""
- def __init__(self, params=None):
- """
- Parameters
- ----------
- params: [str]
- List of module parameters that need special treatment.
- for each param `p` in the list, instances of the extended module `m`
- need to have a field `m.p` and the class extending `ModuleExtension`
- need to provide a method with the same signature as the `backprop`
- method.
- The result of this method will be saved in the savefield of `m.p`.
- """
- if params is None:
- params = []
-
- self.__params = params
+ def __init__(self, params: List[str] = None):
+ """Initialization.
- for param in self.__params:
- extFunc = getattr(self, param, None)
- if extFunc is None:
- raise NotImplementedError
+ Args:
+ params: List of module parameters that need special treatment.
+ For each param `p` in the list, instances of the extended module `m`
+ need to have a field `m.p` and the class extending `ModuleExtension`
+ needs to provide a method with the same signature as the `backpropagate`
+ method.
+ The result of this method will be saved in the savefield of `m.p`.
- def backpropagate(self, ext, module, g_inp, g_out, bpQuantities):
+ Raises:
+ NotImplementedError: if child class doesn't have a method for each parameter
"""
- Main method to extend to backpropagate additional information through
- the graph.
-
- Parameters
- ----------
- ext: BackpropExtension
- Instance of the extension currently running
- module: torch.nn.Module
- Instance of the extended module
- g_inp: [Tensor]
- Gradient of the loss w.r.t. the inputs
- g_out: Tensor
- Gradient of the loss w.r.t. the output
- bpQuantities:
- Quantities backpropagated w.r.t. the output
+ self.__params: List[str] = [] if params is None else params
+
+ for param in self.__params:
+ if not hasattr(self, param):
+ raise NotImplementedError(
+ f"The module extension {self} is missing an implementation "
+ f"of how to calculate the quantity for {param}. "
+ f"This should be realized in a function "
+ f"{param}(extension, module, g_inp, g_out, bpQuantities) -> Any."
+ )
+
+ def backpropagate(
+ self,
+ extension: BackpropExtension,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ bpQuantities: Any,
+ ) -> Any:
+ """Backpropagation of additional information through the graph.
+
+ Args:
+ extension: Instance of the extension currently running
+ module: Instance of the extended module
+ g_inp: Gradient of the loss w.r.t. the inputs
+ g_out: Gradient of the loss w.r.t. the output
+ bpQuantities: Quantities backpropagated w.r.t. the output
Returns
- -------
- bpQuantities:
Quantities backpropagated w.r.t. the input
"""
- warnings.warn("Backpropagate has not been overwritten")
-
- def apply(self, ext, module, g_inp, g_out):
- inp = module.input0
- out = module.output
-
- bpQuantities = self.__backproped_quantities(ext, out)
+ warn("Backpropagate has not been overwritten")
+
+ def __call__(
+ self,
+ extension: BackpropExtension,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None:
+ """Apply all actions required by the extension.
+
+ Fetch backpropagated quantities from module output, apply backpropagation
+ rule, and store as backpropagated quantities for the module input(s).
+
+ Args:
+ extension: current backpropagation extension
+ module: current module
+ g_inp: input gradients
+ g_out: output gradients
+
+ Raises:
+ AssertionError: if there is no saved quantity although extension expects one,
+ or if a backpropagated quantity is expected, but there is None and the old
+ backward hook is used and the module is not a Flatten no op.
+ """
+ self.check_hyperparameters_module_extension(extension, module, g_inp, g_out)
+ delete_old_quantities = not self.__should_retain_backproped_quantities(module)
+ bp_quantity = self.__get_backproped_quantity(
+ extension, module.output, delete_old_quantities
+ )
+ if (
+ extension.expects_backpropagation_quantities()
+ and bp_quantity is None
+ and not is_loss(module)
+ ):
+ raise AssertionError(
+ "BackPACK extension expects a backpropagation quantity but it is None. "
+ f"Module: {module}, Extension: {extension}."
+ )
for param in self.__params:
if self.__param_exists_and_requires_grad(module, param):
extFunc = getattr(self, param)
- extValue = extFunc(ext, module, g_inp, g_out, bpQuantities)
- self.__save(extValue, ext, module, param)
+ extValue = extFunc(extension, module, g_inp, g_out, bp_quantity)
+ self.__save_value_on_parameter(extValue, extension, module, param)
- bpQuantities = self.backpropagate(ext, module, g_inp, g_out, bpQuantities)
-
- self.__backprop_quantities(ext, inp, out, bpQuantities)
+ module_inputs = self.__get_inputs_for_backpropagation(extension, module)
+ if module_inputs:
+ bp_quantity = self.backpropagate(
+ extension, module, g_inp, g_out, bp_quantity
+ )
+ for module_inp in module_inputs:
+ self.__save_backproped_quantity(extension, module_inp, bp_quantity)
@staticmethod
- def __backproped_quantities(ext, out):
- return getattr(out, ext.savefield, None)
+ def __get_inputs_for_backpropagation(
+ extension: BackpropExtension, module: Module
+ ) -> Tuple[Tensor]:
+ """Returns the inputs on which a backpropagation should be performed.
+
+ Args:
+ extension: current extension
+ module: current module
+
+ Returns:
+ the inputs which need a backpropagation quantity
+ """
+ module_inputs: Tuple[Tensor, ...] = ()
+
+ if extension.expects_backpropagation_quantities():
+ i = 0
+ while hasattr(module, f"input{i}"):
+ input = getattr(module, f"input{i}")
+ if input.requires_grad:
+ module_inputs += (input,)
+ i += 1
+
+ return module_inputs
@staticmethod
- def __backprop_quantities(ext, inp, out, bpQuantities):
+ def __should_retain_backproped_quantities(module: Module) -> bool:
+ """Whether the backpropagation quantities should be kept.
- setattr(inp, ext.savefield, bpQuantities)
+ This is old code inherited and not tested.
- is_a_leaf = out.grad_fn is None
- retain_grad_is_on = getattr(out, "retains_grad", False)
- inp_is_out = id(inp) == id(out)
- should_retain_grad = is_a_leaf or retain_grad_is_on or inp_is_out
+ Args:
+ module: current module
+
+ Returns:
+ whether backpropagation quantities should be kept
+ """
+ is_a_leaf = module.output.grad_fn is None
+ retain_grad_is_on = getattr(module.output, "retains_grad", False)
+ # inp_is_out = id(module.input0) == id(module.output)
+ should_retain_grad = is_a_leaf or retain_grad_is_on # or inp_is_out
+ return should_retain_grad
+
+ @staticmethod
+ def __get_backproped_quantity(
+ extension: BackpropExtension,
+ reference_tensor: Tensor,
+ delete_old: bool,
+ ) -> Tensor or None:
+ """Fetch backpropagated quantities attached to the module output.
+
+ The property reference_tensor.data_ptr() is used as a reference.
+
+ Args:
+ extension: current BackPACK extension
+ reference_tensor: the output Tensor of the current module
+ delete_old: whether to delete the old backpropagated quantity
+
+ Returns:
+ the backpropagation quantity
+ """
+ return extension.saved_quantities.retrieve_quantity(
+ reference_tensor.data_ptr(), delete_old
+ )
- if not should_retain_grad:
- if hasattr(out, ext.savefield):
- delattr(out, ext.savefield)
+ @staticmethod
+ def __save_backproped_quantity(
+ extension: BackpropExtension, reference_tensor: Tensor, bpQuantities: Any
+ ) -> None:
+ """Save additional information backpropagated for a tensor.
+
+ Args:
+ extension: current BackPACK extension
+ reference_tensor: reference tensor for which additional information
+ is backpropagated.
+ bpQuantities: backpropagation quantities that should be saved
+ """
+ extension.saved_quantities.save_quantity(
+ reference_tensor.data_ptr(),
+ bpQuantities,
+ extension.accumulate_backpropagated_quantities,
+ )
@staticmethod
- def __param_exists_and_requires_grad(module, param):
- param_exists = getattr(module, param) is not None
- return param_exists and getattr(module, param).requires_grad
+ def __param_exists_and_requires_grad(module: Module, param_str: str) -> bool:
+ """Whether the module has the parameter and it requires gradient.
+
+ Args:
+ module: current module
+ param_str: parameter name
+
+ Returns:
+ whether the module has the parameter and it requires gradient
+ """
+ param_exists = getattr(module, param_str) is not None
+ return param_exists and getattr(module, param_str).requires_grad
@staticmethod
- def __save(value, extension, module, param):
- setattr(getattr(module, param), extension.savefield, value)
+ def __save_value_on_parameter(
+ value: Any, extension: BackpropExtension, module: Module, param_str: str
+ ) -> None:
+ """Saves the value on the parameter of that module.
+
+ Args:
+ value: The value that should be saved.
+ extension: The current BackPACK extension.
+ module: current module
+ param_str: parameter name
+ """
+ setattr(getattr(module, param_str), extension.savefield, value)
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None:
+ """Check whether the current module is supported by the extension.
+
+ Child classes can override this method.
+
+ Args:
+ ext: current extension
+ module: module
+ g_inp: input gradients
+ g_out: output gradients
+ """
+ pass
diff --git a/backpack/extensions/saved_quantities.py b/backpack/extensions/saved_quantities.py
new file mode 100644
index 000000000..c5006fc2a
--- /dev/null
+++ b/backpack/extensions/saved_quantities.py
@@ -0,0 +1,50 @@
+"""Class for saving backpropagation quantities."""
+from typing import Any, Callable, Dict, Union
+
+from torch import Tensor
+
+
+class SavedQuantities:
+ """Implements interface to save backpropagation quantities."""
+
+ def __init__(self):
+ """Initialization."""
+ self._saved_quantities: Dict[int, Tensor] = {}
+
+ def save_quantity(
+ self,
+ key: int,
+ quantity: Tensor,
+ accumulation_function: Callable[[Any, Any], Any],
+ ) -> None:
+ """Saves the quantity under the specified key.
+
+ Accumulate quantities which already have an entry.
+
+ Args:
+ key: data_ptr() of reference tensor (module.input0).
+ quantity: tensor to save
+ accumulation_function: function defining how to accumulate quantity
+ """
+ if key in self._saved_quantities:
+ existing = self.retrieve_quantity(key, delete_old=True)
+ save_value = accumulation_function(existing, quantity)
+ else:
+ save_value = quantity
+
+ self._saved_quantities[key] = save_value
+
+ def retrieve_quantity(self, key: int, delete_old: bool) -> Union[Tensor, None]:
+ """Returns the saved quantity.
+
+ Args:
+ key: data_ptr() of module.output.
+ delete_old: whether to delete the old quantity
+
+ Returns:
+ the saved quantity, None if it does not exist
+ """
+ get_value = (
+ self._saved_quantities.pop if delete_old else self._saved_quantities.get
+ )
+ return get_value(key, None)
diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py
index afe9ebc2f..3d3de33c9 100644
--- a/backpack/extensions/secondorder/__init__.py
+++ b/backpack/extensions/secondorder/__init__.py
@@ -17,11 +17,22 @@
:func:`KFRA `,
:func:`KFLR `.
- The diagonal of the Hessian :func:`DiagHessian `
+- The symmetric (square root) factorization of the GGN/Fisher information,
+ using exact computation
+ (:func:`SqrtGGNExact `)
+ or a Monte-Carlo (MC) approximation
+ (:func:`SqrtGGNMC`)
"""
-from .diag_ggn import BatchDiagGGNExact, BatchDiagGGNMC, DiagGGNExact, DiagGGNMC
-from .diag_hessian import BatchDiagHessian, DiagHessian
-from .hbp import HBP, KFAC, KFLR, KFRA
+from backpack.extensions.secondorder.diag_ggn import (
+ BatchDiagGGNExact,
+ BatchDiagGGNMC,
+ DiagGGNExact,
+ DiagGGNMC,
+)
+from backpack.extensions.secondorder.diag_hessian import BatchDiagHessian, DiagHessian
+from backpack.extensions.secondorder.hbp import HBP, KFAC, KFLR, KFRA
+from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC
__all__ = [
"DiagGGNExact",
@@ -34,4 +45,6 @@
"KFLR",
"KFRA",
"HBP",
+ "SqrtGGNExact",
+ "SqrtGGNMC",
]
diff --git a/backpack/extensions/secondorder/base.py b/backpack/extensions/secondorder/base.py
new file mode 100644
index 000000000..d65fa548f
--- /dev/null
+++ b/backpack/extensions/secondorder/base.py
@@ -0,0 +1,9 @@
+"""Contains base classes for second order extensions."""
+from backpack.extensions.backprop_extension import BackpropExtension
+
+
+class SecondOrderBackpropExtension(BackpropExtension):
+ """Base backpropagation extension for second order."""
+
+ def expects_backpropagation_quantities(self) -> bool: # noqa: D102
+ return True
diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py
index 98ab7c0ea..e18d7eea3 100644
--- a/backpack/extensions/secondorder/diag_ggn/__init__.py
+++ b/backpack/extensions/secondorder/diag_ggn/__init__.py
@@ -1,9 +1,28 @@
+"""Module contains definitions of DiagGGN extensions.
+
+Contains:
+DiagGGN(BackpropExtension)
+DiagGGNExact(DiagGGN)
+DiagGGNMC(DiagGGN)
+BatchDiagGGN(BackpropExtension)
+BatchDiagGGNExact(BatchDiagGGN)
+BatchDiagGGNMC(BatchDiagGGN)
+"""
+from torch import Tensor
from torch.nn import (
ELU,
+ LSTM,
+ RNN,
SELU,
+ AdaptiveAvgPool1d,
+ AdaptiveAvgPool2d,
+ AdaptiveAvgPool3d,
AvgPool1d,
AvgPool2d,
AvgPool3d,
+ BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
@@ -12,7 +31,9 @@
ConvTranspose3d,
CrossEntropyLoss,
Dropout,
+ Embedding,
Flatten,
+ Identity,
LeakyReLU,
Linear,
LogSigmoid,
@@ -26,27 +47,36 @@
ZeroPad2d,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.custom_module.branching import SumModule
+from backpack.custom_module.permute import Permute
+from backpack.custom_module.scale_module import ScaleModule
+from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
from backpack.extensions.secondorder.hbp import LossHessianStrategy
from . import (
activations,
+ adaptive_avg_pool_nd,
+ batchnorm_nd,
conv1d,
conv2d,
conv3d,
convtranspose1d,
convtranspose2d,
convtranspose3d,
+ custom_module,
dropout,
+ embedding,
flatten,
linear,
losses,
padding,
+ permute,
pooling,
+ rnn,
)
-class DiagGGN(BackpropExtension):
+class DiagGGN(SecondOrderBackpropExtension):
"""Base class for diagonal generalized Gauss-Newton/Fisher matrix."""
VALID_LOSS_HESSIAN_STRATEGIES = [
@@ -54,7 +84,16 @@ class DiagGGN(BackpropExtension):
LossHessianStrategy.SAMPLING,
]
- def __init__(self, loss_hessian_strategy, savefield):
+ def __init__(self, loss_hessian_strategy: str, savefield: str):
+ """Initialization.
+
+ Args:
+ loss_hessian_strategy: either LossHessianStrategy.EXACT or .SAMPLING
+ savefield: the field where to save the calculated property
+
+ Raises:
+ ValueError: if chosen loss strategy is not valid.
+ """
if loss_hessian_strategy not in self.VALID_LOSS_HESSIAN_STRATEGIES:
raise ValueError(
"Unknown hessian strategy: {}".format(loss_hessian_strategy)
@@ -91,13 +130,31 @@ def __init__(self, loss_hessian_strategy, savefield):
LogSigmoid: activations.DiagGGNLogSigmoid(),
ELU: activations.DiagGGNELU(),
SELU: activations.DiagGGNSELU(),
+ Identity: custom_module.DiagGGNScaleModule(),
+ ScaleModule: custom_module.DiagGGNScaleModule(),
+ SumModule: custom_module.DiagGGNSumModule(),
+ RNN: rnn.DiagGGNRNN(),
+ LSTM: rnn.DiagGGNLSTM(),
+ Permute: permute.DiagGGNPermute(),
+ AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1),
+ AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(2),
+ AdaptiveAvgPool3d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(3),
+ BatchNorm1d: batchnorm_nd.DiagGGNBatchNormNd(),
+ BatchNorm2d: batchnorm_nd.DiagGGNBatchNormNd(),
+ BatchNorm3d: batchnorm_nd.DiagGGNBatchNormNd(),
+ Embedding: embedding.DiagGGNEmbedding(),
},
)
+ def accumulate_backpropagated_quantities(
+ self, existing: Tensor, other: Tensor
+ ) -> Tensor: # noqa: D102
+ return existing + other
+
class DiagGGNExact(DiagGGN):
- """
- Diagonal of the Generalized Gauss-Newton/Fisher.
+ """Diagonal of the Generalized Gauss-Newton/Fisher.
+
Uses the exact Hessian of the loss w.r.t. the model output.
Stores the output in :code:`diag_ggn_exact`,
@@ -105,16 +162,16 @@ class DiagGGNExact(DiagGGN):
For a faster but less precise alternative,
see :py:meth:`backpack.extensions.DiagGGNMC`.
-
"""
def __init__(self):
+ """Initialization. Chooses exact loss strategy and savefield diag_ggn_exact."""
super().__init__(LossHessianStrategy.EXACT, "diag_ggn_exact")
class DiagGGNMC(DiagGGN):
- """
- Diagonal of the Generalized Gauss-Newton/Fisher.
+ """Diagonal of the Generalized Gauss-Newton/Fisher.
+
Uses a Monte-Carlo approximation of
the Hessian of the loss w.r.t. the model output.
@@ -123,21 +180,27 @@ class DiagGGNMC(DiagGGN):
For a more precise but slower alternative,
see :py:meth:`backpack.extensions.DiagGGNExact`.
-
- Args:
- mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``.
-
"""
- def __init__(self, mc_samples=1):
+ def __init__(self, mc_samples: int = 1):
+ """Initialization. Chooses sampling loss strategy and savefield diag_ggn_mc.
+
+ Args:
+ mc_samples: Number of Monte-Carlo samples. Default: ``1``.
+ """
self._mc_samples = mc_samples
super().__init__(LossHessianStrategy.SAMPLING, "diag_ggn_mc")
- def get_num_mc_samples(self):
+ def get_num_mc_samples(self) -> int:
+ """Returns number of Monte-Carlo samples.
+
+ Returns:
+ number of Monte-Carlo samples
+ """
return self._mc_samples
-class BatchDiagGGN(BackpropExtension):
+class BatchDiagGGN(SecondOrderBackpropExtension):
"""Base class for batched diagonal generalized Gauss-Newton/Fisher matrix."""
VALID_LOSS_HESSIAN_STRATEGIES = [
@@ -145,7 +208,16 @@ class BatchDiagGGN(BackpropExtension):
LossHessianStrategy.SAMPLING,
]
- def __init__(self, loss_hessian_strategy, savefield):
+ def __init__(self, loss_hessian_strategy: str, savefield: str):
+ """Initialization.
+
+ Args:
+ loss_hessian_strategy: either LossHessianStrategy.EXACT or .SAMPLING
+ savefield: name of variable where to save calculated quantity
+
+ Raises:
+ ValueError: if chosen loss strategy is not valid.
+ """
if loss_hessian_strategy not in self.VALID_LOSS_HESSIAN_STRATEGIES:
raise ValueError(
"Unknown hessian strategy: {}".format(loss_hessian_strategy)
@@ -181,13 +253,31 @@ def __init__(self, loss_hessian_strategy, savefield):
LogSigmoid: activations.DiagGGNLogSigmoid(),
ELU: activations.DiagGGNELU(),
SELU: activations.DiagGGNSELU(),
+ Identity: custom_module.DiagGGNScaleModule(),
+ ScaleModule: custom_module.DiagGGNScaleModule(),
+ SumModule: custom_module.DiagGGNSumModule(),
+ RNN: rnn.BatchDiagGGNRNN(),
+ LSTM: rnn.BatchDiagGGNLSTM(),
+ Permute: permute.DiagGGNPermute(),
+ AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(1),
+ AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(2),
+ AdaptiveAvgPool3d: adaptive_avg_pool_nd.DiagGGNAdaptiveAvgPoolNd(3),
+ BatchNorm1d: batchnorm_nd.BatchDiagGGNBatchNormNd(),
+ BatchNorm2d: batchnorm_nd.BatchDiagGGNBatchNormNd(),
+ BatchNorm3d: batchnorm_nd.BatchDiagGGNBatchNormNd(),
+ Embedding: embedding.BatchDiagGGNEmbedding(),
},
)
+ def accumulate_backpropagated_quantities(
+ self, existing: Tensor, other: Tensor
+ ) -> Tensor: # noqa: D102
+ return existing + other
+
class BatchDiagGGNExact(BatchDiagGGN):
- """
- Individual diagonal of the Generalized Gauss-Newton/Fisher.
+ """Individual diagonal of the Generalized Gauss-Newton/Fisher.
+
Uses the exact Hessian of the loss w.r.t. the model output.
Stores the output in ``diag_ggn_exact_batch`` as a ``[N x ...]`` tensor,
@@ -195,15 +285,16 @@ class BatchDiagGGNExact(BatchDiagGGN):
"""
def __init__(self):
- super().__init__(
- loss_hessian_strategy=LossHessianStrategy.EXACT,
- savefield="diag_ggn_exact_batch",
- )
+ """Initialization.
+
+ Chooses exact loss strategy and savefield diag_ggn_exact_batch.
+ """
+ super().__init__(LossHessianStrategy.EXACT, "diag_ggn_exact_batch")
class BatchDiagGGNMC(BatchDiagGGN):
- """
- Individual diagonal of the Generalized Gauss-Newton/Fisher.
+ """Individual diagonal of the Generalized Gauss-Newton/Fisher.
+
Uses a Monte-Carlo approximation of
the Hessian of the loss w.r.t. the model output.
@@ -212,18 +303,23 @@ class BatchDiagGGNMC(BatchDiagGGN):
For a more precise but slower alternative,
see :py:meth:`backpack.extensions.BatchDiagGGNExact`.
+ """
- Args:
- mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``.
+ def __init__(self, mc_samples: int = 1):
+ """Initialization.
- """
+ Chooses sampling loss strategy and savefield diag_ggn_mc_batch.
- def __init__(self, mc_samples=1):
+ Args:
+ mc_samples: Number of Monte-Carlo samples. Default: ``1``.
+ """
self._mc_samples = mc_samples
- super().__init__(
- loss_hessian_strategy=LossHessianStrategy.SAMPLING,
- savefield="diag_ggn_mc_batch",
- )
+ super().__init__(LossHessianStrategy.SAMPLING, "diag_ggn_mc_batch")
+
+ def get_num_mc_samples(self) -> int:
+ """Returns number of Monte-Carlo samples.
- def get_num_mc_samples(self):
+ Returns:
+ number of Monte-Carlo samples
+ """
return self._mc_samples
diff --git a/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py b/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py
new file mode 100644
index 000000000..b2cfceb46
--- /dev/null
+++ b/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py
@@ -0,0 +1,15 @@
+"""DiagGGN extension for AdaptiveAvgPool."""
+from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+
+
+class DiagGGNAdaptiveAvgPoolNd(DiagGGNBaseModule):
+ """DiagGGN extension for AdaptiveAvgPool."""
+
+ def __init__(self, N: int):
+ """Initialization.
+
+ Args:
+ N: number of free dimensions, e.g. use N=1 for AdaptiveAvgPool1d
+ """
+ super().__init__(derivatives=AdaptiveAvgPoolNDDerivatives(N=N))
diff --git a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py
new file mode 100644
index 000000000..c0aa7c29b
--- /dev/null
+++ b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py
@@ -0,0 +1,44 @@
+"""DiagGGN extension for BatchNorm."""
+from typing import Tuple, Union
+
+from torch import Tensor
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
+from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+from backpack.utils.errors import batch_norm_raise_error_if_train
+
+
+class DiagGGNBatchNormNd(DiagGGNBaseModule):
+ """DiagGGN extension for BatchNorm."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=True)
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None: # noqa: D102
+ batch_norm_raise_error_if_train(module)
+
+
+class BatchDiagGGNBatchNormNd(DiagGGNBaseModule):
+ """BatchDiagGGN extension for BatchNorm."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(BatchNormNdDerivatives(), ["weight", "bias"], sum_batch=False)
+
+ def check_hyperparameters_module_extension(
+ self,
+ ext: BackpropExtension,
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ ) -> None: # noqa: D102
+ batch_norm_raise_error_if_train(module)
diff --git a/backpack/extensions/secondorder/diag_ggn/custom_module.py b/backpack/extensions/secondorder/diag_ggn/custom_module.py
new file mode 100644
index 000000000..293ed4281
--- /dev/null
+++ b/backpack/extensions/secondorder/diag_ggn/custom_module.py
@@ -0,0 +1,20 @@
+"""DiagGGN extensions for backpack's custom modules."""
+from backpack.core.derivatives.scale_module import ScaleModuleDerivatives
+from backpack.core.derivatives.sum_module import SumModuleDerivatives
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+
+
+class DiagGGNScaleModule(DiagGGNBaseModule):
+ """DiagGGN extension for ScaleModule."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(derivatives=ScaleModuleDerivatives())
+
+
+class DiagGGNSumModule(DiagGGNBaseModule):
+ """DiagGGN extension for SumModule."""
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__(derivatives=SumModuleDerivatives())
diff --git a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py
index e97334dea..203b8ebd6 100644
--- a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py
+++ b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py
@@ -1,6 +1,76 @@
+"""Contains DiagGGN base class."""
+from typing import Callable, List, Tuple, Union
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.core.derivatives.basederivatives import (
+ BaseDerivatives,
+ BaseParameterDerivatives,
+)
from backpack.extensions.mat_to_mat_jac_base import MatToJacMat
+from backpack.extensions.module_extension import ModuleExtension
class DiagGGNBaseModule(MatToJacMat):
- def __init__(self, derivatives, params=None):
+ """Base class for DiagGGN extension."""
+
+ def __init__(
+ self,
+ derivatives: Union[BaseDerivatives, BaseParameterDerivatives],
+ params: List[str] = None,
+ sum_batch: bool = None,
+ ):
+ """Initialization.
+
+ If params and sum_batch is provided:
+ Creates a method named after parameter for each parameter. Checks if that
+ method is implemented, so a child class can implement a more efficient version.
+
+ Args:
+ derivatives: class containing derivatives
+ params: list of parameter names. Defaults to None.
+ sum_batch: Specifies whether the created method sums along batch axis.
+ For BatchDiagGGNModule should be False.
+ For DiagGGNModule should be True.
+ Defaults to None.
+ """
+ if params is not None and sum_batch is not None:
+ for param in params:
+ if not hasattr(self, param):
+ setattr(self, param, self._make_param_method(param, sum_batch))
super().__init__(derivatives, params=params)
+
+ def _make_param_method(
+ self, param_str: str, sum_batch: bool
+ ) -> Callable[
+ [ModuleExtension, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor
+ ]:
+ def _param(
+ ext: ModuleExtension,
+ module: Module,
+ grad_inp: Tuple[Tensor],
+ grad_out: Tuple[Tensor],
+ backproped: Tensor,
+ ) -> Tensor:
+ """Returns diagonal of GGN.
+
+ Args:
+ ext: extension
+ module: module through which to backpropagate
+ grad_inp: input gradients
+ grad_out: output gradients
+ backproped: backpropagated information
+
+ Returns:
+ diagonal
+ """
+ axis: Tuple[int] = (0, 1) if sum_batch else (0,)
+ return (
+ self.derivatives.param_mjp(
+ param_str, module, grad_inp, grad_out, backproped, sum_batch=False
+ )
+ ** 2
+ ).sum(axis=axis)
+
+ return _param
diff --git a/backpack/extensions/secondorder/diag_ggn/embedding.py b/backpack/extensions/secondorder/diag_ggn/embedding.py
new file mode 100644
index 000000000..1021b089b
--- /dev/null
+++ b/backpack/extensions/secondorder/diag_ggn/embedding.py
@@ -0,0 +1,23 @@
+"""DiagGGN extension for Embedding."""
+from backpack.core.derivatives.embedding import EmbeddingDerivatives
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+
+
+class DiagGGNEmbedding(DiagGGNBaseModule):
+ """DiagGGN extension of Embedding."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(
+ derivatives=EmbeddingDerivatives(), params=["weight"], sum_batch=True
+ )
+
+
+class BatchDiagGGNEmbedding(DiagGGNBaseModule):
+ """DiagGGN extension of Embedding."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(
+ derivatives=EmbeddingDerivatives(), params=["weight"], sum_batch=False
+ )
diff --git a/backpack/extensions/secondorder/diag_ggn/flatten.py b/backpack/extensions/secondorder/diag_ggn/flatten.py
index 60c1ca8d4..cf6f63358 100644
--- a/backpack/extensions/secondorder/diag_ggn/flatten.py
+++ b/backpack/extensions/secondorder/diag_ggn/flatten.py
@@ -5,9 +5,3 @@
class DiagGGNFlatten(DiagGGNBaseModule):
def __init__(self):
super().__init__(derivatives=FlattenDerivatives())
-
- def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
- if self.derivatives.is_no_op(module):
- return backproped
- else:
- return super().backpropagate(ext, module, grad_inp, grad_out, backproped)
diff --git a/backpack/extensions/secondorder/diag_ggn/losses.py b/backpack/extensions/secondorder/diag_ggn/losses.py
index 377adb52b..6679a9b3e 100644
--- a/backpack/extensions/secondorder/diag_ggn/losses.py
+++ b/backpack/extensions/secondorder/diag_ggn/losses.py
@@ -9,7 +9,6 @@
class DiagGGNLoss(DiagGGNBaseModule):
def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
hess_func = self.make_loss_hessian_func(ext)
-
return hess_func(module, grad_inp, grad_out)
def make_loss_hessian_func(self, ext):
@@ -21,7 +20,6 @@ def make_loss_hessian_func(self, ext):
elif loss_hessian_strategy == LossHessianStrategy.SAMPLING:
mc_samples = ext.get_num_mc_samples()
return partial(self.derivatives.sqrt_hessian_sampled, mc_samples=mc_samples)
-
else:
raise ValueError(
"Unknown hessian strategy {}".format(loss_hessian_strategy)
diff --git a/backpack/extensions/secondorder/diag_ggn/permute.py b/backpack/extensions/secondorder/diag_ggn/permute.py
new file mode 100644
index 000000000..7e7db118c
--- /dev/null
+++ b/backpack/extensions/secondorder/diag_ggn/permute.py
@@ -0,0 +1,11 @@
+"""Module defining DiagGGNPermute."""
+from backpack.core.derivatives.permute import PermuteDerivatives
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+
+
+class DiagGGNPermute(DiagGGNBaseModule):
+ """DiagGGN extension of Permute."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(derivatives=PermuteDerivatives())
diff --git a/backpack/extensions/secondorder/diag_ggn/rnn.py b/backpack/extensions/secondorder/diag_ggn/rnn.py
new file mode 100644
index 000000000..7c926c945
--- /dev/null
+++ b/backpack/extensions/secondorder/diag_ggn/rnn.py
@@ -0,0 +1,52 @@
+"""Module implementing GGN for RNN."""
+from backpack.core.derivatives.lstm import LSTMDerivatives
+from backpack.core.derivatives.rnn import RNNDerivatives
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+
+
+class DiagGGNRNN(DiagGGNBaseModule):
+ """Calculating diagonal of GGN."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(
+ derivatives=RNNDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ sum_batch=True,
+ )
+
+
+class BatchDiagGGNRNN(DiagGGNBaseModule):
+ """Calculating per-sample diagonal of GGN."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(
+ derivatives=RNNDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ sum_batch=False,
+ )
+
+
+class DiagGGNLSTM(DiagGGNBaseModule):
+ """Calculating GGN diagonal of LSTM."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(
+ derivatives=LSTMDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ sum_batch=True,
+ )
+
+
+class BatchDiagGGNLSTM(DiagGGNBaseModule):
+ """Calculating per-sample diagonal of GGN."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(
+ derivatives=LSTMDerivatives(),
+ params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"],
+ sum_batch=False,
+ )
diff --git a/backpack/extensions/secondorder/diag_hessian/__init__.py b/backpack/extensions/secondorder/diag_hessian/__init__.py
index 6bb9933d8..ffdf7d639 100644
--- a/backpack/extensions/secondorder/diag_hessian/__init__.py
+++ b/backpack/extensions/secondorder/diag_hessian/__init__.py
@@ -31,7 +31,7 @@
ZeroPad2d,
)
-from backpack.extensions.backprop_extension import BackpropExtension
+from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
from . import (
activations,
@@ -50,7 +50,7 @@
)
-class DiagHessian(BackpropExtension):
+class DiagHessian(SecondOrderBackpropExtension):
"""BackPACK extension that computes the Hessian diagonal.
Stores the output in :code:`diag_h`, has the same dimensions as the gradient.
@@ -96,7 +96,7 @@ def __init__(self):
)
-class BatchDiagHessian(BackpropExtension):
+class BatchDiagHessian(SecondOrderBackpropExtension):
"""BackPACK extensions that computes the per-sample (individual) Hessian diagonal.
Stores the output in ``diag_h_batch`` as a ``[N x ...]`` tensor,
diff --git a/backpack/extensions/secondorder/diag_hessian/diag_h_base.py b/backpack/extensions/secondorder/diag_hessian/diag_h_base.py
index 58c15ee60..acf58e718 100644
--- a/backpack/extensions/secondorder/diag_hessian/diag_h_base.py
+++ b/backpack/extensions/secondorder/diag_hessian/diag_h_base.py
@@ -24,9 +24,9 @@ def backpropagate(self, ext, module, g_inp, g_out, backproped):
return {"matrices": bp_matrices, "signs": bp_signs}
def __local_curvatures(self, module, g_inp, g_out):
- if self.derivatives.hessian_is_zero():
+ if self.derivatives.hessian_is_zero(module):
return []
- if not self.derivatives.hessian_is_diagonal():
+ if not self.derivatives.hessian_is_diagonal(module):
raise NotImplementedError
def positive_part(sign, H):
diff --git a/backpack/extensions/secondorder/diag_hessian/flatten.py b/backpack/extensions/secondorder/diag_hessian/flatten.py
index d6d28b7c2..d8b01357a 100644
--- a/backpack/extensions/secondorder/diag_hessian/flatten.py
+++ b/backpack/extensions/secondorder/diag_hessian/flatten.py
@@ -5,9 +5,3 @@
class DiagHFlatten(DiagHBaseModule):
def __init__(self):
super().__init__(derivatives=FlattenDerivatives())
-
- def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
- if self.derivatives.is_no_op(module):
- return backproped
- else:
- return super().backpropagate(ext, module, grad_inp, grad_out, backproped)
diff --git a/backpack/extensions/secondorder/hbp/__init__.py b/backpack/extensions/secondorder/hbp/__init__.py
index 7f469da4d..2e529c4ef 100644
--- a/backpack/extensions/secondorder/hbp/__init__.py
+++ b/backpack/extensions/secondorder/hbp/__init__.py
@@ -13,8 +13,8 @@
ZeroPad2d,
)
-from backpack.extensions.backprop_extension import BackpropExtension
from backpack.extensions.curvature import Curvature
+from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
from backpack.extensions.secondorder.hbp.hbp_options import (
BackpropStrategy,
ExpectationApproximation,
@@ -24,7 +24,7 @@
from . import activations, conv2d, dropout, flatten, linear, losses, padding, pooling
-class HBP(BackpropExtension):
+class HBP(SecondOrderBackpropExtension):
def __init__(
self,
curv_type,
diff --git a/backpack/extensions/secondorder/hbp/flatten.py b/backpack/extensions/secondorder/hbp/flatten.py
index 990d0b023..c20014e92 100644
--- a/backpack/extensions/secondorder/hbp/flatten.py
+++ b/backpack/extensions/secondorder/hbp/flatten.py
@@ -5,9 +5,3 @@
class HBPFlatten(HBPBaseModule):
def __init__(self):
super().__init__(derivatives=FlattenDerivatives())
-
- def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
- if self.derivatives.is_no_op(module):
- return backproped
- else:
- return super().backpropagate(ext, module, grad_inp, grad_out, backproped)
diff --git a/backpack/extensions/secondorder/hbp/hbpbase.py b/backpack/extensions/secondorder/hbp/hbpbase.py
index 6bf2647a8..e6258b79b 100644
--- a/backpack/extensions/secondorder/hbp/hbpbase.py
+++ b/backpack/extensions/secondorder/hbp/hbpbase.py
@@ -34,10 +34,10 @@ def backpropagate_batch_average(self, ext, module, g_inp, g_out, H):
return ggn
def second_order_module_effects(self, module, g_inp, g_out):
- if self.derivatives.hessian_is_zero():
+ if self.derivatives.hessian_is_zero(module):
return None
- elif not self.derivatives.hessian_is_diagonal():
+ elif not self.derivatives.hessian_is_diagonal(module):
raise NotImplementedError(
"Residual terms are only supported for elementwise functions"
)
diff --git a/backpack/extensions/secondorder/hbp/linear.py b/backpack/extensions/secondorder/hbp/linear.py
index 779459e14..c89791a3c 100644
--- a/backpack/extensions/secondorder/hbp/linear.py
+++ b/backpack/extensions/secondorder/hbp/linear.py
@@ -1,4 +1,5 @@
from torch import einsum
+from torch.nn import Linear
from backpack.core.derivatives.linear import LinearDerivatives
from backpack.extensions.secondorder.hbp.hbp_options import (
@@ -13,6 +14,7 @@ def __init__(self):
super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"])
def weight(self, ext, module, g_inp, g_out, backproped):
+ self.check_parameters(ext, module)
bp_strategy = ext.get_backprop_strategy()
if BackpropStrategy.is_batch_average(bp_strategy):
@@ -44,6 +46,7 @@ def _factor_from_sqrt(self, backproped):
return [einsum("vni,vnj->ij", (backproped, backproped))]
def bias(self, ext, module, g_inp, g_out, backproped):
+ self.check_parameters(ext, module)
bp_strategy = ext.get_backprop_strategy()
if BackpropStrategy.is_batch_average(bp_strategy):
@@ -61,3 +64,19 @@ def __mean_input_outer(self, module):
N = module.input0.size(0)
flat_input = module.input0.reshape(N, -1)
return einsum("ni,nj->ij", (flat_input, flat_input)) / N
+
+ def check_parameters(self, ext, module: Linear) -> None:
+ """Raise an exception if module parameters are not supported.
+
+ Args:
+ ext (KFAC or KFRA or KFLR): Extension calling out to the module.
+ module: Linear layer.
+
+ Raises:
+ NotImplementedError: If the setting is not implemented.
+ """
+ if module.input0.dim() != 2:
+ raise NotImplementedError(
+ f"Only 2d inputs are supported by {ext.__class__.__name__} "
+ + f"(got {module.input0.dim()})."
+ )
diff --git a/backpack/extensions/secondorder/sqrt_ggn/__init__.py b/backpack/extensions/secondorder/sqrt_ggn/__init__.py
new file mode 100644
index 000000000..ebd31dae4
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/__init__.py
@@ -0,0 +1,181 @@
+"""Defines base class and extensions for computing the GGN/Fisher matrix square root."""
+
+from typing import List, Union
+
+from torch.nn import (
+ ELU,
+ SELU,
+ AvgPool1d,
+ AvgPool2d,
+ AvgPool3d,
+ Conv1d,
+ Conv2d,
+ Conv3d,
+ ConvTranspose1d,
+ ConvTranspose2d,
+ ConvTranspose3d,
+ CrossEntropyLoss,
+ Dropout,
+ Embedding,
+ Flatten,
+ LeakyReLU,
+ Linear,
+ LogSigmoid,
+ MaxPool1d,
+ MaxPool2d,
+ MaxPool3d,
+ MSELoss,
+ ReLU,
+ Sigmoid,
+ Tanh,
+ ZeroPad2d,
+)
+
+from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
+from backpack.extensions.secondorder.hbp import LossHessianStrategy
+from backpack.extensions.secondorder.sqrt_ggn import (
+ activations,
+ convnd,
+ convtransposend,
+ dropout,
+ embedding,
+ flatten,
+ linear,
+ losses,
+ padding,
+ pooling,
+)
+
+
+class SqrtGGN(SecondOrderBackpropExtension):
+ """Base class for extensions that compute the GGN/Fisher matrix square root."""
+
+ def __init__(
+ self,
+ loss_hessian_strategy: str,
+ savefield: str,
+ subsampling: Union[List[int], None],
+ ):
+ """Store approximation for backpropagated object and where to save the result.
+
+ Args:
+ loss_hessian_strategy: Which approximation is used for the backpropagated
+ loss Hessian. Must be ``'exact'`` or ``'sampling'``.
+ savefield: Attribute under which the quantity is saved in a parameter.
+ subsampling: Indices of active samples. ``None`` uses the full mini-batch.
+ """
+ self.loss_hessian_strategy = loss_hessian_strategy
+ super().__init__(
+ savefield=savefield,
+ fail_mode="ERROR",
+ module_exts={
+ MSELoss: losses.SqrtGGNMSELoss(),
+ CrossEntropyLoss: losses.SqrtGGNCrossEntropyLoss(),
+ Linear: linear.SqrtGGNLinear(),
+ MaxPool1d: pooling.SqrtGGNMaxPool1d(),
+ MaxPool2d: pooling.SqrtGGNMaxPool2d(),
+ AvgPool1d: pooling.SqrtGGNAvgPool1d(),
+ MaxPool3d: pooling.SqrtGGNMaxPool3d(),
+ AvgPool2d: pooling.SqrtGGNAvgPool2d(),
+ AvgPool3d: pooling.SqrtGGNAvgPool3d(),
+ ZeroPad2d: padding.SqrtGGNZeroPad2d(),
+ Conv1d: convnd.SqrtGGNConv1d(),
+ Conv2d: convnd.SqrtGGNConv2d(),
+ Conv3d: convnd.SqrtGGNConv3d(),
+ ConvTranspose1d: convtransposend.SqrtGGNConvTranspose1d(),
+ ConvTranspose2d: convtransposend.SqrtGGNConvTranspose2d(),
+ ConvTranspose3d: convtransposend.SqrtGGNConvTranspose3d(),
+ Dropout: dropout.SqrtGGNDropout(),
+ Flatten: flatten.SqrtGGNFlatten(),
+ ReLU: activations.SqrtGGNReLU(),
+ Sigmoid: activations.SqrtGGNSigmoid(),
+ Tanh: activations.SqrtGGNTanh(),
+ LeakyReLU: activations.SqrtGGNLeakyReLU(),
+ LogSigmoid: activations.SqrtGGNLogSigmoid(),
+ ELU: activations.SqrtGGNELU(),
+ SELU: activations.SqrtGGNSELU(),
+ Embedding: embedding.SqrtGGNEmbedding(),
+ },
+ subsampling=subsampling,
+ )
+
+ def get_loss_hessian_strategy(self) -> str:
+ """Return the strategy used to represent the backpropagated loss Hessian.
+
+ Returns:
+ Loss Hessian strategy.
+ """
+ return self.loss_hessian_strategy
+
+
+class SqrtGGNExact(SqrtGGN):
+ """Exact matrix square root of the generalized Gauss-Newton/Fisher.
+
+ Uses the exact Hessian of the loss w.r.t. the model output.
+
+ Stores the output in :code:`sqrt_ggn_exact`, has shape ``[C, N, param.shape]``,
+ where ``C`` is the model output dimension (number of classes for classification
+ problems) and ``N`` is the batch size. If sub-sampling is enabled, ``N`` is
+ replaced by the number of active samples, ``len(subsampling)``.
+
+ For a faster but less precise alternative, see
+ :py:meth:`backpack.extensions.SqrtGGNMC`.
+
+ .. note::
+
+ (Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_exact``
+ can be viewed as a ``[C * N, param.numel()]`` matrix. Concatenating this
+ matrix over all parameters results in a matrix ``Vᵀ``, which
+ is the GGN/Fisher's matrix square root, i.e. ``G = V Vᵀ``.
+ """
+
+ def __init__(self, subsampling: List[int] = None):
+ """Use exact loss Hessian, store results under ``sqrt_ggn_exact``.
+
+ Args:
+ subsampling: Indices of active samples. Defaults to ``None`` (use all
+ samples in the mini-batch).
+ """
+ super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact", subsampling)
+
+
+class SqrtGGNMC(SqrtGGN):
+ """Approximate matrix square root of the generalized Gauss-Newton/Fisher.
+
+ Uses a Monte-Carlo (MC) approximation of the Hessian of the loss w.r.t. the model
+ output.
+
+ Stores the output in :code:`sqrt_ggn_mc`, has shape ``[M, N, param.shape]``,
+ where ``M`` is the number of Monte-Carlo samples and ``N`` is the batch size.
+ If sub-sampling is enabled, ``N`` is replaced by the number of active samples,
+ ``len(subsampling)``.
+
+ For a more precise but slower alternative, see
+ :py:meth:`backpack.extensions.SqrtGGNExact`.
+
+ .. note::
+
+ (Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_mc``
+ can be viewed as a ``[M * N, param.numel()]`` matrix. Concatenating this
+ matrix over all parameters results in a matrix ``Vᵀ``, which
+ is the approximate GGN/Fisher's matrix square root, i.e. ``G ≈ V Vᵀ``.
+ """
+
+ def __init__(self, mc_samples: int = 1, subsampling: List[int] = None):
+ """Approximate loss Hessian via MC and set savefield to ``sqrt_ggn_mc``.
+
+ Args:
+ mc_samples: Number of Monte-Carlo samples. Default: ``1``.
+ subsampling: Indices of active samples. Defaults to ``None`` (use all
+ samples in the mini-batch).
+ """
+ self._mc_samples = mc_samples
+ super().__init__(LossHessianStrategy.SAMPLING, "sqrt_ggn_mc", subsampling)
+
+ def get_num_mc_samples(self) -> int:
+ """Return the number of MC samples used to approximate the loss Hessian.
+
+ Returns:
+ Number of Monte-Carlo samples.
+ """
+ return self._mc_samples
diff --git a/backpack/extensions/secondorder/sqrt_ggn/activations.py b/backpack/extensions/secondorder/sqrt_ggn/activations.py
new file mode 100644
index 000000000..3aaf8fff2
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/activations.py
@@ -0,0 +1,65 @@
+"""Contains extensions for activation layers used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.elu import ELUDerivatives
+from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives
+from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives
+from backpack.core.derivatives.relu import ReLUDerivatives
+from backpack.core.derivatives.selu import SELUDerivatives
+from backpack.core.derivatives.sigmoid import SigmoidDerivatives
+from backpack.core.derivatives.tanh import TanhDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNReLU(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ReLU`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.ReLU`` module."""
+ super().__init__(ReLUDerivatives())
+
+
+class SqrtGGNSigmoid(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Sigmoid`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Sigmoid`` module."""
+ super().__init__(SigmoidDerivatives())
+
+
+class SqrtGGNTanh(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Tanh`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Tanh`` module."""
+ super().__init__(TanhDerivatives())
+
+
+class SqrtGGNELU(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ELU`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.ELU`` module."""
+ super().__init__(ELUDerivatives())
+
+
+class SqrtGGNSELU(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.SELU`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.SELU`` module."""
+ super().__init__(SELUDerivatives())
+
+
+class SqrtGGNLeakyReLU(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LeakyReLU`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.LeakyReLU`` module."""
+ super().__init__(LeakyReLUDerivatives())
+
+
+class SqrtGGNLogSigmoid(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LogSigmoid`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.LogSigmoid`` module."""
+ super().__init__(LogSigmoidDerivatives())
diff --git a/backpack/extensions/secondorder/sqrt_ggn/base.py b/backpack/extensions/secondorder/sqrt_ggn/base.py
new file mode 100644
index 000000000..425766f8e
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/base.py
@@ -0,0 +1,80 @@
+"""Contains base class for ``SqrtGGN{Exact, MC}`` module extensions."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, List, Tuple, Union
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.core.derivatives.basederivatives import BaseDerivatives
+from backpack.extensions.mat_to_mat_jac_base import MatToJacMat
+
+if TYPE_CHECKING:
+ from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC
+
+
+class SqrtGGNBaseModule(MatToJacMat):
+ """Base module extension for ``SqrtGGN{Exact, MC}``."""
+
+ def __init__(self, derivatives: BaseDerivatives, params: List[str] = None):
+ """Store parameter names and derivatives.
+
+ Sets up methods that extract the GGN/Fisher matrix square root for the
+ passed parameters, unless these methods are overwritten by a child class.
+
+ Args:
+ derivatives: derivatives object.
+ params: List of parameter names. Defaults to None.
+ """
+ if params is not None:
+ for param_str in params:
+ if not hasattr(self, param_str):
+ setattr(self, param_str, self._make_param_function(param_str))
+
+ super().__init__(derivatives, params=params)
+
+ def _make_param_function(
+ self, param_str: str
+ ) -> Callable[
+ [Union[SqrtGGNExact, SqrtGGNMC], Module, Tuple[Tensor], Tuple[Tensor], Tensor],
+ Tensor,
+ ]:
+ """Create a function that computes the GGN/Fisher square root for a parameter.
+
+ Args:
+ param_str: name of parameter
+
+ Returns:
+ Function that computes the GGN/Fisher matrix square root.
+ """
+
+ def param_function(
+ ext: Union[SqrtGGNExact, SqrtGGNMC],
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ backproped: Tensor,
+ ) -> Tensor:
+ """Calculate the GGN/Fisher matrix square root with the derivatives object.
+
+ Args:
+ ext: extension that is used
+ module: module that performed forward pass
+ g_inp: input gradient tensors
+ g_out: output gradient tensors
+ backproped: Backpropagated quantities from second-order extension.
+
+ Returns:
+ GGN/Fisher matrix square root.
+ """
+ return self.derivatives.param_mjp(
+ param_str,
+ module,
+ g_inp,
+ g_out,
+ backproped,
+ sum_batch=False,
+ subsampling=ext.get_subsampling(),
+ )
+
+ return param_function
diff --git a/backpack/extensions/secondorder/sqrt_ggn/convnd.py b/backpack/extensions/secondorder/sqrt_ggn/convnd.py
new file mode 100644
index 000000000..74a88651c
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/convnd.py
@@ -0,0 +1,29 @@
+"""Contains extensions for convolution layers used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.conv1d import Conv1DDerivatives
+from backpack.core.derivatives.conv2d import Conv2DDerivatives
+from backpack.core.derivatives.conv3d import Conv3DDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNConv1d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv1d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Conv1d`` module."""
+ super().__init__(Conv1DDerivatives(), params=["bias", "weight"])
+
+
+class SqrtGGNConv2d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv2d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Conv2d`` module."""
+ super().__init__(Conv2DDerivatives(), params=["bias", "weight"])
+
+
+class SqrtGGNConv3d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv3d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Conv3d`` module."""
+ super().__init__(Conv3DDerivatives(), params=["bias", "weight"])
diff --git a/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py
new file mode 100644
index 000000000..a18331976
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py
@@ -0,0 +1,29 @@
+"""Contains transpose convolution layer extensions used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives
+from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives
+from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNConvTranspose1d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose1d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.ConvTranspose1d`` module."""
+ super().__init__(ConvTranspose1DDerivatives(), params=["bias", "weight"])
+
+
+class SqrtGGNConvTranspose2d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose2d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.ConvTranspose2d`` module."""
+ super().__init__(ConvTranspose2DDerivatives(), params=["bias", "weight"])
+
+
+class SqrtGGNConvTranspose3d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose3d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.ConvTranspose3d`` module."""
+ super().__init__(ConvTranspose3DDerivatives(), params=["bias", "weight"])
diff --git a/backpack/extensions/secondorder/sqrt_ggn/dropout.py b/backpack/extensions/secondorder/sqrt_ggn/dropout.py
new file mode 100644
index 000000000..2f03b8aa9
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/dropout.py
@@ -0,0 +1,11 @@
+"""Contains extensions for dropout layers used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.dropout import DropoutDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNDropout(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Dropout`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Dropout`` module."""
+ super().__init__(DropoutDerivatives())
diff --git a/backpack/extensions/secondorder/sqrt_ggn/embedding.py b/backpack/extensions/secondorder/sqrt_ggn/embedding.py
new file mode 100644
index 000000000..070ad217c
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/embedding.py
@@ -0,0 +1,11 @@
+"""Contains extension for the embedding layer used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.embedding import EmbeddingDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNEmbedding(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Embedding`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Embedding`` module."""
+ super().__init__(EmbeddingDerivatives(), params=["weight"])
diff --git a/backpack/extensions/secondorder/sqrt_ggn/flatten.py b/backpack/extensions/secondorder/sqrt_ggn/flatten.py
new file mode 100644
index 000000000..2a045c957
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/flatten.py
@@ -0,0 +1,11 @@
+"""Contains extensions for the flatten layer used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.flatten import FlattenDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNFlatten(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Flatten`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Flatten`` module."""
+ super().__init__(FlattenDerivatives())
diff --git a/backpack/extensions/secondorder/sqrt_ggn/linear.py b/backpack/extensions/secondorder/sqrt_ggn/linear.py
new file mode 100644
index 000000000..4aecca6f5
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/linear.py
@@ -0,0 +1,11 @@
+"""Contains extension for the linear layer used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.linear import LinearDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNLinear(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Linear`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.Linear`` module."""
+ super().__init__(LinearDerivatives(), params=["bias", "weight"])
diff --git a/backpack/extensions/secondorder/sqrt_ggn/losses.py b/backpack/extensions/secondorder/sqrt_ggn/losses.py
new file mode 100644
index 000000000..2294bc794
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/losses.py
@@ -0,0 +1,82 @@
+"""Contains base class and extensions for losses used by ``SqrtGGN{Exact, MC}``."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Tuple, Union
+
+from torch import Tensor
+from torch.nn import Module
+
+from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives
+from backpack.core.derivatives.mseloss import MSELossDerivatives
+from backpack.extensions.secondorder.hbp import LossHessianStrategy
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+if TYPE_CHECKING:
+ from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact, SqrtGGNMC
+
+
+class SqrtGGNBaseLossModule(SqrtGGNBaseModule):
+ """Base class for losses used by ``SqrtGGN{Exact, MC}``."""
+
+ def backpropagate(
+ self,
+ ext: Union[SqrtGGNExact, SqrtGGNMC],
+ module: Module,
+ grad_inp: Tuple[Tensor],
+ grad_out: Tuple[Tensor],
+ backproped: None,
+ ) -> Tensor:
+ """Initialize the backpropagated quantity.
+
+ Uses the exact loss Hessian square root, or a Monte-Carlo approximation
+ thereof.
+
+ Args:
+ ext: BackPACK extension calling out to the module extension.
+ module: Module that performed the forward pass.
+ grad_inp: Gradients w.r.t. the module inputs.
+ grad_out: Gradients w.r.t. the module outputs.
+ backproped: Backpropagated information. Should be ``None``.
+
+ Returns:
+ Symmetric factorization of the loss Hessian w.r.t. the module input.
+
+ Raises:
+ NotImplementedError: For invalid strategies to represent the loss Hessian.
+ """
+ loss_hessian_strategy = ext.get_loss_hessian_strategy()
+ subsampling = ext.get_subsampling()
+
+ if loss_hessian_strategy == LossHessianStrategy.EXACT:
+ return self.derivatives.sqrt_hessian(
+ module, grad_inp, grad_out, subsampling=subsampling
+ )
+ elif loss_hessian_strategy == LossHessianStrategy.SAMPLING:
+ mc_samples = ext.get_num_mc_samples()
+ return self.derivatives.sqrt_hessian_sampled(
+ module,
+ grad_inp,
+ grad_out,
+ mc_samples=mc_samples,
+ subsampling=subsampling,
+ )
+ else:
+ raise NotImplementedError(
+ f"Unknown hessian strategy {loss_hessian_strategy}"
+ )
+
+
+class SqrtGGNMSELoss(SqrtGGNBaseLossModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MSELoss`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.MSELoss`` module."""
+ super().__init__(MSELossDerivatives())
+
+
+class SqrtGGNCrossEntropyLoss(SqrtGGNBaseLossModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.CrossEntropyLoss`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.CrossEntropyLoss`` module."""
+ super().__init__(CrossEntropyLossDerivatives())
diff --git a/backpack/extensions/secondorder/sqrt_ggn/padding.py b/backpack/extensions/secondorder/sqrt_ggn/padding.py
new file mode 100644
index 000000000..18574f685
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/padding.py
@@ -0,0 +1,11 @@
+"""Contains extensions for padding layers used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNZeroPad2d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ZeroPad2d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.ZeroPad2d`` module."""
+ super().__init__(ZeroPad2dDerivatives())
diff --git a/backpack/extensions/secondorder/sqrt_ggn/pooling.py b/backpack/extensions/secondorder/sqrt_ggn/pooling.py
new file mode 100644
index 000000000..e19cfba2a
--- /dev/null
+++ b/backpack/extensions/secondorder/sqrt_ggn/pooling.py
@@ -0,0 +1,56 @@
+"""Contains extensions for pooling layers used by ``SqrtGGN{Exact, MC}``."""
+from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives
+from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives
+from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives
+from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives
+from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives
+from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives
+from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule
+
+
+class SqrtGGNMaxPool1d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool1d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.MaxPool1d`` module."""
+ super().__init__(MaxPool1DDerivatives())
+
+
+class SqrtGGNMaxPool2d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool2d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.MaxPool2d`` module."""
+ super().__init__(MaxPool2DDerivatives())
+
+
+class SqrtGGNMaxPool3d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.MaxPool3d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.MaxPool3d`` module."""
+ super().__init__(MaxPool3DDerivatives())
+
+
+class SqrtGGNAvgPool1d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool1d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.AvgPool1d`` module."""
+ super().__init__(AvgPool1DDerivatives())
+
+
+class SqrtGGNAvgPool2d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool2d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.AvgPool2d`` module."""
+ super().__init__(AvgPool2DDerivatives())
+
+
+class SqrtGGNAvgPool3d(SqrtGGNBaseModule):
+ """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.AvgPool3d`` module."""
+
+ def __init__(self):
+ """Pass derivatives for ``torch.nn.AvgPool3d`` module."""
+ super().__init__(AvgPool3DDerivatives())
diff --git a/backpack/hessianfree/ggnvp.py b/backpack/hessianfree/ggnvp.py
index 92c083636..165aff253 100644
--- a/backpack/hessianfree/ggnvp.py
+++ b/backpack/hessianfree/ggnvp.py
@@ -1,42 +1,56 @@
-from .hvp import hessian_vector_product
-from .lop import L_op
-from .rop import R_op
-
-
-def ggn_vector_product(loss, output, model, v):
+"""Autodiff-only matrix-free multiplication by the generalized Gauss-Newton/Fisher."""
+from typing import List, Tuple
+
+from torch import Tensor
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+
+from backpack.hessianfree.hvp import hessian_vector_product
+from backpack.hessianfree.lop import L_op
+from backpack.hessianfree.rop import R_op
+
+
+def ggn_vector_product(
+ loss: Tensor, output: Tensor, model: Module, v: List[Tensor]
+) -> Tuple[Tensor]:
+ """Multiply a vector with the generalized Gauss-Newton/Fisher.
+
+ Note:
+ ``G v = J.T @ H @ J @ v`` where ``J`` is the Jacobian of ``output`` w.r.t.
+ ``model``'s trainable parameters and `H` is the Hessian of `loss` w.r.t.
+ ``output``. ``v`` is the flattened and concatenated version of the passed
+ list of vectors.
+
+ Args:
+ loss: Scalar tensor that represents the loss.
+ output: Model output.
+ model: The model.
+ v: Vector specified as list of tensors matching the trainable parameters.
+
+ Returns:
+ GGN-vector product in list format, i.e. as list that matches the sizes
+ of trainable model parameters.
"""
- Multiplies the vector `v` with the Generalized Gauss-Newton,
- `ggn_v = J.T @ H @ J @ v`
-
- where `J` is the Jacobian of `output` w.r.t. `model.parameters()`
- and `H` is the Hessian of `loss` w.r.t. `output`.
+ return ggn_vector_product_from_plist(
+ loss, output, [p for p in model.parameters() if p.requires_grad], v
+ )
- Example usage:
- ```
- X, Y = data()
- model = torch.nn.Linear(784, 10)
- lossfunc = torch.nn.CrossEntropyLoss()
- output = model(X)
- loss = lossfunc(output, Y)
+def ggn_vector_product_from_plist(
+ loss: Tensor, output: Tensor, plist: List[Parameter], v: List[Tensor]
+) -> Tuple[Tensor]:
+ """Multiply a vector with a sub-block of the generalized Gauss-Newton/Fisher.
- v = list([torch.randn_like(p) for p in model.parameters])
+ Args:
+ loss: Scalar tensor that represents the loss.
+ output: Model output.
+ plist: List of trainable parameters whose GGN block is used for multiplication.
+ v: Vector specified as list of tensors matching the sizes of ``plist``.
- GGNv = ggn_vector_product(loss, output, model, v)
- ```
-
- Parameters:
- -----------
- loss: torch.Tensor
- output: torch.Tensor
- model: torch.nn.Module
- v: [torch.Tensor]
- List of tensors matching the sizes of model.parameters()
+ Returns:
+ GGN-vector product in list format, i.e. as list that matches the sizes of
+ ``plist``.
"""
- return ggn_vector_product_from_plist(loss, output, list(model.parameters()), v)
-
-
-def ggn_vector_product_from_plist(loss, output, plist, v):
Jv = R_op(output, plist, v)
HJv = hessian_vector_product(loss, output, Jv)
JTHJv = L_op(output, plist, HJv)
diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py
index e69de29bb..d5fb6701b 100644
--- a/backpack/utils/__init__.py
+++ b/backpack/utils/__init__.py
@@ -0,0 +1,8 @@
+"""Contains utility functions."""
+from pkg_resources import get_distribution, packaging
+
+TORCH_VERSION = packaging.version.parse(get_distribution("torch").version)
+TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1")
+TORCH_VERSION_AT_LEAST_2_0_0 = TORCH_VERSION >= packaging.version.parse("2.0.0")
+
+ADAPTIVE_AVG_POOL_BUG: bool = not TORCH_VERSION_AT_LEAST_2_0_0
diff --git a/backpack/utils/conv.py b/backpack/utils/conv.py
index 14d394f54..2cc1c5adb 100644
--- a/backpack/utils/conv.py
+++ b/backpack/utils/conv.py
@@ -1,22 +1,56 @@
+from typing import Callable, Type, Union
+
import torch
from einops import rearrange
-from torch import einsum
+from torch import Tensor, einsum
+from torch.nn import Conv1d, Conv2d, Conv3d, Module
from torch.nn.functional import conv1d, conv2d, conv3d, unfold
-def unfold_input(module, input):
+def get_conv_module(N: int) -> Type[Module]:
+ """Return the PyTorch module class of N-dimensional convolution.
+
+ Args:
+ N: Convolution dimension.
+
+ Returns:
+ Convolution class.
+ """
+ return {
+ 1: Conv1d,
+ 2: Conv2d,
+ 3: Conv3d,
+ }[N]
+
+
+def get_conv_function(N: int) -> Callable:
+ """Return the PyTorch function of N-dimensional convolution.
+
+ Args:
+ N: Convolution dimension.
+
+ Returns:
+ Convolution function.
+ """
+ return {
+ 1: conv1d,
+ 2: conv2d,
+ 3: conv3d,
+ }[N]
+
+
+def unfold_input(module: Union[Conv1d, Conv2d, Conv3d], input: Tensor) -> Tensor:
"""Return unfolded input to a convolution.
Use PyTorch's ``unfold`` operation for 2d convolutions (4d input tensors),
otherwise fall back to a custom implementation.
Args:
- module (torch.nn.Conv1d or torch.nn.Conv2d or torch.nn.Conv3d): Convolution
- module whose hyperparameters are used for the unfold.
- input (torch.Tensor): Input to convolution that will be unfolded.
+ module: Convolution module whose hyperparameters are used for the unfold.
+ input: Input to convolution that will be unfolded.
Returns:
- torch.Tensor: Unfolded input.
+ Unfolded input.
"""
if input.dim() == 4:
return unfold(
@@ -30,7 +64,7 @@ def unfold_input(module, input):
return unfold_by_conv(input, module)
-def get_weight_gradient_factors(input, grad_out, module, N):
+def get_weight_gradient_factors(input, grad_out, module):
X = unfold_input(module, input)
dE_dY = rearrange(grad_out, "n c ... -> n c (...)")
return X, dE_dY
@@ -116,15 +150,9 @@ def make_weight():
repeat = [C_in, 1] + [1 for _ in kernel_size]
return weight.repeat(*repeat)
- def get_conv():
- functional_for_module_cls = {
- torch.nn.Conv1d: conv1d,
- torch.nn.Conv2d: conv2d,
- torch.nn.Conv3d: conv3d,
- }
- return functional_for_module_cls[module.__class__]
+ conv_dim = input.dim() - 2
+ conv = get_conv_function(conv_dim)
- conv = get_conv()
unfold = conv(
input,
make_weight().to(input.device),
diff --git a/backpack/utils/conv_transpose.py b/backpack/utils/conv_transpose.py
index d1183e909..3c90834be 100644
--- a/backpack/utils/conv_transpose.py
+++ b/backpack/utils/conv_transpose.py
@@ -1,14 +1,49 @@
"""Utility functions for extracting transpose convolution BackPACK quantities."""
+from typing import Callable, Type
+
import torch
from einops import rearrange
from torch import einsum
+from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module
from torch.nn.functional import conv_transpose1d, conv_transpose2d, conv_transpose3d
from backpack.utils.conv import extract_bias_diagonal as conv_extract_bias_diagonal
-def get_weight_gradient_factors(input, grad_out, module, N):
+def get_conv_transpose_module(N: int) -> Type[Module]:
+ """Return the PyTorch module class of N-dimensional transpose convolution.
+
+ Args:
+ N: Transpose convolution dimension.
+
+ Returns:
+ Transpose convolution class.
+ """
+ return {
+ 1: ConvTranspose1d,
+ 2: ConvTranspose2d,
+ 3: ConvTranspose3d,
+ }[N]
+
+
+def get_conv_transpose_function(N: int) -> Callable:
+ """Return the PyTorch function of N-dimensional transpose convolution.
+
+ Args:
+ N: Transpose convolution dimension.
+
+ Returns:
+ Transpose convolution function.
+ """
+ return {
+ 1: conv_transpose1d,
+ 2: conv_transpose2d,
+ 3: conv_transpose3d,
+ }[N]
+
+
+def get_weight_gradient_factors(input, grad_out, module):
M, C_in = input.shape[0], input.shape[1]
kernel_size_numel = module.weight.shape[2:].numel()
@@ -109,15 +144,9 @@ def make_weight():
weight = weight.repeat(*repeat)
return weight.to(module.weight.device)
- def get_conv_transpose():
- functional_for_module_cls = {
- torch.nn.ConvTranspose1d: conv_transpose1d,
- torch.nn.ConvTranspose2d: conv_transpose2d,
- torch.nn.ConvTranspose3d: conv_transpose3d,
- }
- return functional_for_module_cls[module.__class__]
+ conv_dim = input.dim() - 2
+ conv_transpose = get_conv_transpose_function(conv_dim)
- conv_transpose = get_conv_transpose()
unfold = conv_transpose(
input,
make_weight().to(module.weight.device),
diff --git a/backpack/utils/errors.py b/backpack/utils/errors.py
new file mode 100644
index 000000000..690dc451b
--- /dev/null
+++ b/backpack/utils/errors.py
@@ -0,0 +1,29 @@
+"""Contains errors for BackPACK."""
+from typing import Union
+from warnings import warn
+
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+
+def batch_norm_raise_error_if_train(
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], raise_error: bool = True
+) -> None:
+ """Check if BatchNorm module is in training mode.
+
+ Args:
+ module: BatchNorm module to check
+ raise_error: whether to raise an error, alternatively warn. Default: True.
+
+ Raises:
+ NotImplementedError: if module is in training mode
+ """
+ if module.training:
+ message = (
+ "Encountered BatchNorm module in training mode. BackPACK's computation "
+ "will pass, but results like individual gradients may not be meaningful, "
+ "as BatchNorm mixes samples. Only proceed if you know what you are doing."
+ )
+ if raise_error:
+ raise NotImplementedError(message)
+ else:
+ warn(message)
diff --git a/backpack/utils/examples.py b/backpack/utils/examples.py
index eb9392613..788241140 100644
--- a/backpack/utils/examples.py
+++ b/backpack/utils/examples.py
@@ -1,35 +1,117 @@
"""Utility functions for examples."""
-import torch
-import torchvision
+from typing import Iterator, List, Tuple
+from torch import Tensor, stack, zeros
+from torch.nn import Module
+from torch.nn.utils.convert_parameters import parameters_to_vector
+from torch.utils.data import DataLoader, Dataset
+from torchvision.datasets import MNIST
+from torchvision.transforms import Compose, Normalize, ToTensor
-def load_mnist_dataset():
- """Download and normalize MNIST training data."""
- mnist_dataset = torchvision.datasets.MNIST(
+from backpack.hessianfree.ggnvp import ggn_vector_product
+from backpack.utils.convert_parameters import vector_to_parameter_list
+
+
+def load_mnist_dataset() -> Dataset:
+ """Download and normalize MNIST training data.
+
+ Returns:
+ Normalized MNIST dataset
+ """
+ return MNIST(
root="./data",
train=True,
- transform=torchvision.transforms.Compose(
- [
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize((0.1307,), (0.3081,)),
- ]
- ),
+ transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]),
download=True,
)
- return mnist_dataset
-def get_mnist_dataloader(batch_size=64, shuffle=True):
- """Returns a dataloader for MNIST"""
- return torch.utils.data.dataloader.DataLoader(
- load_mnist_dataset(),
- batch_size=batch_size,
- shuffle=shuffle,
- )
+def get_mnist_dataloader(batch_size: int = 64, shuffle: bool = True) -> DataLoader:
+ """Returns a dataloader for MNIST.
+
+ Args:
+ batch_size: Mini-batch size. Default: ``64``.
+ shuffle: Randomly shuffle the data. Default: ``True``.
+
+ Returns:
+ MNIST dataloader
+ """
+ return DataLoader(load_mnist_dataset(), batch_size=batch_size, shuffle=shuffle)
-def load_one_batch_mnist(batch_size=64, shuffle=True):
- """Return a single batch (inputs, labels) of MNIST data."""
+def load_one_batch_mnist(
+ batch_size: int = 64, shuffle: bool = True
+) -> Tuple[Tensor, Tensor]:
+ """Return a single mini-batch (inputs, labels) from MNIST.
+
+ Args:
+ batch_size: Mini-batch size. Default: ``64``.
+ shuffle: Randomly shuffle the data. Default: ``True``.
+
+ Returns:
+ A single batch (inputs, labels) from MNIST.
+ """
dataloader = get_mnist_dataloader(batch_size, shuffle)
X, y = next(iter(dataloader))
+
return X, y
+
+
+def autograd_diag_ggn_exact(
+ X: Tensor, y: Tensor, model: Module, loss_function: Module, idx: List[int] = None
+) -> Tensor:
+ """Compute the generalized Gauss-Newton diagonal with ``torch.autograd``.
+
+ Args:
+ X: Input to the model.
+ y: Labels.
+ model: The neural network.
+ loss_function: Loss function module.
+ idx: Indices for which the diagonal entries are computed. Default value ``None``
+ computes the full diagonal.
+
+ Returns:
+ Exact GGN diagonal (flattened and concatenated).
+ """
+ diag_elements = [
+ col[col_idx]
+ for col_idx, col in _autograd_ggn_exact_columns(
+ X, y, model, loss_function, idx=idx
+ )
+ ]
+
+ return stack(diag_elements)
+
+
+def _autograd_ggn_exact_columns(
+ X: Tensor, y: Tensor, model: Module, loss_function: Module, idx: List[int] = None
+) -> Iterator[Tuple[int, Tensor]]:
+ """Yield exact generalized Gauss-Newton's columns computed with ``torch.autograd``.
+
+ Args:
+ X: Input to the model.
+ y: Labels.
+ model: The neural network.
+ loss_function: Loss function module.
+ idx: Indices of columns that are computed. Default value ``None`` computes all
+ columns.
+
+ Yields:
+ Tuple of column index and respective GGN column (flattened and concatenated).
+ """
+ trainable_parameters = [p for p in model.parameters() if p.requires_grad]
+ D = sum(p.numel() for p in trainable_parameters)
+
+ outputs = model(X)
+ loss = loss_function(outputs, y)
+
+ idx = idx if idx is not None else list(range(D))
+
+ for d in idx:
+ e_d = zeros(D, device=loss.device, dtype=loss.dtype)
+ e_d[d] = 1.0
+ e_d_list = vector_to_parameter_list(e_d, trainable_parameters)
+
+ ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list)
+
+ yield d, parameters_to_vector(ggn_d_list)
diff --git a/backpack/utils/hooks.py b/backpack/utils/hooks.py
index d86381ec6..c4a8aff68 100644
--- a/backpack/utils/hooks.py
+++ b/backpack/utils/hooks.py
@@ -2,5 +2,10 @@
def no_op(*args, **kwargs):
- """Placeholder function that accepts arbitrary input and does nothing."""
- return None
+ """Placeholder function that accepts arbitrary input and does nothing.
+
+ Args:
+ *args: anything
+ **kwargs: anything
+ """
+ pass
diff --git a/backpack/utils/linear.py b/backpack/utils/linear.py
index b3a2453b3..60912e5ac 100644
--- a/backpack/utils/linear.py
+++ b/backpack/utils/linear.py
@@ -1,17 +1,62 @@
-from torch import einsum
+"""Contains utility functions to extract the GGN diagonal for linear layers."""
+from torch import Tensor, einsum
+from torch.nn import Linear
-def extract_weight_diagonal(module, backproped, sum_batch=True):
- if sum_batch:
- equation = "vno,ni->oi"
+def extract_weight_diagonal(
+ module: Linear, S: Tensor, sum_batch: bool = True
+) -> Tensor:
+ """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the weight Jacobian.
+
+ Args:
+ module: Linear layer for which the diagonal is extracted w.r.t. the weight.
+ S: Backpropagated symmetric factorization of the loss Hessian. Has shape
+ ``(V, *module.output.shape)``.
+ sum_batch: Sum out the weight diagonal's batch dimension. Default: ``True``.
+
+ Returns:
+ Per-sample weight diagonal if ``sum_batch=False`` (shape
+ ``(N, module.weight.shape)`` with batch size ``N``) or summed weight diagonal
+ if ``sum_batch=True`` (shape ``module.weight.shape``).
+ """
+ has_additional_axes = module.input0.dim() > 2
+
+ if has_additional_axes:
+ S_flat = S.flatten(start_dim=2, end_dim=-2)
+ X_flat = module.input0.flatten(start_dim=1, end_dim=-2)
+ equation = f"vnmo,nmi,vnko,nki->{'' if sum_batch else 'n'}oi"
+ # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class
+ # implementation: https://github.com/fKunstner/backpack-discuss/issues/111
+ return einsum(equation, S_flat, X_flat, S_flat, X_flat)
+
else:
- equation = "vno,ni->noi"
- return einsum(equation, (backproped ** 2, module.input0 ** 2))
+ equation = f"vno,ni->{'' if sum_batch else 'n'}oi"
+ return einsum(equation, S ** 2, module.input0 ** 2)
+
+# TODO This method applies the bias Jacobian, then squares and sums the result. Intro-
+# duce base class for {Batch}DiagHessian and DiagGGN{Exact,MC} and remove this method
+def extract_bias_diagonal(module: Linear, S: Tensor, sum_batch: bool = True) -> Tensor:
+ """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the bias Jacobian.
-def extract_bias_diagonal(module, backproped, sum_batch=True):
- if sum_batch:
- equation = "vno->o"
+ Args:
+ module: Linear layer for which the diagonal is extracted w.r.t. the bias.
+ S: Backpropagated symmetric factorization of the loss Hessian. Has shape
+ ``(V, *module.output.shape)``.
+ sum_batch: Sum out the bias diagonal's batch dimension. Default: ``True``.
+
+ Returns:
+ Per-sample bias diagonal if ``sum_batch=False`` (shape
+ ``(N, module.bias.shape)`` with batch size ``N``) or summed bias diagonal
+ if ``sum_batch=True`` (shape ``module.bias.shape``).
+ """
+ additional_axes = list(range(2, module.input0.dim()))
+
+ if additional_axes:
+ JS = S.sum(additional_axes)
else:
- equation = "vno->no"
- return einsum(equation, backproped ** 2)
+ JS = S
+
+ equation = f"vno->{'' if sum_batch else 'n'}o"
+
+ return einsum(equation, JS ** 2)
diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py
new file mode 100644
index 000000000..b8d9b5b5f
--- /dev/null
+++ b/backpack/utils/module_classification.py
@@ -0,0 +1,32 @@
+"""Contains util function for classification of modules."""
+from torch.fx import GraphModule
+from torch.nn import Module, Sequential
+from torch.nn.modules.loss import _Loss
+
+from backpack.custom_module.branching import Parallel, _Branch
+from backpack.custom_module.reduce_tuple import ReduceTuple
+
+
+def is_loss(module: Module) -> bool:
+ """Return whether `module` is a `torch` loss function.
+
+ Args:
+ module: A PyTorch module.
+
+ Returns:
+ Whether `module` is a loss function.
+ """
+ return isinstance(module, _Loss)
+
+
+def is_no_op(module: Module) -> bool:
+ """Return whether the module does no operation in graph.
+
+ Args:
+ module: module
+
+ Returns:
+ whether module is no operation
+ """
+ no_op_modules = (Sequential, _Branch, Parallel, ReduceTuple, GraphModule)
+ return isinstance(module, no_op_modules)
diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py
new file mode 100644
index 000000000..62d399f4c
--- /dev/null
+++ b/backpack/utils/subsampling.py
@@ -0,0 +1,22 @@
+"""Utility functions to enable mini-batch subsampling in extensions."""
+from typing import List
+
+from torch import Tensor
+
+
+def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor:
+ """Select samples from a tensor along a dimension.
+
+ Args:
+ tensor: Tensor to select from.
+ dim: Selection dimension. Defaults to ``0``.
+ subsampling: Indices of samples that are sliced along the dimension.
+ Defaults to ``None`` (use all samples).
+
+ Returns:
+ Tensor of same rank that is sub-sampled along the dimension.
+ """
+ if subsampling is None:
+ return tensor
+ else:
+ return tensor[(slice(None),) * dim + (subsampling,)]
diff --git a/changelog.md b/changelog.md
index 6a00293f4..4bf090ad8 100644
--- a/changelog.md
+++ b/changelog.md
@@ -6,6 +6,108 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+## [1.4.0] - 2021-11-12
+
+This release ships many new features. Some rely on recent PyTorch functionality.
+We now require `torch>=1.9.0`.
+
+**Highlights:**
+
+- *ResNets & RNNs:* Thanks to [@schaefertim](https://github.com/schaefertim) for
+ bringing basic support for RNNs
+ ([#16](https://github.com/f-dangel/backpack/issues/16),
+ [tutorial](https://docs.backpack.pt/en/1.4.0/use_cases/example_rnn.html#sphx-glr-use-cases-example-rnn-py)) and
+ ResNets ([#14](https://github.com/f-dangel/backpack/issues/14),
+ [tutorial](https://docs.backpack.pt/en/1.4.0/use_cases/example_resnet_all_in_one.html#sphx-glr-use-cases-example-resnet-all-in-one-py))
+- *`SqrtGGN{Exact,MC}` extension:* Symmetric factorization of the generalized
+ Gauss-Newton/Fisher (see
+ [arXiv:2106.02624](https://arxiv.org/abs/2106.02624))
+- *Sub-sampling:* Allows for restricting BackPACK extensions to a sub-set of
+ samples in the mini-batch
+ ([#12](https://github.com/f-dangel/backpack/issues/12),
+ [tutorial](https://docs.backpack.pt/en/1.4.0/use_cases/example_subsampling.html#sphx-glr-use-cases-example-subsampling-py))
+
+### Added/New
+- Converter functionality for basic support of ResNets and RNNs
+ [[PR1](https://github.com/f-dangel/backpack/pull/202),
+ [PR2](https://github.com/f-dangel/backpack/pull/221),
+ [PR3](https://github.com/f-dangel/backpack/pull/229)]
+- New extensions:
+ - `SqrtGGNExact`: Symmetric factorization of the exact GGN/Fisher
+ [[PR](https://github.com/f-dangel/backpack/pull/180)]
+ - `SqrtGGNMC`: Symmetric factorization of the MC-approximated GGN/Fisher
+ [[PR](https://github.com/f-dangel/backpack/pull/182)]
+- Module support:
+ - `Linear`: Support additional (more than 2) input dimensions
+ [[PR1](https://github.com/f-dangel/backpack/pull/185),
+ [PR2](https://github.com/f-dangel/backpack/pull/186)]
+ - `BatchNormNd`: Distinguish evaluation and training mode, support
+ first-order extensions and `DiagGGN{Exact,MC}`
+ [[#160](https://github.com/f-dangel/backpack/issues/160),
+ [PR1](https://github.com/f-dangel/backpack/pull/179),
+ [PR2](https://github.com/f-dangel/backpack/pull/201)]
+ - `AdaptiveAvgPoolND`: Support first-order extensions and
+ `DiagGGN{Exact,MC}`
+ [[PR](https://github.com/f-dangel/backpack/pull/201)]
+ - `RNN`: Support first-order extensions and `DiagGGN{MC,Exact}`
+ [[PR1](https://github.com/f-dangel/backpack/pull/159)
+ [PR2](https://github.com/f-dangel/backpack/pull/158)
+ [PR3](https://github.com/f-dangel/backpack/pull/156)]
+ - `LSTM`: Support first-order extensions and `DiagGGN{MC,Exact}`
+ [[PR](https://github.com/f-dangel/backpack/pull/215)]
+ - `CrossEntropyLoss`: Support additional (more than 2) input dimensions.
+ [[PR](https://github.com/f-dangel/backpack/pull/211)]
+ - `Embedding`: Support first-order extensions and `DiagGGN{MC,Exact}`
+ [[PR](https://github.com/f-dangel/backpack/pull/216)]
+- Mini-batch sub-sampling
+ - `BatchGrad`
+ [[PR1](https://github.com/f-dangel/backpack/pull/200),
+ [PR2](https://github.com/f-dangel/backpack/pull/210)]
+ - `SqrtGGN{Exact,MC}`
+ [[PR](https://github.com/f-dangel/backpack/pull/200)]
+- `retain_graph` option for `backpack` context
+ [[PR](https://github.com/f-dangel/backpack/pull/217)]
+- Assume batch axis always first
+ [[PR](https://github.com/f-dangel/backpack/pull/227)]
+
+### Fixed/Removed
+- Deprecate `python3.6`, require at least `python3.7`
+ [[PR](https://github.com/f-dangel/backpack/pull/190)]
+
+### Internal
+- Use `full_backward_hook` for `torch>=1.9.0`
+ [[PR](https://github.com/f-dangel/backpack/pull/194)]
+- Core
+ - Implement derivatives for `LSTM`
+ [[PR](https://github.com/f-dangel/backpack/pull/169)]
+ - Implement derivatives for `AdaptiveAvgPoolNd`
+ [[PR](https://github.com/f-dangel/backpack/pull/165)]
+ - Sub-sampling
+ - `weight_jac_t_mat_prod`
+ [[PR](https://github.com/f-dangel/backpack/pull/195)]
+ - `bias_jac_t_mat_prod`
+ [[PR](https://github.com/f-dangel/backpack/pull/196)]
+ - `*_jac_t_mat_prod` of `RNN` and `LSTM` parameters
+ [[PR](https://github.com/f-dangel/backpack/pull/197)]
+ - `jac_t_mat_prod`
+ [[PR](https://github.com/f-dangel/backpack/pull/205)]
+ - Hessian square root decomposition (exact and MC)
+ [[PR](https://github.com/f-dangel/backpack/pull/207)]
+ - Refactor: Share code for `*_jac_t_mat_prod`
+ [[PR](https://github.com/f-dangel/backpack/pull/203)]
+- Extensions
+ - Refactor `BatchL2Grad`, introducing a base class
+ [[PR](https://github.com/f-dangel/backpack/pull/175)]
+ - Automate parameter functions for `BatchGrad` and `Grad`
+ [[PR](https://github.com/f-dangel/backpack/pull/150)]
+ - Introduce interface to check module hyperparameters
+ [[PR](https://github.com/f-dangel/backpack/pull/206)]
+- Tests
+ - Check if module Hessian is zero
+ [[PR](https://github.com/f-dangel/backpack/pull/183)]
+ - Reduce run time
+ [[PR](https://github.com/f-dangel/backpack/pull/199)]
+
## [1.3.0] - 2021-06-16
Thanks to [@sbharadwajj](https://github.com/sbharadwajj)
@@ -234,7 +336,8 @@ co-authoring many PRs shipped in this release.
Initial release
-[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.3.0...HEAD
+[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.4.0...HEAD
+[1.4.0]: https://github.com/f-dangel/backpack/compare/1.4.0...1.3.0
[1.3.0]: https://github.com/f-dangel/backpack/compare/1.3.0...1.2.0
[1.2.0]: https://github.com/f-dangel/backpack/compare/1.2.0...1.1.1
[1.1.1]: https://github.com/f-dangel/backpack/compare/1.1.0...1.1.1
diff --git a/docs_src/examples/basic_usage/example_all_in_one.py b/docs_src/examples/basic_usage/example_all_in_one.py
index 49a8f1e46..cb7aba42d 100644
--- a/docs_src/examples/basic_usage/example_all_in_one.py
+++ b/docs_src/examples/basic_usage/example_all_in_one.py
@@ -29,6 +29,8 @@
DiagGGNExact,
DiagGGNMC,
DiagHessian,
+ SqrtGGNExact,
+ SqrtGGNMC,
SumGradSquared,
Variance,
)
@@ -166,6 +168,19 @@
print(".diag_h.shape: ", param.diag_h.shape)
print(".diag_h_batch.shape: ", param.diag_h_batch.shape)
+# %%
+# Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation
+
+loss = lossfunc(model(X), y)
+with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)):
+ loss.backward()
+
+for name, param in model.named_parameters():
+ print(name)
+ print(".grad.shape: ", param.grad.shape)
+ print(".sqrt_ggn_exact.shape: ", param.sqrt_ggn_exact.shape)
+ print(".sqrt_ggn_mc.shape: ", param.sqrt_ggn_mc.shape)
+
# %%
# Block-diagonal curvature products
# ---------------------------------
diff --git a/docs_src/examples/use_cases/example_first_order_resnet.py b/docs_src/examples/use_cases/example_first_order_resnet.py
index df3ded93d..a1a1075bc 100644
--- a/docs_src/examples/use_cases/example_first_order_resnet.py
+++ b/docs_src/examples/use_cases/example_first_order_resnet.py
@@ -1,101 +1,8 @@
r"""First order extensions with a ResNet
========================================
-
"""
# %%
-# Let's get the imports, configuration and some helper functions out of the way first.
-
-import torch
-import torch.nn.functional as F
-
-from backpack import backpack, extend
-from backpack.extensions import BatchGrad
-from backpack.utils.examples import load_one_batch_mnist
-
-BATCH_SIZE = 3
-torch.manual_seed(0)
-DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
-
-def get_accuracy(output, targets):
- """Helper function to print the accuracy"""
- predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
- return predictions.eq(targets).float().mean().item()
-
-
-x, y = load_one_batch_mnist(batch_size=BATCH_SIZE)
-x, y = x.to(DEVICE), y.to(DEVICE)
-
-
-# %%
-# We can build a ResNet by extending :py:class:`torch.nn.Module`.
-# As long as the layers with parameters
-# (:py:class:`torch.nn.Conv2d` and :py:class:`torch.nn.Linear`) are
-# ``nn`` modules, BackPACK can extend them,
-# and this is all that is needed for first order extensions.
-# We can rewrite the forward to implement the residual connection,
-# and :py:func:`extend() ` the resulting model.
-
-
-class MyFirstResNet(torch.nn.Module):
- def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10):
- super().__init__()
-
- self.conv1 = torch.nn.Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1)
- self.conv2 = torch.nn.Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1)
- self.linear1 = torch.nn.Linear(input_dim[0] * input_dim[1] * C_hid, output_dim)
- if C_in == C_hid:
- self.shortcut = torch.nn.Identity()
- else:
- self.shortcut = torch.nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1)
-
- def forward(self, x):
- residual = self.shortcut(x)
- x = self.conv2(F.relu(self.conv1(x)))
- x += residual
- x = x.view(x.size(0), -1)
- x = self.linear1(x)
- return x
-
-
-model = extend(MyFirstResNet()).to(DEVICE)
-
-# %%
-# Using :py:class:`BatchGrad ` in a
-# :py:class:`with backpack(...) ` block,
-# we can access the individual gradients for each sample.
-#
-# The loss does not need to be extended in this case either, as it does not
-# have model parameters and BackPACK does not need to know about it for
-# first order extensions. This also means you can use any custom loss function.
-
-model.zero_grad()
-loss = F.cross_entropy(model(x), y, reduction="sum")
-with backpack(BatchGrad()):
- loss.backward()
-
-print("{:<20} {:<30} {:<30}".format("Param", "grad", "grad (batch)"))
-print("-" * 80)
-for name, p in model.named_parameters():
- print(
- "{:<20}: {:<30} {:<30}".format(name, str(p.grad.shape), str(p.grad_batch.shape))
- )
-
-# %%
-# To check that everything works, let's compute one individual gradient with
-# PyTorch (using a single sample in a forward and backward pass)
-# and compare it with the one computed by BackPACK.
-
-sample_to_check = 1
-x_to_check = x[sample_to_check, :].unsqueeze(0)
-y_to_check = y[sample_to_check].unsqueeze(0)
-
-model.zero_grad()
-loss = F.cross_entropy(model(x_to_check), y_to_check)
-loss.backward()
-
-print("Do the individual gradients match?")
-for name, p in model.named_parameters():
- match = torch.allclose(p.grad_batch[sample_to_check, :], p.grad, atol=1e-7)
- print("{:<20} {}".format(name, match))
+# This tutorial has moved. Click
+# `here `_
+# to continue to its new location.
diff --git a/docs_src/examples/use_cases/example_resnet_all_in_one.py b/docs_src/examples/use_cases/example_resnet_all_in_one.py
new file mode 100644
index 000000000..0b3398e3b
--- /dev/null
+++ b/docs_src/examples/use_cases/example_resnet_all_in_one.py
@@ -0,0 +1,316 @@
+"""Residual networks
+====================
+"""
+# %%
+# There are three different approaches to using BackPACK with ResNets.
+#
+# 1. :ref:`Custom ResNet`: (Only works for first-order extensions) Write your own model
+# by defining its forward pass. Trainable parameters must be in modules known to
+# BackPACK (e.g. :class:`torch.nn.Conv2d`, :class:`torch.nn.Linear`).
+#
+# 2. :ref:`Custom ResNet with BackPACK custom modules`: (Works for first- and second-
+# order extensions) Build your ResNet with custom modules provided by BackPACK
+# without overwriting the forward pass. This approach is useful if you want to
+# understand how BackPACK handles ResNets, or if you think building a container
+# module that implicitly defines the forward pass is more elegant than coding up
+# a forward pass.
+#
+# 3. :ref:`Any ResNet with BackPACK's converter`: (Works for first- and second-order
+# extensions) Convert your model into a BackPACK-compatible architecture.
+#
+# .. note::
+# ResNets are still an experimental feature. Always double-check your
+# results, as done in this example! Open an issue if you encounter a bug to help
+# us improve the support.
+#
+# Not all extensions support ResNets (yet). Please create a feature request in the
+# repository if the extension you need is not supported.
+
+# %%
+# Let's get the imports out of the way.
+
+from torch import (
+ allclose,
+ cat,
+ cuda,
+ device,
+ int32,
+ linspace,
+ manual_seed,
+ rand,
+ rand_like,
+)
+from torch.nn import (
+ Conv2d,
+ CrossEntropyLoss,
+ Flatten,
+ Identity,
+ Linear,
+ Module,
+ MSELoss,
+ ReLU,
+ Sequential,
+)
+from torch.nn.functional import cross_entropy, relu
+from torchvision.models import resnet18
+
+from backpack import backpack, extend
+from backpack.custom_module.branching import Parallel, SumModule
+from backpack.custom_module.graph_utils import BackpackTracer
+from backpack.extensions import BatchGrad, DiagGGNExact
+from backpack.utils.examples import autograd_diag_ggn_exact, load_one_batch_mnist
+
+manual_seed(0)
+DEVICE = device("cuda:0" if cuda.is_available() else "cpu")
+x, y = load_one_batch_mnist(batch_size=32)
+x, y = x.to(DEVICE), y.to(DEVICE)
+
+
+# %%
+# Custom ResNet
+# -------------
+# We can build a ResNet by extending :py:class:`torch.nn.Module`.
+# As long as the layers with parameters (:py:class:`torch.nn.Conv2d`
+# and :py:class:`torch.nn.Linear`) are ``nn`` modules, BackPACK can extend them,
+# and this is all that is needed for first-order extensions.
+# We can rewrite the :code:`forward` to implement the residual connection,
+# and :py:func:`extend() ` the resulting model.
+#
+# .. note::
+# Using in-place operations is not compatible with PyTorch's
+# :meth:`torch.nn.Module.register_full_backward_hook`. Therefore,
+# always use :code:`x = x + residual` instead of :code:`x += residual`.
+
+
+class MyFirstResNet(Module):
+ def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10):
+ """Instantiate submodules that are used in the forward pass."""
+ super().__init__()
+
+ self.conv1 = Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1)
+ self.conv2 = Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1)
+ self.linear1 = Linear(input_dim[0] * input_dim[1] * C_hid, output_dim)
+ if C_in == C_hid:
+ self.shortcut = Identity()
+ else:
+ self.shortcut = Conv2d(C_in, C_hid, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ """Manual implementation of the forward pass."""
+ residual = self.shortcut(x)
+ x = self.conv2(relu(self.conv1(x)))
+ x = x + residual # don't use: x += residual
+ x = x.flatten(start_dim=1)
+ x = self.linear1(x)
+ return x
+
+
+model = extend(MyFirstResNet()).to(DEVICE)
+
+# %%
+# The loss does not need to be extended in this case either, as it does not
+# have model parameters and BackPACK does not need to know about it for
+# first-order extensions. This also means you can use any custom loss function.
+#
+# Using :py:class:`BatchGrad ` in a
+# :py:class:`with backpack(...) ` block,
+# we can access the individual gradients for each sample.
+
+loss = cross_entropy(model(x), y, reduction="sum")
+
+with backpack(BatchGrad()):
+ loss.backward()
+
+for name, parameter in model.named_parameters():
+ print(f"{name:>20}'s grad_batch shape: {parameter.grad_batch.shape}")
+
+# %%
+# To check that everything works, let's compute one individual gradient with
+# PyTorch (using a single sample in a forward and backward pass)
+# and compare it with the one computed by BackPACK.
+
+sample_to_check = 1
+x_to_check = x[[sample_to_check]]
+y_to_check = y[[sample_to_check]]
+
+model.zero_grad()
+loss = cross_entropy(model(x_to_check), y_to_check)
+loss.backward()
+
+print("Do the individual gradients match?")
+for name, parameter in model.named_parameters():
+ match = allclose(parameter.grad_batch[sample_to_check], parameter.grad, atol=1e-6)
+ print(f"{name:>20}: {match}")
+ if not match:
+ raise AssertionError("Individual gradients don't match!")
+
+# %%
+# Custom ResNet with BackPACK custom modules
+# -------------
+# Second-order extensions only work if every node in the computation graph is an
+# ``nn`` module that can be extended by BackPACK. The above ResNet class
+# :py:class:`MyFirstResNet` does not satisfy these conditions, because
+# it implements the skip connection via :py:func:`torch.add` while overwriting the
+# :py:meth:`forward() ` method.
+#
+# To build ResNets without overwriting the forward pass, BackPACK offers custom modules:
+#
+# 1. :py:class:`Parallel` is similar to
+# :py:class:`torch.nn.Sequential`, but implements a container for a parallel sequence
+# of modules (followed by an aggregation module), rather than a sequential one.
+#
+# 2. :py:class:`SumModule` is the module that takes the
+# role of :py:func:`torch.add` in the previous example. It sums up multiple inputs.
+# We will use it to merge the skip connection.
+#
+# With the above modules, we can build a simple ResNet as a container that implicitly
+# defines the forward pass:
+
+C_in = 1
+C_hid = 2
+input_dim = (28, 28)
+output_dim = 10
+
+model = Sequential(
+ Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1),
+ ReLU(),
+ Parallel( # skip connection with ReLU-activated convolution
+ Identity(),
+ Sequential(
+ Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1),
+ ReLU(),
+ ),
+ merge_module=SumModule(),
+ ),
+ Flatten(),
+ Linear(input_dim[0] * input_dim[1] * C_hid, output_dim),
+)
+
+model = extend(model.to(DEVICE))
+loss_function = extend(CrossEntropyLoss(reduction="mean")).to(DEVICE)
+
+
+# %%
+# This ResNets supports BackPACK's second-order extensions:
+
+loss = loss_function(model(x), y)
+
+with backpack(DiagGGNExact()):
+ loss.backward()
+
+for name, parameter in model.named_parameters():
+ print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")
+
+diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()])
+
+# %%
+# Comparison with :py:mod:`torch.autograd`:
+#
+# .. note::
+#
+# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation
+# can be slow, depending on the number of parameters. To reduce run time, we only
+# compare some elements of the diagonal.
+
+num_params = sum(p.numel() for p in model.parameters())
+num_to_compare = 10
+idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)
+
+diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
+ x, y, model, loss_function, idx=idx_to_compare
+)
+
+print("Do the exact GGN diagonals match?")
+for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
+ match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6)
+ print(f"Diagonal entry {idx:>6}: {match}")
+ if not match:
+ raise AssertionError("Exact GGN diagonals don't match!")
+
+# %%
+# Any ResNet with BackPACK's converter
+# -------------
+# If you are not building a ResNet through custom modules but for instance want to
+# use a prominent ResNet from :py:mod:`torchvision.models`, BackPACK offers a converter.
+# It analyzes the model and tries to turn it into a compatible architecture. The result
+# is a :py:class:`torch.fx.GraphModule` that exclusively consists of modules.
+#
+# Here, we demo the converter on :py:class:`resnet18 `.
+#
+# .. note::
+#
+# :py:class:`resnet18 ` has to be in evaluation mode,
+# because it contains batch normalization layers that are not supported in train
+# mode by the second-order extension used in this example.
+#
+# Let's create the model, and convert it in the call to :py:func:`extend `:
+
+loss_function = extend(MSELoss().to(DEVICE))
+model = resnet18(num_classes=5).to(DEVICE).eval()
+
+# use BackPACK's converter to extend the model (turned off by default)
+model = extend(model, use_converter=True)
+
+# %%
+# To get an understanding what happened, we can inspect the model's graph with the
+# following helper function:
+
+
+def print_table(module: Module) -> None:
+ """Prints a table of the module.
+
+ Args:
+ module: module to analyze
+ """
+ graph = BackpackTracer().trace(module)
+ graph.print_tabular()
+
+
+print_table(model)
+
+# %%
+# Admittedly, the converted :py:class:`resnet18 `'s graph
+# is quite large. Note however that it fully consists of modules (indicated by
+# ``call_module`` in the first table column) such that BackPACK's hooks can
+# successfully backpropagate additional information for its second-order extensions
+# (first-order extensions work, too).
+#
+# Let's verify that second-order extensions are working:
+
+x = rand(4, 3, 7, 7, device=DEVICE) # (128, 3, 224, 224)
+output = model(x)
+y = rand_like(output)
+
+loss = loss_function(output, y)
+
+with backpack(DiagGGNExact()):
+ loss.backward()
+
+for name, parameter in model.named_parameters():
+ print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")
+
+diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()])
+
+# %%
+# Comparison with :py:mod:`torch.autograd`:
+#
+# .. note::
+#
+# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation
+# can be slow, depending on the number of parameters. To reduce run time, we only
+# compare some elements of the diagonal.
+
+num_params = sum(p.numel() for p in model.parameters())
+num_to_compare = 10
+idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)
+
+diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
+ x, y, model, loss_function, idx=idx_to_compare
+)
+
+print("Do the exact GGN diagonals match?")
+for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
+ match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6)
+ print(f"Diagonal entry {idx:>8}: {match}")
+ if not match:
+ raise AssertionError("Exact GGN diagonals don't match!")
diff --git a/docs_src/examples/use_cases/example_rnn.py b/docs_src/examples/use_cases/example_rnn.py
new file mode 100644
index 000000000..283b1a900
--- /dev/null
+++ b/docs_src/examples/use_cases/example_rnn.py
@@ -0,0 +1,306 @@
+"""Recurrent networks
+====================
+"""
+# %%
+# There are two different approaches to using BackPACK with RNNs.
+#
+# 1. :ref:`Custom RNN with BackPACK custom modules`:
+# Build your RNN with custom modules provided by BackPACK
+# without overwriting the forward pass. This approach is useful if you want to
+# understand how BackPACK handles RNNs, or if you think building a container
+# module that implicitly defines the forward pass is more elegant than coding up
+# a forward pass.
+#
+# 2. :ref:`RNN with BackPACK's converter`:
+# Automatically convert your model into a BackPACK-compatible architecture.
+#
+# .. note::
+# RNNs are still an experimental feature. Always double-check your
+# results, as done in this example! Open an issue if you encounter a bug to help
+# us improve the support.
+#
+# Not all extensions support RNNs (yet). Please create a feature request in the
+# repository if the extension you need is not supported.
+
+# %%
+# Let's get the imports out of the way.
+from torch import (
+ allclose,
+ cat,
+ device,
+ int32,
+ linspace,
+ manual_seed,
+ nn,
+ randint,
+ zeros_like,
+)
+
+from backpack import backpack, extend
+from backpack.custom_module.graph_utils import BackpackTracer
+from backpack.custom_module.permute import Permute
+from backpack.custom_module.reduce_tuple import ReduceTuple
+from backpack.extensions import BatchGrad, DiagGGNExact
+from backpack.utils.examples import autograd_diag_ggn_exact
+
+manual_seed(0)
+DEVICE = device("cpu") # Verification via autograd only works on CPU
+
+
+# %%
+# For this demo, we will use the Tolstoi Char RNN from
+# `DeepOBS `_.
+# This network is trained on Leo Tolstoi's War and Peace
+# and learns to predict the next character.
+class TolstoiCharRNN(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 8
+ self.hidden_dim = 64
+ self.num_layers = 2
+ self.seq_len = 15
+ self.vocab_size = 25
+
+ self.embedding = nn.Embedding(
+ num_embeddings=self.vocab_size, embedding_dim=self.hidden_dim
+ )
+ self.dropout = nn.Dropout(p=0.2)
+ self.lstm = nn.LSTM(
+ input_size=self.hidden_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ dropout=0.36,
+ batch_first=True,
+ )
+ # deactivate redundant bias
+ self.lstm.bias_ih_l0.data = zeros_like(self.lstm.bias_ih_l0)
+ self.lstm.bias_ih_l1.data = zeros_like(self.lstm.bias_ih_l1)
+ self.lstm.bias_ih_l0.requires_grad = False
+ self.lstm.bias_ih_l1.requires_grad = False
+ self.dense = nn.Linear(
+ in_features=self.hidden_dim, out_features=self.vocab_size
+ )
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = self.dropout(x)
+ x, _ = self.lstm(x) # last return values are hidden states
+ x = self.dropout(x)
+ output = self.dense(x)
+ output = output.permute(0, 2, 1) # [N, T, C] → [N, C, T]
+ return output
+
+ def input_target_fn(self):
+ input = randint(0, self.vocab_size, (self.batch_size, self.seq_len))
+ # target is the input shifted by 1 in time axis
+ target = cat(
+ [
+ randint(0, self.vocab_size, (self.batch_size, 1)),
+ input[:, :-1],
+ ],
+ dim=1,
+ )
+ return input.to(DEVICE), target.to(DEVICE)
+
+ def loss_fn(self) -> nn.Module:
+ return nn.CrossEntropyLoss().to(DEVICE)
+
+
+manual_seed(1)
+tolstoi_char_rnn = TolstoiCharRNN().to(DEVICE).eval()
+loss_function = extend(tolstoi_char_rnn.loss_fn())
+x, y = tolstoi_char_rnn.input_target_fn()
+# %%
+# Note that instead of the real data set, we will feed synthetic data to the network for
+# simplicity. We also use the network in evaluation mode. This disables the
+# :py:class:`Dropout ` layers and allows double-checking our results
+# via :py:mod:`torch.autograd`.
+#
+# Custom RNN with BackPACK custom modules
+# -------------
+# Second-order extensions only work if every node in the computation graph is an
+# ``nn`` module that can be extended by BackPACK. The above RNN
+# :py:class:`TolstoiCharRNN` does not satisfy these conditions, because
+# it has a multi-layer :py:class:`torch.nn.LSTM` and implicitly uses the
+# :py:func:`getitem` (for unpacking) and :py:meth:`permute() `
+# functions in the :py:meth:`forward() ` method.
+#
+# To build RNN without overwriting the forward pass, BackPACK offers custom modules:
+#
+# 1. :py:class:`ReduceTuple `
+#
+# 2. :py:class:`Permute `
+#
+# With the above modules, we can build a simple RNN as a container that implicitly
+# defines the forward pass:
+manual_seed(1) # same seed as used to initialize `tolstoi_char_rnn`
+tolstoi_char_rnn_custom = nn.Sequential(
+ nn.Embedding(tolstoi_char_rnn.vocab_size, tolstoi_char_rnn.hidden_dim),
+ nn.Dropout(p=0.2),
+ nn.LSTM(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.hidden_dim, batch_first=True),
+ ReduceTuple(index=0),
+ nn.Dropout(p=0.36),
+ nn.LSTM(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.hidden_dim, batch_first=True),
+ ReduceTuple(index=0),
+ nn.Dropout(p=0.2),
+ nn.Linear(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.vocab_size),
+ Permute(0, 2, 1),
+)
+tolstoi_char_rnn_custom.eval().to(DEVICE)
+
+# %%
+# Let's check that both models have the same forward pass.
+for name, p in tolstoi_char_rnn_custom.named_parameters():
+ if "bias_ih_l" in name:
+ # deactivate redundant bias
+ p.data = zeros_like(p.data)
+ p.requires_grad = False
+
+match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x))
+print(f"Forward pass of custom model matches TolstoiCharRNN? {match}")
+
+if not match:
+ raise AssertionError("Forward passes don't match.")
+
+# %%
+# We can :py:func:`extend ` our model and the loss function to
+# compute BackPACK extensions.
+
+tolstoi_char_rnn_custom = extend(tolstoi_char_rnn_custom)
+loss = loss_function(tolstoi_char_rnn_custom(x), y)
+
+with backpack(BatchGrad(), DiagGGNExact()):
+ loss.backward()
+
+for name, param in tolstoi_char_rnn_custom.named_parameters():
+ if param.requires_grad:
+ print(
+ name,
+ param.shape,
+ param.grad_batch.shape,
+ param.diag_ggn_exact.shape,
+ )
+
+# %%
+# Comparison of the GGN diagonal extension with :py:mod:`torch.autograd`:
+#
+# .. note::
+#
+# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation
+# can be slow, depending on the number of parameters. To reduce run time, we only
+# compare some elements of the diagonal.
+trainable_params = [p for p in tolstoi_char_rnn_custom.parameters() if p.requires_grad]
+
+diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in trainable_params])
+
+num_params = sum(p.numel() for p in trainable_params)
+num_to_compare = 10
+idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)
+
+diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
+ x, y, tolstoi_char_rnn_custom, loss_function, idx=idx_to_compare
+)
+
+print("Do the exact GGN diagonals match?")
+for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
+ match = allclose(element, diag_ggn_exact_vector[idx])
+ print(
+ f"Diagonal entry {idx:>8}: {match}:"
+ + f"\t{element:.5e}, {diag_ggn_exact_vector[idx]:.5e}"
+ )
+ if not match:
+ raise AssertionError("Exact GGN diagonals don't match!")
+
+# %%
+# RNN with BackPACK's converter
+# -------------
+# If you are not building an RNN through custom modules but for instance want to
+# directly use the Tolstoi Char RNN, BackPACK offers a converter.
+# It analyzes the model and tries to turn it into a compatible architecture. The result
+# is a :py:class:`torch.fx.GraphModule` that exclusively consists of modules.
+#
+# Here, we demonstrate the converter on the above Tolstoi Char RNN. Let's convert it
+# while :py:func:`extend `-ing the model:
+
+# use BackPACK's converter to extend the model (turned off by default)
+tolstoi_char_rnn = extend(tolstoi_char_rnn, use_converter=True)
+
+# %%
+# To get an understanding what happened, we can inspect the model's graph with the
+# following helper function:
+
+
+def print_table(module: nn.Module) -> None:
+ """Prints a table of the module.
+
+ Args:
+ module: module to analyze
+ """
+ graph = BackpackTracer().trace(module)
+ graph.print_tabular()
+
+
+print_table(tolstoi_char_rnn)
+
+# %%
+# Note that the computation graph fully consists of modules (indicated by
+# ``call_module`` in the first table column) such that BackPACK's hooks can
+# successfully backpropagate additional information for its second-order extensions
+# (first-order extensions work, too).
+#
+# First, let's compare the forward pass with the custom module from the previous
+# section to make sure the converter worked fine:
+
+match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x))
+print(f"Forward pass of extended TolstoiCharRNN matches custom model? {match}")
+
+if not match:
+ raise AssertionError("Forward passes don't match.")
+
+# %%
+#
+# Now let's verify that second-order extensions (GGN diagonal) are working:
+output = tolstoi_char_rnn(x)
+loss = loss_function(output, y)
+
+with backpack(DiagGGNExact()):
+ loss.backward()
+
+for name, parameter in tolstoi_char_rnn.named_parameters():
+ if parameter.requires_grad:
+ print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")
+
+diag_ggn_exact_vector = cat(
+ [
+ p.diag_ggn_exact.flatten()
+ for p in tolstoi_char_rnn.parameters()
+ if p.requires_grad
+ ]
+)
+
+# %%
+# Finally, we compare BackPACK's GGN diagonal with :py:mod:`torch.autograd`:
+#
+# .. note::
+#
+# Computing the full GGN diagonal with PyTorch's built-in automatic differentiation
+# can be slow, depending on the number of parameters. To reduce run time, we only
+# compare some elements of the diagonal.
+
+num_params = sum(p.numel() for p in tolstoi_char_rnn.parameters() if p.requires_grad)
+num_to_compare = 10
+idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)
+
+diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
+ x, y, tolstoi_char_rnn, loss_function, idx=idx_to_compare
+)
+
+print("Do the exact GGN diagonals match?")
+for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
+ match = allclose(element, diag_ggn_exact_vector[idx])
+ print(
+ f"Diagonal entry {idx:>8}: {match}:"
+ + f"\t{element:.5e}, {diag_ggn_exact_vector[idx]:.5e}"
+ )
+ if not match:
+ raise AssertionError("Exact GGN diagonals don't match!")
diff --git a/docs_src/examples/use_cases/example_subsampling.py b/docs_src/examples/use_cases/example_subsampling.py
new file mode 100644
index 000000000..b89d01a3d
--- /dev/null
+++ b/docs_src/examples/use_cases/example_subsampling.py
@@ -0,0 +1,86 @@
+"""Mini-batch sub-sampling
+==========================
+
+By default, BackPACK's extensions consider all samples in a mini-batch. Some extensions
+support limiting the computations to a subset of samples. This example shows how to
+restrict the computations to such a subset of samples.
+
+This may be interesting for applications where parts of the samples are used for
+different purposes, e.g. computing curvature and gradient information on different
+subsets. Limiting the computations to fewer samples also reduces costs.
+
+.. note::
+ Not all extensions support sub-sampling yet. Please create a feature request in the
+ repository if the extension you need is not supported.
+"""
+
+# %%
+# Let's start by loading some dummy data and extending the model
+
+from torch import allclose, cuda, device, manual_seed
+from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
+
+from backpack import backpack, extend
+from backpack.extensions import BatchGrad
+from backpack.utils.examples import load_one_batch_mnist
+
+# make deterministic
+manual_seed(0)
+
+dev = device("cuda" if cuda.is_available() else "cpu")
+
+# data
+X, y = load_one_batch_mnist(batch_size=128)
+X, y = X.to(dev), y.to(dev)
+
+# model
+model = Sequential(Flatten(), Linear(784, 10)).to(dev)
+lossfunc = CrossEntropyLoss().to(dev)
+
+model = extend(model)
+lossfunc = extend(lossfunc)
+
+# %%
+# Individual gradients for a mini-batch subset
+# --------------------------------------------
+#
+# Let's say we only want to compute individual gradients for samples 0, 1,
+# 13, and 42. Naively, we could perform the computation for all samples, then
+# slice out the samples we care about.
+
+# selected samples
+subsampling = [0, 1, 13, 42]
+
+loss = lossfunc(model(X), y)
+
+with backpack(BatchGrad()):
+ loss.backward()
+
+# naive approach: compute for all, slice out relevant
+naive = [p.grad_batch[subsampling] for p in model.parameters()]
+
+# %%
+# This is not efficient, as individual gradients are computed for all samples,
+# most of them being discarded after. We can do better by specifying the active
+# samples directly with the ``subsampling`` argument of
+# :py:class:`BatchGrad `.
+
+loss = lossfunc(model(X), y)
+
+# efficient approach: specify active samples in backward pass
+with backpack(BatchGrad(subsampling=subsampling)):
+ loss.backward()
+
+efficient = [p.grad_batch for p in model.parameters()]
+
+# %%
+# Let's verify that both ways yield the same result:
+
+match = all(
+ allclose(g_naive, g_efficient) for g_naive, g_efficient in zip(naive, efficient)
+)
+
+print(f"Naive and efficient sub-sampled individual gradients match? {match}")
+
+if not match:
+ raise ValueError("Naive and efficient sub-sampled individual gradient don't match.")
diff --git a/docs_src/examples/use_cases/example_trace_estimation.py b/docs_src/examples/use_cases/example_trace_estimation.py
index 86c24ae78..8c2a57107 100644
--- a/docs_src/examples/use_cases/example_trace_estimation.py
+++ b/docs_src/examples/use_cases/example_trace_estimation.py
@@ -225,7 +225,7 @@ def hutchinson_trace_autodiff_blockwise(V):
plt.semilogx(
V_list,
- trace_estimates,
+ [trace_estimate.cpu() for trace_estimate in trace_estimates],
linestyle="--",
color="orange",
label="Hutchinson" if i == 0 else None,
diff --git a/docs_src/rtd/extensions.rst b/docs_src/rtd/extensions.rst
index 2fa9e2369..9eea8df02 100644
--- a/docs_src/rtd/extensions.rst
+++ b/docs_src/rtd/extensions.rst
@@ -25,6 +25,8 @@ Available Extensions
.. autofunction:: backpack.extensions.KFRA
.. autofunction:: backpack.extensions.DiagHessian
.. autofunction:: backpack.extensions.BatchDiagHessian
+.. autofunction:: backpack.extensions.SqrtGGNExact
+.. autofunction:: backpack.extensions.SqrtGGNMC
-----
diff --git a/docs_src/rtd/good-to-know.rst b/docs_src/rtd/good-to-know.rst
index 23911ea9e..18131e80f 100644
--- a/docs_src/rtd/good-to-know.rst
+++ b/docs_src/rtd/good-to-know.rst
@@ -19,7 +19,7 @@ and a backward pass over each sample individually.
This is slow, but can be used to check that the values returned by BackPACK
match what you expect them to be.
-While we test many a use-case and try to write solid code, unexpected
+While we test many use-cases and try to write solid code, unexpected
behavior (such as some listed on this page) or bugs are not impossible.
We recommend that you check that the outputs match your expectations,
especially if you're using non-default values on slightly more unusual parameters
@@ -55,35 +55,29 @@ are not affected by :py:meth:`zero_grad() `.
The :ref:`intro example ` shows how to make a model
using a :py:class:`torch.nn.Sequential` module
-and how to :py:func:`extend() ` the model and the loss function,
-but this setup is only really necessary for
-:ref:`second order quantities `.
+and how to :py:func:`extend() ` the model and the loss function.
+But extending everything is only really necessary for
+:ref:`Second order quantities `.
For those, BackPACK needs to know about the structure of the whole network
-to propagate additional information.
-
-:ref:`First order extensions ` are more flexible,
-and the only :py:class:`torch.nn.Module` that need to be extended
-are modules with parameters, to extract more information,
-as the gradients are already propagated by PyTorch.
-For every operations that is not parametrized, you can use standard operations
-from the :std:doc:`torch.nn.functional ` module or standard
-tensor operations. This makes it possible to use first order extensions
-for ResNets (see :ref:`this example `).
+to backpropagate additional information.
+Because :ref:`First order extensions ` don't
+backpropagate additional information, they are more flexible and only require
+every parameterized :py:class:`torch.nn.Module` be extended. For any
+unparameterized operation, you can use standard operations from the
+:std:doc:`torch.nn.functional ` module or standard tensor
+operations.
Not (yet) supported models
----------------------------------
-The second-order extensions for BackPACK don't support (yet) residual networks,
-and no extension support recurrent architectures.
-We're working on how to handle those, as well as adding more
-:ref:`layers `.
-Along those lines, some things that will (most likely) not work with BackPACK,
-but that we're trying to build support for:
-
-- Reusing the same parameters or module multiple time in the computation graph.
+We're working on handling more complex computation graphs, as well as adding
+more :ref:`layers `. Along those lines, some things that will
+(most likely) **not** work with BackPACK are:
- For second order extensions, this also holds for any module,
- whether or not they have parameters.
- This sadly mean that BackPACK can't compute the individual gradients or
- second-order information of a L2-regularized loss, for example.
+- **Reusing the same parameters multiple times in the computation graph:** This
+ sadly means that BackPACK can't compute the individual gradients or
+ second-order information of an L2-regularized loss or architectures that use
+ parameters multiple times.
+- **Some exotic hyperparameters are not fully supported:** Feature requests on
+ the repository are welcome.
diff --git a/docs_src/rtd/main-api.rst b/docs_src/rtd/main-api.rst
index 67a4d0c52..dce6792e2 100644
--- a/docs_src/rtd/main-api.rst
+++ b/docs_src/rtd/main-api.rst
@@ -66,5 +66,6 @@ and the :ref:`Supported models`.
-----
.. autofunction:: backpack.extend
-.. autofunction:: backpack.backpack
+.. autoclass:: backpack.backpack
+ :members: __init__
.. autofunction:: backpack.disable
diff --git a/docs_src/rtd/supported-layers.rst b/docs_src/rtd/supported-layers.rst
index 149bb7dc2..41887141c 100644
--- a/docs_src/rtd/supported-layers.rst
+++ b/docs_src/rtd/supported-layers.rst
@@ -14,12 +14,13 @@ For example,
torch.nn.Linear(64, 10)
)
-This page lists the layers currently supported by BackPACK.
-
+**If you overwrite any** :code:`forward()` **function** (for example in ResNets
+and RNNs), the additional backward pass to compute second-order quantities will
+break. You can ask BackPACK to inspect the graph and try converting it
+into a compatible structure by setting :code:`use_converter=True` in
+:py:func:`extend `.
-**Do not rewrite the** :code:`forward()` **function of the** :code:`Sequential` **or the inner modules!**
-If the forward is not standard, the additional backward pass to compute second-order quantities will not match the actual function.
-First-order extensions that extract information might work outside of this framework, but it is not tested.
+This page lists the layers currently supported by BackPACK.
.. raw:: html
@@ -38,55 +39,72 @@ parameters of the following layers;
- :py:class:`torch.nn.ConvTranspose1d`,
:py:class:`torch.nn.ConvTranspose2d`,
:py:class:`torch.nn.ConvTranspose3d`
+- :py:class:`torch.nn.BatchNorm1d` (evaluation mode),
+ :py:class:`torch.nn.BatchNorm2d` (evaluation mode),
+ :py:class:`torch.nn.BatchNorm3d` (evaluation mode)
+- :py:class:`torch.nn.Embedding`
+- :py:class:`torch.nn.RNN`, :py:class:`torch.nn.LSTM`
-First-order extensions should support any module as long as they do not have parameters,
-but some layers lead to the concept of "individual gradient for a sample in a minibatch"
-to be ill-defined, as they introduce dependencies across examples
-(like :py:class:`torch.nn.BatchNorm`).
+Some layers (like :code:`torch.nn.BatchNormNd` in training mode) mix samples and
+lead to ill-defined first-order quantities.
-----
For second-order extensions
--------------------------------------
-BackPACK needs to know how to propagate second-order information.
-This is implemented for:
-
-+-------------------------------+---------------------------------------+
-| **Parametrized layers** | :py:class:`torch.nn.Conv1d`, |
-| | :py:class:`torch.nn.Conv2d`, |
-| | :py:class:`torch.nn.Conv3d` |
-| +---------------------------------------+
-| | :py:class:`torch.nn.ConvTranspose1d`, |
-| | :py:class:`torch.nn.ConvTranspose2d`, |
-| | :py:class:`torch.nn.ConvTranspose3d` |
-| +---------------------------------------+
-| | :py:class:`torch.nn.Linear` |
-+-------------------------------+---------------------------------------+
-| **Loss functions** | :py:class:`torch.nn.MSELoss` |
-| +---------------------------------------+
-| | :py:class:`torch.nn.CrossEntropyLoss` |
-+-------------------------------+---------------------------------------+
-| **Layers without parameters** | :py:class:`torch.nn.MaxPool1d`, |
-| | :py:class:`torch.nn.MaxPool2d`, |
-| | :py:class:`torch.nn.MaxPool3d` |
-| +---------------------------------------+
-| | :py:class:`torch.nn.AvgPool1d`, |
-| | :py:class:`torch.nn.AvgPool2d`, |
-| | :py:class:`torch.nn.AvgPool3d` |
-| +---------------------------------------+
-| | :py:class:`torch.nn.ZeroPad2d`, |
-| +---------------------------------------+
-| | :py:class:`torch.nn.Dropout` |
-| +---------------------------------------+
-| | :py:class:`torch.nn.ReLU`, |
-| | :py:class:`torch.nn.Sigmoid`, |
-| | :py:class:`torch.nn.Tanh`, |
-| | :py:class:`torch.nn.LeakyReLU`, |
-| | :py:class:`torch.nn.LogSigmoid`, |
-| | :py:class:`torch.nn.ELU`, |
-| | :py:class:`torch.nn.SELU` |
-+-------------------------------+---------------------------------------+
-
-Some exotic hyperparameters are not fully supported, but feature requests
-on the repository are welcome.
+BackPACK needs to know how to backpropagate additional information for
+second-order quantities. This is implemented for:
+
++-------------------------------+-----------------------------------------------+
+| **Parametrized layers** | :py:class:`torch.nn.Conv1d`, |
+| | :py:class:`torch.nn.Conv2d`, |
+| | :py:class:`torch.nn.Conv3d` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.ConvTranspose1d`, |
+| | :py:class:`torch.nn.ConvTranspose2d`, |
+| | :py:class:`torch.nn.ConvTranspose3d` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.Linear` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.BatchNorm1d`, |
+| | :py:class:`torch.nn.BatchNorm2d`, |
+| | :py:class:`torch.nn.BatchNorm3d` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.Embedding` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.RNN`, |
+| | :py:class:`torch.nn.LSTM` |
++-------------------------------+-----------------------------------------------+
+| **Loss functions** | :py:class:`torch.nn.MSELoss` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.CrossEntropyLoss` |
++-------------------------------+-----------------------------------------------+
+| **Layers without parameters** | :py:class:`torch.nn.MaxPool1d`, |
+| | :py:class:`torch.nn.MaxPool2d`, |
+| | :py:class:`torch.nn.MaxPool3d` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.AvgPool1d`, |
+| | :py:class:`torch.nn.AvgPool2d`, |
+| | :py:class:`torch.nn.AvgPool3d` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.AdaptiveAvgPool1d`, |
+| | :py:class:`torch.nn.AdaptiveAvgPool2d`, |
+| | :py:class:`torch.nn.AdaptiveAvgPool3d` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.ZeroPad2d`, |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.Dropout` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.ReLU`, |
+| | :py:class:`torch.nn.Sigmoid`, |
+| | :py:class:`torch.nn.Tanh`, |
+| | :py:class:`torch.nn.LeakyReLU`, |
+| | :py:class:`torch.nn.LogSigmoid`, |
+| | :py:class:`torch.nn.ELU`, |
+| | :py:class:`torch.nn.SELU` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.Identity` |
+| +-----------------------------------------------+
+| | :py:class:`torch.nn.Flatten` |
++-------------------------------+-----------------------------------------------+
diff --git a/fully_documented.txt b/fully_documented.txt
index 3eaa0f12f..de3c98730 100644
--- a/fully_documented.txt
+++ b/fully_documented.txt
@@ -1,14 +1,110 @@
-test/extensions/test_backprop_extension.py
+setup.py
-test/extensions/secondorder/secondorder_settings.py
+backpack/__init__.py
+backpack/context.py
+backpack/custom_module/
-test/extensions/secondorder/hbp
+backpack/core/derivatives/basederivatives.py
+backpack/core/derivatives/rnn.py
+backpack/core/derivatives/shape_check.py
+backpack/core/derivatives/__init__.py
+backpack/core/derivatives/permute.py
+backpack/core/derivatives/lstm.py
+backpack/core/derivatives/linear.py
+backpack/core/derivatives/adaptive_avg_pool_nd.py
+backpack/core/derivatives/batchnorm_nd.py
+backpack/core/derivatives/embedding.py
+backpack/core/derivatives/crossentropyloss.py
+backpack/core/derivatives/scale_module.py
+backpack/core/derivatives/sum_module.py
+backpack/core/derivatives/dropout.py
+backpack/extensions/__init__.py
+backpack/extensions/backprop_extension.py
+backpack/extensions/module_extension.py
+backpack/extensions/saved_quantities.py
+backpack/extensions/mat_to_mat_jac_base.py
+backpack/extensions/firstorder/base.py
+backpack/extensions/firstorder/gradient/base.py
+backpack/extensions/firstorder/gradient/rnn.py
+backpack/extensions/firstorder/gradient/__init__.py
+backpack/extensions/firstorder/gradient/batchnorm_nd.py
+backpack/extensions/firstorder/gradient/embedding.py
+backpack/extensions/firstorder/batch_grad/batch_grad_base.py
+backpack/extensions/firstorder/batch_grad/rnn.py
+backpack/extensions/firstorder/batch_grad/__init__.py
+backpack/extensions/firstorder/batch_grad/batchnorm_nd.py
+backpack/extensions/firstorder/batch_grad/embedding.py
+backpack/extensions/firstorder/variance/variance_base.py
+backpack/extensions/firstorder/variance/rnn.py
+backpack/extensions/firstorder/variance/__init__.py
+backpack/extensions/firstorder/variance/batchnorm_nd.py
+backpack/extensions/firstorder/variance/embedding.py
+backpack/extensions/firstorder/sum_grad_squared/sgs_base.py
+backpack/extensions/firstorder/sum_grad_squared/rnn.py
+backpack/extensions/firstorder/sum_grad_squared/__init__.py
+backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py
+backpack/extensions/firstorder/sum_grad_squared/embedding.py
+backpack/extensions/firstorder/batch_l2_grad/
backpack/extensions/secondorder/__init__.py
-
+backpack/extensions/secondorder/diag_ggn/__init__.py
+backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py
+backpack/extensions/secondorder/diag_ggn/rnn.py
+backpack/extensions/secondorder/diag_ggn/permute.py
+backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py
+backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py
+backpack/extensions/secondorder/diag_ggn/embedding.py
+backpack/extensions/secondorder/diag_ggn/custom_module.py
backpack/extensions/secondorder/diag_hessian/__init__.py
backpack/extensions/secondorder/diag_hessian/conv1d.py
backpack/extensions/secondorder/diag_hessian/conv2d.py
backpack/extensions/secondorder/diag_hessian/conv3d.py
+backpack/extensions/secondorder/sqrt_ggn/
-backpack/extensions/__init__.py
+backpack/hessianfree/ggnvp.py
+
+backpack/utils/linear.py
+backpack/utils/subsampling.py
+backpack/utils/errors.py
+backpack/utils/__init__.py
+backpack/utils/module_classification.py
+backpack/utils/hooks.py
+backpack/utils/examples.py
+
+test/extensions/automated_settings.py
+test/extensions/problem.py
+test/extensions/utils.py
+test/extensions/test_backprop_extension.py
+test/extensions/test_hooks.py
+test/extensions/graph_clear_test.py
+test/extensions/firstorder/firstorder_settings.py
+test/extensions/firstorder/variance/
+test/extensions/firstorder/batch_grad/batch_grad_settings.py
+test/extensions/firstorder/batch_grad/test_batch_grad.py
+test/extensions/secondorder/secondorder_settings.py
+test/extensions/secondorder/diag_ggn/
+test/extensions/secondorder/hbp/
+test/extensions/secondorder/sqrt_ggn/
+test/extensions/implementation/base.py
+test/extensions/implementation/autograd.py
+test/extensions/implementation/backpack.py
+test/adaptive_avg_pool/
+test/core/derivatives/derivatives_test.py
+test/core/derivatives/__init__.py
+test/core/derivatives/rnn_settings.py
+test/core/derivatives/utils.py
+test/core/derivatives/implementation/
+test/core/derivatives/permute_settings.py
+test/core/derivatives/lstm_settings.py
+test/core/derivatives/pooling_adaptive_settings.py
+test/core/derivatives/batch_norm_settings.py
+test/core/derivatives/embedding_settings.py
+test/core/derivatives/scale_module_settings.py
+test/utils/evaluation_mode.py
+test/utils/skip_test.py
+test/utils/__init__.py
+test/converter/
+test/utils/test_subsampling.py
+test/custom_module/
+test/test_retain_graph.py
+test/test_batch_first.py
diff --git a/makefile b/makefile
index fbcd353f1..b0d4726a3 100644
--- a/makefile
+++ b/makefile
@@ -56,16 +56,16 @@ help:
###
# Test coverage
test:
- @pytest -vx --run-optional-tests=montecarlo --cov=backpack .
+ @pytest -vx -rs --run-optional-tests=montecarlo --cov=backpack .
test-light:
- @pytest -vx --cov=backpack .
+ @pytest -vx -rs --cov=backpack .
test-no-gpu:
- @pytest -k "not cuda" -vx --run-optional-tests=montecarlo --cov=backpack .
+ @pytest -k "not cuda" -vx -rs --run-optional-tests=montecarlo --cov=backpack .
test-light-no-gpu:
- @pytest -k "not cuda" -vx --cov=backpack .
+ @pytest -k "not cuda" -vx -rs --cov=backpack .
###
# Linter and autoformatter
diff --git a/setup.cfg b/setup.cfg
index bf2aa19f0..b74c2d594 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -22,7 +22,6 @@ classifiers =
Development Status :: 4 - Beta
License :: OSI Approved :: MIT License
Operating System :: OS Independent
- Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
@@ -35,11 +34,11 @@ setup_requires =
setuptools_scm
# Dependencies of the project (semicolon/line-separated):
install_requires =
- torch >= 1.6.0, < 2.0.0
+ torch >= 1.9.0, < 2.0.0
torchvision >= 0.7.0, < 1.0.0
einops >= 0.3.0, < 1.0.0
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
-python_requires = >=3.6
+python_requires = >=3.7
[options.packages.find]
exclude = test*
@@ -78,6 +77,7 @@ docs =
matplotlib
sphinx-gallery
memory_profiler
+ tabulate
###############################################################################
# Development tool configurations #
diff --git a/test/adaptive_avg_pool/__init__.py b/test/adaptive_avg_pool/__init__.py
new file mode 100644
index 000000000..d9ed32847
--- /dev/null
+++ b/test/adaptive_avg_pool/__init__.py
@@ -0,0 +1,4 @@
+"""Module tests AdaptiveAvgPoolNDDerivatives.
+
+Especially the shape checker for equivalence with AvgPoolND.
+"""
diff --git a/test/adaptive_avg_pool/problem.py b/test/adaptive_avg_pool/problem.py
new file mode 100644
index 000000000..e96713ed4
--- /dev/null
+++ b/test/adaptive_avg_pool/problem.py
@@ -0,0 +1,179 @@
+"""Test problems for the AdaptiveAvgPool shape checker."""
+from __future__ import annotations
+
+import copy
+from test.automated_test import check_sizes_and_values
+from test.core.derivatives.utils import get_available_devices
+from typing import Any, Dict, List, Tuple, Union
+
+import torch
+from torch import Tensor, randn
+from torch.nn import (
+ AdaptiveAvgPool1d,
+ AdaptiveAvgPool2d,
+ AdaptiveAvgPool3d,
+ AvgPool1d,
+ AvgPool2d,
+ AvgPool3d,
+ Module,
+)
+
+from backpack import extend
+from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives
+
+
+def make_test_problems(settings: List[Dict[str, Any]]) -> List[AdaptiveAvgPoolProblem]:
+ """Creates the test problem from settings.
+
+ Args:
+ settings: list of dictionaries with settings
+
+ Returns:
+ a list of the test problems
+ """
+ problem_dicts: List[Dict[str, Any]] = []
+
+ for setting in settings:
+ setting = add_missing_defaults(setting)
+ devices = setting["device"]
+
+ for dev in devices:
+ problem = copy.deepcopy(setting)
+ problem["device"] = dev
+ problem_dicts.append(problem)
+
+ return [AdaptiveAvgPoolProblem(**p) for p in problem_dicts]
+
+
+def add_missing_defaults(setting: Dict[str, Any]) -> Dict[str, Any]:
+ """Add missing entries in settings such that the new format works.
+
+ Args:
+ setting: dictionary with required settings and some optional settings
+
+ Returns:
+ complete settings including the default values for missing optional settings
+
+ Raises:
+ ValueError: if the settings do not work
+ """
+ required = ["N", "shape_input", "shape_target", "works"]
+ optional = {
+ "id_prefix": "",
+ "device": get_available_devices(),
+ "seed": 0,
+ }
+
+ for req in required:
+ if req not in setting.keys():
+ raise ValueError(f"Missing configuration entry for {req}")
+
+ for opt, default in optional.items():
+ if opt not in setting.keys():
+ setting[opt] = default
+
+ for s in setting.keys():
+ if s not in required and s not in optional.keys():
+ raise ValueError(f"Unknown config: {s}")
+
+ return setting
+
+
+class AdaptiveAvgPoolProblem:
+ """Test problem for testing AdaptiveAvgPoolNDDerivatives.check_parameters()."""
+
+ def __init__(
+ self,
+ N: int,
+ shape_input: Any,
+ shape_target: Tuple[int],
+ works: bool,
+ device,
+ seed: int,
+ id_prefix: str,
+ ):
+ """Initialization.
+
+ Args:
+ N: number of dimensions
+ shape_input: input shape
+ shape_target: target shape
+ works: whether the test should run without errors
+ device: device
+ seed: seed for torch
+ id_prefix: prefix for problem id
+
+ Raises:
+ NotImplementedError: if N is not in [1, 2, 3]
+ """
+ if N not in [1, 2, 3]:
+ raise NotImplementedError(f"N={N} not implemented in test suite.")
+ self.N = N
+ self.shape_input = shape_input
+ self.shape_target = shape_target
+ self.works = works
+ self.device = device
+ self.seed = seed
+ self.id_prefix = id_prefix
+
+ def make_id(self) -> str:
+ """Create an id from problem parameters.
+
+ Returns:
+ problem id
+ """
+ prefix = (self.id_prefix + "-") if self.id_prefix != "" else ""
+ return (
+ prefix + f"dev={self.device}-N={self.N}-in={self.shape_input}-"
+ f"out={self.shape_target}-works={self.works}"
+ )
+
+ def set_up(self) -> None:
+ """Set up problem and do one forward pass."""
+ torch.manual_seed(self.seed)
+ self.module = self._make_module()
+ self.input = randn(self.shape_input)
+ self.output = self.module(self.input)
+
+ def tear_down(self):
+ """Delete created torch variables."""
+ del self.module
+ del self.input
+ del self.output
+
+ def _make_module(
+ self,
+ ) -> Union[AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d]:
+ map_class = {1: AdaptiveAvgPool1d, 2: AdaptiveAvgPool2d, 3: AdaptiveAvgPool3d}
+ module = map_class[self.N](output_size=self.shape_target)
+ return extend(module.to(device=self.device))
+
+ def check_parameters(self) -> None:
+ """Key method for test.
+
+ Run the AdaptiveAvgPoolNDDerivatives.check_parameters() method.
+ """
+ self._get_derivatives().check_parameters(module=self.module)
+
+ def _get_derivatives(self) -> AdaptiveAvgPoolNDDerivatives:
+ return AdaptiveAvgPoolNDDerivatives(N=self.N)
+
+ def check_equivalence(self) -> None:
+ """Check if the given parameters lead to the same output.
+
+ Checks the sizes and values.
+ """
+ stride, kernel_size, _ = self._get_derivatives().get_avg_pool_parameters(
+ self.module
+ )
+ module_equivalent: Module = self._make_module_equivalent(stride, kernel_size)
+ output_equivalent: Tensor = module_equivalent(self.input)
+
+ check_sizes_and_values(self.output, output_equivalent)
+
+ def _make_module_equivalent(
+ self, stride: List[int], kernel_size: List[int]
+ ) -> Union[AvgPool1d, AvgPool2d, AvgPool3d]:
+ map_class = {1: AvgPool1d, 2: AvgPool2d, 3: AvgPool3d}
+ module = map_class[self.N](kernel_size=kernel_size, stride=stride)
+ return module.to(self.device)
diff --git a/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py b/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py
new file mode 100644
index 000000000..58fdaaebb
--- /dev/null
+++ b/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py
@@ -0,0 +1,49 @@
+"""Settings to run test_adaptive_avg_pool_nd."""
+from typing import Any, Dict, List
+
+from torch import Size
+
+SETTINGS: List[Dict[str, Any]] = [
+ {
+ "N": 1,
+ "shape_target": 2,
+ "shape_input": (1, 5, 8),
+ "works": True,
+ },
+ {
+ "N": 1,
+ "shape_target": 2,
+ "shape_input": (1, 8, 7),
+ "works": False,
+ },
+ {
+ "N": 2,
+ "shape_target": Size((4, 3)),
+ "shape_input": (1, 64, 8, 9),
+ "works": True,
+ },
+ {
+ "N": 2,
+ "shape_target": 2,
+ "shape_input": (1, 64, 8, 10),
+ "works": True,
+ },
+ {
+ "N": 2,
+ "shape_target": 2,
+ "shape_input": (1, 64, 8, 9),
+ "works": False,
+ },
+ {
+ "N": 2,
+ "shape_target": (5, 2),
+ "shape_input": (1, 64, 64, 10),
+ "works": False,
+ },
+ {
+ "N": 3,
+ "shape_target": (None, 2, None),
+ "shape_input": (1, 64, 7, 10, 5),
+ "works": True,
+ },
+]
diff --git a/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py b/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py
new file mode 100644
index 000000000..eebf1931b
--- /dev/null
+++ b/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py
@@ -0,0 +1,28 @@
+"""Test the shape checker of AdaptiveAvgPoolNDDerivatives."""
+from test.adaptive_avg_pool.problem import AdaptiveAvgPoolProblem, make_test_problems
+from test.adaptive_avg_pool.settings_adaptive_avg_pool_nd import SETTINGS
+from typing import List
+
+import pytest
+
+PROBLEMS: List[AdaptiveAvgPoolProblem] = make_test_problems(SETTINGS)
+IDS: List[str] = [problem.make_id() for problem in PROBLEMS]
+
+
+@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
+def test_adaptive_avg_pool_check_parameters(problem: AdaptiveAvgPoolProblem):
+ """Test AdaptiveAvgPoolNDDerivatives.check_parameters().
+
+ Additionally check if returned parameters are indeed equivalent.
+
+ Args:
+ problem: test problem
+ """
+ problem.set_up()
+ if problem.works:
+ problem.check_parameters()
+ problem.check_equivalence()
+ else:
+ with pytest.raises(NotImplementedError):
+ problem.check_parameters()
+ problem.tear_down()
diff --git a/test/benchmark/jvp.py b/test/benchmark/jvp.py
index 7e5c89256..1a1a1464b 100644
--- a/test/benchmark/jvp.py
+++ b/test/benchmark/jvp.py
@@ -1,11 +1,11 @@
from functools import partial
+from test.core.derivatives import derivatives_for
import pytest
import torch
from torch import allclose
from torch.nn import Dropout, ReLU, Sigmoid, Tanh
-from backpack.core.derivatives import derivatives_for
from backpack.hessianfree.lop import transposed_jacobian_vector_product
from backpack.hessianfree.rop import jacobian_vector_product
@@ -130,7 +130,7 @@ def bp_jtv_weight_func(module, vin):
def f():
r = (
derivatives_for[module.__class__]()
- .weight_jac_t_mat_prod(module, None, None, vin)
+ .param_mjp("weight", module, None, None, vin)
.contiguous()
)
if vin.is_cuda:
@@ -160,7 +160,7 @@ def bp_jtv_bias_func(module, vin):
def f():
r = (
derivatives_for[module.__class__]()
- .bias_jac_t_mat_prod(module, None, None, vin.unsqueeze(2))
+ .param_mjp("bias", module, None, None, vin.unsqueeze(2))
.contiguous()
)
if vin.is_cuda:
diff --git a/test/bugfixes_test.py b/test/bugfixes_test.py
index 088ecd907..ba4ef8a19 100644
--- a/test/bugfixes_test.py
+++ b/test/bugfixes_test.py
@@ -4,6 +4,7 @@
import torch
import backpack
+from backpack.core.derivatives.convnd import weight_jac_t_save_memory
def parameters_issue_30():
@@ -31,7 +32,12 @@ def parameters_issue_30():
@pytest.mark.parametrize("params", **parameters_issue_30())
-def test_convolutions_stride_issue_30(params):
+@pytest.mark.parametrize(
+ "save_memory",
+ [True, False],
+ ids=["save_memory=True", "save_memory=False"],
+)
+def test_convolutions_stride_issue_30(params, save_memory):
"""
https://github.com/f-dangel/backpack/issues/30
@@ -51,7 +57,9 @@ def test_convolutions_stride_issue_30(params):
backpack.extend(mod)
x = torch.randn(size=(params["N"], params["C_in"], params["W"], params["H"]))
- with backpack.backpack(backpack.extensions.BatchGrad()):
+ with weight_jac_t_save_memory(save_memory), backpack.backpack(
+ backpack.extensions.BatchGrad()
+ ):
loss = torch.sum(mod(x))
loss.backward()
diff --git a/test/converter/__init__.py b/test/converter/__init__.py
new file mode 100644
index 000000000..0699f9ffe
--- /dev/null
+++ b/test/converter/__init__.py
@@ -0,0 +1 @@
+"""Contains tests for the converter and ResNets."""
diff --git a/test/converter/converter_cases.py b/test/converter/converter_cases.py
new file mode 100644
index 000000000..0715be0b5
--- /dev/null
+++ b/test/converter/converter_cases.py
@@ -0,0 +1,294 @@
+"""Test cases for the converter.
+
+Network with resnet18
+Network with inplace activation
+Network with parameter-free module used in multiple places
+Network with flatten operation
+Network with multiply operation
+Network with add operation
+"""
+import abc
+from typing import List, Type
+
+from torch import Tensor, flatten, permute, rand, randint, transpose, zeros_like
+from torch.nn import (
+ LSTM,
+ RNN,
+ CrossEntropyLoss,
+ Dropout,
+ Embedding,
+ Linear,
+ Module,
+ MSELoss,
+ ReLU,
+)
+from torchvision.models import resnet18, wide_resnet50_2
+
+
+class ConverterModule(Module, abc.ABC):
+ """Interface class for test modules for converter."""
+
+ @abc.abstractmethod
+ def input_fn(self) -> Tensor:
+ """Generate a fitting input for the module.
+
+ Returns:
+ an input
+ """
+ return
+
+ def loss_fn(self) -> Module:
+ """The loss function.
+
+ Returns:
+ loss function
+ """
+ return MSELoss()
+
+
+CONVERTER_MODULES: List[Type[ConverterModule]] = []
+
+
+class _ResNet18(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.resnet18 = resnet18(num_classes=4).eval()
+
+ def forward(self, x):
+ return self.resnet18(x)
+
+ def input_fn(self) -> Tensor:
+ return rand(2, 3, 7, 7)
+
+
+class _WideResNet50(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.wide_resnet50_2 = wide_resnet50_2(num_classes=4).eval()
+
+ def forward(self, x):
+ return self.wide_resnet50_2(x)
+
+ def input_fn(self) -> Tensor:
+ return rand(2, 3, 7, 7)
+
+
+class _InplaceActivation(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 3
+ self.in_dim = 3
+ out_dim = 2
+ self.linear = Linear(self.in_dim, out_dim)
+ self.relu = ReLU(inplace=True)
+ self.linear2 = Linear(out_dim, out_dim)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.relu(x)
+ x = self.linear2(x)
+ return x
+
+ def input_fn(self) -> Tensor:
+ return rand(self.batch_size, self.in_dim)
+
+
+class _MultipleUsages(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 3
+ self.in_dim = 3
+ out_dim = 2
+ self.linear = Linear(self.in_dim, out_dim)
+ self.relu = ReLU()
+ self.linear2 = Linear(out_dim, out_dim)
+
+ def forward(self, x):
+ x = self.relu(x)
+ x = self.linear(x)
+ x = self.relu(x)
+ x = self.linear2(x)
+ x = self.relu(x)
+ return x
+
+ def input_fn(self) -> Tensor:
+ return rand(self.batch_size, self.in_dim)
+
+
+class _FlattenNetwork(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 3
+ self.in_dim = (2, 2, 4)
+ out_dim = 3
+ self.linear = Linear(self.in_dim[2], out_dim)
+ self.linear2 = Linear(self.in_dim[1] * out_dim, out_dim)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = flatten(x, 2) # built-in function flatten
+ x = self.linear2(x)
+ x = x.flatten(1) # method flatten
+ return x
+
+ def input_fn(self) -> Tensor:
+ return rand(self.batch_size, *self.in_dim)
+
+
+class _Multiply(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 2
+ self.in_dim = 4
+ out_dim = 3
+ self.linear = Linear(self.in_dim, out_dim)
+
+ def forward(self, x):
+ x = x * 2.5 # built-in method multiply (Tensor-float)
+ x = self.linear(x)
+ x = 0.5 * x # built-in method multiply (float-Tensor)
+ x = x.multiply(3.1415) # method multiply
+ return x
+
+ def input_fn(self) -> Tensor:
+ return rand(self.batch_size, self.in_dim)
+
+
+class _Add(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 3
+ self.in_dim = 3
+ out_dim = 2
+ self.linear = Linear(self.in_dim, self.in_dim)
+ self.linear1 = Linear(self.in_dim, out_dim)
+ self.linear2 = Linear(self.in_dim, out_dim)
+ self.relu = ReLU()
+
+ def forward(self, x):
+ x = self.linear(x)
+ x1 = self.linear1(x)
+ x2 = self.linear2(x)
+ x = x1 + x2 # built-in method add
+ x = self.relu(x)
+ x = x.add(x2) # method add
+ return x
+
+ def input_fn(self) -> Tensor:
+ return rand(self.batch_size, self.in_dim)
+
+
+class _Permute(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 3
+ self.in_dim = (5, 3)
+ out_dim = 2
+ self.linear = Linear(self.in_dim[-1], out_dim)
+ self.linear2 = Linear(self.in_dim[-2], out_dim)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = x.permute(0, 2, 1) # method permute
+ x = self.linear2(x)
+ x = permute(x, (0, 2, 1)) # function permute
+ return x
+
+ def input_fn(self) -> Tensor:
+ return rand(self.batch_size, *self.in_dim)
+
+ def loss_fn(self) -> Module:
+ return CrossEntropyLoss()
+
+
+class _Transpose(ConverterModule):
+ def __init__(self):
+ super().__init__()
+ self.batch_size = 3
+ self.in_dim = (5, 3)
+ out_dim = 2
+ out_dim2 = 3
+ self.linear = Linear(self.in_dim[-1], out_dim)
+ self.linear2 = Linear(self.in_dim[-2], out_dim2)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = x.transpose(1, 2) # method transpose
+ x = self.linear2(x)
+ x = transpose(x, 1, 2) # function transpose
+ return x
+
+ def input_fn(self) -> Tensor:
+ return rand(self.batch_size, *self.in_dim)
+
+ def loss_fn(self) -> Module:
+ return CrossEntropyLoss()
+
+
+class _TolstoiCharRNN(ConverterModule):
+ def __init__(self):
+ super(_TolstoiCharRNN, self).__init__()
+ self.batch_size = 8
+ self.hidden_dim = 64
+ self.num_layers = 2
+ self.seq_len = 15
+ self.vocab_size = 25
+
+ self.embedding = Embedding(
+ num_embeddings=self.vocab_size, embedding_dim=self.hidden_dim
+ )
+ self.dropout = Dropout(p=0.2)
+ self.lstm = LSTM(
+ input_size=self.hidden_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ dropout=0.36,
+ batch_first=True,
+ )
+ self.lstm.bias_ih_l0.data = zeros_like(self.lstm.bias_ih_l0)
+ self.lstm.bias_ih_l1.data = zeros_like(self.lstm.bias_ih_l1)
+ self.lstm.bias_ih_l0.requires_grad = False
+ self.lstm.bias_ih_l1.requires_grad = False
+ self.dense = Linear(in_features=self.hidden_dim, out_features=self.vocab_size)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = self.dropout(x)
+ x, new_state = self.lstm(x)
+ x = self.dropout(x)
+ output = self.dense(x)
+ output = output.permute(0, 2, 1)
+ return output
+
+ def input_fn(self) -> Tensor:
+ return randint(0, self.vocab_size, (self.batch_size, self.seq_len))
+
+ def loss_fn(self) -> Module:
+ return CrossEntropyLoss()
+
+
+class _TolstoiRNNVersion(_TolstoiCharRNN):
+ def __init__(self):
+ super(_TolstoiRNNVersion, self).__init__()
+ self.lstm = RNN(
+ input_size=self.hidden_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ dropout=0.36,
+ batch_first=True,
+ )
+
+
+CONVERTER_MODULES += [
+ _ResNet18,
+ _WideResNet50,
+ _InplaceActivation,
+ _MultipleUsages,
+ _FlattenNetwork,
+ _Multiply,
+ _Add,
+ _Permute,
+ _Transpose,
+ _TolstoiCharRNN,
+ _TolstoiRNNVersion,
+]
diff --git a/test/converter/resnet_cases.py b/test/converter/resnet_cases.py
new file mode 100644
index 000000000..4d5e00de7
--- /dev/null
+++ b/test/converter/resnet_cases.py
@@ -0,0 +1,146 @@
+"""Contains example ResNets to be used in tests."""
+from torch import flatten, tensor
+from torch.nn import (
+ AdaptiveAvgPool2d,
+ BatchNorm2d,
+ Conv2d,
+ Linear,
+ MaxPool2d,
+ Module,
+ MSELoss,
+ ReLU,
+ Sequential,
+ Tanh,
+)
+from torchvision.models.resnet import BasicBlock, conv1x1
+
+
+class ResNet1(Module):
+ """Small ResNet."""
+
+ def __init__(self, in_dim: int = 2, out_dim: int = 10):
+ """Initialization.
+
+ Args:
+ in_dim: input dimensions
+ out_dim: output dimensions
+ """
+ super().__init__()
+ self.net = Sequential(
+ Linear(in_dim, out_dim),
+ Tanh(),
+ Linear(out_dim, out_dim),
+ Tanh(),
+ Linear(out_dim, in_dim),
+ )
+
+ def forward(self, input):
+ """Forward pass. One Euler step.
+
+ Args:
+ input: input tensor
+
+ Returns:
+ result
+ """
+ x = self.net(input)
+ return input + x * 0.1
+
+ input_test = tensor([[1.0, 2.0]])
+ target_test = tensor([[1.0, 1.0]])
+ loss_test = MSELoss()
+
+
+class ResNet2(Module):
+ """Replicates resnet18 but a lot smaller."""
+
+ num_classes: int = 3
+ batch_size: int = 2
+ picture_width: int = 7
+ inplanes = 2
+
+ input_test = (batch_size, 3, picture_width, picture_width)
+ target_test = (batch_size, num_classes)
+ loss_test = MSELoss()
+
+ def __init__(self):
+ """Initialization."""
+ super().__init__()
+ self.inplanes = ResNet2.inplanes
+
+ self.conv1 = Conv2d(
+ 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
+ )
+ self.bn1 = BatchNorm2d(self.inplanes)
+ self.relu = ReLU(inplace=True)
+ self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(BasicBlock, ResNet2.inplanes, 2)
+ self.layer2 = self._make_layer(BasicBlock, 2 * ResNet2.inplanes, 2, stride=2)
+ self.layer3 = self._make_layer(BasicBlock, 4 * ResNet2.inplanes, 2, stride=2)
+ self.avgpool = AdaptiveAvgPool2d((1, 1))
+ self.fc = Linear(4 * ResNet2.inplanes, self.num_classes)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x: input tensor
+
+ Returns:
+ result
+ """
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ x = self.avgpool(x)
+ x = flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ """Creates a concatenation of blocks in the ResNet.
+
+ This function is similar to the one in torchvision/resnets.
+ https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
+
+ Args:
+ block: basic block to use (with one skip connection)
+ planes: number of parallel planes
+ blocks: number of sequential blocks
+ stride: factor between input and output planes
+
+ Returns:
+ a sequence of blocks
+ """
+ norm_layer = BatchNorm2d
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = [
+ block(self.inplanes, planes, stride, downsample, 1, 64, 1, norm_layer)
+ ]
+ self.inplanes = planes * block.expansion
+ layers += [
+ block(
+ self.inplanes,
+ planes,
+ groups=1,
+ base_width=64,
+ dilation=1,
+ norm_layer=norm_layer,
+ )
+ for _ in range(1, blocks)
+ ]
+
+ return Sequential(*layers)
diff --git a/test/converter/test_converter.py b/test/converter/test_converter.py
new file mode 100644
index 000000000..860637910
--- /dev/null
+++ b/test/converter/test_converter.py
@@ -0,0 +1,94 @@
+"""Tests converter.
+
+- whether converted network is equivalent to original network
+- whether DiagGGN runs without errors on new network
+"""
+from test.converter.converter_cases import CONVERTER_MODULES, ConverterModule
+from test.core.derivatives.utils import classification_targets, regression_targets
+from typing import Tuple
+
+from pytest import fixture
+from torch import Tensor, allclose, cat, int32, linspace, manual_seed
+from torch.nn import CrossEntropyLoss, Module, MSELoss
+
+from backpack import backpack, extend
+from backpack.extensions import DiagGGNExact
+from backpack.utils.examples import autograd_diag_ggn_exact
+
+
+@fixture(
+ params=CONVERTER_MODULES,
+ ids=[str(model_class) for model_class in CONVERTER_MODULES],
+)
+def model_and_input(request) -> Tuple[Module, Tensor, Module]:
+ """Yield ResNet model and an input to it.
+
+ Args:
+ request: pytest request
+
+ Yields:
+ model and input and loss function
+ """
+ manual_seed(0)
+ model: ConverterModule = request.param()
+ inputs: Tensor = model.input_fn()
+ loss_fn: Module = model.loss_fn()
+ yield model, inputs, loss_fn
+ del model, inputs, loss_fn
+
+
+def test_network_diag_ggn(model_and_input):
+ """Test whether the given module can compute diag_ggn.
+
+ This test is placed here, because some models are too big to run with PyTorch.
+ Thus, a full diag_ggn comparison with PyTorch is impossible.
+ This test just checks whether it runs on BackPACK without errors.
+ Additionally, it checks whether the forward pass is identical to the original model.
+ Finally, a small number of elements of DiagGGN are compared.
+
+ Args:
+ model_and_input: module to test
+
+ Raises:
+ NotImplementedError: if loss_fn is not MSELoss or CrossEntropyLoss
+ """
+ model_original, x, loss_fn = model_and_input
+ model_original = model_original.eval()
+ output_compare = model_original(x)
+ if isinstance(loss_fn, MSELoss):
+ y = regression_targets(output_compare.shape)
+ elif isinstance(loss_fn, CrossEntropyLoss):
+ y = classification_targets(
+ (output_compare.shape[0], *output_compare.shape[2:]),
+ output_compare.shape[1],
+ )
+ else:
+ raise NotImplementedError(f"test cannot handle loss_fn = {type(loss_fn)}")
+
+ num_params = sum(p.numel() for p in model_original.parameters() if p.requires_grad)
+ num_to_compare = 10
+ idx_to_compare = linspace(0, num_params - 1, num_to_compare, dtype=int32)
+ diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
+ x, y, model_original, loss_fn, idx=idx_to_compare
+ )
+
+ model_extended = extend(model_original, use_converter=True, debug=True)
+ output = model_extended(x)
+
+ assert allclose(output, output_compare)
+
+ loss = extend(loss_fn)(output, y)
+
+ with backpack(DiagGGNExact()):
+ loss.backward()
+
+ diag_ggn_exact_vector = cat(
+ [
+ p.diag_ggn_exact.flatten()
+ for p in model_extended.parameters()
+ if p.requires_grad
+ ]
+ )
+
+ for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
+ assert allclose(element, diag_ggn_exact_vector[idx], atol=1e-5)
diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py
index e5a0962f9..f1aca1af8 100644
--- a/test/core/derivatives/__init__.py
+++ b/test/core/derivatives/__init__.py
@@ -1 +1,117 @@
"""Test functionality of `backpack.core.derivatives` module."""
+from torch.nn import (
+ ELU,
+ LSTM,
+ RNN,
+ SELU,
+ AdaptiveAvgPool1d,
+ AdaptiveAvgPool2d,
+ AdaptiveAvgPool3d,
+ AvgPool1d,
+ AvgPool2d,
+ AvgPool3d,
+ BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
+ Conv1d,
+ Conv2d,
+ Conv3d,
+ ConvTranspose1d,
+ ConvTranspose2d,
+ ConvTranspose3d,
+ CrossEntropyLoss,
+ Dropout,
+ Embedding,
+ Identity,
+ LeakyReLU,
+ Linear,
+ LogSigmoid,
+ MaxPool1d,
+ MaxPool2d,
+ MaxPool3d,
+ MSELoss,
+ ReLU,
+ Sigmoid,
+ Tanh,
+ ZeroPad2d,
+)
+
+from backpack.core.derivatives.adaptive_avg_pool_nd import (
+ AdaptiveAvgPool1dDerivatives,
+ AdaptiveAvgPool2dDerivatives,
+ AdaptiveAvgPool3dDerivatives,
+)
+from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives
+from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives
+from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives
+from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
+from backpack.core.derivatives.conv1d import Conv1DDerivatives
+from backpack.core.derivatives.conv2d import Conv2DDerivatives
+from backpack.core.derivatives.conv3d import Conv3DDerivatives
+from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives
+from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives
+from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives
+from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives
+from backpack.core.derivatives.dropout import DropoutDerivatives
+from backpack.core.derivatives.elu import ELUDerivatives
+from backpack.core.derivatives.embedding import EmbeddingDerivatives
+from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives
+from backpack.core.derivatives.linear import LinearDerivatives
+from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives
+from backpack.core.derivatives.lstm import LSTMDerivatives
+from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives
+from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives
+from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives
+from backpack.core.derivatives.mseloss import MSELossDerivatives
+from backpack.core.derivatives.permute import PermuteDerivatives
+from backpack.core.derivatives.relu import ReLUDerivatives
+from backpack.core.derivatives.rnn import RNNDerivatives
+from backpack.core.derivatives.scale_module import ScaleModuleDerivatives
+from backpack.core.derivatives.selu import SELUDerivatives
+from backpack.core.derivatives.sigmoid import SigmoidDerivatives
+from backpack.core.derivatives.sum_module import SumModuleDerivatives
+from backpack.core.derivatives.tanh import TanhDerivatives
+from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives
+from backpack.custom_module.branching import SumModule
+from backpack.custom_module.permute import Permute
+from backpack.custom_module.scale_module import ScaleModule
+
+derivatives_for = {
+ Linear: LinearDerivatives,
+ Conv1d: Conv1DDerivatives,
+ Conv2d: Conv2DDerivatives,
+ Conv3d: Conv3DDerivatives,
+ AvgPool1d: AvgPool1DDerivatives,
+ AvgPool2d: AvgPool2DDerivatives,
+ AvgPool3d: AvgPool3DDerivatives,
+ MaxPool1d: MaxPool1DDerivatives,
+ MaxPool2d: MaxPool2DDerivatives,
+ MaxPool3d: MaxPool3DDerivatives,
+ ZeroPad2d: ZeroPad2dDerivatives,
+ Dropout: DropoutDerivatives,
+ ReLU: ReLUDerivatives,
+ Tanh: TanhDerivatives,
+ Sigmoid: SigmoidDerivatives,
+ ConvTranspose1d: ConvTranspose1DDerivatives,
+ ConvTranspose2d: ConvTranspose2DDerivatives,
+ ConvTranspose3d: ConvTranspose3DDerivatives,
+ LeakyReLU: LeakyReLUDerivatives,
+ LogSigmoid: LogSigmoidDerivatives,
+ ELU: ELUDerivatives,
+ SELU: SELUDerivatives,
+ CrossEntropyLoss: CrossEntropyLossDerivatives,
+ MSELoss: MSELossDerivatives,
+ RNN: RNNDerivatives,
+ Permute: PermuteDerivatives,
+ LSTM: LSTMDerivatives,
+ AdaptiveAvgPool1d: AdaptiveAvgPool1dDerivatives,
+ AdaptiveAvgPool2d: AdaptiveAvgPool2dDerivatives,
+ AdaptiveAvgPool3d: AdaptiveAvgPool3dDerivatives,
+ BatchNorm1d: BatchNormNdDerivatives,
+ BatchNorm2d: BatchNormNdDerivatives,
+ BatchNorm3d: BatchNormNdDerivatives,
+ Embedding: EmbeddingDerivatives,
+ ScaleModule: ScaleModuleDerivatives,
+ Identity: ScaleModuleDerivatives,
+ SumModule: SumModuleDerivatives,
+}
diff --git a/test/core/derivatives/batch_norm_settings.py b/test/core/derivatives/batch_norm_settings.py
new file mode 100644
index 000000000..7994e1716
--- /dev/null
+++ b/test/core/derivatives/batch_norm_settings.py
@@ -0,0 +1,57 @@
+"""Test configurations for `backpack.core.derivatives` BatchNorm layers.
+
+Required entries:
+ "module_fn" (callable): Contains a model constructed from `torch.nn` layers
+ "input_fn" (callable): Used for specifying input function
+
+Optional entries:
+ "target_fn" (callable): Fetches the groundtruth/target classes
+ of regression/classification task
+ "loss_function_fn" (callable): Loss function used in the model
+ "device" [list(torch.device)]: List of devices to run the test on.
+ "id_prefix" (str): Prefix to be included in the test name.
+ "seed" (int): seed for the random number for torch.rand
+"""
+from test.utils.evaluation_mode import initialize_batch_norm_eval
+
+from torch import rand
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+BATCH_NORM_SETTINGS = [
+ {
+ "module_fn": lambda: BatchNorm1d(num_features=7),
+ "input_fn": lambda: rand(size=(5, 7)),
+ },
+ {
+ "module_fn": lambda: BatchNorm1d(num_features=7),
+ "input_fn": lambda: rand(size=(5, 7, 4)),
+ },
+ {
+ "module_fn": lambda: BatchNorm2d(num_features=7),
+ "input_fn": lambda: rand(size=(5, 7, 3, 4)),
+ },
+ {
+ "module_fn": lambda: BatchNorm3d(num_features=3),
+ "input_fn": lambda: rand(size=(5, 3, 3, 4, 2)),
+ },
+ {
+ "module_fn": lambda: initialize_batch_norm_eval(BatchNorm1d(num_features=7)),
+ "input_fn": lambda: rand(size=(5, 7)),
+ "id_prefix": "training=False",
+ },
+ {
+ "module_fn": lambda: initialize_batch_norm_eval(BatchNorm1d(num_features=7)),
+ "input_fn": lambda: rand(size=(5, 7, 4)),
+ "id_prefix": "training=False",
+ },
+ {
+ "module_fn": lambda: initialize_batch_norm_eval(BatchNorm2d(num_features=7)),
+ "input_fn": lambda: rand(size=(5, 7, 3, 4)),
+ "id_prefix": "training=False",
+ },
+ {
+ "module_fn": lambda: initialize_batch_norm_eval(BatchNorm3d(num_features=7)),
+ "input_fn": lambda: rand(size=(5, 7, 3, 4, 2)),
+ "id_prefix": "training=False",
+ },
+]
diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py
index 776edcd03..5e751ffe1 100644
--- a/test/core/derivatives/derivatives_test.py
+++ b/test/core/derivatives/derivatives_test.py
@@ -6,16 +6,29 @@
- Jacobian-matrix products with respect to layer parameters
- Transposed Jacobian-matrix products with respect to layer parameters
"""
-
+from contextlib import nullcontext
from test.automated_test import check_sizes_and_values
+from test.core.derivatives.batch_norm_settings import BATCH_NORM_SETTINGS
+from test.core.derivatives.embedding_settings import EMBEDDING_SETTINGS
from test.core.derivatives.implementation.autograd import AutogradDerivatives
from test.core.derivatives.implementation.backpack import BackpackDerivatives
from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS
-from test.core.derivatives.problem import make_test_problems
+from test.core.derivatives.lstm_settings import LSTM_SETTINGS
+from test.core.derivatives.permute_settings import PERMUTE_SETTINGS
+from test.core.derivatives.problem import DerivativesTestProblem, make_test_problems
+from test.core.derivatives.rnn_settings import RNN_SETTINGS as RNN_SETTINGS
+from test.core.derivatives.scale_module_settings import SCALE_MODULE_SETTINGS
from test.core.derivatives.settings import SETTINGS
+from test.utils.skip_test import (
+ skip_adaptive_avg_pool3d_cuda,
+ skip_batch_norm_train_mode_with_subsampling,
+ skip_subsampling_conflict,
+)
+from typing import List, Union
+from warnings import warn
-import pytest
-import torch
+from pytest import fixture, mark, raises, skip
+from torch import Tensor, rand
from backpack.core.derivatives.convnd import weight_jac_t_save_memory
@@ -32,17 +45,80 @@
LOSS_FAIL_PROBLEMS = make_test_problems(LOSS_FAIL_SETTINGS)
LOSS_FAIL_IDS = [problem.make_id() for problem in LOSS_FAIL_PROBLEMS]
+RNN_PROBLEMS = make_test_problems(RNN_SETTINGS)
+RNN_PROBLEMS += make_test_problems(LSTM_SETTINGS)
+RNN_IDS = [problem.make_id() for problem in RNN_PROBLEMS]
+
+PERMUTE_PROBLEMS = make_test_problems(PERMUTE_SETTINGS)
+PERMUTE_IDS = [problem.make_id() for problem in PERMUTE_PROBLEMS]
+
+BATCH_NORM_PROBLEMS = make_test_problems(BATCH_NORM_SETTINGS)
+BATCH_NORM_IDS = [problem.make_id() for problem in BATCH_NORM_PROBLEMS]
+
+EMBEDDING_PROBLEMS = make_test_problems(EMBEDDING_SETTINGS)
+EMBEDDING_IDS = [problem.make_id() for problem in EMBEDDING_PROBLEMS]
+
+SCALE_MODULE_PROBLEMS = make_test_problems(SCALE_MODULE_SETTINGS)
+SCALE_MODULE_IDS = [problem.make_id() for problem in SCALE_MODULE_PROBLEMS]
-@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
-def test_jac_mat_prod(problem, V=3):
+SUBSAMPLINGS = [None, [0, 0], [2, 0]]
+SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS]
+
+
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+@mark.parametrize("sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"])
+def test_param_mjp(
+ problem: DerivativesTestProblem,
+ sum_batch: bool,
+ subsampling: List[int] or None,
+ request,
+) -> None:
+ """Test all parameter derivatives.
+
+ Args:
+ problem: test problem
+ sum_batch: whether to sum along batch axis
+ subsampling: subsampling indices
+ request: problem request
+ """
+ skip_subsampling_conflict(problem, subsampling)
+ test_save_memory: bool = "Conv" in request.node.callspec.id
+ V = 3
+
+ for param_str, _ in problem.module.named_parameters():
+ print(f"testing derivative wrt {param_str}")
+ for save_memory in [True, False] if test_save_memory else [None]:
+ if test_save_memory:
+ print(f"testing with save_memory={save_memory}")
+
+ mat = rand_mat_like_output(V, problem, subsampling=subsampling)
+ with weight_jac_t_save_memory(
+ save_memory=save_memory
+ ) if test_save_memory else nullcontext():
+ backpack_res = BackpackDerivatives(problem).param_mjp(
+ param_str, mat, sum_batch, subsampling=subsampling
+ )
+ autograd_res = AutogradDerivatives(problem).param_mjp(
+ param_str, mat, sum_batch, subsampling=subsampling
+ )
+
+ check_sizes_and_values(autograd_res, backpack_res)
+
+
+@mark.parametrize(
+ "problem",
+ NO_LOSS_PROBLEMS + RNN_PROBLEMS + PERMUTE_PROBLEMS + BATCH_NORM_PROBLEMS,
+ ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + BATCH_NORM_IDS,
+)
+def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None:
"""Test the Jacobian-matrix product.
Args:
- problem (DerivativesProblem): Problem for derivative test.
- V (int): Number of vectorized Jacobian-vector products.
+ problem: Test case.
+ V: Number of vectorized Jacobian-vector products. Default: ``3``.
"""
problem.set_up()
- mat = torch.rand(V, *problem.input_shape).to(problem.device)
+ mat = rand(V, *problem.input_shape).to(problem.device)
backpack_res = BackpackDerivatives(problem).jac_mat_prod(mat)
autograd_res = AutogradDerivatives(problem).jac_mat_prod(mat)
@@ -51,19 +127,43 @@ def test_jac_mat_prod(problem, V=3):
problem.tear_down()
-@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
-def test_jac_t_mat_prod(problem, V=3):
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+@mark.parametrize(
+ "problem",
+ NO_LOSS_PROBLEMS
+ + RNN_PROBLEMS
+ + PERMUTE_PROBLEMS
+ + BATCH_NORM_PROBLEMS
+ + SCALE_MODULE_PROBLEMS,
+ ids=NO_LOSS_IDS + RNN_IDS + PERMUTE_IDS + BATCH_NORM_IDS + SCALE_MODULE_IDS,
+)
+def test_jac_t_mat_prod(
+ problem: DerivativesTestProblem,
+ subsampling: Union[None, List[int]],
+ request,
+ V: int = 3,
+) -> None:
"""Test the transposed Jacobian-matrix product.
Args:
- problem (DerivativesProblem): Problem for derivative test.
- V (int): Number of vectorized transposed Jacobian-vector products.
+ problem: Problem for derivative test.
+ subsampling: Indices of active samples.
+ request: Pytest request, used for getting id.
+ V: Number of vectorized transposed Jacobian-vector products. Default: ``3``.
"""
+ skip_adaptive_avg_pool3d_cuda(request)
+
problem.set_up()
- mat = torch.rand(V, *problem.output_shape).to(problem.device)
+ skip_batch_norm_train_mode_with_subsampling(problem, subsampling)
+ skip_subsampling_conflict(problem, subsampling)
+ mat = rand_mat_like_output(V, problem, subsampling=subsampling)
- backpack_res = BackpackDerivatives(problem).jac_t_mat_prod(mat)
- autograd_res = AutogradDerivatives(problem).jac_t_mat_prod(mat)
+ backpack_res = BackpackDerivatives(problem).jac_t_mat_prod(
+ mat, subsampling=subsampling
+ )
+ autograd_res = AutogradDerivatives(problem).jac_t_mat_prod(
+ mat, subsampling=subsampling
+ )
check_sizes_and_values(autograd_res, backpack_res)
problem.tear_down()
@@ -77,47 +177,44 @@ def test_jac_t_mat_prod(problem, V=3):
IDS_WITH_WEIGHTS.append(problem_id)
-@pytest.mark.parametrize(
- "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]
-)
-@pytest.mark.parametrize(
- "save_memory",
- [True, False],
- ids=["save_memory=True", "save_memory=False"],
-)
-@pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS)
-def test_weight_jac_t_mat_prod(problem, sum_batch, save_memory, V=3):
- """Test the transposed Jacobian-matrix product w.r.t. to the weights.
+def rand_mat_like_output(
+ V: int, problem: DerivativesTestProblem, subsampling: List[int] = None
+) -> Tensor:
+ """Generate random matrix whose columns are shaped like the layer output.
+
+ Can be used to generate random inputs to functions that act on tensors
+ shaped like the module output (like ``*_jac_t_mat_prod``).
Args:
- problem (DerivativesProblem): Problem for derivative test.
- sum_batch (bool): Sum results over the batch dimension.
- save_memory (bool): Use Owkin implementation to save memory.
- V (int): Number of vectorized transposed Jacobian-vector products.
+ V: Number of rows.
+ problem: Test case.
+ subsampling: Indices of samples used by sub-sampling.
+
+ Returns:
+ Random matrix with (subsampled) output shape.
"""
- problem.set_up()
- mat = torch.rand(V, *problem.output_shape).to(problem.device)
+ subsample_shape = list(problem.output_shape)
- with weight_jac_t_save_memory(save_memory):
- backpack_res = BackpackDerivatives(problem).weight_jac_t_mat_prod(
- mat, sum_batch
- )
- autograd_res = AutogradDerivatives(problem).weight_jac_t_mat_prod(mat, sum_batch)
+ if subsampling is not None:
+ subsample_shape[0] = len(subsampling)
- check_sizes_and_values(autograd_res, backpack_res)
- problem.tear_down()
+ return rand(V, *subsample_shape, device=problem.device)
-@pytest.mark.parametrize("problem", PROBLEMS_WITH_WEIGHTS, ids=IDS_WITH_WEIGHTS)
-def test_weight_jac_mat_prod(problem, V=3):
- """Test the Jacobian-matrix product w.r.t. to the weights.
+@mark.parametrize(
+ "problem",
+ PROBLEMS_WITH_WEIGHTS + BATCH_NORM_PROBLEMS,
+ ids=IDS_WITH_WEIGHTS + BATCH_NORM_IDS,
+)
+def test_weight_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None:
+ """Test the Jacobian-matrix product w.r.t. to the weight.
Args:
- problem (DerivativesProblem): Problem for derivative test.
- V (int): Number of vectorized transposed Jacobian-vector products.
+ problem: Test case.
+ V: Number of vectorized Jacobian-vector products. Default: ``3``.
"""
problem.set_up()
- mat = torch.rand(V, *problem.module.weight.shape).to(problem.device)
+ mat = rand(V, *problem.module.weight.shape).to(problem.device)
backpack_res = BackpackDerivatives(problem).weight_jac_mat_prod(mat)
autograd_res = AutogradDerivatives(problem).weight_jac_mat_prod(mat)
@@ -134,46 +231,20 @@ def test_weight_jac_mat_prod(problem, V=3):
IDS_WITH_BIAS.append(problem_id)
-@pytest.mark.parametrize(
- "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"]
-)
-@pytest.mark.parametrize(
- "problem",
- PROBLEMS_WITH_BIAS,
- ids=IDS_WITH_BIAS,
-)
-def test_bias_jac_t_mat_prod(problem, sum_batch, V=3):
- """Test the transposed Jacobian-matrix product w.r.t. to the biass.
-
- Args:
- problem (DerivativesProblem): Problem for derivative test.
- sum_batch (bool): Sum results over the batch dimension.
- V (int): Number of vectorized transposed Jacobian-vector products.
- """
- problem.set_up()
- mat = torch.rand(V, *problem.output_shape).to(problem.device)
-
- backpack_res = BackpackDerivatives(problem).bias_jac_t_mat_prod(mat, sum_batch)
- autograd_res = AutogradDerivatives(problem).bias_jac_t_mat_prod(mat, sum_batch)
-
- check_sizes_and_values(autograd_res, backpack_res)
- problem.tear_down()
-
-
-@pytest.mark.parametrize(
+@mark.parametrize(
"problem",
- PROBLEMS_WITH_BIAS,
- ids=IDS_WITH_BIAS,
+ PROBLEMS_WITH_BIAS + BATCH_NORM_PROBLEMS,
+ ids=IDS_WITH_BIAS + BATCH_NORM_IDS,
)
-def test_bias_jac_mat_prod(problem, V=3):
- """Test the Jacobian-matrix product w.r.t. to the biass.
+def test_bias_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None:
+ """Test the Jacobian-matrix product w.r.t. to the bias.
Args:
- problem (DerivativesProblem): Problem for derivative test.
- V (int): Number of vectorized transposed Jacobian-vector products.
+ problem: Test case.
+ V: Number of vectorized Jacobian-vector products. Default: ``3``.
"""
problem.set_up()
- mat = torch.rand(V, *problem.module.bias.shape).to(problem.device)
+ mat = rand(V, *problem.module.bias.shape).to(problem.device)
backpack_res = BackpackDerivatives(problem).bias_jac_mat_prod(mat)
autograd_res = AutogradDerivatives(problem).bias_jac_mat_prod(mat)
@@ -182,61 +253,93 @@ def test_bias_jac_mat_prod(problem, V=3):
problem.tear_down()
-@pytest.mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS)
-def test_sqrt_hessian_squared_equals_hessian(problem):
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS)
+def test_sqrt_hessian_squared_equals_hessian(
+ problem: DerivativesTestProblem, subsampling: Union[List[int], None]
+) -> None:
"""Test the sqrt decomposition of the input Hessian.
Args:
- problem (DerivativesProblem): Problem for derivative test.
+ problem: Test case.
+ subsampling: Indices of active samples.
Compares the Hessian to reconstruction from individual Hessian sqrt.
"""
problem.set_up()
+ skip_subsampling_conflict(problem, subsampling)
- backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian()
- autograd_res = AutogradDerivatives(problem).input_hessian()
-
- print(backpack_res.device)
- print(autograd_res.device)
+ backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian(
+ subsampling=subsampling
+ )
+ autograd_res = AutogradDerivatives(problem).input_hessian(subsampling=subsampling)
check_sizes_and_values(autograd_res, backpack_res)
problem.tear_down()
-@pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS)
-def test_sqrt_hessian_should_fail(problem):
- with pytest.raises(ValueError):
- test_sqrt_hessian_squared_equals_hessian(problem)
-
-
-@pytest.mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS)
-def test_sqrt_hessian_sampled_squared_approximates_hessian(problem, mc_samples=100000):
- """Test the MC-sampled sqrt decomposition of the input Hessian.
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+@mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS)
+def test_sqrt_hessian_should_fail(
+ problem: DerivativesTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Test that sqrt_hessian fails.
Args:
- problem (DerivativesProblem): Problem for derivative test.
+ problem: Test case.
+ subsampling: Indices of active samples.
+ """
+ with raises(ValueError):
+ test_sqrt_hessian_squared_equals_hessian(problem, subsampling)
+
+
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS)
+def test_sqrt_hessian_sampled_squared_approximates_hessian(
+ problem: DerivativesTestProblem,
+ subsampling: Union[List[int], None],
+ mc_samples: int = 1000000,
+ chunks: int = 10,
+) -> None:
+ """Test the MC-sampled sqrt decomposition of the input Hessian.
Compares the Hessian to reconstruction from individual Hessian MC-sampled sqrt.
+
+ Args:
+ problem: Test case.
+ subsampling: Indices of active samples.
+ mc_samples: number of samples. Defaults to 1000000.
+ chunks: Number of passes the MC samples will be processed sequentially.
"""
problem.set_up()
+ skip_subsampling_conflict(problem, subsampling)
backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian(
- mc_samples=mc_samples
+ mc_samples=mc_samples, chunks=chunks, subsampling=subsampling
)
- autograd_res = AutogradDerivatives(problem).input_hessian()
+ autograd_res = AutogradDerivatives(problem).input_hessian(subsampling=subsampling)
- RTOL, ATOL = 1e-2, 2e-2
+ RTOL, ATOL = 1e-2, 7e-3
check_sizes_and_values(autograd_res, backpack_res, rtol=RTOL, atol=ATOL)
problem.tear_down()
-@pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS)
-def test_sqrt_hessian_sampled_should_fail(problem):
- with pytest.raises(ValueError):
- test_sqrt_hessian_sampled_squared_approximates_hessian(problem)
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+@mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS)
+def test_sqrt_hessian_sampled_should_fail(
+ problem: DerivativesTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Test that sqrt_hessian_samples fails.
+
+ Args:
+ problem: Test case.
+ subsampling: Indices of active samples.
+ """
+ with raises(ValueError):
+ test_sqrt_hessian_sampled_squared_approximates_hessian(problem, subsampling)
-@pytest.mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS)
+@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS)
def test_sum_hessian(problem):
"""Test the summed Hessian.
@@ -252,15 +355,20 @@ def test_sum_hessian(problem):
problem.tear_down()
-@pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS)
+@mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS)
def test_sum_hessian_should_fail(problem):
- with pytest.raises(ValueError):
+ """Test sum_hessian, should fail.
+
+ Args:
+ problem: test problem
+ """
+ with raises(ValueError):
test_sum_hessian(problem)
-@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
-def test_ea_jac_t_mat_jac_prod(problem):
- """Test KFRA backpropagation
+@mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
+def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None:
+ """Test KFRA backpropagation.
H_in → 1/N ∑ₙ Jₙ^T H_out Jₙ
@@ -271,11 +379,14 @@ def test_ea_jac_t_mat_jac_prod(problem):
as `Dropout` is not deterministic.
Args:
- problem (DerivativesProblem): Problem for derivative test.
+ problem: Test case.
+ request: PyTest request, used to get test id.
"""
+ skip_adaptive_avg_pool3d_cuda(request)
+
problem.set_up()
- out_features = torch.prod(torch.tensor(problem.output_shape[1:]))
- mat = torch.rand(out_features, out_features).to(problem.device)
+ out_features = problem.output_shape[1:].numel()
+ mat = rand(out_features, out_features).to(problem.device)
backpack_res = BackpackDerivatives(problem).ea_jac_t_mat_jac_prod(mat)
autograd_res = AutogradDerivatives(problem).ea_jac_t_mat_jac_prod(mat)
@@ -284,47 +395,96 @@ def test_ea_jac_t_mat_jac_prod(problem):
problem.tear_down()
-@pytest.mark.skip("[WAITING] Autograd issue with Hessian-vector products")
-@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
-def test_hessian_is_zero(problem):
- """Check if the input-output Hessian is (non-)zero."""
- problem.set_up()
+@fixture(
+ params=PROBLEMS + BATCH_NORM_PROBLEMS + RNN_PROBLEMS + EMBEDDING_PROBLEMS,
+ ids=lambda p: p.make_id(),
+)
+def problem(request) -> DerivativesTestProblem:
+ """Set seed, create tested layer and data. Finally clean up.
- backpack_res = BackpackDerivatives(problem).hessian_is_zero()
- autograd_res = AutogradDerivatives(problem).hessian_is_zero()
+ Args:
+ request (SubRequest): Request for the fixture from a test/fixture function.
- assert backpack_res == autograd_res
- problem.tear_down()
+ Yields:
+ Test case with deterministically constructed attributes.
+ """
+ case = request.param
+ case.set_up()
+ yield case
+ case.tear_down()
-@pytest.mark.skip
-@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
-def test_hessian_is_diagonal(problem):
- problem.set_up()
+@fixture
+def small_input_problem(
+ problem: DerivativesTestProblem, max_input_numel: int = 100
+) -> DerivativesTestProblem:
+ """Skip cases with large inputs.
- # TODO
- raise NotImplementedError
+ Args:
+ problem: Test case with deterministically constructed attributes.
+ max_input_numel: Maximum input size. Default: ``100``.
- problem.tear_down()
+ Yields:
+ Instantiated test case with small input.
+ """
+ if problem.input.numel() > max_input_numel:
+ skip("Input is too large:" + f" {problem.input.numel()} > {max_input_numel}")
+ else:
+ yield problem
-@pytest.mark.skip
-@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
-def test_hessian_diagonal(problem):
- problem.set_up()
+@fixture
+def no_loss_problem(
+ small_input_problem: DerivativesTestProblem,
+) -> DerivativesTestProblem:
+ """Skip cases that are loss functions.
+
+ Args:
+ small_input_problem: Test case with small input.
- # TODO
- raise NotImplementedError
+ Yields:
+ Instantiated test case that is not a loss layer.
+ """
+ if small_input_problem.is_loss():
+ skip("Only required for non-loss layers.")
+ else:
+ yield small_input_problem
- problem.tear_down()
+def test_hessian_is_zero(no_loss_problem: DerivativesTestProblem) -> None:
+ """Check if the input-output Hessian is (non-)zero.
-@pytest.mark.skip
-@pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS)
-def test_hessian_is_psd(problem):
+ Note:
+ `hessian_is_zero` is a global statement that assumes arbitrary inputs.
+ It can thus happen that the Hessian diagonal is zero for the current
+ input, but not in general.
+
+ Args:
+ no_loss_problem: Test case whose module is not a loss.
+ """
+ backpack_res = BackpackDerivatives(no_loss_problem).hessian_is_zero()
+ autograd_res = AutogradDerivatives(no_loss_problem).hessian_is_zero()
+
+ if autograd_res and not backpack_res:
+ warn(
+ "Autograd Hessian diagonal is zero for this input "
+ " while BackPACK implementation implies inputs with non-zero Hessian."
+ )
+ else:
+ assert backpack_res == autograd_res
+
+
+@mark.parametrize("problem", LOSS_PROBLEMS, ids=LOSS_IDS)
+def test_make_hessian_mat_prod(problem: DerivativesTestProblem) -> None:
+ """Test hessian_mat_prod.
+
+ Args:
+ problem: test problem
+ """
problem.set_up()
+ mat = rand(4, *problem.input_shape, device=problem.device)
- # TODO
- raise NotImplementedError
+ autograd_res = AutogradDerivatives(problem).hessian_mat_prod(mat)
+ backpack_res = BackpackDerivatives(problem).hessian_mat_prod(mat)
- problem.tear_down()
+ check_sizes_and_values(backpack_res, autograd_res)
diff --git a/test/core/derivatives/embedding_settings.py b/test/core/derivatives/embedding_settings.py
new file mode 100644
index 000000000..e6f8b1486
--- /dev/null
+++ b/test/core/derivatives/embedding_settings.py
@@ -0,0 +1,14 @@
+"""Settings for testing derivatives of Embedding."""
+from torch import randint
+from torch.nn import Embedding
+
+EMBEDDING_SETTINGS = [
+ {
+ "module_fn": lambda: Embedding(3, 5),
+ "input_fn": lambda: randint(0, 3, (4,)),
+ },
+ {
+ "module_fn": lambda: Embedding(5, 7),
+ "input_fn": lambda: randint(0, 5, (8, 3, 3)),
+ },
+]
diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py
index e7a60f6b2..921e46132 100644
--- a/test/core/derivatives/implementation/autograd.py
+++ b/test/core/derivatives/implementation/autograd.py
@@ -1,104 +1,166 @@
+"""Derivatives computed with PyTorch's autograd."""
from test.core.derivatives.implementation.base import DerivativesImplementation
+from typing import List
-import torch
+from torch import Tensor, allclose, backends, cat, stack, zeros, zeros_like
from backpack.hessianfree.hvp import hessian_vector_product
from backpack.hessianfree.lop import transposed_jacobian_vector_product
from backpack.hessianfree.rop import jacobian_vector_product
+from backpack.utils.subsampling import subsample
class AutogradDerivatives(DerivativesImplementation):
"""Derivative implementations with autograd."""
- def jac_vec_prod(self, vec):
- input, output, _ = self.problem.forward_pass(input_requires_grad=True)
- return jacobian_vector_product(output, input, vec)[0]
-
- def jac_mat_prod(self, mat):
- V = mat.shape[0]
+ def jac_vec_prod(self, vec) -> Tensor:
+ """Product of input-output-Jacobian and a vector.
- vecs = [mat[v] for v in range(V)]
- jac_vec_prods = [self.jac_vec_prod(vec) for vec in vecs]
+ Args:
+ vec: vector
- return torch.stack(jac_vec_prods)
-
- def jac_t_vec_prod(self, vec):
+ Returns:
+ product
+ """
input, output, _ = self.problem.forward_pass(input_requires_grad=True)
- return transposed_jacobian_vector_product(output, input, vec)[0]
-
- def jac_t_mat_prod(self, mat):
- V = mat.shape[0]
-
- vecs = [mat[v] for v in range(V)]
- jac_t_vec_prods = [self.jac_t_vec_prod(vec) for vec in vecs]
+ return jacobian_vector_product(output, input, vec)[0]
- return torch.stack(jac_t_vec_prods)
+ def jac_mat_prod(self, mat): # noqa: D102
+ try:
+ return stack([self.jac_vec_prod(vec) for vec in mat])
+ except RuntimeError:
+ # A RuntimeError is thrown for RNNs on CUDA,
+ # because PyTorch does not support double-backwards pass for them.
+ # This is the recommended workaround.
+ with backends.cudnn.flags(enabled=False):
+ return stack([self.jac_vec_prod(vec) for vec in mat])
+
+ def jac_t_vec_prod(self, vec: Tensor, subsampling=None) -> Tensor: # noqa: D102
+ input, output, _ = self.problem.forward_pass(input_requires_grad=True)
- def weight_jac_t_mat_prod(self, mat, sum_batch):
- return self.param_jac_t_mat_prod("weight", mat, sum_batch)
+ if subsampling is None:
+ return transposed_jacobian_vector_product(output, input, vec)[0]
+ else:
+ # for each sample, multiply by full input Jacobian, slice out result:
+ # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n]
+ batch_axis = 0
+ output = subsample(output, dim=batch_axis, subsampling=subsampling)
+ output = output.split(1, dim=batch_axis)
+ vec = vec.split(1, dim=batch_axis)
+
+ vjps: List[Tensor] = []
+ for sample_idx, out, v in zip(subsampling, output, vec):
+ vjp = transposed_jacobian_vector_product(out, input, v)[0]
+ vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx])
+ vjps.append(vjp)
+
+ return cat(vjps, dim=batch_axis)
+
+ def jac_t_mat_prod(
+ self, mat: Tensor, subsampling: List[int] = None
+ ) -> Tensor: # noqa: D102
+ return stack([self.jac_t_vec_prod(vec, subsampling=subsampling) for vec in mat])
+
+ def param_mjp(
+ self,
+ param_str: str,
+ mat: Tensor,
+ sum_batch: bool,
+ subsampling: List[int] = None,
+ ) -> Tensor: # noqa: D102
+ return stack(
+ [
+ self._param_vjp(
+ param_str,
+ vec,
+ sum_batch,
+ axis_batch=0,
+ subsampling=subsampling,
+ )
+ for vec in mat
+ ]
+ )
- def bias_jac_t_mat_prod(self, mat, sum_batch):
- return self.param_jac_t_mat_prod("bias", mat, sum_batch)
+ def _param_vjp(
+ self,
+ name: str,
+ vec: Tensor,
+ sum_batch: bool,
+ axis_batch: int = 0,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Compute the product of jac_t and the given vector.
+
+ Args:
+ name: name of parameter for derivative
+ vec: vectors which to multiply
+ sum_batch: whether to sum along batch axis
+ axis_batch: index of batch axis. Defaults to 0.
+ subsampling: Indices of active samples. Default: ``None`` (all).
- def param_jac_t_vec_prod(self, name, vec, sum_batch):
+ Returns:
+ product of jac_t and vec
+ """
input, output, named_params = self.problem.forward_pass()
param = named_params[name]
- if sum_batch:
- return transposed_jacobian_vector_product(output, param, vec)[0]
- else:
- N = input.shape[0]
+ samples = range(input.shape[axis_batch]) if subsampling is None else subsampling
+ sample_outputs = output.split(1, dim=axis_batch)
+ sample_vecs = vec.split(1, dim=axis_batch)
- sample_outputs = [output[n] for n in range(N)]
- sample_vecs = [vec[n] for n in range(N)]
+ jac_t_sample_prods = stack(
+ [
+ transposed_jacobian_vector_product(sample_outputs[n], param, vec_n)[0]
+ for n, vec_n in zip(samples, sample_vecs)
+ ],
+ )
- jac_t_sample_prods = [
- transposed_jacobian_vector_product(n_out, param, n_vec)[0]
- for n_out, n_vec in zip(sample_outputs, sample_vecs)
- ]
+ if sum_batch:
+ jac_t_sample_prods = jac_t_sample_prods.sum(0)
- return torch.stack(jac_t_sample_prods)
+ return jac_t_sample_prods
- def param_jac_t_mat_prod(self, name, mat, sum_batch):
- V = mat.shape[0]
+ def weight_jac_mat_prod(self, mat) -> Tensor:
+ """Product of jacobian and matrix.
- vecs = [mat[v] for v in range(V)]
- jac_t_vec_prods = [
- self.param_jac_t_vec_prod(name, vec, sum_batch) for vec in vecs
- ]
+ Args:
+ mat: matrix
- return torch.stack(jac_t_vec_prods)
+ Returns:
+ product
+ """
+ return self._param_jac_mat_prod("weight", mat)
+
+ def bias_jac_mat_prod(self, mat) -> Tensor:
+ """Product of jacobian and matrix.
- def weight_jac_mat_prod(self, mat):
- return self.param_jac_mat_prod("weight", mat)
+ Args:
+ mat: matrix
- def bias_jac_mat_prod(self, mat):
- return self.param_jac_mat_prod("bias", mat)
+ Returns:
+ product
+ """
+ return self._param_jac_mat_prod("bias", mat)
- def param_jac_vec_prod(self, name, vec):
+ def _param_jac_vec_prod(self, name, vec):
input, output, named_params = self.problem.forward_pass()
param = named_params[name]
return jacobian_vector_product(output, param, vec)[0]
- def param_jac_mat_prod(self, name, mat):
- V = mat.shape[0]
+ def _param_jac_mat_prod(self, name, mat):
+ return stack([self._param_jac_vec_prod(name, vec) for vec in mat])
- vecs = [mat[v] for v in range(V)]
- jac_vec_prods = [self.param_jac_vec_prod(name, vec) for vec in vecs]
-
- return torch.stack(jac_vec_prods)
-
- def ea_jac_t_mat_jac_prod(self, mat):
- def sample_jac_t_mat_jac_prod(sample_idx, mat):
+ def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102
+ def _sample_jac_t_mat_jac_prod(sample_idx, mat):
assert len(mat.shape) == 2
- def sample_jac_t_mat_prod(sample_idx, mat):
+ def _sample_jac_t_mat_prod(sample_idx, mat):
sample, output, _ = self.problem.forward_pass(
- input_requires_grad=True, sample_idx=sample_idx
+ input_requires_grad=True, subsampling=[sample_idx]
)
- result = torch.zeros(sample.numel(), mat.size(1), device=sample.device)
+ result = zeros(sample.numel(), mat.size(1), device=sample.device)
for col in range(mat.size(1)):
column = mat[:, col].reshape(output.shape)
@@ -108,9 +170,9 @@ def sample_jac_t_mat_prod(sample_idx, mat):
return result
- jac_t_mat = sample_jac_t_mat_prod(sample_idx, mat)
+ jac_t_mat = _sample_jac_t_mat_prod(sample_idx, mat)
mat_t_jac = jac_t_mat.t()
- jac_t_mat_t_jac = sample_jac_t_mat_prod(sample_idx, mat_t_jac)
+ jac_t_mat_t_jac = _sample_jac_t_mat_prod(sample_idx, mat_t_jac)
jac_t_mat_jac = jac_t_mat_t_jac.t()
return jac_t_mat_jac
@@ -118,24 +180,26 @@ def sample_jac_t_mat_prod(sample_idx, mat):
N = self.problem.input.shape[0]
input_features = self.problem.input.shape.numel() // N
- result = torch.zeros(input_features, input_features).to(self.problem.device)
+ result = zeros(input_features, input_features).to(self.problem.device)
for n in range(N):
- result += sample_jac_t_mat_jac_prod(n, mat)
+ result += _sample_jac_t_mat_jac_prod(n, mat)
return result / N
- def hessian(self, loss, x):
+ def _hessian(self, loss: Tensor, x: Tensor) -> Tensor:
"""Return the Hessian matrix of a scalar `loss` w.r.t. a tensor `x`.
- Arguments:
- loss (torch.Tensor): A scalar-valued tensor.
- x (torch.Tensor): Tensor used in the computation graph of `loss`.
+ Args:
+ loss: A scalar-valued tensor.
+ x: Tensor used in the computation graph of `loss`.
+
Shapes:
loss: `[1,]`
x: `[A, B, C, ...]`
+
Returns:
- torch.Tensor: Hessian tensor of `loss` w.r.t. `x`. The Hessian has shape
+ Hessian tensor of `loss` w.r.t. `x`. The Hessian has shape
`[A, B, C, ..., A, B, C, ...]`.
"""
assert loss.numel() == 1
@@ -143,11 +207,11 @@ def hessian(self, loss, x):
vectorized_shape = (x.numel(), x.numel())
final_shape = (*x.shape, *x.shape)
- hessian_vec_x = torch.zeros(vectorized_shape).to(loss.device)
+ hessian_vec_x = zeros(vectorized_shape).to(loss.device)
num_cols = hessian_vec_x.shape[1]
for column_idx in range(num_cols):
- unit = torch.zeros(num_cols).to(loss.device)
+ unit = zeros(num_cols).to(loss.device)
unit[column_idx] = 1.0
unit = unit.view_as(x)
@@ -157,16 +221,12 @@ def hessian(self, loss, x):
return hessian_vec_x.reshape(final_shape)
- def elementwise_hessian(self, tensor, x):
- """Yield the Hessian of each element in `tensor` w.r.t `x`.
+ def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor:
+ """Computes the Hessian of each element in `tensor` w.r.t `x`.
- Hessians are returned in the order of elements in the flattened tensor.
- """
- for t in tensor.flatten():
- yield self.hessian(t, x)
-
- def tensor_hessian(self, tensor, x):
- """Return the Hessian of a tensor `tensor` w.r.t. a tensor `x`.
+ If ``tensor`` is linear in ``x``, autograd raises a ``RuntimeError``.
+ If ``tensor`` does not depend on ``x``, autograd raises an ``AttributeError``.
+ In both cases, a Hessian of zeros is created manually and returned.
Given a `tensor` of shape `[A, B, C]` and another tensor `x` with shape `[D, E]`
used in the computation of `tensor`, the generalized Hessian has shape
@@ -174,73 +234,136 @@ def tensor_hessian(self, tensor, x):
`hessian[a, b, c]` contains the Hessian of the scalar entry `tensor[a, b, c]`
w.r.t. `x[a, b, c]`.
+ If ``tensor`` is linear in ``x``, autograd raises a ``RuntimeError``.
+ If ``tensor`` does not depend on ``x``, autograd raises an ``AttributeError``.
+ In both cases, a Hessian of zeros is created manually and returned.
+
Arguments:
- tensor (torch.Tensor): An arbitrary tensor.
- x (torch.Tensor): Tensor used in the computation graph of `tensor`.
+ tensor: An arbitrary tensor.
+ x: Tensor used in the computation graph of `tensor`.
- Returns:
- torch.Tensor: Generalized Hessian of `tensor` w.r.t. `x`.
+ Yields:
+ Hessians in the order of elements in the flattened tensor.
"""
- shape = (*tensor.shape, *x.shape, *x.shape)
-
- return torch.cat(list(self.elementwise_hessian(tensor, x))).reshape(shape)
-
- def hessian_is_zero(self):
- """Return whether the input-output Hessian is zero.
+ for t in tensor.flatten():
+ try:
+ yield self._hessian(t, x)
+ except (RuntimeError, AttributeError):
+ yield zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype)
- Returns:
- bool: `True`, if Hessian is zero, else `False`.
- """
+ def hessian_is_zero(self) -> bool: # noqa: D102
input, output, _ = self.problem.forward_pass(input_requires_grad=True)
zero = None
- for hessian in self.elementwise_hessian(output, input):
+ for hessian in self._elementwise_hessian(output, input):
if zero is None:
- zero = torch.zeros_like(hessian)
+ zero = zeros_like(hessian)
- if not torch.allclose(hessian, zero):
+ if not allclose(hessian, zero):
return False
return True
- def input_hessian(self):
- """Compute the Hessian of the module output w.r.t. the input."""
+ def input_hessian(self, subsampling: List[int] = None) -> Tensor:
+ """Compute the Hessian of the module output w.r.t. the input.
+
+ Args:
+ subsampling: Indices of active samples. ``None`` uses all samples.
+
+ Returns:
+ Hessian of shape ``[N, *, N, *]`` where ``N`` denotes the
+ number of sub-samples, and ``*`` is the input feature shape.
+ """
input, output, _ = self.problem.forward_pass(input_requires_grad=True)
- return self.hessian(output, input)
+ hessian = self._hessian(output, input)
+ return self._subsample_input_hessian(hessian, input, subsampling=subsampling)
+
+ @staticmethod
+ def _subsample_input_hessian(
+ hessian: Tensor, input: Tensor, subsampling: List[int] = None
+ ) -> Tensor:
+ """Slice sub-samples out of Hessian w.r.t the full input.
+
+ If ``subsampling`` is set to ``None``, leaves the Hessian unchanged.
+
+ Args:
+ hessian: The Hessian w.r.t. the module input.
+ input: Module input.
+ subsampling: List of active samples. Default of ``None`` uses all samples.
+
+ Returns:
+ Sub-sampled Hessian of shape ``[N, *, N, *]`` where ``N`` denotes the
+ number of sub-samples, and ``*`` is the input feature shape.
+ """
+ N, D_shape = input.shape[0], input.shape[1:]
+ D = input.numel() // N
+
+ subsampled_hessian = hessian.reshape(N, D, N, D)[subsampling, :, :, :][
+ :, :, subsampling, :
+ ]
- def sum_hessian(self):
- """Compute the Hessian of a loss module w.r.t. its input."""
+ has_duplicates = subsampling is not None and len(set(subsampling)) != len(
+ subsampling
+ )
+ if has_duplicates:
+ # For duplicates in `subsampling`, the above slicing is not sufficient.
+ # and off-diagonal blocks need to be zeroed. E.g. if subsampling is [0, 0]
+ # then the sliced input Hessian has non-zero off-diagonal blocks (1, 0) and
+ # (0, 1), which should be zero as the samples are considered independent.
+ for idx1, sample1 in enumerate(subsampling[:-1]):
+ for idx2, sample2 in enumerate(subsampling[idx1 + 1 :], start=idx1 + 1):
+ if sample1 == sample2:
+ subsampled_hessian[idx1, :, idx2, :] = 0
+ subsampled_hessian[idx2, :, idx1, :] = 0
+
+ N_active = N if subsampling is None else len(subsampling)
+ out_shape = [N_active, *D_shape, N_active, *D_shape]
+
+ return subsampled_hessian.reshape(out_shape)
+
+ def sum_hessian(self) -> Tensor:
+ """Compute the Hessian of a loss module w.r.t. its input.
+
+ Returns:
+ hessian
+ """
hessian = self.input_hessian()
return self._sum_hessian_blocks(hessian)
- def _sum_hessian_blocks(self, hessian):
+ def _sum_hessian_blocks(self, hessian: Tensor) -> Tensor:
"""Sum second derivatives over the batch dimension.
Assert second derivative w.r.t. different samples is zero.
- """
- input = self.problem.input
- num_axes = len(input.shape)
- if num_axes != 2:
- raise ValueError("Only 2D inputs are currently supported.")
+ Args:
+ hessian: Hessian of the output w.r.t. the input. Has shape ``[N, *, N, *]``
+ where ``N`` is the number of active samples and ``*`` is the input's
+ feature shape.
+ Returns:
+ Sum of Hessians w.r.t. to individual samples. Has shape ``[*, *]``.
+ """
+ input = self.problem.input
N = input.shape[0]
- num_features = input.numel() // N
+ shape_feature = input.shape[1:]
+ D = shape_feature.numel()
- sum_hessian = torch.zeros(num_features, num_features, device=input.device)
+ hessian = hessian.reshape(N, D, N, D)
+ sum_hessian = zeros(D, D, device=input.device, dtype=input.dtype)
- hessian_different_samples = torch.zeros(
- num_features, num_features, device=input.device
- )
+ hessian_different_samples = zeros(D, D, device=input.device, dtype=input.dtype)
for n_1 in range(N):
for n_2 in range(N):
block = hessian[n_1, :, n_2, :]
-
if n_1 == n_2:
sum_hessian += block
-
else:
- assert torch.allclose(block, hessian_different_samples)
+ assert allclose(block, hessian_different_samples)
+
+ return sum_hessian.reshape(*shape_feature, *shape_feature)
+
+ def hessian_mat_prod(self, mat: Tensor) -> Tensor: # noqa: D102
+ input, output, _ = self.problem.forward_pass(input_requires_grad=True)
- return sum_hessian
+ return stack([hessian_vector_product(output, [input], [vec])[0] for vec in mat])
diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py
index c36d94d9f..092d368c1 100644
--- a/test/core/derivatives/implementation/backpack.py
+++ b/test/core/derivatives/implementation/backpack.py
@@ -1,116 +1,179 @@
+"""Contains derivative calculation with BackPACK."""
from test.core.derivatives.implementation.base import DerivativesImplementation
+from test.utils import chunk_sizes
+from typing import List
-import torch
+from torch import Tensor, einsum, zeros
+
+from backpack.utils.subsampling import subsample
class BackpackDerivatives(DerivativesImplementation):
"""Derivative implementations with BackPACK."""
def __init__(self, problem):
+ """Initialization.
+
+ Args:
+ problem: test problem
+ """
problem.extend()
super().__init__(problem)
def store_forward_io(self):
+ """Do one forward pass.
+
+ This implicitly saves relevant quantities for backward pass.
+ """
self.problem.forward_pass()
- def jac_mat_prod(self, mat):
+ def jac_mat_prod(self, mat): # noqa: D102
self.store_forward_io()
return self.problem.derivative.jac_mat_prod(
self.problem.module, None, None, mat
)
- def jac_t_mat_prod(self, mat):
+ def jac_t_mat_prod(
+ self, mat: Tensor, subsampling: List[int]
+ ) -> Tensor: # noqa: D102
self.store_forward_io()
return self.problem.derivative.jac_t_mat_prod(
- self.problem.module, None, None, mat
+ self.problem.module, None, None, mat, subsampling=subsampling
)
- def weight_jac_t_mat_prod(self, mat, sum_batch):
+ def param_mjp(
+ self,
+ param_str: str,
+ mat: Tensor,
+ sum_batch: bool,
+ subsampling: List[int] = None,
+ ) -> Tensor: # noqa: D102
self.store_forward_io()
- return self.problem.derivative.weight_jac_t_mat_prod(
- self.problem.module, None, None, mat, sum_batch=sum_batch
+ return self.problem.derivative.param_mjp(
+ param_str,
+ self.problem.module,
+ None,
+ None,
+ mat,
+ sum_batch=sum_batch,
+ subsampling=subsampling,
)
- def bias_jac_t_mat_prod(self, mat, sum_batch):
- self.store_forward_io()
- return self.problem.derivative.bias_jac_t_mat_prod(
- self.problem.module, None, None, mat, sum_batch=sum_batch
- )
-
- def weight_jac_mat_prod(self, mat):
+ def weight_jac_mat_prod(self, mat): # noqa: D102
self.store_forward_io()
return self.problem.derivative.weight_jac_mat_prod(
self.problem.module, None, None, mat
)
- def bias_jac_mat_prod(self, mat):
+ def bias_jac_mat_prod(self, mat): # noqa: D102
self.store_forward_io()
return self.problem.derivative.bias_jac_mat_prod(
self.problem.module, None, None, mat
)
- def ea_jac_t_mat_jac_prod(self, mat):
+ def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102
self.store_forward_io()
return self.problem.derivative.ea_jac_t_mat_jac_prod(
self.problem.module, None, None, mat
)
- def sum_hessian(self):
+ def sum_hessian(self): # noqa: D102
self.store_forward_io()
return self.problem.derivative.sum_hessian(self.problem.module, None, None)
- def input_hessian_via_sqrt_hessian(self, mc_samples=None):
- # MC_SAMPLES = 100000
+ def input_hessian_via_sqrt_hessian(
+ self, mc_samples: int = None, chunks: int = 1, subsampling: List[int] = None
+ ) -> Tensor:
+ """Computes the Hessian w.r.t. to the input from its matrix square root.
+
+ Args:
+ mc_samples: If int, uses an MC approximation with the specified
+ number of samples. If None, uses the exact hessian. Defaults to None.
+ chunks: Maximum sequential split of the computation. Default: ``1``.
+ Only used if mc_samples is specified.
+ subsampling: Indices of active samples. ``None`` uses all samples.
+
+ Returns:
+ Hessian with respect to the input. Has shape
+ ``[N, A, B, ..., N, A, B, ...]`` where ``N`` is the batch size or number
+ of active samples when sub-sampling is used, and ``[A, B, ...]`` are the
+ input's feature dimensions.
+ """
self.store_forward_io()
if mc_samples is not None:
- sqrt_hessian = self.problem.derivative.sqrt_hessian_sampled(
- self.problem.module, None, None, mc_samples=mc_samples
+ chunk_samples = chunk_sizes(mc_samples, chunks)
+ chunk_weights = [samples / mc_samples for samples in chunk_samples]
+
+ individual_hessians: Tensor = sum(
+ weight
+ * self._sample_hessians_from_sqrt(
+ self.problem.derivative.sqrt_hessian_sampled(
+ self.problem.module,
+ None,
+ None,
+ mc_samples=samples,
+ subsampling=subsampling,
+ )
+ )
+ for weight, samples in zip(chunk_weights, chunk_samples)
)
else:
sqrt_hessian = self.problem.derivative.sqrt_hessian(
- self.problem.module, None, None
+ self.problem.module, None, None, subsampling=subsampling
)
+ individual_hessians = self._sample_hessians_from_sqrt(sqrt_hessian)
- individual_hessians = self._sample_hessians_from_sqrt(sqrt_hessian)
+ input0 = subsample(self.problem.module.input0, subsampling=subsampling)
+ return self._embed_sample_hessians(individual_hessians, input0)
- return self._embed_sample_hessians(
- individual_hessians, self.problem.module.input0
- )
+ def hessian_is_zero(self) -> bool: # noqa: D102
+ return self.problem.derivative.hessian_is_zero(self.problem.module)
+
+ def _sample_hessians_from_sqrt(self, sqrt: Tensor) -> Tensor:
+ """Convert individual matrix square root into individual full matrix.
- def hessian_is_zero(self):
- """Return whether the input-output Hessian is zero.
+ Args:
+ sqrt: individual square root of hessian
Returns:
- bool: `True`, if Hessian is zero, else `False`.
+ Individual Hessians of shape ``[N, A, B, ..., A, B, ...]`` where
+ ``input.shape[1:] = [A, B, ...]`` are the input feature dimensions
+ and ``N`` is the batch size.
"""
- return self.problem.derivative.hessian_is_zero()
+ N, input_dims = sqrt.shape[1], sqrt.shape[2:]
- def _sample_hessians_from_sqrt(self, sqrt):
- """Convert individual matrix square root into individual full matrix."""
- equation = None
- num_axes = len(sqrt.shape)
+ sqrt_flat = sqrt.flatten(start_dim=2)
+ sample_hessians = einsum("vni,vnj->nij", sqrt_flat, sqrt_flat)
- # TODO improve readability
- if num_axes == 3:
- equation = "vni,vnj->nij"
- else:
- raise ValueError("Only 2D inputs are currently supported.")
+ return sample_hessians.reshape(N, *input_dims, *input_dims)
- return torch.einsum(equation, sqrt, sqrt)
+ def _embed_sample_hessians(
+ self, individual_hessians: Tensor, input: Tensor
+ ) -> Tensor:
+ """Embed Hessians w.r.t. individual samples into Hessian w.r.t. all samples.
- def _embed_sample_hessians(self, individual_hessians, input):
- hessian_shape = (*input.shape, *input.shape)
- hessian = torch.zeros(hessian_shape, device=input.device)
+ Args:
+ individual_hessians: Hessians w.r.t. individual samples in the input.
+ input: Inputs for the for samples whose individual Hessians are passed.
+ Has shape ``[N, A, B, ..., A, B, ...]`` where ``N`` is the number of
+ active samples and ``[A, B, ...]`` are the feature dimensions.
- N = input.shape[0]
+ Returns:
+ Hessian that contains the individual Hessians as diagonal blocks.
+ Has shape ``[N, A, B, ..., N, A, B, ...]``.
+ """
+ N, D = input.shape[0], input.shape[1:].numel()
+ hessian = zeros(N, D, N, D, device=input.device, dtype=input.dtype)
for n in range(N):
- num_axes = len(input.shape)
+ hessian[n, :, n, :] = individual_hessians[n].reshape(D, D)
- if num_axes == 2:
- hessian[n, :, n, :] = individual_hessians[n]
- else:
- raise ValueError("Only 2D inputs are currently supported.")
+ return hessian.reshape(*input.shape, *input.shape)
- return hessian
+ def hessian_mat_prod(self, mat: Tensor) -> Tensor: # noqa: D102
+ self.store_forward_io()
+ hmp = self.problem.derivative.make_hessian_mat_prod(
+ self.problem.module, None, None
+ )
+ return hmp(mat)
diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py
index 9edaa8194..1bf91c387 100644
--- a/test/core/derivatives/implementation/base.py
+++ b/test/core/derivatives/implementation/base.py
@@ -1,23 +1,131 @@
-class DerivativesImplementation:
+"""Contains DerivativesImplementation, the base class for autograd and backpack."""
+from abc import ABC, abstractmethod
+from typing import List
+
+from torch import Tensor
+
+
+class DerivativesImplementation(ABC):
"""Base class for autograd and BackPACK implementations."""
def __init__(self, problem):
+ """Initialization.
+
+ Args:
+ problem: test problem
+ """
self.problem = problem
- def jac_mat_prod(self, mat):
+ @abstractmethod
+ def jac_mat_prod(self, mat: Tensor) -> Tensor:
+ """Vectorized product of input-output-Jacobian and a matrix.
+
+ Args:
+ mat: matrix: the vectors along its leading dimension will be multiplied.
+
+ Returns:
+ Tensor representing the result of Jacobian-vector product.
+ product[v] = J @ mat[v]
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def jac_t_mat_prod(self, mat: Tensor, subsampling: List[int] = None) -> Tensor:
+ """Vectorized product of transposed jacobian and matrix.
+
+ Args:
+ mat: matrix: the vectors along its leading dimension will be multiplied.
+ subsampling: Active samples in the output. Default: ``None`` (all).
+
+ Returns:
+ Tensor representing the result of Jacobian-vector product.
+ product[v] = mat[v] @ J
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def param_mjp(
+ self,
+ param_str: str,
+ mat: Tensor,
+ sum_batch: bool,
+ subsampling: List[int] = None,
+ ) -> Tensor:
+ """Matrix-Jacobian products w.r.t. the parameter.
+
+ Args:
+ param_str: parameter name
+ mat: matrix
+ sum_batch: whether to sum along batch axis
+ subsampling: Active samples in the output. Default: ``None`` (all).
+
+ Returns:
+ product
+ """
raise NotImplementedError
- def jac_t_mat_prod(self, mat):
+ @abstractmethod
+ def weight_jac_mat_prod(self, mat: Tensor) -> Tensor:
+ """Product of jacobian and matrix.
+
+ Args:
+ mat: matrix
+
+ Returns:
+ product
+ """
raise NotImplementedError
- def weight_jac_t_mat_prod(self, mat, sum_batch):
+ @abstractmethod
+ def bias_jac_mat_prod(self, mat: Tensor) -> Tensor:
+ """Product of jacobian and matrix.
+
+ Args:
+ mat: matrix
+
+ Returns:
+ product
+ """
raise NotImplementedError
- def bias_jac_t_mat_prod(self, mat, sum_batch):
+ @abstractmethod
+ def ea_jac_t_mat_jac_prod(self, mat: Tensor) -> Tensor:
+ """Product of ea jacobian with matrix.
+
+ Args:
+ mat: matrix
+
+ Returns:
+ product
+ """
raise NotImplementedError
- def weight_jac_mat_prod(self, mat):
+ @abstractmethod
+ def sum_hessian(self) -> Tensor:
+ """Sum of hessians.
+
+ Returns:
+ the sum of hessians
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def hessian_is_zero(self) -> bool:
+ """Return whether the input-output Hessian is zero.
+
+ Returns:
+ `True`, if Hessian is zero, else `False`.
+ """
raise NotImplementedError
- def bias_jac_mat_prod(self, mat):
+ @abstractmethod
+ def hessian_mat_prod(self, mat: Tensor) -> Tensor:
+ """Product of hessian with matrix mat.
+
+ Args:
+ mat: matrix to multiply
+
+ Returns:
+ product
+ """
raise NotImplementedError
diff --git a/test/core/derivatives/linear_settings.py b/test/core/derivatives/linear_settings.py
index df7132d99..a98df7aab 100644
--- a/test/core/derivatives/linear_settings.py
+++ b/test/core/derivatives/linear_settings.py
@@ -6,7 +6,7 @@
"input_fn" (callable): Used for specifying input function
Optional entries:
- "target_fn" (callable): Fetches the groundtruth/target classes
+ "target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
"device" [list(torch.device)]: List of devices to run the test on.
@@ -55,3 +55,22 @@
),
},
]
+
+# additional dimensions
+LINEAR_SETTINGS += [
+ {
+ "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True),
+ "input_fn": lambda: torch.rand(size=(3, 2, 4)),
+ "id_prefix": "one-additional",
+ },
+ {
+ "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True),
+ "input_fn": lambda: torch.rand(size=(3, 2, 3, 4)),
+ "id_prefix": "two-additional",
+ },
+ {
+ "module_fn": lambda: torch.nn.Linear(in_features=4, out_features=3, bias=True),
+ "input_fn": lambda: torch.rand(size=(3, 2, 3, 5, 4)),
+ "id_prefix": "three-additional",
+ },
+]
diff --git a/test/core/derivatives/loss_settings.py b/test/core/derivatives/loss_settings.py
index 7c5a57855..391420cae 100644
--- a/test/core/derivatives/loss_settings.py
+++ b/test/core/derivatives/loss_settings.py
@@ -3,7 +3,7 @@
Required entries:
"module_fn" (callable): Contains a model constructed from `torch.nn` layers
"input_fn" (callable): Used for specifying input function
- "target_fn" (callable): Fetches the groundtruth/target classes
+ "target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
@@ -28,7 +28,7 @@
example = {
"module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
"input_fn": lambda: torch.rand(size=(2, 4)),
- "target_fn": lambda: classification_targets(size=(2,), num_classes=2),
+ "target_fn": lambda: classification_targets(size=(2,), num_classes=4),
"device": [torch.device("cpu")], # optional
"seed": 0, # optional
"id_prefix": "loss-example", # optional
@@ -37,15 +37,25 @@
LOSS_SETTINGS += [
+ {
+ "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
+ "input_fn": lambda: torch.rand(size=(2, 4, 3)),
+ "target_fn": lambda: classification_targets(size=(2, 3), num_classes=4),
+ },
+ {
+ "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
+ "input_fn": lambda: torch.rand(size=(3, 4, 3, 2)),
+ "target_fn": lambda: classification_targets(size=(3, 3, 2), num_classes=4),
+ },
{
"module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
"input_fn": lambda: torch.rand(size=(2, 4)),
- "target_fn": lambda: classification_targets(size=(2,), num_classes=2),
+ "target_fn": lambda: classification_targets(size=(2,), num_classes=4),
},
{
"module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"input_fn": lambda: torch.rand(size=(8, 4)),
- "target_fn": lambda: classification_targets(size=(8,), num_classes=2),
+ "target_fn": lambda: classification_targets(size=(8,), num_classes=4),
},
{
"module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="none"),
diff --git a/test/core/derivatives/lstm_settings.py b/test/core/derivatives/lstm_settings.py
new file mode 100644
index 000000000..77181d9dd
--- /dev/null
+++ b/test/core/derivatives/lstm_settings.py
@@ -0,0 +1,29 @@
+"""Test configurations for `backpack.core.derivatives` LSTM layers.
+
+Required entries:
+ "module_fn" (callable): Contains a model constructed from `torch.nn` layers
+ "input_fn" (callable): Used for specifying input function
+
+Optional entries:
+ "target_fn" (callable): Fetches the groundtruth/target classes
+ of regression/classification task
+ "loss_function_fn" (callable): Loss function used in the model
+ "device" [list(torch.device)]: List of devices to run the test on.
+ "id_prefix" (str): Prefix to be included in the test name.
+ "seed" (int): seed for the random number for torch.rand
+"""
+
+from torch import rand
+from torch.nn import LSTM
+
+LSTM_SETTINGS = []
+
+###############################################################################
+# test settings #
+###############################################################################
+LSTM_SETTINGS += [
+ {
+ "module_fn": lambda: LSTM(input_size=4, hidden_size=3, batch_first=True),
+ "input_fn": lambda: rand(size=(3, 5, 4)),
+ },
+]
diff --git a/test/core/derivatives/permute_settings.py b/test/core/derivatives/permute_settings.py
new file mode 100644
index 000000000..e6ffd360d
--- /dev/null
+++ b/test/core/derivatives/permute_settings.py
@@ -0,0 +1,33 @@
+"""Test configurations for `backpack.core.derivatives` Permute.
+
+Required entries:
+ "module_fn" (callable): Contains a model constructed from `torch.nn` layers
+ "input_fn" (callable): Used for specifying input function
+
+Optional entries:
+ "target_fn" (callable): Fetches the groundtruth/target classes
+ of regression/classification task
+ "loss_function_fn" (callable): Loss function used in the model
+ "device" [list(torch.device)]: List of devices to run the test on.
+ "id_prefix" (str): Prefix to be included in the test name.
+ "seed" (int): seed for the random number for torch.rand
+"""
+
+import torch
+
+from backpack.custom_module.permute import Permute
+
+PERMUTE_SETTINGS = [
+ {
+ "module_fn": lambda: Permute(0, 1, 2),
+ "input_fn": lambda: torch.rand(size=(1, 2, 3)),
+ },
+ {
+ "module_fn": lambda: Permute(0, 2, 1),
+ "input_fn": lambda: torch.rand(size=(4, 3, 2)),
+ },
+ {
+ "module_fn": lambda: Permute(0, 3, 1, 2),
+ "input_fn": lambda: torch.rand(size=(5, 4, 3, 2)),
+ },
+]
diff --git a/test/core/derivatives/pooling_adaptive_settings.py b/test/core/derivatives/pooling_adaptive_settings.py
new file mode 100644
index 000000000..0222f4c4d
--- /dev/null
+++ b/test/core/derivatives/pooling_adaptive_settings.py
@@ -0,0 +1,36 @@
+"""Test configurations for `backpack.core.derivatives` for adaptive pooling layers.
+
+Required entries:
+ "module_fn" (callable): Contains a model constructed from `torch.nn` layers
+ "input_fn" (callable): Used for specifying input function
+
+Optional entries:
+ "target_fn" (callable): Fetches the groundtruth/target classes
+ of regression/classification task
+ "loss_function_fn" (callable): Loss function used in the model
+ "device" [list(torch.device)]: List of devices to run the test on.
+ "id_prefix" (str): Prefix to be included in the test name.
+ "seed" (int): seed for the random number for torch.rand
+"""
+
+import torch
+
+POOLING_ADAPTIVE_SETTINGS = []
+
+###############################################################################
+# test settings #
+###############################################################################
+POOLING_ADAPTIVE_SETTINGS += [
+ {
+ "module_fn": lambda: torch.nn.AdaptiveAvgPool1d(output_size=(3,)),
+ "input_fn": lambda: torch.rand(size=(1, 4, 9)),
+ },
+ {
+ "module_fn": lambda: torch.nn.AdaptiveAvgPool2d(output_size=(3, 5)),
+ "input_fn": lambda: torch.rand(size=(2, 3, 9, 20)),
+ },
+ {
+ "module_fn": lambda: torch.nn.AdaptiveAvgPool3d(output_size=(2, 2, 2)),
+ "input_fn": lambda: torch.rand(size=(1, 3, 4, 8, 8)),
+ },
+]
diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py
index eba54f394..fa67a5634 100644
--- a/test/core/derivatives/problem.py
+++ b/test/core/derivatives/problem.py
@@ -1,15 +1,15 @@
"""Convert problem settings."""
import copy
-from test.core.derivatives.utils import (
- derivative_cls_for,
- get_available_devices,
- is_loss,
-)
+from test.core.derivatives.utils import derivative_cls_for, get_available_devices
+from typing import Dict, List, Tuple
import torch
+from torch import Tensor, long
from backpack import extend
+from backpack.utils.module_classification import is_loss
+from backpack.utils.subsampling import subsample
def make_test_problems(settings):
@@ -132,26 +132,37 @@ def make_output_shape(self):
else:
output = module(input, target)
+ if isinstance(output, tuple):
+ # is true for RNN,GRU,LSTM which return tuple (output, ...)
+ output = output[0]
+
return output.shape
def is_loss(self):
return is_loss(self.make_module())
- def forward_pass(self, input_requires_grad=False, sample_idx=None):
+ def forward_pass(
+ self, input_requires_grad: bool = False, subsampling: List[int] = None
+ ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]:
"""Do a forward pass. Return input, output, and parameters."""
- if sample_idx is None:
- input = self.input.clone().detach()
- else:
- input = self.input.clone()[sample_idx, :].unsqueeze(0).detach()
+ input: Tensor = self.input.clone().detach()
- if input_requires_grad:
+ if subsampling is not None:
+ batch_axis = 0
+ input = subsample(input, dim=batch_axis, subsampling=subsampling)
+
+ if input_requires_grad and input.dtype is not long:
input.requires_grad = True
if self.is_loss():
- assert sample_idx is None
- output = self.module(input, self.target)
+ assert subsampling is None
+ output: Tensor = self.module(input, self.target)
else:
- output = self.module(input)
+ output: Tensor = self.module(input)
+
+ if isinstance(output, tuple):
+ # is true for RNN,GRU,LSTM which return tuple (output, ...)
+ output: Tensor = output[0]
return input, output, dict(self.module.named_parameters())
@@ -177,3 +188,11 @@ def has_weight(self):
def has_bias(self):
module = self.make_module()
return hasattr(module, "bias") and module.bias is not None
+
+ def get_batch_size(self) -> int:
+ """Return the mini-batch size.
+
+ Returns:
+ Mini-batch size.
+ """
+ return self.input.shape[0]
diff --git a/test/core/derivatives/rnn_settings.py b/test/core/derivatives/rnn_settings.py
new file mode 100644
index 000000000..6641cbb55
--- /dev/null
+++ b/test/core/derivatives/rnn_settings.py
@@ -0,0 +1,25 @@
+"""Test configurations for `backpack.core.derivatives` RNN layers.
+
+Required entries:
+ "module_fn" (callable): Contains a model constructed from `torch.nn` layers
+ "input_fn" (callable): Used for specifying input function
+
+Optional entries:
+ "target_fn" (callable): Fetches the groundtruth/target classes
+ of regression/classification task
+ "loss_function_fn" (callable): Loss function used in the model
+ "device" [list(torch.device)]: List of devices to run the test on.
+ "id_prefix" (str): Prefix to be included in the test name.
+ "seed" (int): seed for the random number for torch.rand
+"""
+
+import torch
+
+RNN_SETTINGS = [
+ {
+ "module_fn": lambda: torch.nn.RNN(
+ input_size=4, hidden_size=3, batch_first=True
+ ),
+ "input_fn": lambda: torch.rand(size=(3, 5, 4)),
+ },
+]
diff --git a/test/core/derivatives/scale_module_settings.py b/test/core/derivatives/scale_module_settings.py
new file mode 100644
index 000000000..3bc089ef9
--- /dev/null
+++ b/test/core/derivatives/scale_module_settings.py
@@ -0,0 +1,24 @@
+"""Test settings for ScaleModule derivatives."""
+from torch import rand
+from torch.nn import Identity
+
+from backpack.custom_module.scale_module import ScaleModule
+
+SCALE_MODULE_SETTINGS = [
+ {
+ "module_fn": lambda: ScaleModule(),
+ "input_fn": lambda: rand(3, 4, 2),
+ },
+ {
+ "module_fn": lambda: ScaleModule(0.3),
+ "input_fn": lambda: rand(3, 2),
+ },
+ {
+ "module_fn": lambda: ScaleModule(5.7),
+ "input_fn": lambda: rand(2, 3),
+ },
+ {
+ "module_fn": lambda: Identity(),
+ "input_fn": lambda: rand(3, 1, 2),
+ },
+]
diff --git a/test/core/derivatives/settings.py b/test/core/derivatives/settings.py
index 0e1c33024..7379f5af7 100644
--- a/test/core/derivatives/settings.py
+++ b/test/core/derivatives/settings.py
@@ -15,6 +15,7 @@
from test.core.derivatives.linear_settings import LINEAR_SETTINGS
from test.core.derivatives.loss_settings import LOSS_SETTINGS
from test.core.derivatives.padding_settings import PADDING_SETTINGS
+from test.core.derivatives.pooling_adaptive_settings import POOLING_ADAPTIVE_SETTINGS
from test.core.derivatives.pooling_settings import POOLING_SETTINGS
SETTINGS = (
@@ -24,4 +25,5 @@
+ LOSS_SETTINGS
+ PADDING_SETTINGS
+ POOLING_SETTINGS
+ + POOLING_ADAPTIVE_SETTINGS
)
diff --git a/test/core/derivatives/utils.py b/test/core/derivatives/utils.py
index cfdca8433..21bfe1834 100644
--- a/test/core/derivatives/utils.py
+++ b/test/core/derivatives/utils.py
@@ -1,8 +1,12 @@
-"""Utility functions to test `backpack.core.derivatives`"""
+"""Utility functions to test `backpack.core.derivatives`."""
+from test.core.derivatives import derivatives_for
+from typing import Tuple, Type
import torch
+from torch import Tensor
+from torch.nn import Module
-from backpack.core.derivatives import derivatives_for
+from backpack.core.derivatives.basederivatives import BaseDerivatives
def get_available_devices():
@@ -19,42 +23,47 @@ def get_available_devices():
return devices
-def derivative_cls_for(module_cls):
+def derivative_cls_for(module_cls: Type[Module]) -> Type[BaseDerivatives]:
"""Return the associated derivative class for a module.
Args:
- module_cls (torch.nn.Module): Layer class.
+ module_cls: Layer class.
Returns:
- backpack.core.derivatives.Derivatives: Class implementing the
- derivatives for `module_cls`.
+ Class implementing the derivatives for `module_cls`.
+
+ Raises:
+ KeyError: if derivative for module is missing
"""
try:
return derivatives_for[module_cls]
- except KeyError:
+ except KeyError as e:
raise KeyError(
- "No derivative available for {}".format(module_cls)
- + "Known mappings:\n{}".format(derivatives_for)
- )
+ f"No derivative available for {module_cls}. "
+ + f"Known mappings:\n{derivatives_for}"
+ ) from e
-def is_loss(module):
- """Return whether `module` is a `torch` loss function.
+def classification_targets(size: Tuple[int, ...], num_classes: int) -> Tensor:
+ """Create random targets for classes 0, ..., `num_classes - 1`.
Args:
- module (torch.nn.Module): A PyTorch module.
+ size: shape of targets
+ num_classes: number of classes
Returns:
- bool: Whether `module` is a loss function.
+ classification targets
"""
- return isinstance(module, torch.nn.modules.loss._Loss)
+ return torch.randint(size=size, low=0, high=num_classes)
-def classification_targets(size, num_classes):
- """Create random targets for classes 0, ..., `num_classes - 1`."""
- return torch.randint(size=size, low=0, high=num_classes)
+def regression_targets(size: Tuple[int, ...]) -> Tensor:
+ """Create random targets for regression.
+ Args:
+ size: shape of targets
-def regression_targets(size):
- """Create random targets for regression."""
+ Returns:
+ regression targets
+ """
return torch.rand(size=size)
diff --git a/test/extensions/automated_settings.py b/test/extensions/automated_settings.py
index 58fbec8ac..f2334c515 100644
--- a/test/extensions/automated_settings.py
+++ b/test/extensions/automated_settings.py
@@ -1,35 +1,45 @@
+"""Contains helpers to create CNN test cases."""
from test.core.derivatives.utils import classification_targets
+from typing import Any, Tuple, Type
-import torch
+from torch import Tensor, rand
+from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, Module, ReLU, Sequential
-###
-# Helpers
-###
+def set_requires_grad(model: Module, new_requires_grad: bool) -> None:
+ """Set the ``requires_grad`` attribute of the model parameters.
-def make_simple_act_setting(act_cls, bias):
+ Args:
+ model: Network or layer.
+ new_requires_grad: New value for ``requires_grad``.
"""
- input: Activation function & Bias setting
- return: simple CNN Network
+ for p in model.parameters():
+ p.requires_grad = new_requires_grad
- This function is used to automatically create a
- simple CNN Network consisting of CNN & Linear layer
- for different activation functions.
- It is used to test `test.extensions`.
+
+def make_simple_act_setting(act_cls: Type[Module], bias: bool) -> dict:
+ """Create a simple CNN with activation as test case dictionary.
+
+ Make parameters of final linear layer non-differentiable to save run time.
+
+ Args:
+ act_cls: Class of the activation function.
+ bias: Use bias in the convolution.
+
+ Returns:
+ Dictionary representation of the simple CNN test case.
"""
- def make_simple_cnn(act_cls, bias):
- return torch.nn.Sequential(
- torch.nn.Conv2d(3, 2, 2, bias=bias),
- act_cls(),
- torch.nn.Flatten(),
- torch.nn.Linear(72, 5),
- )
+ def _make_simple_cnn(act_cls: Type[Module], bias: bool) -> Sequential:
+ linear = Linear(72, 5)
+ set_requires_grad(linear, False)
+
+ return Sequential(Conv2d(3, 2, 2, bias=bias), act_cls(), Flatten(), linear)
dict_setting = {
- "input_fn": lambda: torch.rand(3, 3, 7, 7),
- "module_fn": lambda: make_simple_cnn(act_cls, bias),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(),
+ "input_fn": lambda: rand(3, 3, 7, 7),
+ "module_fn": lambda: _make_simple_cnn(act_cls, bias),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
"target_fn": lambda: classification_targets((3,), 5),
"id_prefix": "automated-simple-cnn-act",
}
@@ -37,40 +47,37 @@ def make_simple_cnn(act_cls, bias):
return dict_setting
-def make_simple_cnn_setting(input_size, conv_class, conv_params):
- """
- input_size: tuple of input size of (N*C*Image Size)
- conv_class: convolutional class
- conv_params: configurations for convolutional class
- return: simple CNN Network
-
- This function is used to automatically create a
- simple CNN Network consisting of CNN & Linear layer
- for different convolutional layers.
- It is used to test `test.extensions`.
+def make_simple_cnn_setting(
+ input_size: Tuple[int], conv_cls: Type[Module], conv_params: Tuple[Any]
+) -> dict:
+ """Create ReLU CNN with convolution hyperparameters as test case dictionary.
+
+ Make parameters of final linear layer non-differentiable to save run time.
+
+ Args:
+ input_size: Input shape ``[N, C_in, ...]``.
+ conv_cls: Class of convolution layer.
+ conv_params: Convolution hyperparameters.
+
+ Returns:
+ Dictionary representation of the test case.
"""
- def make_cnn(conv_class, output_size, conv_params):
- """Note: output class size is assumed to be 5"""
- return torch.nn.Sequential(
- conv_class(*conv_params),
- torch.nn.ReLU(),
- torch.nn.Flatten(),
- torch.nn.Linear(output_size, 5),
- )
+ def _make_cnn(
+ conv_cls: Type[Module], output_dim: int, conv_params: Tuple
+ ) -> Sequential:
+ linear = Linear(output_dim, 5)
+ set_requires_grad(linear, False)
- def get_output_shape(module, module_params, input):
- """Returns the output shape for a given layer."""
- output = module(*module_params)(input)
- return output.numel() // output.shape[0]
+ return Sequential(conv_cls(*conv_params), ReLU(), Flatten(), linear)
- input = torch.rand(input_size)
- output_size = get_output_shape(conv_class, conv_params, input)
+ input = rand(input_size)
+ output_dim = _get_output_dim(conv_cls(*conv_params), input)
dict_setting = {
- "input_fn": lambda: torch.rand(input_size),
- "module_fn": lambda: make_cnn(conv_class, output_size, conv_params),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
+ "input_fn": lambda: rand(input_size),
+ "module_fn": lambda: _make_cnn(conv_cls, output_dim, conv_params),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
"id_prefix": "automated-simple-cnn",
}
@@ -78,49 +85,59 @@ def get_output_shape(module, module_params, input):
return dict_setting
-def make_simple_pooling_setting(input_size, conv_class, pool_cls, pool_params):
- """
- input_size: tuple of input size of (N*C*Image Size)
- conv_class: convolutional class
- conv_params: configurations for convolutional class
- return: simple CNN Network
-
- This function is used to automatically create a
- simple CNN Network consisting of CNN & Linear layer
- for different convolutional layers.
- It is used to test `test.extensions`.
+def make_simple_pooling_setting(
+ input_size: Tuple[int],
+ conv_cls: Type[Module],
+ pool_cls: Type[Module],
+ pool_params: Tuple[Any],
+) -> dict:
+ """Create CNN with convolution and pooling layer as test case dictionary.
+
+ Make parameters of final linear layer non-differentiable to save run time.
+
+ Args:
+ input_size: Input shape ``[N, C_in, ...]``.
+ conv_cls: Class of convolution layer.
+ pool_cls: Class of pooling layer.
+ pool_params: Pooling hyperparameters.
+
+ Returns:
+ Dictionary representation of the test case.
"""
- def make_cnn(conv_class, output_size, conv_params, pool_cls, pool_params):
- """Note: output class size is assumed to be 5"""
- return torch.nn.Sequential(
- conv_class(*conv_params),
- torch.nn.ReLU(),
- pool_cls(*pool_params),
- torch.nn.Flatten(),
- torch.nn.Linear(output_size, 5),
+ def _make_cnn(
+ conv_cls: Type[Module],
+ output_size: int,
+ conv_params: Tuple[Any],
+ pool_cls: Type[Module],
+ pool_params: Tuple[Any],
+ ) -> Sequential:
+ linear = Linear(output_size, 5)
+ set_requires_grad(linear, False)
+
+ return Sequential(
+ conv_cls(*conv_params), ReLU(), pool_cls(*pool_params), Flatten(), linear
)
- def get_output_shape(module, module_params, input, pool, pool_params):
- """Returns the output shape for a given layer."""
- output_1 = module(*module_params)(input)
- output = pool_cls(*pool_params)(output_1)
- return output.numel() // output.shape[0]
-
conv_params = (3, 2, 2)
- input = torch.rand(input_size)
- output_size = get_output_shape(
- conv_class, conv_params, input, pool_cls, pool_params
+ input = rand(input_size)
+ output_dim = _get_output_dim(
+ Sequential(conv_cls(*conv_params), pool_cls(*pool_params)), input
)
dict_setting = {
- "input_fn": lambda: torch.rand(input_size),
- "module_fn": lambda: make_cnn(
- conv_class, output_size, conv_params, pool_cls, pool_params
+ "input_fn": lambda: rand(input_size),
+ "module_fn": lambda: _make_cnn(
+ conv_cls, output_dim, conv_params, pool_cls, pool_params
),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
"id_prefix": "automated-simple-cnn",
}
return dict_setting
+
+
+def _get_output_dim(module: Module, input: Tensor) -> int:
+ output = module(input)
+ return output.numel() // output.shape[0]
diff --git a/test/extensions/firstorder/batch_grad/batch_grad_settings.py b/test/extensions/firstorder/batch_grad/batch_grad_settings.py
new file mode 100644
index 000000000..7b1926d63
--- /dev/null
+++ b/test/extensions/firstorder/batch_grad/batch_grad_settings.py
@@ -0,0 +1,11 @@
+"""Test cases for BackPACK's ``BatchGrad`` extension.
+
+The tests are taken from ``test.extensions.firstorder.firstorder_settings``,
+but additional custom tests can be defined here by appending it to the list.
+"""
+from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS
+
+SHARED_SETTINGS = FIRSTORDER_SETTINGS
+LOCAL_SETTINGS = []
+
+BATCH_GRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS
diff --git a/test/extensions/firstorder/batch_grad/batchgrad_settings.py b/test/extensions/firstorder/batch_grad/batchgrad_settings.py
deleted file mode 100644
index 862050ac7..000000000
--- a/test/extensions/firstorder/batch_grad/batchgrad_settings.py
+++ /dev/null
@@ -1,14 +0,0 @@
-"""Test configurations to test batch_grad
-
-The tests are taken from `test.extensions.firstorder.firstorder_settings`,
-but additional custom tests can be defined here by appending it to the list.
-"""
-
-from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS
-
-BATCHGRAD_SETTINGS = []
-
-SHARED_SETTINGS = FIRSTORDER_SETTINGS
-LOCAL_SETTING = []
-
-BATCHGRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTING
diff --git a/test/extensions/firstorder/batch_grad/test_batch_grad.py b/test/extensions/firstorder/batch_grad/test_batch_grad.py
new file mode 100644
index 000000000..7c916568e
--- /dev/null
+++ b/test/extensions/firstorder/batch_grad/test_batch_grad.py
@@ -0,0 +1,51 @@
+"""Test BackPACK's ``BatchGrad`` extension."""
+from test.automated_test import check_sizes_and_values
+from test.extensions.firstorder.batch_grad.batch_grad_settings import (
+ BATCH_GRAD_SETTINGS,
+)
+from test.extensions.implementation.autograd import AutogradExtensions
+from test.extensions.implementation.backpack import BackpackExtensions
+from test.extensions.problem import ExtensionsTestProblem, make_test_problems
+from test.extensions.utils import skip_if_subsampling_conflict
+from typing import List, Union
+
+from pytest import fixture, mark
+
+PROBLEMS = make_test_problems(BATCH_GRAD_SETTINGS)
+
+SUBSAMPLINGS = [None, [0, 0], [2, 0]]
+SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS]
+
+
+@fixture(params=PROBLEMS, ids=lambda p: p.make_id())
+def problem(request) -> ExtensionsTestProblem:
+ """Set up and tear down a test case.
+
+ Args:
+ request: Pytest request.
+
+ Yields:
+ Instantiated test case.
+ """
+ problem = request.param
+ problem.set_up()
+ yield problem
+ problem.tear_down()
+
+
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+def test_batch_grad(
+ problem: ExtensionsTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Test individual gradients.
+
+ Args:
+ problem: Test case.
+ subsampling: Indices of active samples.
+ """
+ skip_if_subsampling_conflict(problem, subsampling)
+
+ backpack_res = BackpackExtensions(problem).batch_grad(subsampling)
+ autograd_res = AutogradExtensions(problem).batch_grad(subsampling)
+
+ check_sizes_and_values(autograd_res, backpack_res)
diff --git a/test/extensions/firstorder/batch_grad/test_batchgrad.py b/test/extensions/firstorder/batch_grad/test_batchgrad.py
deleted file mode 100644
index c93ae69df..000000000
--- a/test/extensions/firstorder/batch_grad/test_batchgrad.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Test class for module Batch_grad (batch gradients)
-from `backpack.core.extensions.firstorder`
-
-Test individual gradients for the following layers:
-- batch gradients of linear layers
-- batch gradients of convolutional layers
-
-"""
-from test.automated_test import check_sizes_and_values
-from test.extensions.firstorder.batch_grad.batchgrad_settings import BATCHGRAD_SETTINGS
-from test.extensions.implementation.autograd import AutogradExtensions
-from test.extensions.implementation.backpack import BackpackExtensions
-from test.extensions.problem import make_test_problems
-
-import pytest
-
-PROBLEMS = make_test_problems(BATCHGRAD_SETTINGS)
-IDS = [problem.make_id() for problem in PROBLEMS]
-
-
-@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
-def test_batch_grad(problem):
- """Test individual gradients
-
- Args:
- problem (ExtensionsTestProblem): Problem for extension test.
- """
- problem.set_up()
-
- backpack_res = BackpackExtensions(problem).batch_grad()
- autograd_res = AutogradExtensions(problem).batch_grad()
-
- check_sizes_and_values(autograd_res, backpack_res)
- problem.tear_down()
diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py
index 0385e3c4e..4fc1f1a8e 100644
--- a/test/extensions/firstorder/firstorder_settings.py
+++ b/test/extensions/firstorder/firstorder_settings.py
@@ -1,37 +1,54 @@
-"""Test configurations for `backpack.core.extensions.firstorder`
-that is shared among the following firstorder methods:
-- batch_grad
-- batch_l2_grad
-- sum_grad_sqaured
-- variance
+"""Shared test cases for BackPACK's first-order extensions.
+Shared by the tests of:
+- ``BatchGrad``
+- ``BatchL2Grad``
+- ``SumGradSquared``
+- ``Variance``
Required entries:
"module_fn" (callable): Contains a model constructed from `torch.nn` layers
"input_fn" (callable): Used for specifying input function
- "target_fn" (callable): Fetches the groundtruth/target classes
+ "target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
Optional entries:
- "device" [list(torch.device)]: List of devices to run the test on.
+ "device" [list(device)]: List of devices to run the test on.
"id_prefix" (str): Prefix to be included in the test name.
- "seed" (int): seed for the random number for torch.rand
+ "seed" (int): seed set before initializing a case.
"""
-
-
from test.core.derivatives.utils import classification_targets, regression_targets
from test.extensions.automated_settings import make_simple_cnn_setting
+from test.utils.evaluation_mode import initialize_training_false_recursive
-import torch
+from torch import device, rand, randint
from torch.nn import (
+ LSTM,
+ RNN,
+ BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
+ CrossEntropyLoss,
+ Embedding,
+ Flatten,
+ Linear,
+ MSELoss,
+ ReLU,
+ Sequential,
+ Sigmoid,
)
+from torchvision.models import resnet18
+
+from backpack import convert_module_to_backpack
+from backpack.custom_module.permute import Permute
+from backpack.custom_module.reduce_tuple import ReduceTuple
FIRSTORDER_SETTINGS = []
@@ -40,11 +57,11 @@
###############################################################################
example = {
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(torch.nn.Linear(10, 5)),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 5)),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
- "device": [torch.device("cpu")],
+ "device": [device("cpu")],
"seed": 0,
"id_prefix": "example",
}
@@ -57,47 +74,91 @@
FIRSTORDER_SETTINGS += [
# classification
{
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(
- torch.nn.Linear(10, 7), torch.nn.Linear(7, 5)
- ),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 7), Linear(7, 5)),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(
- torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5)
- ),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
},
- # Regression
+ # regression
{
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(
- torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5)
- ),
- "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)),
+ "loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((3, 5)),
},
]
+# linear with additional dimension
+FIRSTORDER_SETTINGS += [
+ # regression
+ {
+ "input_fn": lambda: rand(3, 4, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)),
+ "loss_function_fn": lambda: MSELoss(reduction="mean"),
+ "target_fn": lambda: regression_targets((3, 4, 2)),
+ "id_prefix": "one-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), Sigmoid(), Linear(3, 2)),
+ "loss_function_fn": lambda: MSELoss(reduction="mean"),
+ "target_fn": lambda: regression_targets((3, 4, 2, 2)),
+ "id_prefix": "two-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 3, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)),
+ "loss_function_fn": lambda: MSELoss(reduction="sum"),
+ "target_fn": lambda: regression_targets((3, 4, 2, 3, 2)),
+ "id_prefix": "three-additional",
+ },
+ # classification
+ {
+ "input_fn": lambda: rand(3, 4, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3,), 8),
+ "id_prefix": "one-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 5),
+ "module_fn": lambda: Sequential(
+ Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten()
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3,), 16),
+ "id_prefix": "two-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 3, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2), Flatten()),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
+ "target_fn": lambda: classification_targets((3,), 48),
+ "id_prefix": "three-additional",
+ },
+]
+
###############################################################################
# test setting: Convolutional Layers #
"""
-Syntax with default parameters:
- - `torch.nn.ConvNd(in_channels, out_channels,
- kernel_size, stride=1, padding=0, dilation=1,
- groups=1, bias=True, padding_mode='zeros)`
+Syntax with default parameters:
+ - `torch.nn.ConvNd(in_channels, out_channels,
+ kernel_size, stride=1, padding=0, dilation=1,
+ groups=1, bias=True, padding_mode='zeros)`
- - `torch.nn.ConvTransposeNd(in_channels, out_channels,
- kernel_size, stride=1, padding=0, output_padding=0,
+ - `torch.nn.ConvTransposeNd(in_channels, out_channels,
+ kernel_size, stride=1, padding=0, output_padding=0,
groups=1, bias=True, dilation=1, padding_mode='zeros)`
-Note: There are 5 tests added to each `torch.nn.layers`.
+Note: There are 5 tests added to each `torch.nn.layers`.
For `torch.nn.ConvTranspose2d` and `torch.nn.ConvTranspose3d`
-only 3 tests are added because they are very memory intensive.
+only 3 tests are added because they are very memory intensive.
"""
###############################################################################
@@ -152,3 +213,121 @@
(3, 3, 2, 7, 7), ConvTranspose3d, (3, 2, 2, 4, 2, 0, 1, False)
),
]
+
+###############################################################################
+# test setting: BatchNorm #
+###############################################################################
+FIRSTORDER_SETTINGS += [
+ {
+ "input_fn": lambda: rand(2, 3, 4),
+ "module_fn": lambda: initialize_training_false_recursive(
+ BatchNorm1d(num_features=3)
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((2, 4), 3),
+ },
+ {
+ "input_fn": lambda: rand(3, 2, 4, 3),
+ "module_fn": lambda: initialize_training_false_recursive(
+ BatchNorm2d(num_features=2)
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3, 4, 3), 2),
+ },
+ {
+ "input_fn": lambda: rand(3, 3, 4, 1, 2),
+ "module_fn": lambda: initialize_training_false_recursive(
+ BatchNorm3d(num_features=3)
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3, 4, 1, 2), 3),
+ },
+ {
+ "input_fn": lambda: rand(3, 3, 4, 1, 2),
+ "module_fn": lambda: initialize_training_false_recursive(
+ Sequential(
+ BatchNorm3d(num_features=3),
+ Linear(2, 3),
+ BatchNorm3d(num_features=3),
+ ReLU(),
+ BatchNorm3d(num_features=3),
+ )
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3, 4, 1, 3), 3),
+ },
+]
+###############################################################################
+# test setting: RNN Layers #
+###############################################################################
+FIRSTORDER_SETTINGS += [
+ {
+ "input_fn": lambda: rand(8, 5, 6),
+ "module_fn": lambda: Sequential(
+ RNN(input_size=6, hidden_size=3, batch_first=True),
+ ReduceTuple(index=0),
+ Permute(0, 2, 1),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((8, 5), 3),
+ },
+ {
+ "input_fn": lambda: rand(8, 5, 6),
+ "module_fn": lambda: Sequential(
+ RNN(input_size=6, hidden_size=3, batch_first=True),
+ ReduceTuple(index=0),
+ Permute(0, 2, 1),
+ Flatten(),
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((8, 3 * 5)),
+ },
+ {
+ "input_fn": lambda: rand(4, 5, 3),
+ "module_fn": lambda: Sequential(
+ LSTM(3, 4, batch_first=True),
+ ReduceTuple(index=0),
+ Flatten(),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
+ "target_fn": lambda: classification_targets((4,), 20),
+ },
+]
+###############################################################################
+# test setting: Embedding #
+###############################################################################
+FIRSTORDER_SETTINGS += [
+ {
+ "input_fn": lambda: randint(0, 5, (6,)),
+ "module_fn": lambda: Sequential(
+ Embedding(5, 3),
+ Linear(3, 4),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((6,), 4),
+ },
+ {
+ "input_fn": lambda: randint(0, 3, (4, 2, 2)),
+ "module_fn": lambda: Sequential(
+ Embedding(3, 5),
+ Flatten(),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((4,), 2 * 5),
+ },
+]
+
+###############################################################################
+# test setting: torchvision resnet #
+###############################################################################
+FIRSTORDER_SETTINGS += [
+ {
+ "input_fn": lambda: rand(2, 3, 7, 7),
+ "module_fn": lambda: convert_module_to_backpack(
+ resnet18(num_classes=4).eval(), True
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((2, 4)),
+ "id_prefix": "resnet18",
+ },
+]
diff --git a/test/extensions/firstorder/variance/__init__.py b/test/extensions/firstorder/variance/__init__.py
index e69de29bb..7ca2b624c 100644
--- a/test/extensions/firstorder/variance/__init__.py
+++ b/test/extensions/firstorder/variance/__init__.py
@@ -0,0 +1 @@
+"""Contains tests for BackPACK's ``Variance`` extension."""
diff --git a/test/extensions/firstorder/variance/test_variance.py b/test/extensions/firstorder/variance/test_variance.py
index 50cf21675..8680c2086 100644
--- a/test/extensions/firstorder/variance/test_variance.py
+++ b/test/extensions/firstorder/variance/test_variance.py
@@ -1,16 +1,9 @@
-"""Test class for module variance
-from `backpack.core.extensions.firstorder`
-
-Test variances for the following layers:
-- variance of linear layers
-- variance of convolutional layers
-
-"""
+"""Test BackPACK's ``Variance`` extension."""
from test.automated_test import check_sizes_and_values
from test.extensions.firstorder.variance.variance_settings import VARIANCE_SETTINGS
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
-from test.extensions.problem import make_test_problems
+from test.extensions.problem import ExtensionsTestProblem, make_test_problems
import pytest
@@ -19,16 +12,17 @@
@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
-def test_variance(problem):
- """Test variance of individual gradients
+def test_variance(problem: ExtensionsTestProblem) -> None:
+ """Test variance of individual gradients.
Args:
- problem (ExtensionsTestProblem): Problem for extension test.
+ problem: Test case.
"""
problem.set_up()
backpack_res = BackpackExtensions(problem).variance()
autograd_res = AutogradExtensions(problem).variance()
- check_sizes_and_values(autograd_res, backpack_res)
+ rtol = 5e-5
+ check_sizes_and_values(autograd_res, backpack_res, rtol=rtol)
problem.tear_down()
diff --git a/test/extensions/firstorder/variance/variance_settings.py b/test/extensions/firstorder/variance/variance_settings.py
index 61c39d0d6..c8a8de2da 100644
--- a/test/extensions/firstorder/variance/variance_settings.py
+++ b/test/extensions/firstorder/variance/variance_settings.py
@@ -1,7 +1,7 @@
-"""Test configurations to test variance
+"""Test cases for ``Variance`` extension.
-The tests are taken from `test.extensions.firstorder.firstorder_settings`,
-but additional custom tests can be defined here by appending it to the list.
+Uses shared test cases from `test.extensions.firstorder.firstorder_settings`,
+and the local cases defined in this file.
"""
from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS
diff --git a/test/extensions/graph_clear_test.py b/test/extensions/graph_clear_test.py
new file mode 100644
index 000000000..17b6419ee
--- /dev/null
+++ b/test/extensions/graph_clear_test.py
@@ -0,0 +1,61 @@
+"""Test whether the graph is clear after a backward pass."""
+from typing import Tuple
+
+from pytest import fixture
+from torch import Tensor, rand, rand_like
+from torch.nn import Flatten, Linear, Module, MSELoss, ReLU, Sequential
+
+from backpack import backpack, extend
+from backpack.extensions import DiagGGNExact
+
+PROBLEM_STRING = ["standard", "flatten_no_op", "flatten_with_op"]
+
+
+def test_graph_clear(problem) -> None:
+ """Test that the graph is clear after a backward pass.
+
+ More specifically, test that there are no saved quantities left over.
+
+ Args:
+ problem: problem consisting of inputs, and model
+ """
+ inputs, model = problem
+ extension = DiagGGNExact()
+ outputs = extend(model)(inputs)
+ loss = extend(MSELoss())(outputs, rand_like(outputs))
+ with backpack(extension):
+ loss.backward()
+
+ # test that the dictionary is empty
+ saved_quantities: dict = extension.saved_quantities._saved_quantities
+ assert type(saved_quantities) is dict
+ assert not saved_quantities
+
+
+@fixture(params=PROBLEM_STRING, ids=PROBLEM_STRING)
+def problem(request) -> Tuple[Tensor, Module]:
+ """Problem setting.
+
+ Args:
+ request: pytest request, contains parameters
+
+ Yields:
+ inputs and model
+
+ Raises:
+ NotImplementedError: if problem string is unknown
+ """
+ batch_size, in_dim, out_dim = 2, 3, 4
+ inputs = rand(batch_size, in_dim)
+ if request.param == PROBLEM_STRING[0]:
+ model = Sequential(Linear(in_dim, out_dim), ReLU(), Linear(out_dim, out_dim))
+ elif request.param == PROBLEM_STRING[1]:
+ model = Sequential(Linear(in_dim, out_dim), Flatten(), Linear(out_dim, out_dim))
+ elif request.param == PROBLEM_STRING[2]:
+ inputs = rand(batch_size, in_dim, in_dim)
+ model = Sequential(
+ Linear(in_dim, out_dim), Flatten(), Linear(in_dim * out_dim, out_dim)
+ )
+ else:
+ raise NotImplementedError(f"unknown request.param={request.param}")
+ yield inputs, model
diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py
index adedb7cf0..ac9f94f11 100644
--- a/test/extensions/implementation/autograd.py
+++ b/test/extensions/implementation/autograd.py
@@ -1,8 +1,12 @@
+"""Autograd implementation of BackPACK's extensions."""
+from math import isclose
from test.extensions.implementation.base import ExtensionsImplementation
+from typing import Iterator, List, Union
-import torch
+from torch import Tensor, autograd, backends, cat, stack, var, zeros, zeros_like
+from torch.nn.utils.convert_parameters import parameters_to_vector
-from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
+from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.hessianfree.rop import R_op
from backpack.utils.convert_parameters import vector_to_parameter_list
@@ -10,126 +14,102 @@
class AutogradExtensions(ExtensionsImplementation):
"""Extension implementations with autograd."""
- def batch_grad(self):
- N = self.problem.input.shape[0]
- batch_grads = [
- torch.zeros(N, *p.size()).to(self.problem.device)
- for p in self.problem.model.parameters()
- ]
+ def batch_grad(
+ self, subsampling: Union[List[int], None]
+ ) -> List[Tensor]: # noqa: D102
+ N = self.problem.get_batch_size()
+ samples = list(range(N)) if subsampling is None else subsampling
- loss_list = torch.zeros((N))
gradients_list = []
for b in range(N):
- _, _, loss = self.problem.forward_pass(sample_idx=b)
- gradients = torch.autograd.grad(loss, self.problem.model.parameters())
+ _, _, loss = self.problem.forward_pass(subsampling=[b])
+ gradients = autograd.grad(loss, self.problem.trainable_parameters())
gradients_list.append(gradients)
- loss_list[b] = loss
- _, _, batch_loss = self.problem.forward_pass()
- factor = self.problem.get_reduction_factor(batch_loss, loss_list)
+ batch_grads = [
+ zeros(len(samples), *p.size()).to(self.problem.device)
+ for p in self.problem.trainable_parameters()
+ ]
+ factor = self.problem.compute_reduction_factor()
- for b, gradients in zip(range(N), gradients_list):
- for idx, g in enumerate(gradients):
- batch_grads[idx][b, :] = g.detach() * factor
+ for out_idx, sample in enumerate(samples):
+ for param_idx, sample_g in enumerate(gradients_list[sample]):
+ batch_grads[param_idx][out_idx, :] = sample_g.detach() * factor
return batch_grads
- def batch_l2_grad(self):
- batch_grad = self.batch_grad()
- batch_l2_grads = [(g ** 2).flatten(start_dim=1).sum(1) for g in batch_grad]
- return batch_l2_grads
-
- def sgs(self):
- N = self.problem.input.shape[0]
- sgs = [
- torch.zeros(*p.size()).to(self.problem.device)
- for p in self.problem.model.parameters()
+ def batch_l2_grad(self) -> List[Tensor]: # noqa: D102
+ return [
+ (g ** 2).flatten(start_dim=1).sum(1)
+ for g in self.batch_grad(subsampling=None)
]
- loss_list = torch.zeros((N))
- gradients_list = []
- for b in range(N):
- _, _, loss = self.problem.forward_pass(sample_idx=b)
- gradients = torch.autograd.grad(loss, self.problem.model.parameters())
- loss_list[b] = loss
- gradients_list.append(gradients)
-
- _, _, batch_loss = self.problem.forward_pass()
- factor = self.problem.get_reduction_factor(batch_loss, loss_list)
-
- for _, gradients in zip(range(N), gradients_list):
- for idx, g in enumerate(gradients):
- sgs[idx] += (g.detach() * factor) ** 2
- return sgs
+ def sgs(self) -> List[Tensor]: # noqa: D102
+ return [(g ** 2).sum(0) for g in self.batch_grad(subsampling=None)]
- def variance(self):
- batch_grad = self.batch_grad()
- variances = [torch.var(g, dim=0, unbiased=False) for g in batch_grad]
- return variances
-
- def _get_diag_ggn(self, loss, output):
- def extract_ith_element_of_diag_ggn(i, p, loss, output):
- v = torch.zeros(p.numel()).to(self.problem.device)
- v[i] = 1.0
- vs = vector_to_parameter_list(v, [p])
- GGN_vs = ggn_vector_product_from_plist(loss, output, [p], vs)
- GGN_v = torch.cat([g.detach().view(-1) for g in GGN_vs])
- return GGN_v[i]
-
- diag_ggns = []
- for p in list(self.problem.model.parameters()):
- diag_ggn_p = torch.zeros_like(p).view(-1)
-
- for parameter_index in range(p.numel()):
- diag_value = extract_ith_element_of_diag_ggn(
- parameter_index, p, loss, output
- )
- diag_ggn_p[parameter_index] = diag_value
-
- diag_ggns.append(diag_ggn_p.view(p.size()))
- return diag_ggns
-
- def diag_ggn(self):
- _, output, loss = self.problem.forward_pass()
- return self._get_diag_ggn(loss, output)
-
- def diag_ggn_batch(self):
- batch_size = self.problem.input.shape[0]
- _, _, batch_loss = self.problem.forward_pass()
- loss_list = torch.zeros(batch_size, device=self.problem.device)
+ def variance(self) -> List[Tensor]: # noqa: D102
+ return [
+ var(g, dim=0, unbiased=False) for g in self.batch_grad(subsampling=None)
+ ]
+ def _get_diag_ggn(self, loss: Tensor, output: Tensor) -> List[Tensor]:
+ diag_ggn_flat = cat(
+ [col[[i]] for i, col in enumerate(self._ggn_columns(loss, output))]
+ )
+ return vector_to_parameter_list(
+ diag_ggn_flat, list(self.problem.trainable_parameters())
+ )
+
+ def diag_ggn(self) -> List[Tensor]: # noqa: D102
+ try:
+ _, output, loss = self.problem.forward_pass()
+ return self._get_diag_ggn(loss, output)
+ except RuntimeError:
+ # torch does not implement cuda double-backwards pass on RNNs and
+ # recommends this workaround
+ with backends.cudnn.flags(enabled=False):
+ _, output, loss = self.problem.forward_pass()
+ return self._get_diag_ggn(loss, output)
+
+ def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa: D102
+ try:
+ return self._diag_ggn_exact_batch()
+ except RuntimeError:
+ # torch does not implement cuda double-backwards pass on RNNs and
+ # recommends this workaround
+ with backends.cudnn.flags(enabled=False):
+ return self._diag_ggn_exact_batch()
+
+ def _diag_ggn_exact_batch(self):
# batch_diag_ggn has entries [sample_idx][param_idx]
batch_diag_ggn = []
- for b in range(batch_size):
- _, output, loss = self.problem.forward_pass(sample_idx=b)
+ for b in range(self.problem.get_batch_size()):
+ _, output, loss = self.problem.forward_pass(subsampling=[b])
diag_ggn = self._get_diag_ggn(loss, output)
batch_diag_ggn.append(diag_ggn)
- loss_list[b] = loss
- factor = self.problem.get_reduction_factor(batch_loss, loss_list)
+
+ factor = self.problem.compute_reduction_factor()
+
# params_batch_diag_ggn has entries [param_idx][sample_idx]
params_batch_diag_ggn = list(zip(*batch_diag_ggn))
- return [torch.stack(param) * factor for param in params_batch_diag_ggn]
+ return [stack(param) * factor for param in params_batch_diag_ggn]
def _get_diag_h(self, loss):
- def hvp(df_dx, x, v):
- Hv = R_op(df_dx, x, v)
- return [j.detach() for j in Hv]
-
def extract_ith_element_of_diag_h(i, p, df_dx):
- v = torch.zeros(p.numel()).to(self.problem.device)
+ v = zeros_like(p).flatten()
v[i] = 1.0
vs = vector_to_parameter_list(v, [p])
- Hvs = hvp(df_dx, [p], vs)
- Hv = torch.cat([g.detach().view(-1) for g in Hvs])
+ Hvs = R_op(df_dx, [p], vs)
+ Hv = cat([g.flatten() for g in Hvs])
return Hv[i]
diag_hs = []
- for p in list(self.problem.model.parameters()):
- diag_h_p = torch.zeros_like(p).view(-1)
+ for p in list(self.problem.trainable_parameters()):
+ diag_h_p = zeros_like(p).flatten()
- df_dx = torch.autograd.grad(loss, [p], create_graph=True, retain_graph=True)
+ df_dx = autograd.grad(loss, [p], create_graph=True, retain_graph=True)
for parameter_index in range(p.numel()):
diag_value = extract_ith_element_of_diag_h(parameter_index, p, df_dx)
diag_h_p[parameter_index] = diag_value
@@ -137,21 +117,64 @@ def extract_ith_element_of_diag_h(i, p, df_dx):
diag_hs.append(diag_h_p.view(p.size()))
return diag_hs
- def diag_h(self):
+ def diag_h(self) -> List[Tensor]: # noqa: D102
_, _, loss = self.problem.forward_pass()
return self._get_diag_h(loss)
- def diag_h_batch(self):
- batch_size = self.problem.input.shape[0]
- _, _, batch_loss = self.problem.forward_pass()
- loss_list = torch.zeros(batch_size, device=self.problem.device)
-
+ def diag_h_batch(self) -> List[Tensor]: # noqa: D102
batch_diag_h = []
- for b in range(batch_size):
- _, _, loss = self.problem.forward_pass(sample_idx=b)
- loss_list[b] = loss
+ for b in range(self.problem.get_batch_size()):
+ _, _, loss = self.problem.forward_pass(subsampling=[b])
diag_h = self._get_diag_h(loss)
batch_diag_h.append(diag_h)
- factor = self.problem.get_reduction_factor(batch_loss, loss_list)
+
+ factor = self.problem.compute_reduction_factor()
+
params_batch_diag_h = list(zip(*batch_diag_h))
- return [torch.stack(param) * factor for param in params_batch_diag_h]
+ return [stack(param) * factor for param in params_batch_diag_h]
+
+ def ggn(self, subsampling: List[int] = None) -> Tensor: # noqa: D102
+ _, output, loss = self.problem.forward_pass(subsampling=subsampling)
+ ggn = stack(list(self._ggn_columns(loss, output)), dim=1)
+
+ # correct normalization constant for 'mean' reduction
+ if subsampling is not None:
+ factor = self.problem.compute_reduction_factor()
+ if not isclose(factor, 1.0):
+ ggn *= len(subsampling) * factor
+
+ return ggn
+
+ def _ggn_columns(self, loss: Tensor, output: Tensor) -> Iterator[Tensor]:
+ params = list(self.problem.trainable_parameters())
+ num_params = sum(p.numel() for p in params)
+ model = self.problem.model
+
+ for i in range(num_params):
+ # GGN-vector product with i.th unit vector yields the i.th row
+ e_i = zeros(num_params).to(self.problem.device)
+ e_i[i] = 1.0
+
+ # convert to model parameter shapes
+ e_i_list = vector_to_parameter_list(e_i, params)
+ ggn_i_list = ggn_vector_product(loss, output, model, e_i_list)
+
+ yield parameters_to_vector(ggn_i_list)
+
+ def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa: D102
+ raise NotImplementedError
+
+ def diag_ggn_mc_batch(self, mc_samples: int) -> List[Tensor]: # noqa: D102
+ raise NotImplementedError
+
+ def ggn_mc(self, mc_samples: int, chunks: int = 1): # noqa: D102
+ raise NotImplementedError
+
+ def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: # noqa: D102
+ raise NotImplementedError
+
+ def kflr(self) -> List[List[Tensor]]: # noqa: D102
+ raise NotImplementedError
+
+ def kfra(self) -> List[List[Tensor]]: # noqa: D102
+ raise NotImplementedError
diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py
index ae121c6c5..4754be907 100644
--- a/test/extensions/implementation/backpack.py
+++ b/test/extensions/implementation/backpack.py
@@ -1,9 +1,15 @@
+"""Extension implementations with BackPACK."""
from test.extensions.implementation.base import ExtensionsImplementation
from test.extensions.implementation.hooks import (
BatchL2GradHook,
ExtensionHookManager,
SumGradSquaredHook,
)
+from test.extensions.problem import ExtensionsTestProblem
+from test.utils import chunk_sizes
+from typing import List
+
+from torch import Tensor, cat, einsum
import backpack.extensions as new_ext
from backpack import backpack
@@ -12,91 +18,98 @@
class BackpackExtensions(ExtensionsImplementation):
"""Extension implementations with BackPACK."""
- def __init__(self, problem):
+ def __init__(self, problem: ExtensionsTestProblem):
+ """Add BackPACK functionality to, and store, the test case.
+
+ Args:
+ problem: Test case
+ """
problem.extend()
super().__init__(problem)
- def batch_grad(self):
- with backpack(new_ext.BatchGrad()):
+ def batch_grad(self, subsampling) -> List[Tensor]: # noqa:D102
+ with backpack(new_ext.BatchGrad(subsampling=subsampling)):
_, _, loss = self.problem.forward_pass()
loss.backward()
- batch_grads = [p.grad_batch for p in self.problem.model.parameters()]
- return batch_grads
+ return self.problem.collect_data("grad_batch")
- def batch_l2_grad(self):
+ def batch_l2_grad(self) -> List[Tensor]: # noqa:D102
with backpack(new_ext.BatchL2Grad()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- batch_l2_grad = [p.batch_l2 for p in self.problem.model.parameters()]
- return batch_l2_grad
+ return self.problem.collect_data("batch_l2")
- def batch_l2_grad_extension_hook(self):
- """Individual gradient squared ℓ₂ norms via extension hook."""
+ def batch_l2_grad_extension_hook(self) -> List[Tensor]:
+ """Individual gradient squared ℓ₂ norms via extension hook.
+
+ Returns:
+ Parameter-wise individual gradient norms.
+ """
hook = ExtensionHookManager(BatchL2GradHook())
with backpack(new_ext.BatchGrad(), extension_hook=hook):
_, _, loss = self.problem.forward_pass()
loss.backward()
- batch_l2_grad = [p.batch_l2_hook for p in self.problem.model.parameters()]
- return batch_l2_grad
+ return self.problem.collect_data("batch_l2_hook")
- def sgs(self):
+ def sgs(self) -> List[Tensor]: # noqa:D102
with backpack(new_ext.SumGradSquared()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- sgs = [p.sum_grad_squared for p in self.problem.model.parameters()]
- return sgs
+ return self.problem.collect_data("sum_grad_squared")
+
+ def sgs_extension_hook(self) -> List[Tensor]:
+ """Individual gradient second moment via extension hook.
- def sgs_extension_hook(self):
- """Individual gradient second moment via extension hook."""
+ Returns:
+ Parameter-wise individual gradient second moment.
+ """
hook = ExtensionHookManager(SumGradSquaredHook())
with backpack(new_ext.BatchGrad(), extension_hook=hook):
_, _, loss = self.problem.forward_pass()
loss.backward()
- sgs = [p.sum_grad_squared_hook for p in self.problem.model.parameters()]
- return sgs
+ return self.problem.collect_data("sum_grad_squared_hook")
- def variance(self):
+ def variance(self) -> List[Tensor]: # noqa:D102
with backpack(new_ext.Variance()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- variances = [p.variance for p in self.problem.model.parameters()]
- return variances
+ return self.problem.collect_data("variance")
- def diag_ggn(self):
+ def diag_ggn(self) -> List[Tensor]: # noqa:D102
with backpack(new_ext.DiagGGNExact()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- diag_ggn = [p.diag_ggn_exact for p in self.problem.model.parameters()]
- return diag_ggn
+ return self.problem.collect_data("diag_ggn_exact")
- def diag_ggn_exact_batch(self):
+ def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa:D102
with backpack(new_ext.BatchDiagGGNExact()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- diag_ggn_exact_batch = [
- p.diag_ggn_exact_batch for p in self.problem.model.parameters()
- ]
- return diag_ggn_exact_batch
+ return self.problem.collect_data("diag_ggn_exact_batch")
- def diag_ggn_mc(self, mc_samples):
+ def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa:D102
with backpack(new_ext.DiagGGNMC(mc_samples=mc_samples)):
_, _, loss = self.problem.forward_pass()
loss.backward()
- diag_ggn_mc = [p.diag_ggn_mc for p in self.problem.model.parameters()]
- return diag_ggn_mc
+ return self.problem.collect_data("diag_ggn_mc")
- def diag_ggn_mc_batch(self, mc_samples):
+ def diag_ggn_mc_batch(self, mc_samples) -> List[Tensor]: # noqa:D102
with backpack(new_ext.BatchDiagGGNMC(mc_samples=mc_samples)):
_, _, loss = self.problem.forward_pass()
loss.backward()
- diag_ggn_mc_batch = [
- p.diag_ggn_mc_batch for p in self.problem.model.parameters()
- ]
- return diag_ggn_mc_batch
+ return self.problem.collect_data("diag_ggn_mc_batch")
- def diag_ggn_mc_chunk(self, mc_samples, chunks=10):
- """Like ``diag_ggn_mc``, but handles larger number of samples by chunking."""
- chunk_samples = self.chunk_sizes(mc_samples, chunks)
+ def diag_ggn_mc_chunk(self, mc_samples: int, chunks: int = 10) -> List[Tensor]:
+ """Like ``diag_ggn_mc``, but can handle more samples by chunking.
+
+ Args:
+ mc_samples: Number of Monte-Carlo samples.
+ chunks: Maximum sequential split of the computation. Default: ``10``.
+
+ Returns:
+ Parameter-wise MC-approximation of the GGN diagonal.
+ """
+ chunk_samples = chunk_sizes(mc_samples, chunks)
chunk_weights = [samples / mc_samples for samples in chunk_samples]
diag_ggn_mc = None
@@ -113,11 +126,19 @@ def diag_ggn_mc_chunk(self, mc_samples, chunks=10):
return diag_ggn_mc
- def diag_ggn_mc_batch_chunk(self, mc_samples, chunks=10):
- """
- Like ``diag_ggn_mc_batch``, but handles larger number of samples by chunking.
+ def diag_ggn_mc_batch_chunk(
+ self, mc_samples: int, chunks: int = 10
+ ) -> List[Tensor]:
+ """Like ``diag_ggn_mc_batch``, but can handle more samples by chunking.
+
+ Args:
+ mc_samples: Number of Monte-Carlo samples.
+ chunks: Maximum sequential split of the computation. Default: ``10``.
+
+ Returns:
+ Parameter-wise MC-approximation of the per-sample GGN diagonals.
"""
- chunk_samples = self.chunk_sizes(mc_samples, chunks)
+ chunk_samples = chunk_sizes(mc_samples, chunks)
chunk_weights = [samples / mc_samples for samples in chunk_samples]
diag_ggn_mc_batch = None
@@ -136,57 +157,94 @@ def diag_ggn_mc_batch_chunk(self, mc_samples, chunks=10):
return diag_ggn_mc_batch
- @staticmethod
- def chunk_sizes(total_size, num_chunks):
- """Return list containing the sizes of chunks."""
- chunk_size = max(total_size // num_chunks, 1)
-
- if chunk_size == 1:
- sizes = total_size * [chunk_size]
- else:
- equal, rest = divmod(total_size, chunk_size)
- sizes = equal * [chunk_size]
-
- if rest != 0:
- sizes.append(rest)
-
- return sizes
-
- def diag_h(self):
+ def diag_h(self) -> List[Tensor]: # noqa:D102
with backpack(new_ext.DiagHessian()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- diag_h = [p.diag_h for p in self.problem.model.parameters()]
- return diag_h
+ return self.problem.collect_data("diag_h")
- def kfac(self, mc_samples=1):
+ def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: # noqa:D102
with backpack(new_ext.KFAC(mc_samples=mc_samples)):
_, _, loss = self.problem.forward_pass()
loss.backward()
- kfac = [p.kfac for p in self.problem.model.parameters()]
+ return self.problem.collect_data("kfac")
- return kfac
-
- def kflr(self):
+ def kflr(self) -> List[List[Tensor]]: # noqa:D102
with backpack(new_ext.KFLR()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- kflr = [p.kflr for p in self.problem.model.parameters()]
-
- return kflr
+ return self.problem.collect_data("kflr")
- def kfra(self):
+ def kfra(self) -> List[List[Tensor]]: # noqa:D102
with backpack(new_ext.KFRA()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- kfra = [p.kfra for p in self.problem.model.parameters()]
-
- return kfra
+ return self.problem.collect_data("kfra")
- def diag_h_batch(self):
+ def diag_h_batch(self) -> List[Tensor]: # noqa:D102
with backpack(new_ext.BatchDiagHessian()):
_, _, loss = self.problem.forward_pass()
loss.backward()
- diag_h_batch = [p.diag_h_batch for p in self.problem.model.parameters()]
+ return self.problem.collect_data("diag_h_batch")
+
+ def ggn(self, subsampling: List[int] = None) -> Tensor: # noqa:D102
+ return self._square_sqrt_ggn(self.sqrt_ggn(subsampling=subsampling))
+
+ def sqrt_ggn(self, subsampling: List[int] = None) -> List[Tensor]:
+ """Compute the matrix square root of the exact generalized Gauss-Newton.
+
+ Args:
+ subsampling: Indices of active samples. Defaults to ``None`` (use all
+ samples in the mini-batch).
+
+ Returns:
+ Parameter-wise matrix square root of the exact GGN.
+ """
+ with backpack(new_ext.SqrtGGNExact(subsampling=subsampling)):
+ _, _, loss = self.problem.forward_pass()
+ loss.backward()
+ return self.problem.collect_data("sqrt_ggn_exact")
+
+ def sqrt_ggn_mc(
+ self, mc_samples: int, subsampling: List[int] = None
+ ) -> List[Tensor]:
+ """Compute the approximate matrix square root of the generalized Gauss-Newton.
+
+ Args:
+ mc_samples: Number of Monte-Carlo samples.
+ subsampling: Indices of active samples. Defaults to ``None`` (use all
+ samples in the mini-batch).
+
+ Returns:
+ Parameter-wise approximate matrix square root of the exact GGN.
+ """
+ with backpack(
+ new_ext.SqrtGGNMC(mc_samples=mc_samples, subsampling=subsampling)
+ ):
+ _, _, loss = self.problem.forward_pass()
+ loss.backward()
+ return self.problem.collect_data("sqrt_ggn_mc")
+
+ def ggn_mc(
+ self, mc_samples: int, chunks: int = 1, subsampling: List[int] = None
+ ) -> Tensor: # noqa:D102
+ samples = chunk_sizes(mc_samples, chunks)
+ weights = [samples / mc_samples for samples in samples]
- return diag_h_batch
+ return sum(
+ w * self._square_sqrt_ggn(self.sqrt_ggn_mc(s, subsampling=subsampling))
+ for w, s in zip(weights, samples)
+ )
+
+ @staticmethod
+ def _square_sqrt_ggn(sqrt_ggn: List[Tensor]) -> Tensor:
+ """Utility function to concatenate and square the GGN factorization.
+
+ Args:
+ sqrt_ggn: Parameter-wise matrix square root of the GGN.
+
+ Returns:
+ Matrix representation of the GGN.
+ """
+ sqrt_mat = cat([s.flatten(start_dim=2) for s in sqrt_ggn], dim=2)
+ return einsum("cni,cnj->ij", sqrt_mat, sqrt_mat)
diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py
index 0b2eb73b4..8f53af7fe 100644
--- a/test/extensions/implementation/base.py
+++ b/test/extensions/implementation/base.py
@@ -1,75 +1,148 @@
-class ExtensionsImplementation:
+"""Base class containing the functions to compare BackPACK and autograd."""
+from abc import ABC, abstractmethod
+from test.extensions.problem import ExtensionsTestProblem
+from typing import List, Union
+
+from torch import Tensor
+
+
+class ExtensionsImplementation(ABC):
"""Base class for autograd and BackPACK implementations of extensions."""
- def __init__(self, problem):
+ def __init__(self, problem: ExtensionsTestProblem):
+ """Store the test case.
+
+ Args:
+ problem: Test case.
+ """
self.problem = problem
- def batch_grad(self):
- """Individual gradients."""
- raise NotImplementedError
+ @abstractmethod
+ def batch_grad(self, subsampling: Union[List[int], None]) -> List[Tensor]:
+ """Individual gradients.
+
+ Args:
+ subsampling: List of active samples. ``None`` means all samples.
+ """
+ return
- def batch_l2_grad(self):
+ @abstractmethod
+ def batch_l2_grad(self) -> List[Tensor]:
"""L2 norm of Individual gradients."""
- raise NotImplementedError
+ return
- def sgs(self):
- """Sum of Square of Individual gradients"""
- raise NotImplementedError
+ @abstractmethod
+ def sgs(self) -> List[Tensor]:
+ """Sum of Square of Individual gradients."""
+ return
- def variance(self):
- """Variance of Individual gradients"""
- raise NotImplementedError
+ @abstractmethod
+ def variance(self) -> List[Tensor]:
+ """Variance of Individual gradients."""
+ return
- def diag_ggn(self):
- """Diagonal of Gauss Newton"""
- raise NotImplementedError
+ @abstractmethod
+ def diag_ggn(self) -> List[Tensor]:
+ """Diagonal of Gauss Newton."""
+ return
- def diag_ggn_batch(self):
- """Individual diagonal of Generalized Gauss-Newton/Fisher"""
- raise NotImplementedError
+ @abstractmethod
+ def diag_ggn_exact_batch(self) -> List[Tensor]:
+ """Individual diagonal of Generalized Gauss-Newton/Fisher."""
+ return
- def diag_ggn_mc(self, mc_samples):
- """MC approximation of Diagonal of Gauss Newton"""
- raise NotImplementedError
+ @abstractmethod
+ def diag_ggn_mc(self, mc_samples: int) -> List[Tensor]:
+ """MC approximation of the generalized Gauss-Newton/Fisher diagonal.
- def diag_ggn_mc_batch(self, mc_samples):
- """MC approximation of individual Generalized Gauss-Newton/Fisher diagonal."""
- raise NotImplementedError
+ Args:
+ mc_samples: Number of Monte-Carlo samples used for the approximation.
+ """
+ return
+
+ @abstractmethod
+ def diag_ggn_mc_batch(self, mc_samples: int) -> List[Tensor]:
+ """MC approximation of individual Generalized Gauss-Newton/Fisher diagonal.
+
+ Args:
+ mc_samples: Number of Monte-Carlo samples used for the approximation.
+ """
+ return
- def diag_h(self):
- """Diagonal of Hessian"""
- raise NotImplementedError
+ @abstractmethod
+ def diag_h(self) -> List[Tensor]:
+ """Diagonal of Hessian."""
+ return
- def kfac(self, mc_samples=1):
+ @abstractmethod
+ def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]:
"""Kronecker-factored approximate curvature (KFAC).
Args:
- mc_samples (int, optional): Number of Monte-Carlo samples. Default: ``1``.
+ mc_samples: Number of Monte-Carlo samples. Default: ``1``.
Returns:
- list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors.
+ Parameter-wise lists of Kronecker factors.
"""
- raise NotImplementedError
+ return
- def kflr(self):
+ @abstractmethod
+ def kflr(self) -> List[List[Tensor]]:
"""Kronecker-factored low-rank approximation (KFLR).
Returns:
- list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors.
+ Parameter-wise lists of Kronecker factors.
"""
- raise NotImplementedError
+ return
- def kfra(self):
+ @abstractmethod
+ def kfra(self) -> List[List[Tensor]]:
"""Kronecker-factored recursive approximation (KFRA).
Returns:
- list(list(torch.Tensor)): Parameter-wise lists of Kronecker factors.
+ Parameter-wise lists of Kronecker factors.
"""
+ return
- def diag_h_batch(self):
+ @abstractmethod
+ def diag_h_batch(self) -> List[Tensor]:
"""Per-sample Hessian diagonal.
Returns:
- list(torch.Tensor): Parameter-wise per-sample Hessian diagonal.
+ Parameter-wise per-sample Hessian diagonal.
+ """
+ return
+
+ @abstractmethod
+ def ggn(self, subsampling: List[int] = None) -> Tensor:
+ """Exact generalized Gauss-Newton/Fisher matrix.
+
+ Note:
+ For losses with ``'mean'`` reduction, the GGN is ``¹/N ∑ₙ Jₙᵀ Hₙ Jₙ``. If
+ sub-sampling is enabled, the sum will only run over active samples. The
+ normalization will not be ``1/len(subsampling)``, but remain ``1/N``.
+
+ Args:
+ subsampling: Indices of active samples. Default: ``None`` (all).
+
+ Returns:
+ Matrix representation of the exact GGN.
+ """
+ return
+
+ @abstractmethod
+ def ggn_mc(
+ self, mc_samples: int, chunks: int = 1, subsampling: List[int] = None
+ ) -> Tensor:
+ """Compute the MC-approximation of the GGN in chunks of MC samples.
+
+ Args:
+ mc_samples: Number of Monte-Carlo samples.
+ chunks: Number of sequential portions to split the computation.
+ Default: ``1`` (no sequential split).
+ subsampling: Indices of active samples. Default: ``None`` (all).
+
+ Returns:
+ Matrix representation of the Monte-Carlo approximated GGN.
"""
- raise NotImplementedError
+ return
diff --git a/test/extensions/implementation/hooks.py b/test/extensions/implementation/hooks.py
index cd3bba1a5..5e6cc2f4c 100644
--- a/test/extensions/implementation/hooks.py
+++ b/test/extensions/implementation/hooks.py
@@ -1,4 +1,4 @@
-"""Post extension hooks to compact BackPACK quantities during backpropagation."""
+"""Extension hooks to compact BackPACK quantities during backpropagation."""
class ExtensionHookManager:
diff --git a/test/extensions/problem.py b/test/extensions/problem.py
index 4c41ef191..b2f96b92c 100644
--- a/test/extensions/problem.py
+++ b/test/extensions/problem.py
@@ -2,13 +2,25 @@
import copy
from test.core.derivatives.utils import get_available_devices
+from typing import Any, Iterator, List, Tuple
import torch
+from torch import Tensor
+from torch.nn.parameter import Parameter
from backpack import extend
+from backpack.utils.subsampling import subsample
def make_test_problems(settings):
+ """Creates test problems from settings.
+
+ Args:
+ settings (list[dict]): raw settings of the problems
+
+ Returns:
+ list[ExtensionTestProblem]
+ """
problem_dicts = []
for setting in settings:
@@ -24,11 +36,16 @@ def make_test_problems(settings):
def add_missing_defaults(setting):
- """Create extensions test problem from setting.
+ """Create full settings from setting.
+
Args:
setting (dict): configuration dictionary
+
Returns:
- ExtensionsTestProblem: problem with specified settings.
+ dict: full settings.
+
+ Raises:
+ ValueError: if no proper settings
"""
required = ["module_fn", "input_fn", "loss_function_fn", "target_fn"]
optional = {
@@ -53,6 +70,8 @@ def add_missing_defaults(setting):
class ExtensionsTestProblem:
+ """Class providing functions and parameters."""
+
def __init__(
self,
input_fn,
@@ -84,6 +103,7 @@ def __init__(
self.id_prefix = id_prefix
def set_up(self):
+ """Set up problem from settings."""
torch.manual_seed(self.seed)
self.model = self.module_fn().to(self.device)
@@ -92,10 +112,15 @@ def set_up(self):
self.loss_function = self.loss_function_fn().to(self.device)
def tear_down(self):
+ """Delete all variables after problem."""
del self.model, self.input, self.target, self.loss_function
def make_id(self):
- """Needs to function without call to `set_up`."""
+ """Needs to function without call to `set_up`.
+
+ Returns:
+ str: id of problem
+ """
prefix = (self.id_prefix + "-") if self.id_prefix != "" else ""
return (
prefix
@@ -107,37 +132,63 @@ def make_id(self):
).replace(" ", "")
)
- def forward_pass(self, sample_idx=None):
- """Do a forward pass. Return input, output, and parameters."""
- if sample_idx is None:
- input = self.input.clone().detach()
- target = self.target.clone().detach()
- else:
- input = self.input.clone()[sample_idx, :].unsqueeze(0).detach()
- target = self.target.clone()[sample_idx].unsqueeze(0).detach()
+ def forward_pass(
+ self, subsampling: List[int] = None
+ ) -> Tuple[Tensor, Tensor, Tensor]:
+ """Do a forward pass. Return input, output, and parameters.
+
+ If sub-sampling is None, the forward pass is calculated on the whole batch.
+
+ Args:
+ subsampling: Indices of selected samples. Default: ``None`` (all samples).
+
+ Returns:
+ input, output, and loss of the forward pass
+ """
+ input = self.input.clone()
+ target = self.target.clone()
+
+ if subsampling is not None:
+ batch_axis = 0
+ input = subsample(self.input, dim=batch_axis, subsampling=subsampling)
+ target = subsample(self.target, dim=batch_axis, subsampling=subsampling)
- print(self.target.shape)
- print(target.shape)
output = self.model(input)
loss = self.loss_function(output, target)
return input, output, loss
def extend(self):
+ """Extend module of problem."""
self.model = extend(self.model)
self.loss_function = extend(self.loss_function)
- def get_reduction_factor(self, loss, unreduced_loss):
- """Return the factor used to reduce the individual losses."""
+ @staticmethod
+ def __get_reduction_factor(loss: Tensor, unreduced_loss: Tensor) -> float:
+ """Return the factor used to reduce the individual losses.
+
+ Args:
+ loss: Reduced loss.
+ unreduced_loss: Unreduced loss.
+
+ Returns:
+ Reduction factor.
+
+ Raises:
+ RuntimeError: if either mean or sum cannot be determined
+ """
mean_loss = unreduced_loss.flatten().mean()
sum_loss = unreduced_loss.flatten().sum()
if torch.allclose(mean_loss, sum_loss):
- raise RuntimeError(
- "Cannot determine reduction factor. ",
- "Results from 'mean' and 'sum' reduction are identical. ",
- f"'mean': {mean_loss}, 'sum': {sum_loss}",
- )
- if torch.allclose(loss, mean_loss):
+ if unreduced_loss.numel() == 1 and torch.allclose(loss, sum_loss):
+ factor = 1.0
+ else:
+ raise RuntimeError(
+ "Cannot determine reduction factor. ",
+ "Results from 'mean' and 'sum' reduction are identical. ",
+ f"'mean': {mean_loss}, 'sum': {sum_loss}",
+ )
+ elif torch.allclose(loss, mean_loss):
factor = 1.0 / unreduced_loss.numel()
elif torch.allclose(loss, sum_loss):
factor = 1.0
@@ -147,3 +198,68 @@ def get_reduction_factor(self, loss, unreduced_loss):
f"'mean': {mean_loss}, 'sum': {sum_loss}, loss: {loss}",
)
return factor
+
+ def trainable_parameters(self) -> Iterator[Parameter]:
+ """Yield the model's trainable parameters.
+
+ Yields:
+ Model parameter with gradients enabled.
+ """
+ for p in self.model.parameters():
+ if p.requires_grad:
+ yield p
+
+ def collect_data(self, savefield: str) -> List[Any]:
+ """Collect BackPACK attributes from trainable parameters.
+
+ Args:
+ savefield: Attribute name.
+
+ Returns:
+ List of attributes saved under the trainable model parameters.
+
+ Raises:
+ RuntimeError: If a non-differentiable parameter with the attribute is
+ encountered.
+ """
+ data = []
+
+ for p in self.model.parameters():
+ if p.requires_grad:
+ data.append(getattr(p, savefield))
+ else:
+ if hasattr(p, savefield):
+ raise RuntimeError(
+ f"Found non-differentiable parameter with attribute '{savefield}'."
+ )
+
+ return data
+
+ def get_batch_size(self) -> int:
+ """Return the mini-batch size.
+
+ Returns:
+ Mini-batch size.
+ """
+ return self.input.shape[0]
+
+ def compute_reduction_factor(self) -> float:
+ """Compute loss function's reduction factor for aggregating per-sample losses.
+
+ For instance, if ``reduction='mean'`` is used, then the reduction factor
+ is ``1 / N`` where ``N`` is the batch size. With ``reduction='sum'``, it
+ is ``1``.
+
+ Returns:
+ Reduction factor
+ """
+ _, _, loss = self.forward_pass()
+
+ batch_size = self.get_batch_size()
+ loss_list = torch.zeros(batch_size, device=self.device)
+
+ for n in range(batch_size):
+ _, _, loss_n = self.forward_pass(subsampling=[n])
+ loss_list[n] = loss_n
+
+ return self.__get_reduction_factor(loss, loss_list)
diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py
new file mode 100644
index 000000000..6934f754e
--- /dev/null
+++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py
@@ -0,0 +1,283 @@
+"""Test cases for BackPACK extensions for the GGN diagonal.
+
+Includes
+- ``DiagGGNExact``
+- ``DiagGGNMC``
+- ``BatchDiagGGNExact``
+- ``BatchDiagGGNMC``
+
+Shared settings are taken from `test.extensions.secondorder.secondorder_settings`.
+Additional local cases can be defined here through ``LOCAL_SETTINGS``.
+"""
+from test.converter.resnet_cases import ResNet1, ResNet2
+from test.core.derivatives.utils import classification_targets, regression_targets
+from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS
+from test.utils.evaluation_mode import initialize_training_false_recursive
+
+from torch import rand, randint
+from torch.nn import (
+ LSTM,
+ RNN,
+ AdaptiveAvgPool1d,
+ AdaptiveAvgPool2d,
+ AdaptiveAvgPool3d,
+ BatchNorm1d,
+ BatchNorm2d,
+ BatchNorm3d,
+ Conv2d,
+ CrossEntropyLoss,
+ Embedding,
+ Flatten,
+ Identity,
+ Linear,
+ MaxPool2d,
+ MSELoss,
+ ReLU,
+ Sequential,
+ Sigmoid,
+)
+
+from backpack import convert_module_to_backpack
+from backpack.custom_module import branching
+from backpack.custom_module.branching import Parallel
+from backpack.custom_module.permute import Permute
+from backpack.custom_module.reduce_tuple import ReduceTuple
+
+SHARED_SETTINGS = SECONDORDER_SETTINGS
+LOCAL_SETTINGS = []
+##################################################################
+# RNN settings #
+##################################################################
+LOCAL_SETTINGS += [
+ # RNN settings
+ {
+ "input_fn": lambda: rand(8, 5, 6),
+ "module_fn": lambda: Sequential(
+ RNN(input_size=6, hidden_size=3, batch_first=True),
+ ReduceTuple(index=0),
+ Permute(0, 2, 1),
+ Flatten(),
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((8, 3 * 5)),
+ },
+ {
+ "input_fn": lambda: rand(4, 3, 5),
+ "module_fn": lambda: Sequential(
+ LSTM(input_size=5, hidden_size=4, batch_first=True),
+ ReduceTuple(index=0),
+ Flatten(),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
+ "target_fn": lambda: classification_targets((4,), 4 * 3),
+ },
+ {
+ "input_fn": lambda: rand(8, 5, 6),
+ "module_fn": lambda: Sequential(
+ RNN(input_size=6, hidden_size=3, batch_first=True),
+ ReduceTuple(index=0),
+ Linear(3, 3),
+ Permute(0, 2, 1),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
+ "target_fn": lambda: classification_targets((8, 5), 3),
+ },
+]
+##################################################################
+# AdaptiveAvgPool settings #
+##################################################################
+LOCAL_SETTINGS += [
+ {
+ "input_fn": lambda: rand(2, 2, 9),
+ "module_fn": lambda: Sequential(
+ Linear(9, 9), AdaptiveAvgPool1d((3,)), Flatten()
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((2, 2 * 3)),
+ },
+ {
+ "input_fn": lambda: rand(2, 2, 6, 8),
+ "module_fn": lambda: Sequential(
+ Linear(8, 8), AdaptiveAvgPool2d((3, 4)), Flatten()
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((2, 2 * 3 * 4)),
+ },
+ {
+ "input_fn": lambda: rand(2, 2, 9, 5, 4),
+ "module_fn": lambda: Sequential(
+ Linear(4, 4), AdaptiveAvgPool3d((3, 5, 2)), Flatten()
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((2, 2 * 3 * 5 * 2)),
+ },
+]
+##################################################################
+# BatchNorm settings #
+##################################################################
+LOCAL_SETTINGS += [
+ {
+ "input_fn": lambda: rand(2, 3, 4),
+ "module_fn": lambda: initialize_training_false_recursive(
+ Sequential(BatchNorm1d(num_features=3), Flatten())
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((2, 4 * 3)),
+ },
+ {
+ "input_fn": lambda: rand(3, 2, 4, 3),
+ "module_fn": lambda: initialize_training_false_recursive(
+ Sequential(BatchNorm2d(num_features=2), Flatten())
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((3, 2 * 4 * 3)),
+ },
+ {
+ "input_fn": lambda: rand(3, 3, 4, 1, 2),
+ "module_fn": lambda: initialize_training_false_recursive(
+ Sequential(BatchNorm3d(num_features=3), Flatten())
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((3, 3 * 4 * 1 * 2)),
+ },
+ {
+ "input_fn": lambda: rand(3, 3, 4, 1, 2),
+ "module_fn": lambda: initialize_training_false_recursive(
+ Sequential(
+ BatchNorm3d(num_features=3),
+ Linear(2, 3),
+ BatchNorm3d(num_features=3),
+ ReLU(),
+ BatchNorm3d(num_features=3),
+ Flatten(),
+ )
+ ),
+ "loss_function_fn": lambda: MSELoss(),
+ "target_fn": lambda: regression_targets((3, 4 * 1 * 3 * 3)),
+ },
+]
+###############################################################################
+# Embedding #
+###############################################################################
+LOCAL_SETTINGS += [
+ {
+ "input_fn": lambda: randint(0, 5, (6,)),
+ "module_fn": lambda: Sequential(
+ Embedding(5, 3),
+ Linear(3, 4),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((6,), 4),
+ },
+ {
+ "input_fn": lambda: randint(0, 3, (3, 2, 2)),
+ "module_fn": lambda: Sequential(
+ Embedding(3, 2),
+ Flatten(),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3,), 2 * 2),
+ "seed": 2,
+ },
+]
+
+
+###############################################################################
+# Branched models #
+###############################################################################
+LOCAL_SETTINGS += [
+ {
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(
+ Linear(10, 5),
+ ReLU(),
+ # skip connection
+ Parallel(
+ Identity(),
+ Linear(5, 5),
+ ),
+ # end of skip connection
+ Sigmoid(),
+ Linear(5, 4),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
+ "target_fn": lambda: classification_targets((3,), 4),
+ "id_prefix": "branching-linear",
+ },
+ {
+ "input_fn": lambda: rand(4, 2, 6, 6),
+ "module_fn": lambda: Sequential(
+ Conv2d(2, 3, kernel_size=3, stride=1, padding=1),
+ ReLU(),
+ # skip connection
+ Parallel(
+ Identity(),
+ Sequential(
+ Conv2d(3, 5, kernel_size=3, stride=1, padding=1),
+ ReLU(),
+ Conv2d(5, 3, kernel_size=3, stride=1, padding=1),
+ ),
+ ),
+ # end of skip connection
+ MaxPool2d(kernel_size=3, stride=2),
+ Flatten(),
+ Linear(12, 5),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
+ "target_fn": lambda: classification_targets((4,), 5),
+ "id_prefix": "branching-convolution",
+ },
+ {
+ "input_fn": lambda: rand(4, 3, 6, 6),
+ "module_fn": lambda: Sequential(
+ Conv2d(3, 2, kernel_size=3, stride=1, padding=1),
+ ReLU(),
+ # skip connection
+ Parallel(
+ Identity(),
+ Sequential(
+ Conv2d(2, 4, kernel_size=3, stride=1, padding=1),
+ Sigmoid(),
+ Conv2d(4, 2, kernel_size=3, stride=1, padding=1),
+ branching.Parallel(
+ Identity(),
+ Sequential(
+ Conv2d(2, 4, kernel_size=3, stride=1, padding=1),
+ ReLU(),
+ Conv2d(4, 2, kernel_size=3, stride=1, padding=1),
+ ),
+ ),
+ ),
+ ),
+ # end of skip connection
+ MaxPool2d(kernel_size=3, stride=2),
+ Flatten(),
+ Linear(8, 5),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
+ "target_fn": lambda: classification_targets((4,), 5),
+ "id_prefix": "nested-branching-convolution",
+ },
+]
+
+###############################################################################
+# Branched models - converter #
+###############################################################################
+LOCAL_SETTINGS += [
+ {
+ "input_fn": lambda: ResNet1.input_test,
+ "module_fn": lambda: convert_module_to_backpack(ResNet1(), True),
+ "loss_function_fn": lambda: ResNet1.loss_test,
+ "target_fn": lambda: ResNet1.target_test,
+ "id_prefix": "ResNet1",
+ },
+ {
+ "input_fn": lambda: rand(ResNet2.input_test),
+ "module_fn": lambda: convert_module_to_backpack(ResNet2().eval(), True),
+ "loss_function_fn": lambda: ResNet2.loss_test,
+ "target_fn": lambda: rand(ResNet2.target_test),
+ "id_prefix": "ResNet2",
+ },
+]
+
+DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS
diff --git a/test/extensions/secondorder/diag_ggn/diaggnn_settings.py b/test/extensions/secondorder/diag_ggn/diaggnn_settings.py
deleted file mode 100644
index e0d5b7bbc..000000000
--- a/test/extensions/secondorder/diag_ggn/diaggnn_settings.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""Test configurations to test diag_ggn
-
-The tests are taken from `test.extensions.secondorder.secondorder_settings`,
-but additional custom tests can be defined here by appending it to the list.
-"""
-
-from test.extensions.automated_settings import make_simple_act_setting
-from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS
-
-from torch.nn import ELU, SELU
-
-DiagGGN_SETTINGS = []
-
-SHARED_SETTINGS = SECONDORDER_SETTINGS
-
-LOCAL_SETTINGS = []
-
-###############################################################################
-# test setting: Activation Layers #
-###############################################################################
-activations = [ELU, SELU]
-
-for act in activations:
- for bias in [True, False]:
- LOCAL_SETTINGS.append(make_simple_act_setting(act, bias=bias))
-
-DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS
diff --git a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py
index 378c0a2c7..0030fa95b 100644
--- a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py
+++ b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py
@@ -1,8 +1,10 @@
+"""Test BatchDiagGGN extension."""
from test.automated_test import check_sizes_and_values
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import make_test_problems
-from test.extensions.secondorder.diag_ggn.diaggnn_settings import DiagGGN_SETTINGS
+from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS
+from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda
import pytest
@@ -11,16 +13,18 @@
@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
-def test_diag_ggn_batch(problem):
- """Test the individual diagonal of Generalized Gauss-Newton/Fisher
+def test_diag_ggn_exact_batch(problem, request):
+ """Test the individual diagonal of Generalized Gauss-Newton/Fisher.
Args:
problem (ExtensionsTestProblem): Problem for extension test.
+ request: problem request
"""
+ skip_adaptive_avg_pool3d_cuda(request)
problem.set_up()
backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch()
- autograd_res = AutogradExtensions(problem).diag_ggn_batch()
+ autograd_res = AutogradExtensions(problem).diag_ggn_exact_batch()
check_sizes_and_values(autograd_res, backpack_res)
problem.tear_down()
@@ -33,8 +37,9 @@ def test_diag_ggn_batch(problem):
@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_diag_ggn_mc_batch_light(problem):
- """Test the MC approximation of individual diagonal of
- Generalized Gauss-Newton/Fisher with few mc_samples (light version)
+ """Test the MC approximation of individual diagonal.
+
+ of Generalized Gauss-Newton/Fisher with few mc_samples (light version)
Args:
problem (ExtensionsTestProblem): Problem for extension test.
@@ -42,7 +47,7 @@ def test_diag_ggn_mc_batch_light(problem):
problem.set_up()
backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch()
- mc_samples = 5000
+ mc_samples = 6000
backpack_res_mc_avg = BackpackExtensions(problem).diag_ggn_mc_batch(mc_samples)
check_sizes_and_values(
@@ -54,8 +59,9 @@ def test_diag_ggn_mc_batch_light(problem):
@pytest.mark.montecarlo
@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_diag_ggn_mc_batch(problem):
- """Test the MC approximation of individual diagonal of Gauss-Newton
- with more samples (slow version)
+ """Test the MC approximation of individual diagonal.
+
+ of generalized Gauss-Newton with more samples (slow version)
Args:
problem (ExtensionsTestProblem): Problem for extension test.
diff --git a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py
index 858e479b5..0b7ba0469 100644
--- a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py
+++ b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py
@@ -1,8 +1,10 @@
+"""Test DiagGGN extension."""
from test.automated_test import check_sizes_and_values
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import make_test_problems
-from test.extensions.secondorder.diag_ggn.diaggnn_settings import DiagGGN_SETTINGS
+from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS
+from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda
import pytest
@@ -11,12 +13,14 @@
@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
-def test_diag_ggn(problem):
- """Test the diagonal of Gauss-Newton
+def test_diag_ggn(problem, request):
+ """Test the diagonal of generalized Gauss-Newton.
Args:
problem (ExtensionsTestProblem): Problem for extension test.
+ request: problem request
"""
+ skip_adaptive_avg_pool3d_cuda(request)
problem.set_up()
backpack_res = BackpackExtensions(problem).diag_ggn()
@@ -33,8 +37,9 @@ def test_diag_ggn(problem):
@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_diag_ggn_mc_light(problem):
- """Test the MC approximation of Diagonal of Gauss-Newton
- with few mc_samples (light version)
+ """Test the MC approximation of Diagonal of generalized Gauss-Newton.
+
+ with few mc_samples (light version)
Args:
problem (ExtensionsTestProblem): Problem for extension test.
@@ -54,8 +59,9 @@ def test_diag_ggn_mc_light(problem):
@pytest.mark.montecarlo
@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_diag_ggn_mc(problem):
- """Test the MC approximation of Diagonal of Gauss-Newton
- with more samples (slow version)
+ """Test the MC approximation of Diagonal of generalized Gauss-Newton.
+
+ with more samples (slow version)
Args:
problem (ExtensionsTestProblem): Problem for extension test.
diff --git a/test/extensions/secondorder/hbp/kfac_settings.py b/test/extensions/secondorder/hbp/kfac_settings.py
index 895b99cc5..a6d0ece1d 100644
--- a/test/extensions/secondorder/hbp/kfac_settings.py
+++ b/test/extensions/secondorder/hbp/kfac_settings.py
@@ -1,8 +1,13 @@
"""Define test cases for KFAC."""
-from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS
+from test.extensions.secondorder.secondorder_settings import (
+ GROUP_CONV_SETTINGS,
+ LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS,
+)
-SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS
+SHARED_NOT_SUPPORTED_SETTINGS = (
+ GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS
+)
LOCAL_NOT_SUPPORTED_SETTINGS = []
NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS
diff --git a/test/extensions/secondorder/hbp/kflr_settings.py b/test/extensions/secondorder/hbp/kflr_settings.py
index 6b74a2842..de61c5b3b 100644
--- a/test/extensions/secondorder/hbp/kflr_settings.py
+++ b/test/extensions/secondorder/hbp/kflr_settings.py
@@ -1,8 +1,13 @@
"""Define test cases for KFLR."""
-from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS
+from test.extensions.secondorder.secondorder_settings import (
+ GROUP_CONV_SETTINGS,
+ LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS,
+)
-SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS
+SHARED_NOT_SUPPORTED_SETTINGS = (
+ GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS
+)
LOCAL_NOT_SUPPORTED_SETTINGS = []
NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS
diff --git a/test/extensions/secondorder/hbp/kfra_settings.py b/test/extensions/secondorder/hbp/kfra_settings.py
index 5a28ab738..94e65c2b7 100644
--- a/test/extensions/secondorder/hbp/kfra_settings.py
+++ b/test/extensions/secondorder/hbp/kfra_settings.py
@@ -1,8 +1,13 @@
"""Define test cases for KFRA."""
-from test.extensions.secondorder.secondorder_settings import GROUP_CONV_SETTINGS
+from test.extensions.secondorder.secondorder_settings import (
+ GROUP_CONV_SETTINGS,
+ LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS,
+)
-SHARED_NOT_SUPPORTED_SETTINGS = GROUP_CONV_SETTINGS
+SHARED_NOT_SUPPORTED_SETTINGS = (
+ GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS
+)
LOCAL_NOT_SUPPORTED_SETTINGS = []
NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS
diff --git a/test/extensions/secondorder/hbp/test_kflr.py b/test/extensions/secondorder/hbp/test_kflr.py
index 79e46f186..3bb4a900b 100644
--- a/test/extensions/secondorder/hbp/test_kflr.py
+++ b/test/extensions/secondorder/hbp/test_kflr.py
@@ -19,7 +19,7 @@ def test_kflr_not_supported(problem):
"""
problem.set_up()
- with pytest.raises(NotImplementedError):
+ with pytest.raises(RuntimeError):
BackpackExtensions(problem).kflr()
problem.tear_down()
diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py
index f90456134..562be8a1e 100644
--- a/test/extensions/secondorder/secondorder_settings.py
+++ b/test/extensions/secondorder/secondorder_settings.py
@@ -8,14 +8,14 @@
Required entries:
"module_fn" (callable): Contains a model constructed from `torch.nn` layers
"input_fn" (callable): Used for specifying input function
- "target_fn" (callable): Fetches the groundtruth/target classes
+ "target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
Optional entries:
"device" [list(torch.device)]: List of devices to run the test on.
"id_prefix" (str): Prefix to be included in the test name.
- "seed" (int): seed for the random number for torch.rand
+ "seed" (int): seed for the random number for rand
"""
@@ -26,7 +26,7 @@
make_simple_pooling_setting,
)
-import torch
+from torch import device, rand
from torch.nn import (
ELU,
SELU,
@@ -39,12 +39,17 @@
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
+ CrossEntropyLoss,
+ Flatten,
LeakyReLU,
+ Linear,
LogSigmoid,
MaxPool1d,
MaxPool2d,
MaxPool3d,
+ MSELoss,
ReLU,
+ Sequential,
Sigmoid,
Tanh,
)
@@ -56,11 +61,11 @@
###############################################################################
example = {
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(torch.nn.Linear(10, 5)),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 5)),
+ "loss_function_fn": lambda: CrossEntropyLoss(),
"target_fn": lambda: classification_targets((3,), 5),
- "device": [torch.device("cpu")],
+ "device": [device("cpu")],
"seed": 0,
"id_prefix": "example",
}
@@ -70,28 +75,22 @@
SECONDORDER_SETTINGS += [
# classification
{
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(
- torch.nn.Linear(10, 7), torch.nn.Linear(7, 5)
- ),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 7), Linear(7, 5)),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
"target_fn": lambda: classification_targets((3,), 5),
},
{
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(
- torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5)
- ),
- "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
},
# Regression
{
- "input_fn": lambda: torch.rand(3, 10),
- "module_fn": lambda: torch.nn.Sequential(
- torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5)
- ),
- "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"),
+ "input_fn": lambda: rand(3, 10),
+ "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)),
+ "loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((3, 5)),
},
]
@@ -109,8 +108,8 @@
###############################################################################
# test setting: Pooling Layers #
"""
-Syntax with default parameters:
- - `torch.nn.MaxPoolNd(kernel_size, stride, padding, dilation,
+Syntax with default parameters:
+ - `MaxPoolNd(kernel_size, stride, padding, dilation,
return_indices, ceil_mode)`
"""
###############################################################################
@@ -150,18 +149,18 @@
###############################################################################
# test setting: Convolutional Layers #
"""
-Syntax with default parameters:
- - `torch.nn.ConvNd(in_channels, out_channels,
- kernel_size, stride=1, padding=0, dilation=1,
- groups=1, bias=True, padding_mode='zeros)`
+Syntax with default parameters:
+ - `ConvNd(in_channels, out_channels,
+ kernel_size, stride=1, padding=0, dilation=1,
+ groups=1, bias=True, padding_mode='zeros)`
- - `torch.nn.ConvTransposeNd(in_channels, out_channels,
- kernel_size, stride=1, padding=0, output_padding=0,
+ - `ConvTransposeNd(in_channels, out_channels,
+ kernel_size, stride=1, padding=0, output_padding=0,
groups=1, bias=True, dilation=1, padding_mode='zeros)`
-Note: There are 5 tests added to each `torch.nn.layers`.
-For `torch.nn.ConvTranspose2d` and `torch.nn.ConvTranspose3d`
-only 3 tests are added because they are very memory intensive.
+Note: There are 5 tests added to each `layers`.
+For `ConvTranspose2d` and `ConvTranspose3d`
+only 3 tests are added because they are very memory intensive.
"""
###############################################################################
@@ -233,3 +232,70 @@
]
SECONDORDER_SETTINGS += GROUP_CONV_SETTINGS
+
+# linear with additional dimension
+LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS = [
+ # regression
+ {
+ "input_fn": lambda: rand(3, 4, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()),
+ "loss_function_fn": lambda: MSELoss(reduction="mean"),
+ "target_fn": lambda: regression_targets((3, 8)),
+ "id_prefix": "one-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 5),
+ "module_fn": lambda: Sequential(
+ Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten()
+ ),
+ "loss_function_fn": lambda: MSELoss(reduction="mean"),
+ "target_fn": lambda: regression_targets((3, 16)),
+ "id_prefix": "two-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 3, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()),
+ "loss_function_fn": lambda: MSELoss(reduction="sum"),
+ "target_fn": lambda: regression_targets((3, 48)),
+ "id_prefix": "three-additional",
+ },
+ # classification
+ {
+ "input_fn": lambda: rand(3, 4, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2), Flatten()),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3,), 8),
+ "id_prefix": "one-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 5),
+ "module_fn": lambda: Sequential(
+ Linear(5, 3), Sigmoid(), Linear(3, 2), Flatten()
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3,), 16),
+ "id_prefix": "two-additional",
+ },
+ {
+ "input_fn": lambda: rand(3, 4, 2, 3, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2), Flatten()),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
+ "target_fn": lambda: classification_targets((3,), 48),
+ "id_prefix": "three-additional",
+ },
+]
+
+SECONDORDER_SETTINGS += LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS
+
+###############################################################################
+# test setting: CrossEntropyLoss #
+###############################################################################
+SECONDORDER_SETTINGS += [
+ {
+ "input_fn": lambda: rand(3, 4, 2, 3, 5),
+ "module_fn": lambda: Sequential(Linear(5, 3), ReLU(), Linear(3, 2)),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
+ "target_fn": lambda: classification_targets((3, 2, 3, 2), 4),
+ "id_prefix": "multi-d-CrossEntropyLoss",
+ },
+]
diff --git a/test/extensions/secondorder/sqrt_ggn/__init__.py b/test/extensions/secondorder/sqrt_ggn/__init__.py
new file mode 100644
index 000000000..6741d2c13
--- /dev/null
+++ b/test/extensions/secondorder/sqrt_ggn/__init__.py
@@ -0,0 +1 @@
+"""Contains tests of ``backpack.extensions.secondorder.sqrt_ggn``."""
diff --git a/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py b/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py
new file mode 100644
index 000000000..f6fdc34cd
--- /dev/null
+++ b/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py
@@ -0,0 +1,33 @@
+"""Contains test settings for testing SqrtGGN extension."""
+from test.core.derivatives.utils import classification_targets
+from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS
+
+from torch import randint
+from torch.nn import CrossEntropyLoss, Embedding, Flatten, Linear, Sequential
+
+SQRT_GGN_SETTINGS = SECONDORDER_SETTINGS
+
+###############################################################################
+# Embedding #
+###############################################################################
+SQRT_GGN_SETTINGS += [
+ {
+ "input_fn": lambda: randint(0, 5, (6,)),
+ "module_fn": lambda: Sequential(
+ Embedding(5, 3),
+ Linear(3, 4),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((6,), 4),
+ },
+ {
+ "input_fn": lambda: randint(0, 3, (3, 2, 2)),
+ "module_fn": lambda: Sequential(
+ Embedding(3, 2),
+ Flatten(),
+ ),
+ "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
+ "target_fn": lambda: classification_targets((3,), 2 * 2),
+ "seed": 1,
+ },
+]
diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py
new file mode 100644
index 000000000..92c44f152
--- /dev/null
+++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py
@@ -0,0 +1,105 @@
+"""Tests BackPACK's ``SqrtGGNExact`` and ``SqrtGGNMC`` extension."""
+
+from math import isclose
+from test.automated_test import check_sizes_and_values
+from test.extensions.implementation.autograd import AutogradExtensions
+from test.extensions.implementation.backpack import BackpackExtensions
+from test.extensions.problem import ExtensionsTestProblem, make_test_problems
+from test.extensions.secondorder.sqrt_ggn.sqrt_ggn_settings import SQRT_GGN_SETTINGS
+from test.utils.skip_test import skip_large_parameters, skip_subsampling_conflict
+from typing import List, Union
+
+from pytest import fixture, mark
+
+PROBLEMS = make_test_problems(SQRT_GGN_SETTINGS)
+
+SUBSAMPLINGS = [None, [0, 0], [2, 0]]
+SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS]
+
+
+@fixture(params=PROBLEMS, ids=lambda p: p.make_id())
+def problem(request) -> ExtensionsTestProblem:
+ """Set seed, create tested model, loss, data. Finally clean up.
+
+ Args:
+ request (SubRequest): Request for the fixture from a test/fixture function.
+
+ Yields:
+ Test case with deterministically constructed attributes.
+ """
+ case = request.param
+ case.set_up()
+ yield case
+ case.tear_down()
+
+
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+def test_ggn_exact(
+ problem: ExtensionsTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Compare exact GGN from BackPACK's matrix square root with autograd.
+
+ Args:
+ problem: Test case with small network whose GGN can be evaluated.
+ subsampling: Indices of active samples. ``None`` uses the full mini-batch.
+ """
+ skip_large_parameters(problem)
+ skip_subsampling_conflict(problem, subsampling)
+
+ autograd_res = AutogradExtensions(problem).ggn(subsampling=subsampling)
+ backpack_res = BackpackExtensions(problem).ggn(subsampling=subsampling)
+
+ check_sizes_and_values(autograd_res, backpack_res)
+
+
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+def test_sqrt_ggn_mc_integration(
+ problem: ExtensionsTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Check if MC-approximated GGN matrix square root code executes.
+
+ Note:
+ This test does not perform correctness checks on the results,
+ which are expensive because a large number of samples is required.
+ Such a check is performed by `test_sqrt_ggn_mc`, which is run less
+ frequently.
+
+ Args:
+ problem: Test case with small network whose GGN can be evaluated.
+ subsampling: Indices of active samples. ``None`` uses the full mini-batch.
+ """
+ skip_large_parameters(problem)
+ skip_subsampling_conflict(problem, subsampling)
+
+ BackpackExtensions(problem).sqrt_ggn_mc(mc_samples=1, subsampling=subsampling)
+
+
+@mark.montecarlo
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+def test_ggn_mc(
+ problem: ExtensionsTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Compare MC-approximated GGN from BackPACK with exact version from autograd.
+
+ Args:
+ problem: Test case with small network whose GGN can be evaluated.
+ subsampling: Indices of active samples. ``None`` uses the full mini-batch.
+ """
+ skip_large_parameters(problem)
+ skip_subsampling_conflict(problem, subsampling)
+
+ autograd_res = AutogradExtensions(problem).ggn(subsampling=subsampling)
+ atol, rtol = 5e-3, 5e-3
+ mc_samples, chunks = 150000, 15
+ backpack_res = BackpackExtensions(problem).ggn_mc(
+ mc_samples, chunks=chunks, subsampling=subsampling
+ )
+
+ # compare normalized entries ∈ [-1; 1] (easier to tune atol)
+ max_val = max(autograd_res.abs().max(), backpack_res.abs().max())
+ # NOTE: The GGN can be exactly zero; e.g. if a ReLU after all parameters zeroes
+ # its input, its Jacobian is thus zero and will cancel the backpropagated GGN
+ if not isclose(max_val, 0):
+ autograd_res, backpack_res = autograd_res / max_val, backpack_res / max_val
+
+ check_sizes_and_values(autograd_res, backpack_res, atol=atol, rtol=rtol)
diff --git a/test/extensions/test_hooks.py b/test/extensions/test_hooks.py
index 119afa8be..bc6058f87 100644
--- a/test/extensions/test_hooks.py
+++ b/test/extensions/test_hooks.py
@@ -3,89 +3,194 @@
These tests aim at demonstrating the pitfalls one may run into when using hooks that
iterate over ``module.parameters()``.
"""
-
from test.core.derivatives.utils import classification_targets, get_available_devices
+from typing import Tuple
-import pytest
-import torch
+from pytest import fixture, mark, raises
+from torch import Tensor, manual_seed, rand
+from torch.nn import CrossEntropyLoss, Linear, Module, Sequential
-from backpack import backpack, extend, extensions
+from backpack import backpack, extend
+from backpack.extensions import BatchGrad, DiagGGNExact
+from backpack.extensions.backprop_extension import FAIL_ERROR, BackpropExtension
DEVICES = get_available_devices()
DEVICES_ID = [str(dev) for dev in DEVICES]
+NESTED_SEQUENTIAL = "NESTED_SEQUENTIAL"
+CUSTOM_CONTAINER = "CUSTOM_CONTAINER"
+problem_list = [NESTED_SEQUENTIAL, CUSTOM_CONTAINER]
+
+
+@fixture(params=DEVICES, ids=DEVICES_ID)
+def device(request):
+ """Yields the available device for the test.
+
+ Args:
+ request: pytest request
+
+ Yields:
+ an available device
+ """
+ yield request.param
-def set_up(device):
- """Return extended nested sequential with loss from a forward pass."""
- torch.manual_seed(0)
+
+@fixture(params=problem_list, ids=problem_list)
+def problem(device, request) -> Tuple[Module, Tensor, str]:
+ """Return extended nested sequential with loss from a forward pass.
+
+ Args:
+ device: available device
+ request: pytest request
+
+ Yields:
+ model, loss and problem_string
+
+ Raises:
+ NotImplementedError: if the problem_string is unknown
+ """
+ problem_string = request.param
+ manual_seed(0)
B = 2
- X = torch.rand(B, 4).to(device)
+ X = rand(B, 4).to(device)
y = classification_targets((B,), 2).to(device)
- model = torch.nn.Sequential(
- torch.nn.Linear(4, 3, bias=False),
- torch.nn.Sequential(
- torch.nn.Linear(3, 2, bias=False),
- ),
- ).to(device)
+ if problem_string == NESTED_SEQUENTIAL:
+ model = Sequential(
+ Linear(4, 3, bias=False),
+ Sequential(
+ Linear(3, 2, bias=False),
+ ),
+ )
+ elif problem_string == CUSTOM_CONTAINER:
+
+ class _MyCustomModule(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = Linear(4, 3, bias=False)
+ self.linear2 = Linear(3, 2, bias=False)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.linear2(x)
+ return x
+
+ model = _MyCustomModule()
+ else:
+ raise NotImplementedError(
+ f"problem={problem_string} but no test setting for this."
+ )
+
+ model = extend(model.to(device))
+ lossfunc = extend(CrossEntropyLoss(reduction="mean").to(device))
+ loss = lossfunc(model(X), y)
+ yield model, loss, problem_string
+
- model = extend(model)
- lossfunc = extend(torch.nn.CrossEntropyLoss(reduction="mean"))
+@mark.parametrize(
+ "extension", [BatchGrad(), DiagGGNExact()], ids=["BatchGrad", "DiagGGNExact"]
+)
+def test_extension_hook_multiple_parameter_visits(
+ problem, extension: BackpropExtension
+):
+ """Tests whether each parameter is visited exactly once.
- loss = lossfunc(model(X), y)
+ For those cases where parameters are visited more than once (e.g. Custom containers),
+ it tests that an error is raised.
- return model, loss
+ Furthermore, it is tested whether first order extensions run fine in either case,
+ and second order extensions raise an error in the case of custom containers.
+ Args:
+ problem: test problem, consisting of model, loss, and problem_string
+ extension: first or second order extension to test
-@pytest.mark.parametrize("device", DEVICES, ids=DEVICES_ID)
-def test_extension_hook_multiple_parameter_visits(device):
- """Extension hooks iterating over parameters may traverse them more than once."""
- model, loss = set_up(device)
+ Raises:
+ NotImplementedError: if the problem_string is unknown
+ """
+ model, loss, problem_string = problem
params_visited = {id(p): 0 for p in model.parameters()}
def count_visits(module):
- """Increase counter in ``params_visited`` for all parameters in ``module``."""
+ """Increase counter in ``params_visited`` for all parameters in ``module``.
+
+ Args:
+ module: the module of which the parameter visits are counted
+ """
for p in module.parameters():
params_visited[id(p)] += 1
- with backpack(extension_hook=count_visits, debug=True):
+ if problem_string == CUSTOM_CONTAINER and extension._fail_mode == FAIL_ERROR:
+ with raises(NotImplementedError):
+ with backpack(extension, extension_hook=count_visits, debug=True):
+ loss.backward()
+ return
+ with backpack(extension, extension_hook=count_visits, debug=True):
loss.backward()
- def check():
- """Raise ``AssertionError`` if a parameter has been visited more than once."""
+ def check_all_parameters_visited_once():
+ """Checks whether all parameters have been visited exactly once.
+
+ Raises:
+ AssertionError: if a parameter hasn't been visited exactly once
+ """
for param_id, visits in params_visited.items():
- if visits == 0:
- raise ValueError(f"Hook never visited param {param_id}")
- elif visits == 1:
- pass
- else:
- raise AssertionError(f"Hook visited param {param_id} {visits} times ")
+ if visits != 1:
+ raise AssertionError(f"Hook visited param {param_id} {visits}≠1 times")
+
+ if problem_string == NESTED_SEQUENTIAL:
+ check_all_parameters_visited_once()
+ elif problem_string == CUSTOM_CONTAINER:
+ with raises(AssertionError):
+ check_all_parameters_visited_once()
+ else:
+ raise NotImplementedError(f"unknown problem_string={problem_string}")
- with pytest.raises(AssertionError):
- check()
+def test_extension_hook_param_before_savefield_exists(problem):
+ """Extension hooks iterating over parameters may get called before BackPACK.
-@pytest.mark.parametrize("device", DEVICES, ids=DEVICES_ID)
-def test_extension_hook_param_before_savefield_exists(device):
- """Extension hooks iterating over parameters may get called before BackPACK."""
- _, loss = set_up(device)
+ This leads to the case, that the BackPACK quantities might not be calculated yet.
+ Thus, derived quantities cannot be calculated.
+
+ Sequential containers just work fine.
+ Custom containers crash.
+
+ Args:
+ problem: problem consisting of model, loss, and problem_string
+
+ Raises:
+ NotImplementedError: if problem_string is unknown
+ """
+ _, loss, problem_string = problem
params_without_grad_batch = []
def check_grad_batch(module):
- """Raise ``AssertionError`` if one parameter misses ``'grad_batch'``."""
+ """Check whether the module has a grad_batch attribute.
+
+ Args:
+ module: the module to check
+
+ Raises:
+ AssertionError: if a parameter does not have grad_batch attribute.
+ """
for p in module.parameters():
if not hasattr(p, "grad_batch"):
params_without_grad_batch.append(id(p))
raise AssertionError(f"Param {id(p)} has no 'grad_batch' attribute")
- # AssertionError is caught inside BackPACK and will raise a RuntimeError
- with pytest.raises(RuntimeError):
- with backpack(
- extensions.BatchGrad(), extension_hook=check_grad_batch, debug=True
- ):
+ if problem_string == NESTED_SEQUENTIAL:
+ with backpack(BatchGrad(), extension_hook=check_grad_batch, debug=True):
loss.backward()
- assert len(params_without_grad_batch) > 0
+ assert len(params_without_grad_batch) == 0
+ elif problem_string == CUSTOM_CONTAINER:
+ with raises(AssertionError):
+ with backpack(BatchGrad(), extension_hook=check_grad_batch, debug=True):
+ loss.backward()
+ assert len(params_without_grad_batch) > 0
+ else:
+ raise NotImplementedError(f"unknown problem_string={problem_string}")
diff --git a/test/extensions/utils.py b/test/extensions/utils.py
new file mode 100644
index 000000000..b517a3c8d
--- /dev/null
+++ b/test/extensions/utils.py
@@ -0,0 +1,21 @@
+"""Utility functions for testing BackPACK's extensions."""
+
+from test.extensions.problem import ExtensionsTestProblem
+from typing import List, Union
+
+from pytest import skip
+
+
+def skip_if_subsampling_conflict(
+ problem: ExtensionsTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Skip if some samples in subsampling are not contained in input.
+
+ Args:
+ problem: Test case.
+ subsampling: Indices of active samples.
+ """
+ N = problem.input.shape[0]
+ enough_samples = subsampling is None or N > max(subsampling)
+ if not enough_samples:
+ skip(f"Not enough samples: N={N}, subsampling={subsampling}")
diff --git a/test/test___init__.py b/test/test___init__.py
index 600b82b0f..16abef1fc 100644
--- a/test/test___init__.py
+++ b/test/test___init__.py
@@ -1,5 +1,6 @@
"""Tests for `backpack.__init__.py`."""
+from contextlib import nullcontext
from test import pytorch_current_memory_usage
from test.core.derivatives.utils import classification_targets, get_available_devices
@@ -12,21 +13,6 @@
DEVICES_ID = [str(dev) for dev in DEVICES]
-# TODO Use contextlib.nullcontext after dropping Python 3.6 support
-class nullcontext:
- """Empty context.
-
- ``contextlib.nullcontext`` is available from Python 3.7 onwards.
- The tests are also executed on Python 3.6.
- """
-
- def __enter__(self):
- pass
-
- def __exit__(self, type, value, traceback):
- pass
-
-
def test_no_io():
"""Check IO is not tracked."""
torch.manual_seed(0)
diff --git a/test/test_batch_first.py b/test/test_batch_first.py
new file mode 100644
index 000000000..ec962dab4
--- /dev/null
+++ b/test/test_batch_first.py
@@ -0,0 +1,25 @@
+"""Tests whether batch axis is always first."""
+from pytest import raises
+
+from backpack.custom_module.permute import Permute
+
+
+def test_permute_batch_axis() -> None:
+ """Verify that an Error is raised in the correct settings."""
+ Permute(0, 1, 2)
+ Permute(0, 2, 1)
+ Permute(0, 2, 3, 1)
+ with raises(ValueError):
+ Permute(1, 0, 2)
+ with raises(ValueError):
+ Permute(2, 0, 1)
+
+ Permute(1, 2, init_transpose=True)
+ Permute(3, 1, init_transpose=True)
+ Permute(2, 1, init_transpose=True)
+ with raises(ValueError):
+ Permute(0, 1, init_transpose=True)
+ with raises(ValueError):
+ Permute(1, 0, init_transpose=True)
+ with raises(ValueError):
+ Permute(2, 0, init_transpose=True)
diff --git a/test/test_retain_graph.py b/test/test_retain_graph.py
new file mode 100644
index 000000000..055f18d14
--- /dev/null
+++ b/test/test_retain_graph.py
@@ -0,0 +1,99 @@
+"""Test autograd functionality like retain_graph."""
+from test.automated_test import check_sizes_and_values
+
+from pytest import raises
+from torch import autograd, manual_seed, ones_like, rand, randint, randn, zeros
+from torch.nn import CrossEntropyLoss, Linear, Module, Sequential
+
+from backpack import backpack, extend
+from backpack.extensions import BatchGrad
+
+
+def test_retain_graph():
+ """Tests whether retain_graph works as expected.
+
+ Does several forward and backward passes.
+ In between, it is tested whether BackPACK quantities are present or not.
+ """
+ manual_seed(0)
+ model = extend(Sequential(Linear(4, 6), Linear(6, 5)))
+ loss_fn = extend(CrossEntropyLoss())
+
+ # after a forward pass graph is not clear
+ inputs = rand(8, 4)
+ labels = randint(5, (8,))
+ loss = loss_fn(model(inputs), labels)
+ with raises(AssertionError):
+ _check_no_io(model)
+
+ # after a normal backward pass graph should be clear
+ loss.backward()
+ _check_no_io(model)
+
+ # after a backward pass with retain_graph=True graph is not clear
+ loss = loss_fn(model(inputs), labels)
+ with backpack(retain_graph=True):
+ loss.backward(retain_graph=True)
+ with raises(AssertionError):
+ _check_no_io(model)
+
+ # doing several backward passes with retain_graph=True
+ for _ in range(3):
+ with backpack(retain_graph=True):
+ loss.backward(retain_graph=True)
+ with raises(AssertionError):
+ _check_no_io(model)
+
+ # finally doing a normal backward pass that verifies graph is clear again
+ with backpack(BatchGrad()):
+ loss.backward()
+ _check_no_io(model)
+
+
+def _check_no_io(module: Module) -> None:
+ """Checks whether the module is clear of any BackPACK inputs and outputs.
+
+ Args:
+ module: The module to test
+
+ Raises:
+ AssertionError: if the module or any child module has BackPACK inputs or outputs.
+ """
+ for child_module in module.children():
+ _check_no_io(child_module)
+
+ io_strs = ["input0", "output"]
+ if any(hasattr(module, io) for io in io_strs):
+ raise AssertionError(f"IO should be clear, but {module} has one of {io_strs}.")
+
+
+def test_for_loop_replace() -> None:
+ """Application of retain_graph: replace an outer for-loop.
+
+ This test is based on issue #220 opened by Romain3Ch216.
+ It computes per-component individual gradients of a tensor-valued output
+ with a for loop over components, rather than over samples and components.
+ """
+ manual_seed(0)
+ B = 5
+ M = 3
+ h = 2
+
+ x = randn(B, h)
+ fc = extend(Linear(h, M))
+ A = fc(x)
+
+ grad_autograd = zeros(B, M, *fc.weight.shape)
+ for b in range(B):
+ for m in range(M):
+ with backpack(retain_graph=True):
+ grads = autograd.grad(A[b, m], fc.weight, retain_graph=True)
+ grad_autograd[b, m] = grads[0]
+
+ grad_backpack = zeros(B, M, *fc.weight.shape)
+ for i in range(M):
+ with backpack(BatchGrad(), retain_graph=True):
+ A[:, i].backward(ones_like(A[:, i]), retain_graph=True)
+ grad_backpack[:, i] = fc.weight.grad_batch
+
+ check_sizes_and_values(grad_backpack, grad_autograd)
diff --git a/test/test_second_order_warnings.py b/test/test_second_order_warnings.py
index 5b7370552..39591e6ba 100644
--- a/test/test_second_order_warnings.py
+++ b/test/test_second_order_warnings.py
@@ -5,6 +5,8 @@
- using unsupported parameters of the loss
"""
+from test.core.derivatives.utils import classification_targets
+
import pytest
import torch
from torch.nn import CrossEntropyLoss, MSELoss
@@ -29,15 +31,10 @@
]
-def classification_targets(N, num_classes):
- """Create random targets for classes 0, ..., `num_classes - 1`."""
- return torch.randint(size=(N,), low=0, high=num_classes)
-
-
def dummy_cross_entropy(N=5):
y_pred = torch.rand((N, 2))
y_pred.requires_grad = True
- y = classification_targets(N, 2)
+ y = classification_targets((N,), 2)
loss_module = extend(CrossEntropyLoss())
return loss_module(y_pred, y)
diff --git a/test/test_simple_resnet.py b/test/test_simple_resnet.py
deleted file mode 100644
index 27e3a9380..000000000
--- a/test/test_simple_resnet.py
+++ /dev/null
@@ -1,144 +0,0 @@
-"""An example to check if BackPACK' first-order extensions are working for ResNets."""
-
-from test.core.derivatives.utils import classification_targets
-
-import torch
-
-from backpack import backpack, extend, extensions
-
-from .automated_test import check_sizes, check_values
-
-
-def autograd_individual_gradients(X, y, model, loss_func):
- """Individual gradients via for loop with automatic differentiation.
-
- Args:
- X (torch.Tensor): Mini-batch of shape `(N, *)`
- y (torch.Tensor: Labels for `X`
- model (torch.nn.Module): Model for forward pass
- loss_func (torch.nn.Module): Loss function for model prediction
-
- Returns:
- [torch.Tensor]: Individual gradients for samples in the mini-batch
- with respect to the model parameters. Arranged in the same order
- as `model.parameters()`.
- """
- N = X.shape[0]
- reduction_factor = _get_reduction_factor(X, y, model, loss_func)
-
- individual_gradients = [
- torch.zeros(N, *p.shape).to(X.device) for p in model.parameters()
- ]
-
- for n in range(N):
- x_n = X[n].unsqueeze(0)
- y_n = y[n].unsqueeze(0)
-
- f_n = model(x_n)
- l_n = loss_func(f_n, y_n) / reduction_factor
-
- g_n = torch.autograd.grad(l_n, model.parameters())
-
- for idx, g in enumerate(g_n):
- individual_gradients[idx][n] = g
-
- return individual_gradients
-
-
-def _get_reduction_factor(X, y, model, loss_func):
- """Return reduction factor of loss function."""
- N = X.shape[0]
-
- x_0 = X[0].unsqueeze(0)
- y_0 = y[0].unsqueeze(0)
-
- x_0_repeated = x_0.repeat([N if pos == 0 else 1 for pos, _ in enumerate(X.shape)])
- y_0_repeated = y_0.repeat([N if pos == 0 else 1 for pos, _ in enumerate(y.shape)])
-
- individual_loss = loss_func(model(x_0), y_0)
- reduced_loss = loss_func(model(x_0_repeated), y_0_repeated)
-
- return (N * individual_loss / reduced_loss).item()
-
-
-def backpack_individual_gradients(X, y, model, loss_func):
- """Individual gradients with BackPACK.
-
- Args:
- X (torch.Tensor): Mini-batch of shape `(N, *)`
- y (torch.Tensor: Labels for `X`
- model (torch.nn.Module): Model for forward pass
- loss_func (torch.nn.Module): Loss function for model prediction
-
- Returns:
- [torch.Tensor]: Individual gradients for samples in the mini-batch
- with respect to the model parameters. Arranged in the same order
- as `model.parameters()`.
- """
- model = extend(model)
- loss_func = extend(loss_func)
-
- loss = loss_func(model(X), y)
-
- with backpack(extensions.BatchGrad()):
- loss.backward()
-
- individual_gradients = [p.grad_batch for p in model.parameters()]
-
- return individual_gradients
-
-
-class Identity(torch.nn.Module):
- """Identity operation."""
-
- def forward(self, input):
- return input
-
-
-class Parallel(torch.nn.Sequential):
- """Feed input to multiple modules, sum the result.
-
- |-----|
- | -> | f_1 | -> |
- | |-----| |
- | |
- | |-----| |
- x ->| -> | f_2 | -> + -> f₁(x) + f₂(x) + ...
- | |-----| |
- | |
- | |-----| |
- | -> | ... | -> |
- |-----|
-
- """
-
- def forward(self, input):
- """Process input with all modules, sum the output."""
- for idx, module in enumerate(self.children()):
- if idx == 0:
- output = module(input)
- else:
- output = output + module(input)
-
- return output
-
-
-def test_individual_gradients_simple_resnet():
- """Individual gradients for a simple ResNet with autodiff and BackPACK."""
-
- # batch size, feature dimension
- N, D = 2, 5
- # classification
- C = 3
-
- X = torch.rand(N, D)
- y = classification_targets((N,), num_classes=C)
-
- model = Parallel(Identity(), torch.nn.Linear(D, D, bias=True))
- loss_func = torch.nn.CrossEntropyLoss(reduction="sum")
-
- result_autograd = autograd_individual_gradients(X, y, model, loss_func)
- result_backpack = backpack_individual_gradients(X, y, model, loss_func)
-
- check_sizes(result_autograd, result_backpack)
- check_values(result_autograd, result_backpack)
diff --git a/test/utils/__init__.py b/test/utils/__init__.py
index e69de29bb..40711349c 100644
--- a/test/utils/__init__.py
+++ b/test/utils/__init__.py
@@ -0,0 +1,27 @@
+"""Helper functions for tests."""
+
+from typing import List
+
+
+def chunk_sizes(total_size: int, num_chunks: int) -> List[int]:
+ """Return list containing the sizes of chunks.
+
+ Args:
+ total_size: Total computation work.
+ num_chunks: Maximum number of chunks the work will be split into.
+
+ Returns:
+ List of chunks with split work.
+ """
+ chunk_size = max(total_size // num_chunks, 1)
+
+ if chunk_size == 1:
+ sizes = total_size * [chunk_size]
+ else:
+ equal, rest = divmod(total_size, chunk_size)
+ sizes = equal * [chunk_size]
+
+ if rest != 0:
+ sizes.append(rest)
+
+ return sizes
diff --git a/test/utils/evaluation_mode.py b/test/utils/evaluation_mode.py
new file mode 100644
index 000000000..f4e57be77
--- /dev/null
+++ b/test/utils/evaluation_mode.py
@@ -0,0 +1,40 @@
+"""Tools for initializing in evaluation mode, especially BatchNorm."""
+from typing import Union
+
+from torch import rand_like
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d, Module
+
+
+def initialize_training_false_recursive(module: Module) -> Module:
+ """Initializes a module recursively in evaluation mode.
+
+ Args:
+ module: the module to initialize
+
+ Returns:
+ initialized module in evaluation mode
+ """
+ if isinstance(module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)):
+ initialize_batch_norm_eval(module)
+ else:
+ for module_child in module.children():
+ initialize_training_false_recursive(module_child)
+ return module.train(False)
+
+
+def initialize_batch_norm_eval(
+ module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]
+) -> Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]:
+ """Initializes a BatchNorm module in evaluation mode.
+
+ Args:
+ module: BatchNorm module
+
+ Returns:
+ the initialized BatchNorm module in evaluation mode
+ """
+ module.running_mean = rand_like(module.running_mean)
+ module.running_var = rand_like(module.running_var)
+ module.weight.data = rand_like(module.weight)
+ module.bias.data = rand_like(module.bias)
+ return module.train(False)
diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py
new file mode 100644
index 000000000..4f282662f
--- /dev/null
+++ b/test/utils/skip_test.py
@@ -0,0 +1,71 @@
+"""Skip specific tests."""
+
+from test.core.derivatives.problem import DerivativesTestProblem
+from test.extensions.problem import ExtensionsTestProblem
+from typing import List, Union
+
+from pytest import skip
+from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
+
+from backpack.utils import ADAPTIVE_AVG_POOL_BUG
+
+
+def skip_adaptive_avg_pool3d_cuda(request) -> None:
+ """Skips test if AdaptiveAvgPool3d and cuda.
+
+ Args:
+ request: problem request
+ """
+ if ADAPTIVE_AVG_POOL_BUG:
+ if all(
+ string in request.node.callspec.id
+ for string in ["AdaptiveAvgPool3d", "cuda"]
+ ):
+ skip(
+ "Skip test because AdaptiveAvgPool3d does not work on cuda. "
+ "Should be fixed in torch 2.0."
+ )
+
+
+def skip_batch_norm_train_mode_with_subsampling(
+ problem: DerivativesTestProblem, subsampling: Union[List[int], None]
+) -> None:
+ """Skip BatchNorm in train mode when sub-sampling is turned on.
+
+ Args:
+ problem: Test case.
+ subsampling: Indices of active samples.
+ """
+ if isinstance(problem.module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)):
+ if problem.module.train and subsampling is not None:
+ skip(f"Skipping BatchNorm in train mode with sub-sampling: {subsampling}")
+
+
+def skip_subsampling_conflict(
+ problem: Union[DerivativesTestProblem, ExtensionsTestProblem],
+ subsampling: Union[List[int], None],
+) -> None:
+ """Skip if some samples in subsampling are not contained in input.
+
+ Args:
+ problem: Test case.
+ subsampling: Indices of active samples.
+ """
+ N = problem.get_batch_size()
+ enough_samples = subsampling is None or N > max(subsampling)
+ if not enough_samples:
+ skip("Not enough samples.")
+
+
+def skip_large_parameters(
+ problem: ExtensionsTestProblem, max_num_params: int = 1000
+) -> None:
+ """Skip architectures with too many parameters.
+
+ Args:
+ problem: Test case.
+ max_num_params: Maximum number of model parameters. Default: ``1000``.
+ """
+ num_params = sum(p.numel() for p in problem.trainable_parameters())
+ if num_params > max_num_params:
+ skip(f"Model has too many parameters: {num_params} > {max_num_params}")
diff --git a/test/utils/test_subsampling.py b/test/utils/test_subsampling.py
new file mode 100644
index 000000000..750c06b1f
--- /dev/null
+++ b/test/utils/test_subsampling.py
@@ -0,0 +1,22 @@
+"""Contains tests of sub-sampling functionality."""
+
+from torch import allclose, manual_seed, rand
+
+from backpack.utils.subsampling import subsample
+
+
+def test_subsample():
+ """Test slicing operations for sub-sampling a tensor's batch axis."""
+ manual_seed(0)
+ tensor = rand(3, 4, 5, 6)
+
+ # leave tensor untouched when `subsampling = None`
+ assert id(subsample(tensor)) == id(tensor)
+ assert allclose(subsample(tensor), tensor)
+
+ # slice along correct dimension
+ idx = [2, 0]
+ assert allclose(subsample(tensor, dim=0, subsampling=idx), tensor[idx])
+ assert allclose(subsample(tensor, dim=1, subsampling=idx), tensor[:, idx])
+ assert allclose(subsample(tensor, dim=2, subsampling=idx), tensor[:, :, idx])
+ assert allclose(subsample(tensor, dim=3, subsampling=idx), tensor[:, :, :, idx])