Skip to content

Commit

Permalink
Merge branch 'master' into development
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jun 26, 2023
2 parents d2304e5 + ffa6068 commit 67ae8fe
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 86 deletions.
6 changes: 5 additions & 1 deletion backpack/utils/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,24 @@ def get_mnist_dataloader(batch_size: int = 64, shuffle: bool = True) -> DataLoad


def load_one_batch_mnist(
batch_size: int = 64, shuffle: bool = True
batch_size: int = 64, shuffle: bool = True, flat: bool = False
) -> Tuple[Tensor, Tensor]:
"""Return a single mini-batch (inputs, labels) from MNIST.
Args:
batch_size: Mini-batch size. Default: ``64``.
shuffle: Randomly shuffle the data. Default: ``True``.
flat: Flatten chanel and returns a matrix ``[batch_size x 784]``
Returns:
A single batch (inputs, labels) from MNIST.
"""
dataloader = get_mnist_dataloader(batch_size, shuffle)
X, y = next(iter(dataloader))

if flat:
X = X.reshape(X.shape[0], -1)

return X, y


Expand Down
76 changes: 38 additions & 38 deletions docs/examples.html

Large diffs are not rendered by default.

36 changes: 20 additions & 16 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ <h1 style="margin-top:auto; margin-bottom:auto; display:block">

"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="n">load_mnist_data</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>


<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_mnist_data</span><span class="p">()</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>
Expand All @@ -127,10 +127,11 @@ <h1 style="margin-top:auto; margin-bottom:auto; display:block">
and the variance with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="n">load_mnist_data</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span><span class="p">,</span> <span class="n">Variance</span></span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">Variance</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_mnist_data</span><span class="p">()</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>
Expand All @@ -150,10 +151,11 @@ <h1 style="margin-top:auto; margin-bottom:auto; display:block">
and the second moment with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="n">load_mnist_data</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span><span class="p">,</span> <span class="n">SumGradSquared</span></span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">SumGradSquared</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_mnist_data</span><span class="p">()</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>
Expand All @@ -173,10 +175,11 @@ <h1 style="margin-top:auto; margin-bottom:auto; display:block">
and the diagonal of the Gauss-Newton with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="n">load_mnist_data</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span><span class="p">,</span> <span class="n">DiagGGNExact</span></span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">DiagGGNExact</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_mnist_data</span><span class="p">()</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>
Expand All @@ -196,10 +199,11 @@ <h1 style="margin-top:auto; margin-bottom:auto; display:block">
and KFAC with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="n">load_mnist_data</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span><span class="p">,</span> <span class="n">KFAC</span></span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">KFAC</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_mnist_data</span><span class="p">()</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>
Expand Down Expand Up @@ -292,14 +296,14 @@ <h1 style="margin-top:auto; margin-bottom:auto; display:block">
<hr />

<p><strong>Install with</strong></p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip install backpack-for-pytorch
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip install backpack-for-pytorch
</code></pre></div></div>

<hr />

<p>If you use BackPACK in your research, please cite <float style="float:right"><a href="/assets/dangel2020backpack.bib">download bibtex</a></float></p>

<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@inproceedings{dangel2020backpack,
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@inproceedings{dangel2020backpack,
title = {BackPACK: Packing more into Backprop},
author = {Felix Dangel and Frederik Kunstner and Philipp Hennig},
booktitle = {International Conference on Learning Representations},
Expand Down
1 change: 1 addition & 0 deletions docs_src/CNAME
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
backpack.pt
40 changes: 23 additions & 17 deletions docs_src/README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
**Building the web version**

Requirements: [Jekyll](https://jekyllrb.com/docs/installation/) and [Sphinx](https://www.sphinx-doc.org/en/1.8/usage/installation.html)
Requirements: [Jekyll](https://jekyllrb.com/docs/installation/) and [Sphinx](https://www.sphinx-doc.org/en/1.8/usage/installation.html)
and installing the jekyll dependencies (`bundle install` in `docs_src/splash`)

Full build to output results in `../docs`
```
bash buildweb.sh
```
- Full build to output results in `../docs`
```
bash buildweb.sh
```

Local build of the Jekyll splash page
```
cd splash
bundle exec jekyll server
```
and go to `localhost:4000/backpack`
- Local build of the Jekyll splash page
```
cd splash
bundle exec jekyll server
```
and go to `localhost:4000/backpack`

Note: The code examples on backpack.pt are defined with HTML tags in
`splash/_includes/code-samples.html`.
There are no python source file to generate them.
Test manually by copy-pasting from the resulting page.

Local build of the documentation
```
cd rtd
make
```
and open `/docs_src/rtd_output/index.html`
- Local build of the documentation
```
cd rtd
make
```
and open `/docs_src/rtd_output/index.html`



1 change: 1 addition & 0 deletions docs_src/buildweb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ cd splash
bundle exec jekyll build -d "../../docs"
cd ..
touch ../docs/.nojekyll
cp CNAME ../docs/CNAME
Loading

0 comments on commit 67ae8fe

Please sign in to comment.