-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Faster load time for large models #2350
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@gau-nernst I have thought about doing this, it is something that would be worthwhile as vision models are increasing in size, but, it's a lil bit of work to make it all work nicely :) Beyond something like stochastic depth, there are also quite a few other common cases re arange, ones/zeroes etc for non-persistent buffers, quite common for relative pos embeds, and sure there are other cases. The backwards compat does need to be addressed, I still try to keep things working back to roughly pytorch 1.12 or so. safetensors should be using mmap. willing to work through the issues with you and figure out a good solution, I had really hoped that pytorch would have a one liner by this point that'd deal with all the init / buffer issues without having to pull in extra deps or DIY |
Glad to hear! I will slowly work through the errors and ping you when ready for another pass. For now let's focus on meta device to skip weight init first.
Does using numpy instead for some of these cases ok? At least for stochastic depth, I don't see a problem. Will need to look into other cases. Also, is it ok to depend on numpy? I saw Backward-compat should be do-able, it's just gonna make the code a bit ugly 😅 |
@gau-nernst I definitely don't want to start bringing in numpy as an alternative. There's definitely a way to do it properly with torch, I guess we'll see how messy it ends up. There's also a related issue where you do want to init the model (from scratch, not pretrained) and want to do so on the GPU to avoid the cpu -> gpu step. It's actually a problem for many of the same initializaiton steps that are a problem with meta devices, initializing the pos embed buffers, etc can end up very different if you do it on CPU vs GPU due to compounding of float rounding differences, using bfloat16 instead of float32, etc... it's better to force that on CPU even if the model weights are being on GPU. Something to keep in mind. |
Manually added a lot of One side note. Because of the explicit |
Hi @rwightman, do you have some time to take another look at this? (and run the CI). Thank you! |
@rwightman Sorry for pinging you again. Do you have the time to review this? Thank you. I also tested with |
@gau-nernst sorry for not checking in here for a bit (I should have at least kicked off the CI), was trying to tick off a few things before end of year, I'll take a closer look at the current state. Though just a warning, still juggling quite a few things so might get side tracked again... FYI CI can be run locally, though it's really slow to run the full thing but can hack it to filter out parts initially, etc. |
Thanks for the update. I ran some of the CI tests locally but felt like it did not cover all the changes/models. You would know best if existing tests cover the new changes (i.e. specifically |
@gau-nernst there is a test that loads all pretrained checkpoints when run outside of the github CI, but it cannot verify correctness / changes in output, only that the weights load. I created some very small |
So I ran through some models, many of them work, resnet, vit, etc. But as soon as you get into models with non-persistent buffers, lots of issues. swin, maxvit crash, eva02 models don't crash but are producing garbage outputs. |
I was thinking about this recently, the faster load time is important, but also related are things like better init compatibility with advanced param sharding, etc. In many cases you want to be able to call the init explicitly after meta-device creation / sharding, etc. Wondering if that might need to be part of the solution anyways? Have a mechanism to call init on any model after creation (instead of just having it via init right now)... AND be able to specify if trainable params and/or buffers get init, etc... |
I discovered that when we init non-persistent buffers on CPU device, then later call
I don't know about other frameworks, but at least for PyTorch's FSDP1/2, they rely on |
@gau-nernst the FSDP1 way did, but FSDP2 doesn't rely on reset_parameters... I was thinking that might be a model to follow. If I do add FSDP2 support into timm it will be 2.
Depending on how messy addressing the non-persistent buffers turns, if we had a reliable
And for things like FSDP2,
etc, |
If you do find a clean approach to current buffers issue, can proceed as is of course. Ultimately I feel this will need to move in a direction where there is a consistent init_weights interface for the models though, to address all the use cases of using the meta device and having an option of decoupling weight init from |
Pushed a quick fix # .to_empty() will also move cpu params/buffers to uninitialized storage.
# this is problematic for non-persistent buffers, since they don't get loaded
# from pretrained weights later (not part of state_dict). hence, we have
# to save them before calling .to_empty() and fill them back after.
buffers = {k: v for k, v in model.named_buffers() if not v.is_meta}
model.to_empty(device="cpu")
for k, v in model.named_buffers():
if k in buffers:
v.data = buffers[k]
# alternative, rely on internal method ._apply()
# model._apply(lambda t: torch.empty_like(t, device="cpu") if t.is_meta else t)
# for reference, .to_empty() is implemented with
# self._apply(lambda t: torch.empty_like(t, device=device)) Tested with Swin and it seems to work correctly. Will manually check other models... The alternative approach using Extra thoughts. This meta-device problem concerns params+buffers that don't exist in pretrained weights. Non-persistent buffer is one. I believe some checkpoints also don't have classifier head, such as DINO, which won't be initialized correctly if users set I think solution is either
|
@gau-nernst hmm, yeah it's not just models without the classifier, any use case where num classes doesn't match and a new different classifier gets created (old one removed from state_dict) would be a problem no? The old classifier is removed from the weights, and it relies on the init to provide a replacement. Also complicating an 'easy' work around, the head is initialized differently by different model inits, so simply running model.get_classifier().reset_parameters() would result in differing behaviour than right now... hmm. pytorch-image-models/timm/models/_builder.py Lines 246 to 251 in 131518c
Changing in_chans should be fine because it modifies the state_dict input conv weight. |
FYI, I'm running some bulk evals to compare against existing numbers. |
Added a check to use meta-device init only when pytorch-image-models/timm/models/_builder.py Lines 433 to 434 in 131518c
Here, determining |
@gau-nernst hmm, so the idea with the num_classes ultimately coming from the model was that it is the ultimate source of truth wrt to what config/args will result in the final model classifier shape that will need to match the pretrained weights. That Right now what's passed in via kwargs to override the default is I think transparent for all models, it ends up matching model.num_classes. I'd envisioned scenarios where that might not be the case. Changing the num_classes is such a common use case here, it would definitely diminish the utility to not support that case. |
Something else came up in the larger test, models using BlurPool are failing. Also, thinking about this a bit more, this solution is going to cause issues with some use cases... namely anyone use torch.set_default_device or using device context managers themselves, it will break. The use of 'cpu' for the tensor init fns, while allowing the meta device context to be used, will break anyone trying to use a default device or context manager for another, different device. This applies at both the model or layer level, some people do use timm just for specific layers that aren't elsewhere. And several layers now have device='cpu'... |
Will look into BlurPool models, but I will be travelling soon, so there will be few updates from me until next week.
I also mentioned this previously, but didn't think it would be a big problem. We can always do It seems like you are leaning towards having sth like
Point 3 requires factoring out head init logic for all models, so there is no workaround. Point 1 and 4 are covered in this PR. This PR can provide workaround for point 2, though it doesn't look very "nice" (and also breaking at the "layer level" mentioned above). Will you prefer putting this PR on hold while working on refactoring |
@gau-nernst yeah, it's at the layer level I was getting more concerned... there is no way to realiably track down what the uses of various layers are in the wild, it could just break a lot of things. This is mostly working right now so that's promising but I feel a custom init fn with consistent naming/interface could help avoid the layer and model breakage concerns by allowing patching / skipping those functions, etc ... maybe adding some state tracking to handle head re-init. I will think about this a bit more... |
While playing around with Flux-Redux, which used siglip-so400m, I noticed that loading time is pretty slow. I think we can bring some of the loading-time optimizations for LLMs to timm too. Hence, I open this RFC to ask for feedback if you think it is a useful improvement to timm.
As a proof-of-concept, I added meta device to skip weight initialization.
Benchmark script (lmk if you want to use a different way to benchmark weight loading time)
Some considerations about meta device
Do let me know your thoughts and whether I should proceed with this PR.
Apart from using meta device, some other optimizations we can look into (possibly in future PRs):
torch.load(mmap=True)
(I believe safetensors already uses memory-map by default?)model.load_state_dict(assign=True)