Skip to content

Commit

Permalink
factorized autoencoder.
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamer2368 committed Oct 23, 2024
1 parent d484837 commit 0daa0b1
Showing 1 changed file with 112 additions and 122 deletions.
234 changes: 112 additions & 122 deletions src/lasdi/latent_space.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
import torch
import numpy as np

# activation dict
act_dict = {'ELU': torch.nn.ELU,
'hardshrink': torch.nn.Hardshrink,
'hardsigmoid': torch.nn.Hardsigmoid,
'hardtanh': torch.nn.Hardtanh,
'hardswish': torch.nn.Hardswish,
'leakyReLU': torch.nn.LeakyReLU,
'logsigmoid': torch.nn.LogSigmoid,
'multihead': torch.nn.MultiheadAttention,
'PReLU': torch.nn.PReLU,
'ReLU': torch.nn.ReLU,
'ReLU6': torch.nn.ReLU6,
'RReLU': torch.nn.RReLU,
'SELU': torch.nn.SELU,
'CELU': torch.nn.CELU,
'GELU': torch.nn.GELU,
'sigmoid': torch.nn.Sigmoid,
'SiLU': torch.nn.SiLU,
'mish': torch.nn.Mish,
'softplus': torch.nn.Softplus,
'softshrink': torch.nn.Softshrink,
'tanh': torch.nn.Tanh,
'tanhshrink': torch.nn.Tanhshrink,
'threshold': torch.nn.Threshold,
}

def initial_condition_latent(param_grid, physics, autoencoder):

'''
Expand All @@ -23,157 +49,121 @@ def initial_condition_latent(param_grid, physics, autoencoder):
Z0.append(z0)

return Z0

class Autoencoder(torch.nn.Module):
# set by physics.qgrid_size
qgrid_size = []
# prod(qgrid_size)
space_dim = -1
n_z = -1

# activation dict
act_dict = {'ELU': torch.nn.ELU,
'hardshrink': torch.nn.Hardshrink,
'hardsigmoid': torch.nn.Hardsigmoid,
'hardtanh': torch.nn.Hardtanh,
'hardswish': torch.nn.Hardswish,
'leakyReLU': torch.nn.LeakyReLU,
'logsigmoid': torch.nn.LogSigmoid,
'multihead': torch.nn.MultiheadAttention,
'PReLU': torch.nn.PReLU,
'ReLU': torch.nn.ReLU,
'ReLU6': torch.nn.ReLU6,
'RReLU': torch.nn.RReLU,
'SELU': torch.nn.SELU,
'CELU': torch.nn.CELU,
'GELU': torch.nn.GELU,
'sigmoid': torch.nn.Sigmoid,
'SiLU': torch.nn.SiLU,
'mish': torch.nn.Mish,
'softplus': torch.nn.Softplus,
'softshrink': torch.nn.Softshrink,
'tanh': torch.nn.Tanh,
'tanhshrink': torch.nn.Tanhshrink,
'threshold': torch.nn.Threshold,
}

def __init__(self, physics, config):
super(Autoencoder, self).__init__()

self.qgrid_size = physics.qgrid_size
self.space_dim = np.prod(self.qgrid_size)
hidden_units = config['hidden_units']
n_z = config['latent_dimension']
self.n_z = n_z
class MultiLayerPerceptron(torch.nn.Module):

n_layers = len(hidden_units)
self.n_layers = n_layers
def __init__(self, layer_sizes,
act_type='sigmoid', reshape_index=None, reshape_shape=None,
threshold=0.1, value=0.0, num_heads=1):
super(MultiLayerPerceptron, self).__init__()

fc1_e = torch.nn.Linear(self.space_dim, hidden_units[0])
torch.nn.init.xavier_uniform_(fc1_e.weight)
self.fc1_e = fc1_e
# including input, hidden, output layers
self.n_layers = len(layer_sizes)
self.layer_sizes = layer_sizes

if n_layers > 1:
for i in range(n_layers - 1):
fc_e = torch.nn.Linear(hidden_units[i], hidden_units[i + 1])
torch.nn.init.xavier_uniform_(fc_e.weight)
setattr(self, 'fc' + str(i + 2) + '_e', fc_e)
# Linear features between layers
self.fcs = []
for k in range(self.n_layers-1):
self.fcs += [torch.nn.Linear(layer_sizes[k], layer_sizes[k + 1])]
self.fcs = torch.nn.ModuleList(self.fcs)

fc_e = torch.nn.Linear(hidden_units[-1], n_z)
torch.nn.init.xavier_uniform_(fc_e.weight)
setattr(self, 'fc' + str(n_layers + 1) + '_e', fc_e)
# Reshape input or output layer
assert((reshape_index is None) or (reshape_index in [0, -1]))
assert((reshape_shape is None) or (np.prod(reshape_shape) == layer_sizes[reshape_index]))
self.reshape_index = reshape_index
self.reshape_shape = reshape_shape

act_type = config['activation'] if 'activation' in config else 'sigmoid'
# Initalize activation function
self.act_type = act_type
self.use_multihead = False
if act_type == "threshold":
#grab relevant initialization values from config
threshold = config["threshold"] if "threshold" in config else 0.1
value = config["value"] if "value" in config else 0.0
self.g_e = self.act_dict[act_type](threshold, value)
self.act = act_dict[act_type](threshold, value)

elif act_type == "multihead":
#grab relevant initialization values from config
num_heads = config['num_heads'] if 'num_heads' in config else 1
if n_layers > 1:
for i in range(n_layers):
setattr(self, 'a' + str(i + 1), self.act_dict[act_type](hidden_units[i], num_heads))
self.g_e = torch.nn.Identity() # No additional activation
self.use_multihead = True
if (self.n_layers > 3): # if you have more than one hidden layer
self.act = []
for i in range(self.n_layers-2):
self.act += [act_dict[act_type](layer_sizes[i+1], num_heads)]
else:
self.act = [torch.nn.Identity()] # No additional activation
self.act = torch.nn.ModuleList(self.fcs)

#all other activation functions initialized here
else:
self.g_e = self.act_dict[act_type]()

fc1_d = torch.nn.Linear(n_z, hidden_units[-1])
torch.nn.init.xavier_uniform_(fc1_d.weight)
self.fc1_d = fc1_d

if n_layers > 1:
for i in range(n_layers - 1, 0, -1):
fc_d = torch.nn.Linear(hidden_units[i], hidden_units[i - 1])
torch.nn.init.xavier_uniform_(fc_d.weight)
setattr(self, 'fc' + str(n_layers - i + 1) + '_d', fc_d)

fc_d = torch.nn.Linear(hidden_units[0], self.space_dim)
torch.nn.init.xavier_uniform_(fc_d.weight)
setattr(self, 'fc' + str(n_layers + 1) + '_d', fc_d)



def encoder(self, x):
# make sure the input has a proper shape
assert(list(x.shape[-len(self.qgrid_size):]) == self.qgrid_size)
# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-len(self.qgrid_size)]) + [self.space_dim])

for i in range(1, self.n_layers + 1):
fc = getattr(self, 'fc' + str(i) + '_e')
x = fc(x) # apply linear layer
if hasattr(self, 'a1'): # test if there is at least one attention layer
self.act = act_dict[act_type]()
return

def forward(self, x):
if (self.reshape_index == 0):
# make sure the input has a proper shape
assert(list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape)
# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-len(self.reshape_shape)]) + [self.layer_sizes[self.reshape_index]])

for i in range(self.n_layers-2):
x = self.fcs[i](x) # apply linear layer
if (self.use_multihead):
x = self.apply_attention(self, x, i)
x = self.g_e(x) # apply activation function
else:
x = self.act(x)

fc = getattr(self, 'fc' + str(self.n_layers + 1) + '_e')
x = fc(x)

return x
x = self.fcs[-1](x)

if (self.reshape_index == -1):
# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-1]) + self.reshape_shape)

def decoder(self, x):
return x

def apply_attention(self, x, act_idx):
x = x.unsqueeze(1) # Add sequence dimension for attention
x, _ = self.act[act_idx](x, x, x) # apply attention
x = x.squeeze(1) # Remove sequence dimension
return x

def init_weight(self):
# TODO(kevin): support other initializations?
for fc in self.fcs:
torch.nn.init.xavier_uniform_(fc.weight)
return

for i in range(1, self.n_layers + 1):
fc = getattr(self, 'fc' + str(i) + '_d')
x = fc(x) # apply linear layer
if hasattr(self, 'a1'): # test if there is at least one attention layer
x = self.apply_attention(self, x, self.n_layers - i)
class Autoencoder(torch.nn.Module):

x = self.g_e(x) # apply activation function
def __init__(self, physics, config):
super(Autoencoder, self).__init__()

fc = getattr(self, 'fc' + str(self.n_layers + 1) + '_d')
x = fc(x)
self.qgrid_size = physics.qgrid_size
self.space_dim = np.prod(self.qgrid_size)
hidden_units = config['hidden_units']
n_z = config['latent_dimension']
self.n_z = n_z

# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-1]) + self.qgrid_size)
layer_sizes = [self.space_dim] + hidden_units + [n_z]
#grab relevant initialization values from config
act_type = config['activation'] if 'activation' in config else 'sigmoid'
threshold = config["threshold"] if "threshold" in config else 0.1
value = config["value"] if "value" in config else 0.0
num_heads = config['num_heads'] if 'num_heads' in config else 1

return x
self.encoder = MultiLayerPerceptron(layer_sizes, act_type,
reshape_index=0, reshape_shape=self.qgrid_size,
threshold=threshold, value=value, num_heads=num_heads)

self.decoder = MultiLayerPerceptron(layer_sizes[::-1], act_type,
reshape_index=-1, reshape_shape=self.qgrid_size,
threshold=threshold, value=value, num_heads=num_heads)

return

def forward(self, x):

x = self.encoder(x)
x = self.decoder(x)

return x


def apply_attention(self, x, layer):
x = x.unsqueeze(1) # Add sequence dimension for attention
a = getattr(self, 'a' + str(layer))
x, _ = a(x, x, x) # apply attention
x = x.squeeze(1) # Remove sequence dimension

return x

def export(self):
dict_ = {'autoencoder_param': self.cpu().state_dict()}
Expand Down

0 comments on commit 0daa0b1

Please sign in to comment.