Skip to content

Commit

Permalink
Fix detection of mutable objects for repeated object annotations.
Browse files Browse the repository at this point in the history
Python objects are hashable even if they are mutable, as long as they
compare equal by identity only. However, this wasn't correctly detected
by the shared value detection logic in treescope; such objects were
incorrectly treated as immutable.

Additionally, in rare cases reference counting in Python could cause
the same object ID to be reused, leading to a spurious repeated object
warning. By keeping references to the possibly-repeated objects until
finishing rendering, we can avoid this issue.

PiperOrigin-RevId: 637319549
  • Loading branch information
danieldjohnson authored and Penzai Developers committed May 28, 2024
1 parent 4d08e2b commit db6ae64
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions penzai/treescope/handlers/shared_value_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
import dataclasses
import io
import types
from typing import Any, Optional, Sequence

import jax
Expand All @@ -35,12 +36,21 @@
class _SharedObjectTracker:
"""Helper object to track IDs we've seen before.
This object keeps track of nodes we've encountered while rendering the root
object, so that we can detect when the same object is encountered in multiple
places. To ensure that we don't have false positives, we also store the nodes
themselves until the rendering finishes. (Usually, rendered Python objects
will be kept alive due to being part of the rendered tree, but some handlers
may render temporary objects that are discarded, at which point their object
IDs are sometimes re-used by other temporary objects.)
Attributes:
seen_at_least_once: Set of node IDs we've seen at least one time.
seen_at_least_once: Map with node IDs as keys and nodes as values,
containing nodes we've seen at least once.
seen_more_than_once: Set of node IDs we've seen more than once.
"""

seen_at_least_once: set[int]
seen_at_least_once: dict[int, Any]
seen_more_than_once: set[int]


Expand Down Expand Up @@ -245,16 +255,36 @@ def setup_shared_value_context() -> contextlib.AbstractContextManager[None]:
This should be included in the `context_builders` argument to any renderer
that checks for shared values.
"""
return _shared_object_ids_seen.set_scoped(_SharedObjectTracker(set(), set()))
return _shared_object_ids_seen.set_scoped(_SharedObjectTracker({}, set()))


# Types that can have multiple references in the same object without it being
# necessary or important to highlight the shared reference.
_SAFE_TO_SHARE_TYPES = {
jax.Array,
types.FunctionType,
types.MethodType,
types.ModuleType,
type,
type(None),
type(NotImplemented),
type(Ellipsis),
}


def _is_safe_to_share(node: Any) -> bool:
"""Returns whether the given node is immutable."""
# According to the Python data model, "If a class defines mutable objects and
# implements an __eq__() method, it should not implement __hash__()". So, if
# we find an object that implements __eq__ and __hash__, we can generally
# assume it is immutable.
return isinstance(node, tuple(_SAFE_TO_SHARE_TYPES)) or (
type(node).__hash__ is not None
and type(node).__hash__ is not object.__hash__
and type(node).__eq__ is not object.__eq__
)


def check_for_shared_values(
node: Any,
path: tuple[Any, ...] | None,
Expand Down Expand Up @@ -297,11 +327,8 @@ def check_for_shared_values(
node_id = id(node)

# For types that we know are immutable, it's not necessary to render shared
# references in a special way. (Hashable objects can _technically_ be
# modified but we trust the user to know what they are doing if so.)
safe_to_share = (
hasattr(node, "__hash__") and node.__hash__ is not None
) or isinstance(node, tuple(_SAFE_TO_SHARE_TYPES))
# references in a special way.
safe_to_share = _is_safe_to_share(node)

# Render the node normally.
rendering = node_renderer(node, path)
Expand All @@ -311,7 +338,7 @@ def check_for_shared_values(
if node_id in shared_object_tracker.seen_at_least_once:
shared_object_tracker.seen_more_than_once.add(node_id)
else:
shared_object_tracker.seen_at_least_once.add(node_id)
shared_object_tracker.seen_at_least_once[node_id] = node

# Wrap it in a shared value wrapper; this will check to see if the same
# node was seen more than once, and add an annotation if so.
Expand Down

0 comments on commit db6ae64

Please sign in to comment.