Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting State Dictionary #922

Open
krypticmouse opened this issue Jan 2, 2025 · 6 comments
Open

Getting State Dictionary #922

krypticmouse opened this issue Jan 2, 2025 · 6 comments
Labels
question User queries

Comments

@krypticmouse
Copy link

krypticmouse commented Jan 2, 2025

Hi 👋

I've been trying to get the state dictionary from an equinox model and for that I've built this function:

def get_state_dict(model):
    arrays, _ = eqx.partition(model, eqx.is_array)
    paths_and_values = jtu.tree_flatten_with_path(arrays)[0]
    
    state_dict = {}
    for (path, value) in paths_and_values:
        path_parts = []
        for p in path:
            if hasattr(p, 'name'):
                path_parts.append(p.name)
            elif hasattr(p, 'idx'):
                path_parts.append(str(p.idx))
        path_str = '.'.join(path_parts)
        state_dict[path_str] = value
        
    return state_dict

I wanted to know:
(a) if there is any better or elegant way to extract state_dict from an equinox model.
(b) Is there native support to set model params by passing state_dict?

Thanks!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 2, 2025

So Equinox has a slightly different approach to this: namely, the model itself is its own 'state dictionary'. Or rather, 'state pytree'. Basically, we take a slightly simpler approach in which we just don't cast things to dictionaries at all! And so for example to answer your part (b), we would normally load model params using something like eqx.tree_deserialise_leaves.

I'm guessing you probably care about this in the context of serialization. In this case you might like the API docs on serialization, and this full example on the topic.

Nonetheless, to actually answer your question (a): if you really do want to convert a pytree to a dictionary, then I think what you're doing here looks good to me! :)

@patrick-kidger patrick-kidger added the question User queries label Jan 2, 2025
@krypticmouse
Copy link
Author

krypticmouse commented Jan 2, 2025

Hi Patrick!

Thank you so much for the answer!! I went through the serialization docs, the reason I really want to avoid pytree is because I'm building a lib where I want to save and load the model from safetensors stored in pytorch format in hf model hub. So direct use of deserialisation doesn't work for the existing safetensors format, maybe I'm doing it wrong?

For the usecase I'd want the saved equinox model weights to be able to get loaded in pytorch model as well and vice versa. Hence the issue arised.

Thank you so much again!! It really helped!!

@patrick-kidger
Copy link
Owner

Ah, if you're trying to explicitly convert to a PyTorch-like format then indeed something like what you're doing is appropriate.

The conversion back again is probably also doable by tree-mapping-with-path over your model, and looking up the path in your state dict.

@krypticmouse
Copy link
Author

That's a great idea!! Yea I'll to do it via that then!!

Thank you so much for the help and for maintaining this awesome lib 🙏

@lockwo
Copy link
Contributor

lockwo commented Jan 2, 2025

I think there's also a project for PyTorch to equinox (#714) and https://github.com/Artur-Galstyan/statedict2pytree?tab=readme-ov-file that might be useful

@dlwh
Copy link
Contributor

dlwh commented Jan 3, 2025

@krypticmouse fwiw I have equinox-trees-to-state-dict in Haliax (https://haliax.readthedocs.io/en/latest/state-dict/) and the logic for sharding models and uploading to HF (via safetensors) in Levanter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants