Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jun 14, 2024
1 parent 91ad640 commit ec79da3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
17 changes: 5 additions & 12 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,27 +362,20 @@ def __repr__(self) -> str:
arg_repr = str(len(self)) if len(self) > 1 else ''
return f'{self.__class__.__name__}({arg_repr})'

def get_summary(self, fmt: str = "psql") -> Any:
r"""Collects summary statistics for the dataset.
Args:
fmt (str, optional): Summary tables format. Available table formats
can be found `here <https://github.com/astanin/python-tabulate?
tab=readme-ov-file#table-format>`__. (default: :obj:`psql`)
"""
def get_summary(self) -> Any:
r"""Collects summary statistics for the dataset."""
from torch_geometric.data.summary import Summary
return Summary.from_dataset(self, fmt=fmt)
return Summary.from_dataset(self)

def print_summary(self, fmt: str = "psql") -> None:
r"""Prints summary statistics of the dataset to the console.
Args:
fmt (str, optional): Summary tables format. Available table formats
can be found `here <https://github.com/astanin/python-tabulate?
tab=readme-ov-file#table-format>`__. (default: :obj:`psql`)
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
"""
print(str(self.get_summary(fmt=fmt)))
print(self.get_summary().format(fmt=fmt))

def to_datapipe(self) -> Any:
r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
Expand Down
24 changes: 14 additions & 10 deletions torch_geometric/data/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def from_data(

@dataclass(repr=False)
class Summary:
fmt: str
name: str
num_graphs: int
num_nodes: Stats
Expand All @@ -56,7 +55,6 @@ def from_dataset(
dataset: Dataset,
progress_bar: Optional[bool] = None,
per_type: bool = True,
fmt: str = "psql",
) -> Self:
r"""Creates a summary of a :class:`~torch_geometric.data.Dataset`
object.
Expand All @@ -70,9 +68,6 @@ def from_dataset(
per_type (bool, optional): If set to :obj:`True`, will separate
statistics per node and edge type (only applicable in
heterogeneous graph datasets). (default: :obj:`True`)
fmt (str, optional): Summary tables format. Available table formats
can be found `here <https://github.com/astanin/python-tabulate?
tab=readme-ov-file#table-format>`__. (default: :obj:`psql`)
"""
name = dataset.__class__.__name__

Expand Down Expand Up @@ -114,7 +109,6 @@ def from_dataset(
}

return cls(
fmt=fmt,
name=name,
num_graphs=len(dataset),
num_nodes=Stats.from_data(num_nodes),
Expand All @@ -123,7 +117,14 @@ def from_dataset(
num_edges_per_type=num_edges_per_type,
)

def __repr__(self) -> str:
def format(self, fmt: str = "psql") -> str:
r"""Formats summary statistics of the dataset.
Args:
fmt (str, optional): Summary tables format. Available table formats
can be found `here <https://github.com/astanin/python-tabulate?
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
"""
from tabulate import tabulate

body = f'{self.name} (#graphs={self.num_graphs}):\n'
Expand All @@ -133,7 +134,7 @@ def __repr__(self) -> str:
for field in Stats.__dataclass_fields__:
row = [field] + [f'{getattr(s, field):.1f}' for s in stats]
content.append(row)
body += tabulate(content, headers='firstrow', tablefmt=self.fmt)
body += tabulate(content, headers='firstrow', tablefmt=fmt)

if self.num_nodes_per_type is not None:
content = [['']]
Expand All @@ -146,7 +147,7 @@ def __repr__(self) -> str:
]
content.append(row)
body += "\nNumber of nodes per node type:\n"
body += tabulate(content, headers='firstrow', tablefmt=self.fmt)
body += tabulate(content, headers='firstrow', tablefmt=fmt)

if self.num_edges_per_type is not None:
content = [['']]
Expand All @@ -162,6 +163,9 @@ def __repr__(self) -> str:
]
content.append(row)
body += "\nNumber of edges per edge type:\n"
body += tabulate(content, headers='firstrow', tablefmt=self.fmt)
body += tabulate(content, headers='firstrow', tablefmt=fmt)

return body

def __repr__(self) -> str:
return self.format()

0 comments on commit ec79da3

Please sign in to comment.