Skip to content

Commit

Permalink
Add XPU support to basic GNN examples (#9421)
Browse files Browse the repository at this point in the history
Add XPU support to basic GNN examples(gat,gcn, graphsage. etc.)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored Jun 24, 2024
1 parent 9b7874b commit cd36c86
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 34 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249))
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))
- Added documentation on Environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))
- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))

### Changed

Expand All @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001))
- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))
- Added XPU support to basic GNN examples ([#9421](https://github.com/pyg-team/pytorch_geometric/pull/9001))

### Deprecated

Expand Down
9 changes: 8 additions & 1 deletion examples/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.logging import init_wandb, log
Expand All @@ -19,7 +20,13 @@
parser.add_argument('--wandb', action='store_true', help='Track experiment')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')

init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs,
hidden_channels=args.hidden_channels, lr=args.lr, device=device)

Expand Down
8 changes: 2 additions & 6 deletions examples/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.logging import init_wandb, log
Expand All @@ -19,12 +20,7 @@
parser.add_argument('--wandb', action='store_true', help='Track experiment')
args = parser.parse_args()

if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
device = torch_geometric.device('auto')

init_wandb(
name=f'GCN-{args.dataset}',
Expand Down
9 changes: 7 additions & 2 deletions examples/graph_sage_unsup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import LinkNeighborLoader
Expand All @@ -22,8 +23,12 @@
neg_sampling_ratio=1.0,
num_neighbors=[10, 10],
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
data = data.to(device, 'x', 'edge_index')

model = GraphSAGE(
Expand Down
8 changes: 7 additions & 1 deletion examples/graph_sage_unsup_ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.metrics import f1_score
from sklearn.multioutput import MultiOutputClassifier

import torch_geometric
from torch_geometric.data import Batch
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader, LinkNeighborLoader
Expand All @@ -29,7 +30,12 @@
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
model = GraphSAGE(
in_channels=train_dataset.num_features,
hidden_channels=64,
Expand Down
8 changes: 7 additions & 1 deletion examples/pna.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool
Expand Down Expand Up @@ -66,7 +67,12 @@ def forward(self, x, edge_index, edge_attr, batch):
return self.mlp(x)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .edge_index import EdgeIndex
from .seed import seed_everything
from .home import get_home_dir, set_home_dir
from .device import is_mps_available, is_xpu_available, device
from .isinstance import is_torch_instance
from .debug import is_debug_enabled, debug, set_debug

Expand Down Expand Up @@ -33,6 +34,9 @@
'set_home_dir',
'compile',
'is_compiling',
'is_mps_available',
'is_xpu_available',
'device',
'is_torch_instance',
'is_debug_enabled',
'debug',
Expand Down
42 changes: 42 additions & 0 deletions torch_geometric/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any

import torch


def is_mps_available() -> bool:
r"""Returns a bool indicating if MPS is currently available."""
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
try: # Github CI may not have access to MPS hardware. Confirm:
torch.empty(1, device='mps')
return True
except Exception:
return False
return False


def is_xpu_available() -> bool:
r"""Returns a bool indicating if XPU is currently available."""
if hasattr(torch, 'xpu') and torch.xpu.is_available():
return True
try:
import intel_extension_for_pytorch as ipex
return ipex.xpu.is_available()
except ImportError:
return False


def device(device: Any) -> torch.device:
r"""Returns a :class:`torch.device`.
If :obj:`"auto"` is specified, returns the optimal device depending on
available hardware.
"""
if device != 'auto':
return torch.device(device)
if torch.cuda.is_available():
return torch.device('cuda')
if is_mps_available():
return torch.device('mps')
if is_xpu_available():
return torch.device('xpu')
return torch.device('cpu')
27 changes: 5 additions & 22 deletions torch_geometric/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from packaging.requirements import Requirement
from packaging.version import Version

import torch_geometric
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
from torch_geometric.visualization.graph import has_graphviz

Expand Down Expand Up @@ -113,13 +114,8 @@ def onlyCUDA(func: Callable) -> Callable:
def onlyXPU(func: Callable) -> Callable:
r"""A decorator to skip tests if XPU is not found."""
import pytest
try:
import intel_extension_for_pytorch as ipex
xpu_available = ipex.xpu.is_available()
except ImportError:
xpu_available = False
return pytest.mark.skipif(
not xpu_available,
not torch_geometric.is_xpu_available(),
reason="XPU not available",
)(func)

Expand Down Expand Up @@ -227,23 +223,10 @@ def withDevice(func: Callable) -> Callable:
if torch.cuda.is_available():
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
try: # Github CI may not have access to MPS hardware. Confirm:
torch.empty(1, device='mps')
devices.append(pytest.param(torch.device('mps:0'), id='mps'))
except RuntimeError:
pass

if not hasattr(torch, 'xpu'):
try:
import intel_extension_for_pytorch as ipex
xpu_available = ipex.xpu.is_available()
except ImportError:
xpu_available = False
else:
xpu_available = torch.xpu.is_available()
if torch_geometric.is_mps_available():
devices.append(pytest.param(torch.device('mps:0'), id='mps'))

if xpu_available:
if torch_geometric.is_xpu_available():
devices.append(pytest.param(torch.device('xpu:0'), id='xpu'))

# Additional devices can be registered through environment variables:
Expand Down

0 comments on commit cd36c86

Please sign in to comment.