-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Faster load time for large models #2350
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@gau-nernst I have thought about doing this, it is something that would be worthwhile as vision models are increasing in size, but, it's a lil bit of work to make it all work nicely :) Beyond something like stochastic depth, there are also quite a few other common cases re arange, ones/zeroes etc for non-persistent buffers, quite common for relative pos embeds, and sure there are other cases. The backwards compat does need to be addressed, I still try to keep things working back to roughly pytorch 1.12 or so. safetensors should be using mmap. willing to work through the issues with you and figure out a good solution, I had really hoped that pytorch would have a one liner by this point that'd deal with all the init / buffer issues without having to pull in extra deps or DIY |
Glad to hear! I will slowly work through the errors and ping you when ready for another pass. For now let's focus on meta device to skip weight init first.
Does using numpy instead for some of these cases ok? At least for stochastic depth, I don't see a problem. Will need to look into other cases. Also, is it ok to depend on numpy? I saw Backward-compat should be do-able, it's just gonna make the code a bit ugly 😅 |
@gau-nernst I definitely don't want to start bringing in numpy as an alternative. There's definitely a way to do it properly with torch, I guess we'll see how messy it ends up. There's also a related issue where you do want to init the model (from scratch, not pretrained) and want to do so on the GPU to avoid the cpu -> gpu step. It's actually a problem for many of the same initializaiton steps that are a problem with meta devices, initializing the pos embed buffers, etc can end up very different if you do it on CPU vs GPU due to compounding of float rounding differences, using bfloat16 instead of float32, etc... it's better to force that on CPU even if the model weights are being on GPU. Something to keep in mind. |
Manually added a lot of One side note. Because of the explicit |
Hi @rwightman, do you have some time to take another look at this? (and run the CI). Thank you! |
@rwightman Sorry for pinging you again. Do you have the time to review this? Thank you. I also tested with |
While playing around with Flux-Redux, which used siglip-so400m, I noticed that loading time is pretty slow. I think we can bring some of the loading-time optimizations for LLMs to timm too. Hence, I open this RFC to ask for feedback if you think it is a useful improvement to timm.
As a proof-of-concept, I added meta device to skip weight initialization.
Benchmark script (lmk if you want to use a different way to benchmark weight loading time)
Some considerations about meta device
Do let me know your thoughts and whether I should proceed with this PR.
Apart from using meta device, some other optimizations we can look into (possibly in future PRs):
torch.load(mmap=True)
(I believe safetensors already uses memory-map by default?)model.load_state_dict(assign=True)