Skip to content

Commit

Permalink
Update flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincenzooo authored and Vincenzooo committed Dec 13, 2024
1 parent 9c5c19e commit d4864ee
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def gat_norm( # noqa: F811
if not adj_t.has_value():
adj_t = adj_t.fill_value(1., dtype=dtype)

deg = torch_sparse.sum(adj_t, dim=1)
# idx = col if flow == 'source_to_target' else row
dim = 1 if flow == 'source_to_target' else 0
deg = torch_sparse.sum(adj_t, dim=dim)

att_mat = adj_t.set_value(edge_weight)
# repeat_interleave for any dtype tensor
num_heads = edge_weight.shape[1] if edge_weight.dim() > 1 else 1
Expand Down Expand Up @@ -487,6 +490,8 @@ def forward( # noqa: F811
alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
alpha = (alpha_src, alpha_dst)

# Skipping add_self_loops when normalize
# as we are performing it on gat_norm function
if self.add_self_loops and not self.normalize:
if isinstance(edge_index, Tensor):
# We only want to add self-loops for nodes that appear both as
Expand Down

0 comments on commit d4864ee

Please sign in to comment.