Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Faster load time for large models #2350

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

gau-nernst
Copy link
Contributor

While playing around with Flux-Redux, which used siglip-so400m, I noticed that loading time is pretty slow. I think we can bring some of the loading-time optimizations for LLMs to timm too. Hence, I open this RFC to ask for feedback if you think it is a useful improvement to timm.

As a proof-of-concept, I added meta device to skip weight initialization.

Benchmark script (lmk if you want to use a different way to benchmark weight loading time)

import torch
import timm
import time

model = timm.create_model("vit_so400m_patch14_siglip_378.webli", pretrained=True)

N = 4
time0 = time.perf_counter()
for _ in range(N):
    model = timm.create_model("vit_so400m_patch14_siglip_378.webli", pretrained=True)
print((time.perf_counter() - time0) / N)

model(torch.randn(1, 3, 378, 378))  # make sure model forward works
Name Time (s)
Baseline 4.20
w/ meta device 0.72

Some considerations about meta device

  • meta device only exists for PyTorch>=2.3 I think. Need to guard the usage of meta device against PyTorch version
  • In some cases, we need to bypass meta device during model init, such as what I did in ViT model file for calculating stochastic depth. I believe there are such cases for other models too.
  • I'm not aware of any other caveats for meta device.

Do let me know your thoughts and whether I should proceed with this PR.

Apart from using meta device, some other optimizations we can look into (possibly in future PRs):

  • Use torch.load(mmap=True) (I believe safetensors already uses memory-map by default?)
  • Use model.load_state_dict(assign=True)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rwightman
Copy link
Collaborator

@gau-nernst I have thought about doing this, it is something that would be worthwhile as vision models are increasing in size, but, it's a lil bit of work to make it all work nicely :)

Beyond something like stochastic depth, there are also quite a few other common cases re arange, ones/zeroes etc for non-persistent buffers, quite common for relative pos embeds, and sure there are other cases.

The backwards compat does need to be addressed, I still try to keep things working back to roughly pytorch 1.12 or so.

safetensors should be using mmap.

willing to work through the issues with you and figure out a good solution, I had really hoped that pytorch would have a one liner by this point that'd deal with all the init / buffer issues without having to pull in extra deps or DIY

@gau-nernst
Copy link
Contributor Author

Glad to hear! I will slowly work through the errors and ping you when ready for another pass. For now let's focus on meta device to skip weight init first.

Beyond something like stochastic depth, there are also quite a few other common cases re arange, ones/zeroes etc for non-persistent buffers, quite common for relative pos embeds, and sure there are other cases.

Does using numpy instead for some of these cases ok? At least for stochastic depth, I don't see a problem. Will need to look into other cases. Also, is it ok to depend on numpy? I saw pyproject.toml doesn't specify numpy, but requirements.txt has numpy. And there are some numpy imports in the repo.

Backward-compat should be do-able, it's just gonna make the code a bit ugly 😅

@rwightman
Copy link
Collaborator

rwightman commented Dec 1, 2024

@gau-nernst I definitely don't want to start bringing in numpy as an alternative. There's definitely a way to do it properly with torch, I guess we'll see how messy it ends up.

There's also a related issue where you do want to init the model (from scratch, not pretrained) and want to do so on the GPU to avoid the cpu -> gpu step. It's actually a problem for many of the same initializaiton steps that are a problem with meta devices, initializing the pos embed buffers, etc can end up very different if you do it on CPU vs GPU due to compounding of float rounding differences, using bfloat16 instead of float32, etc... it's better to force that on CPU even if the model weights are being on GPU. Something to keep in mind.

@gau-nernst
Copy link
Contributor Author

gau-nernst commented Dec 2, 2024

Manually added a lot of device="cpu" for stochastic depth and non-persistent buffer. Which test/command should I run to test everything is working correctly in my local env (must use pretrained=True to trigger meta device behavior)? I tried pytest -vv --forked --durations=0 -m base "tests/test_models.py::test_model_inference" but it only ran 13 tests, so definitely did not cover everything.

One side note. Because of the explicit device="cpu", non-persistent buffers will always default to CPU. Without this PR, it will default to default device (which can be CUDA or others). To preserve the previous behavior, I think we have to do something like check for default device, and change it to "cpu" if default device is "meta", which again would take some efforts, but do-able. Though I think having non-persistent buffers default to "cpu" is not a big deal - users can (and should) call .cuda() on the model later, before training/inference (or we can do it in build_model_with_cfg()?).

@gau-nernst
Copy link
Contributor Author

Hi @rwightman, do you have some time to take another look at this? (and run the CI). Thank you!

@gau-nernst
Copy link
Contributor Author

@rwightman Sorry for pinging you again. Do you have the time to review this? Thank you.

I also tested with timm/eva_giant_patch14_336.clip_ft_in1k (1B params) and the loading time reduces from 9.8s to 1.1s

@rwightman
Copy link
Collaborator

@gau-nernst sorry for not checking in here for a bit (I should have at least kicked off the CI), was trying to tick off a few things before end of year, I'll take a closer look at the current state. Though just a warning, still juggling quite a few things so might get side tracked again...

FYI CI can be run locally, though it's really slow to run the full thing but can hack it to filter out parts initially, etc.

@gau-nernst
Copy link
Contributor Author

Thanks for the update. I ran some of the CI tests locally but felt like it did not cover all the changes/models. You would know best if existing tests cover the new changes (i.e. specifically timm.create_model(pretrained=True) for all models in timm)

@rwightman
Copy link
Collaborator

@gau-nernst there is a test that loads all pretrained checkpoints when run outside of the github CI, but it cannot verify correctness / changes in output, only that the weights load. I created some very small test_ models to start verifying outputs but the coverage is minimal at this point.

@rwightman
Copy link
Collaborator

So I ran through some models, many of them work, resnet, vit, etc. But as soon as you get into models with non-persistent buffers, lots of issues. swin, maxvit crash, eva02 models don't crash but are producing garbage outputs.

@rwightman
Copy link
Collaborator

I was thinking about this recently, the faster load time is important, but also related are things like better init compatibility with advanced param sharding, etc. In many cases you want to be able to call the init explicitly after meta-device creation / sharding, etc. Wondering if that might need to be part of the solution anyways? Have a mechanism to call init on any model after creation (instead of just having it via init right now)... AND be able to specify if trainable params and/or buffers get init, etc...

@gau-nernst
Copy link
Contributor Author

gau-nernst commented Jan 3, 2025

I discovered that when we init non-persistent buffers on CPU device, then later call .to_empty(device="cpu") on the meta-device model, the non-persistent buffers will be filled with unallocated memory, hence the issue. Will try to handle this... Persistent buffers and params don't have this issue because they will be loaded from state_dict

In many cases you want to be able to call the init explicitly after meta-device creation / sharding,

I don't know about other frameworks, but at least for PyTorch's FSDP1/2, they rely on .reset_parameters() https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md#meta-device-initialization. Anyway, I think we don't need to think too much on this since it's more for init from scratch, while this PR only involves init model with pre-trained weights.

@rwightman
Copy link
Collaborator

@gau-nernst the FSDP1 way did, but FSDP2 doesn't rely on reset_parameters... I was thinking that might be a model to follow. If I do add FSDP2 support into timm it will be 2.

with torch.device("meta"):
    model = Transformer()
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
    assert tensor.device == torch.device("meta")
# Allocate buffers and sharded parameters on GPU
model.to_empty(device="cuda")
# Run user-defined initializers
model.init_weights() # or `model.apply(init_weights)`

