Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularizing encoder/decoder in Autoencoder class #10

Merged
merged 2 commits into from
Oct 26, 2024
Merged

Conversation

dreamer2368
Copy link
Collaborator

@dreamer2368 dreamer2368 commented Oct 10, 2024

Using DistributedDataParallel for data parallelism requires access to encoder and decoder as a torch.nn.Module. Current Autoencoder class provide encoder and decoder as member functions, though DistributedDataParallel cannot use custom member functions except forward.

  • lasdi.latent_space.MultiLayerPerceptron is now provided as a distinct module for a vanilla MLP.
  • lasdi.latent_space.Autoencoder simply contains two MultiLayerPerceptrons as encoder and decoder.

Per @punkduckable , we should implement multihead attention properly as a layer rather than an activation function. While this PR simply translates the current implementation, this is posted as issue #13 .

@dreamer2368 dreamer2368 added the RFR Ready for Review label Oct 10, 2024
Copy link
Collaborator

@punkduckable punkduckable left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, the code looks quite good (I went through it in detail yesterday and agree with pretty much all of the changes). There is one thing I would like to see changed, however: multiheaded attention. This isn't really an activation function, even though torch's api classifies it as such. It implements the attention layer in a transformer. This is really designed for mapping finite sequences to finite sequences. I do not think it makes any sense in this context. To make this point even clearer, notice that the apply_attention function uses the input matrix, x, for keys, queries, and values. This is a little strange to say the least. I think this should be removed from the MLP class. This would mean removing "multihead" from the activation dictionary, removing num_heads as an argument to the MLP initializer, and removing the apply_attention method. Below is a copy of latent_space.py with these changes implemented:

import torch
import numpy as np

activation dict

act_dict = {'ELU': torch.nn.ELU,
'hardshrink': torch.nn.Hardshrink,
'hardsigmoid': torch.nn.Hardsigmoid,
'hardtanh': torch.nn.Hardtanh,
'hardswish': torch.nn.Hardswish,
'leakyReLU': torch.nn.LeakyReLU,
'logsigmoid': torch.nn.LogSigmoid,
'PReLU': torch.nn.PReLU,
'ReLU': torch.nn.ReLU,
'ReLU6': torch.nn.ReLU6,
'RReLU': torch.nn.RReLU,
'SELU': torch.nn.SELU,
'CELU': torch.nn.CELU,
'GELU': torch.nn.GELU,
'sigmoid': torch.nn.Sigmoid,
'SiLU': torch.nn.SiLU,
'mish': torch.nn.Mish,
'softplus': torch.nn.Softplus,
'softshrink': torch.nn.Softshrink,
'tanh': torch.nn.Tanh,
'tanhshrink': torch.nn.Tanhshrink,
'threshold': torch.nn.Threshold,
}

def initial_condition_latent(param_grid, physics, autoencoder):

'''

Outputs the initial condition in the latent space: Z0 = encoder(U0)

'''

n_param = param_grid.shape[0]
Z0 = []

sol_shape = [1, 1] + physics.qgrid_size

for i in range(n_param):
    u0 = physics.initial_condition(param_grid[i])
    u0 = u0.reshape(sol_shape)
    u0 = torch.Tensor(u0)
    z0 = autoencoder.encoder(u0)
    z0 = z0[0, 0, :].detach().numpy()
    Z0.append(z0)

return Z0

class MultiLayerPerceptron(torch.nn.Module):

def __init__(self, layer_sizes,
             act_type='sigmoid', reshape_index=None, reshape_shape=None,
             threshold=0.1, value=0.0):
    super(MultiLayerPerceptron, self).__init__()

    # including input, hidden, output layers
    self.n_layers = len(layer_sizes)
    self.layer_sizes = layer_sizes

    # Linear features between layers
    self.fcs = []
    for k in range(self.n_layers-1):
        self.fcs += [torch.nn.Linear(layer_sizes[k], layer_sizes[k + 1])]
    self.fcs = torch.nn.ModuleList(self.fcs)
    self.init_weight()

    # Reshape input or output layer
    assert((reshape_index is None) or (reshape_index in [0, -1]))
    assert((reshape_shape is None) or (np.prod(reshape_shape) == layer_sizes[reshape_index]))
    self.reshape_index = reshape_index
    self.reshape_shape = reshape_shape

    # Initalize activation function
    self.act_type = act_type
    self.use_multihead = False
    if act_type == "threshold":
        self.act = act_dict[act_type](threshold, value)
    else:
        self.act = act_dict[act_type]()
    return

