Skip to content

Commit

Permalink
Deployed 88e99b4 with MkDocs version: 1.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lebrice committed Jun 27, 2024
1 parent 975d52a commit 7b5806c
Show file tree
Hide file tree
Showing 20 changed files with 44 additions and 259 deletions.
4 changes: 1 addition & 3 deletions 404.html
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@
</ul>
<p class="caption"><span class="caption-text">Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="/ResearchTemplate/examples/SUMMARY/">SUMMARY</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="/ResearchTemplate/examples/examples/">Examples</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="/ResearchTemplate/examples/jax/">Jax</a>
<li class="toctree-l1"><a class="reference internal" href="/ResearchTemplate/examples/jax/">Using Jax</a>
</li>
</ul>
<ul>
Expand Down
4 changes: 1 addition & 3 deletions SUMMARY/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@
</ul>
<p class="caption"><span class="caption-text">Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../examples/SUMMARY/">SUMMARY</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../examples/examples/">Examples</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../examples/jax/">Jax</a>
<li class="toctree-l1"><a class="reference internal" href="../examples/jax/">Using Jax</a>
</li>
</ul>
<ul>
Expand Down
4 changes: 1 addition & 3 deletions contributing/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@
</ul>
<p class="caption"><span class="caption-text">Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../examples/SUMMARY/">SUMMARY</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../examples/examples/">Examples</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../examples/jax/">Jax</a>
<li class="toctree-l1"><a class="reference internal" href="../examples/jax/">Using Jax</a>
</li>
</ul>
<ul>
Expand Down
175 changes: 0 additions & 175 deletions examples/SUMMARY/index.html

This file was deleted.

29 changes: 4 additions & 25 deletions examples/examples/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,15 @@
</ul>
<p class="caption"><span class="caption-text">Examples</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../SUMMARY/">SUMMARY</a>
</li>
<li class="toctree-l1 current"><a class="reference internal current" href="#">Examples</a>
<ul class="current">
<li class="toctree-l2"><a class="reference internal" href="#simple-run">Simple run</a>
</li>
<li class="toctree-l2"><a class="reference internal" href="#running-a-hyper-parameter-sweep-on-a-slurm-cluster">Running a Hyper-Parameter sweep on a SLURM cluster</a>
</li>
<li class="toctree-l2"><a class="reference internal" href="#using-jax">Using Jax</a>
<ul>
<li class="toctree-l3"><a class="reference internal" href="#example-algorithm-that-uses-jax">Example Algorithm that uses Jax</a>
</li>
<li class="toctree-l3"><a class="reference internal" href="#example-datamodule-that-uses-jax">Example datamodule that uses Jax</a>
</li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../jax/">Jax</a>
<li class="toctree-l1"><a class="reference internal" href="../jax/">Using Jax</a>
</li>
</ul>
<ul>
Expand Down Expand Up @@ -152,23 +142,12 @@ <h2 id="simple-run">Simple run<a class="headerlink" href="#simple-run" title="Pe
<h2 id="running-a-hyper-parameter-sweep-on-a-slurm-cluster">Running a Hyper-Parameter sweep on a SLURM cluster<a class="headerlink" href="#running-a-hyper-parameter-sweep-on-a-slurm-cluster" title="Permanent link">#</a></h2>
<div class="highlight"><pre><span></span><code>python<span class="w"> </span>project/main.py<span class="w"> </span><span class="nv">experiment</span><span class="o">=</span>cluster_sweep_example
</code></pre></div>
<h2 id="using-jax">Using Jax<a class="headerlink" href="#using-jax" title="Permanent link">#</a></h2>
<p>You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.</p>
<p>How does this work?
Well, we use <a href="https://www.github.com/lebrice/torch_jax_interop">torch-jax-interop</a>, another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details.</p>
<h3 id="example-algorithm-that-uses-jax">Example Algorithm that uses Jax<a class="headerlink" href="#example-algorithm-that-uses-jax" title="Permanent link">#</a></h3>
<p>You can use Jax for your training step, but not the entire training loop (since that is handled by Lightning).
There are a few good reasons why you should let Lightning handle the training loop, most notably the fact that it handles all the logging, checkpointing, and other stuff that you'd lose if you swapped out the entire training framework for something based on Jax.</p>
<p>In this <a href="https://www.github.com/mila-iqia/ResearchTemplate/tree/master/project/algorithms/jax_algo.py">example Jax algorithm</a>,
a Neural network written in Jax (using <a href="https://flax.readthedocs.io/en/latest/">flax</a>) is wrapped using the <code>torch_jax_interop.JaxFunction</code>, so that its parameters are learnable. The parameters are saved on the LightningModule as nn.Parameters (which use the same underlying memory as the jax arrays). In this example, the loss function is written in PyTorch, while the network forward and backward passes are written in Jax.</p>
<h3 id="example-datamodule-that-uses-jax">Example datamodule that uses Jax<a class="headerlink" href="#example-datamodule-that-uses-jax" title="Permanent link">#</a></h3>
<p>(todo)</p>

