From cd36c863d44b7073d964eda796823c87311c3ce0 Mon Sep 17 00:00:00 2001 From: cj Date: Mon, 24 Jun 2024 22:14:21 +0800 Subject: [PATCH] Add XPU support to basic GNN examples (#9421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add XPU support to basic GNN examples(gat,gcn, graphsage. etc.) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s --- CHANGELOG.md | 3 +- examples/gat.py | 9 +++++- examples/gcn.py | 8 ++--- examples/graph_sage_unsup.py | 9 ++++-- examples/graph_sage_unsup_ppi.py | 8 ++++- examples/pna.py | 8 ++++- torch_geometric/__init__.py | 4 +++ torch_geometric/device.py | 42 +++++++++++++++++++++++++++ torch_geometric/testing/decorators.py | 27 ++++------------- 9 files changed, 84 insertions(+), 34 deletions(-) create mode 100644 torch_geometric/device.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b77526e11d3e..3fdbff2bc90d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249)) - Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983)) - Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952)) -- Added documentation on Environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407)) +- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407)) ### Changed @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001)) - Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937)) - Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918)) +- Added XPU support to basic GNN examples ([#9421](https://github.com/pyg-team/pytorch_geometric/pull/9001)) ### Deprecated diff --git a/examples/gat.py b/examples/gat.py index 1147acdcaca7..09f90011efdd 100644 --- a/examples/gat.py +++ b/examples/gat.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F +import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.logging import init_wandb, log @@ -19,7 +20,13 @@ parser.add_argument('--wandb', action='store_true', help='Track experiment') args = parser.parse_args() -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch_geometric.is_xpu_available(): + device = torch.device('xpu') +else: + device = torch.device('cpu') + init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs, hidden_channels=args.hidden_channels, lr=args.lr, device=device) diff --git a/examples/gcn.py b/examples/gcn.py index 948beb8497df..2842ca41be98 100644 --- a/examples/gcn.py +++ b/examples/gcn.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F +import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.logging import init_wandb, log @@ -19,12 +20,7 @@ parser.add_argument('--wandb', action='store_true', help='Track experiment') args = parser.parse_args() -if torch.cuda.is_available(): - device = torch.device('cuda') -elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - device = torch.device('mps') -else: - device = torch.device('cpu') +device = torch_geometric.device('auto') init_wandb( name=f'GCN-{args.dataset}', diff --git a/examples/graph_sage_unsup.py b/examples/graph_sage_unsup.py index d523cdd7485e..822c005994fa 100644 --- a/examples/graph_sage_unsup.py +++ b/examples/graph_sage_unsup.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from sklearn.linear_model import LogisticRegression +import torch_geometric import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.loader import LinkNeighborLoader @@ -22,8 +23,12 @@ neg_sampling_ratio=1.0, num_neighbors=[10, 10], ) - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch_geometric.is_xpu_available(): + device = torch.device('xpu') +else: + device = torch.device('cpu') data = data.to(device, 'x', 'edge_index') model = GraphSAGE( diff --git a/examples/graph_sage_unsup_ppi.py b/examples/graph_sage_unsup_ppi.py index dd085dfe5d98..19dd73273113 100644 --- a/examples/graph_sage_unsup_ppi.py +++ b/examples/graph_sage_unsup_ppi.py @@ -8,6 +8,7 @@ from sklearn.metrics import f1_score from sklearn.multioutput import MultiOutputClassifier +import torch_geometric from torch_geometric.data import Batch from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader, LinkNeighborLoader @@ -29,7 +30,12 @@ val_loader = DataLoader(val_dataset, batch_size=2) test_loader = DataLoader(test_dataset, batch_size=2) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch_geometric.is_xpu_available(): + device = torch.device('xpu') +else: + device = torch.device('cpu') model = GraphSAGE( in_channels=train_dataset.num_features, hidden_channels=64, diff --git a/examples/pna.py b/examples/pna.py index 4697f49d7121..6e680f82df34 100644 --- a/examples/pna.py +++ b/examples/pna.py @@ -5,6 +5,7 @@ from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch_geometric from torch_geometric.datasets import ZINC from torch_geometric.loader import DataLoader from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool @@ -66,7 +67,12 @@ def forward(self, x, edge_index, edge_attr, batch): return self.mlp(x) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch_geometric.is_xpu_available(): + device = torch.device('xpu') +else: + device = torch.device('cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, diff --git a/torch_geometric/__init__.py b/torch_geometric/__init__.py index 381c3524e168..d32562dfd562 100644 --- a/torch_geometric/__init__.py +++ b/torch_geometric/__init__.py @@ -3,6 +3,7 @@ from .edge_index import EdgeIndex from .seed import seed_everything from .home import get_home_dir, set_home_dir +from .device import is_mps_available, is_xpu_available, device from .isinstance import is_torch_instance from .debug import is_debug_enabled, debug, set_debug @@ -33,6 +34,9 @@ 'set_home_dir', 'compile', 'is_compiling', + 'is_mps_available', + 'is_xpu_available', + 'device', 'is_torch_instance', 'is_debug_enabled', 'debug', diff --git a/torch_geometric/device.py b/torch_geometric/device.py new file mode 100644 index 000000000000..524ac52e0255 --- /dev/null +++ b/torch_geometric/device.py @@ -0,0 +1,42 @@ +from typing import Any + +import torch + + +def is_mps_available() -> bool: + r"""Returns a bool indicating if MPS is currently available.""" + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + try: # Github CI may not have access to MPS hardware. Confirm: + torch.empty(1, device='mps') + return True + except Exception: + return False + return False + + +def is_xpu_available() -> bool: + r"""Returns a bool indicating if XPU is currently available.""" + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + return True + try: + import intel_extension_for_pytorch as ipex + return ipex.xpu.is_available() + except ImportError: + return False + + +def device(device: Any) -> torch.device: + r"""Returns a :class:`torch.device`. + + If :obj:`"auto"` is specified, returns the optimal device depending on + available hardware. + """ + if device != 'auto': + return torch.device(device) + if torch.cuda.is_available(): + return torch.device('cuda') + if is_mps_available(): + return torch.device('mps') + if is_xpu_available(): + return torch.device('xpu') + return torch.device('cpu') diff --git a/torch_geometric/testing/decorators.py b/torch_geometric/testing/decorators.py index e0b16e138d2d..9aaa7a99b4e6 100644 --- a/torch_geometric/testing/decorators.py +++ b/torch_geometric/testing/decorators.py @@ -9,6 +9,7 @@ from packaging.requirements import Requirement from packaging.version import Version +import torch_geometric from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE from torch_geometric.visualization.graph import has_graphviz @@ -113,13 +114,8 @@ def onlyCUDA(func: Callable) -> Callable: def onlyXPU(func: Callable) -> Callable: r"""A decorator to skip tests if XPU is not found.""" import pytest - try: - import intel_extension_for_pytorch as ipex - xpu_available = ipex.xpu.is_available() - except ImportError: - xpu_available = False return pytest.mark.skipif( - not xpu_available, + not torch_geometric.is_xpu_available(), reason="XPU not available", )(func) @@ -227,23 +223,10 @@ def withDevice(func: Callable) -> Callable: if torch.cuda.is_available(): devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0')) - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - try: # Github CI may not have access to MPS hardware. Confirm: - torch.empty(1, device='mps') - devices.append(pytest.param(torch.device('mps:0'), id='mps')) - except RuntimeError: - pass - - if not hasattr(torch, 'xpu'): - try: - import intel_extension_for_pytorch as ipex - xpu_available = ipex.xpu.is_available() - except ImportError: - xpu_available = False - else: - xpu_available = torch.xpu.is_available() + if torch_geometric.is_mps_available(): + devices.append(pytest.param(torch.device('mps:0'), id='mps')) - if xpu_available: + if torch_geometric.is_xpu_available(): devices.append(pytest.param(torch.device('xpu:0'), id='xpu')) # Additional devices can be registered through environment variables: