Skip to content

Commit

Permalink
Fix gat_norm when edge_index is SparseTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincenzooo authored and Vincenzooo committed Dec 11, 2024
1 parent c19ada3 commit 76cab5d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
41 changes: 23 additions & 18 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,31 @@ def gat_norm( # noqa: F811
adj_t = adj_t.fill_value(1., dtype=dtype)

deg = torch_sparse.sum(adj_t, dim=1)
att_mat = edge_index.copy()
att_mat = att_mat.set_value(edge_weight)
att_mat = torch_sparse.mul(att_mat,
deg) # unnormalized attention matrix
att_mat = adj_t.set_value(edge_weight)
# repeat_interleave for any dtype tensor
num_heads = edge_weight.shape[1] if edge_weight.dim() > 1 else 1
repeat_deg = deg.view(-1).unsqueeze(1).repeat(1, num_heads).flatten()
# perform mat multiplication with sparse tensor
# for multiple attention heads
att_mat = torch_sparse.mul(att_mat, repeat_deg.view(-1, 1, num_heads))

# Add self-loop, also called renormalization trick
att_mat = torch_sparse.fill_diag(att_mat, fill_value)
deg_tilde = deg + 1
deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5)
deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0.)
# repeat_interleave for any dtype tensor
repeat_deg_tilde_inv = deg_tilde_inv_sqrt.view(-1).unsqueeze(1).repeat(
1, num_heads).flatten()
repeat_deg_tilde_inv = repeat_deg_tilde_inv

att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(-1, 1))
att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(1, -1))

return adj_t, att_mat.to_dense()
att_mat = torch_sparse.mul(att_mat,
repeat_deg_tilde_inv.view(-1, 1, num_heads))
att_mat = torch_sparse.mul(att_mat,
repeat_deg_tilde_inv.view(1, -1, num_heads))
# should never be None, only for typing purpose
alpha = att_mat.storage.value()
return adj_t, alpha if alpha is not None else edge_weight

if is_torch_sparse_tensor(edge_index):
assert edge_index.size(0) == edge_index.size(1)
Expand Down Expand Up @@ -284,9 +294,6 @@ def __init__(
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=0, **kwargs)

if normalize:
add_self_loops = False

self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
Expand Down Expand Up @@ -479,7 +486,7 @@ def forward( # noqa: F811
alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
alpha = (alpha_src, alpha_dst)

if self.add_self_loops:
if self.add_self_loops and not self.normalize:
if isinstance(edge_index, Tensor):
# We only want to add self-loops for nodes that appear both as
# source and target nodes:
Expand Down Expand Up @@ -511,7 +518,7 @@ def forward( # noqa: F811
edge_index, edge_attr = remove_self_loops(
edge_index, edge_attr)
elif isinstance(edge_index, SparseTensor):
edge_index = torch_sparse.fill_diag(edge_index, 0.0)
edge_index = torch_sparse.remove_diag(edge_index)

# edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor)
alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr,
Expand All @@ -525,11 +532,9 @@ def forward( # noqa: F811
num_nodes = min(num_nodes, x_dst.size(0))
num_nodes = min(size) if size is not None else num_nodes

edge_index, alpha = gat_norm(edge_index,
alpha,
num_nodes=num_nodes,
flow=self.flow,
dtype=alpha.dtype) # yapf: disable
edge_index, alpha = gat_norm(edge_index, alpha,
num_nodes=num_nodes, flow=self.flow,
dtype=alpha.dtype) # noqa: F811

# propagate_type: (x: OptPairTensor, alpha: Tensor)
out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
Expand Down
5 changes: 1 addition & 4 deletions torch_geometric/nn/conv/gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@ def __init__(
):
super().__init__(node_dim=0, **kwargs)

if normalize:
add_self_loops = False

self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
Expand Down Expand Up @@ -311,7 +308,7 @@ def forward( # noqa: F811
assert x_l is not None
assert x_r is not None

if self.add_self_loops:
if self.add_self_loops and not self.normalize:
if isinstance(edge_index, Tensor):
num_nodes = x_l.size(0)
if x_r is not None:
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def fill_diag(src: SparseTensor, fill_value: float,
k: int = 0) -> SparseTensor:
raise ImportError("'fill_diag' requires 'torch-sparse'")

@staticmethod
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
raise ImportError("'remove_diag' requires 'torch-sparse'")

@staticmethod
def masked_select_nnz(src: SparseTensor, mask: Tensor,
layout: Optional[str] = None) -> SparseTensor:
Expand Down

0 comments on commit 76cab5d

Please sign in to comment.