From cdd3a63e45ba3b508cd29df0e2b3079136f4e8f0 Mon Sep 17 00:00:00 2001 From: elilaird Date: Wed, 14 Feb 2024 10:25:59 -0600 Subject: [PATCH 1/9] added log_softmax to utils with tests --- test/utils/test_log_softmax.py | 156 ++++++++++++++++++ torch_geometric/typing.py | 68 ++++---- torch_geometric/utils/__init__.py | 218 ++++++++++++++------------ torch_geometric/utils/_log_softmax.py | 87 ++++++++++ 4 files changed, 400 insertions(+), 129 deletions(-) create mode 100644 test/utils/test_log_softmax.py create mode 100644 torch_geometric/utils/_log_softmax.py diff --git a/test/utils/test_log_softmax.py b/test/utils/test_log_softmax.py new file mode 100644 index 000000000000..feda7bbdab89 --- /dev/null +++ b/test/utils/test_log_softmax.py @@ -0,0 +1,156 @@ +import pytest +import torch + +import torch_geometric.typing +from torch_geometric.profile import benchmark +from torch_geometric.utils import log_softmax + +CALCULATION_VIA_PTR_AVAILABLE = (torch_geometric.typing.WITH_LOG_SOFTMAX + or torch_geometric.typing.WITH_TORCH_SCATTER) + +ATOL = 1e-4 +RTOL = 1e-4 + + +def test_log_softmax(): + src = torch.tensor([1.0, 1.0, 1.0, 1.0]) + index = torch.tensor([0, 0, 1, 2]) + ptr = torch.tensor([0, 2, 3, 4]) + + out = log_softmax(src, index) + assert torch.allclose(out, torch.tensor([-0.6931, -0.6931, 0.0000, + 0.0000]), atol=ATOL, rtol=RTOL) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose(log_softmax(src, ptr=ptr), out, atol=ATOL, + rtol=RTOL) + else: + with pytest.raises(NotImplementedError, match="requires 'index'"): + log_softmax(src, ptr=ptr) + + src = src.view(-1, 1) + out = log_softmax(src, index) + assert torch.allclose( + out, + torch.tensor([[-0.6931], [-0.6931], [0.0000], [0.0000]]), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose(log_softmax(src, None, ptr), out, atol=ATOL, + rtol=RTOL) + + jit = torch.jit.script(log_softmax) + assert torch.allclose(jit(src, index), out, atol=ATOL, rtol=RTOL) + + +def test_log_softmax_backward(): + src_sparse = torch.rand(4, 8, requires_grad=True) + index = torch.tensor([0, 0, 1, 1]) + src_dense = src_sparse.clone().detach().view(2, 2, src_sparse.size(-1)) + src_dense.requires_grad_(True) + + out_sparse = log_softmax(src_sparse, index) + out_sparse.sum().backward() + out_dense = torch.log_softmax(src_dense, dim=1) + out_dense.sum().backward() + + assert torch.allclose(out_sparse, out_dense.view_as(out_sparse), atol=ATOL) + assert torch.allclose(src_sparse.grad, src_dense.grad.view_as(src_sparse), + atol=ATOL) + + +def test_log_softmax_dim(): + index = torch.tensor([0, 0, 0, 0]) + ptr = torch.tensor([0, 4]) + + src = torch.randn(4) + assert torch.allclose( + log_softmax(src, index, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + + src = torch.randn(4, 16) + assert torch.allclose( + log_softmax(src, index, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + + src = torch.randn(4, 4) + assert torch.allclose( + log_softmax(src, index, dim=-1), + torch.log_softmax(src, dim=-1), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=-1), + torch.log_softmax(src, dim=-1), + atol=ATOL, + rtol=RTOL, + ) + + src = torch.randn(4, 4, 16) + assert torch.allclose( + log_softmax(src, index, dim=1), + torch.log_softmax(src, dim=1), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=1), + torch.log_softmax(src, dim=1), + atol=ATOL, + rtol=RTOL, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--backward", action="store_true") + args = parser.parse_args() + + num_nodes, num_edges = 10_000, 200_000 + x = torch.randn(num_edges, 64, device=args.device) + index = torch.randint(num_nodes, (num_edges, ), device=args.device) + + compiled_log_softmax = torch.compile(log_softmax) + + def dense_softmax(x, index): + x = x.view(num_nodes, -1, x.size(-1)) + return x.softmax(dim=-1) + + def dense_log_softmax(x, index): + x = x.view(num_nodes, -1, x.size(-1)) + return torch.log_softmax(x, dim=-1) + + benchmark( + funcs=[dense_log_softmax, log_softmax, compiled_log_softmax], + func_names=["Dense Log Softmax", "Vanilla", "Compiled"], + args=(x, index), + num_steps=50 if args.device == "cpu" else 500, + num_warmups=10 if args.device == "cpu" else 100, + backward=args.backward, + ) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index e7a6a90f82e7..be80f078408b 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -8,46 +8,48 @@ import torch from torch import Tensor -WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2 -WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1 -WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2 -WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3 -WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11 -WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12 -WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13 +WITH_PT20 = int(torch.__version__.split(".")[0]) >= 2 +WITH_PT21 = WITH_PT20 and int(torch.__version__.split(".")[1]) >= 1 +WITH_PT22 = WITH_PT20 and int(torch.__version__.split(".")[1]) >= 2 +WITH_PT23 = WITH_PT20 and int(torch.__version__.split(".")[1]) >= 3 +WITH_PT111 = WITH_PT20 or int(torch.__version__.split(".")[1]) >= 11 +WITH_PT112 = WITH_PT20 or int(torch.__version__.split(".")[1]) >= 12 +WITH_PT113 = WITH_PT20 or int(torch.__version__.split(".")[1]) >= 13 -WITH_WINDOWS = os.name == 'nt' +WITH_WINDOWS = os.name == "nt" MAX_INT64 = torch.iinfo(torch.int64).max -if not hasattr(torch, 'sparse_csc'): +if not hasattr(torch, "sparse_csc"): torch.sparse_csc = torch.sparse_coo try: import pyg_lib # noqa + WITH_PYG_LIB = True - WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, 'grouped_matmul') - WITH_SEGMM = hasattr(pyg_lib.ops, 'segment_matmul') - if WITH_SEGMM and 'pytest' in sys.modules and torch.cuda.is_available(): + WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, "grouped_matmul") + WITH_SEGMM = hasattr(pyg_lib.ops, "segment_matmul") + if WITH_SEGMM and "pytest" in sys.modules and torch.cuda.is_available(): # NOTE `segment_matmul` is currently bugged on older NVIDIA cards which # let our GPU tests on CI crash. Try if this error is present on the # current GPU and disable `WITH_SEGMM`/`WITH_GMM` if necessary. # TODO Drop this code block once `segment_matmul` is fixed. try: - x = torch.randn(3, 4, device='cuda') - ptr = torch.tensor([0, 2, 3], device='cuda') - weight = torch.randn(2, 4, 4, device='cuda') + x = torch.randn(3, 4, device="cuda") + ptr = torch.tensor([0, 2, 3], device="cuda") + weight = torch.randn(2, 4, 4, device="cuda") out = pyg_lib.ops.segment_matmul(x, ptr, weight) except RuntimeError: WITH_GMM = False WITH_SEGMM = False - WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add') - WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr') - WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort') - WITH_METIS = hasattr(pyg_lib, 'partition') - WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature( + WITH_SAMPLED_OP = hasattr(pyg_lib.ops, "sampled_add") + WITH_SOFTMAX = hasattr(pyg_lib.ops, "softmax_csr") + WITH_LOG_SOFTMAX = hasattr(pyg_lib.ops, "log_softmax_csr") + WITH_INDEX_SORT = hasattr(pyg_lib.ops, "index_sort") + WITH_METIS = hasattr(pyg_lib, "partition") + WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ("edge_time" in inspect.signature( pyg_lib.sampler.neighbor_sample).parameters) - WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature( + WITH_WEIGHTED_NEIGHBOR_SAMPLE = ("edge_weight" in inspect.signature( pyg_lib.sampler.neighbor_sample).parameters) except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -59,6 +61,7 @@ WITH_SEGMM = False WITH_SAMPLED_OP = False WITH_SOFTMAX = False + WITH_LOG_SOFTMAX = False WITH_INDEX_SORT = False WITH_METIS = False WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False @@ -66,6 +69,7 @@ try: import torch_scatter # noqa + WITH_TORCH_SCATTER = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -76,8 +80,9 @@ try: import torch_cluster # noqa + WITH_TORCH_CLUSTER = True - WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__ + WITH_TORCH_CLUSTER_BATCH_SIZE = "batch_size" in torch_cluster.knn.__doc__ except Exception as e: if not isinstance(e, ImportError): # pragma: no cover warnings.warn(f"An issue occurred while importing 'torch-cluster'. " @@ -93,6 +98,7 @@ def __getattr__(self, key: str) -> Any: try: import torch_spline_conv # noqa + WITH_TORCH_SPLINE_CONV = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -104,6 +110,7 @@ def __getattr__(self, key: str) -> Any: try: import torch_sparse # noqa from torch_sparse import SparseStorage, SparseTensor + WITH_TORCH_SPARSE = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -156,7 +163,7 @@ def from_edge_index( sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, is_sorted: bool = False, trust_data: bool = False, - ) -> 'SparseTensor': + ) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") @property @@ -165,7 +172,7 @@ def storage(self) -> SparseStorage: @classmethod def from_dense(self, mat: Tensor, - has_value: bool = True) -> 'SparseTensor': + has_value: bool = True) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") def size(self, dim: int) -> int: @@ -181,11 +188,11 @@ def has_value(self) -> bool: raise ImportError("'SparseTensor' requires 'torch-sparse'") def set_value(self, value: Optional[Tensor], - layout: Optional[str] = None) -> 'SparseTensor': + layout: Optional[str] = None) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") def fill_value(self, fill_value: float, - dtype: Optional[torch.dtype] = None) -> 'SparseTensor': + dtype: Optional[torch.dtype] = None) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") def coo(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]: @@ -235,6 +242,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, try: import torch_frame # noqa + WITH_TORCH_FRAME = True from torch_frame import TensorFrame except Exception: @@ -247,6 +255,7 @@ class TensorFrame: # type: ignore try: import intel_extension_for_pytorch # noqa + WITH_IPEX = True except Exception: WITH_IPEX = False @@ -265,6 +274,7 @@ def __init__( def t(self) -> Tensor: # Only support accessing its transpose: from torch_geometric.utils import to_torch_csr_tensor + size = self.size return to_torch_csr_tensor( self.edge_index.flip([0]), @@ -284,15 +294,15 @@ def t(self) -> Tensor: # Only support accessing its transpose: NodeOrEdgeType = Union[NodeType, EdgeType] -DEFAULT_REL = 'to' -EDGE_TYPE_STR_SPLIT = '__' +DEFAULT_REL = "to" +EDGE_TYPE_STR_SPLIT = "__" class EdgeTypeStr(str): r"""A helper class to construct serializable edge types by merging an edge type tuple into a single string. """ - def __new__(cls, *args: Any) -> 'EdgeTypeStr': + def __new__(cls, *args: Any) -> "EdgeTypeStr": if isinstance(args[0], (list, tuple)): # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`: args = tuple(args[0]) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index e00d169e4086..7bae4208d5a9 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -8,16 +8,21 @@ from .functions import cumsum from ._degree import degree from ._softmax import softmax +from ._log_softmax import log_softmax from ._lexsort import lexsort from ._sort_edge_index import sort_edge_index from ._coalesce import coalesce from .undirected import is_undirected, to_undirected -from .loop import (contains_self_loops, remove_self_loops, - segregate_self_loops, add_self_loops, - add_remaining_self_loops, get_self_loop_attr) +from .loop import ( + contains_self_loops, + remove_self_loops, + segregate_self_loops, + add_self_loops, + add_remaining_self_loops, + get_self_loop_attr, +) from .isolated import contains_isolated_nodes, remove_isolated_nodes -from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph, - bipartite_subgraph) +from ._subgraph import get_num_hops, subgraph, k_hop_subgraph, bipartite_subgraph from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path from ._homophily import homophily from ._assortativity import assortativity @@ -28,10 +33,16 @@ from ._to_dense_batch import to_dense_batch from ._to_dense_adj import to_dense_adj from .nested import to_nested_tensor, from_nested_tensor -from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor, - to_torch_coo_tensor, to_torch_csr_tensor, - to_torch_csc_tensor, to_torch_sparse_tensor, - to_edge_index) +from .sparse import ( + dense_to_sparse, + is_sparse, + is_torch_sparse_tensor, + to_torch_coo_tensor, + to_torch_csr_tensor, + to_torch_csc_tensor, + to_torch_sparse_tensor, + to_edge_index, +) from ._spmm import spmm from ._unbatch import unbatch, unbatch_edge_index from ._one_hot import one_hot @@ -45,11 +56,17 @@ from .convert import to_cugraph, from_cugraph from .convert import to_dgl, from_dgl from .smiles import from_smiles, to_smiles -from .random import (erdos_renyi_graph, stochastic_blockmodel_graph, - barabasi_albert_graph) -from ._negative_sampling import (negative_sampling, batched_negative_sampling, - structured_negative_sampling, - structured_negative_sampling_feasible) +from .random import ( + erdos_renyi_graph, + stochastic_blockmodel_graph, + barabasi_albert_graph, +) +from ._negative_sampling import ( + negative_sampling, + batched_negative_sampling, + structured_negative_sampling, + structured_negative_sampling_feasible, +) from .augmentation import shuffle_node, mask_feature, add_random_edge from ._tree_decomposition import tree_decomposition from .embedding import get_embeddings @@ -58,94 +75,95 @@ from ._train_test_split_edges import train_test_split_edges __all__ = [ - 'scatter', - 'group_argsort', - 'segment', - 'index_sort', - 'cumsum', - 'degree', - 'softmax', - 'lexsort', - 'sort_edge_index', - 'coalesce', - 'is_undirected', - 'to_undirected', - 'contains_self_loops', - 'remove_self_loops', - 'segregate_self_loops', - 'add_self_loops', - 'add_remaining_self_loops', - 'get_self_loop_attr', - 'contains_isolated_nodes', - 'remove_isolated_nodes', - 'get_num_hops', - 'subgraph', - 'bipartite_subgraph', - 'k_hop_subgraph', - 'dropout_node', - 'dropout_edge', - 'dropout_path', - 'dropout_adj', - 'homophily', - 'assortativity', - 'get_laplacian', - 'get_mesh_laplacian', - 'mask_select', - 'index_to_mask', - 'mask_to_index', - 'select', - 'narrow', - 'to_dense_batch', - 'to_dense_adj', - 'to_nested_tensor', - 'from_nested_tensor', - 'dense_to_sparse', - 'is_torch_sparse_tensor', - 'is_sparse', - 'to_torch_coo_tensor', - 'to_torch_csr_tensor', - 'to_torch_csc_tensor', - 'to_torch_sparse_tensor', - 'to_edge_index', - 'spmm', - 'unbatch', - 'unbatch_edge_index', - 'one_hot', - 'normalized_cut', - 'grid', - 'geodesic_distance', - 'to_scipy_sparse_matrix', - 'from_scipy_sparse_matrix', - 'to_networkx', - 'from_networkx', - 'to_networkit', - 'from_networkit', - 'to_trimesh', - 'from_trimesh', - 'to_cugraph', - 'from_cugraph', - 'to_dgl', - 'from_dgl', - 'from_smiles', - 'to_smiles', - 'erdos_renyi_graph', - 'stochastic_blockmodel_graph', - 'barabasi_albert_graph', - 'negative_sampling', - 'batched_negative_sampling', - 'structured_negative_sampling', - 'structured_negative_sampling_feasible', - 'shuffle_node', - 'mask_feature', - 'add_random_edge', - 'tree_decomposition', - 'get_embeddings', - 'trim_to_layer', - 'get_ppr', - 'train_test_split_edges', + "scatter", + "group_argsort", + "segment", + "index_sort", + "cumsum", + "degree", + "softmax", + "log_softmax" + "lexsort", + "sort_edge_index", + "coalesce", + "is_undirected", + "to_undirected", + "contains_self_loops", + "remove_self_loops", + "segregate_self_loops", + "add_self_loops", + "add_remaining_self_loops", + "get_self_loop_attr", + "contains_isolated_nodes", + "remove_isolated_nodes", + "get_num_hops", + "subgraph", + "bipartite_subgraph", + "k_hop_subgraph", + "dropout_node", + "dropout_edge", + "dropout_path", + "dropout_adj", + "homophily", + "assortativity", + "get_laplacian", + "get_mesh_laplacian", + "mask_select", + "index_to_mask", + "mask_to_index", + "select", + "narrow", + "to_dense_batch", + "to_dense_adj", + "to_nested_tensor", + "from_nested_tensor", + "dense_to_sparse", + "is_torch_sparse_tensor", + "is_sparse", + "to_torch_coo_tensor", + "to_torch_csr_tensor", + "to_torch_csc_tensor", + "to_torch_sparse_tensor", + "to_edge_index", + "spmm", + "unbatch", + "unbatch_edge_index", + "one_hot", + "normalized_cut", + "grid", + "geodesic_distance", + "to_scipy_sparse_matrix", + "from_scipy_sparse_matrix", + "to_networkx", + "from_networkx", + "to_networkit", + "from_networkit", + "to_trimesh", + "from_trimesh", + "to_cugraph", + "from_cugraph", + "to_dgl", + "from_dgl", + "from_smiles", + "to_smiles", + "erdos_renyi_graph", + "stochastic_blockmodel_graph", + "barabasi_albert_graph", + "negative_sampling", + "batched_negative_sampling", + "structured_negative_sampling", + "structured_negative_sampling_feasible", + "shuffle_node", + "mask_feature", + "add_random_edge", + "tree_decomposition", + "get_embeddings", + "trim_to_layer", + "get_ppr", + "train_test_split_edges", ] # `structured_negative_sampling_feasible` is a long name and thus destroys the # documentation rendering. We remove it for now from the documentation: classes = copy.copy(__all__) -classes.remove('structured_negative_sampling_feasible') +classes.remove("structured_negative_sampling_feasible") diff --git a/torch_geometric/utils/_log_softmax.py b/torch_geometric/utils/_log_softmax.py new file mode 100644 index 000000000000..79667e3c014d --- /dev/null +++ b/torch_geometric/utils/_log_softmax.py @@ -0,0 +1,87 @@ +import torch + +from torch_geometric.utils import scatter, segment +from torch_geometric.utils.num_nodes import maybe_num_nodes + + +def log_softmax( + src: torch.Tensor, + index: torch.Optional[torch.Tensor] = None, + ptr: torch.Optional[torch.Tensor] = None, + num_nodes: torch.Optional[int] = None, + dim: int = 0, +) -> torch.Tensor: + r"""Computes a sparsely evaluated log_softmax. + + Given a value tensor :attr:`src`, this function first groups the values + along the specified dimension based on the indices specified in :attr:`index` + or sorted inputs in CSR representation given by :attr:`ptr`, and then proceeds + to compute the log_softmax individually for each group. + + The log_softmax operation is defined as the logarithm of the softmax + probabilities, which can provide numerical stability improvements over + separately computing softmax followed by a logarithm. + + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + log_softmax. When specified, `src` values are grouped by `index` to + compute the log_softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the log_softmax based on + sorted inputs in CSR representation. This allows for efficient + computation over contiguous ranges of nodes. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.*, the maximum value + + 1 of :attr:`index`. This is required when `index` is specified to + determine the dimension size for scattering operations. + (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. For most use + cases, this should be set to 0, as log_softmax is typically applied + along the node dimension in graph neural networks. (default: :obj:`0`) + + :rtype: :class:`Tensor` + + Examples: + >>> src = torch.tensor([1., 2., 3., 4.]) + >>> index = torch.tensor([0, 0, 1, 1]) + >>> ptr = torch.tensor([0, 2, 4]) + >>> log_softmax(src, index) + tensor([-3.0486, -2.0486, -2.0486, -3.0486]) + + >>> log_softmax(src, None, ptr) + tensor([-3.0486, -2.0486, -2.0486, -3.0486]) + + >>> src = torch.randn(4, 4) + >>> ptr = torch.tensor([0, 4]) + >>> log_softmax(src, index, dim=-1) + tensor([[-1.3130, -0.6931, -0.3130, -1.3130], + [-1.0408, -0.0408, -0.0408, -1.0408], + [-0.5514, -0.5514, -0.1542, -0.5514], + [-0.7520, -0.7520, -0.1542, -0.7520]]) + """ + if ptr is not None: + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + count = ptr[1:] - ptr[:-1] + ptr = ptr.view(size) + src_max = segment(src.detach(), ptr, reduce="max") + src_max = src_max.repeat_interleave(count, dim=dim) + out = src - src_max + out_exp = out.exp() + out_sum = segment(out_exp, ptr, reduce="sum") + 1e-16 + out_sum = out_sum.repeat_interleave(count, dim=dim) + log_out_sum = out_sum.log() + out = out - log_out_sum + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src.detach(), index, dim, dim_size=N, reduce="max") + out = src - src_max.index_select(dim, index) + out_exp = out.exp() + out_sum = scatter(out_exp, index, dim, dim_size=N, + reduce="sum") + 1e-16 + out_sum = out_sum.index_select(dim, index) + log_out_sum = out_sum.log() + out = out - log_out_sum + else: + raise NotImplementedError + + return out From dc2e6c5cbfb4e4f4ced9035d2c7c9bf32407640c Mon Sep 17 00:00:00 2001 From: elilaird Date: Wed, 14 Feb 2024 10:42:14 -0600 Subject: [PATCH 2/9] added chamgelog entry --- CHANGELOG.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 454028d5a23a..be29e9a72f89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added implementation of `log_softmax` in `torch_geometric.utils` ([#8909](https://github.com/pyg-team/pytorch_geometric/pull/8909)) - Added an example for recommender systems, including k-NN search and retrieval metrics ([#8546](https://github.com/pyg-team/pytorch_geometric/pull/8546)) - Added multi-GPU evaluation in distributed sampling example ([#8880](https://github.com/pyg-team/pytorch_geometric/pull/8880)) - Added end-to-end example for distributed CPU training ([#8713](https://github.com/pyg-team/pytorch_geometric/pull/8713)) @@ -141,7 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925)) - Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917)) - Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918)) -- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915)) +- Added support for floating-point slicing in `Dataset`, _e.g._, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915)) - Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895)) - Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827)) - Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864)) @@ -159,7 +160,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656)) - Added `Data.sort()` and `HeteroData.sort()` functionalities ([#7649](https://github.com/pyg-team/pytorch_geometric/pull/7649)) - Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647)) -- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700)) +- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700)) - Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603)) - Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594)) - Enabled different attention modes in `HypergraphConv` via the `attention_mode` argument ([#7601](https://github.com/pyg-team/pytorch_geometric/pull/7601)) @@ -219,7 +220,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166)) +- Fixed `HeteroConv` for layers that have a non-default argument order, _e.g._, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166)) - Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163)) - Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145)) - Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143)) @@ -228,7 +229,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942)) - Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737)) - Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955)) -- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956) +- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956) - Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953)) - Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941)) - Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923)) From 41719c9bfb317bb4fae7c7efcbd6590d8a926220 Mon Sep 17 00:00:00 2001 From: elilaird Date: Wed, 14 Feb 2024 10:56:20 -0600 Subject: [PATCH 3/9] fixed missing comma and formatted --- torch_geometric/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 7bae4208d5a9..35db2ded1b06 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -82,7 +82,7 @@ "cumsum", "degree", "softmax", - "log_softmax" + "log_softmax", "lexsort", "sort_edge_index", "coalesce", From 7e07bfe2a60f4ab0c1ac99982d34a4f0e9aabc67 Mon Sep 17 00:00:00 2001 From: elilaird Date: Wed, 14 Feb 2024 11:03:26 -0600 Subject: [PATCH 4/9] replace torch.Optional with typing.Optional --- torch_geometric/utils/_log_softmax.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/torch_geometric/utils/_log_softmax.py b/torch_geometric/utils/_log_softmax.py index 79667e3c014d..dedac00db87f 100644 --- a/torch_geometric/utils/_log_softmax.py +++ b/torch_geometric/utils/_log_softmax.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch_geometric.utils import scatter, segment @@ -6,17 +8,17 @@ def log_softmax( src: torch.Tensor, - index: torch.Optional[torch.Tensor] = None, - ptr: torch.Optional[torch.Tensor] = None, - num_nodes: torch.Optional[int] = None, + index: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + num_nodes: Optional[int] = None, dim: int = 0, ) -> torch.Tensor: r"""Computes a sparsely evaluated log_softmax. Given a value tensor :attr:`src`, this function first groups the values - along the specified dimension based on the indices specified in :attr:`index` - or sorted inputs in CSR representation given by :attr:`ptr`, and then proceeds - to compute the log_softmax individually for each group. + along the specified dimension based on the indices specified in + :attr:`index` or sorted inputs in CSR representation given by :attr:`ptr`, + and then proceeds to compute the log_softmax individually for each group. The log_softmax operation is defined as the logarithm of the softmax probabilities, which can provide numerical stability improvements over @@ -30,13 +32,14 @@ def log_softmax( ptr (LongTensor, optional): If given, computes the log_softmax based on sorted inputs in CSR representation. This allows for efficient computation over contiguous ranges of nodes. (default: :obj:`None`) - num_nodes (int, optional): The number of nodes, *i.e.*, the maximum value - + 1 of :attr:`index`. This is required when `index` is specified to - determine the dimension size for scattering operations. + num_nodes (int, optional): The number of nodes, *i.e.*, the maximum + value + 1 of :attr:`index`. This is required when `index` is + specified to determine the dimension for scattering operations. (default: :obj:`None`) dim (int, optional): The dimension in which to normalize. For most use cases, this should be set to 0, as log_softmax is typically applied - along the node dimension in graph neural networks. (default: :obj:`0`) + along the node dimension in graph neural networks. + (default: :obj:`0`) :rtype: :class:`Tensor` From f63f1b4b19dd77d94c7432eb31138ebec6a73ccf Mon Sep 17 00:00:00 2001 From: elilaird Date: Wed, 14 Feb 2024 11:11:11 -0600 Subject: [PATCH 5/9] fixed tab error in docstring for documentation --- torch_geometric/utils/_log_softmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/utils/_log_softmax.py b/torch_geometric/utils/_log_softmax.py index dedac00db87f..6e7310ad0797 100644 --- a/torch_geometric/utils/_log_softmax.py +++ b/torch_geometric/utils/_log_softmax.py @@ -18,7 +18,7 @@ def log_softmax( Given a value tensor :attr:`src`, this function first groups the values along the specified dimension based on the indices specified in :attr:`index` or sorted inputs in CSR representation given by :attr:`ptr`, - and then proceeds to compute the log_softmax individually for each group. + and then proceeds to compute the log_softmax individually for each group. The log_softmax operation is defined as the logarithm of the softmax probabilities, which can provide numerical stability improvements over From 8c28f407c7d24c98e9b03464f954121aacf6a93a Mon Sep 17 00:00:00 2001 From: elilaird Date: Wed, 14 Feb 2024 12:26:59 -0600 Subject: [PATCH 6/9] fixed nonimplementederror and scatter compile check --- torch_geometric/utils/_log_softmax.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_geometric/utils/_log_softmax.py b/torch_geometric/utils/_log_softmax.py index 6e7310ad0797..38f54baadfc1 100644 --- a/torch_geometric/utils/_log_softmax.py +++ b/torch_geometric/utils/_log_softmax.py @@ -2,6 +2,8 @@ import torch +import torch_geometric.typing +from torch_geometric import is_compiling from torch_geometric.utils import scatter, segment from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -61,7 +63,8 @@ def log_softmax( [-0.5514, -0.5514, -0.1542, -0.5514], [-0.7520, -0.7520, -0.1542, -0.7520]]) """ - if ptr is not None: + if (ptr is not None and torch_geometric.typing.WITH_TORCH_SCATTER + and not is_compiling()): dim = dim + src.dim() if dim < 0 else dim size = ([1] * dim) + [-1] count = ptr[1:] - ptr[:-1] @@ -85,6 +88,7 @@ def log_softmax( log_out_sum = out_sum.log() out = out - log_out_sum else: - raise NotImplementedError + raise NotImplementedError( + "'log_softmax' requires 'index' to be specified") return out From a838f7722e8b7caa9f7faceedca0f469d6bf726b Mon Sep 17 00:00:00 2001 From: elilaird Date: Fri, 1 Mar 2024 09:43:26 -0600 Subject: [PATCH 7/9] fixed flake8 formatting issue --- torch_geometric/utils/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 35db2ded1b06..7bd2a32092a4 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -22,7 +22,12 @@ get_self_loop_attr, ) from .isolated import contains_isolated_nodes, remove_isolated_nodes -from ._subgraph import get_num_hops, subgraph, k_hop_subgraph, bipartite_subgraph +from ._subgraph import ( + get_num_hops, + subgraph, + k_hop_subgraph, + bipartite_subgraph, +) from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path from ._homophily import homophily from ._assortativity import assortativity From ecc32576f7aaf92e57d4f498d71b7444d53ead96 Mon Sep 17 00:00:00 2001 From: Eli Laird Date: Mon, 17 Jun 2024 23:02:33 -0500 Subject: [PATCH 8/9] Update __init__.py group_cat not used. --- torch_geometric/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 5cf42c00875e..7bd2a32092a4 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -2,7 +2,7 @@ import copy -from ._scatter import scatter, group_argsort, group_cat +from ._scatter import scatter, group_argsort from ._segment import segment from ._index_sort import index_sort from .functions import cumsum From c48cbe93246a31c6dd3c084f1e98ce54dee29161 Mon Sep 17 00:00:00 2001 From: Eli Laird Date: Mon, 17 Jun 2024 23:13:36 -0500 Subject: [PATCH 9/9] Update __init__.py --- torch_geometric/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 7bd2a32092a4..7b7e0296ede6 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -2,7 +2,7 @@ import copy -from ._scatter import scatter, group_argsort +from ._scatter import scatter, group_argsort, group_cat from ._segment import segment from ._index_sort import index_sort from .functions import cumsum @@ -82,6 +82,7 @@ __all__ = [ "scatter", "group_argsort", + "group_cat", "segment", "index_sort", "cumsum",