diff --git a/README.md b/README.md new file mode 100644 index 0000000..fd7c886 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# Canonical Network +This builds on an existing repo by [_Kaba et. al_](https://github.com/oumarkaba/canonical_network/tree/main/canonical_network) in which they create code that implements their canonicalization functions from their paper +[_Equivariance with Learned Canonicalization_](https://arxiv.org/abs/2211.06489). Changes include a new Transformer implementation, documentation of functions, and various bug fixes. diff --git a/canonical_network/models/euclideangraph_base_models.py b/canonical_network/models/euclideangraph_base_models.py index cddedaf..dc366af 100644 --- a/canonical_network/models/euclideangraph_base_models.py +++ b/canonical_network/models/euclideangraph_base_models.py @@ -8,18 +8,21 @@ import torchmetrics.functional as tmf import wandb import torch_scatter as ts - +import math from canonical_network.models.gcl import E_GCL_vel, GCL from canonical_network.models.vn_layers import VNLinearLeakyReLU, VNLinear, VNLeakyReLU, VNSoftplus from canonical_network.models.set_base_models import SequentialMultiple +# This model is the parent of all the following models in this file. class BaseEuclideangraphModel(pl.LightningModule): def __init__(self, hyperparams): super().__init__() self.learning_rate = hyperparams.learning_rate if hasattr(hyperparams, "learning_rate") else None self.weight_decay = hyperparams.weight_decay if hasattr(hyperparams, "weight_decay") else 0.0 self.patience = hyperparams.patience if hasattr(hyperparams, "patience") else 100 + # Each input has 5 particles. This list defines all the edges, since our graph is fully connected. + # vertex at self.edges[0][i] has an edge connecting to self.edges[1][i] self.edges = [ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], [1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3], @@ -37,18 +40,32 @@ def __init__(self, hyperparams): self.dummy_edge_attr = torch.zeros(40, 2, device=self.device, dtype=torch.float) def training_step(self, batch, batch_idx): + """ + Performs one training step. + + Args: + `batch`: a list of tensors [loc, vel, edge_attr, charges, loc_end] + `loc`: batch_size x n_nodes x 3 + `vel`: batch_size x n_nodes x 3 + `edge_attr`: batch_size x n_edges x 1 + `charges`: batch_size x n_nodes x 1 + `loc_end`: batch_size x n_nodes x 3 + `batch_idx`: index of the batch + """ + batch_size, n_nodes, _ = batch[0].size() - batch = [d.view(-1, d.size(2)) for d in batch] + batch = [d.view(-1, d.size(2)) for d in batch] # converts to 2D matrices loc, vel, edge_attr, charges, loc_end = batch - edges = self.get_edges(batch_size, n_nodes) + edges = self.get_edges(batch_size, n_nodes) # returns a list of two tensors, each of size num_edges * batch_size (where num_edges is always 20, since G = K5) - nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() + nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() # norm of velocity vectors rows, cols = edges loc_dist = torch.sum((loc[rows] - loc[cols]) ** 2, 1).unsqueeze(1) # relative distances among locations edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties - outputs = self(nodes, loc.detach(), edges, vel, edge_attr, charges) + outputs = self(nodes, loc.detach(), edges, vel, edge_attr, charges) # self takes a step. + # outputs and loc_end are both (5*batch_size)x3 loss = self.loss(outputs, loc_end) metrics = {"train/loss": loss} @@ -57,6 +74,19 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Performs one validation step. + + Args: + Args: + `batch`: a list of tensors [loc, vel, edge_attr, charges, loc_end] + `loc`: batch_size x n_nodes x 3 + `vel`: batch_size x n_nodes x 3 + `edge_attr`: batch_size x n_edges x 1 + `charges`: batch_size x n_nodes x 1 + `loc_end`: batch_size x n_nodes x 3 + `batch_idx`: index of the batch + """ batch_size, n_nodes, _ = batch[0].size() batch = [d.view(-1, d.size(2)) for d in batch] loc, vel, edge_attr, charges, loc_end = batch @@ -79,7 +109,7 @@ def validation_step(self, batch, batch_idx): return loss def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=1e-12) + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=self.patience, factor=0.5, min_lr=1e-6, mode="max" ) @@ -106,18 +136,26 @@ def validation_epoch_end(self, validation_step_outputs): # wandb.save(model_filename) def get_edges(self, batch_size, n_nodes): + """ + Returns a length 2 list of vertices, where edges[0][i] is adjacent to edges[1][i] + + Args: + `batch_size`: int, defined in `train_nbody.HYPERPARAMS` + `n_nodes`: number of nodes in each sample. + """ edges = [torch.LongTensor(self.edges[0]).to(self.device), torch.LongTensor(self.edges[1]).to(self.device)] if batch_size == 1: return edges elif batch_size > 1: rows, cols = [], [] + # Adds 5i to the vertices in each sample, allowing us to use rows and cols for indexing our data. for i in range(batch_size): rows.append(edges[0] + n_nodes * i) cols.append(edges[1] + n_nodes * i) edges = [torch.cat(rows), torch.cat(cols)] return edges - +# Based on https://arxiv.org/pdf/2102.09844.pdf equation 7 class EGNN_vel(BaseEuclideangraphModel): def __init__(self, hyperparams): super(EGNN_vel, self).__init__(hyperparams) @@ -139,9 +177,9 @@ def __init__(self, hyperparams): self.add_module( "gcl_%d" % 0, E_GCL_vel( - self.hidden_dim, - self.hidden_dim, - self.hidden_dim, + input_nf=self.hidden_dim, + output_nf=self.hidden_dim, + hidden_dim=self.hidden_dim, edges_in_d=hyperparams.in_edge_nf, act_fn=self.act_fn, coords_weight=self.coords_weight, @@ -155,9 +193,9 @@ def __init__(self, hyperparams): self.add_module( "gcl_%d" % i, E_GCL_vel( - self.hidden_dim, - self.hidden_dim, - self.hidden_dim, + input_nf=self.hidden_dim, + output_nf=self.hidden_dim, + hidden_dim=self.hidden_dim, edges_in_d=hyperparams.in_edge_nf, act_fn=self.act_fn, coords_weight=self.coords_weight, @@ -171,9 +209,9 @@ def __init__(self, hyperparams): self.add_module( "gcl_%d" % (self.n_layers - 1), E_GCL_vel( - self.hidden_dim, - self.hidden_dim, - self.hidden_dim, + input_nf=self.hidden_dim, + output_nf=self.hidden_dim, + hidden_dim=self.hidden_dim, edges_in_d=hyperparams.in_edge_nf, act_fn=self.act_fn, coords_weight=self.coords_weight, @@ -186,12 +224,23 @@ def __init__(self, hyperparams): ) def forward(self, h, x, edges, vel, edge_attr, _): - h = self.embedding(h) + """ + Returns: Node coordinate embeddings + Args: + `h`: Norms of velocity vectors. Shape: (n_nodes * batch_size) x 1 + `x`: Coordinates of nodes. Shape: (n_nodes * batch_size) x coord_dim + `edges`: Length 2 list of vertices, where edges[0][i] is adjacent to edges[1][i]. + `vel`: Velocities of nodes. Shape: (n_nodes * batch_size) x vel_dim + `edge_attr`: Products of charges along edges. batch_size x n_edges x 1 + """ + h = self.embedding(h) # Node embeddings. (n_nodes * batch_size) x hidden_dim + # Applies each layer of EGNN for i in range(0, self.n_layers): h, x, _ = self._modules["gcl_%d" % i](h, edges, x, vel, edge_attr=edge_attr) - return x.squeeze(2) + return x.squeeze(2) # Predicted coordinates +# Model based on https://arxiv.org/pdf/2102.09844.pdf, equations 3-6. class GNN(BaseEuclideangraphModel): def __init__(self, hyperparams): super(GNN, self).__init__(hyperparams) @@ -208,9 +257,9 @@ def __init__(self, hyperparams): self.add_module( "gcl_%d" % i, GCL( - self.hidden_dim, - self.hidden_dim, - self.hidden_dim, + input_nf=self.hidden_dim, + output_nf=self.hidden_dim, + hidden_dim=self.hidden_dim, edges_in_nf=2, act_fn=self.act_fn, attention=self.attention, @@ -224,13 +273,23 @@ def __init__(self, hyperparams): self.embedding = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim)) def forward(self, nodes, loc, edges, vel, edge_attr, _): - nodes = torch.cat([loc, vel], dim=1) - h = self.embedding(nodes) + """ + Returns: Node coordinate embeddings + Args: + `nodes`: Norms of velocity vectors. Shape: (n_nodes * batch_size) x 1 + `loc`: Coordinates of nodes. Shape: (n_nodes * batch_size) x coord_dim + `edges`: Length 2 list of vertices, where edges[0][i] is adjacent to edges[1][i]. + `vel`: Velocities of nodes. Shape: (n_nodes * batch_size) x vel_dim + `edge_attr`: Products of charges along edges. batch_size x n_edges x 1 + """ + nodes = torch.cat([loc, vel], dim=1) # (n_nodes * batch_size) x (coord_dim + vel_dim) + h = self.embedding(nodes) # (n_nodes * batch_size) x hidden_dim # h, _ = self._modules["gcl_0"](h, edges, edge_attr=edge_attr) for i in range(0, self.n_layers): h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr) + # h is 500x32 and then passed to decoder to become 500x3 # return h - return self.decoder(h) + return self.decoder(h) # (n_nodes * batch_size) x 3 class VNDeepSets(BaseEuclideangraphModel): @@ -274,6 +333,9 @@ def forward(self, nodes, loc, edges, vel, edge_attr, charges): mean_loc = ts.scatter(loc, batch_indices, 0, reduce=self.layer_pooling) mean_loc = mean_loc.repeat(5, 1, 1).transpose(0, 1).reshape(-1, 3) canonical_loc = loc - mean_loc + # p = position + # v = velocity + # a = angular velocity (cross product of position and velocity) if self.canon_feature == "p": features = torch.stack([canonical_loc], dim=2) if self.canon_feature == "pv": @@ -288,18 +350,19 @@ def forward(self, nodes, loc, edges, vel, edge_attr, charges): features = torch.stack([canonical_loc, vel, angular, canonical_loc * charges], dim=2) x, _ = self.first_set_layer(features, edges) - x, _ = self.set_layers(x, edges) + x, _ = self.set_layers(x, edges) # n_nodes*batch_size x 3 x 16 if self.prediction_mode: output = self.output_layer(x) output = output.squeeze() return output + # Run this when being used as conicalizer else: - x = ts.scatter(x, batch_indices, 0, reduce=self.final_pooling) - output = self.output_layer(x) + x = ts.scatter(x, batch_indices, 0, reduce=self.final_pooling) # batch_size x 3 x 16 + output = self.output_layer(x) # 100 x 3 x 4 - output = output.repeat(5, 1, 1, 1).transpose(0, 1) - output = output.reshape(-1, 3, 4) + output = output.repeat(5, 1, 1, 1).transpose(0, 1) # batch_size x (n_nodes) x 3 x 16 + output = output.reshape(-1, 3, 4) # (batch_size * n_nodes) x 3 x 4 rotation_vectors = output[:, :, :3] translation_vectors = output[:, :, 3:] if self.canon_translation else 0.0 @@ -331,6 +394,9 @@ def __init__(self, in_channels, out_channels, nonlinearity, pooling="sum", resid self.nonlinear_function = VNLeakyReLU(out_channels, share_nonlinearity=False) def forward(self, x, edges): + # here x is the features, which depends on canon_feature + # check VNDeepSets.forward + # edges_1 = edges[0] edges_2 = edges[1] @@ -348,3 +414,81 @@ def forward(self, x, edges): output = output + x return output, edges + +class Transformer(BaseEuclideangraphModel): + def __init__(self, hyperparams): + super(Transformer, self).__init__(hyperparams) + self.model = "Transformer" + self.hidden_dim = hyperparams.hidden_dim #32 + self.input_dim = hyperparams.input_dim #6 + self.n_layers = hyperparams.num_layers #4 + self.ff_hidden = hyperparams.ff_hidden + self.act_fn = nn.ReLU() + self.dropout = hyperparams.dropout + self.nhead = hyperparams.nheads + + self.coord_embedding = nn.Linear(1, self.hidden_dim) + + self.pos_encoder = PositionalEncoding(hidden_dim=self.hidden_dim, dropout=self.dropout) + + self.charge_embedding = nn.Embedding(2,self.hidden_dim) + + encoder_layer = nn.TransformerEncoderLayer(d_model=7*self.hidden_dim, nhead=self.nhead, dim_feedforward=self.ff_hidden, batch_first=True) + self.encoder = torch.nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=self.n_layers) + + self.decoder = nn.Sequential( + nn.Linear(in_features=7*self.hidden_dim, out_features=7*self.hidden_dim), + self.act_fn, + nn.Linear(in_features=7*self.hidden_dim, out_features=3) + ) + + def forward(self, nodes, loc, edges, vel, edge_attr, charges): + """ + Forward pass through Transformer model + + Args: + `nodes`: Norms of velocity vectors. Shape: (n_nodes*batch_size) x 1 + `loc`: Starting locations of nodes. Shape: (n_nodes*batch_size) x 3 + `edges`: list of length 2, where each element is a 2000 dimensional tensor + `vel`: Starting velocities of nodes. Shape: (n_nodes*batch_size) x 3 + `edge_attr`: Products of charges and squared relative distances between adjacent nodes (each have their own column). Shape: (n_edges*batch_size) x 2 + `charges`: Charges of nodes . Shape: (n_nodes * batch_size) x 1 + """ + # Positional encodings + pos_encodings = torch.cat([loc,vel], dim = 1).unsqueeze(2) # n_nodes*batch x 6 x 1 + pos_encodings = self.coord_embedding(pos_encodings) + self.coord_embedding(pos_encodings)# n_nodes*batch x 6 x hidden_dim + + pos_encodings = self.pos_encoder(pos_encodings) #+ self.coord_embedding(pos_encodings)# n_nodes*batch x 6 x hidden_dim + # Charge embeddings + charges[charges == -1] = 0 # to work with nn.Embedding + charges = charges.long() + charges = self.charge_embedding(charges) # n_nodes*batch x 1 x hidden_dim + nodes = torch.cat([pos_encodings, charges], dim = 1) # n_nodes * batch_size x 7 x hidden_dim + nodes = nodes.view(-1, 5, nodes.shape[1]*nodes.shape[2]) # batch_size x n_nodes x (7 * hidden_dim) + h = self.encoder(nodes) # batch_size x n_nodes x (7 * hidden_dim) + h = h.view(-1,h.shape[2]) + h = self.decoder(h) + return h + + +class PositionalEncoding(nn.Module): + def __init__(self, hidden_dim, dropout): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + self.hidden_dim = hidden_dim + div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-math.log(10000.0) / hidden_dim)).view(1,1, int(hidden_dim / 2)) # 1 x 1 x (hidden_dim / 2) + self.register_buffer('div_term', div_term) # puts div_term on GPU + + def forward(self, x): + """ + Returns positional encoding of coordinates and velocities. + Args: + `x`: Concatenated velocity and coordinate vectors. Shape: (n_nodes * batch_size x 6 x 1) + Output: + `pe`: Positional encoding of x. Shape: (n_nodes * batch_size x 6 x 32) + """ + pe = torch.zeros(x.shape[0],x.shape[1], self.hidden_dim).to(x.device) # (n_nodes * batch_size) x 6 x 32 + sin_terms = torch.sin(x * self.div_term) # (n_nodes * batch_size) x 6 x 1 + pe[:, :,0::2] = sin_terms + pe[:, :,1::2] = torch.cos(x * self.div_term) # there is an encoding for each dimension (ie. embedding for x, y, z, vx, vy, vz) + return self.dropout(pe) \ No newline at end of file diff --git a/canonical_network/models/euclideangraph_model.py b/canonical_network/models/euclideangraph_model.py index ffa7c94..023bab0 100644 --- a/canonical_network/models/euclideangraph_model.py +++ b/canonical_network/models/euclideangraph_model.py @@ -3,29 +3,31 @@ import torch.nn.functional as F import pytorch_lightning as pl from pytorch3d.transforms import RotateAxisAngle, Rotate, random_rotations +from pytorch3d.transforms import RotateAxisAngle, Rotate, random_rotations import torchmetrics.functional as tmf import wandb from canonical_network.models.vn_layers import * -from canonical_network.models.euclideangraph_base_models import EGNN_vel, GNN, VNDeepSets, BaseEuclideangraphModel +from canonical_network.models.euclideangraph_base_models import EGNN_vel, GNN, VNDeepSets, BaseEuclideangraphModel, Transformer from canonical_network.utils import define_hyperparams, dict_to_object +# Input dim is 6 because location and velocity vectors are concatenated. NBODY_HYPERPARAMS = { - "learning_rate": 1e-3, - "weight_decay": 1e-12, - "patience": 1000, - "hidden_dim": 32, + "learning_rate": 1e-3, #1e-3 + "weight_decay": 1e-8, + "patience": 1000, #1000 + "hidden_dim": 8, #32 "input_dim": 6, "in_node_nf": 1, "in_edge_nf": 2, - "num_layers": 4, - "out_dim": 1, - "canon_num_layers": 4, - "canon_hidden_dim": 16, + "num_layers": 4, #4 + "out_dim": 4, + "canon_num_layers": 2, + "canon_hidden_dim": 32, "canon_layer_pooling": "mean", "canon_final_pooling": "mean", "canon_nonlinearity": "relu", - "canon_feature": "p", + "canon_feature": "pv", "canon_translation": False, "canon_angular_feature": 0, "canon_dropout": 0.5, @@ -34,11 +36,16 @@ "final_pooling": "mean", "nonlinearity": "relu", "angular_feature": "pv", - "dropout": 0, + "dropout": 0.5, #0 + "nheads": 8, + "ff_hidden": 128, } - class EuclideangraphCanonFunction(pl.LightningModule): + """ + Returns rotation matrix and translation vectors for canonicalization + following eqns (9) and (10) in https://arxiv.org/pdf/2211.06489.pdf. + """ def __init__(self, hyperparams): super().__init__() self.model_type = hyperparams.canon_model_type @@ -73,9 +80,22 @@ def __init__(self, hyperparams): }[self.model_type]() def forward(self, nodes, loc, edges, vel, edge_attr, charges): + """ + Returns rotation matrix and translation vectors, which are denoted as O and t respectively in eqn. 10 + of https://arxiv.org/pdf/2211.06489.pdf. + + Args: + `nodes`: Norms of velocity vectors. Shape: (n_nodes*batch_size) x 1 + `loc`: Starting locations of nodes. Shape: (n_nodes*batch_size) x 3 + `edges`: list of length 2, where each element is a 2000 dimensional tensor + `vel`: Starting velocities of nodes. Shape: (n_nodes*batch_size) x 3 + `edge_attr`: Products of charges and squared relative distances between adjacent nodes (each have their own column). Shape: (n_edges*batch_size) x 2 + `charges`: Charges of nodes . Shape: (n_nodes * batch_size) x 1 + """ + # (n_nodes * batch_size) x 3 x 3, (n_nodes * batch_size) x 3 rotation_vectors, translation_vectors = self.model(nodes, loc, edges, vel, edge_attr, charges) - - rotation_matrix = self.modified_gram_schmidt(rotation_vectors) + # Apply gram schmidt to make vectors orthogonal for rotation matrix + rotation_matrix = self.modified_gram_schmidt(rotation_vectors) # (n_nodes * batch_size) x 3 x 3 return rotation_matrix, translation_vectors @@ -104,6 +124,9 @@ def modified_gram_schmidt(self, vectors): class EuclideangraphPredFunction(pl.LightningModule): + """ + Defines a neural network that makes predictions after canonicalization. + """ def __init__(self, hyperparams): super().__init__() self.model_type = hyperparams.pred_model_type @@ -112,6 +135,9 @@ def __init__(self, hyperparams): self.input_dim = hyperparams.input_dim self.in_node_nf = hyperparams.in_node_nf self.in_edge_nf = hyperparams.in_edge_nf + self.ff_hidden = hyperparams.ff_hidden + self.nheads = hyperparams.nheads + self.dropout = hyperparams.dropout model_hyperparams = { "num_layers": self.num_layers, @@ -119,12 +145,16 @@ def __init__(self, hyperparams): "input_dim": self.input_dim, "in_node_nf": self.in_node_nf, "in_edge_nf": self.in_edge_nf, + "ff_hidden": self.ff_hidden, + "nheads": self.nheads, + "dropout": self.dropout, } self.model = { "GNN": lambda: GNN(define_hyperparams(model_hyperparams)), "EGNN": lambda: EGNN_vel(define_hyperparams(model_hyperparams)), "vndeepsets": lambda: VNDeepSets(define_hyperparams(model_hyperparams)), + "Transformer": lambda: Transformer(define_hyperparams(model_hyperparams)) }[self.model_type]() def forward(self, nodes, loc, edges, vel, edge_attr, charges): @@ -145,17 +175,41 @@ def __init__(self, hyperparams): self.canon_function.freeze() def forward(self, nodes, loc, edges, vel, edge_attr, charges): + """ + Returns predicted coordinates. + + Args: + `nodes`: Norms of velocity vectors. Shape: (n_nodes*batch_size) x 1 + `loc`: Starting locations of nodes. Shape: (n_nodes*batch_size) x coord_dim + `edges`: list of length 2, where each element is a 2000 dimensional tensor + `vel`: Starting velocities of nodes. Shape: (n_nodes*batch_size) x vel_dim + `edge_attr`: Products of charges and squared relative distances between adjacent nodes (each have their own column). Shape: (n_edges*batch_size) x 2 + `charges`: Charges of nodes . Shape: (n_nodes * batch_size) x 1 + """ + # Rotation and translation vectors from eqn (10) in https://arxiv.org/pdf/2211.06489.pdf. + # Shapes: (n_nodes * batch_size) x 3 x 3 and (n_nodes * batch_size) x 3 + # ie. One rotation matrix and one translation vector for each node. + # QUESTION: IS ROTATION MATRIX THE SAME FOR ALL NODES IN A BATCH? rotation_matrix, translation_vectors = self.canon_function(nodes, loc, edges, vel, edge_attr, charges) - rotation_matrix_inverse = rotation_matrix.transpose(1, 2) + rotation_matrix_inverse = rotation_matrix.transpose(1, 2) # Inverse of a rotation matrix is its transpose. + # Canonicalizes coordinates by rotating node coordinates and translation vectors by inverse rotation. + # Shape: (n_nodes * batch_size) x coord_dim. + # loc[:,None, :] adds a dimension to loc, so that it can be multiplied with rotation_matrix_inverse. (n_nodes * batch_size) x 1 x 3 canonical_loc = ( torch.bmm(loc[:, None, :], rotation_matrix_inverse).squeeze() - torch.bmm(translation_vectors[:, None, :], rotation_matrix_inverse).squeeze() ) + # Canonicalizes velocities. + # Shape: (n_nodes * batch_size) x vel_dim. canonical_vel = torch.bmm(vel[:, None, :], rotation_matrix_inverse).squeeze() + # Makes prediction on canonical inputs. + # Shape: (n_nodes * batch_size) x coord_dim. position_prediction = self.pred_function(nodes, canonical_loc, edges, canonical_vel, edge_attr, charges) + # Applies rotation to predictions, following equation (10) from https://arxiv.org/pdf/2211.06489.pdf + # Shape: (n_nodes * batch_size) x coord_dim. position_prediction = ( torch.bmm(position_prediction[:, None, :], rotation_matrix).squeeze() + translation_vectors ) diff --git a/canonical_network/models/gcl.py b/canonical_network/models/gcl.py index 0b648e5..b0357fa 100644 --- a/canonical_network/models/gcl.py +++ b/canonical_network/models/gcl.py @@ -22,14 +22,14 @@ def __init__(self, nin, nout, nh): def forward(self, x): return self.net(x) - +# Defines a parent class for all the following GNN layers. class GCL_basic(nn.Module): """Graph Neural Net with global state and fixed number of nodes per graph. Args: - hidden_dim: Number of hidden units. - num_nodes: Maximum number of nodes (for self-attentive pooling). - global_agg: Global aggregation function ('attn' or 'sum'). - temp: Softmax temperature. + `hidden_dim`: Number of hidden units. + `num_nodes`: Maximum number of nodes (for self-attentive pooling). + `global_agg`: Global aggregation function ('attn' or 'sum'). + `temp`: Softmax temperature. """ def __init__(self): @@ -42,19 +42,33 @@ def node_model(self, h, edge_index, edge_attr): pass def forward(self, x, edge_index, edge_attr=None): + + """ + Based on equation (2) in https://arxiv.org/pdf/2102.09844.pdf. + A matrix of edge features is created and used to update node features (m and h in paper). + + Args: + `x`: Matrix of node embeddings. Shape: (n_nodes * batch_size) x hidden_dim + `edge_index`: Length 2 list of tensors, containing indices of adjacent nodes; each shape (n_edges * batch_size). + """ row, col = edge_index + # phi_e in the paper. returns a matrix m of edge features used to update node and feature embeddings. + # x[row], x[col] are embeddings of adjacent nodes + # edge_attr = edge attributes/features edge_feat = self.edge_model(x[row], x[col], edge_attr) - x = self.node_model(x, edge_index, edge_feat) + x = self.node_model(x, edge_index, edge_feat) # updates node embeddings (phi_h in the paper) return x, edge_feat - +# Graph convolutional layer +# Based on equation (2) from https://arxiv.org/pdf/2102.09844.pdf class GCL(GCL_basic): """Graph Neural Net with global state and fixed number of nodes per graph. + Args: - hidden_dim: Number of hidden units. - num_nodes: Maximum number of nodes (for self-attentive pooling). - global_agg: Global aggregation function ('attn' or 'sum'). - temp: Softmax temperature. + `hidden_dim`: Number of hidden units. + `num_nodes`: Maximum number of nodes (for self-attentive pooling). + `global_agg`: Global aggregation function ('attn' or 'sum'). + `temp`: Softmax temperature. """ def __init__( @@ -73,7 +87,7 @@ def __init__( self.attention = attention self.t_eq = t_eq self.recurrent = recurrent - input_edge_nf = input_nf * 2 + input_edge_nf = input_nf * 2 # because we concatenate self.edge_mlp = nn.Sequential( nn.Linear(input_edge_nf + edges_in_nf, hidden_dim, bias=bias), act_fn, @@ -93,33 +107,49 @@ def __init__( # self.gru = nn.GRUCell(hidden_dim, hidden_dim) def edge_model(self, source, target, edge_attr): - edge_in = torch.cat([source, target], dim=1) + """ + Returns matrix m from eqn. (2) in paper, of shape (batch_size * n_edges) x hidden_dim. + + Args: + `source`: Embeddings of nodes start of edge. Shape: (batch_size * n_edges) x input_nf + `target`: Embeddings of nodes at end of edge. Shape: (batch_size * n_edges) x input_nf + `edge_attr`: Attributes of edges. Shape: (batch_size * n_edges) x edge_attr_dim + """ + edge_in = torch.cat([source, target], dim=1) # (batch_size * n_edges) x (input_edge_nf) if edge_attr is not None: - edge_in = torch.cat([edge_in, edge_attr], dim=1) - out = self.edge_mlp(edge_in) + edge_in = torch.cat([edge_in, edge_attr], dim=1) # (batch_size * n_edges) x (input_edge_nf + edges_in_nf) + out = self.edge_mlp(edge_in) # m from paper, (batch_size * n_edges) x hidden_dim if self.attention: att = self.att_mlp(torch.abs(source - target)) out = out * att - return out + return out #(batch_size * n_edges) x hidden_dim def node_model(self, h, edge_index, edge_attr): + """ + Returns updated node embeddings, h, from paper. Shape: (n_nodes * batch_size) x output_nf + + Args: + `h`: current node embeddings. Shape: (n_nodes * batch_size) x input_nf + `edge_index`: Indices of adjacent nodes. Shape: (n_edges * batch_size) x 2 + `edge_attr`: Attributes of edges. Shape: (batch_size * n_edges) x hidden_dim (this is the output of edge_model) + """ row, col = edge_index - agg = unsorted_segment_sum(edge_attr, row, num_segments=h.size(0)) - out = torch.cat([h, agg], dim=1) - out = self.node_mlp(out) + # m_i from paper, where m__i = sum of edge attributes for edges adjacent to i (n_nodes x edge_attr_dim) + agg = unsorted_segment_sum(data=edge_attr, segment_ids=row, num_segments=h.size(0)) + out = torch.cat([h, agg], dim=1) + out = self.node_mlp(out) # phi_h from the paper. Shape: (n_nodes * batch_size) x output_nf if self.recurrent: out = out + h # out = self.gru(out, h) - return out - + return out #Shape: (n_nodes * batch_size) x output_nf class GCL_rf(GCL_basic): """Graph Neural Net with global state and fixed number of nodes per graph. Args: - hidden_dim: Number of hidden units. - num_nodes: Maximum number of nodes (for self-attentive pooling). - global_agg: Global aggregation function ('attn' or 'sum'). - temp: Softmax temperature. + `hidden_dim`: Number of hidden units. + `num_nodes`: Maximum number of nodes (for self-attentive pooling). + `global_agg`: Global aggregation function ('attn' or 'sum'). + `temp`: Softmax temperature. """ def __init__(self, nf=64, edge_attr_nf=0, reg=0, act_fn=nn.LeakyReLU(0.2), clamp=False): @@ -148,13 +178,15 @@ def node_model(self, x, edge_index, edge_attr): return x_out +# Equivariant graph convolutional layer +# Based on equations (3) - (6) from https://arxiv.org/pdf/2102.09844.pdf class E_GCL(nn.Module): """Graph Neural Net with global state and fixed number of nodes per graph. Args: - hidden_dim: Number of hidden units. - num_nodes: Maximum number of nodes (for self-attentive pooling). - global_agg: Global aggregation function ('attn' or 'sum'). - temp: Softmax temperature. + `hidden_dim`: Number of hidden units. + `num_nodes`: Maximum number of nodes (for self-attentive pooling). + `global_agg`: Global aggregation function ('attn' or 'sum'). + `temp`: Softmax temperature. """ def __init__( @@ -198,7 +230,7 @@ def __init__( nn.Linear(hidden_dim + input_nf + nodes_att_dim, hidden_dim), act_fn, nn.Linear(hidden_dim, output_nf) ) - layer = nn.Linear(hidden_dim, num_vectors_in * num_vectors_out, bias=False) + layer = nn.Linear(hidden_dim, num_vectors_in * num_vectors_out, bias=False) # outputs a scalar torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) self.clamp = clamp @@ -218,50 +250,90 @@ def __init__( # self.gru = nn.GRUCell(hidden_dim, hidden_dim) def edge_model(self, source, target, radial, edge_attr): + """ + Returns matrix m from eqn. (3) from paper, of shape (batch_size * n_edges) x hidden_dim. + + Args: + `source`: Embeddings of nodes start of edge. Shape: (batch_size * n_edges) x input_nf + `target`: Embeddings of nodes at end of edge. Shape: (batch_size * n_edges) x input_nf + `radial`: Squared distances between coordinates of adjacent nodes. Shape: (n_edges * batch_size) x 1 + `edge_attr`: Attributes of edges. Shape: (batch_size * n_edges) x edge_attr_dim + """ if edge_attr is None: # Unused. out = torch.cat([source, target, radial], dim=1) else: - out = torch.cat([source, target, radial, edge_attr], dim=1) - out = self.edge_mlp(out) + out = torch.cat([source, target, radial, edge_attr], dim=1) # concatenates inputs to be passed into phi_e + out = self.edge_mlp(out) # phi_e from eqn. (3). Shape: (n_nodes * batch_size) x hidden_dim if self.attention: att_val = self.att_mlp(out) out = out * att_val - return out - - def node_model(self, x, edge_index, edge_attr, node_attr): - row, col = edge_index - agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) + return out #Shape: (n_nodes * batch_size) x hidden_dim + + def node_model(self, h, edge_index, edge_attr, node_attr): + """ + Returns tuple containing updated node embeddings, h, from eqn. (6). and m_i from eqn. (5). + Shape: ((n_nodes * batch_size) x output_nf, (n_nodes * batch_size) x (2*hidden_dim)) + + Args: + `h`: Node feature embeddings. Shape: (n_nodes * batch_size) x input_nf + `edge_index`: Indices of adjacent nodes. Shape: (n_edges * batch_size) x 2 + `edge_attr`: Attributes of edges. Matrix m from eqn. (3). Shape: (n_edges * batch_size) x hidden_dim (this is the output of edge_model) + `node_attr`: Node coordinate embeddings. Shape: (n_nodes * batch_size) x coord_dim + """ + row, col = edge_index # Indices of adjacent nodes + agg = unsorted_segment_sum(edge_attr, row, num_segments=h.size(0)) # (n_nodes * batch_size) x hidden_dim. m_i from paper. if node_attr is not None: - agg = torch.cat([x, agg, node_attr], dim=1) + agg = torch.cat([h, agg, node_attr], dim=1) else: - agg = torch.cat([x, agg], dim=1) - out = self.node_mlp(agg) + agg = torch.cat([h, agg], dim=1) # concatenate inputs for phi_h. (n_nodes * batch_size) x (2*hidden_dim) + # phi_h from eqn. (6). Updates node feature embeddings. Shape: (n_nodes * batch_size) x output_nf + out = self.node_mlp(agg) if self.recurrent: - out = x + out - return out, agg + out = h + out + return out, agg # Shape: ((n_nodes * batch_size) x output_nf, (n_nodes * batch_size) x (2*hidden_dim)) def coord_model(self, coord, edge_index, coord_diff, radial, edge_feat): - row, col = edge_index - coord_matrix = self.coord_mlp(edge_feat).view(-1, self.num_vectors_in, self.num_vectors_out) + """ + Returns updated coordinate embeddings from eqn. (4). Shape: n_nodes * batch_size x 3 x 1 + + Args: + `coord`: Coordinates of nodes. Shape: (batch_size * n_nodes) x coord_dim + `edge_index`: Indices of adjacent nodes. Shape: (n_edges * batch_size) x 2 + `coord_diff`: Differences between coords of adjacent nodes. Shape: (batch_size * n_edges) x coord_dim + `radial`: Squared distances of coords of adjacent nodes. Shape: (n_edges * batch_size) x 1 + `edge_feat`: Matrix m from eqn. (3). (n_edges * batch_size) x hidden_dim + """ + row, col = edge_index # indices of adjacent nodes + # Eqn. (4) phi_x(m_ij). Shape: (n_edges * batch_size) x num_vectors_in x num_vectors_out + coord_matrix = self.coord_mlp(edge_feat).view(-1, self.num_vectors_in, self.num_vectors_out) if coord_diff.dim() == 2: coord_diff = coord_diff.unsqueeze(2) - coord = coord.unsqueeze(2).repeat(1, 1, self.num_vectors_out) + coord = coord.unsqueeze(2).repeat(1, 1, self.num_vectors_out) # n_nodes * batch_size x coord_dim x num_vectors_out # coord_diff = coord_diff / radial.unsqueeze(1) - trans = torch.einsum("bij,bci->bcj", coord_matrix, coord_diff) + # + trans = torch.einsum("bij,bci->bcj", coord_matrix, coord_diff) # (n_edges * batch_size) x coord_dim x 1 trans = torch.clamp( trans, min=-100, max=100 ) # This is never activated but just in case it case it explosed it may save the train - agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) + agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) # n_nodes * batch_size x coord_dim x 1. sum from eqn. (4) if self.last_layer: coord = coord.mean(dim=2, keepdim=True) + agg * self.coords_weight else: - coord += agg * self.coords_weight - return coord + coord += agg * self.coords_weight # Update coordinate embeddings following eqn. (4) + return coord # def coord2radial(self, edge_index, coord): - row, col = edge_index - coord_diff = coord[row] - coord[col] - radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) + """ + Returns a tuple of differences and squared differences of coordinates adjacent vertices. + ((n_edges * batch_size) x 1, (batch_size * n_edges) x coord_dim) + + Args: + `edge_attr`: Attributes of edges. Shape: (batch_size * n_edges) x hidden_dim (this is the output of edge_model) + `coord`: Coordinates of nodes. Shape: (batch_size * n_nodes) x coord_dim + """ + row, col = edge_index # indices of adjacent nodes. + coord_diff = coord[row] - coord[col] # differences between cords of adjacent nodes. Shape: (batch_size * n_edges) x coord_dim + radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) # squared distances. Shape: (n_edges * batch_size) x 1 if self.norm_diff: norm = torch.sqrt(radial) + 1 @@ -270,20 +342,29 @@ def coord2radial(self, edge_index, coord): if radial.dim() == 3: radial = radial.squeeze(1) - return radial, coord_diff + return radial, coord_diff # returns squared dists and diffs def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None): - row, col = edge_index - radial, coord_diff = self.coord2radial(edge_index, coord) - - edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) - coord = self.coord_model(coord, edge_index, coord_diff, radial, edge_feat) + """ + Based on equations (3)-(6) in https://arxiv.org/pdf/2102.09844.pdf. + Updates node feature and coordinate embeddings. + + Args: + `h`: Node feature embeddings. Shape: (n_nodes * batch_size) x hidden_dim + `edge_index`: Indices of adjacent nodes. Shape: (n_edges * batch_size) x 2 + `coord`: Node coordinates. Shape: (n_nodes * batch_size) x coord_dim + """ + row, col = edge_index #indices of adjacent nodes + # squared dists and diffs. (n_edges * batch_size) x 1, (batch_size * n_edges) x coord_dim + radial, coord_diff = self.coord2radial(edge_index, coord) + edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) #Shape: (n_edges * batch_size) x hidden_dim + coord = self.coord_model(coord, edge_index, coord_diff, radial, edge_feat) # Updated coord embeddings from eqn. 4. (n_nodes * batch_size) x coord_dim x 1 h, agg = self.node_model(h, edge_index, edge_feat, node_attr) # coord = self.node_coord_model(h, coord) # x = self.node_model(x, edge_index, x[col], u, batch) # GCN return h, coord, edge_attr - +# Based on section 3.2 in https://arxiv.org/pdf/2102.09844.pdf. class E_GCL_vel(E_GCL): """Graph Neural Net with global state and fixed number of nodes per graph. Args: @@ -333,17 +414,28 @@ def __init__( ) def forward(self, h, edge_index, coord, vel, edge_attr=None, node_attr=None): - row, col = edge_index + """ + Based on section 3.2 in https://arxiv.org/pdf/2102.09844.pdf. + Updates node feature, coordinate, and velocity embeddings. + + Args: + `h`: Node feature embeddings. Shape: (n_nodes * batch_size) x hidden_dim + `edge_index`: Indices of adjacent nodes. Shape: (n_edges * batch_size) x 2 + `coord`: Node coordinates. Shape: (n_nodes * batch_size) x coord_dim + `vel`: Node velocities. Shape: (n_nodes * batch_size) x vel_dim + """ + row, col = edge_index #Indices of adjacent nodes + # squared dists and diffs. (n_edges * batch_size) x 1, (batch_size * n_edges) x coord_dim radial, coord_diff = self.coord2radial(edge_index, coord) - edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) - coord = self.coord_model(coord, edge_index, coord_diff, radial, edge_feat) - - coord_vel_matrix = self.coord_mlp_vel(h).view(-1, self.num_vectors_in, self.num_vectors_out) + edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) #Shape: (n_edges * batch_size) x hidden_dim + coord = self.coord_model(coord, edge_index, coord_diff, radial, edge_feat) # Updated coord embeddings from eqn. 4. (n_nodes * batch_size) x coord_dim x 1 + # phi_v from eqn. 7. Shape: (n_nodes * batch_size) x num_vectors_in * num_vectors_out + coord_vel_matrix = self.coord_mlp_vel(h).view(-1, self.num_vectors_in, self.num_vectors_out) if vel.dim() == 2: vel = vel.unsqueeze(2) - coord += torch.einsum("bij,bci->bcj", coord_vel_matrix, vel) - h, agg = self.node_model(h, edge_index, edge_feat, node_attr) + coord += torch.einsum("bij,bci->bcj", coord_vel_matrix, vel) # eqn. (7) + h, agg = self.node_model(h, edge_index, edge_feat, node_attr) # updates node embeddings # coord = self.node_coord_model(h, coord) # x = self.node_model(x, edge_index, x[col], u, batch) # GCN return h, coord, edge_attr @@ -394,7 +486,7 @@ def node_model(self, x, edge_index, edge_m): def unsorted_segment_sum(data, segment_ids, num_segments): """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.""" - result_shape = (num_segments, data.size(1)) + result_shape = (num_segments, data.size(1)) result = data.new_full(result_shape, 0) # Init empty result tensor. segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) result.scatter_add_(0, segment_ids, data) diff --git a/canonical_network/prepare/nbody_data.py b/canonical_network/prepare/nbody_data.py index 06cb1b0..bb8b686 100644 --- a/canonical_network/prepare/nbody_data.py +++ b/canonical_network/prepare/nbody_data.py @@ -42,21 +42,30 @@ def preprocess(self, loc, vel, edges, charges): # cast to torch and swap n_nodes <--> n_features dimensions loc, vel = torch.Tensor(loc).transpose(2, 3), torch.Tensor(vel).transpose(2, 3) n_nodes = loc.size(2) - loc = loc[0:self.max_samples, :, :, :] # limit number of samples - vel = vel[0:self.max_samples, :, :, :] # speed when starting the trajectory - charges = charges[0:self.max_samples] + loc = loc[0:self.max_samples, :, :, :] # limit number of samples, max_samples x 49 x 5 x 3 + vel = vel[0:self.max_samples, :, :, :] # speed when starting the trajectory, max_samples x 49 x 5 x 3 + charges = charges[0:self.max_samples] # max_samples x 5 x 1 edge_attr = [] + # edges is currently 10000 x 5 x 5. + # i believe M = edges[i,:,:] is a symmetric matrix where M[j][k] = charges[i][j] * charges[i][k] #Initialize edges and edge_attributes rows, cols = [], [] for i in range(n_nodes): for j in range(n_nodes): + # if i != j, append charge(node_i)*charge(node_j) for all samples + # and save the combination of row and col if i != j: - edge_attr.append(edges[:, i, j]) + # edge_attr.append(edges[:, i, j]) #COMMENTED THIS OUT! SEEMED TO USE UNEEDED ROWS + # could we instead just use edge_attr.append(edges[:self.max_samples, i, j])? --> I think so. would save memory too + edge_attr.append(edges[:self.max_samples, i, j]) rows.append(i) cols.append(j) + # Once loop is over, + # edge_attr = list all charge products between distinct nodes (20 x max_samples) (ie. product for all edges) + # where edge_attr[i] = product of charges for node rows[i] and cols[i] edges = [rows, cols] - edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # swap n_nodes <--> batch_size and add nf dimension + edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # swap n_nodes <--> batch_size and add nf dimension -> 10000 x 20 x 1 return torch.Tensor(loc), torch.Tensor(vel), torch.Tensor(edge_attr), edges, torch.Tensor(charges) diff --git a/canonical_network/train_nbody.py b/canonical_network/train_nbody.py index 5e1fc3a..9a3a307 100644 --- a/canonical_network/train_nbody.py +++ b/canonical_network/train_nbody.py @@ -3,15 +3,28 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping import wandb import os +import torch from canonical_network.prepare.nbody_data import NBodyDataModule from canonical_network.models.euclideangraph_model import NBODY_HYPERPARAMS, EuclideanGraphModel -from canonical_network.models.euclideangraph_base_models import EGNN_vel, GNN, VNDeepSets +from canonical_network.models.euclideangraph_base_models import EGNN_vel, GNN, VNDeepSets, Transformer + +# Change model here +HYPERPARAMS = {"model": "Transformer", + "canon_model_type": "vndeepsets", + "pred_model_type": "GNN", + "batch_size": 100, + "dryrun": False, + "use_wandb": False, + "checkpoint": False, + "num_epochs": 10000, + "num_workers":12, + "auto_tune":False, + "seed": 0} -HYPERPARAMS = {"model": "vndeepsets", "canon_model_type": "vndeepsets", "pred_model_type": "GNN", "batch_size": 100, "dryrun": False, "use_wandb": False, "checkpoint": False, "num_epochs": 1000, "num_workers":0, "auto_tune":False, "seed": 0} def train_nbody(): - hyperparams = HYPERPARAMS | NBODY_HYPERPARAMS + hyperparams = HYPERPARAMS | NBODY_HYPERPARAMS # merges the dictionaries if not hyperparams["use_wandb"]: print('Wandb disable for logging.') @@ -21,14 +34,16 @@ def train_nbody(): os.environ["WANDB_MODE"] = "online" wandb.login() - wandb.init(config=hyperparams, entity="", project="canonical_network-nbody") - wandb_logger = WandbLogger(project="canonical_network-nbody") + wandb.init(config=hyperparams, entity="symmetry_group", project="canonical_network-nbody-transformer") + wandb_logger = WandbLogger(project="canonical_network-nbody-transformer") hyperparams = wandb.config + # This is passed to the model as the hyperparameters + # Can now access data using . operator. + # eg. hyperparams.hidden_dim nbody_hypeyparams = hyperparams pl.seed_everything(nbody_hypeyparams.seed) - nbody_data = NBodyDataModule(nbody_hypeyparams) checkpoint_callback = ModelCheckpoint(dirpath="canonical_network/results/nbody/model_saves", filename= nbody_hypeyparams.model + "_" + wandb.run.name + "_{epoch}_{valid/loss:.3f}", monitor="valid/loss", mode="min") @@ -36,16 +51,21 @@ def train_nbody(): early_stop_lr_callback = EarlyStopping(monitor="lr", min_delta=0.0, patience=10000, verbose=True, mode="min", stopping_threshold=1.1e-6) callbacks = [checkpoint_callback, early_stop_lr_callback, early_stop_metric_callback] if nbody_hypeyparams.checkpoint else [early_stop_lr_callback, early_stop_metric_callback] - model = {"euclideangraph_model": lambda: EuclideanGraphModel(nbody_hypeyparams), "EGNN": lambda: EGNN_vel(nbody_hypeyparams), "GNN": lambda: GNN(nbody_hypeyparams), "vndeepsets": lambda: VNDeepSets(nbody_hypeyparams)}[nbody_hypeyparams.model]() + # Instantiates model using hyperparams + model = {"euclideangraph_model": lambda: EuclideanGraphModel(nbody_hypeyparams), + "EGNN": lambda: EGNN_vel(nbody_hypeyparams), + "GNN": lambda: GNN(nbody_hypeyparams), + "vndeepsets": lambda: VNDeepSets(nbody_hypeyparams), + "Transformer": lambda: Transformer(nbody_hypeyparams), + }[nbody_hypeyparams.model]() if nbody_hypeyparams.auto_tune: - trainer = pl.Trainer(fast_dev_run=nbody_hypeyparams.dryrun, max_epochs=nbody_hypeyparams.num_epochs, accelerator="auto", auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, callbacks=callbacks, deterministic=False) + trainer = pl.Trainer(fast_dev_run=nbody_hypeyparams.dryrun, max_epochs=nbody_hypeyparams.num_epochs, accelerator="auto", auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, callbacks=callbacks, deterministic=False, log_every_n_steps=30) trainer.tune(model, datamodule=nbody_data, enable_checkpointing=nbody_hypeyparams.checkpoint) elif nbody_hypeyparams.dryrun: - trainer = pl.Trainer(fast_dev_run=False, max_epochs=2, accelerator="auto", limit_train_batches=10, limit_val_batches=10, logger=wandb_logger, callbacks=callbacks, deterministic=False, enable_checkpointing=nbody_hypeyparams.checkpoint) + trainer = pl.Trainer(fast_dev_run=False, max_epochs=2, accelerator="auto", limit_train_batches=10, limit_val_batches=10, logger=wandb_logger, callbacks=callbacks, deterministic=False, enable_checkpointing=nbody_hypeyparams.checkpoint, log_every_n_steps=30) else: - trainer = pl.Trainer(fast_dev_run=nbody_hypeyparams.dryrun, max_epochs=nbody_hypeyparams.num_epochs, accelerator="auto", logger=wandb_logger, callbacks=callbacks, deterministic=False, enable_checkpointing=nbody_hypeyparams.checkpoint) - + trainer = pl.Trainer(fast_dev_run=nbody_hypeyparams.dryrun, max_epochs=nbody_hypeyparams.num_epochs, accelerator="auto", logger=wandb_logger, callbacks=callbacks, deterministic=False, enable_checkpointing=nbody_hypeyparams.checkpoint, log_every_n_steps=30) trainer.fit(model, datamodule=nbody_data) diff --git a/pyproject.toml b/pyproject.toml index cd1097f..3b4e6b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,7 @@ version = "0.1.0" [tool.poetry.dependencies] einops = "^0.4.1" python = "^3.9" -torch = "1.10.1" -torch-scatter = "^2.0.9" +torch = "2.1.2" numpy = "^1.22.4" pytorch-lightning = "^1.6.4" torchmetrics = "^0.9.1" diff --git a/setup.md b/setup.md new file mode 100644 index 0000000..b65b440 --- /dev/null +++ b/setup.md @@ -0,0 +1,46 @@ +# Setup Guide +**Note:** The following issues occured on the Windows Subsystem for Linux (WSL). I was succesful in following these instructions in the exact order listed below. Results may vary. + +1. Start by installing `pytorch3d`. To do this, execute the following commands. I recommend copying these instructions exactly, even if you already have `torchvision` and `cuda` installed, since `pytorch3d` is very picky about the versions of packages. +```{console} +conda create -n pytorch3d python=3.9 +conda activate pytorch3d +conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia +conda install -c fvcore -c iopath -c conda-forge fvcore iopath +``` +Next, run the following. +```{console} +conda install pytorch3d -c pytorch3d +``` + +Once this is done, a conda environment will be created named `pytorch3d`; all the following commands were executed under this environment. + +2. Next, you will possibly run into some of the following errors: + + - `ModuleNotFoundError: No module named 'canonical_network'` + + To solve this, navigate to `/canonical_network` and run the command `pip install -e .` + + - `ERROR: Could not find a version that satisfies the requirement torch==1.10.1 (from canonical-network) (from versions: 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2) + ERROR: No matching distribution found for torch==1.10.1` + + To solve this, open pyproject.toml and change the torch version to the one currently installed. + + - `ERROR: Failed building wheel for torch-scatter` + + To solve this, delete line in pyproject.toml for torch-scatter + +3. Next, install the required packages +```{console} +conda install lightning -c conda-forge +pip install wandb +conda install conda-forge::kornia +conda install conda-forge::torch-scatter +``` + +4. Finally, you need to upload file `canonical_network/canonical_network/data/n_body_system/dataset` (currently not on the GitHub repo.) + +5. (Optional) Can rename environment. +``` +conda rename -n pytorch3d new-env-name +``` \ No newline at end of file diff --git a/sweep.yaml b/sweep.yaml new file mode 100644 index 0000000..b06bcd4 --- /dev/null +++ b/sweep.yaml @@ -0,0 +1,12 @@ +program: /home/jikael/mcgill/comp396/canonical_network/canonical_network/canonical_network/train_nbody.py +entity: symmetry_group +method: bayes +metric: + goal: minimize + name: valid/loss.min +name: Transformer Optimization 3 +parameters: + learning_rate: + distribution: log_uniform_values + min: 1e-4 + max: 1e-3 \ No newline at end of file