def forward(self, x):
    if (self.reshape_index == 0):
        # make sure the input has a proper shape
        assert(list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape)
        # we use torch.Tensor.view instead of torch.Tensor.reshape,
        # in order to avoid data copying.
        x = x.view(list(x.shape[:-len(self.reshape_shape)]) + [self.layer_sizes[self.reshape_index]])

    for i in range(self.n_layers-2):
        x = self.fcs[i](x) # apply linear layer
        if (self.use_multihead):
            x = self.apply_attention(self, x, i)
        else:
            x = self.act(x)

    x = self.fcs[-1](x)

    if (self.reshape_index == -1):
        # we use torch.Tensor.view instead of torch.Tensor.reshape,
        # in order to avoid data copying.
        x = x.view(list(x.shape[:-1]) + self.reshape_shape)

    return x

def init_weight(self):
    # TODO(kevin): support other initializations?
    for fc in self.fcs:
        torch.nn.init.xavier_uniform_(fc.weight)
    return

class Autoencoder(torch.nn.Module):

def __init__(self, physics, config):
    super(Autoencoder, self).__init__()

    self.qgrid_size = physics.qgrid_size
    self.space_dim = np.prod(self.qgrid_size)
    hidden_units = config['hidden_units']
    n_z = config['latent_dimension']
    self.n_z = n_z

    layer_sizes = [self.space_dim] + hidden_units + [n_z]
    #grab relevant initialization values from config
    act_type = config['activation'] if 'activation' in config else 'sigmoid'
    threshold = config["threshold"] if "threshold" in config else 0.1
    value = config["value"] if "value" in config else 0.0

    self.encoder = MultiLayerPerceptron(layer_sizes, act_type,
                                        reshape_index=0, reshape_shape=self.qgrid_size,
                                        threshold=threshold, value=value)
    
    self.decoder = MultiLayerPerceptron(layer_sizes[::-1], act_type,
                                        reshape_index=-1, reshape_shape=self.qgrid_size,
                                        threshold=threshold, value=value)

    return

def forward(self, x):

    x = self.encoder(x)
    x = self.decoder(x)

    return x

def export(self):
    dict_ = {'autoencoder_param': self.cpu().state_dict()}
    return dict_

def load(self, dict_):
    self.load_state_dict(dict_['autoencoder_param'])
    return

@dreamer2368
Copy link
Collaborator Author

Overall, the code looks quite good (I went through it in detail yesterday and agree with pretty much all of the changes). There is one thing I would like to see changed, however: multiheaded attention. This isn't really an activation function, even though torch's api classifies it as such. It implements the attention layer in a transformer. This is really designed for mapping finite sequences to finite sequences. I do not think it makes any sense in this context. To make this point even clearer, notice that the apply_attention function uses the input matrix, x, for keys, queries, and values. This is a little strange to say the least. I think this should be removed from the MLP class. This would mean removing "multihead" from the activation dictionary, removing num_heads as an argument to the MLP initializer, and removing the apply_attention method. Below is a copy of latent_space.py with these changes implemented:

@punkduckable , thanks for implementing a new version for this. As I posted in the PR, this PR simply translates the current implementation. While we could add a new change right here, I suggest you making it as a future PR in order to avoid code conflict for the next PRs from #11 to PR #16. For record, I also put this as issue #13 .

@dreamer2368
Copy link
Collaborator Author

Overall, the code looks quite good (I went through it in detail yesterday and agree with pretty much all of the changes). There is one thing I would like to see changed, however: multiheaded attention. This isn't really an activation function, even though torch's api classifies it as such. It implements the attention layer in a transformer. This is really designed for mapping finite sequences to finite sequences. I do not think it makes any sense in this context. To make this point even clearer, notice that the apply_attention function uses the input matrix, x, for keys, queries, and values. This is a little strange to say the least. I think this should be removed from the MLP class. This would mean removing "multihead" from the activation dictionary, removing num_heads as an argument to the MLP initializer, and removing the apply_attention method. Below is a copy of latent_space.py with these changes implemented:

@punkduckable , I saw that you already implemented this on the later PR #15. It doesn't make sense to duplicate the same feature here that will be merged later again.

@dreamer2368 dreamer2368 dismissed punkduckable’s stale review October 24, 2024 17:48

The requested change is already implemented in another PR #15.

Copy link
Collaborator

@punkduckable punkduckable left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! We will make the attention changes in PR 15.

@dreamer2368 dreamer2368 merged commit 4932179 into main Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFR Ready for Review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants