diff --git a/penzai/treescope/handlers/shared_value_postprocessor.py b/penzai/treescope/handlers/shared_value_postprocessor.py index 11e1691..e6708b1 100644 --- a/penzai/treescope/handlers/shared_value_postprocessor.py +++ b/penzai/treescope/handlers/shared_value_postprocessor.py @@ -21,6 +21,7 @@ import contextlib import dataclasses import io +import types from typing import Any, Optional, Sequence import jax @@ -35,12 +36,21 @@ class _SharedObjectTracker: """Helper object to track IDs we've seen before. + This object keeps track of nodes we've encountered while rendering the root + object, so that we can detect when the same object is encountered in multiple + places. To ensure that we don't have false positives, we also store the nodes + themselves until the rendering finishes. (Usually, rendered Python objects + will be kept alive due to being part of the rendered tree, but some handlers + may render temporary objects that are discarded, at which point their object + IDs are sometimes re-used by other temporary objects.) + Attributes: - seen_at_least_once: Set of node IDs we've seen at least one time. + seen_at_least_once: Map with node IDs as keys and nodes as values, + containing nodes we've seen at least once. seen_more_than_once: Set of node IDs we've seen more than once. """ - seen_at_least_once: set[int] + seen_at_least_once: dict[int, Any] seen_more_than_once: set[int] @@ -245,16 +255,36 @@ def setup_shared_value_context() -> contextlib.AbstractContextManager[None]: This should be included in the `context_builders` argument to any renderer that checks for shared values. """ - return _shared_object_ids_seen.set_scoped(_SharedObjectTracker(set(), set())) + return _shared_object_ids_seen.set_scoped(_SharedObjectTracker({}, set())) # Types that can have multiple references in the same object without it being # necessary or important to highlight the shared reference. _SAFE_TO_SHARE_TYPES = { jax.Array, + types.FunctionType, + types.MethodType, + types.ModuleType, + type, + type(None), + type(NotImplemented), + type(Ellipsis), } +def _is_safe_to_share(node: Any) -> bool: + """Returns whether the given node is immutable.""" + # According to the Python data model, "If a class defines mutable objects and + # implements an __eq__() method, it should not implement __hash__()". So, if + # we find an object that implements __eq__ and __hash__, we can generally + # assume it is immutable. + return isinstance(node, tuple(_SAFE_TO_SHARE_TYPES)) or ( + type(node).__hash__ is not None + and type(node).__hash__ is not object.__hash__ + and type(node).__eq__ is not object.__eq__ + ) + + def check_for_shared_values( node: Any, path: tuple[Any, ...] | None, @@ -297,11 +327,8 @@ def check_for_shared_values( node_id = id(node) # For types that we know are immutable, it's not necessary to render shared - # references in a special way. (Hashable objects can _technically_ be - # modified but we trust the user to know what they are doing if so.) - safe_to_share = ( - hasattr(node, "__hash__") and node.__hash__ is not None - ) or isinstance(node, tuple(_SAFE_TO_SHARE_TYPES)) + # references in a special way. + safe_to_share = _is_safe_to_share(node) # Render the node normally. rendering = node_renderer(node, path) @@ -311,7 +338,7 @@ def check_for_shared_values( if node_id in shared_object_tracker.seen_at_least_once: shared_object_tracker.seen_more_than_once.add(node_id) else: - shared_object_tracker.seen_at_least_once.add(node_id) + shared_object_tracker.seen_at_least_once[node_id] = node # Wrap it in a shared value wrapper; this will check to see if the same # node was seen more than once, and add an annotation if so.