From bde8a581f011ee447c9d8eb95eaa0e61fb7d5991 Mon Sep 17 00:00:00 2001 From: "Zhang, Chaojun" Date: Thu, 13 Jun 2024 12:55:59 +0000 Subject: [PATCH] Add XPU support to basic gnn examples(GraphSage) --- examples/graph_sage_unsup.py | 9 +++++++-- examples/graph_sage_unsup_ppi.py | 8 +++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/graph_sage_unsup.py b/examples/graph_sage_unsup.py index d523cdd7485e2..581bb9e5708b4 100644 --- a/examples/graph_sage_unsup.py +++ b/examples/graph_sage_unsup.py @@ -9,6 +9,7 @@ from torch_geometric.datasets import Planetoid from torch_geometric.loader import LinkNeighborLoader from torch_geometric.nn import GraphSAGE +from torch_geometric.testing.device import is_xpu_avaliable dataset = 'Cora' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) @@ -22,8 +23,12 @@ neg_sampling_ratio=1.0, num_neighbors=[10, 10], ) - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif is_xpu_avaliable(): + device = torch.device('xpu') +else: + device = torch.device('cpu') data = data.to(device, 'x', 'edge_index') model = GraphSAGE( diff --git a/examples/graph_sage_unsup_ppi.py b/examples/graph_sage_unsup_ppi.py index dd085dfe5d986..9fc304b48c93a 100644 --- a/examples/graph_sage_unsup_ppi.py +++ b/examples/graph_sage_unsup_ppi.py @@ -12,6 +12,7 @@ from torch_geometric.datasets import PPI from torch_geometric.loader import DataLoader, LinkNeighborLoader from torch_geometric.nn import GraphSAGE +from torch_geometric.testing.device import is_xpu_avaliable path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI') train_dataset = PPI(path, split='train') @@ -29,7 +30,12 @@ val_loader = DataLoader(val_dataset, batch_size=2) test_loader = DataLoader(test_dataset, batch_size=2) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif is_xpu_avaliable(): + device = torch.device('xpu') +else: + device = torch.device('cpu') model = GraphSAGE( in_channels=train_dataset.num_features, hidden_channels=64,