diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 998a4fe1ce24..70e860081c6e 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -78,21 +78,31 @@ def gat_norm( # noqa: F811 adj_t = adj_t.fill_value(1., dtype=dtype) deg = torch_sparse.sum(adj_t, dim=1) - att_mat = edge_index.copy() - att_mat = att_mat.set_value(edge_weight) - att_mat = torch_sparse.mul(att_mat, - deg) # unnormalized attention matrix + att_mat = adj_t.set_value(edge_weight) + # repeat_interleave for any dtype tensor + num_heads = edge_weight.shape[1] if edge_weight.dim() > 1 else 1 + repeat_deg = deg.view(-1).unsqueeze(1).repeat(1, num_heads).flatten() + # perform mat multiplication with sparse tensor + # for multiple attention heads + att_mat = torch_sparse.mul(att_mat, repeat_deg.view(-1, 1, num_heads)) # Add self-loop, also called renormalization trick att_mat = torch_sparse.fill_diag(att_mat, fill_value) deg_tilde = deg + 1 deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0.) + # repeat_interleave for any dtype tensor + repeat_deg_tilde_inv = deg_tilde_inv_sqrt.view(-1).unsqueeze(1).repeat( + 1, num_heads).flatten() + repeat_deg_tilde_inv = repeat_deg_tilde_inv - att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(-1, 1)) - att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(1, -1)) - - return adj_t, att_mat.to_dense() + att_mat = torch_sparse.mul(att_mat, + repeat_deg_tilde_inv.view(-1, 1, num_heads)) + att_mat = torch_sparse.mul(att_mat, + repeat_deg_tilde_inv.view(1, -1, num_heads)) + # should never be None, only for typing purpose + alpha = att_mat.storage.value() + return adj_t, alpha if alpha is not None else edge_weight if is_torch_sparse_tensor(edge_index): assert edge_index.size(0) == edge_index.size(1) @@ -284,9 +294,6 @@ def __init__( kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) - if normalize: - add_self_loops = False - self.in_channels = in_channels self.out_channels = out_channels self.heads = heads @@ -479,7 +486,7 @@ def forward( # noqa: F811 alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) - if self.add_self_loops: + if self.add_self_loops and not self.normalize: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: @@ -511,7 +518,7 @@ def forward( # noqa: F811 edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) elif isinstance(edge_index, SparseTensor): - edge_index = torch_sparse.fill_diag(edge_index, 0.0) + edge_index = torch_sparse.remove_diag(edge_index) # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, @@ -525,11 +532,9 @@ def forward( # noqa: F811 num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes - edge_index, alpha = gat_norm(edge_index, - alpha, - num_nodes=num_nodes, - flow=self.flow, - dtype=alpha.dtype) # yapf: disable + edge_index, alpha = gat_norm(edge_index, alpha, + num_nodes=num_nodes, flow=self.flow, + dtype=alpha.dtype) # noqa: F811 # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index a95117f28feb..982de16cf74e 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -153,9 +153,6 @@ def __init__( ): super().__init__(node_dim=0, **kwargs) - if normalize: - add_self_loops = False - self.in_channels = in_channels self.out_channels = out_channels self.heads = heads @@ -311,7 +308,7 @@ def forward( # noqa: F811 assert x_l is not None assert x_r is not None - if self.add_self_loops: + if self.add_self_loops and not self.normalize: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index e0fb657cb294..c6e91f4f1db0 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -248,6 +248,10 @@ def fill_diag(src: SparseTensor, fill_value: float, k: int = 0) -> SparseTensor: raise ImportError("'fill_diag' requires 'torch-sparse'") + @staticmethod + def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: + raise ImportError("'remove_diag' requires 'torch-sparse'") + @staticmethod def masked_select_nnz(src: SparseTensor, mask: Tensor, layout: Optional[str] = None) -> SparseTensor: