Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincenzooo authored and Vincenzooo committed Dec 11, 2024
1 parent 6197e8f commit 9c5c19e
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ def test_remove_diag_sparse_tensor():
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
edge_index2 = torch.tensor([[1, 2, 3], [0, 1, 1]])

adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj2 = SparseTensor.from_edge_index(edge_index2, sparse_sizes=(4, 4))
if torch_geometric.typing.WITH_TORCH_SPARSE:
adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj2 = SparseTensor.from_edge_index(edge_index2, sparse_sizes=(4, 4))

assert torch_sparse.remove_diag(adj1.t()) == adj2.t()
assert torch_sparse.remove_diag(adj1.t()) == adj2.t()

0 comments on commit 9c5c19e

Please sign in to comment.