-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting State Dictionary #922
Comments
So Equinox has a slightly different approach to this: namely, the model itself is its own 'state dictionary'. Or rather, 'state pytree'. Basically, we take a slightly simpler approach in which we just don't cast things to dictionaries at all! And so for example to answer your part (b), we would normally load model params using something like I'm guessing you probably care about this in the context of serialization. In this case you might like the API docs on serialization, and this full example on the topic. Nonetheless, to actually answer your question (a): if you really do want to convert a pytree to a dictionary, then I think what you're doing here looks good to me! :) |
Hi Patrick! Thank you so much for the answer!! I went through the serialization docs, the reason I really want to avoid pytree is because I'm building a lib where I want to save and load the model from safetensors stored in pytorch format in hf model hub. So direct use of deserialisation doesn't work for the existing safetensors format, maybe I'm doing it wrong? For the usecase I'd want the saved equinox model weights to be able to get loaded in pytorch model as well and vice versa. Hence the issue arised. Thank you so much again!! It really helped!! |
Ah, if you're trying to explicitly convert to a PyTorch-like format then indeed something like what you're doing is appropriate. The conversion back again is probably also doable by tree-mapping-with-path over your model, and looking up the path in your state dict. |
That's a great idea!! Yea I'll to do it via that then!! Thank you so much for the help and for maintaining this awesome lib 🙏 |
I think there's also a project for PyTorch to equinox (#714) and https://github.com/Artur-Galstyan/statedict2pytree?tab=readme-ov-file that might be useful |
@krypticmouse fwiw I have equinox-trees-to-state-dict in Haliax (https://haliax.readthedocs.io/en/latest/state-dict/) and the logic for sharding models and uploading to HF (via safetensors) in Levanter |
Hi 👋
I've been trying to get the state dictionary from an equinox model and for that I've built this function:
I wanted to know:
(a) if there is any better or elegant way to extract state_dict from an equinox model.
(b) Is there native support to set model params by passing state_dict?
Thanks!
The text was updated successfully, but these errors were encountered: