From 8e0a60458a80caffd904e2add2dc92d15c53a240 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sun, 15 Dec 2024 16:30:23 -0500 Subject: [PATCH] [nnx] add tabulate --- flax/nnx/__init__.py | 1 + flax/nnx/filterlib.py | 4 +- flax/nnx/summary.py | 214 ++++++++++++++++++++++++++++++++++++++++ flax/nnx/variablelib.py | 2 +- uv.lock | 38 +++---- 5 files changed, 238 insertions(+), 21 deletions(-) create mode 100644 flax/nnx/summary.py diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index fcb15f0608..d85271c465 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -167,4 +167,5 @@ from .extract import to_tree as to_tree from .extract import from_tree as from_tree from .extract import NodeStates as NodeStates +from .summary import tabulate as tabulate from . import traversals as traversals \ No newline at end of file diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index 63ed371be9..1028efb2b1 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate: else: raise TypeError(f'Invalid collection filter: {filter:!r}. ') -def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]: +def filters_to_predicates( + filters: tp.Sequence[Filter], +) -> tuple[Predicate, ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py new file mode 100644 index 0000000000..17dd088c4b --- /dev/null +++ b/flax/nnx/summary.py @@ -0,0 +1,214 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pytype: skip-file + + +import io +import typing as tp +from itertools import groupby +from types import MappingProxyType + +import jax +import rich.console +import rich.table +import rich.text +import yaml +import jax.numpy as jnp + +from flax.nnx import graph, rnglib, variablelib + + +def tabulate( + obj, + depth: int = 2, + table_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + column_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + console_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), +) -> str: + """Create a table summary of the states. + + Args: + obj: The object to summarize. + depth: The depth of the table. + + Returns: + A string representing the states in a table. + """ + state = graph.state(obj) + graph_map = dict(graph.iter_graph(obj)) + flat_state = sorted(state.flat_state()) + + def key_fn( + path_state: tuple[graph.PathParts, variablelib.VariableState[tp.Any]], + ): + path, _ = path_state + if len(path) <= depth: + return path[:-1] + return path[:depth] + + rows = groupby(flat_state, key_fn) + table = sorted((path, list(flat_states)) for path, flat_states in rows) + + state_types = {variable_state.type for _, variable_state in flat_state} + # replace RngKey and RngCount with RngState + if rnglib.RngKey in state_types: + state_types.remove(rnglib.RngKey) + state_types.add(rnglib.RngState) + if rnglib.RngCount in state_types: + state_types.remove(rnglib.RngCount) + state_types.add(rnglib.RngState) + # sort based on MRO + state_types = _sort_variable_types(state_types) + + rich_table = rich.table.Table( + show_header=True, + show_lines=True, + show_footer=True, + title=f'{type(obj).__name__} Summary', + **table_kwargs, + ) + + rich_table.add_column('path', **column_kwargs) + rich_table.add_column('type', **column_kwargs) + + for state_type in state_types: + rich_table.add_column(state_type.__name__, **column_kwargs) + + for key_path, row_states in table: + row: list[str] = [] + node = graph_map[key_path] + type_state_groups = variablelib.split_flat_state(row_states, state_types) + path_str = '/'.join(map(str, key_path)) + node_type = type(node).__name__ + row.extend([path_str, node_type]) + + for state_type, type_path_and_states in zip(state_types, type_state_groups): + attributes = {} + for state_path, variable_state in type_path_and_states: + if len(state_path) == len(key_path) + 1: + name = str(state_path[-1]) + value = variable_state.value + value_repr = _render_array(value) if _has_shape_dtype(value) else '' + metadata = variable_state.get_metadata() + + if metadata: + attributes[name] = { + 'value': value_repr, + **metadata, + } + elif value_repr: + attributes[name] = value_repr + + if attributes: + col_repr = _as_yaml_str(attributes) + '\n\n' + else: + col_repr = '' + + type_states = [state for _, state in type_path_and_states] + size_, bytes_ = _size_and_bytes(type_states) + col_repr += f'[bold]{_size_and_bytes_repr(size_, bytes_)}[/bold]' + row.append(col_repr) + + rich_table.add_row(*row) + + rich_table.columns[1].footer = rich.text.Text.from_markup( + 'Total', justify='right' + ) + flat_states = variablelib.split_flat_state(flat_state, state_types) + + for i, (state_type, type_path_and_states) in enumerate( + zip(state_types, flat_states) + ): + type_states = [state for _, state in type_path_and_states] + size_, bytes_ = _size_and_bytes(type_states) + size_repr = _size_and_bytes_repr(size_, bytes_) + rich_table.columns[i + 2].footer = size_repr + + rich_table.caption_style = 'bold' + rich_table.caption = ( + f'\nTotal Parameters: {_size_and_bytes_repr(*_size_and_bytes(state))}' + ) + + return '\n' + _get_rich_repr(rich_table, console_kwargs) + '\n' + + +def _get_rich_repr(obj, console_kwargs): + f = io.StringIO() + console = rich.console.Console(file=f, **console_kwargs) + console.print(obj) + return f.getvalue() + + +def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]: + leaves = jax.tree.leaves(pytree) + size = sum(x.size for x in leaves if hasattr(x, 'size')) + num_bytes = sum( + x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') + ) + return size, num_bytes + + +def _size_and_bytes_repr(size: int, num_bytes: int) -> str: + if not size: + return '' + bytes_repr = _bytes_repr(num_bytes) + return f'{size:,} [dim]({bytes_repr})[/dim]' + + +def _bytes_repr(num_bytes): + count, units = ( + (f'{num_bytes / 1e9 :,.1f}', 'GB') + if num_bytes > 1e9 + else (f'{num_bytes / 1e6 :,.1f}', 'MB') + if num_bytes > 1e6 + else (f'{num_bytes / 1e3 :,.1f}', 'KB') + if num_bytes > 1e3 + else (f'{num_bytes:,}', 'B') + ) + + return f'{count} {units}' + + +def _has_shape_dtype(value): + return hasattr(value, 'shape') and hasattr(value, 'dtype') + + +def _as_yaml_str(value) -> str: + if (hasattr(value, '__len__') and len(value) == 0) or value is None: + return '' + + file = io.StringIO() + yaml.safe_dump( + value, + file, + default_flow_style=False, + indent=2, + sort_keys=False, + explicit_end=False, + ) + return file.getvalue().replace('\n...', '').replace("'", '').strip() + + +def _render_array(x): + shape, dtype = jnp.shape(x), jnp.result_type(x) + shape_repr = ','.join(str(x) for x in shape) + return f'[dim]{dtype}[/dim][{shape_repr}]' + + +def _sort_variable_types(types: tp.Iterable[type]): + def _variable_parents_count(t: type): + return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) + + type_sort_key = {t: (-_variable_parents_count(t), t.__name__) for t in types} + return sorted(types, key=lambda t: type_sort_key[t]) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 4752a9b7bd..85b441711b 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -911,7 +911,7 @@ def wrapper(*args): def split_flat_state( flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]], - filters: tuple[filterlib.Filter, ...], + filters: tp.Sequence[filterlib.Filter], ) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]: predicates = filterlib.filters_to_predicates(filters) # we have n + 1 states, where n is the number of predicates diff --git a/uv.lock b/uv.lock index e08e2dbf53..52a0caa4a6 100644 --- a/uv.lock +++ b/uv.lock @@ -3,13 +3,13 @@ requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] [[package]] @@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 } wheels = [ @@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 } wheels = [ @@ -1202,7 +1202,7 @@ name = "ipython" version = "8.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "decorator" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "jedi" }, @@ -1431,7 +1431,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, + { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -2095,7 +2095,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -2122,9 +2122,9 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -2135,7 +2135,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -2436,7 +2436,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 } wheels = [ @@ -2454,10 +2454,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 } wheels = [ @@ -2606,7 +2606,7 @@ name = "pytest" version = "8.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig" }, { name = "packaging" }, @@ -3195,7 +3195,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -3684,7 +3684,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },