Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code vs paper #10

Open
EelcoHoogendoorn opened this issue Feb 8, 2023 · 1 comment
Open

Code vs paper #10

EelcoHoogendoorn opened this issue Feb 8, 2023 · 1 comment

Comments

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Feb 8, 2023

Hi all,

Im working on a JAX implementation of a hippo-gated-rnn; I wasnt quite sure how to interpret the diagram in the paper linked below; but indeed I cannot quite mesh it with the torch implementation linked below as well. The code seems more sensible to me than the paper; that is it makes sense to me to have the ssm see the raw data unfiltered; placing the nonlinear gated action in front seems like it might jeopardize the unobstructed flow of gradients along the sequence.

The version from the diagram in the paper works quite alright though, in my use case. Though the torch version seems to converge more quickly. Just curious if im misreading something here, or what your latest thinking on these matters is.

EDIT: I had very good experience with the paper version in terms of avoiding exploding gradients; while the code version seems to converge faster and smoother initially, I do observe the gated unit to be able to explode on longer trajectories. Lots of things to explore here I suppose; deep/stacked ssms with pointwise nonlinearities have not worked for me so far.

class GatedHippo(nn.Module):
	"""
	linear state-space-model coupled with a gated-nonlinear module
	"""
	ssm: nn.Module
	rnn: nn.Module = MGUCell()

	@nn.compact
	def __call__(self, carry, inputs):
		"""How I read the diagram in the paper
	        https://arxiv.org/pdf/2008.07669.pdf
                """
		rnn_carry, rnn_output = self.rnn(carry['rnn'], jnp.concatenate([carry['ssm'], inputs]))
		ssm_carry, ssm_output = self.ssm(carry['ssm'], rnn_output)
		carry = {'rnn': rnn_carry, 'ssm': ssm_carry}
		return carry, rnn_output  # should we use ssm output here, lest it goes unused?

	@nn.compact
	def __call__(self, carry, inputs):
		"""How I read the torch code
		https://github.com/HazyResearch/hippo-code/blob/201148256bd2b71cb07668dc00075420cfd4c567/model/model.py#L79
		"""
		ssm_carry, ssm_output = self.ssm(carry['ssm'], inputs)
		rnn_carry, rnn_output = self.rnn(carry['rnn'], jnp.concatenate([ssm_output, inputs]))
		carry = {'rnn': rnn_carry, 'ssm': ssm_carry}
		return carry, rnn_output
@albertfgu
Copy link
Contributor

The version in the torch code is the one used in all experiments in the original paper. Any differences from the paper figure is probably due to a different interpretation. To be honest, the RNN cell was somewhat arbitrary so there are a lot of reasonable alternatives. The original HiPPO-RNN cell has long been abandoned in favor of the S4 approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants