Skip to content

Commit

Permalink
using PyG clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
deniselj24 committed Dec 15, 2024
1 parent 0db6a8d commit 0c92260
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 60 deletions.
111 changes: 76 additions & 35 deletions test/transforms/test_multiple_virtual_nodes.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import copy

import torch

from torch_geometric.data import Data
from torch_geometric.transforms import MultipleVirtualNodes

# modified the tests in test_virtual_node.py
import copy

# modified the tests in test_virtual_node.py

def test_multiple_virtual_nodes():
print("Test 1: Random assignments")
assert str(MultipleVirtualNodes(
n_to_add=3, clustering=False)) == 'MultipleVirtualNodes()'
assert str(MultipleVirtualNodes(n_to_add=3, clustering=False)) == 'MultipleVirtualNodes()'

x = torch.randn(4, 16)
edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]])
Expand All @@ -32,31 +31,49 @@ def test_multiple_virtual_nodes():
assert data.x[4:].abs().sum() == 0

first_3_col = [row[:3] for row in data.edge_index.tolist()]
assert first_3_col == [[2, 0, 2],
[3, 1,
0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, 11)
assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged
num_total_edges = 11
assert data.edge_index.size() == (2, num_total_edges)

virtual_nodes = {4, 5, 6}

def validate_edge_index(edge_index):
source_nodes = edge_index[0][3:11]
target_nodes = edge_index[1][3:11]
source_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
target_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
for source in source_nodes:
source_counts[source] += 1
def validate_edge_index(edge_index, n_original_nodes, n_added_nodes):
source_nodes = edge_index[0][n_original_nodes - 1:num_total_edges]
target_nodes = edge_index[1][n_original_nodes - 1:num_total_edges]
source_counts = {i: 0 for i in range(n_original_nodes + n_added_nodes)}
target_counts = {i: 0 for i in range(n_original_nodes + n_added_nodes)}
original_node_indices = [i for i in range(n_original_nodes)]
virtual_node_indices = [n_original_nodes + i for i in range(n_added_nodes)]

for i in range(len(source_nodes)):
source_counts[source_nodes[i]] += 1
# check that virtual nodes are not connected to each other and original nodes are not connected to each other
if source_nodes[i] in virtual_node_indices:
assert target_nodes[i] in original_node_indices
else:
assert target_nodes[i] in virtual_node_indices
for target in target_nodes:
target_counts[target] += 1
assert source_counts[4] + source_counts[5] + source_counts[6] == 4
assert target_counts[4] + target_counts[5] + target_counts[6] == 4
for virtual in virtual_nodes:
assert source_counts[virtual] > 0
assert source_counts[virtual] < 3
assert target_counts[virtual] > 0
assert target_counts[virtual] < 3

validate_edge_index(data.edge_index.tolist())
total_virtual_source_count = 0
total_virtual_target_count = 0
# check virtual nodes' edges have been added in the correct way
for i in range(n_added_nodes):
assert source_counts[n_original_nodes + i] > 0
assert source_counts[n_original_nodes + i] < 3
assert target_counts[n_original_nodes + i] > 0
assert target_counts[n_original_nodes + i] < 3
total_virtual_source_count += source_counts[n_original_nodes + i]
total_virtual_target_count += target_counts[n_original_nodes + i]
# check original nodes
for j in range(n_original_nodes):
assert source_counts[j] == 1
assert target_counts[j] == 1

assert total_virtual_source_count == n_original_nodes
assert total_virtual_target_count == n_original_nodes

validate_edge_index(data.edge_index.tolist(), 4, 3)

assert data.edge_weight.size() == (11, )
assert torch.allclose(data.edge_weight[:3], edge_weight)
Expand All @@ -72,35 +89,59 @@ def validate_edge_index(edge_index):

print("Test 1 passed\nTest 2: Clustering Assignments")

# Test 2: clustering assignments and uneven split
# Test 2: clustering assignments
data = MultipleVirtualNodes(n_to_add=2, clustering=True)(original_data)

assert len(data) == 6
assert data.x.size() == (6, 16)
assert torch.allclose(data.x[:4], x)
assert data.x[4:].abs().sum() == 0

first_3_col = [row[:3] for row in data.edge_index.tolist()]
assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, num_total_edges)
validate_edge_index(data.edge_index.tolist(), 4, 2)

assert data.edge_weight.size() == (num_total_edges, )
assert torch.allclose(data.edge_weight[:3], edge_weight)
assert data.edge_weight[3:].abs().sum() == 8

assert data.edge_attr.size() == (num_total_edges, 8)
assert torch.allclose(data.edge_attr[:3], edge_attr)
assert data.edge_attr[3:].abs().sum() == 0

assert data.num_nodes == 6

assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]

