Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Faster load time for large models #2350

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

gau-nernst
Copy link
Contributor

While playing around with Flux-Redux, which used siglip-so400m, I noticed that loading time is pretty slow. I think we can bring some of the loading-time optimizations for LLMs to timm too. Hence, I open this RFC to ask for feedback if you think it is a useful improvement to timm.

As a proof-of-concept, I added meta device to skip weight initialization.

Benchmark script (lmk if you want to use a different way to benchmark weight loading time)

import torch
import timm
import time

model = timm.create_model("vit_so400m_patch14_siglip_378.webli", pretrained=True)

N = 4
time0 = time.perf_counter()
for _ in range(N):
    model = timm.create_model("vit_so400m_patch14_siglip_378.webli", pretrained=True)
print((time.perf_counter() - time0) / N)

model(torch.randn(1, 3, 378, 378))  # make sure model forward works
Name Time (s)
Baseline 4.20
w/ meta device 0.72

Some considerations about meta device

  • meta device only exists for PyTorch>=2.3 I think. Need to guard the usage of meta device against PyTorch version
  • In some cases, we need to bypass meta device during model init, such as what I did in ViT model file for calculating stochastic depth. I believe there are such cases for other models too.
  • I'm not aware of any other caveats for meta device.

Do let me know your thoughts and whether I should proceed with this PR.

Apart from using meta device, some other optimizations we can look into (possibly in future PRs):

  • Use torch.load(mmap=True) (I believe safetensors already uses memory-map by default?)
  • Use model.load_state_dict(assign=True)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rwightman
Copy link
Collaborator

@gau-nernst I have thought about doing this, it is something that would be worthwhile as vision models are increasing in size, but, it's a lil bit of work to make it all work nicely :)

Beyond something like stochastic depth, there are also quite a few other common cases re arange, ones/zeroes etc for non-persistent buffers, quite common for relative pos embeds, and sure there are other cases.

The backwards compat does need to be addressed, I still try to keep things working back to roughly pytorch 1.12 or so.

safetensors should be using mmap.

willing to work through the issues with you and figure out a good solution, I had really hoped that pytorch would have a one liner by this point that'd deal with all the init / buffer issues without having to pull in extra deps or DIY

@gau-nernst
Copy link
Contributor Author

Glad to hear! I will slowly work through the errors and ping you when ready for another pass. For now let's focus on meta device to skip weight init first.

Beyond something like stochastic depth, there are also quite a few other common cases re arange, ones/zeroes etc for non-persistent buffers, quite common for relative pos embeds, and sure there are other cases.

Does using numpy instead for some of these cases ok? At least for stochastic depth, I don't see a problem. Will need to look into other cases. Also, is it ok to depend on numpy? I saw pyproject.toml doesn't specify numpy, but requirements.txt has numpy. And there are some numpy imports in the repo.

Backward-compat should be do-able, it's just gonna make the code a bit ugly 😅

@rwightman
Copy link
Collaborator

rwightman commented Dec 1, 2024

@gau-nernst I definitely don't want to start bringing in numpy as an alternative. There's definitely a way to do it properly with torch, I guess we'll see how messy it ends up.

There's also a related issue where you do want to init the model (from scratch, not pretrained) and want to do so on the GPU to avoid the cpu -> gpu step. It's actually a problem for many of the same initializaiton steps that are a problem with meta devices, initializing the pos embed buffers, etc can end up very different if you do it on CPU vs GPU due to compounding of float rounding differences, using bfloat16 instead of float32, etc... it's better to force that on CPU even if the model weights are being on GPU. Something to keep in mind.

@gau-nernst
Copy link
Contributor Author

gau-nernst commented Dec 2, 2024

Manually added a lot of device="cpu" for stochastic depth and non-persistent buffer. Which test/command should I run to test everything is working correctly in my local env (must use pretrained=True to trigger meta device behavior)? I tried pytest -vv --forked --durations=0 -m base "tests/test_models.py::test_model_inference" but it only ran 13 tests, so definitely did not cover everything.

One side note. Because of the explicit device="cpu", non-persistent buffers will always default to CPU. Without this PR, it will default to default device (which can be CUDA or others). To preserve the previous behavior, I think we have to do something like check for default device, and change it to "cpu" if default device is "meta", which again would take some efforts, but do-able. Though I think having non-persistent buffers default to "cpu" is not a big deal - users can (and should) call .cuda() on the model later, before training/inference (or we can do it in build_model_with_cfg()?).

@gau-nernst
Copy link
Contributor Author

Hi @rwightman, do you have some time to take another look at this? (and run the CI). Thank you!

@gau-nernst
Copy link
Contributor Author

@rwightman Sorry for pinging you again. Do you have the time to review this? Thank you.

I also tested with timm/eva_giant_patch14_336.clip_ft_in1k (1B params) and the loading time reduces from 9.8s to 1.1s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants