Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All the changes from this semester #3

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Canonical Network
This builds on an existing repo by [_Kaba et. al_](https://github.com/oumarkaba/canonical_network/tree/main/canonical_network) in which they create code that implements their canonicalization functions from their paper
[_Equivariance with Learned Canonicalization_](https://arxiv.org/abs/2211.06489). Changes include a new Transformer implementation, documentation of functions, and various bug fixes.
202 changes: 173 additions & 29 deletions canonical_network/models/euclideangraph_base_models.py

Large diffs are not rendered by default.

84 changes: 69 additions & 15 deletions canonical_network/models/euclideangraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,31 @@
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.transforms import RotateAxisAngle, Rotate, random_rotations
from pytorch3d.transforms import RotateAxisAngle, Rotate, random_rotations
import torchmetrics.functional as tmf
import wandb

from canonical_network.models.vn_layers import *
from canonical_network.models.euclideangraph_base_models import EGNN_vel, GNN, VNDeepSets, BaseEuclideangraphModel
from canonical_network.models.euclideangraph_base_models import EGNN_vel, GNN, VNDeepSets, BaseEuclideangraphModel, Transformer
from canonical_network.utils import define_hyperparams, dict_to_object

# Input dim is 6 because location and velocity vectors are concatenated.
NBODY_HYPERPARAMS = {
"learning_rate": 1e-3,
"weight_decay": 1e-12,
"patience": 1000,
"hidden_dim": 32,
"learning_rate": 1e-3, #1e-3
"weight_decay": 1e-8,
"patience": 1000, #1000
"hidden_dim": 8, #32
"input_dim": 6,
"in_node_nf": 1,
"in_edge_nf": 2,
"num_layers": 4,
"out_dim": 1,
"canon_num_layers": 4,
"canon_hidden_dim": 16,
"num_layers": 4, #4
"out_dim": 4,
"canon_num_layers": 2,
"canon_hidden_dim": 32,
"canon_layer_pooling": "mean",
"canon_final_pooling": "mean",
"canon_nonlinearity": "relu",
"canon_feature": "p",
"canon_feature": "pv",
"canon_translation": False,
"canon_angular_feature": 0,
"canon_dropout": 0.5,
Expand All @@ -34,11 +36,16 @@
"final_pooling": "mean",
"nonlinearity": "relu",
"angular_feature": "pv",
"dropout": 0,
"dropout": 0.5, #0
"nheads": 8,
"ff_hidden": 128,
}


class EuclideangraphCanonFunction(pl.LightningModule):
"""
Returns rotation matrix and translation vectors for canonicalization
following eqns (9) and (10) in https://arxiv.org/pdf/2211.06489.pdf.
"""
def __init__(self, hyperparams):
super().__init__()
self.model_type = hyperparams.canon_model_type
Expand Down Expand Up @@ -73,9 +80,22 @@ def __init__(self, hyperparams):
}[self.model_type]()

def forward(self, nodes, loc, edges, vel, edge_attr, charges):
"""
Returns rotation matrix and translation vectors, which are denoted as O and t respectively in eqn. 10
of https://arxiv.org/pdf/2211.06489.pdf.

Args:
`nodes`: Norms of velocity vectors. Shape: (n_nodes*batch_size) x 1
`loc`: Starting locations of nodes. Shape: (n_nodes*batch_size) x 3
`edges`: list of length 2, where each element is a 2000 dimensional tensor
`vel`: Starting velocities of nodes. Shape: (n_nodes*batch_size) x 3
`edge_attr`: Products of charges and squared relative distances between adjacent nodes (each have their own column). Shape: (n_edges*batch_size) x 2
`charges`: Charges of nodes . Shape: (n_nodes * batch_size) x 1
"""
# (n_nodes * batch_size) x 3 x 3, (n_nodes * batch_size) x 3
rotation_vectors, translation_vectors = self.model(nodes, loc, edges, vel, edge_attr, charges)

rotation_matrix = self.modified_gram_schmidt(rotation_vectors)
# Apply gram schmidt to make vectors orthogonal for rotation matrix
rotation_matrix = self.modified_gram_schmidt(rotation_vectors) # (n_nodes * batch_size) x 3 x 3

return rotation_matrix, translation_vectors

Expand Down Expand Up @@ -104,6 +124,9 @@ def modified_gram_schmidt(self, vectors):


class EuclideangraphPredFunction(pl.LightningModule):
"""
Defines a neural network that makes predictions after canonicalization.
"""
def __init__(self, hyperparams):
super().__init__()
self.model_type = hyperparams.pred_model_type
Expand All @@ -112,19 +135,26 @@ def __init__(self, hyperparams):
self.input_dim = hyperparams.input_dim
self.in_node_nf = hyperparams.in_node_nf
self.in_edge_nf = hyperparams.in_edge_nf
self.ff_hidden = hyperparams.ff_hidden
self.nheads = hyperparams.nheads
self.dropout = hyperparams.dropout

model_hyperparams = {
"num_layers": self.num_layers,
"hidden_dim": self.hidden_dim,
"input_dim": self.input_dim,
"in_node_nf": self.in_node_nf,
"in_edge_nf": self.in_edge_nf,
"ff_hidden": self.ff_hidden,
"nheads": self.nheads,
"dropout": self.dropout,
}

self.model = {
"GNN": lambda: GNN(define_hyperparams(model_hyperparams)),
"EGNN": lambda: EGNN_vel(define_hyperparams(model_hyperparams)),
"vndeepsets": lambda: VNDeepSets(define_hyperparams(model_hyperparams)),
"Transformer": lambda: Transformer(define_hyperparams(model_hyperparams))
}[self.model_type]()

def forward(self, nodes, loc, edges, vel, edge_attr, charges):
Expand All @@ -145,17 +175,41 @@ def __init__(self, hyperparams):
self.canon_function.freeze()

def forward(self, nodes, loc, edges, vel, edge_attr, charges):
"""
Returns predicted coordinates.

Args:
`nodes`: Norms of velocity vectors. Shape: (n_nodes*batch_size) x 1
`loc`: Starting locations of nodes. Shape: (n_nodes*batch_size) x coord_dim
`edges`: list of length 2, where each element is a 2000 dimensional tensor
`vel`: Starting velocities of nodes. Shape: (n_nodes*batch_size) x vel_dim
`edge_attr`: Products of charges and squared relative distances between adjacent nodes (each have their own column). Shape: (n_edges*batch_size) x 2
`charges`: Charges of nodes . Shape: (n_nodes * batch_size) x 1
"""
# Rotation and translation vectors from eqn (10) in https://arxiv.org/pdf/2211.06489.pdf.
# Shapes: (n_nodes * batch_size) x 3 x 3 and (n_nodes * batch_size) x 3
# ie. One rotation matrix and one translation vector for each node.
# QUESTION: IS ROTATION MATRIX THE SAME FOR ALL NODES IN A BATCH?
rotation_matrix, translation_vectors = self.canon_function(nodes, loc, edges, vel, edge_attr, charges)
rotation_matrix_inverse = rotation_matrix.transpose(1, 2)
rotation_matrix_inverse = rotation_matrix.transpose(1, 2) # Inverse of a rotation matrix is its transpose.

# Canonicalizes coordinates by rotating node coordinates and translation vectors by inverse rotation.
# Shape: (n_nodes * batch_size) x coord_dim.
# loc[:,None, :] adds a dimension to loc, so that it can be multiplied with rotation_matrix_inverse. (n_nodes * batch_size) x 1 x 3
canonical_loc = (
torch.bmm(loc[:, None, :], rotation_matrix_inverse).squeeze()
- torch.bmm(translation_vectors[:, None, :], rotation_matrix_inverse).squeeze()
)
# Canonicalizes velocities.
# Shape: (n_nodes * batch_size) x vel_dim.
canonical_vel = torch.bmm(vel[:, None, :], rotation_matrix_inverse).squeeze()

# Makes prediction on canonical inputs.
# Shape: (n_nodes * batch_size) x coord_dim.
position_prediction = self.pred_function(nodes, canonical_loc, edges, canonical_vel, edge_attr, charges)

# Applies rotation to predictions, following equation (10) from https://arxiv.org/pdf/2211.06489.pdf
# Shape: (n_nodes * batch_size) x coord_dim.
position_prediction = (
torch.bmm(position_prediction[:, None, :], rotation_matrix).squeeze() + translation_vectors
)
Expand Down
Loading