Skip to content

Commit

Permalink
chore: cleanup auxdata -> metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 27, 2024
1 parent 7e0a810 commit f213a17
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 34 deletions.
4 changes: 2 additions & 2 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ class PyTreeSpec {
// Arity for non-Leaf types.
ssize_t arity = 0;

// Kind-specific auxiliary data.
// Kind-specific metadata.
// For a NamedTuple/PyStructSequence, contains the tuple type object.
// For a Dict, contains a sorted list of keys.
// For a OrderedDict, contains a list of keys.
// For a DefaultDict, contains a tuple of (default_factory, sorted list of keys).
// For a Deque, contains the `maxlen` attribute.
// For a Custom type, contains the auxiliary data returned by the `flatten_func` function.
// For a Custom type, contains the metadata returned by the `flatten_func` function.
py::object node_data{};

// The tuple of path entries.
Expand Down
4 changes: 2 additions & 2 deletions optree/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def tree_flatten(self) -> tuple[ # type: ignore[override]
Callable[..., Any],
tuple[str, str],
]:
"""Flatten the :class:`partial` instance to children and auxiliary data."""
"""Flatten the :class:`partial` instance to children and metadata."""
return (self.args, self.keywords), self.func, ('args', 'keywords')

@classmethod
Expand All @@ -157,7 +157,7 @@ def tree_unflatten( # type: ignore[override]
metadata: Callable[..., Any],
children: tuple[tuple[T, ...], dict[str, T]],
) -> Self:
"""Unflatten the children and auxiliary data into a :class:`partial` instance."""
"""Unflatten the children and metadata into a :class:`partial` instance."""
args, keywords = children
return cls(metadata, *args, **keywords)

Expand Down
30 changes: 15 additions & 15 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2438,7 +2438,7 @@ def tree_flatten_one_level(
tuple[Any, ...],
Callable[[MetaData, list[PyTree[T]]], PyTree[T]],
]:
"""Flatten the pytree one level, returning a 4-tuple of children, auxiliary data, path entries, and an unflatten function.
"""Flatten the pytree one level, returning a 4-tuple of children, metadata, path entries, and an unflatten function.
See also :func:`tree_flatten`, :func:`tree_flatten_with_path`.
Expand Down Expand Up @@ -2468,9 +2468,9 @@ def tree_flatten_one_level(
Returns:
A 4-tuple ``(children, metadata, entries, unflatten_func)``. The first element is a list of
one-level children of the pytree node. The second element is the auxiliary data used to
one-level children of the pytree node. The second element is the metadata used to
reconstruct the pytree node. The third element is a tuple of path entries to the children.
The fourth element is a function that can be used to unflatten the auxiliary data and
The fourth element is a function that can be used to unflatten the metadata and
children back to the pytree node.
""" # pylint: disable=line-too-long
node_type = type(tree)
Expand Down Expand Up @@ -3209,11 +3209,11 @@ def _prefix_error(
):
yield lambda name: ValueError(
f'pytree structure error: different types at key path\n'
f' {{name}}{accessor.codify("") if accessor else " tree root"}\n'
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
f' {accessor.codify(name) if accessor else name + " tree root"}\n'
f'At that key path, the prefix pytree {name} has a subtree of type\n'
f' {type(prefix_tree)}\n'
f'but at the same key path the full pytree has a subtree of different type\n'
f' {type(full_tree)}.'.format(name=name),
f' {type(full_tree)}.',
)
return # don't look for more errors in this subtree

Expand Down Expand Up @@ -3254,15 +3254,15 @@ def _prefix_error(
key_difference += f'\nextra key(s):\n {extra_keys}'
yield lambda name: ValueError(
f'pytree structure error: different pytree keys at key path\n'
f' {{name}}{accessor.codify("") if accessor else " tree root"}\n'
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
f' {accessor.codify(name) if accessor else name + " tree root"}\n'
f'At that key path, the prefix pytree {name} has a subtree of type\n'
f' {prefix_tree_type}\n'
f'with {len(prefix_tree_keys)} key(s)\n'
f' {prefix_tree_keys}\n'
f'but at the same key path the full pytree has a subtree of type\n'
f' {full_tree_type}\n'
f'but with {len(full_tree_keys)} key(s)\n'
f' {full_tree_keys}{key_difference}'.format(name=name),
f' {full_tree_keys}{key_difference}',
)
return # don't look for more errors in this subtree

Expand All @@ -3272,12 +3272,12 @@ def _prefix_error(
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
f'pytree structure error: different numbers of pytree children at key path\n'
f' {{name}}{accessor.codify("") if accessor else " tree root"}\n'
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
f' {accessor.codify(name) if accessor else name + " tree root"}\n'
f'At that key path, the prefix pytree {name} has a subtree of type\n'
f' {prefix_tree_type}\n'
f'with {len(prefix_tree_children)} children, '
f'but at the same key path the full pytree has a subtree of the same '
f'type but with {len(full_tree_children)} children.'.format(name=name),
f'type but with {len(full_tree_children)} children.',
)
return # don't look for more errors in this subtree

Expand All @@ -3303,16 +3303,16 @@ def _prefix_error(
)
yield lambda name: ValueError(
f'pytree structure error: different pytree metadata at key path\n'
f' {{name}}{accessor.codify("") if accessor else " tree root"}\n'
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
f' {accessor.codify(name) if accessor else name + " tree root"}\n'
f'At that key path, the prefix pytree {name} has a subtree of type\n'
f' {prefix_tree_type}\n'
f'with metadata\n'
f' {prefix_tree_metadata_repr}\n'
f'but at the same key path the full pytree has a subtree of the same '
f'type but with metadata\n'
f' {full_tree_metadata_repr}\n'
f'so the diff in the metadata at these pytree nodes is\n'
f'{metadata_diff}'.format(name=name),
f'{metadata_diff}',
)
return # don't look for more errors in this subtree

Expand Down
12 changes: 6 additions & 6 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ def register_pytree_node(
cls (type): A Python type to treat as an internal pytree node.
flatten_func (callable): A function to be used during flattening, taking an instance of ``cls``
and returning a triple or optionally a pair, with (1) an iterable for the children to be
flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec
and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree
path entries to the corresponding children. If the entries are not provided or given by
flattened recursively, and (2) some hashable metadata to be stored in the treespec and
to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path
entries to the corresponding children. If the entries are not provided or given by
:data:`None`, then `range(len(children))` will be used.
unflatten_func (callable): A function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treespec, and the unflattened children.
The function should return an instance of ``cls``.
unflatten_func (callable): A function taking two arguments: the metadata that was returned
by ``flatten_func`` and stored in the treespec, and the unflattened children. The
function should return an instance of ``cls``.
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
(default: :class:`AutoEntry`)
namespace (str): A non-empty string that uniquely identifies the namespace of the type registry.
Expand Down
9 changes: 3 additions & 6 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def tree_flatten(
# With optionally implemented path entries
tuple[Children[T], MetaData, Iterable[Any] | None]
):
"""Flatten the custom pytree node into children and auxiliary data."""
"""Flatten the custom pytree node into children and metadata."""

@classmethod
def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTreeNode[T]:
"""Unflatten the children and auxiliary data into the custom pytree node."""
"""Unflatten the children and metadata into the custom pytree node."""


_UnionType = type(Union[int, str])
Expand Down Expand Up @@ -452,10 +452,7 @@ def is_structseq_class(cls: type) -> bool:
# Check the type does not allow subclassing
if platform.python_implementation() == 'PyPy':
try:
# pylint: disable-next=too-few-public-methods
class _(cls): # noqa: N801
pass

types.new_class('subclass', bases=(cls,))
except (AssertionError, TypeError):
return True
return False
Expand Down
6 changes: 3 additions & 3 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ namespace optree {
py::dict dict{};
const scoped_critical_section2 cs{node.node_data, node.original_keys};
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of auxiliary data mismatch.");
EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of metadata mismatch.");
}
const auto keys = (node.kind != PyTreeKind::DefaultDict
? py::reinterpret_borrow<py::list>(node.node_data)
Expand Down Expand Up @@ -1186,7 +1186,7 @@ std::string PyTreeSpec::ToStringImpl() const {

case PyTreeKind::DefaultDict: {
const scoped_critical_section cs(node.node_data);
EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of auxiliary data mismatch.");
EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of metadata mismatch.");
const py::object default_factory = TupleGetItem(node.node_data, 0);
const auto keys = TupleGetItemAs<py::list>(node.node_data, 1);
EXPECT_EQ(ListGetSize(keys),
Expand Down Expand Up @@ -1339,7 +1339,7 @@ std::string PyTreeSpec::ToString() const {
case PyTreeKind::DefaultDict: {
const scoped_critical_section cs{node.node_data};
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of auxiliary data mismatch.");
EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of metadata mismatch.");
const py::object default_factory = TupleGetItem(node.node_data, 0);
data_hash = EVALUATE_WITH_LOCK_HELD(py::hash(default_factory), default_factory);
}
Expand Down

0 comments on commit f213a17

Please sign in to comment.