Skip to content

Commit

Permalink
Tuned default args
Browse files Browse the repository at this point in the history
Tuned default arguments for both LSTNet and SkipGRU.
Condensed arithmetic in SkipGRU.
  • Loading branch information
Chase-Grajeda committed Jun 13, 2024
1 parent af0e1fb commit 360874e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
12 changes: 8 additions & 4 deletions bayesflow/experimental/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@ class LSTNet(keras.Model):

def __init__(
self,
cnn_out: int,
cnn_out: int = 128,
kernel_size: int = 4,
kernel_initializer: str = "glorot_uniform",
kernel_regularizer: regularizers.Regularizer | None = None,
activation: str = "relu",
gru_out: int = 64,
resnet_out: int = 32,
skip_outs: list[int] = [32],
skip_steps: list[int] = [2],
resnet_out: int = 32,
**kwargs
):
if len(skip_outs) != len(skip_steps):
raise ValueError("hidden_out must have same length as skip_steps")

super().__init__(**keras_kwargs(kwargs))

# Define model
Expand All @@ -43,13 +47,13 @@ def __init__(
kernel_regularizer=kernel_regularizer
)
self.bnorm = layers.BatchNormalization()
self.skip_gru = SkipGRU(gru_out, skip_steps)
self.skip_gru = SkipGRU(gru_out, skip_outs, skip_steps)
self.resnet = ResNet(width=resnet_out)

# Aggregate layers In: (batch, time steps, num series)
self.model.add(self.conv1) # -> (batch, reduced time steps, cnn_out)
self.model.add(self.bnorm) # -> (batch, reduced time steps, cnn_out)
self.model.add(self.skip_gru) # -> (batch, gru_out)
self.model.add(self.skip_gru) # -> (batch, _)
self.model.add(self.resnet) # -> (batch, resnet_out)

def call(self, x: Tensor) -> Tensor:
Expand Down
42 changes: 23 additions & 19 deletions bayesflow/experimental/networks/lstnet/skip_gru.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,44 @@
import keras
from keras.saving import register_keras_serializable
from keras import layers, Sequential
from keras import layers
from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import keras_kwargs

@register_keras_serializable(package="bayesflow.networks.skip_gru")
class SkipGRU(keras.Model):
def __init__(self, gru_out: int, skip_steps: list[int], **kwargs):
"""
Implements a Skip GRU layer as described in [1]
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190.
TODO: Add proper docstring
"""
def __init__(self, gru_out: int, skip_outs: list[int], skip_steps: list[int], **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.gru_out = gru_out
self.skip_steps = skip_steps
self.gru = layers.GRU(gru_out)
self.skip_grus = [layers.GRU(gru_out) for _ in range(len(self.skip_steps))]
self.skip_grus = [layers.GRU(skip_outs[i]) for i in range(len(self.skip_steps))]

def call(self, x: Tensor) -> Tensor:
# Standard GRU
# In: (batch, reduced time steps, cnn_out)
gru = self.gru(x) # -> (batch, gru_out)

# Skip GRU
sgru = self.gru(x)
for i, skip_step in enumerate(self.skip_steps):
# Reshape, remove skipped time points
skip_length = x.shape[1] // skip_step
s = x[:, -skip_length * skip_step:, :] # -> (batch, shrinked time steps, cnn_out)
s1 = keras.ops.reshape(s, (-1, s.shape[2], skip_length, skip_step)) # -> (batch, cnn_out, skip_length, skip_step)
s2 = keras.ops.transpose(s1, [0, 3, 2, 1]) # -> (batch, skip step, skip_length, cnn_out)
s3 = keras.ops.reshape(s2, (-1, s2.shape[2], s2.shape[3])) # -> (batch * skip step, skip_length, cnn_out)

# GRU on remaining data
s4 = self.skip_grus[i](s3) # -> (batch * skip step, gru_out)
s5 = keras.ops.reshape(s4, (-1, skip_step * s4.shape[1])) # -> (batch, skip step * gru_out)
s = x[:, -skip_length * skip_step:, :]
s = keras.ops.reshape(s, (-1, s.shape[2], skip_length, skip_step))
s = keras.ops.transpose(s, [0, 3, 2, 1])
s = keras.ops.reshape(s, (-1, s.shape[2], s.shape[3]))

# Concat
gru = keras.ops.concatenate([gru, s5], axis=1) # -> (batch, gru_out * skip step * 2)
# Reapply GRU, add to working tensor
s = self.skip_grus[i](s)
s = keras.ops.reshape(s, (-1, skip_step * s.shape[1]))
sgru = keras.ops.concatenate([sgru, s], axis=1)

return gru
return sgru

def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))

0 comments on commit 360874e

Please sign in to comment.