</div>
</div><footer>
<div class="rst-footer-buttons" role="navigation" aria-label="Footer Navigation">
<a href="../SUMMARY/" class="btn btn-neutral float-left" title="SUMMARY"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../jax/" class="btn btn-neutral float-right" title="Jax">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="../../reference/project/utils/types/" class="btn btn-neutral float-left" title="Types"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../jax/" class="btn btn-neutral float-right" title="Using Jax">Next <span class="icon icon-circle-arrow-right"></span></a>
</div>

<hr/>
Expand All @@ -195,7 +174,7 @@ <h3 id="example-datamodule-that-uses-jax">Example datamodule that uses Jax<a cla
</span>


<span><a href="../SUMMARY/" style="color: #fcfcfc">&laquo; Previous</a></span>
<span><a href="../../reference/project/utils/types/" style="color: #fcfcfc">&laquo; Previous</a></span>


<span><a href="../jax/" style="color: #fcfcfc">Next &raquo;</a></span>
Expand Down
28 changes: 21 additions & 7 deletions examples/jax/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><link rel="canonical" href="https://mila-iqia.github.io/ResearchTemplate/examples/jax/" />
<link rel="shortcut icon" href="../../img/favicon.ico" />
<title>Jax - Research Project Template</title>
<title>Using Jax - Research Project Template</title>
<link rel="stylesheet" href="../../css/theme.css" />
<link rel="stylesheet" href="../../css/theme_extra.css" />
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.8.0/styles/github.min.css" />
<link href="../../assets/_mkdocstrings.css" rel="stylesheet" />

<script>
// Current page data
var mkdocs_page_name = "Jax";
var mkdocs_page_name = "Using Jax";
var mkdocs_page_input_path = "examples/jax.md";
var mkdocs_page_url = "/ResearchTemplate/examples/jax/";
</script>
Expand Down Expand Up @@ -73,11 +73,15 @@
</ul>
<p class="caption"><span class="caption-text">Examples</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../SUMMARY/">SUMMARY</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../examples/">Examples</a>
</li>
<li class="toctree-l1 current"><a class="reference internal current" href="#">Jax</a>
<li class="toctree-l1 current"><a class="reference internal current" href="#">Using Jax</a>
<ul class="current">
<li class="toctree-l2"><a class="reference internal" href="#example-algorithm-that-uses-jax">Example Algorithm that uses Jax</a>
</li>
<li class="toctree-l2"><a class="reference internal" href="#example-datamodule-that-uses-jax">Example datamodule that uses Jax</a>
</li>
</ul>
</li>
</ul>
<ul>
Expand Down Expand Up @@ -111,7 +115,7 @@
<ul class="wy-breadcrumbs">
<li><a href="../.." class="icon icon-home" aria-label="Docs"></a></li>
<li class="breadcrumb-item">Examples</li>
<li class="breadcrumb-item active">Jax</li>
<li class="breadcrumb-item active">Using Jax</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
Expand All @@ -120,7 +124,17 @@
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div class="section" itemprop="articleBody">


<h1 id="using-jax">Using Jax<a class="headerlink" href="#using-jax" title="Permanent link">#</a></h1>
<p>You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.</p>
<p>How does this work?
Well, we use <a href="https://www.github.com/lebrice/torch_jax_interop">torch-jax-interop</a>, another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details.</p>
<h2 id="example-algorithm-that-uses-jax">Example Algorithm that uses Jax<a class="headerlink" href="#example-algorithm-that-uses-jax" title="Permanent link">#</a></h2>
<p>You can use Jax for your training step, but not the entire training loop (since that is handled by Lightning).
There are a few good reasons why you should let Lightning handle the training loop, most notably the fact that it handles all the logging, checkpointing, and other stuff that you'd lose if you swapped out the entire training framework for something based on Jax.</p>
<p>In this <a href="https://www.github.com/mila-iqia/ResearchTemplate/tree/master/project/algorithms/jax_algo.py">example Jax algorithm</a>,
a Neural network written in Jax (using <a href="https://flax.readthedocs.io/en/latest/">flax</a>) is wrapped using the <code>torch_jax_interop.JaxFunction</code>, so that its parameters are learnable. The parameters are saved on the LightningModule as nn.Parameters (which use the same underlying memory as the jax arrays). In this example, the loss function is written in PyTorch, while the network forward and backward passes are written in Jax.</p>
<h2 id="example-datamodule-that-uses-jax">Example datamodule that uses Jax<a class="headerlink" href="#example-datamodule-that-uses-jax" title="Permanent link">#</a></h2>
<p>(todo)</p>

</div>
</div><footer>
Expand Down
4 changes: 1 addition & 3 deletions getting_started/install/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,9 @@
</ul>
<p class="caption"><span class="caption-text">Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../examples/SUMMARY/">SUMMARY</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/examples/">Examples</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/jax/">Jax</a>
<li class="toctree-l1"><a class="reference internal" href="../../examples/jax/">Using Jax</a>
</li>
</ul>
<ul>
Expand Down
Loading

0 comments on commit 7b5806c

Please sign in to comment.