From 1f5e201e0559a8669b0ba0bb58e11327efbce540 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Fri, 14 Jun 2024 08:36:36 +0200 Subject: [PATCH 1/4] Fix `ModuleDict`/`ParameterDict` in combination with `torch.jit.script` (#9424) Fixes https://github.com/pyg-team/pytorch_geometric/issues/9401 --- torch_geometric/nn/module_dict.py | 4 ++-- torch_geometric/nn/parameter_dict.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_geometric/nn/module_dict.py b/torch_geometric/nn/module_dict.py index 2f69e5eb7c53..983619b8db29 100644 --- a/torch_geometric/nn/module_dict.py +++ b/torch_geometric/nn/module_dict.py @@ -1,4 +1,4 @@ -from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union +from typing import Final, Iterable, Mapping, Optional, Tuple, Union import torch from torch.nn import Module @@ -11,7 +11,7 @@ # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ModuleDict(torch.nn.ModuleDict): - CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ModuleDict)) + CLASS_ATTRS: Final[Tuple[str, ...]] = tuple(dir(torch.nn.ModuleDict)) def __init__( self, diff --git a/torch_geometric/nn/parameter_dict.py b/torch_geometric/nn/parameter_dict.py index a2dafc27c390..15f4f9f57e92 100644 --- a/torch_geometric/nn/parameter_dict.py +++ b/torch_geometric/nn/parameter_dict.py @@ -1,4 +1,4 @@ -from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union +from typing import Final, Iterable, Mapping, Optional, Tuple, Union import torch from torch.nn import Parameter @@ -11,7 +11,7 @@ # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ParameterDict(torch.nn.ParameterDict): - CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ParameterDict)) + CLASS_ATTRS: Final[Tuple[str, ...]] = set(dir(torch.nn.ParameterDict)) def __init__( self, From 8842f7c4d34bdeee5c5bd6a695b3c86cfb9c4b5d Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <32910461+drivanov@users.noreply.github.com> Date: Thu, 13 Jun 2024 23:59:25 -0700 Subject: [PATCH 2/4] [BUG] Fixing Bug in Version Handling (#9410) I've noticed that the tests with the `withPackage` test decorators ``` @withPackage("torch>=1.12.0") def test_some_test(...): . . . ``` are not being executed, ``` SKIPPED [2] test/nn/conv/test_heat_conv.py:10: Package torch>=1.12.0 not found SKIPPED [2] test/nn/conv/test_hetero_conv.py:181: Package torch>=2.1.0 not found SKIPPED [1] test/nn/conv/test_hgt_conv.py:12: Package torch>=1.12.0 not found SKIPPED [1] test/nn/conv/test_hgt_conv.py:63: Package torch>=1.12.0 not found SKIPPED [1] test/nn/conv/test_hgt_conv.py:114: Package torch>=1.12.0 not found ``` even though I have torch=2.4.0 installed on my machine. To be precise, I am using Torch version `2.4.0a0+f70bd71a48.nv24.06` which has an unexpected version ID format for the for `pytorch_geometric` . The proposed PR uses standard [Python version handling](https://packaging.python.org/en/latest/specifications/version-specifiers/), which covers the dev case previously treated separately. **NOTES:** - I found and fixed a similar issue in `pytorch_frame`, see [PR#410](https://github.com/pyg-team/pytorch-frame/pull/410). - When testing the fix for the current PR, I noticed that 113 (=600-483) more tests were run from the `test` directory ``` Befor:========== 5926 passed, 600 skipped, 503 warnings in 535.56s (0:08:55) After:========== 13 failed, 6030 passed, 483 skipped, 521 warnings in 596.27s (0:09:56) ``` and 13 of them failed. Similar failures may occur in your CI environment. Please consider this during the PR review. --- torch_geometric/testing/decorators.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torch_geometric/testing/decorators.py b/torch_geometric/testing/decorators.py index ca185f05a06c..40a21ef9f3ff 100644 --- a/torch_geometric/testing/decorators.py +++ b/torch_geometric/testing/decorators.py @@ -7,6 +7,7 @@ import torch from packaging.requirements import Requirement +from packaging.version import Version from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE from torch_geometric.visualization.graph import has_graphviz @@ -180,12 +181,7 @@ def has_package(package: str) -> bool: if not hasattr(module, '__version__'): return True - version = module.__version__ - # `req.specifier` does not support `.dev` suffixes, e.g., for - # `pyg_lib==0.1.0.dev*`, so we manually drop them: - if '.dev' in version: - version = '.'.join(version.split('.dev')[:-1]) - + version = Version(module.__version__).base_version return version in req.specifier From 03515274bc21ac8fbed6b5c4b8d82cb463b035f0 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Fri, 14 Jun 2024 08:59:41 +0200 Subject: [PATCH 3/4] Introduce `auto-skip-changelog` (#9400) Currently, CI will reset `skip-changelog` label if it is manually specified. Separate the two into different labels. --- .github/labeler.yml | 2 +- .github/workflows/changelog.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index 377d3a0d384e..ae2f57cc925a 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -54,5 +54,5 @@ benchmark: - changed-files: - any-glob-to-any-file: ["benchmark/**/*", "torch_geometric/profile/**/*"] -skip-changelog: +auto-skip-changelog: - head-branch: ['^skip', '^pre-commit-ci'] diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index 203ad3e922d8..5e6892bfd1f3 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -13,4 +13,4 @@ jobs: - name: Enforce changelog entry uses: dangoslen/changelog-enforcer@v3 with: - skipLabels: 'skip-changelog' + skipLabels: skip-changelog, auto-skip-changelog From 0419f0f2d372056b0d7b939e15069b41b35a98a7 Mon Sep 17 00:00:00 2001 From: m-atalla Date: Fri, 14 Jun 2024 10:24:03 +0300 Subject: [PATCH 4/4] Allow users to modify dataset summary table format (#9408) Currently dataset summary is generated in `psql` table format with no way to modify it. This PR allows users to modify the format of a generated dataset summary. I have also added a link to the [available table formats](https://github.com/astanin/python-tabulate?tab=readme-ov-file#table-format), I'm not sure if this change requires additional testing, but let me know if it does along with any other changes needed. Thank you. --------- Co-authored-by: rusty1s --- CHANGELOG.md | 1 + torch_geometric/data/dataset.py | 12 +++++++++--- torch_geometric/data/summary.py | 18 ++++++++++++++---- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c536a14fa1df..5d59ece83b93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Added the `fmt` arg to `Dataset.get_summary` ([#9408](https://github.com/pyg-team/pytorch_geometric/pull/9408)) - Skipped zero atom molecules in `MoleculeNet` ([#9318](https://github.com/pyg-team/pytorch_geometric/pull/9318)) - Ensure proper parallelism in `OnDiskDataset` for multi-threaded `get` calls ([#9140](https://github.com/pyg-team/pytorch_geometric/pull/9140)) - Allow `None` outputs in `FeatureStore.get_tensor()` - `KeyError` should now be raised based on the implementation in `FeatureStore._get_tensor()` ([#9102](https://github.com/pyg-team/pytorch_geometric/pull/9102)) diff --git a/torch_geometric/data/dataset.py b/torch_geometric/data/dataset.py index ea769a0714f3..f510d3012f20 100644 --- a/torch_geometric/data/dataset.py +++ b/torch_geometric/data/dataset.py @@ -367,9 +367,15 @@ def get_summary(self) -> Any: from torch_geometric.data.summary import Summary return Summary.from_dataset(self) - def print_summary(self) -> None: - r"""Prints summary statistics of the dataset to the console.""" - print(str(self.get_summary())) + def print_summary(self, fmt: str = "psql") -> None: + r"""Prints summary statistics of the dataset to the console. + + Args: + fmt (str, optional): Summary tables format. Available table formats + can be found `here `__. (default: :obj:`"psql"`) + """ + print(self.get_summary().format(fmt=fmt)) def to_datapipe(self) -> Any: r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`. diff --git a/torch_geometric/data/summary.py b/torch_geometric/data/summary.py index 06b70c30a290..e90b504615e2 100644 --- a/torch_geometric/data/summary.py +++ b/torch_geometric/data/summary.py @@ -117,7 +117,14 @@ def from_dataset( num_edges_per_type=num_edges_per_type, ) - def __repr__(self) -> str: + def format(self, fmt: str = "psql") -> str: + r"""Formats summary statistics of the dataset. + + Args: + fmt (str, optional): Summary tables format. Available table formats + can be found `here `__. (default: :obj:`"psql"`) + """ from tabulate import tabulate body = f'{self.name} (#graphs={self.num_graphs}):\n' @@ -127,7 +134,7 @@ def __repr__(self) -> str: for field in Stats.__dataclass_fields__: row = [field] + [f'{getattr(s, field):.1f}' for s in stats] content.append(row) - body += tabulate(content, headers='firstrow', tablefmt='psql') + body += tabulate(content, headers='firstrow', tablefmt=fmt) if self.num_nodes_per_type is not None: content = [['']] @@ -140,7 +147,7 @@ def __repr__(self) -> str: ] content.append(row) body += "\nNumber of nodes per node type:\n" - body += tabulate(content, headers='firstrow', tablefmt='psql') + body += tabulate(content, headers='firstrow', tablefmt=fmt) if self.num_edges_per_type is not None: content = [['']] @@ -156,6 +163,9 @@ def __repr__(self) -> str: ] content.append(row) body += "\nNumber of edges per edge type:\n" - body += tabulate(content, headers='firstrow', tablefmt='psql') + body += tabulate(content, headers='firstrow', tablefmt=fmt) return body + + def __repr__(self) -> str: + return self.format()