Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] xLSTMTime implementation #1709

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
105 changes: 105 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
import math


class mLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None):
super(mLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.layer_norm = layer_norm

self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.Wq = nn.Linear(input_size, hidden_size)
self.Wk = nn.Linear(input_size, hidden_size)
self.Wv = nn.Linear(input_size, hidden_size)

self.Wi = nn.Linear(input_size, hidden_size)
self.Wf = nn.Linear(input_size, hidden_size)
self.Wo = nn.Linear(input_size, hidden_size)

self.Wq.to(self.device)
self.Wk.to(self.device)
self.Wv.to(self.device)
self.Wi.to(self.device)
self.Wf.to(self.device)
self.Wo.to(self.device)

self.dropout = nn.Dropout(dropout)
self.dropout.to(self.device)

if layer_norm:
self.ln_q = nn.LayerNorm(hidden_size)
self.ln_k = nn.LayerNorm(hidden_size)
self.ln_v = nn.LayerNorm(hidden_size)
self.ln_i = nn.LayerNorm(hidden_size)
self.ln_f = nn.LayerNorm(hidden_size)
self.ln_o = nn.LayerNorm(hidden_size)

self.ln_q.to(self.device)
self.ln_k.to(self.device)
self.ln_v.to(self.device)
self.ln_i.to(self.device)
self.ln_f.to(self.device)
self.ln_o.to(self.device)

self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()

def forward(self, x, h_prev, c_prev, n_prev):

x = x.to(self.device)
h_prev = h_prev.to(self.device)
c_prev = c_prev.to(self.device)
n_prev = n_prev.to(self.device)

batch_size = x.size(0)
assert x.dim() == 2, f"Input should be 2D (batch_size, input_size), got {x.dim()}D"
assert h_prev.size() == (batch_size, self.hidden_size), f"h_prev shape mismatch: {h_prev.size()}"
assert c_prev.size() == (batch_size, self.hidden_size), f"c_prev shape mismatch: {c_prev.size()}"
assert n_prev.size() == (batch_size, self.hidden_size), f"n_prev shape mismatch: {n_prev.size()}"

x = self.dropout(x)
h_prev = self.dropout(h_prev)

q = self.Wq(x)
k = self.Wk(x) / math.sqrt(self.hidden_size)
v = self.Wv(x)

if self.layer_norm:
q = self.ln_q(q)
k = self.ln_k(k)
v = self.ln_v(v)

i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x))
f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x))
o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x))

k_expanded = k.unsqueeze(-1)
v_expanded = v.unsqueeze(-2)

kv_interaction = k_expanded @ v_expanded

kv_sum = kv_interaction.sum(dim=1)

c = f * c_prev + i * kv_sum
n = f * n_prev + i * k

epsilon = 1e-8
normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon)
h = o * self.tanh(c * normalized_n)

return h, c, n

def init_hidden(self, batch_size):
"""
Initialize hidden, cell, and normalization states.
"""
shape = (batch_size, self.hidden_size)
return (
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
)
83 changes: 83 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
from pytorch_forecasting.models.xLSTMTime.mLSTM.cell import mLSTMCell


class mLSTMLayer(nn.Module):
def __init__(
self, input_size, hidden_size, num_layers, dropout=0.2, layer_norm=True, residual_conn=True, device=None
):
super(mLSTMLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.layer_norm = layer_norm
self.residual_conn = residual_conn
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.dropout = nn.Dropout(dropout).to(self.device)

self.cells = nn.ModuleList(
[
mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, dropout, layer_norm, self.device)
for i in range(num_layers)
]
)

def init_hidden(self, batch_size):
"""
Initialize hidden, cell, and normalization states for all layers.
"""
hidden_states, cell_states, norm_states = zip(
*[self.cells[i].init_hidden(batch_size) for i in range(self.num_layers)]
)

return (
torch.stack(hidden_states).to(self.device),
torch.stack(cell_states).to(self.device),
torch.stack(norm_states).to(self.device),
)

def forward(self, x, h=None, c=None, n=None):
"""
Forward pass for the mLSTM layer.
"""

x = x.to(self.device).transpose(0, 1)
batch_size, seq_len, _ = x.size()

if h is None or c is None or n is None:
h, c, n = self.init_hidden(batch_size)

outputs = []

for t in range(seq_len):
layer_input = x[:, t, :]
next_hidden_states = []
next_cell_states = []
next_norm_states = []

for i, cell in enumerate(self.cells):

h_i, c_i, n_i = cell(layer_input, h[i], c[i], n[i])

if self.residual_conn and i > 0:
h_i = h_i + layer_input

layer_input = h_i

next_hidden_states.append(h_i)
next_cell_states.append(c_i)
next_norm_states.append(n_i)

h = torch.stack(next_hidden_states).to(self.device)
c = torch.stack(next_cell_states).to(self.device)
n = torch.stack(next_norm_states).to(self.device)

outputs.append(h[-1])

output = torch.stack(outputs, dim=1)

output = output.transpose(0, 1)

return output, (h, c, n)
38 changes: 38 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/mLSTM/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch.nn as nn
import torch
from pytorch_forecasting.models.xLSTMTime.mLSTM.layer import mLSTMLayer


class mLSTMNetwork(nn.Module):
def __init__(
self,
input_size,
hidden_size,
num_layers,
output_size,
dropout=0.0,
use_layer_norm=True,
use_residual=True,
device=None,
):
super(mLSTMNetwork, self).__init__()
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.mlstm_layer = mLSTMLayer(
input_size, hidden_size, num_layers, dropout, use_layer_norm, use_residual, self.device
)
self.fc = nn.Linear(hidden_size, output_size)

def forward(self, x, h=None, c=None, n=None):
"""
Forward pass through the mLSTM network.
"""
output, (h, c, n) = self.mlstm_layer(x, h, c, n)

output = self.fc(output[-1])

return output, (h, c, n)

def init_hidden(self, batch_size):
"""Initialize hidden, cell, and normalization states."""
return self.mlstm_layer.init_hidden(batch_size)
Empty file.
94 changes: 94 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import math


class sLSTMCell(nn.Module):
"""Stabilized LSTM Cell"""

def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None):
super(sLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.use_layer_norm = use_layer_norm
self.eps = 1e-6

self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.input_weights = nn.Linear(input_size, 4 * hidden_size).to(self.device)
self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size).to(self.device)

if use_layer_norm:
self.ln_cell = nn.LayerNorm(hidden_size).to(self.device)
self.ln_hidden = nn.LayerNorm(hidden_size).to(self.device)
self.ln_input = nn.LayerNorm(4 * hidden_size).to(self.device)
self.ln_hidden_update = nn.LayerNorm(4 * hidden_size).to(self.device)

self.dropout_layer = nn.Dropout(dropout).to(self.device)

self.reset_parameters()

self.grad_clip = 5.0

self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()

self.to(self.device)

def reset_parameters(self):
"""Initialize parameters using Xavier/Glorot initialization"""
std = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-std, std)

def normalized_exp_gate(self, pre_gate):
"""Compute normalized exponential gate activation"""
centered = pre_gate - torch.mean(pre_gate, dim=1, keepdim=True)
exp_val = torch.exp(torch.clamp(centered, min=-5.0, max=5.0))
normalizer = torch.sum(exp_val, dim=1, keepdim=True) + self.eps
return exp_val / normalizer

def forward(self, x, h_prev, c_prev):
"""Forward pass with stabilized exponential gating"""
x = x.to(self.device)
h_prev = h_prev.to(self.device)
c_prev = c_prev.to(self.device)

x = self.dropout_layer(x)
h_prev = self.dropout_layer(h_prev)

gates_x = self.input_weights(x)
gates_h = self.hidden_weights(h_prev)

if self.use_layer_norm:
gates_x = self.ln_input(gates_x)
gates_h = self.ln_hidden_update(gates_h)

gates = gates_x + gates_h
i, f, g, o = gates.chunk(4, dim=1)

i = self.normalized_exp_gate(i)
f = self.normalized_exp_gate(f)
gate_sum = i + f
i = i / (gate_sum + self.eps)
f = f / (gate_sum + self.eps)

c_tilde = self.tanh(g)
c = f * c_prev + i * c_tilde
if self.use_layer_norm:
c = self.ln_cell(c)

o = self.sigmoid(o)
c_out = self.tanh(c)
if self.use_layer_norm:
c_out = self.ln_hidden(c_out)
h = o * c_out

return h, c

def init_hidden(self, batch_size):
return (
torch.zeros(batch_size, self.hidden_size, device=self.device),
torch.zeros(batch_size, self.hidden_size, device=self.device),
)
Loading
Loading