Skip to content

Commit

Permalink
Update changlelog, fix import in pprgo test
Browse files Browse the repository at this point in the history
  • Loading branch information
ryspark committed Dec 11, 2024
1 parent 0eb2673 commit 7042d89
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
- Added `PPRGo` implementation and example ([#9847](https://github.com/pyg-team/pytorch_geometric/pull/9847))
- Allow top-k sparsification in `utils.get_ppr` and `transforms.GDC` ([#9847](https://github.com/pyg-team/pytorch_geometric/pull/9847))

### Changed

Expand Down
8 changes: 4 additions & 4 deletions test/nn/models/test_pprgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from torch_geometric.datasets import KarateClub
from torch_geometric.nn.models.pprgo import PPRGo, prune_features
from torch_geometric.nn.models.pprgo import PPRGo, pprgo_prune_features


@pytest.mark.parametrize('n_layers', [1, 4])
Expand All @@ -22,7 +22,7 @@ def test_pprgo_forward(n_layers, dropout):
torch.randint(0, num_nodes, [num_edges])
], dim=0)

# Mimic the behavior of prune_features manually
# Mimic the behavior of pprgo_prune_features manually
# i.e., we expect node_embeds to be |V| x d
node_embeds = torch.rand((num_nodes, num_features))
node_embeds = node_embeds[edge_index[1], :]
Expand All @@ -38,7 +38,7 @@ def test_pprgo_karate():
data = KarateClub()[0]
num_nodes = data.num_nodes

data = prune_features(data)
data = pprgo_prune_features(data)
data.edge_weight = torch.ones((data.edge_index.shape[1], ))

assert data.x.shape[0] == data.edge_index.shape[1]
Expand All @@ -56,7 +56,7 @@ def test_pprgo_inference(n_power_iters, frac_predict, batch_size):
data = KarateClub()[0]
num_nodes = data.num_nodes

data = prune_features(data)
data = pprgo_prune_features(data)
data.edge_weight = torch.rand(data.edge_index.shape[1])

num_classes = 16
Expand Down

0 comments on commit 7042d89

Please sign in to comment.