Skip to content

Commit

Permalink
Add __penzai_repr__ extension method and repr_lib helpers.
Browse files Browse the repository at this point in the history
The __penzai_repr__ method can be used to add a handler for a custom
type without modifying the global handler list. This is designed to make
it easier to extend treescope to support custom types.

Also adds a high-level helper library `repr_lib` to make it easier to
implement __penzai_repr__.

PiperOrigin-RevId: 637942807
  • Loading branch information
danieldjohnson authored and Penzai Developers committed May 28, 2024
1 parent a1ba9b6 commit 50b2762
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 0 deletions.
2 changes: 2 additions & 0 deletions penzai/treescope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
from . import default_renderer
from . import figures
from . import foldable_representation
from . import formatting_util
from . import handlers
from . import html_compression
from . import html_escaping
from . import renderer
from . import repr_lib
from . import selection_rendering
from . import treescope_ipython
3 changes: 3 additions & 0 deletions penzai/treescope/default_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from penzai.treescope.handlers import builtin_atom_handler
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers import canonical_alias_postprocessor
from penzai.treescope.handlers import extension_method_handler
from penzai.treescope.handlers import function_reflection_handlers
from penzai.treescope.handlers import generic_pytree_handler
from penzai.treescope.handlers import generic_repr_handler
Expand Down Expand Up @@ -77,6 +78,8 @@
layer_handler.handle_layers,
# Other pz.Struct instances.
struct_handler.handle_structs,
# Objects with their own handlers.
extension_method_handler.handle_via_penzai_repr_method,
# NDArrays.
ndarray_handler.handle_ndarrays,
# Reflection of functions and classes.
Expand Down
22 changes: 22 additions & 0 deletions penzai/treescope/formatting_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for formatting and rendering."""

# pylint: disable=g-multiple-import,g-importing-member,unused-import

from penzai.core.formatting_util import (
color_from_string,
oklch_color,
)
65 changes: 65 additions & 0 deletions penzai/treescope/handlers/extension_method_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Handler for custom types via the __penzai_repr__ method."""

from __future__ import annotations

from typing import Any

from penzai.treescope import renderer
from penzai.treescope.foldable_representation import part_interface


