You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Im working on a JAX implementation of a hippo-gated-rnn; I wasnt quite sure how to interpret the diagram in the paper linked below; but indeed I cannot quite mesh it with the torch implementation linked below as well. The code seems more sensible to me than the paper; that is it makes sense to me to have the ssm see the raw data unfiltered; placing the nonlinear gated action in front seems like it might jeopardize the unobstructed flow of gradients along the sequence.
The version from the diagram in the paper works quite alright though, in my use case. Though the torch version seems to converge more quickly. Just curious if im misreading something here, or what your latest thinking on these matters is.
EDIT: I had very good experience with the paper version in terms of avoiding exploding gradients; while the code version seems to converge faster and smoother initially, I do observe the gated unit to be able to explode on longer trajectories. Lots of things to explore here I suppose; deep/stacked ssms with pointwise nonlinearities have not worked for me so far.
classGatedHippo(nn.Module):
""" linear state-space-model coupled with a gated-nonlinear module """ssm: nn.Modulernn: nn.Module=MGUCell()
@nn.compactdef__call__(self, carry, inputs):
"""How I read the diagram in the paper https://arxiv.org/pdf/2008.07669.pdf """rnn_carry, rnn_output=self.rnn(carry['rnn'], jnp.concatenate([carry['ssm'], inputs]))
ssm_carry, ssm_output=self.ssm(carry['ssm'], rnn_output)
carry= {'rnn': rnn_carry, 'ssm': ssm_carry}
returncarry, rnn_output# should we use ssm output here, lest it goes unused?@nn.compactdef__call__(self, carry, inputs):
"""How I read the torch code https://github.com/HazyResearch/hippo-code/blob/201148256bd2b71cb07668dc00075420cfd4c567/model/model.py#L79 """ssm_carry, ssm_output=self.ssm(carry['ssm'], inputs)
rnn_carry, rnn_output=self.rnn(carry['rnn'], jnp.concatenate([ssm_output, inputs]))
carry= {'rnn': rnn_carry, 'ssm': ssm_carry}
returncarry, rnn_output
The text was updated successfully, but these errors were encountered:
The version in the torch code is the one used in all experiments in the original paper. Any differences from the paper figure is probably due to a different interpretation. To be honest, the RNN cell was somewhat arbitrary so there are a lot of reasonable alternatives. The original HiPPO-RNN cell has long been abandoned in favor of the S4 approach.
Hi all,
Im working on a JAX implementation of a hippo-gated-rnn; I wasnt quite sure how to interpret the diagram in the paper linked below; but indeed I cannot quite mesh it with the torch implementation linked below as well. The code seems more sensible to me than the paper; that is it makes sense to me to have the ssm see the raw data unfiltered; placing the nonlinear gated action in front seems like it might jeopardize the unobstructed flow of gradients along the sequence.
The version from the diagram in the paper works quite alright though, in my use case. Though the torch version seems to converge more quickly. Just curious if im misreading something here, or what your latest thinking on these matters is.
EDIT: I had very good experience with the paper version in terms of avoiding exploding gradients; while the code version seems to converge faster and smoother initially, I do observe the gated unit to be able to explode on longer trajectories. Lots of things to explore here I suppose; deep/stacked ssms with pointwise nonlinearities have not worked for me so far.
The text was updated successfully, but these errors were encountered: