From 5bdc18f331b11dbc49ed71cd59671340cb0176f6 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Fri, 26 Jul 2024 14:50:45 -0700 Subject: [PATCH] Stabilize and move penzai.experimental.v2 to the top level. This change moves the V2 implementation of Penzai's neural net system out of experimental and into the main penzai namespace, along with the corresponding toolshed modules and example models. For compatibility, the penzai.experimental.v2 aliases will remain, and will redirect to the stable locations (implemented using dynamic stub modules). PiperOrigin-RevId: 656536074 --- docs/_include/_glue_figures.ipynb | 6 +- docs/api/penzai.experimental.v2.rst | 2 +- docs/guides/howto_reference.md | 6 +- docs/guides/v2_differences.md | 6 +- notebooks/v2_how_to_think_in_penzai.ipynb | 6 +- notebooks/v2_induction_heads.ipynb | 4 +- notebooks/v2_induction_heads_2B.ipynb | 4 +- notebooks/v2_jitting_and_sharding.ipynb | 16 +- notebooks/v2_lora_from_scratch.ipynb | 10 +- .../v2 => }/core/auto_order_types.py | 0 .../v2 => }/core/random_stream.py | 2 +- .../{experimental/v2 => }/core/variables.py | 2 +- penzai/deprecated/v1/pz/__init__.py | 2 +- .../deprecated/v1/toolshed/sharding_util.py | 2 +- penzai/experimental/v2/__init__.py | 219 ++++++++++++++++++ penzai/experimental/v2/pz/__init__.py | 91 -------- penzai/experimental/v2/pz/nn.py | 95 -------- penzai/experimental/v2/pz/ts.py | 54 ----- .../v2 => }/models/simple_mlp.py | 2 +- .../v2 => }/models/transformer/__init__.py | 0 .../v2 => }/models/transformer/model_parts.py | 2 +- .../models/transformer/sampling_mode.py | 4 +- .../transformer/simple_decoding_loop.py | 4 +- .../models/transformer/variants/__init__.py | 0 .../models/transformer/variants/gemma.py | 6 +- .../models/transformer/variants/gpt_neox.py | 4 +- .../models/transformer/variants/llama.py | 4 +- .../transformer/variants/llamalike_common.py | 4 +- .../models/transformer/variants/mistral.py | 4 +- .../nn/_treescope_handlers/layer_handler.py | 4 +- penzai/{experimental/v2 => }/nn/attention.py | 6 +- penzai/{experimental/v2 => }/nn/basic_ops.py | 2 +- .../{experimental/v2 => }/nn/combinators.py | 2 +- penzai/{experimental/v2 => }/nn/dropout.py | 4 +- penzai/{experimental/v2 => }/nn/embeddings.py | 6 +- penzai/{experimental/v2 => }/nn/grouping.py | 2 +- penzai/{experimental/v2 => }/nn/layer.py | 5 +- .../{experimental/v2 => }/nn/layer_stack.py | 4 +- .../v2 => }/nn/linear_and_affine.py | 6 +- penzai/{experimental/v2 => }/nn/parameters.py | 2 +- .../v2 => }/nn/standardization.py | 6 +- penzai/pz/__init__.py | 15 +- penzai/pz/nn.py | 22 +- .../v2 => }/toolshed/basic_training.py | 2 +- .../toolshed/gradient_checkpointing.py | 2 +- .../v2 => }/toolshed/isolate_submodel.py | 2 +- .../v2 => }/toolshed/jit_wrapper.py | 2 +- penzai/{experimental/v2 => }/toolshed/lora.py | 2 +- .../v2 => }/toolshed/model_rewiring.py | 2 +- .../v2 => }/toolshed/save_intermediates.py | 2 +- .../v2 => }/toolshed/sharding_util.py | 2 +- .../v2 => }/toolshed/unflaxify.py | 2 +- .../auto_order_types_test.py | 2 +- tests/{ => core}/misc_util_test.py | 0 tests/{ => core}/named_axes_test.py | 0 tests/{ => core}/partitioning_test.py | 0 tests/{ => core}/selectors_test.py | 0 tests/{ => core}/shapecheck_test.py | 0 .../struct_pytree_dataclass_test.py | 0 .../{experimental => core}/variables_test.py | 2 +- tests/experimental/models/__init__.py | 13 -- tests/experimental/nn/__init__.py | 13 -- tests/experimental/toolshed/__init__.py | 13 -- .../v2/core => tests/models}/__init__.py | 0 .../models/simple_mlp_test.py | 6 +- .../models/transformer_consistency_test.py | 8 +- .../models/transformer_llamalike_test.py | 8 +- tests/{experimental => nn}/__init__.py | 0 tests/{experimental => }/nn/basic_ops_test.py | 2 +- tests/{experimental => }/nn/embedding_test.py | 2 +- tests/{experimental => }/nn/grouping_test.py | 2 +- .../{experimental => }/nn/layer_stack_test.py | 2 +- tests/{experimental => }/nn/layer_test.py | 2 +- .../nn/linear_and_affine_test.py | 2 +- .../{experimental => }/nn/parameters_test.py | 2 +- .../nn/standardization_test.py | 2 +- .../toolshed/gradient_checkpointing_test.py | 6 +- .../toolshed/isolate_submodel_test.py | 6 +- .../toolshed/jit_wrapper_test.py | 6 +- .../{experimental => }/toolshed/lora_test.py | 8 +- .../toolshed/model_rewiring_test.py | 4 +- .../toolshed/save_intermediates_test.py | 6 +- .../toolshed/unflaxify_test.py | 4 +- 83 files changed, 363 insertions(+), 423 deletions(-) rename penzai/{experimental/v2 => }/core/auto_order_types.py (100%) rename penzai/{experimental/v2 => }/core/random_stream.py (97%) rename penzai/{experimental/v2 => }/core/variables.py (99%) create mode 100644 penzai/experimental/v2/__init__.py delete mode 100644 penzai/experimental/v2/pz/__init__.py delete mode 100644 penzai/experimental/v2/pz/nn.py delete mode 100644 penzai/experimental/v2/pz/ts.py rename penzai/{experimental/v2 => }/models/simple_mlp.py (98%) rename penzai/{experimental/v2 => }/models/transformer/__init__.py (100%) rename penzai/{experimental/v2 => }/models/transformer/model_parts.py (99%) rename penzai/{experimental/v2 => }/models/transformer/sampling_mode.py (98%) rename penzai/{experimental/v2 => }/models/transformer/simple_decoding_loop.py (96%) rename penzai/{experimental/v2 => }/models/transformer/variants/__init__.py (100%) rename penzai/{experimental/v2 => }/models/transformer/variants/gemma.py (97%) rename penzai/{experimental/v2 => }/models/transformer/variants/gpt_neox.py (99%) rename penzai/{experimental/v2 => }/models/transformer/variants/llama.py (94%) rename penzai/{experimental/v2 => }/models/transformer/variants/llamalike_common.py (99%) rename penzai/{experimental/v2 => }/models/transformer/variants/mistral.py (95%) rename penzai/{experimental/v2 => }/nn/_treescope_handlers/layer_handler.py (96%) rename penzai/{experimental/v2 => }/nn/attention.py (99%) rename penzai/{experimental/v2 => }/nn/basic_ops.py (98%) rename penzai/{experimental/v2 => }/nn/combinators.py (98%) rename penzai/{experimental/v2 => }/nn/dropout.py (98%) rename penzai/{experimental/v2 => }/nn/embeddings.py (98%) rename penzai/{experimental/v2 => }/nn/grouping.py (99%) rename penzai/{experimental/v2 => }/nn/layer.py (97%) rename penzai/{experimental/v2 => }/nn/layer_stack.py (99%) rename penzai/{experimental/v2 => }/nn/linear_and_affine.py (99%) rename penzai/{experimental/v2 => }/nn/parameters.py (98%) rename penzai/{experimental/v2 => }/nn/standardization.py (97%) rename penzai/{experimental/v2 => }/toolshed/basic_training.py (99%) rename penzai/{experimental/v2 => }/toolshed/gradient_checkpointing.py (98%) rename penzai/{experimental/v2 => }/toolshed/isolate_submodel.py (99%) rename penzai/{experimental/v2 => }/toolshed/jit_wrapper.py (98%) rename penzai/{experimental/v2 => }/toolshed/lora.py (99%) rename penzai/{experimental/v2 => }/toolshed/model_rewiring.py (99%) rename penzai/{experimental/v2 => }/toolshed/save_intermediates.py (99%) rename penzai/{experimental/v2 => }/toolshed/sharding_util.py (99%) rename penzai/{experimental/v2 => }/toolshed/unflaxify.py (99%) rename tests/{experimental => core}/auto_order_types_test.py (96%) rename tests/{ => core}/misc_util_test.py (100%) rename tests/{ => core}/named_axes_test.py (100%) rename tests/{ => core}/partitioning_test.py (100%) rename tests/{ => core}/selectors_test.py (100%) rename tests/{ => core}/shapecheck_test.py (100%) rename tests/{ => core}/struct_pytree_dataclass_test.py (100%) rename tests/{experimental => core}/variables_test.py (99%) delete mode 100644 tests/experimental/models/__init__.py delete mode 100644 tests/experimental/nn/__init__.py delete mode 100644 tests/experimental/toolshed/__init__.py rename {penzai/experimental/v2/core => tests/models}/__init__.py (100%) rename tests/{experimental => }/models/simple_mlp_test.py (95%) rename tests/{experimental => }/models/transformer_consistency_test.py (95%) rename tests/{experimental => }/models/transformer_llamalike_test.py (96%) rename tests/{experimental => nn}/__init__.py (100%) rename tests/{experimental => }/nn/basic_ops_test.py (96%) rename tests/{experimental => }/nn/embedding_test.py (97%) rename tests/{experimental => }/nn/grouping_test.py (98%) rename tests/{experimental => }/nn/layer_stack_test.py (99%) rename tests/{experimental => }/nn/layer_test.py (98%) rename tests/{experimental => }/nn/linear_and_affine_test.py (99%) rename tests/{experimental => }/nn/parameters_test.py (98%) rename tests/{experimental => }/nn/standardization_test.py (97%) rename tests/{experimental => }/toolshed/gradient_checkpointing_test.py (92%) rename tests/{experimental => }/toolshed/isolate_submodel_test.py (96%) rename tests/{experimental => }/toolshed/jit_wrapper_test.py (92%) rename tests/{experimental => }/toolshed/lora_test.py (94%) rename tests/{experimental => }/toolshed/model_rewiring_test.py (98%) rename tests/{experimental => }/toolshed/save_intermediates_test.py (95%) rename tests/{experimental => }/toolshed/unflaxify_test.py (98%) diff --git a/docs/_include/_glue_figures.ipynb b/docs/_include/_glue_figures.ipynb index 02f107f..292b4b6 100644 --- a/docs/_include/_glue_figures.ipynb +++ b/docs/_include/_glue_figures.ipynb @@ -22,7 +22,7 @@ "import myst_nb\n", "\n", "import penzai\n", - "from penzai.experimental.v2 import pz\n", + "from penzai import pz\n", "\n", "import jax\n", "import jax.numpy as jnp\n", @@ -53,8 +53,8 @@ }, "outputs": [], "source": [ - "from penzai.experimental.v2.models.transformer.variants import gemma\n", - "from penzai.experimental.v2.models.transformer.variants import llamalike_common" + "from penzai.models.transformer.variants import gemma\n", + "from penzai.models.transformer.variants import llamalike_common" ] }, { diff --git a/docs/api/penzai.experimental.v2.rst b/docs/api/penzai.experimental.v2.rst index 5edaa9e..691a643 100644 --- a/docs/api/penzai.experimental.v2.rst +++ b/docs/api/penzai.experimental.v2.rst @@ -19,7 +19,7 @@ You can read about the V2 design in the guide :doc:`"How to Think in Penzai (v2 To use the V2 API, we suggest importing the `pz` alias namespace from `penzai.experimental.v2.pz`: :: - from penzai.experimental.v2 import pz + from penzai import pz The rest of this page lists the main components used in the V2 API. diff --git a/docs/guides/howto_reference.md b/docs/guides/howto_reference.md index ae3da66..bdf44a8 100644 --- a/docs/guides/howto_reference.md +++ b/docs/guides/howto_reference.md @@ -5,7 +5,7 @@ This notebook is a guide to accomplishing a variety of tasks with Penzai, using For this guide, we assume you have imported the experimental V2 version of the `pz` alias namespace: ``` -from penzai.experimental.v2 import pz +from penzai import pz ``` ## Visualization @@ -223,7 +223,7 @@ Penzai's Gemma implementation includes a conversion utility that converts the [" ```python import kagglehub import orbax.checkpoint -from penzai.experimental.v2.models.transformer import variants +from penzai.models.transformer import variants weights_dir = kagglehub.model_download('google/gemma/Flax/7b') ckpt_path = os.path.join(weights_dir, '7b') @@ -239,7 +239,7 @@ Penzai also includes re-implementations of the architectures used by [Llama](htt ```python import transformers -from penzai.experimental.v2.models.transformer import variants +from penzai.models.transformer import variants # To load a Llama model: hf_model = transformers.LlamaForCausalLM.from_pretrained(...) diff --git a/docs/guides/v2_differences.md b/docs/guides/v2_differences.md index 9fe14ad..878c032 100644 --- a/docs/guides/v2_differences.md +++ b/docs/guides/v2_differences.md @@ -202,8 +202,8 @@ from penzai.deprecated.v1.example_models import simple_mlp import penzai.toolshed # New -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models import simple_mlp +from penzai import pz +from penzai.models import simple_mlp import penzai.experimental.v2.toolshed ``` @@ -381,7 +381,7 @@ model = gemma.model_core.GemmaTransformer.from_pretrained(flax_params_dict) # (model is an instance of GemmaTransformer) # New -from penzai.experimental.v2.models.transformer import variants +from penzai.models.transformer import variants model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict) # (model is an instance of TransformerLM) ``` diff --git a/notebooks/v2_how_to_think_in_penzai.ipynb b/notebooks/v2_how_to_think_in_penzai.ipynb index 2ff663c..ec0e70d 100644 --- a/notebooks/v2_how_to_think_in_penzai.ipynb +++ b/notebooks/v2_how_to_think_in_penzai.ipynb @@ -91,8 +91,8 @@ "outputs": [], "source": [ "import penzai\n", - "from penzai.experimental.v2 import pz\n", - "from penzai.experimental.v2.models import simple_mlp" + "from penzai import pz\n", + "from penzai.models import simple_mlp" ] }, { @@ -1277,7 +1277,7 @@ }, "outputs": [], "source": [ - "from penzai.experimental.v2.toolshed import basic_training\n", + "from penzai.toolshed import basic_training\n", "import optax" ] }, diff --git a/notebooks/v2_induction_heads.ipynb b/notebooks/v2_induction_heads.ipynb index 2031eee..7a00733 100644 --- a/notebooks/v2_induction_heads.ipynb +++ b/notebooks/v2_induction_heads.ipynb @@ -153,9 +153,9 @@ "outputs": [], "source": [ "import penzai\n", - "from penzai.experimental.v2 import pz\n", + "from penzai import pz\n", "\n", - "from penzai.experimental.v2.models import transformer" + "from penzai.models import transformer" ] }, { diff --git a/notebooks/v2_induction_heads_2B.ipynb b/notebooks/v2_induction_heads_2B.ipynb index fe82074..f536448 100644 --- a/notebooks/v2_induction_heads_2B.ipynb +++ b/notebooks/v2_induction_heads_2B.ipynb @@ -165,9 +165,9 @@ "outputs": [], "source": [ "import penzai\n", - "from penzai.experimental.v2 import pz\n", + "from penzai import pz\n", "\n", - "from penzai.experimental.v2.models import transformer" + "from penzai.models import transformer" ] }, { diff --git a/notebooks/v2_jitting_and_sharding.ipynb b/notebooks/v2_jitting_and_sharding.ipynb index ec5aa58..b0b17e8 100644 --- a/notebooks/v2_jitting_and_sharding.ipynb +++ b/notebooks/v2_jitting_and_sharding.ipynb @@ -135,7 +135,7 @@ "outputs": [], "source": [ "import penzai\n", - "from penzai.experimental.v2 import pz" + "from penzai import pz" ] }, { @@ -146,9 +146,9 @@ }, "outputs": [], "source": [ - "from penzai.experimental.v2.models import transformer\n", - "from penzai.experimental.v2.models import simple_mlp\n", - "from penzai.experimental.v2.toolshed import basic_training" + "from penzai.models import transformer\n", + "from penzai.models import simple_mlp\n", + "from penzai.toolshed import basic_training" ] }, { @@ -610,7 +610,7 @@ }, "outputs": [], "source": [ - "from penzai.experimental.v2.toolshed import jit_wrapper" + "from penzai.toolshed import jit_wrapper" ] }, { @@ -1181,7 +1181,7 @@ }, "outputs": [], "source": [ - "from penzai.experimental.v2.toolshed import sharding_util" + "from penzai.toolshed import sharding_util" ] }, { @@ -1336,8 +1336,8 @@ }, "outputs": [], "source": [ - "from penzai.experimental.v2.toolshed import sharding_util\n", - "from penzai.experimental.v2.models.transformer.variants import llamalike_common" + "from penzai.toolshed import sharding_util\n", + "from penzai.models.transformer.variants import llamalike_common" ] }, { diff --git a/notebooks/v2_lora_from_scratch.ipynb b/notebooks/v2_lora_from_scratch.ipynb index 0063258..bb284cb 100644 --- a/notebooks/v2_lora_from_scratch.ipynb +++ b/notebooks/v2_lora_from_scratch.ipynb @@ -142,7 +142,7 @@ "outputs": [], "source": [ "import penzai\n", - "from penzai.experimental.v2 import pz" + "from penzai import pz" ] }, { @@ -164,11 +164,11 @@ }, "outputs": [], "source": [ - "from penzai.experimental.v2.models import transformer\n", - "from penzai.experimental.v2.models import simple_mlp\n", + "from penzai.models import transformer\n", + "from penzai.models import simple_mlp\n", "from penzai.toolshed import token_visualization\n", - "from penzai.experimental.v2.toolshed import basic_training\n", - "from penzai.experimental.v2.toolshed import jit_wrapper" + "from penzai.toolshed import basic_training\n", + "from penzai.toolshed import jit_wrapper" ] }, { diff --git a/penzai/experimental/v2/core/auto_order_types.py b/penzai/core/auto_order_types.py similarity index 100% rename from penzai/experimental/v2/core/auto_order_types.py rename to penzai/core/auto_order_types.py diff --git a/penzai/experimental/v2/core/random_stream.py b/penzai/core/random_stream.py similarity index 97% rename from penzai/experimental/v2/core/random_stream.py rename to penzai/core/random_stream.py index 6a88801..5497727 100644 --- a/penzai/experimental/v2/core/random_stream.py +++ b/penzai/core/random_stream.py @@ -18,7 +18,7 @@ import jax from penzai.core import struct -from penzai.experimental.v2.core import variables +from penzai.core import variables @struct.pytree_dataclass diff --git a/penzai/experimental/v2/core/variables.py b/penzai/core/variables.py similarity index 99% rename from penzai/experimental/v2/core/variables.py rename to penzai/core/variables.py index fb801b6..07899ed 100644 --- a/penzai/experimental/v2/core/variables.py +++ b/penzai/core/variables.py @@ -77,9 +77,9 @@ import jax import jax.numpy as jnp +from penzai.core import auto_order_types from penzai.core import selectors from penzai.core import struct -from penzai.experimental.v2.core import auto_order_types T = TypeVar("T") diff --git a/penzai/deprecated/v1/pz/__init__.py b/penzai/deprecated/v1/pz/__init__.py index 22ee589..76a0407 100644 --- a/penzai/deprecated/v1/pz/__init__.py +++ b/penzai/deprecated/v1/pz/__init__.py @@ -66,4 +66,4 @@ from . import de from . import nn -from penzai.experimental.v2.pz import ts # pylint: disable=g-bad-import-order +from penzai.pz import ts # pylint: disable=g-bad-import-order diff --git a/penzai/deprecated/v1/toolshed/sharding_util.py b/penzai/deprecated/v1/toolshed/sharding_util.py index 5ee420b..fd43751 100644 --- a/penzai/deprecated/v1/toolshed/sharding_util.py +++ b/penzai/deprecated/v1/toolshed/sharding_util.py @@ -21,7 +21,7 @@ import jax from penzai.deprecated.v1 import pz -from penzai.experimental.v2.toolshed import sharding_util as sharding_util_v2 +from penzai.toolshed import sharding_util as sharding_util_v2 PyTreeOfArrays = Any PyTreeOfNamedArrays = Any diff --git a/penzai/experimental/v2/__init__.py b/penzai/experimental/v2/__init__.py new file mode 100644 index 0000000..7fbbc74 --- /dev/null +++ b/penzai/experimental/v2/__init__.py @@ -0,0 +1,219 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redirector module for Penzai V2. + +Penzai's V2 neural network API was originally defined in +`penzai.experimental.v2`. However, with the release of Penzai V2, the API was +moved to the top-level `penzai` namespace. This module is a redirector to the +new location. +""" + +from __future__ import annotations + +import warnings as _warnings + +_warnings.warn( + "Accessing Penzai's V2 API from `penzai.experimental.v2` is" + " deprecated. As of the 0.2.0 release, the V2 API is now available in" + " the top-level `penzai` namespace. Please use `penzai` directly" + " instead." +) + + +def _make_redirector(stub_module_name: str, target_module_name: str | None): + """Builds a redirector module for the given target module.""" + + # pylint: disable=g-import-not-at-top + import importlib + import importlib.machinery + import importlib.util + import sys + # pylint: enable=g-import-not-at-top + + spec = importlib.machinery.ModuleSpec(name=stub_module_name, loader=None) + mod = importlib.util.module_from_spec(spec) + mod.__doc__ = f"""Redirector module for {stub_module_name}.""" + + if target_module_name is not None: + + def redirecting_getattr(name: str): + if name.startswith("__"): + raise AttributeError( + f"module {repr(stub_module_name)} has no attribute '{name}'" + ) + + reference_module = importlib.import_module(target_module_name) + return getattr(reference_module, name) + + def redirecting_dir(): + reference_module = importlib.import_module(target_module_name) + return dir(reference_module) + + mod.__getattr__ = redirecting_getattr + mod.__dir__ = redirecting_dir + + stub_parent_name, mod_name = stub_module_name.rsplit(".", 1) + setattr(sys.modules[stub_parent_name], mod_name, mod) + + sys.modules[stub_module_name] = mod + + +_make_redirector("penzai.experimental.v2.core", None) +_make_redirector( + "penzai.experimental.v2.core.auto_order_types", + "penzai.core.auto_order_types", +) +_make_redirector( + "penzai.experimental.v2.core.random_stream", + "penzai.core.random_stream", +) +_make_redirector( + "penzai.experimental.v2.core.variables", + "penzai.core.variables", +) + +_make_redirector("penzai.experimental.v2.models", None) +_make_redirector( + "penzai.experimental.v2.models.simple_mlp", + "penzai.models.simple_mlp", +) + +_make_redirector("penzai.experimental.v2.models.transformer", None) +_make_redirector( + "penzai.experimental.v2.models.transformer.model_parts", + "penzai.models.transformer.model_parts", +) +_make_redirector( + "penzai.experimental.v2.models.transformer.sampling_mode", + "penzai.models.transformer.sampling_mode", +) +_make_redirector( + "penzai.experimental.v2.models.transformer.simple_decoding_loop", + "penzai.models.transformer.simple_decoding_loop", +) + +_make_redirector("penzai.experimental.v2.models.transformer.variants", None) +_make_redirector( + "penzai.experimental.v2.models.transformer.variants.", + "penzai.models.transformer.variants.", +) +_make_redirector( + "penzai.experimental.v2.models.transformer.variants.gemma", + "penzai.models.transformer.variants.gemma", +) +_make_redirector( + "penzai.experimental.v2.models.transformer.variants.gpt_neox", + "penzai.models.transformer.variants.gpt_neox", +) +_make_redirector( + "penzai.experimental.v2.models.transformer.variants.llamalike_common", + "penzai.models.transformer.variants.llamalike_common", +) +_make_redirector( + "penzai.experimental.v2.models.transformer.variants.llama", + "penzai.models.transformer.variants.llama", +) +_make_redirector( + "penzai.experimental.v2.models.transformer.variants.mistral", + "penzai.models.transformer.variants.mistral", +) + +_make_redirector("penzai.experimental.v2.nn", None) +_make_redirector( + "penzai.experimental.v2.nn.attention", + "penzai.nn.attention", +) +_make_redirector( + "penzai.experimental.v2.nn.basic_ops", + "penzai.nn.basic_ops", +) +_make_redirector( + "penzai.experimental.v2.nn.combinators", + "penzai.nn.combinators", +) +_make_redirector( + "penzai.experimental.v2.nn.dropout", + "penzai.nn.dropout", +) +_make_redirector( + "penzai.experimental.v2.nn.embeddings", + "penzai.nn.embeddings", +) +_make_redirector( + "penzai.experimental.v2.nn.grouping", + "penzai.nn.grouping", +) +_make_redirector( + "penzai.experimental.v2.nn.layer", + "penzai.nn.layer", +) +_make_redirector( + "penzai.experimental.v2.nn.layer_stack", + "penzai.nn.layer_stack", +) +_make_redirector( + "penzai.experimental.v2.nn.linear_and_affine", + "penzai.nn.linear_and_affine", +) +_make_redirector( + "penzai.experimental.v2.nn.parameters", + "penzai.nn.parameters", +) +_make_redirector( + "penzai.experimental.v2.nn.standardization", + "penzai.nn.standardization", +) + +_make_redirector("penzai.experimental.v2.pz", "penzai.pz") +_make_redirector("penzai.experimental.v2.pz.nn", "penzai.pz.nn") +_make_redirector("penzai.experimental.v2.pz.ts", "penzai.pz.ts") + +_make_redirector("penzai.experimental.v2.toolshed", None) +_make_redirector( + "penzai.experimental.v2.toolshed.basic_training", + "penzai.toolshed.basic_training", +) +_make_redirector( + "penzai.experimental.v2.toolshed.gradient_checkpointing", + "penzai.toolshed.gradient_checkpointing", +) +_make_redirector( + "penzai.experimental.v2.toolshed.isolate_submodel", + "penzai.toolshed.isolate_submodel", +) +_make_redirector( + "penzai.experimental.v2.toolshed.jit_wrapper", + "penzai.toolshed.jit_wrapper", +) +_make_redirector( + "penzai.experimental.v2.toolshed.lora", + "penzai.toolshed.lora", +) +_make_redirector( + "penzai.experimental.v2.toolshed.model_rewiring", + "penzai.toolshed.model_rewiring", +) +_make_redirector( + "penzai.experimental.v2.toolshed.save_intermediates", + "penzai.toolshed.save_intermediates", +) +_make_redirector( + "penzai.experimental.v2.toolshed.sharding_util", + "penzai.toolshed.sharding_util", +) +_make_redirector( + "penzai.experimental.v2.toolshed.unflaxify", + "penzai.toolshed.unflaxify", +) diff --git a/penzai/experimental/v2/pz/__init__.py b/penzai/experimental/v2/pz/__init__.py deleted file mode 100644 index 4ec7651..0000000 --- a/penzai/experimental/v2/pz/__init__.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module of aliases for common penzai classes and functions.""" - -# pylint: disable=g-multiple-import,g-importing-member,unused-import - -import penzai.core.named_axes as nx -from penzai.core.partitioning import ( - NotInThisPartition, - combine, -) -from penzai.core.selectors import ( - Selection, - select, -) -import penzai.core.shapecheck as chk -from penzai.core.struct import ( - Struct, - StructStaticMetadata, - PyTreeDataclassSafetyError, - is_pytree_dataclass_type, - is_pytree_node_field, - pytree_dataclass, -) -# pylint: disable=redefined-builtin -from penzai.core.syntactic_sugar import ( - slice, -) -# pylint: enable=redefined-builtin -from penzai.core.tree_util import ( - pretty_keystr, -) - -from penzai.experimental.v2.core.auto_order_types import ( - AutoOrderedAcrossTypes, -) -from penzai.experimental.v2.core.random_stream import ( - RandomStream, -) -from penzai.experimental.v2.core.variables import ( - VariableConflictError, - UnboundVariableError, - VariableLabel, - AbstractVariable, - AbstractVariableValue, - AbstractVariableSlot, - unbind_variables, - bind_variables, - freeze_variables, - variable_jit, - Parameter, - ParameterValue, - ParameterSlot, - AutoStateVarLabel, - ScopedStateVarLabel, - scoped_auto_state_var_labels, - StateVariable, - StateVariableValue, - StateVariableSlot, - unbind_params, - freeze_params, - unbind_state_vars, - freeze_state_vars, -) -from penzai.treescope._compatibility_setup import ( - show, - disable_interactive_context, - enable_interactive_context, -) -from treescope.context import ( - ContextualValue, -) -from treescope.dataclass_util import ( - dataclass_from_attributes, - init_takes_fields, -) - -from . import nn -from . import ts # pylint: disable=g-bad-import-order diff --git a/penzai/experimental/v2/pz/nn.py b/penzai/experimental/v2/pz/nn.py deleted file mode 100644 index 08f15dc..0000000 --- a/penzai/experimental/v2/pz/nn.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module of aliases for penzai neural networks.""" - -# pylint: disable=g-multiple-import,g-importing-member,unused-import - -from penzai.experimental.v2.nn.attention import ( - ApplyExplicitAttentionMask, - ApplyCausalAttentionMask, - ApplyCausalSlidingWindowAttentionMask, - Attention, - KVCachingAttention, -) -from penzai.experimental.v2.nn.basic_ops import ( - CastToDType, - Elementwise, - Softmax, -) -from penzai.experimental.v2.nn.combinators import ( - Residual, - BranchAndAddTogether, - BranchAndMultiplyTogether, -) -from penzai.experimental.v2.nn.dropout import ( - DisabledDropout, - maybe_dropout, - StochasticDropout, -) -from penzai.experimental.v2.nn.embeddings import ( - EmbeddingTable, - EmbeddingLookup, - EmbeddingDecode, - ApplyRoPE, - ApplyRoPEToSubset, -) -from penzai.experimental.v2.nn.grouping import ( - CheckedSequential, - CheckStructure, - Identity, - inline_anonymous_sequentials, - inline_groups, - is_anonymous_sequential, - is_sequential_or_named, - NamedGroup, - Sequential, -) -from penzai.experimental.v2.nn.layer import ( - Layer, -) -from penzai.experimental.v2.nn.layer_stack import ( - LayerStackVarBehavior, - LayerStackGetAttrKey, - LayerStack, - layerstack_axes_from_keypath, -) -from penzai.experimental.v2.nn.linear_and_affine import ( - AddBias, - ConstantRescale, - NamedEinsum, - Affine, - Linear, - LinearOperatorWeightInitializer, - LinearInPlace, - RenameAxes, - contract, - variance_scaling_initializer, - xavier_normal_initializer, - xavier_uniform_initializer, - constant_initializer, - zero_initializer, -) -from penzai.experimental.v2.nn.parameters import ( - ParameterLike, - derive_param_key, - make_parameter, - assert_no_parameter_slots, -) -from penzai.experimental.v2.nn.standardization import ( - LayerNorm, - Standardize, - RMSLayerNorm, - RMSStandardize, -) diff --git a/penzai/experimental/v2/pz/ts.py b/penzai/experimental/v2/pz/ts.py deleted file mode 100644 index 1bc989b..0000000 --- a/penzai/experimental/v2/pz/ts.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Common aliases for treescope.""" - -# pylint: disable=g-multiple-import,g-importing-member,unused-import - -from penzai.treescope._compatibility_setup import ( - display, - register_as_default, - basic_interactive_setup, -) -from treescope import ( - ArrayAutovisualizer, - active_autovisualizer, - active_expansion_strategy, - active_renderer, - Autovisualizer, - ChildAutovisualizer, - default_diverging_colormap, - default_magic_autovisualizer, - default_sequential_colormap, - integer_digitbox, - IPythonVisualization, - register_autovisualize_magic, - register_context_manager_magic, - render_array_sharding, - render_array, - render_to_html, - render_to_text, - using_expansion_strategy, -) -from treescope.figures import ( - inline, - indented, - with_font_size, - with_color, - bolded, - styled, - text_on_color, -) - -vocab_autovisualizer = ArrayAutovisualizer.for_tokenizer diff --git a/penzai/experimental/v2/models/simple_mlp.py b/penzai/models/simple_mlp.py similarity index 98% rename from penzai/experimental/v2/models/simple_mlp.py rename to penzai/models/simple_mlp.py index 1763217..5c445ed 100644 --- a/penzai/experimental/v2/models/simple_mlp.py +++ b/penzai/models/simple_mlp.py @@ -19,7 +19,7 @@ from typing import Callable import jax -from penzai.experimental.v2 import pz +from penzai import pz @pz.pytree_dataclass(has_implicitly_inherited_fields=True) diff --git a/penzai/experimental/v2/models/transformer/__init__.py b/penzai/models/transformer/__init__.py similarity index 100% rename from penzai/experimental/v2/models/transformer/__init__.py rename to penzai/models/transformer/__init__.py diff --git a/penzai/experimental/v2/models/transformer/model_parts.py b/penzai/models/transformer/model_parts.py similarity index 99% rename from penzai/experimental/v2/models/transformer/model_parts.py rename to penzai/models/transformer/model_parts.py index ad57078..12e807f 100644 --- a/penzai/experimental/v2/models/transformer/model_parts.py +++ b/penzai/models/transformer/model_parts.py @@ -59,7 +59,7 @@ import dataclasses import jax -from penzai.experimental.v2 import pz +from penzai import pz @dataclasses.dataclass diff --git a/penzai/experimental/v2/models/transformer/sampling_mode.py b/penzai/models/transformer/sampling_mode.py similarity index 98% rename from penzai/experimental/v2/models/transformer/sampling_mode.py rename to penzai/models/transformer/sampling_mode.py index 08542da..cf175a7 100644 --- a/penzai/experimental/v2/models/transformer/sampling_mode.py +++ b/penzai/models/transformer/sampling_mode.py @@ -37,8 +37,8 @@ import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models.transformer import model_parts +from penzai import pz +from penzai.models.transformer import model_parts @pz.pytree_dataclass diff --git a/penzai/experimental/v2/models/transformer/simple_decoding_loop.py b/penzai/models/transformer/simple_decoding_loop.py similarity index 96% rename from penzai/experimental/v2/models/transformer/simple_decoding_loop.py rename to penzai/models/transformer/simple_decoding_loop.py index 042f132..af794f1 100644 --- a/penzai/experimental/v2/models/transformer/simple_decoding_loop.py +++ b/penzai/models/transformer/simple_decoding_loop.py @@ -22,8 +22,8 @@ import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models.transformer import sampling_mode +from penzai import pz +from penzai.models.transformer import sampling_mode def temperature_sample_pyloop( diff --git a/penzai/experimental/v2/models/transformer/variants/__init__.py b/penzai/models/transformer/variants/__init__.py similarity index 100% rename from penzai/experimental/v2/models/transformer/variants/__init__.py rename to penzai/models/transformer/variants/__init__.py diff --git a/penzai/experimental/v2/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py similarity index 97% rename from penzai/experimental/v2/models/transformer/variants/gemma.py rename to penzai/models/transformer/variants/gemma.py index 63c0b05..eeb7796 100644 --- a/penzai/experimental/v2/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -26,9 +26,9 @@ from typing import Any import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models.transformer import model_parts -from penzai.experimental.v2.models.transformer.variants import llamalike_common +from penzai import pz +from penzai.models.transformer import model_parts +from penzai.models.transformer.variants import llamalike_common def gemma_from_pretrained_checkpoint( diff --git a/penzai/experimental/v2/models/transformer/variants/gpt_neox.py b/penzai/models/transformer/variants/gpt_neox.py similarity index 99% rename from penzai/experimental/v2/models/transformer/variants/gpt_neox.py rename to penzai/models/transformer/variants/gpt_neox.py index 1524b94..6077584 100644 --- a/penzai/experimental/v2/models/transformer/variants/gpt_neox.py +++ b/penzai/models/transformer/variants/gpt_neox.py @@ -34,8 +34,8 @@ import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models.transformer import model_parts +from penzai import pz +from penzai.models.transformer import model_parts @dataclasses.dataclass diff --git a/penzai/experimental/v2/models/transformer/variants/llama.py b/penzai/models/transformer/variants/llama.py similarity index 94% rename from penzai/experimental/v2/models/transformer/variants/llama.py rename to penzai/models/transformer/variants/llama.py index 185fab7..d27257c 100644 --- a/penzai/experimental/v2/models/transformer/variants/llama.py +++ b/penzai/models/transformer/variants/llama.py @@ -18,8 +18,8 @@ from typing import Any -from penzai.experimental.v2.models.transformer import model_parts -from penzai.experimental.v2.models.transformer.variants import llamalike_common +from penzai.models.transformer import model_parts +from penzai.models.transformer.variants import llamalike_common LlamaForCausalLM = Any diff --git a/penzai/experimental/v2/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py similarity index 99% rename from penzai/experimental/v2/models/transformer/variants/llamalike_common.py rename to penzai/models/transformer/variants/llamalike_common.py index 96a0b7d..f6cc013 100644 --- a/penzai/experimental/v2/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -37,8 +37,8 @@ import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models.transformer import model_parts +from penzai import pz +from penzai.models.transformer import model_parts @dataclasses.dataclass(frozen=True) diff --git a/penzai/experimental/v2/models/transformer/variants/mistral.py b/penzai/models/transformer/variants/mistral.py similarity index 95% rename from penzai/experimental/v2/models/transformer/variants/mistral.py rename to penzai/models/transformer/variants/mistral.py index 64ed25b..fc6daf7 100644 --- a/penzai/experimental/v2/models/transformer/variants/mistral.py +++ b/penzai/models/transformer/variants/mistral.py @@ -18,8 +18,8 @@ from typing import Any -from penzai.experimental.v2.models.transformer import model_parts -from penzai.experimental.v2.models.transformer.variants import llamalike_common +from penzai.models.transformer import model_parts +from penzai.models.transformer.variants import llamalike_common MistralForCausalLM = Any diff --git a/penzai/experimental/v2/nn/_treescope_handlers/layer_handler.py b/penzai/nn/_treescope_handlers/layer_handler.py similarity index 96% rename from penzai/experimental/v2/nn/_treescope_handlers/layer_handler.py rename to penzai/nn/_treescope_handlers/layer_handler.py index ca9d973..74dfe52 100644 --- a/penzai/experimental/v2/nn/_treescope_handlers/layer_handler.py +++ b/penzai/nn/_treescope_handlers/layer_handler.py @@ -19,8 +19,8 @@ import dataclasses from penzai.core._treescope_handlers import struct_handler -from penzai.experimental.v2.nn import grouping -from penzai.experimental.v2.nn import layer +from penzai.nn import grouping +from penzai.nn import layer from treescope import formatting_util from treescope import renderers from treescope import rendering_parts diff --git a/penzai/experimental/v2/nn/attention.py b/penzai/nn/attention.py similarity index 99% rename from penzai/experimental/v2/nn/attention.py rename to penzai/nn/attention.py index 94199a5..7369f97 100644 --- a/penzai/experimental/v2/nn/attention.py +++ b/penzai/nn/attention.py @@ -30,9 +30,9 @@ import jax.numpy as jnp from penzai.core import named_axes from penzai.core import struct -from penzai.experimental.v2.core import variables -from penzai.experimental.v2.nn import layer as layer_base -from penzai.experimental.v2.nn import layer_stack +from penzai.core import variables +from penzai.nn import layer as layer_base +from penzai.nn import layer_stack @struct.pytree_dataclass diff --git a/penzai/experimental/v2/nn/basic_ops.py b/penzai/nn/basic_ops.py similarity index 98% rename from penzai/experimental/v2/nn/basic_ops.py rename to penzai/nn/basic_ops.py index 7de3f95..51211b7 100644 --- a/penzai/experimental/v2/nn/basic_ops.py +++ b/penzai/nn/basic_ops.py @@ -22,7 +22,7 @@ import jax from penzai.core import named_axes from penzai.core import struct -from penzai.experimental.v2.nn import layer +from penzai.nn import layer @struct.pytree_dataclass diff --git a/penzai/experimental/v2/nn/combinators.py b/penzai/nn/combinators.py similarity index 98% rename from penzai/experimental/v2/nn/combinators.py rename to penzai/nn/combinators.py index 89e69d0..7399f01 100644 --- a/penzai/experimental/v2/nn/combinators.py +++ b/penzai/nn/combinators.py @@ -24,7 +24,7 @@ from typing import Any from penzai.core import struct -from penzai.experimental.v2.nn import layer as layer_base +from penzai.nn import layer as layer_base @struct.pytree_dataclass diff --git a/penzai/experimental/v2/nn/dropout.py b/penzai/nn/dropout.py similarity index 98% rename from penzai/experimental/v2/nn/dropout.py rename to penzai/nn/dropout.py index 93f2dd5..b5ad815 100644 --- a/penzai/experimental/v2/nn/dropout.py +++ b/penzai/nn/dropout.py @@ -22,8 +22,8 @@ import jax.numpy as jnp from penzai.core import named_axes from penzai.core import struct -from penzai.experimental.v2.nn import grouping -from penzai.experimental.v2.nn import layer +from penzai.nn import grouping +from penzai.nn import layer @struct.pytree_dataclass diff --git a/penzai/experimental/v2/nn/embeddings.py b/penzai/nn/embeddings.py similarity index 98% rename from penzai/experimental/v2/nn/embeddings.py rename to penzai/nn/embeddings.py index 1cf07b3..e7fff97 100644 --- a/penzai/experimental/v2/nn/embeddings.py +++ b/penzai/nn/embeddings.py @@ -27,9 +27,9 @@ from penzai.core import named_axes from penzai.core import struct from penzai.core import syntactic_sugar -from penzai.experimental.v2.nn import layer as layer_base -from penzai.experimental.v2.nn import linear_and_affine -from penzai.experimental.v2.nn import parameters +from penzai.nn import layer as layer_base +from penzai.nn import linear_and_affine +from penzai.nn import parameters _slice = syntactic_sugar.slice diff --git a/penzai/experimental/v2/nn/grouping.py b/penzai/nn/grouping.py similarity index 99% rename from penzai/experimental/v2/nn/grouping.py rename to penzai/nn/grouping.py index f3ae618..6b38dce 100644 --- a/penzai/experimental/v2/nn/grouping.py +++ b/penzai/nn/grouping.py @@ -22,7 +22,7 @@ from penzai.core import selectors from penzai.core import shapecheck from penzai.core import struct -from penzai.experimental.v2.nn import layer as layer_base +from penzai.nn import layer as layer_base @struct.pytree_dataclass diff --git a/penzai/experimental/v2/nn/layer.py b/penzai/nn/layer.py similarity index 97% rename from penzai/experimental/v2/nn/layer.py rename to penzai/nn/layer.py index f0a9185..507f50e 100644 --- a/penzai/experimental/v2/nn/layer.py +++ b/penzai/nn/layer.py @@ -39,7 +39,7 @@ from typing import Any, Iterable from penzai.core import struct -from penzai.experimental.v2.core import variables as vars_lib +from penzai.core import variables as vars_lib class Layer(struct.Struct, abc.ABC): @@ -155,5 +155,6 @@ def stateless_call( return result, tuple(var.freeze() for var in mut_vars) def __treescope_repr__(self, path: str | None, subtree_renderer: Any): - from penzai.experimental.v2.nn._treescope_handlers import layer_handler # pylint: disable=g-import-not-at-top + from penzai.nn._treescope_handlers import layer_handler # pylint: disable=g-import-not-at-top + return layer_handler.handle_layer(self, path, subtree_renderer) diff --git a/penzai/experimental/v2/nn/layer_stack.py b/penzai/nn/layer_stack.py similarity index 99% rename from penzai/experimental/v2/nn/layer_stack.py rename to penzai/nn/layer_stack.py index 25d42ff..b349786 100644 --- a/penzai/experimental/v2/nn/layer_stack.py +++ b/penzai/nn/layer_stack.py @@ -27,8 +27,8 @@ from penzai.core import selectors from penzai.core import struct from penzai.core import tree_util as pz_tree_util -from penzai.experimental.v2.core import variables -from penzai.experimental.v2.nn import layer as layer_base +from penzai.core import variables +from penzai.nn import layer as layer_base class LayerStackVarBehavior(enum.Enum): diff --git a/penzai/experimental/v2/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py similarity index 99% rename from penzai/experimental/v2/nn/linear_and_affine.py rename to penzai/nn/linear_and_affine.py index 45748ce..4dfe7f2 100644 --- a/penzai/experimental/v2/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -27,9 +27,9 @@ from penzai.core import named_axes from penzai.core import shapecheck from penzai.core import struct -from penzai.experimental.v2.nn import grouping -from penzai.experimental.v2.nn import layer as layer_base -from penzai.experimental.v2.nn import parameters +from penzai.nn import grouping +from penzai.nn import layer as layer_base +from penzai.nn import parameters NamedArray = named_axes.NamedArray diff --git a/penzai/experimental/v2/nn/parameters.py b/penzai/nn/parameters.py similarity index 98% rename from penzai/experimental/v2/nn/parameters.py rename to penzai/nn/parameters.py index 5d8bf43..5d4d099 100644 --- a/penzai/experimental/v2/nn/parameters.py +++ b/penzai/nn/parameters.py @@ -35,7 +35,7 @@ import jax from penzai.core import selectors -from penzai.experimental.v2.core import variables +from penzai.core import variables T = TypeVar("T") diff --git a/penzai/experimental/v2/nn/standardization.py b/penzai/nn/standardization.py similarity index 97% rename from penzai/experimental/v2/nn/standardization.py rename to penzai/nn/standardization.py index 5900683..4bfb4b9 100644 --- a/penzai/experimental/v2/nn/standardization.py +++ b/penzai/nn/standardization.py @@ -22,9 +22,9 @@ import jax.numpy as jnp from penzai.core import named_axes from penzai.core import struct -from penzai.experimental.v2.nn import grouping -from penzai.experimental.v2.nn import layer as layer_base -from penzai.experimental.v2.nn import linear_and_affine +from penzai.nn import grouping +from penzai.nn import layer as layer_base +from penzai.nn import linear_and_affine NamedArray = named_axes.NamedArray diff --git a/penzai/pz/__init__.py b/penzai/pz/__init__.py index 4ec7651..6db776b 100644 --- a/penzai/pz/__init__.py +++ b/penzai/pz/__init__.py @@ -16,11 +16,17 @@ # pylint: disable=g-multiple-import,g-importing-member,unused-import +from penzai.core.auto_order_types import ( + AutoOrderedAcrossTypes, +) import penzai.core.named_axes as nx from penzai.core.partitioning import ( NotInThisPartition, combine, ) +from penzai.core.random_stream import ( + RandomStream, +) from penzai.core.selectors import ( Selection, select, @@ -42,14 +48,7 @@ from penzai.core.tree_util import ( pretty_keystr, ) - -from penzai.experimental.v2.core.auto_order_types import ( - AutoOrderedAcrossTypes, -) -from penzai.experimental.v2.core.random_stream import ( - RandomStream, -) -from penzai.experimental.v2.core.variables import ( +from penzai.core.variables import ( VariableConflictError, UnboundVariableError, VariableLabel, diff --git a/penzai/pz/nn.py b/penzai/pz/nn.py index 08f15dc..0a96503 100644 --- a/penzai/pz/nn.py +++ b/penzai/pz/nn.py @@ -16,36 +16,36 @@ # pylint: disable=g-multiple-import,g-importing-member,unused-import -from penzai.experimental.v2.nn.attention import ( +from penzai.nn.attention import ( ApplyExplicitAttentionMask, ApplyCausalAttentionMask, ApplyCausalSlidingWindowAttentionMask, Attention, KVCachingAttention, ) -from penzai.experimental.v2.nn.basic_ops import ( +from penzai.nn.basic_ops import ( CastToDType, Elementwise, Softmax, ) -from penzai.experimental.v2.nn.combinators import ( +from penzai.nn.combinators import ( Residual, BranchAndAddTogether, BranchAndMultiplyTogether, ) -from penzai.experimental.v2.nn.dropout import ( +from penzai.nn.dropout import ( DisabledDropout, maybe_dropout, StochasticDropout, ) -from penzai.experimental.v2.nn.embeddings import ( +from penzai.nn.embeddings import ( EmbeddingTable, EmbeddingLookup, EmbeddingDecode, ApplyRoPE, ApplyRoPEToSubset, ) -from penzai.experimental.v2.nn.grouping import ( +from penzai.nn.grouping import ( CheckedSequential, CheckStructure, Identity, @@ -56,16 +56,16 @@ NamedGroup, Sequential, ) -from penzai.experimental.v2.nn.layer import ( +from penzai.nn.layer import ( Layer, ) -from penzai.experimental.v2.nn.layer_stack import ( +from penzai.nn.layer_stack import ( LayerStackVarBehavior, LayerStackGetAttrKey, LayerStack, layerstack_axes_from_keypath, ) -from penzai.experimental.v2.nn.linear_and_affine import ( +from penzai.nn.linear_and_affine import ( AddBias, ConstantRescale, NamedEinsum, @@ -81,13 +81,13 @@ constant_initializer, zero_initializer, ) -from penzai.experimental.v2.nn.parameters import ( +from penzai.nn.parameters import ( ParameterLike, derive_param_key, make_parameter, assert_no_parameter_slots, ) -from penzai.experimental.v2.nn.standardization import ( +from penzai.nn.standardization import ( LayerNorm, Standardize, RMSLayerNorm, diff --git a/penzai/experimental/v2/toolshed/basic_training.py b/penzai/toolshed/basic_training.py similarity index 99% rename from penzai/experimental/v2/toolshed/basic_training.py rename to penzai/toolshed/basic_training.py index fa6e918..b28b901 100644 --- a/penzai/experimental/v2/toolshed/basic_training.py +++ b/penzai/toolshed/basic_training.py @@ -27,7 +27,7 @@ import jax import optax -from penzai.experimental.v2 import pz +from penzai import pz ModelPyTree = Any AuxOutPyTree = Any diff --git a/penzai/experimental/v2/toolshed/gradient_checkpointing.py b/penzai/toolshed/gradient_checkpointing.py similarity index 98% rename from penzai/experimental/v2/toolshed/gradient_checkpointing.py rename to penzai/toolshed/gradient_checkpointing.py index a1e5cc3..d12faa6 100644 --- a/penzai/experimental/v2/toolshed/gradient_checkpointing.py +++ b/penzai/toolshed/gradient_checkpointing.py @@ -31,7 +31,7 @@ from typing import Any, Callable import jax -from penzai.experimental.v2 import pz +from penzai import pz def _flat_stateless_call( diff --git a/penzai/experimental/v2/toolshed/isolate_submodel.py b/penzai/toolshed/isolate_submodel.py similarity index 99% rename from penzai/experimental/v2/toolshed/isolate_submodel.py rename to penzai/toolshed/isolate_submodel.py index f5ad3d9..6b3920a 100644 --- a/penzai/experimental/v2/toolshed/isolate_submodel.py +++ b/penzai/toolshed/isolate_submodel.py @@ -28,7 +28,7 @@ from typing import Any -from penzai.experimental.v2 import pz +from penzai import pz @pz.pytree_dataclass diff --git a/penzai/experimental/v2/toolshed/jit_wrapper.py b/penzai/toolshed/jit_wrapper.py similarity index 98% rename from penzai/experimental/v2/toolshed/jit_wrapper.py rename to penzai/toolshed/jit_wrapper.py index 4495ea3..f0d3e69 100644 --- a/penzai/experimental/v2/toolshed/jit_wrapper.py +++ b/penzai/toolshed/jit_wrapper.py @@ -27,7 +27,7 @@ from typing import Any -from penzai.experimental.v2 import pz +from penzai import pz @pz.variable_jit diff --git a/penzai/experimental/v2/toolshed/lora.py b/penzai/toolshed/lora.py similarity index 99% rename from penzai/experimental/v2/toolshed/lora.py rename to penzai/toolshed/lora.py index 4b84204..73b95f8 100644 --- a/penzai/experimental/v2/toolshed/lora.py +++ b/penzai/toolshed/lora.py @@ -32,7 +32,7 @@ from typing import Any import jax -from penzai.experimental.v2 import pz +from penzai import pz @pz.pytree_dataclass(has_implicitly_inherited_fields=True) diff --git a/penzai/experimental/v2/toolshed/model_rewiring.py b/penzai/toolshed/model_rewiring.py similarity index 99% rename from penzai/experimental/v2/toolshed/model_rewiring.py rename to penzai/toolshed/model_rewiring.py index a7b4a76..8fa62b2 100644 --- a/penzai/experimental/v2/toolshed/model_rewiring.py +++ b/penzai/toolshed/model_rewiring.py @@ -29,7 +29,7 @@ import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz +from penzai import pz @pz.pytree_dataclass diff --git a/penzai/experimental/v2/toolshed/save_intermediates.py b/penzai/toolshed/save_intermediates.py similarity index 99% rename from penzai/experimental/v2/toolshed/save_intermediates.py rename to penzai/toolshed/save_intermediates.py index bab25be..1aca1e6 100644 --- a/penzai/experimental/v2/toolshed/save_intermediates.py +++ b/penzai/toolshed/save_intermediates.py @@ -33,7 +33,7 @@ import dataclasses from typing import Any -from penzai.experimental.v2 import pz +from penzai import pz # Disable false-positive pylint warning. # pylint: disable=assigning-non-slot diff --git a/penzai/experimental/v2/toolshed/sharding_util.py b/penzai/toolshed/sharding_util.py similarity index 99% rename from penzai/experimental/v2/toolshed/sharding_util.py rename to penzai/toolshed/sharding_util.py index de73242..9100d50 100644 --- a/penzai/experimental/v2/toolshed/sharding_util.py +++ b/penzai/toolshed/sharding_util.py @@ -20,7 +20,7 @@ from typing import Any, Callable import jax -from penzai.experimental.v2 import pz +from penzai import pz PyTreeOfArrays = Any PyTreeOfNamedArrays = Any diff --git a/penzai/experimental/v2/toolshed/unflaxify.py b/penzai/toolshed/unflaxify.py similarity index 99% rename from penzai/experimental/v2/toolshed/unflaxify.py rename to penzai/toolshed/unflaxify.py index 1f148bb..b7af2a4 100644 --- a/penzai/experimental/v2/toolshed/unflaxify.py +++ b/penzai/toolshed/unflaxify.py @@ -34,7 +34,7 @@ import flax import flax.typing import jax -from penzai.experimental.v2 import pz +from penzai import pz from treescope import formatting_util diff --git a/tests/experimental/auto_order_types_test.py b/tests/core/auto_order_types_test.py similarity index 96% rename from tests/experimental/auto_order_types_test.py rename to tests/core/auto_order_types_test.py index 2896f74..4d29603 100644 --- a/tests/experimental/auto_order_types_test.py +++ b/tests/core/auto_order_types_test.py @@ -16,7 +16,7 @@ import dataclasses from absl.testing import absltest import jax -from penzai.experimental.v2.core import auto_order_types +from penzai.core import auto_order_types class AutoOrderTypesTest(absltest.TestCase): diff --git a/tests/misc_util_test.py b/tests/core/misc_util_test.py similarity index 100% rename from tests/misc_util_test.py rename to tests/core/misc_util_test.py diff --git a/tests/named_axes_test.py b/tests/core/named_axes_test.py similarity index 100% rename from tests/named_axes_test.py rename to tests/core/named_axes_test.py diff --git a/tests/partitioning_test.py b/tests/core/partitioning_test.py similarity index 100% rename from tests/partitioning_test.py rename to tests/core/partitioning_test.py diff --git a/tests/selectors_test.py b/tests/core/selectors_test.py similarity index 100% rename from tests/selectors_test.py rename to tests/core/selectors_test.py diff --git a/tests/shapecheck_test.py b/tests/core/shapecheck_test.py similarity index 100% rename from tests/shapecheck_test.py rename to tests/core/shapecheck_test.py diff --git a/tests/struct_pytree_dataclass_test.py b/tests/core/struct_pytree_dataclass_test.py similarity index 100% rename from tests/struct_pytree_dataclass_test.py rename to tests/core/struct_pytree_dataclass_test.py diff --git a/tests/experimental/variables_test.py b/tests/core/variables_test.py similarity index 99% rename from tests/experimental/variables_test.py rename to tests/core/variables_test.py index 539c0d2..24f23d0 100644 --- a/tests/experimental/variables_test.py +++ b/tests/core/variables_test.py @@ -18,7 +18,7 @@ import chex import jax import jax.numpy as jnp -from penzai.experimental.v2.core import variables +from penzai.core import variables class VariablesTest(parameterized.TestCase): diff --git a/tests/experimental/models/__init__.py b/tests/experimental/models/__init__.py deleted file mode 100644 index 5f20c8b..0000000 --- a/tests/experimental/models/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/experimental/nn/__init__.py b/tests/experimental/nn/__init__.py deleted file mode 100644 index 5f20c8b..0000000 --- a/tests/experimental/nn/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/experimental/toolshed/__init__.py b/tests/experimental/toolshed/__init__.py deleted file mode 100644 index 5f20c8b..0000000 --- a/tests/experimental/toolshed/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/penzai/experimental/v2/core/__init__.py b/tests/models/__init__.py similarity index 100% rename from penzai/experimental/v2/core/__init__.py rename to tests/models/__init__.py diff --git a/tests/experimental/models/simple_mlp_test.py b/tests/models/simple_mlp_test.py similarity index 95% rename from tests/experimental/models/simple_mlp_test.py rename to tests/models/simple_mlp_test.py index 15e8afb..9bb4859 100644 --- a/tests/experimental/models/simple_mlp_test.py +++ b/tests/models/simple_mlp_test.py @@ -18,9 +18,9 @@ import jax import jax.numpy as jnp import optax -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models import simple_mlp -from penzai.experimental.v2.toolshed import basic_training +from penzai import pz +from penzai.models import simple_mlp +from penzai.toolshed import basic_training class SimpleMlpTest(absltest.TestCase): diff --git a/tests/experimental/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py similarity index 95% rename from tests/experimental/models/transformer_consistency_test.py rename to tests/models/transformer_consistency_test.py index 122fa0b..cc6a166 100644 --- a/tests/experimental/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -19,10 +19,10 @@ import chex import jax.numpy as jnp import numpy as np -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models.transformer.variants import gpt_neox -from penzai.experimental.v2.models.transformer.variants import llama -from penzai.experimental.v2.models.transformer.variants import mistral +from penzai import pz +from penzai.models.transformer.variants import gpt_neox +from penzai.models.transformer.variants import llama +from penzai.models.transformer.variants import mistral import torch import transformers diff --git a/tests/experimental/models/transformer_llamalike_test.py b/tests/models/transformer_llamalike_test.py similarity index 96% rename from tests/experimental/models/transformer_llamalike_test.py rename to tests/models/transformer_llamalike_test.py index aeda885..4b0af26 100644 --- a/tests/experimental/models/transformer_llamalike_test.py +++ b/tests/models/transformer_llamalike_test.py @@ -18,10 +18,10 @@ from absl.testing import parameterized import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models.transformer import sampling_mode -from penzai.experimental.v2.models.transformer import simple_decoding_loop -from penzai.experimental.v2.models.transformer.variants import llamalike_common +from penzai import pz +from penzai.models.transformer import sampling_mode +from penzai.models.transformer import simple_decoding_loop +from penzai.models.transformer.variants import llamalike_common class LlamalikeTransformerTest(parameterized.TestCase): diff --git a/tests/experimental/__init__.py b/tests/nn/__init__.py similarity index 100% rename from tests/experimental/__init__.py rename to tests/nn/__init__.py diff --git a/tests/experimental/nn/basic_ops_test.py b/tests/nn/basic_ops_test.py similarity index 96% rename from tests/experimental/nn/basic_ops_test.py rename to tests/nn/basic_ops_test.py index 17d0456..af69bfe 100644 --- a/tests/experimental/nn/basic_ops_test.py +++ b/tests/nn/basic_ops_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest import jax -from penzai.experimental.v2 import pz +from penzai import pz class BasicOpsTest(absltest.TestCase): diff --git a/tests/experimental/nn/embedding_test.py b/tests/nn/embedding_test.py similarity index 97% rename from tests/experimental/nn/embedding_test.py rename to tests/nn/embedding_test.py index c2a05ed..5f1abc4 100644 --- a/tests/experimental/nn/embedding_test.py +++ b/tests/nn/embedding_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest import jax -from penzai.experimental.v2 import pz +from penzai import pz class EmbeddingTest(absltest.TestCase): diff --git a/tests/experimental/nn/grouping_test.py b/tests/nn/grouping_test.py similarity index 98% rename from tests/experimental/nn/grouping_test.py rename to tests/nn/grouping_test.py index e18c04a..abb3f06 100644 --- a/tests/experimental/nn/grouping_test.py +++ b/tests/nn/grouping_test.py @@ -18,7 +18,7 @@ from unittest import mock from absl.testing import absltest from absl.testing import parameterized -from penzai.experimental.v2 import pz +from penzai import pz @pz.pytree_dataclass diff --git a/tests/experimental/nn/layer_stack_test.py b/tests/nn/layer_stack_test.py similarity index 99% rename from tests/experimental/nn/layer_stack_test.py rename to tests/nn/layer_stack_test.py index cf3bdba..34250e7 100644 --- a/tests/experimental/nn/layer_stack_test.py +++ b/tests/nn/layer_stack_test.py @@ -18,7 +18,7 @@ from absl.testing import absltest import chex import jax -from penzai.experimental.v2 import pz +from penzai import pz @pz.pytree_dataclass diff --git a/tests/experimental/nn/layer_test.py b/tests/nn/layer_test.py similarity index 98% rename from tests/experimental/nn/layer_test.py rename to tests/nn/layer_test.py index 5f8ad93..d58c83b 100644 --- a/tests/experimental/nn/layer_test.py +++ b/tests/nn/layer_test.py @@ -15,7 +15,7 @@ """Tests for layer helpers.""" from absl.testing import absltest -from penzai.experimental.v2 import pz +from penzai import pz class LayerTest(absltest.TestCase): diff --git a/tests/experimental/nn/linear_and_affine_test.py b/tests/nn/linear_and_affine_test.py similarity index 99% rename from tests/experimental/nn/linear_and_affine_test.py rename to tests/nn/linear_and_affine_test.py index 2e4ce42..3c62467 100644 --- a/tests/experimental/nn/linear_and_affine_test.py +++ b/tests/nn/linear_and_affine_test.py @@ -17,7 +17,7 @@ from absl.testing import absltest import chex import jax -from penzai.experimental.v2 import pz +from penzai import pz class LinearAndAffineTest(absltest.TestCase): diff --git a/tests/experimental/nn/parameters_test.py b/tests/nn/parameters_test.py similarity index 98% rename from tests/experimental/nn/parameters_test.py rename to tests/nn/parameters_test.py index bf1d1b6..e317b53 100644 --- a/tests/experimental/nn/parameters_test.py +++ b/tests/nn/parameters_test.py @@ -20,7 +20,7 @@ from absl.testing import parameterized import chex import jax -from penzai.experimental.v2 import pz +from penzai import pz class NNParametersTest(parameterized.TestCase): diff --git a/tests/experimental/nn/standardization_test.py b/tests/nn/standardization_test.py similarity index 97% rename from tests/experimental/nn/standardization_test.py rename to tests/nn/standardization_test.py index 6185aed..917596c 100644 --- a/tests/experimental/nn/standardization_test.py +++ b/tests/nn/standardization_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest import jax -from penzai.experimental.v2 import pz +from penzai import pz class StandardizationTest(absltest.TestCase): diff --git a/tests/experimental/toolshed/gradient_checkpointing_test.py b/tests/toolshed/gradient_checkpointing_test.py similarity index 92% rename from tests/experimental/toolshed/gradient_checkpointing_test.py rename to tests/toolshed/gradient_checkpointing_test.py index bdd787c..d3d282d 100644 --- a/tests/experimental/toolshed/gradient_checkpointing_test.py +++ b/tests/toolshed/gradient_checkpointing_test.py @@ -18,9 +18,9 @@ import chex import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models import simple_mlp -from penzai.experimental.v2.toolshed import gradient_checkpointing +from penzai import pz +from penzai.models import simple_mlp +from penzai.toolshed import gradient_checkpointing class GradientCheckpointingTest(absltest.TestCase): diff --git a/tests/experimental/toolshed/isolate_submodel_test.py b/tests/toolshed/isolate_submodel_test.py similarity index 96% rename from tests/experimental/toolshed/isolate_submodel_test.py rename to tests/toolshed/isolate_submodel_test.py index 2652681..0ceb496 100644 --- a/tests/experimental/toolshed/isolate_submodel_test.py +++ b/tests/toolshed/isolate_submodel_test.py @@ -17,9 +17,9 @@ from absl.testing import absltest import chex import jax -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models import simple_mlp -from penzai.experimental.v2.toolshed import isolate_submodel +from penzai import pz +from penzai.models import simple_mlp +from penzai.toolshed import isolate_submodel class IsolateSubmodelTest(absltest.TestCase): diff --git a/tests/experimental/toolshed/jit_wrapper_test.py b/tests/toolshed/jit_wrapper_test.py similarity index 92% rename from tests/experimental/toolshed/jit_wrapper_test.py rename to tests/toolshed/jit_wrapper_test.py index 58f21dc..921db23 100644 --- a/tests/experimental/toolshed/jit_wrapper_test.py +++ b/tests/toolshed/jit_wrapper_test.py @@ -18,9 +18,9 @@ import chex import jax import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models import simple_mlp -from penzai.experimental.v2.toolshed import jit_wrapper +from penzai import pz +from penzai.models import simple_mlp +from penzai.toolshed import jit_wrapper class JitWrapperTest(absltest.TestCase): diff --git a/tests/experimental/toolshed/lora_test.py b/tests/toolshed/lora_test.py similarity index 94% rename from tests/experimental/toolshed/lora_test.py rename to tests/toolshed/lora_test.py index 25646fc..bcb08b0 100644 --- a/tests/experimental/toolshed/lora_test.py +++ b/tests/toolshed/lora_test.py @@ -18,10 +18,10 @@ import jax import jax.numpy as jnp import optax -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models import simple_mlp -from penzai.experimental.v2.toolshed import basic_training -from penzai.experimental.v2.toolshed import lora +from penzai import pz +from penzai.models import simple_mlp +from penzai.toolshed import basic_training +from penzai.toolshed import lora class LoraTest(absltest.TestCase): diff --git a/tests/experimental/toolshed/model_rewiring_test.py b/tests/toolshed/model_rewiring_test.py similarity index 98% rename from tests/experimental/toolshed/model_rewiring_test.py rename to tests/toolshed/model_rewiring_test.py index 15d107f..35a0346 100644 --- a/tests/experimental/toolshed/model_rewiring_test.py +++ b/tests/toolshed/model_rewiring_test.py @@ -17,8 +17,8 @@ from absl.testing import absltest import chex import jax.numpy as jnp -from penzai.experimental.v2 import pz -from penzai.experimental.v2.toolshed import model_rewiring +from penzai import pz +from penzai.toolshed import model_rewiring class ModelRewiringTest(absltest.TestCase): diff --git a/tests/experimental/toolshed/save_intermediates_test.py b/tests/toolshed/save_intermediates_test.py similarity index 95% rename from tests/experimental/toolshed/save_intermediates_test.py rename to tests/toolshed/save_intermediates_test.py index 03f1406..e53a55c 100644 --- a/tests/experimental/toolshed/save_intermediates_test.py +++ b/tests/toolshed/save_intermediates_test.py @@ -17,9 +17,9 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from penzai.experimental.v2 import pz -from penzai.experimental.v2.models import simple_mlp -from penzai.experimental.v2.toolshed import save_intermediates +from penzai import pz +from penzai.models import simple_mlp +from penzai.toolshed import save_intermediates class SaveIntermediatesTest(parameterized.TestCase): diff --git a/tests/experimental/toolshed/unflaxify_test.py b/tests/toolshed/unflaxify_test.py similarity index 98% rename from tests/experimental/toolshed/unflaxify_test.py rename to tests/toolshed/unflaxify_test.py index 06365d2..48fc0de 100644 --- a/tests/experimental/toolshed/unflaxify_test.py +++ b/tests/toolshed/unflaxify_test.py @@ -18,8 +18,8 @@ import chex import flax.linen import jax -from penzai.experimental.v2 import pz -from penzai.experimental.v2.toolshed import unflaxify +from penzai import pz +from penzai.toolshed import unflaxify class UnflaxifyTest(absltest.TestCase):