print("Test 2 passed\nTest 3: Clustering Assignments with an empty cluster. Should add 2 virtual nodes instead of the specified 3")

# Test 2: clustering assignments
data = MultipleVirtualNodes(n_to_add=3, clustering=True)(original_data)

assert len(data) == 6
print(data.x)
assert data.x.size() == (7, 16)
assert data.x.size() == (6, 16)
assert torch.allclose(data.x[:4], x)
assert data.x[4:].abs().sum() == 0

first_3_col = [row[:3] for row in data.edge_index.tolist()]
assert first_3_col == [[2, 0, 2],
[3, 1,
0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, 11)
validate_edge_index(data.edge_index.tolist())
assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, num_total_edges)
validate_edge_index(data.edge_index.tolist(), 4, 2)

assert data.edge_weight.size() == (11, )
assert data.edge_weight.size() == (num_total_edges, )
assert torch.allclose(data.edge_weight[:3], edge_weight)
assert data.edge_weight[3:].abs().sum() == 8

assert data.edge_attr.size() == (11, 8)
assert data.edge_attr.size() == (num_total_edges, 8)
assert torch.allclose(data.edge_attr[:3], edge_attr)
assert data.edge_attr[3:].abs().sum() == 0

assert data.num_nodes == 7
assert data.num_nodes == 6

assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]


if __name__ == '__main__':
test_multiple_virtual_nodes()
print("All tests passed successfully!")
print("All tests passed successfully!")
53 changes: 28 additions & 25 deletions torch_geometric/transforms/multiple_virtual_nodes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import copy

import numpy as np
import pymetis

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

from torch_geometric.data import ClusterData, ClusterLoader

@functional_transform('multiple_virtual_nodes')
class MultipleVirtualNodes(BaseTransform):
Expand All @@ -29,12 +28,19 @@ class MultipleVirtualNodes(BaseTransform):
Hyperparameters:
n_to_add: number of virtual nodes (int). default = 1
clustering: whether the clustering algorithm is used to assign virtual nodes (bool). default = False, means that assignment is random.
recursive: only used if clustering is set to True.If set to :obj:`True`, will use multilevel recursive bisection instead of multilevel k-way partitioning. default = False
modified from VirtualNode class
"""
def __init__(self, n_to_add: int, clustering: bool) -> None:
def __init__(
self,
n_to_add: int,
clustering: bool=False,
recursive: bool=False
) -> None:
self.n = n_to_add
self.clustering = clustering
self.recursive = recursive

def forward(self, data: Data) -> Data:
assert data.edge_index is not None
Expand All @@ -53,28 +59,25 @@ def forward(self, data: Data) -> Data:
permute = np.random.permutation(num_nodes)
assignments = np.array_split(permute, self.n)
else:
# Run METIS, as suggested in the paper. each node is assigned to 1 partition
adjacency_list = [
np.array([], dtype=int) for _ in range(num_nodes)
]
for node1, node2 in zip(row, col):
adjacency_list[node1.item()] = np.append(
adjacency_list[node1.item()], node2.item())
# membership is a list like [1, 1, 1, 0, 1, 0,...] for self.n = 2
n_cuts, membership = pymetis.part_graph(self.n,
adjacency=adjacency_list)
membership = np.array(membership)
assignments = [
np.where(membership == v_node)[0] for v_node in range(self.n)
]
# run clustering algorithm to assign virtual node i to nodes in cluster i
clustered_data = ClusterData(data, self.n, recursive=self.recursive)
partition = clustered_data.partition
assignments = []
for i in range(self.n):
# get nodes in cluster i
start = int(partition.partptr[i])
end = int(partition.partptr[i + 1])
cluster_nodes = partition.node_perm[start:end]
if len(cluster_nodes) == 0:
print(f"Cluster {i + 1} is empty after running the METIS algorithm with recursion set to {self.recursive}. Decreasing number of virtual nodes added to {self.n - 1}.")
self.n -= 1 # decrease the number of virtual nodes to add
continue
assignments.append(np.array(cluster_nodes))

arange = torch.from_numpy(np.concatenate(assignments))
# accounts for uneven splitting
full = torch.cat([
torch.full(
(len(assignments[i]), ), num_nodes + i, device=row.device)
for i in range(self.n)
], dim=-1)
full = torch.cat([torch.full((len(assignments[i]),), num_nodes + i, device=row.device) for i in range(self.n)],
dim=-1)

# Update edge index
row = torch.cat([row, arange, full], dim=0)
Expand Down Expand Up @@ -121,4 +124,4 @@ def forward(self, data: Data) -> Data:
if 'num_nodes' in data:
data.num_nodes = num_nodes + self.n

return data
return data

0 comments on commit 0c92260

Please sign in to comment.