Skip to content

Commit

Permalink
Update Penzai docs for V2 API and treescope dependency.
Browse files Browse the repository at this point in the history
Updates the Penzai documentation to refer to the new V2 API as stable, and removes
outdated documentation of the old V1 API. Also adds pointers to the new Treescope
package.

PiperOrigin-RevId: 653991096
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 26, 2024
1 parent 5bdc18f commit 84a732b
Show file tree
Hide file tree
Showing 35 changed files with 2,152 additions and 25,223 deletions.
112 changes: 58 additions & 54 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,65 +22,63 @@ each useable independently:

* A superpowered interactive Python pretty-printer:

* `penzai.treescope` (``pz.ts``): A drop-in replacement for the ordinary
IPython/Colab renderer. It's designed to help understand Penzai models and
other deeply-nested JAX pytrees, with built-in support for visualizing
arbitrary-dimensional NDArrays.
* [Treescope](https://treescope.readthedocs.io/en/stable/) (`pz.ts`):
A drop-in replacement for the ordinary IPython/Colab renderer, originally
a part of Penzai but now available as a standalone package. It's designed to
help understand Penzai models and other deeply-nested JAX pytrees, with
built-in support for visualizing arbitrary-dimensional NDArrays.

* A set of JAX tree and array manipulation utilities:

* `penzai.core.selectors` (``pz.select``): A pytree swiss-army-knife,
generalizing JAX's ``.at[...].set(...)`` syntax to arbitrary type-driven
* `penzai.core.selectors` (`pz.select`): A pytree swiss-army-knife,
generalizing JAX's `.at[...].set(...)` syntax to arbitrary type-driven
pytree traversals, and making it easy to do complex rewrites or
on-the-fly patching of Penzai models and other data structures.

* `penzai.core.named_axes` (``pz.nx``): A lightweight named axis system which
* `penzai.core.named_axes` (`pz.nx`): A lightweight named axis system which
lifts ordinary JAX functions to vectorize over named axes, and allows you to
seamlessly switch between named and positional programming styles without
having to learn a new array API.

* A declarative combinator-based neural network library, where models are
represented as easy-to-modify data structures:

* `penzai.deprecated.v1.nn` (``pz.nn``): An alternative to other neural network libraries like
* `penzai.nn` (`pz.nn`): An alternative to other neural network libraries like
Flax, Haiku, Keras, or Equinox, which exposes the full structure of your model's
forward pass in the model pytree. This means you can see everything your model
forward pass using declarative combinators. Like Equinox, models are
represented as JAX PyTrees, which means you can see everything your model
does by pretty printing it, and inject new runtime logic with `jax.tree_util`.
Like Equinox, there's no magic: models are just callable pytrees under the
hood.
However, `penzai.nn` models may also contain mutable variables at the leaves
of the tree, allowing them to keep track of mutable state and parameter
sharing.

* `penzai.deprecated.v1.data_effects` (``pz.de``): An opt-in system for side arguments, random
numbers, and state variables that is built on pytree traversal and puts you
in control, without getting in the way of writing or using your model.
* A modular implementation of common Transformer architectures, to support
research into interpretability, model surgery, and training dynamics:

* **(NEW)** `penzai.experimental.v2`: An improved version of `penzai.deprecated.v1.nn` with
less boilerplate, including first-class support for mutable state and
parameter sharing.

* An implementation of the Gemma open-weights model using modular components and
named axes, built to enable interpretability and model surgery research.

* **(NEW)** The V2 version also supports Llama, Mistral, and GPT-NeoX / Pythia
models!
* `penzai.models.transformer`: A reference Transformer implementation that
can load the pre-trained weights for the Gemma, Llama, Mistral, and
GPT-NeoX / Pythia architectures. Built using modular components and named
axes, to simplify complex model-manipulation workflows.

Documentation on Penzai can be found at
[https://penzai.readthedocs.io](https://penzai.readthedocs.io).

> [!IMPORTANT]
> Penzai currently has two versions of its neural network API: the original
> "V1" API, and a new "V2" API located in `penzai.experimental.v2`.
>
> The V2 API aims to be simpler and more flexible, by introducing first-class
> support for mutable state and parameter sharing, and removing unnecessary
> boilerplate. It also includes a more flexible transformer implementation with
> support for more pretrained model variants. You can read about the
> differences between the two APIs in the
> Penzai 0.2 includes a number of breaking changes to the neural network API.
> These changes are intended to simplify common workflows
> by introducing first-class support for mutable state and parameter sharing
> and removing unnecessary boilerplate. You can read about the differences
> between the old "V1" API and the current "V2" API in the
> ["Changes in the V2 API"][v2_differences] overview.
>
> We plan to stabilize the V2 API and move it out of experimental in release
> ``0.2.0``, replacing the V1 API. If you wish to keep the V1 behavior, we
> recommend pinning the ``0.1.x`` release series (e.g. ``penzai>=0.1,<0.2``)
> to avoid breaking changes.
> If you are currently using the V1 API and have not yet converted to the V2
> system, you can instead keep the old behavior by importing from the
> `penzai.deprecated.v1` submodule, e.g. ::
>
> ```python
> from penzai.deprecated.v1 import pz
> from penzai.deprecated.v1.example_models import simple_mlp
> ```
[v2_differences]: https://penzai.readthedocs.io/en/stable/guides/v2_differences.html
Expand All @@ -100,32 +98,29 @@ and import it using

```python
import penzai
from penzai.deprecated.v1 import pz
from penzai import pz
```

(`penzai.pz` is an *alias namespace*, which makes it easier to reference
common Penzai objects.)

When working in an Colab or IPython notebook, we recommend also configuring
Penzai as the default pretty printer, and enabling some utilities for
interactive use:
Treescope (Penzai's companion pretty-printer) as the default pretty printer, and
enabling some utilities for interactive use:

```python
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

# Optional: enables automatic array visualization
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())
import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)
```

Here's how you could initialize and visualize a simple neural network:

```python
from penzai.deprecated.v1.example_models import simple_mlp
mlp = pz.nn.initialize_parameters(
simple_mlp.MLP.from_config([8, 32, 32, 8]),
jax.random.key(42),
from penzai.models import simple_mlp
mlp = simple_mlp.MLP.from_config(
name="mlp",
init_base_rng=jax.random.key(0),
feature_sizes=[8, 32, 32, 8]
)

# Models and arrays are visualized automatically when you output them from a
Expand All @@ -137,23 +132,32 @@ Here's how you could capture and extract the activations after the elementwise
nonlinearities:

```python
mlp_with_captured_activations = pz.de.CollectingSideOutputs.handling(
@pz.pytree_dataclass
class AppendIntermediate(pz.nn.Layer):
saved: pz.StateVariable[list[Any]]
def __call__(self, x: Any, **unused_side_inputs) -> Any:
self.saved.value = self.saved.value + [x]
return x

var = pz.StateVariable(value=[], label="my_intermediates")

# Make a copy of the model that saves its activations:
saving_model = (
pz.select(mlp)
.at_instances_of(pz.nn.Elementwise)
.insert_after(pz.de.TellIntermediate())
.insert_after(AppendIntermediate(var))
)

output, intermediates = mlp_with_captured_activations(
pz.nx.ones({"features": 8})
)
output = saving_model(pz.nx.ones({"features": 8}))
intermediates = var.value
```

To learn more about how to build and manipulate neural networks with Penzai,
we recommend starting with the "How to Think in Penzai" tutorial ([V1 API version][how_to_think_1], [V2 API version][how_to_think_2]), or one
of the other tutorials in the [Penzai documentation][].

[how_to_think_1]: https://penzai.readthedocs.io/en/stable/notebooks/how_to_think_in_penzai.html
[how_to_think_2]: https://penzai.readthedocs.io/en/stable/notebooks/v2_how_to_think_in_penzai.html
[how_to_think_2]: https://penzai.readthedocs.io/en/stable/notebooks/how_to_think_in_penzai.html
[Penzai documentation]: https://penzai.readthedocs.io


Expand Down
13 changes: 6 additions & 7 deletions docs/_autogen_root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@
:recursive:

penzai.core
penzai.nn
penzai.models
penzai.toolshed

penzai.deprecated.v1.core
penzai.deprecated.v1.nn
penzai.deprecated.v1.data_effects
penzai.deprecated.v1.example_models
penzai.toolshed

penzai.experimental.v2.core
penzai.experimental.v2.nn
penzai.experimental.v2.models
penzai.experimental.v2.toolshed
penzai.deprecated.v1.toolshed

.. toctree::
:hidden:

notebooks/induction_heads_2B
notebooks/v2_induction_heads_2B
_include/_glue_figures
3 changes: 1 addition & 2 deletions docs/_include/_glue_figures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@
"outputs": [],
"source": [
"pz.ts.register_as_default(streaming=False)\n",
"pz.ts.register_autovisualize_magic()\n",
"pz.enable_interactive_context()"
"pz.ts.register_autovisualize_magic()"
]
},
{
Expand Down
Loading

0 comments on commit 84a732b

Please sign in to comment.