From 0c92260cf94c962d2abcb11ad1d30f41e333afe3 Mon Sep 17 00:00:00 2001 From: deniselj24 Date: Sun, 15 Dec 2024 01:40:14 -0800 Subject: [PATCH] using PyG clustering --- .../transforms/test_multiple_virtual_nodes.py | 111 ++++++++++++------ .../transforms/multiple_virtual_nodes.py | 53 +++++---- 2 files changed, 104 insertions(+), 60 deletions(-) diff --git a/test/transforms/test_multiple_virtual_nodes.py b/test/transforms/test_multiple_virtual_nodes.py index 4831788b14a1..b4af9ea30839 100644 --- a/test/transforms/test_multiple_virtual_nodes.py +++ b/test/transforms/test_multiple_virtual_nodes.py @@ -1,17 +1,16 @@ -import copy import torch from torch_geometric.data import Data from torch_geometric.transforms import MultipleVirtualNodes -# modified the tests in test_virtual_node.py +import copy +# modified the tests in test_virtual_node.py def test_multiple_virtual_nodes(): print("Test 1: Random assignments") - assert str(MultipleVirtualNodes( - n_to_add=3, clustering=False)) == 'MultipleVirtualNodes()' + assert str(MultipleVirtualNodes(n_to_add=3, clustering=False)) == 'MultipleVirtualNodes()' x = torch.randn(4, 16) edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]]) @@ -32,31 +31,49 @@ def test_multiple_virtual_nodes(): assert data.x[4:].abs().sum() == 0 first_3_col = [row[:3] for row in data.edge_index.tolist()] - assert first_3_col == [[2, 0, 2], - [3, 1, - 0]] # check that the original edges are unchanged - assert data.edge_index.size() == (2, 11) + assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged + num_total_edges = 11 + assert data.edge_index.size() == (2, num_total_edges) virtual_nodes = {4, 5, 6} - def validate_edge_index(edge_index): - source_nodes = edge_index[0][3:11] - target_nodes = edge_index[1][3:11] - source_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0} - target_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0} - for source in source_nodes: - source_counts[source] += 1 + def validate_edge_index(edge_index, n_original_nodes, n_added_nodes): + source_nodes = edge_index[0][n_original_nodes - 1:num_total_edges] + target_nodes = edge_index[1][n_original_nodes - 1:num_total_edges] + source_counts = {i: 0 for i in range(n_original_nodes + n_added_nodes)} + target_counts = {i: 0 for i in range(n_original_nodes + n_added_nodes)} + original_node_indices = [i for i in range(n_original_nodes)] + virtual_node_indices = [n_original_nodes + i for i in range(n_added_nodes)] + + for i in range(len(source_nodes)): + source_counts[source_nodes[i]] += 1 + # check that virtual nodes are not connected to each other and original nodes are not connected to each other + if source_nodes[i] in virtual_node_indices: + assert target_nodes[i] in original_node_indices + else: + assert target_nodes[i] in virtual_node_indices for target in target_nodes: target_counts[target] += 1 - assert source_counts[4] + source_counts[5] + source_counts[6] == 4 - assert target_counts[4] + target_counts[5] + target_counts[6] == 4 - for virtual in virtual_nodes: - assert source_counts[virtual] > 0 - assert source_counts[virtual] < 3 - assert target_counts[virtual] > 0 - assert target_counts[virtual] < 3 - validate_edge_index(data.edge_index.tolist()) + total_virtual_source_count = 0 + total_virtual_target_count = 0 + # check virtual nodes' edges have been added in the correct way + for i in range(n_added_nodes): + assert source_counts[n_original_nodes + i] > 0 + assert source_counts[n_original_nodes + i] < 3 + assert target_counts[n_original_nodes + i] > 0 + assert target_counts[n_original_nodes + i] < 3 + total_virtual_source_count += source_counts[n_original_nodes + i] + total_virtual_target_count += target_counts[n_original_nodes + i] + # check original nodes + for j in range(n_original_nodes): + assert source_counts[j] == 1 + assert target_counts[j] == 1 + + assert total_virtual_source_count == n_original_nodes + assert total_virtual_target_count == n_original_nodes + + validate_edge_index(data.edge_index.tolist(), 4, 3) assert data.edge_weight.size() == (11, ) assert torch.allclose(data.edge_weight[:3], edge_weight) @@ -72,35 +89,59 @@ def validate_edge_index(edge_index): print("Test 1 passed\nTest 2: Clustering Assignments") - # Test 2: clustering assignments and uneven split + # Test 2: clustering assignments + data = MultipleVirtualNodes(n_to_add=2, clustering=True)(original_data) + + assert len(data) == 6 + assert data.x.size() == (6, 16) + assert torch.allclose(data.x[:4], x) + assert data.x[4:].abs().sum() == 0 + + first_3_col = [row[:3] for row in data.edge_index.tolist()] + assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged + assert data.edge_index.size() == (2, num_total_edges) + validate_edge_index(data.edge_index.tolist(), 4, 2) + + assert data.edge_weight.size() == (num_total_edges, ) + assert torch.allclose(data.edge_weight[:3], edge_weight) + assert data.edge_weight[3:].abs().sum() == 8 + + assert data.edge_attr.size() == (num_total_edges, 8) + assert torch.allclose(data.edge_attr[:3], edge_attr) + assert data.edge_attr[3:].abs().sum() == 0 + + assert data.num_nodes == 6 + + assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + + print("Test 2 passed\nTest 3: Clustering Assignments with an empty cluster. Should add 2 virtual nodes instead of the specified 3") + + # Test 2: clustering assignments data = MultipleVirtualNodes(n_to_add=3, clustering=True)(original_data) assert len(data) == 6 - print(data.x) - assert data.x.size() == (7, 16) + assert data.x.size() == (6, 16) assert torch.allclose(data.x[:4], x) assert data.x[4:].abs().sum() == 0 first_3_col = [row[:3] for row in data.edge_index.tolist()] - assert first_3_col == [[2, 0, 2], - [3, 1, - 0]] # check that the original edges are unchanged - assert data.edge_index.size() == (2, 11) - validate_edge_index(data.edge_index.tolist()) + assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged + assert data.edge_index.size() == (2, num_total_edges) + validate_edge_index(data.edge_index.tolist(), 4, 2) - assert data.edge_weight.size() == (11, ) + assert data.edge_weight.size() == (num_total_edges, ) assert torch.allclose(data.edge_weight[:3], edge_weight) assert data.edge_weight[3:].abs().sum() == 8 - assert data.edge_attr.size() == (11, 8) + assert data.edge_attr.size() == (num_total_edges, 8) assert torch.allclose(data.edge_attr[:3], edge_attr) assert data.edge_attr[3:].abs().sum() == 0 - assert data.num_nodes == 7 + assert data.num_nodes == 6 assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] if __name__ == '__main__': test_multiple_virtual_nodes() - print("All tests passed successfully!") + print("All tests passed successfully!") \ No newline at end of file diff --git a/torch_geometric/transforms/multiple_virtual_nodes.py b/torch_geometric/transforms/multiple_virtual_nodes.py index eba88a14bc67..967bad4914ca 100644 --- a/torch_geometric/transforms/multiple_virtual_nodes.py +++ b/torch_geometric/transforms/multiple_virtual_nodes.py @@ -1,14 +1,13 @@ import copy - import numpy as np -import pymetis + import torch from torch import Tensor from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform - +from torch_geometric.data import ClusterData, ClusterLoader @functional_transform('multiple_virtual_nodes') class MultipleVirtualNodes(BaseTransform): @@ -29,12 +28,19 @@ class MultipleVirtualNodes(BaseTransform): Hyperparameters: n_to_add: number of virtual nodes (int). default = 1 clustering: whether the clustering algorithm is used to assign virtual nodes (bool). default = False, means that assignment is random. - + recursive: only used if clustering is set to True.If set to :obj:`True`, will use multilevel recursive bisection instead of multilevel k-way partitioning. default = False + modified from VirtualNode class """ - def __init__(self, n_to_add: int, clustering: bool) -> None: + def __init__( + self, + n_to_add: int, + clustering: bool=False, + recursive: bool=False + ) -> None: self.n = n_to_add self.clustering = clustering + self.recursive = recursive def forward(self, data: Data) -> Data: assert data.edge_index is not None @@ -53,28 +59,25 @@ def forward(self, data: Data) -> Data: permute = np.random.permutation(num_nodes) assignments = np.array_split(permute, self.n) else: - # Run METIS, as suggested in the paper. each node is assigned to 1 partition - adjacency_list = [ - np.array([], dtype=int) for _ in range(num_nodes) - ] - for node1, node2 in zip(row, col): - adjacency_list[node1.item()] = np.append( - adjacency_list[node1.item()], node2.item()) - # membership is a list like [1, 1, 1, 0, 1, 0,...] for self.n = 2 - n_cuts, membership = pymetis.part_graph(self.n, - adjacency=adjacency_list) - membership = np.array(membership) - assignments = [ - np.where(membership == v_node)[0] for v_node in range(self.n) - ] + # run clustering algorithm to assign virtual node i to nodes in cluster i + clustered_data = ClusterData(data, self.n, recursive=self.recursive) + partition = clustered_data.partition + assignments = [] + for i in range(self.n): + # get nodes in cluster i + start = int(partition.partptr[i]) + end = int(partition.partptr[i + 1]) + cluster_nodes = partition.node_perm[start:end] + if len(cluster_nodes) == 0: + print(f"Cluster {i + 1} is empty after running the METIS algorithm with recursion set to {self.recursive}. Decreasing number of virtual nodes added to {self.n - 1}.") + self.n -= 1 # decrease the number of virtual nodes to add + continue + assignments.append(np.array(cluster_nodes)) arange = torch.from_numpy(np.concatenate(assignments)) # accounts for uneven splitting - full = torch.cat([ - torch.full( - (len(assignments[i]), ), num_nodes + i, device=row.device) - for i in range(self.n) - ], dim=-1) + full = torch.cat([torch.full((len(assignments[i]),), num_nodes + i, device=row.device) for i in range(self.n)], + dim=-1) # Update edge index row = torch.cat([row, arange, full], dim=0) @@ -121,4 +124,4 @@ def forward(self, data: Data) -> Data: if 'num_nodes' in data: data.num_nodes = num_nodes + self.n - return data + return data \ No newline at end of file