Skip to content

Commit

Permalink
Add XPU support to basic gnn examples(GraphSage)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaojun-zhang committed Jun 13, 2024
1 parent b7a4cfd commit 6558067
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
9 changes: 7 additions & 2 deletions examples/graph_sage_unsup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import GraphSAGE
from torch_geometric.testing.device import is_xpu_avaliable

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
Expand All @@ -22,8 +23,12 @@
neg_sampling_ratio=1.0,
num_neighbors=[10, 10],
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif is_xpu_avaliable():
device = torch.device('xpu')
else:
device = torch.device('cpu')
data = data.to(device, 'x', 'edge_index')

model = GraphSAGE(
Expand Down
8 changes: 7 additions & 1 deletion examples/graph_sage_unsup_ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader, LinkNeighborLoader
from torch_geometric.nn import GraphSAGE
from torch_geometric.testing.device import is_xpu_avaliable

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')
train_dataset = PPI(path, split='train')
Expand All @@ -29,7 +30,12 @@
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif is_xpu_avaliable():
device = torch.device('xpu')
else:
device = torch.device('cpu')
model = GraphSAGE(
in_channels=train_dataset.num_features,
hidden_channels=64,
Expand Down

0 comments on commit 6558067

Please sign in to comment.