Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove node-hyperlink and copypaste-fallback support, tweak path copying. #68

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/_autogen_root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

penzai.treescope.arrayviz
penzai.treescope.autovisualize
penzai.treescope.copypaste_fallback
penzai.treescope.default_renderer
penzai.treescope.figures
penzai.treescope.repr_lib
Expand Down
7 changes: 0 additions & 7 deletions docs/api/treescope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,3 @@ Rendering to Strings
default_renderer.render_to_text
default_renderer.render_to_html


Utility Types
^^^^^^^^^^^^^

.. autosummary::

copypaste_fallback.NotRoundtrippable
24 changes: 19 additions & 5 deletions notebooks/how_to_think_in_penzai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"import dataclasses\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"from typing import Any, Callable, Sequence"
]
},
Expand Down Expand Up @@ -177,7 +178,20 @@
"source": [
"Try clicking to expand or collapse different sublayers! We've turned on automatic array visualization, so if you expand one of the parameters, you can immediately visualize its shape and array data.\n",
"\n",
"Importantly, this isn't just a pretty visualization of the model, it's actually a **fully-roundtrippable specification of the model structure**. You can press `r` to enable roundtrip mode, and then directly copy and execute the pretty-printed output:\n"
"Importantly, this isn't just a pretty visualization of the model, it's actually a **fully-roundtrippable specification of the model structure**. You can press `r` to enable roundtrip mode, which adds fully-qualified names to every type. Then, if you remove the arrays first, you can directly copy and execute the pretty-printed output:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I7OJhe-RfNT2"
},
"outputs": [],
"source": [
"# We can use `eval_shape` to remove array data and just keep the shapes.\n",
"# Try pressing `r` and copying the below output:\n",
"jax.eval_shape(lambda: mlp)"
]
},
{
Expand All @@ -192,18 +206,18 @@
" sublayers=[\n",
" penzai.nn.linear_and_affine.Affine( # Sequential\n",
" sublayers=[\n",
" penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(8, 32) ≈-0.0019 ±0.22 [≥-0.38, ≤0.38] nonzero:256>', original_id=23148094748192, original_type=jax.Array)), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)),\n",
" penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=jax.ShapeDtypeStruct(shape=(8, 32), dtype=np.dtype('float32'))), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)),\n",
" penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)),\n",
" penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32,) ≈0.0 ±0.0 [≥0.0, ≤0.0] zero:32>', original_id=23148094743968, original_type=jax.Array)), name='Affine_0.AddBias.bias'), new_axis_names=()),\n",
" penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=jax.ShapeDtypeStruct(shape=(32,), dtype=np.dtype('float32'))), name='Affine_0.AddBias.bias'), new_axis_names=()),\n",
" ],\n",
" ),\n",
" penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),\n",
" penzai.nn.linear_and_affine.Affine( # Sequential\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32, 32) ≈-0.0037 ±0.18 [≥-0.31, ≤0.31] nonzero:1_024>', original_id=23147983056352, original_type=jax.Array)), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32,) ≈0.0 ±0.0 [≥0.0, ≤0.0] zero:32>', original_id=23148002433952, original_type=jax.Array)), name='Affine_1.AddBias.bias'), new_axis_names=())],\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=jax.ShapeDtypeStruct(shape=(32, 32), dtype=np.dtype('float32'))), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=jax.ShapeDtypeStruct(shape=(32,), dtype=np.dtype('float32'))), name='Affine_1.AddBias.bias'), new_axis_names=())],\n",
" ),\n",
" penzai.nn.basic_ops.Elementwise(fn=jax.nn.relu),\n",
" penzai.nn.linear_and_affine.Affine( # Sequential\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='<jax.Array float32(32, 8) ≈-0.0052 ±0.23 [≥-0.38, ≤0.39] nonzero:256>', original_id=23148002427616, original_type=jax.Array)), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)', original_id=23147983059168, original_type=jax.Array)), name='Affine_2.AddBias.bias'), new_axis_names=())],\n",
" sublayers=[penzai.nn.linear_and_affine.Linear(weights=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=jax.ShapeDtypeStruct(shape=(32, 8), dtype=np.dtype('float32'))), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.nn.linear_and_affine.AddBias(bias=penzai.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=jax.ShapeDtypeStruct(shape=(8,), dtype=np.dtype('float32'))), name='Affine_2.AddBias.bias'), new_axis_names=())],\n",
" ),\n",
" ],\n",
")\n",
Expand Down
Loading
Loading