Skip to content

Commit

Permalink
Merge branch 'master' into gnn-llm-model-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Jun 14, 2024
2 parents 3e6019b + 0419f0f commit 07b747b
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ benchmark:
- changed-files:
- any-glob-to-any-file: ["benchmark/**/*", "torch_geometric/profile/**/*"]

skip-changelog:
auto-skip-changelog:
- head-branch: ['^skip', '^pre-commit-ci']
2 changes: 1 addition & 1 deletion .github/workflows/changelog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ jobs:
- name: Enforce changelog entry
uses: dangoslen/changelog-enforcer@v3
with:
skipLabels: 'skip-changelog'
skipLabels: skip-changelog, auto-skip-changelog
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Added the `fmt` arg to `Dataset.get_summary` ([#9408](https://github.com/pyg-team/pytorch_geometric/pull/9408))
- Skipped zero atom molecules in `MoleculeNet` ([#9318](https://github.com/pyg-team/pytorch_geometric/pull/9318))
- Ensure proper parallelism in `OnDiskDataset` for multi-threaded `get` calls ([#9140](https://github.com/pyg-team/pytorch_geometric/pull/9140))
- Allow `None` outputs in `FeatureStore.get_tensor()` - `KeyError` should now be raised based on the implementation in `FeatureStore._get_tensor()` ([#9102](https://github.com/pyg-team/pytorch_geometric/pull/9102))
Expand Down
12 changes: 9 additions & 3 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,15 @@ def get_summary(self) -> Any:
from torch_geometric.data.summary import Summary
return Summary.from_dataset(self)

def print_summary(self) -> None:
r"""Prints summary statistics of the dataset to the console."""
print(str(self.get_summary()))
def print_summary(self, fmt: str = "psql") -> None:
r"""Prints summary statistics of the dataset to the console.
Args:
fmt (str, optional): Summary tables format. Available table formats
can be found `here <https://github.com/astanin/python-tabulate?
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
"""
print(self.get_summary().format(fmt=fmt))

def to_datapipe(self) -> Any:
r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
Expand Down
18 changes: 14 additions & 4 deletions torch_geometric/data/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,14 @@ def from_dataset(
num_edges_per_type=num_edges_per_type,
)

def __repr__(self) -> str:
def format(self, fmt: str = "psql") -> str:
r"""Formats summary statistics of the dataset.
Args:
fmt (str, optional): Summary tables format. Available table formats
can be found `here <https://github.com/astanin/python-tabulate?
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
"""
from tabulate import tabulate

body = f'{self.name} (#graphs={self.num_graphs}):\n'
Expand All @@ -127,7 +134,7 @@ def __repr__(self) -> str:
for field in Stats.__dataclass_fields__:
row = [field] + [f'{getattr(s, field):.1f}' for s in stats]
content.append(row)
body += tabulate(content, headers='firstrow', tablefmt='psql')
body += tabulate(content, headers='firstrow', tablefmt=fmt)

if self.num_nodes_per_type is not None:
content = [['']]
Expand All @@ -140,7 +147,7 @@ def __repr__(self) -> str:
]
content.append(row)
body += "\nNumber of nodes per node type:\n"
body += tabulate(content, headers='firstrow', tablefmt='psql')
body += tabulate(content, headers='firstrow', tablefmt=fmt)

if self.num_edges_per_type is not None:
content = [['']]
Expand All @@ -156,6 +163,9 @@ def __repr__(self) -> str:
]
content.append(row)
body += "\nNumber of edges per edge type:\n"
body += tabulate(content, headers='firstrow', tablefmt='psql')
body += tabulate(content, headers='firstrow', tablefmt=fmt)

return body

def __repr__(self) -> str:
return self.format()
4 changes: 2 additions & 2 deletions torch_geometric/nn/module_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union
from typing import Final, Iterable, Mapping, Optional, Tuple, Union

import torch
from torch.nn import Module
Expand All @@ -11,7 +11,7 @@
# internal representation and converts it back to `.` in the external
# representation. It also allows passing tuples as keys.
class ModuleDict(torch.nn.ModuleDict):
CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ModuleDict))
CLASS_ATTRS: Final[Tuple[str, ...]] = tuple(dir(torch.nn.ModuleDict))

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/parameter_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union
from typing import Final, Iterable, Mapping, Optional, Tuple, Union

import torch
from torch.nn import Parameter
Expand All @@ -11,7 +11,7 @@
# internal representation and converts it back to `.` in the external
# representation. It also allows passing tuples as keys.
class ParameterDict(torch.nn.ParameterDict):
CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ParameterDict))
CLASS_ATTRS: Final[Tuple[str, ...]] = set(dir(torch.nn.ParameterDict))

def __init__(
self,
Expand Down
8 changes: 2 additions & 6 deletions torch_geometric/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from packaging.requirements import Requirement
from packaging.version import Version

from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
from torch_geometric.visualization.graph import has_graphviz
Expand Down Expand Up @@ -180,12 +181,7 @@ def has_package(package: str) -> bool:
if not hasattr(module, '__version__'):
return True

version = module.__version__
# `req.specifier` does not support `.dev` suffixes, e.g., for
# `pyg_lib==0.1.0.dev*`, so we manually drop them:
if '.dev' in version:
version = '.'.join(version.split('.dev')[:-1])

version = Version(module.__version__).base_version
return version in req.specifier


Expand Down

0 comments on commit 07b747b

Please sign in to comment.