Skip to content

Commit

Permalink
Merge pull request #6 from arnab39/nbody
Browse files Browse the repository at this point in the history
Added Nbody
  • Loading branch information
arnab39 authored Mar 14, 2024
2 parents 8777201 + b164ef4 commit 955d8e9
Show file tree
Hide file tree
Showing 23 changed files with 2,119 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,6 @@ wandb/
# Docs
docs/api/
docs/_build/

# NBody Data
examples/nbody/data/n_body_system/dataset/*.npy
Empty file added equiadapt/nbody/__init__.py
Empty file.
158 changes: 158 additions & 0 deletions equiadapt/nbody/canonicalization/euclidean_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization


class EuclideanGroupNBody(ContinuousGroupCanonicalization):
"""
A class representing the continuous group for N-body canonicalization.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network.
canonicalization_hyperparams (dict): Hyperparameters for the canonicalization.
Attributes:
canonicalization_info_dict (dict): A dictionary containing the group element information.
"""

def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: dict,
) -> None:
super().__init__(canonicalization_network)

def forward(
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
Forward pass of the continuous group.
Args:
nodes: Node attributes.
**kwargs: Additional keyword arguments. Includes locs, edges, vel, edge_attr, and charges.
Returns:
The result of the canonicalization.
"""
return self.canonicalize(x, None, **kwargs)

def get_groupelement(
self,
nodes: torch.Tensor,
loc: torch.Tensor,
edges: torch.Tensor,
vel: torch.Tensor,
edge_attr: torch.Tensor,
charges: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Get the group element information.
Args:
nodes: Nodes data.
loc: Location data.
edges: Edges data.
vel: Velocity data.
edge_attr: Edge attributes data.
charges: Charges data.
Returns:
A dictionary containing the group element information.
"""
group_element_dict: Dict[str, torch.Tensor] = {}
rotation_vectors, translation_vectors = self.canonicalization_network(
nodes, loc, edges, vel, edge_attr, charges
)
rotation_matrix = self.modified_gram_schmidt(rotation_vectors)

# Check whether canonicalization_info_dict is already defined
if not hasattr(self, "canonicalization_info_dict"):
self.canonicalization_info_dict = {}

group_element_dict["rotation_matrix"] = rotation_matrix
group_element_dict["translation_vectors"] = translation_vectors
group_element_dict["rotation_matrix_inverse"] = rotation_matrix.transpose(
1, 2
) # Inverse of a rotation matrix is its transpose.

self.canonicalization_info_dict["group_element"] = group_element_dict

return group_element_dict

def canonicalize(
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
Canonicalize the input data.
Args:
nodes: Node attributes.
targets: Target data.
**kwargs: Additional keyword arguments. Includes locs, edges, vel, edge_attr, and charges.
Returns:
The canonicalized location and velocity.
"""
self.device = x.device

loc, edges, vel, edge_attr, charges = kwargs.values()

group_element_dict = self.get_groupelement(
x, loc, edges, vel, edge_attr, charges
)
translation_vectors = group_element_dict["translation_vectors"]
rotation_matrix_inverse = group_element_dict["rotation_matrix_inverse"]

# Canonicalizes coordinates by rotating node coordinates and translation vectors by inverse rotation.
# Shape: (n_nodes * batch_size) x coord_dim.
canonical_loc = (
torch.bmm(loc[:, None, :], rotation_matrix_inverse).squeeze()
- torch.bmm(
translation_vectors[:, None, :], rotation_matrix_inverse
).squeeze()
)
# Canonicalizes velocities.
# Shape: (n_nodes * batch_size) x vel_dim.
canonical_vel = torch.bmm(vel[:, None, :], rotation_matrix_inverse).squeeze()

return canonical_loc, canonical_vel

def invert_canonicalization(
self, x_canonicalized_out: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
"""This method takes as input the canonicalized output and returns the original output."""
rotation_matrix, translation_vectors, _ = self.canonicalization_info_dict[
"group_element"
].values()
loc = (
torch.bmm(x_canonicalized_out[:, None, :], rotation_matrix).squeeze()
+ translation_vectors
)
return loc

def modified_gram_schmidt(self, vectors: torch.Tensor) -> torch.Tensor:
"""
Apply the modified Gram-Schmidt process to the input vectors.
Args:
vectors: Input vectors.
Returns:
The orthonormalized vectors.
"""
v1 = vectors[:, 0]
v1 = v1 / torch.norm(v1, dim=1, keepdim=True)
v2 = vectors[:, 1] - torch.sum(vectors[:, 1] * v1, dim=1, keepdim=True) * v1
v2 = v2 / torch.norm(v2, dim=1, keepdim=True)
v3 = vectors[:, 2] - torch.sum(vectors[:, 2] * v1, dim=1, keepdim=True) * v1
v3 = v3 - torch.sum(v3 * v2, dim=1, keepdim=True) * v2
v3 = v3 / torch.norm(v3, dim=1, keepdim=True)
return torch.stack([v1, v2, v3], dim=1)
Loading

0 comments on commit 955d8e9

Please sign in to comment.