Skip to content

Commit

Permalink
Ensure OnDiskDataset can operate on DataLoader(num_workers>0) (#8092
Browse files Browse the repository at this point in the history
)
  • Loading branch information
rusty1s authored Sep 29, 2023
1 parent b3b3d78 commit 68552e7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088))
- Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088), [#8092](https://github.com/pyg-team/pytorch_geometric/pull/8092))
- Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938)
- Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894)
- Added `edge_attr` support to `ResGatedGraphConv` ([#8048](https://github.com/pyg-team/pytorch_geometric/pull/8048))
Expand Down
5 changes: 3 additions & 2 deletions test/loader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@ def test_dataloader(num_workers, device):
assert batch.edge_index_batch.tolist() == [0, 0, 0, 0, 1, 1, 1, 1]


def test_dataloader_on_disk_dataset(tmp_path):
@pytest.mark.parametrize('num_workers', num_workers_list)
def test_dataloader_on_disk_dataset(tmp_path, num_workers):
dataset = OnDiskDataset(tmp_path)
data1 = Data(x=torch.randn(3, 8))
data2 = Data(x=torch.randn(4, 8))
dataset.extend([data1, data2])

loader = DataLoader(dataset, batch_size=2)
loader = DataLoader(dataset, batch_size=2, num_workers=num_workers)
assert len(loader) == 1
batch = next(iter(loader))
assert batch.num_nodes == 7
Expand Down

0 comments on commit 68552e7

Please sign in to comment.