def handle_via_penzai_repr_method(
node: Any,
path: tuple[Any, ...] | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
| part_interface.RenderableAndLineAnnotations
| type(NotImplemented)
):
"""Renders a type by calling its __penzai_repr__ method, if it exists.
The __penzai_repr__ method can be used to add treescope support to custom
classes. The method is expected to return a rendering in treescope's internal
intermediate representation.
Currently, the exact structure of the intermediate representation is an
implementation detail and may change in future releases. Instead of building
a rendering directly, most types should use the construction helpers in
`penzai.treescope.repr_lib` to implement this method.
A useful pattern is to only import `penzai.treescope` inside
`__penzai_repr__`. This allows a library to support treescope without
requiring treescope to be a direct dependency of the library.
Args:
node: The node to render.
path: An optional path to this node from the root.
subtree_renderer: The renderer for sutrees of this node.
Returns:
A rendering of this node, if it implements the __penzai_repr__ extension
method.
"""
if (
not isinstance(node, type)
and hasattr(node, "__penzai_repr__")
and callable(node.__penzai_repr__)
):
return node.__penzai_repr__(path, subtree_renderer)
else:
return NotImplemented
262 changes: 262 additions & 0 deletions penzai/treescope/repr_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Stable high-level interface for building object reprs.
These functions simplify the process of implmenting `__penzai_repr__` for
custom types, allowing them to integrate with treescope. This interface will be
stable across penzai releases, and may be expanded in the future to support
additional customization.
Note that the exact types of `path` and `subtree_renderer` are subject to
change in future releases. These should always be passed directly from
`__penzai_repr__`.
"""

from __future__ import annotations

from typing import Any, Mapping

import jax
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
from penzai.treescope.foldable_representation import part_interface


def render_object_constructor(
object_type: type[Any],
attributes: Mapping[str, Any],
path: tuple[Any, ...] | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
roundtrippable: bool = False,
color: str | None = None,
) -> (
part_interface.RenderableTreePart
| part_interface.RenderableAndLineAnnotations
):
"""Renders an object in "constructor format", similar to a dataclass.
This produces a rendering like `Foo(bar=1, baz=2)`, where Foo identifies the
type of the object, and bar and baz are the names of the attributes of the
object. It is a *requirement* that these are the actual attributes of the
object, which can be accessed via `obj.bar` or similar; otherwise, the
path renderings will break.
This can be used from within a `__penzai_repr__` implementation via ::
def __penzai_repr__(self, path, subtree_renderer):
return repr_lib.render_object_constructor(
object_type=type(self),
attributes=<dict of attributes here>,
path=path,
subtree_renderer=subtree_renderer,
)
Args:
object_type: The type of the object.
attributes: The attributes of the object, which will be rendered as keyword
arguments to the constructor.
path: The path to the object. When `render_object_constructor` is called
from `__penzai_repr__`, this should come from the `path` argument to
`__penzai_repr__`.
subtree_renderer: The renderer to use to render subtrees. When
`render_object_constructor` is called from `__penzai_repr__`, this
should come from the `subtree_renderer` argument to `__penzai_repr__`.
roundtrippable: Whether evaluating the rendering as Python code will produce
an object that is equal to the original object. This implies that the
keyword arguments are actually the keyword arguments to the constructor,
and not some other attributes of the object.
color: The background color to use for the object rendering. If None, does
not use a background color. A utility for assigning a random color based
on a string key is given in `penzai.treescope.formatting_util`.
Returns:
A rendering of the object, suitable for returning from `__penzai_repr__`.
"""
if roundtrippable:
constructor = basic_parts.siblings(
common_structures.maybe_qualified_type_name(object_type), "("
)
closing_suffix = basic_parts.Text(")")
else:
constructor = basic_parts.siblings(
basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("<")),
common_structures.maybe_qualified_type_name(object_type),
"(",
)
closing_suffix = basic_parts.siblings(
")",
basic_parts.RoundtripCondition(roundtrip=basic_parts.Text(">")),
)

children = []
for i, (name, value) in enumerate(attributes.items()):
child_path = (
None if path is None else path + (jax.tree_util.GetAttrKey(name),)
)

if i < len(attributes) - 1:
# Not the last child. Always show a comma, and add a space when
# collapsed.
comma_after = basic_parts.siblings(
",", basic_parts.FoldCondition(collapsed=basic_parts.Text(" "))
)
else:
# Last child: only show the comma when the node is expanded.
comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(","))

child_line = basic_parts.build_full_line_with_annotations(
basic_parts.siblings_with_annotations(
f"{name}=",
subtree_renderer(value, path=child_path),
),
comma_after,
)
children.append(child_line)

return common_structures.build_foldable_tree_node_from_children(
prefix=constructor,
children=children,
suffix=closing_suffix,
path=path,
background_color=color,
)


def render_dictionary_wrapper(
object_type: type[Any],
wrapped_dict: Mapping[str, Any],
path: tuple[Any, ...] | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
roundtrippable: bool = False,
color: str | None = None,
) -> (
part_interface.RenderableTreePart
| part_interface.RenderableAndLineAnnotations
):
"""Renders an object in "wrapped dictionary format".
This produces a rendering like `Foo({"bar": 1, "baz": 2})`, where Foo
identifies the type of the object, and "bar" and "baz" are the keys in the
dictionary that Foo acts like. It is a *requirement* that these are accessible
through `__getitem__`, e.g. as `obj["bar"]` or similar; otherwise, the path
renderings will break.
This can be used from within a `__penzai_repr__` implementation via ::
def __penzai_repr__(self, path, subtree_renderer):
return repr_lib.render_dictionary_wrapper(
object_type=type(self),
wrapped_dict=<dict of items here>,
path=path,
subtree_renderer=subtree_renderer,
)
Args:
object_type: The type of the object.
wrapped_dict: The dictionary that the object wraps.
path: The path to the object. When `render_object_constructor` is called
from `__penzai_repr__`, this should come from the `path` argument to
`__penzai_repr__`.
subtree_renderer: The renderer to use to render subtrees. When
`render_object_constructor` is called from `__penzai_repr__`, this
should come from the `subtree_renderer` argument to `__penzai_repr__`.
roundtrippable: Whether evaluating the rendering as Python code will produce
an object that is equal to the original object. This implies that the
constructor for `object_type` takes a single argument, which is a
dictionary, and that `object_type` then acts like that dictionary.
color: The background color to use for the object rendering. If None, does
not use a background color. A utility for assigning a random color based
on a string key is given in `penzai.treescope.formatting_util`. (By
convention, wrapped dictionaries aren't usually assigned a color in
Penzai.)
Returns:
A rendering of the object, suitable for returning from `__penzai_repr__`.
"""
if roundtrippable:
constructor = basic_parts.siblings(
common_structures.maybe_qualified_type_name(object_type), "({"
)
closing_suffix = basic_parts.Text("})")
else:
constructor = basic_parts.siblings(
basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("<")),
common_structures.maybe_qualified_type_name(object_type),
"({",
)
closing_suffix = basic_parts.siblings(
"})",
basic_parts.RoundtripCondition(roundtrip=basic_parts.Text(">")),
)

children = []
for i, (key, value) in enumerate(wrapped_dict.items()):
child_path = None if path is None else path + (jax.tree_util.DictKey(key),)

if i < len(wrapped_dict) - 1:
# Not the last child. Always show a comma, and add a space when
# collapsed.
comma_after = basic_parts.siblings(
",", basic_parts.FoldCondition(collapsed=basic_parts.Text(" "))
)
else:
# Last child: only show the comma when the node is expanded.
comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(","))

key_rendering = subtree_renderer(key)
value_rendering = subtree_renderer(value, path=child_path)

if (
key_rendering.renderable.collapsed_width < 40
and not key_rendering.renderable.foldables_in_this_part()
and (
key_rendering.annotations is None
or key_rendering.annotations.collapsed_width == 0
)
):
# Simple enough to render on one line.
children.append(
basic_parts.siblings_with_annotations(
key_rendering, ": ", value_rendering, comma_after
)
)
else:
# Should render on multiple lines.
children.append(
basic_parts.siblings(
basic_parts.build_full_line_with_annotations(
key_rendering,
":",
basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")),
),
basic_parts.IndentedChildren.build([
basic_parts.siblings_with_annotations(
value_rendering, comma_after
),
basic_parts.FoldCondition(
expanded=basic_parts.VerticalSpace("0.5em")
),
]),
)
)

return common_structures.build_foldable_tree_node_from_children(
prefix=constructor,
children=children,
suffix=closing_suffix,
path=path,
background_color=color,
)

0 comments on commit 50b2762

Please sign in to comment.