Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XPU support to basic GNN examples #9421

Merged
merged 10 commits into from
Jun 24, 2024
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249))
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))
- Added documentation on Environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))
- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))

### Changed

Expand All @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001))
- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))
- Added XPU support to basic GNN examples ([#9421](https://github.com/pyg-team/pytorch_geometric/pull/9001))

### Deprecated

Expand Down
9 changes: 8 additions & 1 deletion examples/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.logging import init_wandb, log
Expand All @@ -19,7 +20,13 @@
parser.add_argument('--wandb', action='store_true', help='Track experiment')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')

init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs,
hidden_channels=args.hidden_channels, lr=args.lr, device=device)

Expand Down
8 changes: 2 additions & 6 deletions examples/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.logging import init_wandb, log
Expand All @@ -19,12 +20,7 @@
parser.add_argument('--wandb', action='store_true', help='Track experiment')
args = parser.parse_args()

if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
device = torch_geometric.device('auto')

init_wandb(
name=f'GCN-{args.dataset}',
Expand Down
9 changes: 7 additions & 2 deletions examples/graph_sage_unsup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import LinkNeighborLoader
Expand All @@ -22,8 +23,12 @@
neg_sampling_ratio=1.0,
num_neighbors=[10, 10],
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
data = data.to(device, 'x', 'edge_index')

model = GraphSAGE(
Expand Down
8 changes: 7 additions & 1 deletion examples/graph_sage_unsup_ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.metrics import f1_score
from sklearn.multioutput import MultiOutputClassifier

import torch_geometric
from torch_geometric.data import Batch
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader, LinkNeighborLoader
Expand All @@ -29,7 +30,12 @@
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
model = GraphSAGE(
in_channels=train_dataset.num_features,
hidden_channels=64,
Expand Down
8 changes: 7 additions & 1 deletion examples/pna.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool
Expand Down Expand Up @@ -66,7 +67,12 @@ def forward(self, x, edge_index, edge_attr, batch):
return self.mlp(x)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .edge_index import EdgeIndex
from .seed import seed_everything
from .home import get_home_dir, set_home_dir
from .device import is_mps_available, is_xpu_available, device
from .isinstance import is_torch_instance
from .debug import is_debug_enabled, debug, set_debug

Expand Down Expand Up @@ -33,6 +34,9 @@
'set_home_dir',
'compile',
'is_compiling',
'is_mps_available',
'is_xpu_available',
'device',
'is_torch_instance',
'is_debug_enabled',
'debug',
Expand Down
42 changes: 42 additions & 0 deletions torch_geometric/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any

import torch


def is_mps_available() -> bool:
r"""Returns a bool indicating if MPS is currently available."""
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
try: # Github CI may not have access to MPS hardware. Confirm:
torch.empty(1, device='mps')
return True
except Exception:
return False
return False


def is_xpu_available() -> bool:
r"""Returns a bool indicating if XPU is currently available."""
if hasattr(torch, 'xpu') and torch.xpu.is_available():
return True
try:
import intel_extension_for_pytorch as ipex
return ipex.xpu.is_available()
except ImportError:
return False


def device(device: Any) -> torch.device:
r"""Returns a :class:`torch.device`.

If :obj:`"auto"` is specified, returns the optimal device depending on
available hardware.
"""
if device != 'auto':
return torch.device(device)
if torch.cuda.is_available():
return torch.device('cuda')
if is_mps_available():
return torch.device('mps')
if is_xpu_available():
return torch.device('xpu')
return torch.device('cpu')
27 changes: 5 additions & 22 deletions torch_geometric/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from packaging.requirements import Requirement
from packaging.version import Version

import torch_geometric
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
from torch_geometric.visualization.graph import has_graphviz

Expand Down Expand Up @@ -113,13 +114,8 @@ def onlyCUDA(func: Callable) -> Callable:
def onlyXPU(func: Callable) -> Callable:
r"""A decorator to skip tests if XPU is not found."""
import pytest
try:
import intel_extension_for_pytorch as ipex
xpu_available = ipex.xpu.is_available()
except ImportError:
xpu_available = False
return pytest.mark.skipif(
not xpu_available,
not torch_geometric.is_xpu_available(),
reason="XPU not available",
)(func)

Expand Down Expand Up @@ -227,23 +223,10 @@ def withDevice(func: Callable) -> Callable:
if torch.cuda.is_available():
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
try: # Github CI may not have access to MPS hardware. Confirm:
torch.empty(1, device='mps')
devices.append(pytest.param(torch.device('mps:0'), id='mps'))
except RuntimeError:
pass

if not hasattr(torch, 'xpu'):
try:
import intel_extension_for_pytorch as ipex
xpu_available = ipex.xpu.is_available()
except ImportError:
xpu_available = False
else:
xpu_available = torch.xpu.is_available()
if torch_geometric.is_mps_available():
devices.append(pytest.param(torch.device('mps:0'), id='mps'))

if xpu_available:
if torch_geometric.is_xpu_available():
devices.append(pytest.param(torch.device('xpu:0'), id='xpu'))

# Additional devices can be registered through environment variables:
Expand Down
Loading