Skip to content

Commit

Permalink
Stabilize and move penzai.experimental.v2 to the top level.
Browse files Browse the repository at this point in the history
This change moves the V2 implementation of Penzai's neural net system
out of experimental and into the main penzai namespace, along with
the corresponding toolshed modules and example models.

For compatibility, the penzai.experimental.v2 aliases will remain, and
will redirect to the stable locations (implemented using dynamic stub
modules).

PiperOrigin-RevId: 656536074
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 26, 2024
1 parent 9ba0c87 commit 5bdc18f
Show file tree
Hide file tree
Showing 83 changed files with 363 additions and 423 deletions.
6 changes: 3 additions & 3 deletions docs/_include/_glue_figures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import myst_nb\n",
"\n",
"import penzai\n",
"from penzai.experimental.v2 import pz\n",
"from penzai import pz\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -53,8 +53,8 @@
},
"outputs": [],
"source": [
"from penzai.experimental.v2.models.transformer.variants import gemma\n",
"from penzai.experimental.v2.models.transformer.variants import llamalike_common"
"from penzai.models.transformer.variants import gemma\n",
"from penzai.models.transformer.variants import llamalike_common"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/api/penzai.experimental.v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ You can read about the V2 design in the guide :doc:`"How to Think in Penzai (v2

To use the V2 API, we suggest importing the `pz` alias namespace from `penzai.experimental.v2.pz`: ::

from penzai.experimental.v2 import pz
from penzai import pz

The rest of this page lists the main components used in the V2 API.

Expand Down
6 changes: 3 additions & 3 deletions docs/guides/howto_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ This notebook is a guide to accomplishing a variety of tasks with Penzai, using

For this guide, we assume you have imported the experimental V2 version of the `pz` alias namespace:
```
from penzai.experimental.v2 import pz
from penzai import pz
```

## Visualization
Expand Down Expand Up @@ -223,7 +223,7 @@ Penzai's Gemma implementation includes a conversion utility that converts the ["
```python
import kagglehub
import orbax.checkpoint
from penzai.experimental.v2.models.transformer import variants
from penzai.models.transformer import variants

weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
Expand All @@ -239,7 +239,7 @@ Penzai also includes re-implementations of the architectures used by [Llama](htt

```python
import transformers
from penzai.experimental.v2.models.transformer import variants
from penzai.models.transformer import variants

# To load a Llama model:
hf_model = transformers.LlamaForCausalLM.from_pretrained(...)
Expand Down
6 changes: 3 additions & 3 deletions docs/guides/v2_differences.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ from penzai.deprecated.v1.example_models import simple_mlp
import penzai.toolshed

# New
from penzai.experimental.v2 import pz
from penzai.experimental.v2.models import simple_mlp
from penzai import pz
from penzai.models import simple_mlp
import penzai.experimental.v2.toolshed
```

Expand Down Expand Up @@ -381,7 +381,7 @@ model = gemma.model_core.GemmaTransformer.from_pretrained(flax_params_dict)
# (model is an instance of GemmaTransformer)

# New
from penzai.experimental.v2.models.transformer import variants
from penzai.models.transformer import variants
model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
# (model is an instance of TransformerLM)
```
Expand Down
6 changes: 3 additions & 3 deletions notebooks/v2_how_to_think_in_penzai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@
"outputs": [],
"source": [
"import penzai\n",
"from penzai.experimental.v2 import pz\n",
"from penzai.experimental.v2.models import simple_mlp"
"from penzai import pz\n",
"from penzai.models import simple_mlp"
]
},
{
Expand Down Expand Up @@ -1277,7 +1277,7 @@
},
"outputs": [],
"source": [
"from penzai.experimental.v2.toolshed import basic_training\n",
"from penzai.toolshed import basic_training\n",
"import optax"
]
},
Expand Down
4 changes: 2 additions & 2 deletions notebooks/v2_induction_heads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@
"outputs": [],
"source": [
"import penzai\n",
"from penzai.experimental.v2 import pz\n",
"from penzai import pz\n",
"\n",
"from penzai.experimental.v2.models import transformer"
"from penzai.models import transformer"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions notebooks/v2_induction_heads_2B.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@
"outputs": [],
"source": [
"import penzai\n",
"from penzai.experimental.v2 import pz\n",
"from penzai import pz\n",
"\n",
"from penzai.experimental.v2.models import transformer"
"from penzai.models import transformer"
]
},
{
Expand Down
16 changes: 8 additions & 8 deletions notebooks/v2_jitting_and_sharding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"outputs": [],
"source": [
"import penzai\n",
"from penzai.experimental.v2 import pz"
"from penzai import pz"
]
},
{
Expand All @@ -146,9 +146,9 @@
},
"outputs": [],
"source": [
"from penzai.experimental.v2.models import transformer\n",
"from penzai.experimental.v2.models import simple_mlp\n",
"from penzai.experimental.v2.toolshed import basic_training"
"from penzai.models import transformer\n",
"from penzai.models import simple_mlp\n",
"from penzai.toolshed import basic_training"
]
},
{
Expand Down Expand Up @@ -610,7 +610,7 @@
},
"outputs": [],
"source": [
"from penzai.experimental.v2.toolshed import jit_wrapper"
"from penzai.toolshed import jit_wrapper"
]
},
{
Expand Down Expand Up @@ -1181,7 +1181,7 @@
},
"outputs": [],
"source": [
"from penzai.experimental.v2.toolshed import sharding_util"
"from penzai.toolshed import sharding_util"
]
},
{
Expand Down Expand Up @@ -1336,8 +1336,8 @@
},
"outputs": [],
"source": [
"from penzai.experimental.v2.toolshed import sharding_util\n",
"from penzai.experimental.v2.models.transformer.variants import llamalike_common"
"from penzai.toolshed import sharding_util\n",
"from penzai.models.transformer.variants import llamalike_common"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions notebooks/v2_lora_from_scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
"outputs": [],
"source": [
"import penzai\n",
"from penzai.experimental.v2 import pz"
"from penzai import pz"
]
},
{
Expand All @@ -164,11 +164,11 @@
},
"outputs": [],
"source": [
"from penzai.experimental.v2.models import transformer\n",
"from penzai.experimental.v2.models import simple_mlp\n",
"from penzai.models import transformer\n",
"from penzai.models import simple_mlp\n",
"from penzai.toolshed import token_visualization\n",
"from penzai.experimental.v2.toolshed import basic_training\n",
"from penzai.experimental.v2.toolshed import jit_wrapper"
"from penzai.toolshed import basic_training\n",
"from penzai.toolshed import jit_wrapper"
]
},
{
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import jax
from penzai.core import struct
from penzai.experimental.v2.core import variables
from penzai.core import variables


@struct.pytree_dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@

import jax
import jax.numpy as jnp
from penzai.core import auto_order_types
from penzai.core import selectors
from penzai.core import struct
from penzai.experimental.v2.core import auto_order_types


T = TypeVar("T")
Expand Down
2 changes: 1 addition & 1 deletion penzai/deprecated/v1/pz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@

from . import de
from . import nn
from penzai.experimental.v2.pz import ts # pylint: disable=g-bad-import-order
from penzai.pz import ts # pylint: disable=g-bad-import-order
2 changes: 1 addition & 1 deletion penzai/deprecated/v1/toolshed/sharding_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import jax
from penzai.deprecated.v1 import pz
from penzai.experimental.v2.toolshed import sharding_util as sharding_util_v2
from penzai.toolshed import sharding_util as sharding_util_v2

PyTreeOfArrays = Any
PyTreeOfNamedArrays = Any
Expand Down
Loading

0 comments on commit 5bdc18f

Please sign in to comment.