Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 15, 2024
1 parent 684cfe1 commit 0db6a8d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
18 changes: 12 additions & 6 deletions test/transforms/test_multiple_virtual_nodes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import copy

import torch

from torch_geometric.data import Data
from torch_geometric.transforms import MultipleVirtualNodes

import copy

# modified the tests in test_virtual_node.py


def test_multiple_virtual_nodes():
print("Test 1: Random assignments")
assert str(MultipleVirtualNodes(n_to_add=3, clustering=False)) == 'MultipleVirtualNodes()'
assert str(MultipleVirtualNodes(
n_to_add=3, clustering=False)) == 'MultipleVirtualNodes()'

x = torch.randn(4, 16)
edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]])
Expand All @@ -30,7 +32,9 @@ def test_multiple_virtual_nodes():
assert data.x[4:].abs().sum() == 0

first_3_col = [row[:3] for row in data.edge_index.tolist()]
assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged
assert first_3_col == [[2, 0, 2],
[3, 1,
0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, 11)

virtual_nodes = {4, 5, 6}
Expand Down Expand Up @@ -78,7 +82,9 @@ def validate_edge_index(edge_index):
assert data.x[4:].abs().sum() == 0

first_3_col = [row[:3] for row in data.edge_index.tolist()]
assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged
assert first_3_col == [[2, 0, 2],
[3, 1,
0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, 11)
validate_edge_index(data.edge_index.tolist())

Expand All @@ -97,4 +103,4 @@ def validate_edge_index(edge_index):

if __name__ == '__main__':
test_multiple_virtual_nodes()
print("All tests passed successfully!")
print("All tests passed successfully!")
31 changes: 18 additions & 13 deletions torch_geometric/transforms/multiple_virtual_nodes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import copy
import numpy as np

import numpy as np
import pymetis
import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

import pymetis

@functional_transform('multiple_virtual_nodes')
class MultipleVirtualNodes(BaseTransform):
Expand All @@ -32,11 +32,7 @@ class MultipleVirtualNodes(BaseTransform):
modified from VirtualNode class
"""
def __init__(
self,
n_to_add: int,
clustering: bool
) -> None:
def __init__(self, n_to_add: int, clustering: bool) -> None:
self.n = n_to_add
self.clustering = clustering

Expand All @@ -58,18 +54,27 @@ def forward(self, data: Data) -> Data:
assignments = np.array_split(permute, self.n)
else:
# Run METIS, as suggested in the paper. each node is assigned to 1 partition
adjacency_list = [np.array([], dtype=int) for _ in range(num_nodes)]
adjacency_list = [
np.array([], dtype=int) for _ in range(num_nodes)
]
for node1, node2 in zip(row, col):
adjacency_list[node1.item()] = np.append(adjacency_list[node1.item()], node2.item())
adjacency_list[node1.item()] = np.append(
adjacency_list[node1.item()], node2.item())
# membership is a list like [1, 1, 1, 0, 1, 0,...] for self.n = 2
n_cuts, membership = pymetis.part_graph(self.n, adjacency=adjacency_list)
n_cuts, membership = pymetis.part_graph(self.n,
adjacency=adjacency_list)
membership = np.array(membership)
assignments = [np.where(membership == v_node)[0] for v_node in range(self.n)]
assignments = [
np.where(membership == v_node)[0] for v_node in range(self.n)
]

arange = torch.from_numpy(np.concatenate(assignments))
# accounts for uneven splitting
full = torch.cat([torch.full((len(assignments[i]),), num_nodes + i, device=row.device) for i in range(self.n)],
dim=-1)
full = torch.cat([
torch.full(
(len(assignments[i]), ), num_nodes + i, device=row.device)
for i in range(self.n)
], dim=-1)

# Update edge index
row = torch.cat([row, arange, full], dim=0)
Expand Down

0 comments on commit 0db6a8d

Please sign in to comment.