diff --git a/torch_geometric/nn/conv/hypergraph_conv.py b/torch_geometric/nn/conv/hypergraph_conv.py index 411330d3ec21..74db56af3678 100644 --- a/torch_geometric/nn/conv/hypergraph_conv.py +++ b/torch_geometric/nn/conv/hypergraph_conv.py @@ -174,9 +174,9 @@ def forward(self, x: Tensor, hyperedge_index: Tensor, alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) alpha = F.leaky_relu(alpha, self.negative_slope) if self.attention_mode == 'node': - alpha = softmax(alpha, hyperedge_index[1], num_nodes=x.size(0)) + alpha = softmax(alpha, hyperedge_index[1], num_nodes=num_edges) else: - alpha = softmax(alpha, hyperedge_index[0], num_nodes=x.size(0)) + alpha = softmax(alpha, hyperedge_index[0], num_nodes=num_nodes) alpha = F.dropout(alpha, p=self.dropout, training=self.training) D = scatter(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0],