diff --git a/.github/labeler.yml b/.github/labeler.yml index 377d3a0d384e..ae2f57cc925a 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -54,5 +54,5 @@ benchmark: - changed-files: - any-glob-to-any-file: ["benchmark/**/*", "torch_geometric/profile/**/*"] -skip-changelog: +auto-skip-changelog: - head-branch: ['^skip', '^pre-commit-ci'] diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index 203ad3e922d8..5e6892bfd1f3 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -13,4 +13,4 @@ jobs: - name: Enforce changelog entry uses: dangoslen/changelog-enforcer@v3 with: - skipLabels: 'skip-changelog' + skipLabels: skip-changelog, auto-skip-changelog diff --git a/CHANGELOG.md b/CHANGELOG.md index 77de363257dc..b6502862b51f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Added the `fmt` arg to `Dataset.get_summary` ([#9408](https://github.com/pyg-team/pytorch_geometric/pull/9408)) - Skipped zero atom molecules in `MoleculeNet` ([#9318](https://github.com/pyg-team/pytorch_geometric/pull/9318)) - Ensure proper parallelism in `OnDiskDataset` for multi-threaded `get` calls ([#9140](https://github.com/pyg-team/pytorch_geometric/pull/9140)) - Allow `None` outputs in `FeatureStore.get_tensor()` - `KeyError` should now be raised based on the implementation in `FeatureStore._get_tensor()` ([#9102](https://github.com/pyg-team/pytorch_geometric/pull/9102)) diff --git a/torch_geometric/data/dataset.py b/torch_geometric/data/dataset.py index ea769a0714f3..f510d3012f20 100644 --- a/torch_geometric/data/dataset.py +++ b/torch_geometric/data/dataset.py @@ -367,9 +367,15 @@ def get_summary(self) -> Any: from torch_geometric.data.summary import Summary return Summary.from_dataset(self) - def print_summary(self) -> None: - r"""Prints summary statistics of the dataset to the console.""" - print(str(self.get_summary())) + def print_summary(self, fmt: str = "psql") -> None: + r"""Prints summary statistics of the dataset to the console. + + Args: + fmt (str, optional): Summary tables format. Available table formats + can be found `here `__. (default: :obj:`"psql"`) + """ + print(self.get_summary().format(fmt=fmt)) def to_datapipe(self) -> Any: r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`. diff --git a/torch_geometric/data/summary.py b/torch_geometric/data/summary.py index 06b70c30a290..e90b504615e2 100644 --- a/torch_geometric/data/summary.py +++ b/torch_geometric/data/summary.py @@ -117,7 +117,14 @@ def from_dataset( num_edges_per_type=num_edges_per_type, ) - def __repr__(self) -> str: + def format(self, fmt: str = "psql") -> str: + r"""Formats summary statistics of the dataset. + + Args: + fmt (str, optional): Summary tables format. Available table formats + can be found `here `__. (default: :obj:`"psql"`) + """ from tabulate import tabulate body = f'{self.name} (#graphs={self.num_graphs}):\n' @@ -127,7 +134,7 @@ def __repr__(self) -> str: for field in Stats.__dataclass_fields__: row = [field] + [f'{getattr(s, field):.1f}' for s in stats] content.append(row) - body += tabulate(content, headers='firstrow', tablefmt='psql') + body += tabulate(content, headers='firstrow', tablefmt=fmt) if self.num_nodes_per_type is not None: content = [['']] @@ -140,7 +147,7 @@ def __repr__(self) -> str: ] content.append(row) body += "\nNumber of nodes per node type:\n" - body += tabulate(content, headers='firstrow', tablefmt='psql') + body += tabulate(content, headers='firstrow', tablefmt=fmt) if self.num_edges_per_type is not None: content = [['']] @@ -156,6 +163,9 @@ def __repr__(self) -> str: ] content.append(row) body += "\nNumber of edges per edge type:\n" - body += tabulate(content, headers='firstrow', tablefmt='psql') + body += tabulate(content, headers='firstrow', tablefmt=fmt) return body + + def __repr__(self) -> str: + return self.format() diff --git a/torch_geometric/nn/module_dict.py b/torch_geometric/nn/module_dict.py index 2f69e5eb7c53..983619b8db29 100644 --- a/torch_geometric/nn/module_dict.py +++ b/torch_geometric/nn/module_dict.py @@ -1,4 +1,4 @@ -from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union +from typing import Final, Iterable, Mapping, Optional, Tuple, Union import torch from torch.nn import Module @@ -11,7 +11,7 @@ # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ModuleDict(torch.nn.ModuleDict): - CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ModuleDict)) + CLASS_ATTRS: Final[Tuple[str, ...]] = tuple(dir(torch.nn.ModuleDict)) def __init__( self, diff --git a/torch_geometric/nn/parameter_dict.py b/torch_geometric/nn/parameter_dict.py index a2dafc27c390..15f4f9f57e92 100644 --- a/torch_geometric/nn/parameter_dict.py +++ b/torch_geometric/nn/parameter_dict.py @@ -1,4 +1,4 @@ -from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union +from typing import Final, Iterable, Mapping, Optional, Tuple, Union import torch from torch.nn import Parameter @@ -11,7 +11,7 @@ # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ParameterDict(torch.nn.ParameterDict): - CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ParameterDict)) + CLASS_ATTRS: Final[Tuple[str, ...]] = set(dir(torch.nn.ParameterDict)) def __init__( self, diff --git a/torch_geometric/testing/decorators.py b/torch_geometric/testing/decorators.py index ca185f05a06c..40a21ef9f3ff 100644 --- a/torch_geometric/testing/decorators.py +++ b/torch_geometric/testing/decorators.py @@ -7,6 +7,7 @@ import torch from packaging.requirements import Requirement +from packaging.version import Version from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE from torch_geometric.visualization.graph import has_graphviz @@ -180,12 +181,7 @@ def has_package(package: str) -> bool: if not hasattr(module, '__version__'): return True - version = module.__version__ - # `req.specifier` does not support `.dev` suffixes, e.g., for - # `pyg_lib==0.1.0.dev*`, so we manually drop them: - if '.dev' in version: - version = '.'.join(version.split('.dev')[:-1]) - + version = Version(module.__version__).base_version return version in req.specifier