Skip to content

Commit

Permalink
Merge pull request #253 from cshjin/fix-test_tsagcn
Browse files Browse the repository at this point in the history
Fix `test_tsagcn`
  • Loading branch information
SherylHYX authored Oct 14, 2024
2 parents bdcf2b2 + c7dcb2d commit 9d1b5ab
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 5 additions & 4 deletions test/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def test_mtgnn():
num_sub = int(num_nodes / num_split)
for j in range(num_split):
if j != num_split - 1:
id = perm[j * num_sub : (j + 1) * num_sub]
id = perm[j * num_sub: (j + 1) * num_sub]
else:
id = perm[j * num_sub :]
id = perm[j * num_sub:]
tx = trainx[:, :, id, :]
output = model(tx, A_tilde, idx=id)
output = output.transpose(1, 3)
Expand Down Expand Up @@ -652,9 +652,9 @@ def test_mtgnn():
trainx = trainx.transpose(1, 3)
for j in range(num_split):
if j != num_split - 1:
id = perm[j * num_sub : (j + 1) * num_sub]
id = perm[j * num_sub: (j + 1) * num_sub]
else:
id = perm[j * num_sub :]
id = perm[j * num_sub:]
tx = trainx[:, :, id, :]
output = model(tx, A_tilde, idx=id)
output = output.transpose(1, 3)
Expand Down Expand Up @@ -690,6 +690,7 @@ def test_tsagcn():
# (bs, seq, nodes, f_in) -> (bs, f_in, seq, nodes)
# also be sure to pass in a contiguous tensor (the created in create_mock_batch() is not!)
batch = batch.permute(0, 3, 1, 2).contiguous()
edge_index = edge_index.to(device)

stride = 2
aagcn_adaptive = AAGCN(
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric_temporal/nn/attention/tsagcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, edge_index: list, num_nodes: int):
self.A = self.get_spatial_graph(self.num_nodes)

def get_spatial_graph(self, num_nodes):
self_mat = torch.eye(num_nodes)
self_mat = torch.eye(num_nodes, device=self.edge_index.device)
inward_mat = torch.squeeze(to_dense_adj(self.edge_index))
inward_mat_norm = F.normalize(inward_mat, dim=0, p=1)
outward_mat = inward_mat.transpose(0, 1)
Expand Down

0 comments on commit 9d1b5ab

Please sign in to comment.