Skip to content

Commit

Permalink
Remove node-hyperlink and copypaste-fallback support, tweak path copy…
Browse files Browse the repository at this point in the history
…ing.

Removes two complex features in Treescope which were only rarely used:

- Node hyperlinks: Previously, tree nodes could link to each other by path. However, this
  added complexity to every rendering and is unused in the v2 Penzai neural network design.
- Copypaste fallbacks: To try to ensure Treescope renderings were valid Python syntax even
  for unrecognized types, custom `repr` implementations were originally wrapped in a
  "NotRoundtrippable" wrapper in roundtrip mode. However, this was only partially supported
  and cannot be guaranteed now that types can define their own Treescope handlers. Penzai
  models are still roundtrippable if they don't have array data.

Also changes the "Copy path" buttons to be less verbose, by removing the "(lambda root: root...)"
boilerplate.

PiperOrigin-RevId: 651758769
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 18, 2024
1 parent c746fd3 commit 4dc99dd
Show file tree
Hide file tree
Showing 16 changed files with 125 additions and 461 deletions.
1 change: 0 additions & 1 deletion docs/_autogen_root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

penzai.treescope.arrayviz
penzai.treescope.autovisualize
penzai.treescope.copypaste_fallback
penzai.treescope.default_renderer
penzai.treescope.figures
penzai.treescope.repr_lib
Expand Down
7 changes: 0 additions & 7 deletions docs/api/treescope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,3 @@ Rendering to Strings
default_renderer.render_to_text
default_renderer.render_to_html


Utility Types
^^^^^^^^^^^^^

.. autosummary::

copypaste_fallback.NotRoundtrippable
24 changes: 19 additions & 5 deletions notebooks/how_to_think_in_penzai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"import dataclasses\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"from typing import Any, Callable, Sequence"
]
},
Expand Down Expand Up @@ -177,7 +178,20 @@
"source": [
"Try clicking to expand or collapse different sublayers! We've turned on automatic array visualization, so if you expand one of the parameters, you can immediately visualize its shape and array data.\n",
"\n",
"Importantly, this isn't just a pretty visualization of the model, it's actually a **fully-roundtrippable specification of the model structure**. You can press `r` to enable roundtrip mode, and then directly copy and execute the pretty-printed output:\n"
"Importantly, this isn't just a pretty visualization of the model, it's actually a **fully-roundtrippable specification of the model structure**. You can press `r` to enable roundtrip mode, which adds fully-qualified names to every type. Then, if you remove the arrays first, you can directly copy and execute the pretty-printed output:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I7OJhe-RfNT2"
},
"outputs": [],
"source": [
"# We can use `eval_shape` to remove array data and just keep the shapes.\n",
"# Try pressing `r` and copying the below output:\n",
"jax.eval_shape(lambda: mlp)"
]
},
{
Expand All @@ -192,18 +206,18 @@
" sublayers=[\n",
" penzai.nn.linear_and_affine.Affine( # Sequential\n",
" sublayers=[\n",
" penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(8, 32) ≈-0.0019 ±0.22 [≥-0.38, ≤0.38] nonzero:256>', original_id=23148094748192, original_type=jax.Array)), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)),\n",
" penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=jax.ShapeDtypeStruct(shape=(8, 32), dtype=np.dtype('float32'))), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)),\n",
" penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)),\n",
" penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32,) ≈0.0 ±0.0 [≥0.0, ≤0.0] zero:32>', original_id=23148094743968, original_type=jax.Array)), name='Affine_0.AddBias.bias'), new_axis_names=()),\n",
" penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=jax.ShapeDtypeStruct(shape=(32,), dtype=np.dtype('float32'))), name='Affine_0.AddBias.bias'), new_axis_names=()),\n",
" ],\n",
" ),\n",
" penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),\n",
" penzai.nn.linear_and_affine.Affine( # Sequential\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32, 32) ≈-0.0037 ±0.18 [≥-0.31, ≤0.31] nonzero:1_024>', original_id=23147983056352, original_type=jax.Array)), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32,) ≈0.0 ±0.0 [≥0.0, ≤0.0] zero:32>', original_id=23148002433952, original_type=jax.Array)), name='Affine_1.AddBias.bias'), new_axis_names=())],\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=jax.ShapeDtypeStruct(shape=(32, 32), dtype=np.dtype('float32'))), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=jax.ShapeDtypeStruct(shape=(32,), dtype=np.dtype('float32'))), name='Affine_1.AddBias.bias'), new_axis_names=())],\n",
" ),\n",
" penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),\n",
" penzai.nn.linear_and_affine.Affine( # Sequential\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32, 8) ≈-0.0052 ±0.23 [≥-0.38, ≤0.39] nonzero:256>', original_id=23148002427616, original_type=jax.Array)), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)', original_id=23147983059168, original_type=jax.Array)), name='Affine_2.AddBias.bias'), new_axis_names=())],\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=jax.ShapeDtypeStruct(shape=(32, 8), dtype=np.dtype('float32'))), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=jax.ShapeDtypeStruct(shape=(8,), dtype=np.dtype('float32'))), name='Affine_2.AddBias.bias'), new_axis_names=())],\n",
" ),\n",
" ],\n",
")\n",
Expand Down
Loading

0 comments on commit 4dc99dd

Please sign in to comment.