Skip to content

Commit

Permalink
Merge release 1.5.0 from f-dangel/development (#247)
Browse files Browse the repository at this point in the history
Merge release 1.5.0 from f-dangel/development
  • Loading branch information
f-dangel authored Feb 15, 2022
2 parents 1da7e53 + 6127649 commit 0ab9421
Show file tree
Hide file tree
Showing 49 changed files with 1,213 additions and 52 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Upload Python Package

on:
release:
types: [created]

jobs:
deploy:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
74 changes: 74 additions & 0 deletions backpack/core/derivatives/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Contains derivatives of N-dimensional padding."""

from typing import List, Sequence, Tuple

from torch import Tensor

from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.custom_module.pad import Pad


class PadDerivatives(BaseDerivatives):
"""Derivatives of Pad."""

def _jac_t_mat_prod(
self,
module: Pad,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
self.no_pad_batch_axis(module)

return self.unpad(mat, module.pad, module.mode, module.value)

@staticmethod
def no_pad_batch_axis(module: Pad):
"""Assert the batch axis is not padded.
Args:
module: Pad module.
Raises:
ValueError: If the batch axis is padded.
"""
num_pad_axes = len(module.pad) // 2
if num_pad_axes == module.input0.dim():
raise ValueError("Padding the batch axis is not supported.")

@staticmethod
def unpad(tensor: Tensor, pad: Sequence[int], mode: str, value: float) -> Tensor:
"""Remove padding from a tensor.
Undoes the operation ``torch.nn.functional.pad``.
Args:
pad: Tuple of even length specifying the padding.
mode: Padding mode.
value: Fill value for constant padding.
Returns:
Unpadded tensor.
Raises:
NotImplementedError: If padding mode is not constant.
"""
if mode != "constant":
raise NotImplementedError("Only mode='constant' is supported.")

pad_axes = len(pad) // 2
unaffected = tensor.dim() - pad_axes

no_slice = [slice(None) for _ in range(unaffected)]
unpad_slice = []

for affected in range(pad_axes):
pad_start, pad_end = pad[2 * affected : 2 * affected + 2]
dim = tensor.shape[tensor.dim() - 1 - affected]
unpad_slice.insert(0, slice(pad_start, dim - pad_end))

return tensor[no_slice + unpad_slice]

def hessian_is_zero(self, module: Pad) -> bool: # noqa: D102
return True
47 changes: 47 additions & 0 deletions backpack/core/derivatives/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Contains derivatives of slicing operation."""
from typing import List, Tuple

from torch import Tensor, zeros

from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.custom_module.slicing import Slicing
from backpack.utils.subsampling import subsample


class SlicingDerivatives(BaseDerivatives):
"""Derivatives of Slicing."""

def _jac_t_mat_prod(
self,
module: Slicing,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
self.no_slice_batch_axis(module)

input0 = module.input0
result_shape = (mat.shape[0], *subsample(input0, subsampling=subsampling).shape)
result = zeros(result_shape, device=input0.device, dtype=input0.dtype)
result[(slice(None),) + module.slice_info] = mat

return result

@staticmethod
def no_slice_batch_axis(module: Slicing):
"""Assert the batch axis is not sliced.
Args:
module: Slicing module.
Raises:
ValueError: If the batch axis is sliced.
"""
slice_batch_axis = module.slice_info[0]

if slice_batch_axis != slice(None):
raise ValueError("Slicing the batch axis is not supported.")

def hessian_is_zero(self, module: Slicing) -> bool: # noqa: D102
return True
4 changes: 2 additions & 2 deletions backpack/core/derivatives/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def df(
subsampling: List[int] = None,
) -> Tensor:
output = subsample(module.output, subsampling=subsampling)
return 1.0 - output ** 2
return 1.0 - output**2

def d2f(self, module, g_inp, g_out):
return -2.0 * module.output * (1.0 - module.output ** 2)
return -2.0 * module.output * (1.0 - module.output**2)
37 changes: 37 additions & 0 deletions backpack/custom_module/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Module version of ``torch.nn.functional.pad``."""

from typing import Sequence

from torch import Tensor
from torch.nn import Module
from torch.nn.functional import pad


class Pad(Module):
"""Module version of ``torch.nn.functional.pad`` (N-dimensional padding)."""

def __init__(self, pad: Sequence[int], mode: str = "constant", value: float = 0.0):
"""Store padding hyperparameters.
See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html.
Args:
pad: Tuple of even length specifying the padding.
mode: Padding mode. Default ``'constant'``.
value: Fill value for constant padding. Default ``0.0``.
"""
super().__init__()
self.pad = pad
self.mode = mode
self.value = value

def forward(self, input: Tensor) -> Tensor:
"""Pad the input tensor.
Args:
input: Input tensor.
Returns:
Padded input tensor.
"""
return pad(input, self.pad, mode=self.mode, value=self.value)
31 changes: 31 additions & 0 deletions backpack/custom_module/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Custom module to perform tensor slicing."""

from typing import Tuple, Union

from torch import Tensor
from torch.nn import Module


class Slicing(Module):
"""Module that slices a tensor."""

def __init__(self, slice_info: Tuple[Union[slice, int]]):
"""Store the slicing object.
Args:
slice_info: Argument that is passed to the slicing operator in the
forward pass.
"""
super().__init__()
self.slice_info = slice_info

def forward(self, input: Tensor) -> Tensor:
"""Slice the input tensor.
Args:
input: Input tensor.
Returns:
Sliced input tensor.
"""
return input[self.slice_info]
2 changes: 1 addition & 1 deletion backpack/extensions/firstorder/batch_l2_grad/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def weight(
X = module.input0.flatten(start_dim=1, end_dim=-2)
return einsum("nmi,nmj,nki,nkj->n", dE_dY, X, dE_dY, X)
else:
return einsum("ni,nj->n", g_out[0] ** 2, module.input0 ** 2)
return einsum("ni,nj->n", g_out[0] ** 2, module.input0**2)
2 changes: 1 addition & 1 deletion backpack/extensions/firstorder/sum_grad_squared/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def weight(self, ext, module, g_inp, g_out, backproped):
X = module.input0.flatten(start_dim=1, end_dim=-2)
return einsum("nmi,nmj,nki,nkj->ij", dE_dY, X, dE_dY, X)
else:
return einsum("ni,nj->ij", g_out[0] ** 2, module.input0 ** 2)
return einsum("ni,nj->ij", g_out[0] ** 2, module.input0**2)
8 changes: 8 additions & 0 deletions backpack/extensions/secondorder/diag_ggn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@
)

from backpack.custom_module.branching import SumModule
from backpack.custom_module.pad import Pad
from backpack.custom_module.permute import Permute
from backpack.custom_module.scale_module import ScaleModule
from backpack.custom_module.slicing import Slicing
from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
from backpack.extensions.secondorder.hbp import LossHessianStrategy

Expand All @@ -69,10 +71,12 @@
flatten,
linear,
losses,
pad,
padding,
permute,
pooling,
rnn,
slicing,
)


Expand Down Expand Up @@ -143,6 +147,8 @@ def __init__(self, loss_hessian_strategy: str, savefield: str):
BatchNorm2d: batchnorm_nd.DiagGGNBatchNormNd(),
BatchNorm3d: batchnorm_nd.DiagGGNBatchNormNd(),
Embedding: embedding.DiagGGNEmbedding(),
Pad: pad.DiagGGNPad(),
Slicing: slicing.DiagGGNSlicing(),
},
)

