-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge release 1.5.0 from f-dangel/development (#247)
Merge release 1.5.0 from f-dangel/development
- Loading branch information
Showing
49 changed files
with
1,213 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# This workflows will upload a Python Package using Twine when a release is created | ||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries | ||
|
||
name: Upload Python Package | ||
|
||
on: | ||
release: | ||
types: [created] | ||
|
||
jobs: | ||
deploy: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: "3.x" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install setuptools wheel twine | ||
- name: Build and publish | ||
env: | ||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} | ||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} | ||
run: | | ||
python setup.py sdist bdist_wheel | ||
twine upload dist/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Contains derivatives of N-dimensional padding.""" | ||
|
||
from typing import List, Sequence, Tuple | ||
|
||
from torch import Tensor | ||
|
||
from backpack.core.derivatives.basederivatives import BaseDerivatives | ||
from backpack.custom_module.pad import Pad | ||
|
||
|
||
class PadDerivatives(BaseDerivatives): | ||
"""Derivatives of Pad.""" | ||
|
||
def _jac_t_mat_prod( | ||
self, | ||
module: Pad, | ||
g_inp: Tuple[Tensor], | ||
g_out: Tuple[Tensor], | ||
mat: Tensor, | ||
subsampling: List[int] = None, | ||
) -> Tensor: | ||
self.no_pad_batch_axis(module) | ||
|
||
return self.unpad(mat, module.pad, module.mode, module.value) | ||
|
||
@staticmethod | ||
def no_pad_batch_axis(module: Pad): | ||
"""Assert the batch axis is not padded. | ||
Args: | ||
module: Pad module. | ||
Raises: | ||
ValueError: If the batch axis is padded. | ||
""" | ||
num_pad_axes = len(module.pad) // 2 | ||
if num_pad_axes == module.input0.dim(): | ||
raise ValueError("Padding the batch axis is not supported.") | ||
|
||
@staticmethod | ||
def unpad(tensor: Tensor, pad: Sequence[int], mode: str, value: float) -> Tensor: | ||
"""Remove padding from a tensor. | ||
Undoes the operation ``torch.nn.functional.pad``. | ||
Args: | ||
pad: Tuple of even length specifying the padding. | ||
mode: Padding mode. | ||
value: Fill value for constant padding. | ||
Returns: | ||
Unpadded tensor. | ||
Raises: | ||
NotImplementedError: If padding mode is not constant. | ||
""" | ||
if mode != "constant": | ||
raise NotImplementedError("Only mode='constant' is supported.") | ||
|
||
pad_axes = len(pad) // 2 | ||
unaffected = tensor.dim() - pad_axes | ||
|
||
no_slice = [slice(None) for _ in range(unaffected)] | ||
unpad_slice = [] | ||
|
||
for affected in range(pad_axes): | ||
pad_start, pad_end = pad[2 * affected : 2 * affected + 2] | ||
dim = tensor.shape[tensor.dim() - 1 - affected] | ||
unpad_slice.insert(0, slice(pad_start, dim - pad_end)) | ||
|
||
return tensor[no_slice + unpad_slice] | ||
|
||
def hessian_is_zero(self, module: Pad) -> bool: # noqa: D102 | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Contains derivatives of slicing operation.""" | ||
from typing import List, Tuple | ||
|
||
from torch import Tensor, zeros | ||
|
||
from backpack.core.derivatives.basederivatives import BaseDerivatives | ||
from backpack.custom_module.slicing import Slicing | ||
from backpack.utils.subsampling import subsample | ||
|
||
|
||
class SlicingDerivatives(BaseDerivatives): | ||
"""Derivatives of Slicing.""" | ||
|
||
def _jac_t_mat_prod( | ||
self, | ||
module: Slicing, | ||
g_inp: Tuple[Tensor], | ||
g_out: Tuple[Tensor], | ||
mat: Tensor, | ||
subsampling: List[int] = None, | ||
) -> Tensor: | ||
self.no_slice_batch_axis(module) | ||
|
||
input0 = module.input0 | ||
result_shape = (mat.shape[0], *subsample(input0, subsampling=subsampling).shape) | ||
result = zeros(result_shape, device=input0.device, dtype=input0.dtype) | ||
result[(slice(None),) + module.slice_info] = mat | ||
|
||
return result | ||
|
||
@staticmethod | ||
def no_slice_batch_axis(module: Slicing): | ||
"""Assert the batch axis is not sliced. | ||
Args: | ||
module: Slicing module. | ||
Raises: | ||
ValueError: If the batch axis is sliced. | ||
""" | ||
slice_batch_axis = module.slice_info[0] | ||
|
||
if slice_batch_axis != slice(None): | ||
raise ValueError("Slicing the batch axis is not supported.") | ||
|
||
def hessian_is_zero(self, module: Slicing) -> bool: # noqa: D102 | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Module version of ``torch.nn.functional.pad``.""" | ||
|
||
from typing import Sequence | ||
|
||
from torch import Tensor | ||
from torch.nn import Module | ||
from torch.nn.functional import pad | ||
|
||
|
||
class Pad(Module): | ||
"""Module version of ``torch.nn.functional.pad`` (N-dimensional padding).""" | ||
|
||
def __init__(self, pad: Sequence[int], mode: str = "constant", value: float = 0.0): | ||
"""Store padding hyperparameters. | ||
See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. | ||
Args: | ||
pad: Tuple of even length specifying the padding. | ||
mode: Padding mode. Default ``'constant'``. | ||
value: Fill value for constant padding. Default ``0.0``. | ||
""" | ||
super().__init__() | ||
self.pad = pad | ||
self.mode = mode | ||
self.value = value | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
"""Pad the input tensor. | ||
Args: | ||
input: Input tensor. | ||
Returns: | ||
Padded input tensor. | ||
""" | ||
return pad(input, self.pad, mode=self.mode, value=self.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Custom module to perform tensor slicing.""" | ||
|
||
from typing import Tuple, Union | ||
|
||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
|
||
class Slicing(Module): | ||
"""Module that slices a tensor.""" | ||
|
||
def __init__(self, slice_info: Tuple[Union[slice, int]]): | ||
"""Store the slicing object. | ||
Args: | ||
slice_info: Argument that is passed to the slicing operator in the | ||
forward pass. | ||
""" | ||
super().__init__() | ||
self.slice_info = slice_info | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
"""Slice the input tensor. | ||
Args: | ||
input: Input tensor. | ||
Returns: | ||
Sliced input tensor. | ||
""" | ||
return input[self.slice_info] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Contains ``DiagGGN{Exact, MC}`` extension for BackPACK's custom ``Pad`` module.""" | ||
|
||
from backpack.core.derivatives.pad import PadDerivatives | ||
from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule | ||
|
||
|
||
class DiagGGNPad(DiagGGNBaseModule): | ||
"""``DiagGGN{Exact, MC}`` extension for ``backpack.custom_modules.pad.Pad``.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" | ||
super().__init__(PadDerivatives()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Holds ``DiagGGN{Exact, MC}`` extension for BackPACK's custom ``Slicing`` module.""" | ||
|
||
from backpack.core.derivatives.slicing import SlicingDerivatives | ||
from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule | ||
|
||
|
||
class DiagGGNSlicing(DiagGGNBaseModule): | ||
"""``DiagGGN{Exact, MC}`` for ``backpack.custom_modules.slicing.Slicing``.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" | ||
super().__init__(SlicingDerivatives()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Contains ``DiagH`` extension for BackPACK's custom ``Pad`` module.""" | ||
|
||
from backpack.core.derivatives.pad import PadDerivatives | ||
from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule | ||
|
||
|
||
class DiagHPad(DiagHBaseModule): | ||
"""``DiagH`` extension for ``backpack.custom_modules.pad.Pad``.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" | ||
super().__init__(PadDerivatives()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Contains ``DiagH`` extension for BackPACK's custom ``Slicing`` module.""" | ||
|
||
from backpack.core.derivatives.slicing import SlicingDerivatives | ||
from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule | ||
|
||
|
||
class DiagHSlicing(DiagHBaseModule): | ||
"""``DiagH`` extension for ``backpack.custom_modules.slicing.Slicing``.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``backpack.custom_modules.slicing.Slicing`` module.""" | ||
super().__init__(SlicingDerivatives()) |
Oops, something went wrong.