diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..e1969848e4aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `normalize` parameter to `GATConv` and `GATv2Conv` ([#9840](https://github.com/pyg-team/pytorch_geometric/pull/9840)) - Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/benchmark/citation/gat.py b/benchmark/citation/gat.py index f9ed5d6071af..47594ed50bfc 100644 --- a/benchmark/citation/gat.py +++ b/benchmark/citation/gat.py @@ -24,6 +24,7 @@ parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') +parser.add_argument('--normalize', action='store_true') args = parser.parse_args() @@ -31,10 +32,11 @@ class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = GATConv(dataset.num_features, args.hidden, - heads=args.heads, dropout=args.dropout) + heads=args.heads, dropout=args.dropout, + normalize=args.normalize) self.conv2 = GATConv(args.hidden * args.heads, dataset.num_classes, heads=args.output_heads, concat=False, - dropout=args.dropout) + dropout=args.dropout, normalize=args.normalize) def reset_parameters(self): self.conv1.reset_parameters() diff --git a/examples/gat.py b/examples/gat.py index 09f90011efdd..bf091241f695 100644 --- a/examples/gat.py +++ b/examples/gat.py @@ -17,6 +17,8 @@ parser.add_argument('--heads', type=int, default=8) parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--epochs', type=int, default=200) +parser.add_argument('--norm_adj', type=bool, default=False, + help="GAT with symmetric normalized adjacency") parser.add_argument('--wandb', action='store_true', help='Track experiment') args = parser.parse_args() @@ -28,7 +30,8 @@ device = torch.device('cpu') init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs, - hidden_channels=args.hidden_channels, lr=args.lr, device=device) + hidden_channels=args.hidden_channels, lr=args.lr, device=device, + norm_adj=args.norm_adj) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures()) @@ -36,12 +39,14 @@ class GAT(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, heads): + def __init__(self, in_channels, hidden_channels, out_channels, heads, + normalize): super().__init__() - self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6) + self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6, + normalize=normalize) # On the Pubmed dataset, use `heads` output heads in `conv2`. self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, - concat=False, dropout=0.6) + concat=False, dropout=0.6, normalize=normalize) def forward(self, x, edge_index): x = F.dropout(x, p=0.6, training=self.training) @@ -52,7 +57,7 @@ def forward(self, x, edge_index): model = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes, - args.heads).to(device) + args.heads, args.norm_adj).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 6549911ac0d4..eeeee2b31145 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -6,9 +6,10 @@ import torch_geometric.typing from torch_geometric.nn import GATConv +from torch_geometric.nn.conv.gat_conv import gat_norm from torch_geometric.testing import is_full_test, withDevice -from torch_geometric.typing import Adj, Size, SparseTensor -from torch_geometric.utils import to_torch_csc_tensor +from torch_geometric.typing import Adj, Size, SparseTensor, torch_sparse +from torch_geometric.utils import to_torch_csc_tensor, to_torch_csr_tensor @pytest.mark.parametrize('residual', [False, True]) @@ -155,6 +156,45 @@ def forward( assert torch.allclose(jit((x1, x2), adj2.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj2.t()), out2, atol=1e-6) + # Test GAT normalization: + x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) + + conv = GATConv(8, 32, heads=2, residual=residual, normalize=True) + assert str(conv) == 'GATConv(8, 32, heads=2)' + out = conv(x1, edge_index) + assert out.size() == (4, 64) + assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) + assert torch.allclose(conv(x1, adj1.t()), out) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) + + if is_full_test(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = conv + + def forward( + self, + x: Tensor, + edge_index: Adj, + size: Size = None, + ) -> Tensor: + return self.conv(x, edge_index, size=size) + + jit = torch.jit.script(MyModule()) + assert torch.allclose(jit(x1, edge_index), out) + assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) + def test_gat_conv_with_edge_attr(): x = torch.randn(4, 8) @@ -201,3 +241,48 @@ def test_gat_conv_empty_edge_index(device): conv = GATConv(8, 32, heads=2).to(device) out = conv(x, edge_index) assert out.size() == (0, 64) + + +def test_gat_conv_csc_error(): + x1 = torch.randn(4, 8) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csr_tensor(edge_index, size=(4, 4)) + + with pytest.raises(ValueError, match="Unexpected sparse tensor layout"): + conv = GATConv(8, 32, heads=2, normalize=True) + assert str(conv) == 'GATConv(8, 32, heads=2)' + _ = conv(x1, adj1.t()) + + +def test_gat_norm_csc_error(): + edge_index = torch.tensor([[1, 2, 3], [0, 1, 1]]) + edge_weight = torch.tensor([[1.0000, 1.0000], [1.2341, 0.9614], + [0.7659, 1.0386]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) + + with pytest.raises(NotImplementedError, + match="Sparse CSC matrices are not yet supported"): + gat_norm(adj1, edge_weight) + + +def test_gat_conv_bipartite_error(): + x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + with pytest.raises(NotImplementedError, + match="not supported for bipartite message passing"): + conv = GATConv((8, 16), 32, heads=2, normalize=True) + conv((x1, x2), edge_index) + + +def test_remove_diag_sparse_tensor(): + # Used in GAT Normalization + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_index2 = torch.tensor([[1, 2, 3], [0, 1, 1]]) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + adj2 = SparseTensor.from_edge_index(edge_index2, sparse_sizes=(4, 4)) + + assert torch_sparse.remove_diag(adj1.t()) == adj2.t() diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index 3bca6530eee9..3a0477d8df37 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -141,6 +141,44 @@ def forward( if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) + # Test GAT normalization: + x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) + + conv = GATv2Conv(8, 32, heads=2, residual=residual, normalize=True) + assert str(conv) == 'GATv2Conv(8, 32, heads=2)' + out = conv(x1, edge_index) + assert out.size() == (4, 64) + assert torch.allclose(conv(x1, edge_index), out) + assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) + + if is_full_test(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = conv + + def forward( + self, + x: Tensor, + edge_index: Adj, + ) -> Tensor: + return self.conv(x, edge_index) + + jit = torch.jit.script(MyModule()) + assert torch.allclose(jit(x1, edge_index), out) + assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) + def test_gatv2_conv_with_edge_attr(): x = torch.randn(4, 8) @@ -163,3 +201,14 @@ def test_gatv2_conv_with_edge_attr(): conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value='mean') out = conv(x, edge_index, edge_attr) assert out.size() == (4, 64) + + +def test_gat_conv_bipartite_error(): + x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + with pytest.raises(NotImplementedError, + match="not supported for bipartite message passing"): + conv = GATv2Conv((8, 16), 32, heads=2, normalize=True) + conv((x1, x2), edge_index) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 720dfb09811c..1ef86100d9e6 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -19,11 +19,16 @@ torch_sparse, ) from torch_geometric.utils import ( + add_remaining_self_loops, add_self_loops, is_torch_sparse_tensor, remove_self_loops, + scatter, softmax, + to_edge_index, + to_torch_csr_tensor, ) +from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.utils.sparse import set_sparse_value if typing.TYPE_CHECKING: @@ -32,6 +37,135 @@ from torch.jit import _overload_method as overload +@torch.jit._overload +def gat_norm( # noqa: F811 + edge_index, edge_weight, num_nodes, flow, dtype): + # type: (Tensor, Tensor, Optional[int], str, Optional[int]) -> Tuple[Tensor, Tensor] # noqa + pass + + +@torch.jit._overload +def gat_norm( # noqa: F811 + edge_index, edge_weight, num_nodes, flow, dtype): + # type: (SparseTensor, Tensor, Optional[int], str, Optional[int]) -> Tuple[SparseTensor, Tensor] # noqa + pass + + +@torch.jit._overload +def gat_norm( # noqa: F811 + edge_index, edge_weight, num_nodes, flow, dtype): + # type: (Adj, Tensor, Optional[int], str, Optional[int]) -> Tuple[Adj, Tensor] # noqa + pass + + +def gat_norm( # noqa: F811 + edge_index: Adj, + edge_weight: Tensor, + num_nodes: Optional[int] = None, + flow: str = "source_to_target", + dtype: Optional[torch.dtype] = None, +): + fill_value = 1.0 + + assert flow in ['source_to_target', 'target_to_source'] + + if isinstance(edge_index, SparseTensor): + assert edge_index.size(0) == edge_index.size(1) + + adj_t = edge_index + + if not adj_t.has_value(): + adj_t = adj_t.fill_value(1., dtype=dtype) + + # idx = col if flow == 'source_to_target' else row + dim = 1 if flow == 'source_to_target' else 0 + deg = torch_sparse.sum(adj_t, dim=dim) + + att_mat = adj_t.set_value(edge_weight) + # repeat_interleave for any dtype tensor + num_heads = edge_weight.shape[1] if edge_weight.dim() > 1 else 1 + repeat_deg = deg.view(-1).unsqueeze(1).repeat(1, num_heads).flatten() + # perform mat multiplication with sparse tensor + # for multiple attention heads + att_mat = torch_sparse.mul(att_mat, repeat_deg.view(-1, 1, num_heads)) + + # Add self-loop, also called renormalization trick + att_mat = torch_sparse.fill_diag(att_mat, fill_value) + deg_tilde = deg + 1 + deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) + deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0.) + # repeat_interleave for any dtype tensor + repeat_deg_tilde_inv = deg_tilde_inv_sqrt.view(-1).unsqueeze(1).repeat( + 1, num_heads).flatten() + repeat_deg_tilde_inv = repeat_deg_tilde_inv + + att_mat = torch_sparse.mul(att_mat, + repeat_deg_tilde_inv.view(-1, 1, num_heads)) + att_mat = torch_sparse.mul(att_mat, + repeat_deg_tilde_inv.view(1, -1, num_heads)) + # should never be None, only for typing purpose + alpha = att_mat.storage.value() + return att_mat.set_value( + None), alpha if alpha is not None else edge_weight + + if is_torch_sparse_tensor(edge_index): + assert edge_index.size(0) == edge_index.size(1) + + if edge_index.layout == torch.sparse_csc: + raise NotImplementedError("Sparse CSC matrices are not yet " + "supported in 'gat_norm'") + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + edge_index, value = to_edge_index(edge_index) + + col, row = edge_index[0], edge_index[1] + idx = col if flow == 'source_to_target' else row + deg = scatter(value, idx, 0, dim_size=num_nodes, reduce='sum') + att_mat = deg[col].view(-1, 1) * edge_weight + + # Add self-loop, also called renormalization trick + edge_index, att_mat = add_remaining_self_loops(edge_index, att_mat, + fill_value=fill_value, + num_nodes=num_nodes) + col, row = edge_index[0], edge_index[1] + deg_tilde = deg + 1 + deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) + deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0) + att_mat = deg_tilde_inv_sqrt[row].view( + -1, 1) * att_mat * deg_tilde_inv_sqrt[col].view(-1, 1) + + # Sort edge_index lexicographically + sorted_indices = torch.argsort(edge_index[0] * edge_index.size(1) + + edge_index[1]) + edge_index = edge_index[:, sorted_indices] + att_mat = att_mat[sorted_indices] + + return to_torch_csr_tensor(edge_index), att_mat + + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + adj_t = torch.ones((edge_index.size(1), ), dtype=dtype, + device=edge_index.device) + + row, col = edge_index[0], edge_index[1] + idx = col if flow == 'source_to_target' else row + deg = scatter(adj_t, idx, dim=0, dim_size=num_nodes, reduce='sum') + att_mat = deg[col].view(-1, + 1) * edge_weight # unnormalized attention matrix + + # Add self-loop, also called renormalization trick + adj_t, att_mat = add_self_loops(edge_index, att_mat, fill_value=fill_value, + num_nodes=num_nodes) + row, col = adj_t[0], adj_t[1] + deg_tilde = deg + 1 + deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) + deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0) + att_mat = deg_tilde_inv_sqrt[row].unsqueeze( + 1) * att_mat * deg_tilde_inv_sqrt[col].unsqueeze(1) + + return adj_t, att_mat + + class GATConv(MessagePassing): r"""The graph attentional operator from the `"Graph Attention Networks" `_ paper. @@ -76,6 +210,21 @@ class GATConv(MessagePassing): If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. + Normalization will be computed as presented in + `"Bag of Tricks for Node Classification with Graph Neural Networks" + `__ paper. + + .. math:: + \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}_{att}} + \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}_{att}} = A_{att} + \mathbf{I}` denotes the + unnormalized attention matrix :math:`\mathbf{\hat{A}_{att}} + = \mathbf{D}\mathbf{\alpha}` + with inserted self-loops, :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` + its diagonal degree matrix based on the adjacency matrix with inserted + self-loops :math:`\hat{A} = A + I`. + Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. @@ -109,6 +258,9 @@ class GATConv(MessagePassing): an additive bias. (default: :obj:`True`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) + normalize (bool, optional): If set to :obj:`True`, will add + self-loops to the input graph and compute symmetric normalization + coefficients on-the-fly. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. @@ -140,6 +292,7 @@ def __init__( fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, residual: bool = False, + normalize: bool = False, **kwargs, ): kwargs.setdefault('aggr', 'add') @@ -155,6 +308,7 @@ def __init__( self.edge_dim = edge_dim self.fill_value = fill_value self.residual = residual + self.normalize = normalize # In case we are operating in bipartite graphs, we apply separate # transformations 'lin_src' and 'lin_dst' to source and target nodes: @@ -336,7 +490,9 @@ def forward( # noqa: F811 alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) - if self.add_self_loops: + # Skipping add_self_loops when normalize + # as we are performing it on gat_norm function + if self.add_self_loops and not self.normalize: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: @@ -358,10 +514,34 @@ def forward( # noqa: F811 "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") + if self.normalize: + if not isinstance(self.in_channels, int): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") + + if isinstance(edge_index, Tensor): + edge_index, edge_attr = remove_self_loops( + edge_index, edge_attr) + elif isinstance(edge_index, SparseTensor): + edge_index = torch_sparse.remove_diag(edge_index) + # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, size=size) + if self.normalize: + num_nodes: Optional[int] = None + if isinstance(edge_index, Tensor): + num_nodes = x_src.size(0) + if x_dst is not None: + num_nodes = min(num_nodes, x_dst.size(0)) + num_nodes = min(size) if size is not None else num_nodes + + edge_index, alpha = gat_norm(edge_index, alpha, + num_nodes=num_nodes, flow=self.flow, + dtype=alpha.dtype) # noqa: F811 + # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index f3b2f4937e52..f6dcde265bf0 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -7,6 +7,7 @@ from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.conv.gat_conv import gat_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import ( @@ -112,6 +113,9 @@ class GATv2Conv(MessagePassing): (default: :obj:`False`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) + normalize (bool, optional): If set to :obj:`True`, will add + self-loops to the input graph and compute symmetric normalization + coefficients on-the-fly. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. @@ -144,6 +148,7 @@ def __init__( bias: bool = True, share_weights: bool = False, residual: bool = False, + normalize: bool = False, **kwargs, ): super().__init__(node_dim=0, **kwargs) @@ -159,6 +164,7 @@ def __init__( self.fill_value = fill_value self.residual = residual self.share_weights = share_weights + self.normalize = normalize if isinstance(in_channels, int): self.lin_l = Linear(in_channels, heads * out_channels, bias=bias, @@ -302,7 +308,7 @@ def forward( # noqa: F811 assert x_l is not None assert x_r is not None - if self.add_self_loops: + if self.add_self_loops and not self.normalize: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: @@ -321,10 +327,34 @@ def forward( # noqa: F811 "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") + if self.normalize: + if not isinstance(self.in_channels, int): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") + + if isinstance(edge_index, Tensor): + edge_index, edge_attr = remove_self_loops( + edge_index, edge_attr) + elif isinstance(edge_index, SparseTensor): + edge_index = torch_sparse.remove_diag(edge_index) + # edge_updater_type: (x: PairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, x=(x_l, x_r), edge_attr=edge_attr) + if self.normalize: + num_nodes = None + if isinstance(edge_index, Tensor): + num_nodes = x_l.size(0) + if x_r is not None: + num_nodes = min(num_nodes, x_r.size(0)) + edge_index, alpha = gat_norm(edge_index, + alpha, + num_nodes=num_nodes, + flow=self.flow, + dtype=alpha.dtype) # yapf: disable + # propagate_type: (x: PairTensor, alpha: Tensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 468f37abfaed..3698674dd6be 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -242,6 +242,10 @@ def fill_diag(src: SparseTensor, fill_value: float, k: int = 0) -> SparseTensor: raise ImportError("'fill_diag' requires 'torch-sparse'") + @staticmethod + def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: + raise ImportError("'remove_diag' requires 'torch-sparse'") + @staticmethod def masked_select_nnz(src: SparseTensor, mask: Tensor, layout: Optional[str] = None) -> SparseTensor: