Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port of PureJaxRL to equinox #926

Open
dorjeduck opened this issue Jan 6, 2025 · 5 comments
Open

Port of PureJaxRL to equinox #926

dorjeduck opened this issue Jan 6, 2025 · 5 comments

Comments

@dorjeduck
Copy link

dorjeduck commented Jan 6, 2025

I’ve been tinkering with porting PureJaxRL (which is Flax/Linen-based) over to Equinox. I’m still new to JAX, so my first attempt is pretty raw—if you’re curious, it’s here: https://github.com/dorjeduck/eqxrl.

I really like the philosophy behind Equinox, so thanks for making such a great library! That said, I’ve noticed the performance of my port seems slower compared to the Flax/Linen version. Is it a common thing when comparing the two libs or is it most likely due to my beginner way of implementing equinox based tools. I tried to add warmup rounds and avoid compilation during benchmark time periods but i might surely miss things here.

I haven’t tried NNx yet—it's pragmatism seems to lead it too much away from the JAX way of thinking for my taste.

Sorry for dropping this as an issue—I didn’t see a better place to ask, and it seems like others do this here too. Appreciate any advice!

@lockwo
Copy link
Contributor

lockwo commented Jan 6, 2025

The most surprising thing about this to me at first glance is that a trivial 3 layer MLP is almost 100% slower with equinox than with flax. Even without any other comparisons, this comparison is quite small and direct (just a few lines of equinox code and a few lines of flax) and surprising (I would also recommend isolating at least that part in a MVC). I will take more of a look tomorrow.

I am also working on an equinox based RL library (https://github.com/lockwo/NARLL, I know its private right now, but when I open it up the link will work), so this is definitely of interest to me.

@dorjeduck
Copy link
Author

dorjeduck commented Jan 6, 2025

Great to hear about NARLL, looking forward.

I find the benchmarks also surprising to the extend that i think there must be a mistake in the way i implemented it but havent discovered it yet.

I actually didnt intend to frame my question too much around my project as I am a JAX beginner and dont want to ask people to correct my code. What I am most interested in is other's experience on comparing the performance of equinox and flax/linen as this must be something people have looked into ...

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 6, 2025

They should usually both get the exact same performance!

The reason for this is that they usually end up expressing pretty much the same JAX-level computation graph, and at that point it's all in the hands of the jit compiler.

Or put another way, Equinox and Flax both just help you to organise your code -- not what it compiles to.

On benchmarking - there are a couple of common mistakes that can be made here: things like measuring the cost of compilation (run your program once to compile it before timing it), or missing a jax.block_until_ready on the output.

@dorjeduck
Copy link
Author

dorjeduck commented Jan 6, 2025

Thanks Patrick for this clarification. I use jax.block_until_ready and pre compile but there are other aspects I will have to look into. Looking forward to work with your library.

@lockwo
Copy link
Contributor

lockwo commented Jan 6, 2025

I isolated some of the MLP specific code here: #928

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants