Skip to content

Commit

Permalink
Merge pull request #314 from laserkelvin/scatter-add-refactor
Browse files Browse the repository at this point in the history
Scatter add refactor
  • Loading branch information
laserkelvin authored Nov 8, 2024
2 parents 86c6261 + 6e6ccd1 commit 2caaee3
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 15 deletions.
16 changes: 8 additions & 8 deletions matsciml/models/pyg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,28 @@

# load models if we have PyG installed
if _has_pyg:
from matsciml.models.pyg.cgcnn import CGCNN
from matsciml.models.pyg.egnn import EGNN
from matsciml.models.pyg.faenet import FAENet
from matsciml.models.pyg.mace import MACE, ScaleShiftMACE

__all__ = ["CGCNN", "EGNN", "FAENet", "MACE", "ScaleShiftMACE"]
__all__ = ["EGNN", "FAENet", "MACE", "ScaleShiftMACE"]

# these packages need additional pyg dependencies
if package_registry["torch_sparse"] and package_registry["torch_scatter"]:
from matsciml.models.pyg.dimenet import DimeNetWrap
from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap
from matsciml.models.pyg.dimenet import DimeNetWrap # noqa: F401
from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap # noqa: F401

__all__.extend(["DimeNetWrap", "DimeNetPlusPlusWrap"])
else:
logger.warning(
"Missing torch_sparse and torch_scatter; DimeNet models will not be available."
)
if package_registry["torch_scatter"]:
from matsciml.models.pyg.forcenet import ForceNet
from matsciml.models.pyg.schnet import SchNetWrap
from matsciml.models.pyg.faenet import FAENet
from matsciml.models.pyg.forcenet import ForceNet # noqa: F401
from matsciml.models.pyg.schnet import SchNetWrap # noqa: F401
from matsciml.models.pyg.cgcnn import CGCNN # noqa: F401

__all__.extend(["ForceNet", "SchNetWrap", "FAENet"])
__all__.extend(["ForceNet", "SchNetWrap", "FAENet", "CGCNN"])
else:
logger.warning(
"Missing torch_scatter; ForceNet, SchNet, and FAENet models will not be available."
Expand Down
7 changes: 2 additions & 5 deletions matsciml/models/pyg/faenet/layers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from __future__ import annotations

from typing import Tuple, Union

import pandas as pd
import torch
import torch.nn as nn
from mendeleev.fetch import fetch_ionization_energies, fetch_table
from torch import nn
from torch.nn import Embedding, Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.norm import GraphNorm
from torch_scatter import scatter

from matsciml.models.pyg.faenet.helper import *
from matsciml.models.pyg.scatter import scatter_sum


class PhysEmbedding(nn.Module):
Expand Down Expand Up @@ -508,7 +505,7 @@ def forward(
h = h * alpha

# Global pooling
out = scatter(h, batch, dim=0, reduce="add")
out = scatter_sum(h, batch, dim=0)

return out

Expand Down
178 changes: 178 additions & 0 deletions matsciml/models/pyg/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
###########################################################################################
# Implementation of MACE models and other models based E(3)-Equivariant MPNNs
# (https://github.com/ACEsuit/mace)
# Original Authors: Ilyes Batatia, Gregor Simm
# Integrated into matsciml by Vaibhav Bihani, Sajid Mannan
# Refactors and improved docstrings by Kelvin Lee
# This program is distributed under the MIT License
###########################################################################################
"""basic scatter_sum operations from torch_scatter from
https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py
Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency.
PyTorch plans to move these features into the main repo, but until then,
to make installation simpler, we need this pure python set of wrappers
that don't require installing PyTorch C++ extensions.
See https://github.com/pytorch/pytorch/issues/63780.
"""

from __future__ import annotations

from typing import Optional

import torch

__all__ = ["scatter_sum", "scatter_std", "scatter_mean"]


def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int) -> torch.Tensor:
"""
Broadcasts ``src`` to yield a tensor with equivalent shape to ``other``
along dimension ``dim``.
Parameters
----------
src : torch.Tensor
Tensor to broadcast into a new shape.
other : torch.Tensor
Tensor to match shape against.
dim : int
Dimension to broadcast values along.
Returns
-------
torch.Tensor
Broadcasted values of ``src``, with the same shape as ``other``.
"""
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src


@torch.jit.script
def scatter_sum(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: torch.Tensor | None = None,
dim_size: int | None = None,
reduce: str = "sum",
) -> torch.Tensor:
"""
Apply a scatter operation with sum reduction, from ``src``
to ``out`` at indices ``index`` along the specified
dimension.
The function will apply a ``_broadcast`` with ``index``
to reshape it to the same as ``src`` first, then allocate
a new tensor based on the expected final shape (depending
on ``dim``).
Parameters
----------
src : torch.Tensor
Tensor containing source values to scatter add.
index : torch.Tensor
Indices for the scatter add operation.
dim : int, optional
Dimension to apply the scatter add operation, by default -1
out : torch.Tensor, optional
Output tensor to store the scatter sum result, by default None,
which will create a tensor with the correct shape within
this function.
dim_size : int, optional
Used to determine the output shape, by default None, which
will then infer the output shape from ``dim``.
reduce : str, optional
Unused and kept for backwards compatibility.
Returns
-------
torch.Tensor
Resulting scatter sum output.
"""
assert reduce == "sum" # for now, TODO
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)


@torch.jit.script
def scatter_std(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True,
) -> torch.Tensor:
if out is not None:
dim_size = out.size(dim)

if dim < 0:
dim = src.dim() + dim

count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1

ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

index = _broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = _broadcast(count, tmp, dim).clamp(1)
mean = tmp.div(count)

var = src - mean.gather(dim, index)
var = var * var
out = scatter_sum(var, index, dim, out, dim_size)

if unbiased:
count = count.sub(1).clamp_(1)
out = out.div(count + 1e-6).sqrt()

return out


@torch.jit.script
def scatter_mean(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)

index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1

ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = _broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode="floor")
return out
9 changes: 7 additions & 2 deletions matsciml/models/pyg/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
from torch_geometric.nn import SchNet
from torch_scatter import scatter
from matsciml.models.pyg.scatter import scatter_sum, scatter_mean

from matsciml.common.utils import conditional_grad, get_pbc_distances, radius_graph_pbc

Expand Down Expand Up @@ -81,6 +81,11 @@ def __init__(
cutoff=cutoff,
readout=readout,
)
# map literal readout choice to functions
if readout == "add":
self.readout = scatter_sum
else:
self.readout = scatter_mean

@conditional_grad(torch.enable_grad())
def _forward(self, data):
Expand Down Expand Up @@ -124,7 +129,7 @@ def _forward(self, data):
h = self.lin2(h)

batch = torch.zeros_like(z) if batch is None else batch
energy = scatter(h, batch, dim=0, reduce=self.readout)
energy = self.readout(h, batch, dim=0)
else:
energy = super().forward(z, pos, batch)
return energy
Expand Down

0 comments on commit 2caaee3

Please sign in to comment.