Depending on how messy addressing the non-persistent buffers turns, if we had a reliable model.init_weights() fn (or similar) then for this functionality

  • init with meta context (and don't worry about the buffers)
  • load checkpoint
  • model.init_weights(buffers_only=True) -- whether to re-init both persistent and non-persistent buffers or allow specifying either / both, not sure

And for things like FSDP2,

  • init with meta context
  • shard
  • model.init_weights()

etc,

@rwightman
Copy link
Collaborator

If you do find a clean approach to current buffers issue, can proceed as is of course.

Ultimately I feel this will need to move in a direction where there is a consistent init_weights interface for the models though, to address all the use cases of using the meta device and having an option of decoupling weight init from model.__init__() for various scenarios. That can be done later though.

@gau-nernst
Copy link
Contributor Author

gau-nernst commented Jan 3, 2025

Pushed a quick fix

        # .to_empty() will also move cpu params/buffers to uninitialized storage.
        # this is problematic for non-persistent buffers, since they don't get loaded
        # from pretrained weights later (not part of state_dict). hence, we have
        # to save them before calling .to_empty() and fill them back after.
        buffers = {k: v for k, v in model.named_buffers() if not v.is_meta}
        model.to_empty(device="cpu")
        for k, v in model.named_buffers():
            if k in buffers:
                v.data = buffers[k]

        # alternative, rely on internal method ._apply()
        # model._apply(lambda t: torch.empty_like(t, device="cpu") if t.is_meta else t)

        # for reference, .to_empty() is implemented with
        # self._apply(lambda t: torch.empty_like(t, device=device))

Tested with Swin and it seems to work correctly. Will manually check other models... The alternative approach using ._apply() is cleaner, but relies on internal method. Lmk if you prefer one over another.

Extra thoughts. This meta-device problem concerns params+buffers that don't exist in pretrained weights. Non-persistent buffer is one. I believe some checkpoints also don't have classifier head, such as DINO, which won't be initialized correctly if users set num_classes in timm.create_model() (I checked and confirmed it is an issue). Perhaps you would know other similar cases? Changing num_classes or in_chans might have issues too, haven't checked...

I think solution is either

  1. Detect which params are missing from pretrained checkpoint, then initialize the params separately. (harder, but better UX)
  2. Only use fast meta-device init when no extra arguments passed to timm.create_model(). (more conservative, easier, but less useful)

@rwightman
Copy link
Collaborator

@gau-nernst hmm, yeah it's not just models without the classifier, any use case where num classes doesn't match and a new different classifier gets created (old one removed from state_dict) would be a problem no?

The old classifier is removed from the weights, and it relies on the init to provide a replacement. Also complicating an 'easy' work around, the head is initialized differently by different model inits, so simply running model.get_classifier().reset_parameters() would result in differing behaviour than right now... hmm.

if num_classes != pretrained_cfg['num_classes']:
for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights
state_dict.pop(classifier_name + '.weight', None)
state_dict.pop(classifier_name + '.bias', None)
strict = False

Changing in_chans should be fine because it modifies the state_dict input conv weight.

@rwightman
Copy link
Collaborator

FYI, I'm running some bulk evals to compare against existing numbers.

@gau-nernst
Copy link
Contributor Author

Added a check to use meta-device init only when num_classes == 0 or num_classes is not changed. Have a question for you.

# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))

Here, determining num_classes_pretrained requires having an instantiated model. What would be the case that model.num_classes != pretrained_cfg["num_classes"]? i.e. Do we need an instantiated model to determine num_classes? I'm asking because we need to know num_classes before model init so that we can skip meta-device if num_classes is different from pretrained_cfg["num_classes"].

@rwightman
Copy link
Collaborator

@gau-nernst hmm, so the idea with the num_classes ultimately coming from the model was that it is the ultimate source of truth wrt to what config/args will result in the final model classifier shape that will need to match the pretrained weights. That num_classes_pretrained name was unfortunate, not sure why I did that, num_classes_model would have been more clear.

Right now what's passed in via kwargs to override the default is I think transparent for all models, it ends up matching model.num_classes. I'd envisioned scenarios where that might not be the case.

Changing the num_classes is such a common use case here, it would definitely diminish the utility to not support that case.

@rwightman
Copy link
Collaborator

Something else came up in the larger test, models using BlurPool are failing.

Also, thinking about this a bit more, this solution is going to cause issues with some use cases... namely anyone use torch.set_default_device or using device context managers themselves, it will break.

The use of 'cpu' for the tensor init fns, while allowing the meta device context to be used, will break anyone trying to use a default device or context manager for another, different device. This applies at both the model or layer level, some people do use timm just for specific layers that aren't elsewhere. And several layers now have device='cpu'...

@gau-nernst
Copy link
Contributor Author

Will look into BlurPool models, but I will be travelling soon, so there will be few updates from me until next week.

Also, thinking about this a bit more, this solution is going to cause issues with some use cases... namely anyone use torch.set_default_device or using device context managers themselves, it will break.

I also mentioned this previously, but didn't think it would be a big problem. We can always do model.to(torch.get_default_device()) after everything. But for "layer level", it will indeed break people's code...

It seems like you are leaning towards having sth like model.init_weights(buffers_only=True). In summary, these are the following kinds of tensors we need to handle:

  1. Parameters and persistent buffers: this is part of state_dict. We can init it with meta device
  2. Non-persistent buffers: this is NOT part of state_dict. We have to init it to current device, after model.to_empty(), and either before or after model.load_state_dict()
  3. Head parameters when num_classes changes: we need to re-init based on original init logic -> have to refactor .init_weights() for all models
  4. Temporary utility tensors (e.g. torch.linspace() for stochastic depth): ok to set to CPU, since we will call .item() on each element later.

Point 3 requires factoring out head init logic for all models, so there is no workaround. Point 1 and 4 are covered in this PR. This PR can provide workaround for point 2, though it doesn't look very "nice" (and also breaking at the "layer level" mentioned above).

Will you prefer putting this PR on hold while working on refactoring model.init_weights()? I'm guessing if we do .init_weights() for all models, it will span several PRs, then we can revisit this later (use meta device in model builder).

@rwightman
Copy link
Collaborator

@gau-nernst yeah, it's at the layer level I was getting more concerned... there is no way to realiably track down what the uses of various layers are in the wild, it could just break a lot of things.

This is mostly working right now so that's promising but I feel a custom init fn with consistent naming/interface could help avoid the layer and model breakage concerns by allowing patching / skipping those functions, etc ... maybe adding some state tracking to handle head re-init. I will think about this a bit more...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants