-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Need Forwarding with state. #22
Comments
https://github.com/sieusaoml/xLSTM-custom-block |
Yes, I want to do something similar. But in the code, is it only sLSTM that can be initialized with the previous hidden state? Can't mLSTM be initialized with the previous state? |
The step() method and the forward() method of mLSTMLayer use different type of conv1d forward, so I think if you want to use hidden state, you need to use step() token by token instead of forward all of tokens at the same time. |
Yes, I am not asking to forward all of the tokens at the same time. In fact, my original model was an LSTM, which processes each token in a loop. I just want to replace this LSTM with xLSTM. But it seems that 'step' is used during inference, right? May I ask if it can backpropagate normally during training? Will the inplace operations lead to backpropagation errors? |
mLSTMLayer can be used with the previous hidden state, but backpropagate gradient in my test with context_lenght=1 has an error |
Translation:
When training, it runs without state:
def forward(self, idx: torch.Tensor) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x) logits = self.lm_head(x) return logits
Can you give a “forward with state” version?
def forward(self, idx: torch.Tensor, state) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x, state) logits = self. lm_head(x) return logits
The text was updated successfully, but these errors were encountered: