-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port of PureJaxRL to equinox #926
Comments
The most surprising thing about this to me at first glance is that a trivial 3 layer MLP is almost 100% slower with equinox than with flax. Even without any other comparisons, this comparison is quite small and direct (just a few lines of equinox code and a few lines of flax) and surprising (I would also recommend isolating at least that part in a MVC). I will take more of a look tomorrow. I am also working on an equinox based RL library (https://github.com/lockwo/NARLL, I know its private right now, but when I open it up the link will work), so this is definitely of interest to me. |
Great to hear about NARLL, looking forward. I find the benchmarks also surprising to the extend that i think there must be a mistake in the way i implemented it but havent discovered it yet. I actually didnt intend to frame my question too much around my project as I am a JAX beginner and dont want to ask people to correct my code. What I am most interested in is other's experience on comparing the performance of equinox and flax/linen as this must be something people have looked into ... |
They should usually both get the exact same performance! The reason for this is that they usually end up expressing pretty much the same JAX-level computation graph, and at that point it's all in the hands of the jit compiler. Or put another way, Equinox and Flax both just help you to organise your code -- not what it compiles to. On benchmarking - there are a couple of common mistakes that can be made here: things like measuring the cost of compilation (run your program once to compile it before timing it), or missing a |
Thanks Patrick for this clarification. I use |
I isolated some of the MLP specific code here: #928 |
I’ve been tinkering with porting PureJaxRL (which is Flax/Linen-based) over to Equinox. I’m still new to JAX, so my first attempt is pretty raw—if you’re curious, it’s here: https://github.com/dorjeduck/eqxrl.
I really like the philosophy behind Equinox, so thanks for making such a great library! That said, I’ve noticed the performance of my port seems slower compared to the Flax/Linen version. Is it a common thing when comparing the two libs or is it most likely due to my beginner way of implementing equinox based tools. I tried to add warmup rounds and avoid compilation during benchmark time periods but i might surely miss things here.
I haven’t tried NNx yet—it's pragmatism seems to lead it too much away from the JAX way of thinking for my taste.
Sorry for dropping this as an issue—I didn’t see a better place to ask, and it seems like others do this here too. Appreciate any advice!
The text was updated successfully, but these errors were encountered: