Skip to content

Commit

Permalink
Resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
michailramp committed Jan 14, 2025
2 parents f366c9e + 60a1141 commit c9b4ff4
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 69 deletions.
14 changes: 13 additions & 1 deletion .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
installation:
- changed-files:
- any-glob-to-any-file: ["pyproject.toml"]

ci:
- changed-files:
- any-glob-to-any-file: [".github/**/*", "codecov.yaml", ".pre-commit-config.yaml"]

documentation:
- changed-files:
- any-glob-to-any-file: "docs/**/*"
- any-glob-to-any-file: ["docs/**/*", "readthedocs.yml", "README.MD"]

example:
- changed-files:
Expand Down Expand Up @@ -34,6 +42,10 @@ transform:
- changed-files:
- any-glob-to-any-file: "torch_geometric/transforms/**/*"

metrics:
- changed-files:
- any-glob-to-any-file: "torch_geometric/metrics/**/*"

utils:
- changed-files:
- any-glob-to-any-file: "torch_geometric/utils/**/*"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/full_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
runs-on: ${{ matrix.os }}

strategy:
max-parallel: 10
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-14]
Expand Down
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added the `TokenGT` model along with `AddOrthonormalNodeIdentifiers` transform and example usage ([#9834](https://github.com/pyg-team/pytorch_geometric/pull/9834))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
- Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand All @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Updated Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606))
- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807))

Expand All @@ -37,6 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756))
- Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766))
- Fixed `WebQSDataset.process` raising exceptions ([#9665](https://github.com/pyg-team/pytorch_geometric/pull/9665))
- Fixed `is_node_attr()` and `is_edge_attr()` errors when `cat_dim` is a tuple ([#9895](https://github.com/pyg-team/pytorch_geometric/issues/9895))

### Removed

Expand Down
2 changes: 1 addition & 1 deletion examples/llm/molecule_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def adjust_learning_rate(param_group, LR, epoch):
f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501
)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()

print(f"Total Training Time: {time.time() - start_time:2f}s")
# Test
Expand Down
3 changes: 3 additions & 0 deletions readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
version: 2

sphinx:
configuration: docs/source/conf.py

build:
os: ubuntu-22.04
tools:
Expand Down
56 changes: 41 additions & 15 deletions test/data/lightning/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MyFeatureStore,
MyGraphStore,
get_random_edge_index,
has_package,
onlyCUDA,
onlyFullTest,
onlyNeighborSampler,
Expand Down Expand Up @@ -114,12 +115,20 @@ def expect_rank_zero_user_warning(match: str):
num_workers=3, shuffle=True)
assert 'shuffle' not in datamodule.kwargs
old_x = train_dataset._data.x.clone()
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), '
'pred_dataset=MUTAG(98), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)')
if has_package('pytorch_lightning>=2.5.0'):
datamodule_repr = ('{Train dataloader: size=50}\n'
'{Validation dataloader: size=30}\n'
'{Test dataloader: size=10}\n'
'{Predict dataloader: size=98}')
else:
datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), '
'pred_dataset=MUTAG(98), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)')
assert str(datamodule) == datamodule_repr

trainer.fit(model, datamodule)
trainer.test(model, datamodule)
new_x = train_dataset._data.x
Expand All @@ -133,10 +142,17 @@ def expect_rank_zero_user_warning(match: str):
log_every_n_steps=1)

datamodule = LightningDataset(train_dataset, batch_size=5)
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=0, '
'pin_memory=True, '
'persistent_workers=False)')
if has_package('pytorch_lightning>=2.5.0'):
datamodule_repr = ('{Train dataloader: size=50}\n'
'{Validation dataloader: None}\n'
'{Test dataloader: None}\n{'
'Predict dataloader: None}')
else:
datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=0, '
'pin_memory=True, '
'persistent_workers=False)')
assert str(datamodule) == datamodule_repr

with expect_rank_zero_user_warning("defined a `validation_step`"):
trainer.fit(model, datamodule)
Expand Down Expand Up @@ -231,11 +247,21 @@ def test_lightning_node_data(get_dataset, strategy_type, loader):
num_workers=num_workers, **kwargs)

old_x = data.x.clone().cpu()
assert str(datamodule) == (f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, {kwargs_repr}'
f'pin_memory={loader != "full"}, '
f'persistent_workers={loader != "full"})')
flag = loader != 'full'
if has_package('pytorch_lightning>=2.5.0'):
datamodule_repr = (
'{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n'
'{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n'
'{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n'
'{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}')
else:
datamodule_repr = (f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, {kwargs_repr}'
f'pin_memory={flag}, '
f'persistent_workers={flag})')
assert str(datamodule) == datamodule_repr

trainer.fit(model, datamodule)
trainer.test(model, datamodule)
new_x = data.x.cpu()
Expand Down
96 changes: 83 additions & 13 deletions test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_geometric.metrics import (
LinkPredF1,
LinkPredMAP,
LinkPredMetricCollection,
LinkPredMRR,
LinkPredNDCG,
LinkPredPrecision,
Expand All @@ -25,7 +26,7 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k):

pred = torch.rand(num_src_nodes, num_dst_nodes)
pred[row, col] += 0.3 # Offset positive links by a little.
top_k_pred_mat = pred.topk(k, dim=1)[1]
pred_index_mat = pred.topk(k, dim=1)[1]

metric = LinkPredPrecision(k)
assert str(metric) == f'LinkPredPrecision(k={k})'
Expand All @@ -39,7 +40,7 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k):
arange[node_id] = torch.arange(node_id.numel())
y_batch = arange[y_batch]

metric.update(top_k_pred_mat[node_id], (y_batch, y_index))
metric.update(pred_index_mat[node_id], (y_batch, y_index))

out = metric.compute()
metric.reset()
Expand All @@ -48,66 +49,135 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k):
for i in range(num_src_nodes): # Naive computation per node:
y_index = col[row == i]
if y_index.numel() > 0:
mask = torch.isin(top_k_pred_mat[i], y_index)
mask = torch.isin(pred_index_mat[i], y_index)
precision = float(mask.sum() / k)
values.append(precision)
expected = torch.tensor(values).mean()
assert torch.allclose(out, expected)

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :k - 1], edge_label_index)
metric.compute()
metric.reset()


def test_recall():
pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])

metric = LinkPredRecall(k=2)
assert str(metric) == 'LinkPredRecall(k=2)'
metric.update(pred_mat, edge_label_index)
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()

assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5))

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index)
metric.compute()
metric.reset()


def test_f1():
pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])

metric = LinkPredF1(k=2)
assert str(metric) == 'LinkPredF1(k=2)'
metric.update(pred_mat, edge_label_index)
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()
assert float(result) == pytest.approx(0.6500)

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index)
metric.compute()
metric.reset()


def test_map():
pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])

metric = LinkPredMAP(k=2)
assert str(metric) == 'LinkPredMAP(k=2)'
metric.update(pred_mat, edge_label_index)
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()
assert float(result) == pytest.approx(0.6250)

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index)
metric.compute()
metric.reset()


def test_ndcg():
pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]])

metric = LinkPredNDCG(k=2)
assert str(metric) == 'LinkPredNDCG(k=2)'
metric.update(pred_mat, edge_label_index)
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()

assert float(result) == pytest.approx(0.6934264)

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index)
metric.compute()
metric.reset()


def test_mrr():
pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])
edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]])

metric = LinkPredMRR(k=2)
assert str(metric) == 'LinkPredMRR(k=2)'
metric.update(pred_mat, edge_label_index)
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()

assert float(result) == pytest.approx((1 + 0.5 + 0) / 3)

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index)
metric.compute()
metric.reset()


@pytest.mark.parametrize('num_src_nodes', [10])
@pytest.mark.parametrize('num_dst_nodes', [50])
@pytest.mark.parametrize('num_edges', [200])
def test_link_pred_metric_collection(num_src_nodes, num_dst_nodes, num_edges):
metrics = [
LinkPredMAP(k=10),
LinkPredPrecision(k=100),
LinkPredRecall(k=50),
]

row = torch.randint(0, num_src_nodes, (num_edges, ))
col = torch.randint(0, num_dst_nodes, (num_edges, ))
edge_label_index = torch.stack([row, col], dim=0)

pred = torch.rand(num_src_nodes, num_dst_nodes)
pred[row, col] += 0.3 # Offset positive links by a little.
pred_index_mat = pred.argsort(dim=1)

metric_collection = LinkPredMetricCollection(metrics)
assert str(metric_collection) == (
'LinkPredMetricCollection([\n'
' LinkPredMAP@10: LinkPredMAP(k=10),\n'
' LinkPredPrecision@100: LinkPredPrecision(k=100),\n'
' LinkPredRecall@50: LinkPredRecall(k=50),\n'
'])')
assert metric_collection.max_k == 100

expected = {}
for metric in metrics:
metric.update(pred_index_mat[:, :metric.k], edge_label_index)
out = metric.compute()
expected[f'{metric.__class__.__name__}@{metric.k}'] = out
metric.reset()

metric_collection.update(pred_index_mat, edge_label_index)
assert metric_collection.compute() == expected
metric_collection.reset()
8 changes: 8 additions & 0 deletions torch_geometric/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,10 @@ def is_node_attr(self, key: str) -> bool:
return False

cat_dim = self._parent().__cat_dim__(key, value, self)

if not isinstance(cat_dim, int):
return False

num_nodes, num_edges = self.num_nodes, self.num_edges

if value.shape[cat_dim] != num_nodes:
Expand Down Expand Up @@ -852,6 +856,10 @@ def is_edge_attr(self, key: str) -> bool:
return False

cat_dim = self._parent().__cat_dim__(key, value, self)

if not isinstance(cat_dim, int):
return False

num_nodes, num_edges = self.num_nodes, self.num_edges

if value.shape[cat_dim] != num_edges:
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa

from .link_pred import (
LinkPredMetricCollection,
LinkPredPrecision,
LinkPredRecall,
LinkPredF1,
Expand All @@ -10,6 +11,7 @@
)

link_pred_metrics = [
'LinkPredMetricCollection',
'LinkPredPrecision',
'LinkPredRecall',
'LinkPredF1',
Expand Down
Loading

0 comments on commit c9b4ff4

Please sign in to comment.