Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
deniselj24 committed Dec 15, 2024
1 parent 816f2fa commit 684cfe1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
22 changes: 15 additions & 7 deletions test/transforms/test_multiple_virtual_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from torch_geometric.data import Data
from torch_geometric.transforms import MultipleVirtualNodes

import copy

# modified the tests in test_virtual_node.py

def test_multiple_virtual_nodes():
print("Test 1: Random assignments")
assert str(MultipleVirtualNodes()) == 'MultipleVirtualNodes()'
assert str(MultipleVirtualNodes(n_to_add=3, clustering=False)) == 'MultipleVirtualNodes()'

x = torch.randn(4, 16)
edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]])
Expand All @@ -17,6 +19,7 @@ def test_multiple_virtual_nodes():
data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight,
edge_attr=edge_attr, num_nodes=x.size(0))

original_data = copy.deepcopy(data)
# random assignments and uneven split
data = MultipleVirtualNodes(n_to_add=3, clustering=False)(data)

Expand All @@ -27,7 +30,7 @@ def test_multiple_virtual_nodes():
assert data.x[4:].abs().sum() == 0

first_3_col = [row[:3] for row in data.edge_index.tolist()]
assert first_3_col == edge_index # check that the original edges are unchanged
assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, 11)

virtual_nodes = {4, 5, 6}
Expand Down Expand Up @@ -65,17 +68,17 @@ def validate_edge_index(edge_index):

print("Test 1 passed\nTest 2: Clustering Assignments")

# clustering assignments and uneven split
data = MultipleVirtualNodes(n_to_add=3, clustering=True)(data)
# Test 2: clustering assignments and uneven split
data = MultipleVirtualNodes(n_to_add=3, clustering=True)(original_data)

assert len(data) == 6

print(data.x)
assert data.x.size() == (7, 16)
assert torch.allclose(data.x[:4], x)
assert data.x[4:].abs().sum() == 0

first_3_col = [row[:3] for row in data.edge_index.tolist()]
assert first_3_col == edge_index # check that the original edges are unchanged
assert first_3_col == [[2, 0, 2], [3, 1, 0]] # check that the original edges are unchanged
assert data.edge_index.size() == (2, 11)
validate_edge_index(data.edge_index.tolist())

Expand All @@ -89,4 +92,9 @@ def validate_edge_index(edge_index):

assert data.num_nodes == 7

assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]


if __name__ == '__main__':
test_multiple_virtual_nodes()
print("All tests passed successfully!")
1 change: 1 addition & 0 deletions torch_geometric/transforms/multiple_virtual_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def forward(self, data: Data) -> Data:
adjacency_list[node1.item()] = np.append(adjacency_list[node1.item()], node2.item())
# membership is a list like [1, 1, 1, 0, 1, 0,...] for self.n = 2
n_cuts, membership = pymetis.part_graph(self.n, adjacency=adjacency_list)
membership = np.array(membership)
assignments = [np.where(membership == v_node)[0] for v_node in range(self.n)]

arange = torch.from_numpy(np.concatenate(assignments))
Expand Down

0 comments on commit 684cfe1

Please sign in to comment.