Expand Down Expand Up @@ -266,6 +272,8 @@ def __init__(self, loss_hessian_strategy: str, savefield: str):
BatchNorm2d: batchnorm_nd.BatchDiagGGNBatchNormNd(),
BatchNorm3d: batchnorm_nd.BatchDiagGGNBatchNormNd(),
Embedding: embedding.BatchDiagGGNEmbedding(),
Pad: pad.DiagGGNPad(),
Slicing: slicing.DiagGGNSlicing(),
},
)

Expand Down
12 changes: 12 additions & 0 deletions backpack/extensions/secondorder/diag_ggn/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Contains ``DiagGGN{Exact, MC}`` extension for BackPACK's custom ``Pad`` module."""

from backpack.core.derivatives.pad import PadDerivatives
from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule


class DiagGGNPad(DiagGGNBaseModule):
"""``DiagGGN{Exact, MC}`` extension for ``backpack.custom_modules.pad.Pad``."""

def __init__(self):
"""Pass derivatives for ``backpack.custom_modules.pad.Pad`` module."""
super().__init__(PadDerivatives())
12 changes: 12 additions & 0 deletions backpack/extensions/secondorder/diag_ggn/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Holds ``DiagGGN{Exact, MC}`` extension for BackPACK's custom ``Slicing`` module."""

from backpack.core.derivatives.slicing import SlicingDerivatives
from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule


class DiagGGNSlicing(DiagGGNBaseModule):
"""``DiagGGN{Exact, MC}`` for ``backpack.custom_modules.slicing.Slicing``."""

def __init__(self):
"""Pass derivatives for ``backpack.custom_modules.pad.Pad`` module."""
super().__init__(SlicingDerivatives())
8 changes: 8 additions & 0 deletions backpack/extensions/secondorder/diag_hessian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
ZeroPad2d,
)

from backpack.custom_module.pad import Pad
from backpack.custom_module.slicing import Slicing
from backpack.extensions.secondorder.base import SecondOrderBackpropExtension

from . import (
Expand All @@ -45,8 +47,10 @@
flatten,
linear,
losses,
pad,
padding,
pooling,
slicing,
)


Expand Down Expand Up @@ -92,6 +96,8 @@ def __init__(self):
LogSigmoid: activations.DiagHLogSigmoid(),
ELU: activations.DiagHELU(),
SELU: activations.DiagHSELU(),
Pad: pad.DiagHPad(),
Slicing: slicing.DiagHSlicing(),
},
)

Expand Down Expand Up @@ -139,5 +145,7 @@ def __init__(self):
LogSigmoid: activations.DiagHLogSigmoid(),
ELU: activations.DiagHELU(),
SELU: activations.DiagHSELU(),
Pad: pad.DiagHPad(),
Slicing: slicing.DiagHSlicing(),
},
)
12 changes: 12 additions & 0 deletions backpack/extensions/secondorder/diag_hessian/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Contains ``DiagH`` extension for BackPACK's custom ``Pad`` module."""

from backpack.core.derivatives.pad import PadDerivatives
from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule


class DiagHPad(DiagHBaseModule):
"""``DiagH`` extension for ``backpack.custom_modules.pad.Pad``."""

def __init__(self):
"""Pass derivatives for ``backpack.custom_modules.pad.Pad`` module."""
super().__init__(PadDerivatives())
12 changes: 12 additions & 0 deletions backpack/extensions/secondorder/diag_hessian/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Contains ``DiagH`` extension for BackPACK's custom ``Slicing`` module."""

from backpack.core.derivatives.slicing import SlicingDerivatives
from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule


class DiagHSlicing(DiagHBaseModule):
"""``DiagH`` extension for ``backpack.custom_modules.slicing.Slicing``."""

def __init__(self):
"""Pass derivatives for ``backpack.custom_modules.slicing.Slicing`` module."""
super().__init__(SlicingDerivatives())
Loading

0 comments on commit 0ab9421

Please sign in to comment.