Skip to content

Commit

Permalink
Merge pull request huggingface#2394 from huggingface/non_reentrant_ckpt
Browse files Browse the repository at this point in the history
Wrap torch checkpoint() fn to default use_reentrant flag to False and allow env var override
  • Loading branch information
rwightman authored Jan 6, 2025
2 parents 131518c + 155f6e7 commit 6f80214
Show file tree
Hide file tree
Showing 24 changed files with 94 additions and 54 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

## What's New

## Jan 6, 2025
* Add `torch.utils.checkpoint.checkpoint()` wrapper in `timm.models` that defaults `use_reentrant=False`, unless `TIMM_REENTRANT_CKPT=1` is set in env.

## Dec 31, 2024
* `convnext_nano` 384x384 ImageNet-12k pretrain & fine-tune. https://huggingface.co/models?search=convnext_nano%20r384
* Add AIM-v2 encoders from https://github.com/apple/ml-aim, see on Hub: https://huggingface.co/models?search=timm%20aimv2
Expand Down
3 changes: 2 additions & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
set_reentrant_ckpt, use_reentrant_ckpt
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
Expand Down
18 changes: 17 additions & 1 deletion timm/layers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

__all__ = [
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn',
'set_reentrant_ckpt', 'use_reentrant_ckpt'
]

# Set to True if prefer to have layers with no jit optimization (includes activations)
Expand All @@ -34,6 +35,12 @@
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)


if 'TIMM_REENTRANT_CKPT' in os.environ:
_USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT'])
else:
_USE_REENTRANT_CKPT = False # defaults to disabled (off)


def is_no_jit():
return _NO_JIT

Expand Down Expand Up @@ -147,3 +154,12 @@ def set_fused_attn(enable: bool = True, experimental: bool = False):
_USE_FUSED_ATTN = 1
else:
_USE_FUSED_ATTN = 0


def use_reentrant_ckpt() -> bool:
return _USE_REENTRANT_CKPT


def set_reentrant_ckpt(enable: bool = True):
global _USE_REENTRANT_CKPT
_USE_REENTRANT_CKPT = enable
2 changes: 1 addition & 1 deletion timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
group_modules, group_parameters, checkpoint_seq, checkpoint, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
from ._prune import adapt_model_from_string
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
Expand Down
3 changes: 1 addition & 2 deletions timm/models/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from timm.layers import Format, _assert

from ._manipulate import checkpoint

__all__ = [
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
Expand Down
55 changes: 43 additions & 12 deletions timm/models/_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import re
from collections import defaultdict
from itertools import chain
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union

import torch
import torch.utils.checkpoint
from torch import nn as nn
from torch.utils.checkpoint import checkpoint

from timm.layers import use_reentrant_ckpt


__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint']


def model_parameters(model: nn.Module, exclude_head: bool = False):
Expand Down Expand Up @@ -183,13 +186,35 @@ def flatten_modules(
yield name, module


def checkpoint(
function,
*args,
use_reentrant: Optional[bool] = None,
**kwargs,
):
""" checkpoint wrapper fn
A thin wrapper around torch.utils.checkpoint.checkpoint to default
use_reentrant to False
"""
if use_reentrant is None:
use_reentrant = use_reentrant_ckpt()

return torch.utils.checkpoint.checkpoint(
function,
*args,
use_reentrant=use_reentrant,
**kwargs,
)


def checkpoint_seq(
functions,
x,
every=1,
flatten=False,
skip_last=False,
preserve_rng_state=True
every: int = 1,
flatten: bool = False,
skip_last: bool = False,
use_reentrant: Optional[bool] = None,
):
r"""A helper function for checkpointing sequential models.
Expand All @@ -215,10 +240,9 @@ def checkpoint_seq(
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
flatten: flatten nn.Sequential of nn.Sequentials
skip_last: skip checkpointing the last function in the sequence if True
use_reentrant: Use re-entrant checkpointing
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Expand All @@ -227,6 +251,9 @@ def checkpoint_seq(
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
if use_reentrant is None:
use_reentrant = use_reentrant_ckpt()

def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
Expand All @@ -247,7 +274,11 @@ def forward(_x):
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
x = torch.utils.checkpoint.checkpoint(
run_function(start, end, functions),
x,
use_reentrant=use_reentrant,
)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x
Expand Down
3 changes: 1 addition & 2 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid


from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model

__all__ = ['Beit']
Expand Down
5 changes: 2 additions & 3 deletions timm/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch.jit.annotations import List

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP
from ._manipulate import MATCH_PREV_GROUP, checkpoint
from ._registry import register_model, generate_default_cfgs, register_model_deprecations

__all__ = ['DenseNet']
Expand Down Expand Up @@ -60,7 +59,7 @@ def call_checkpoint_bottleneck(self, x):
def closure(*xs):
return self.bottleneck_fn(xs)

return cp.checkpoint(closure, *x)
return checkpoint(closure, *x)

@torch.jit._overload_method # noqa: F811
def forward(self, x):
Expand Down
3 changes: 1 addition & 2 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
Expand All @@ -51,7 +50,7 @@
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations

__all__ = ['EfficientNet', 'EfficientNetFeatures']
Expand Down
2 changes: 1 addition & 1 deletion timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
Expand All @@ -39,6 +38,7 @@

from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model

__all__ = ['Eva']
Expand Down
3 changes: 1 addition & 2 deletions timm/models/focalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model

__all__ = ['FocalNet']
Expand Down
3 changes: 1 addition & 2 deletions timm/models/gcvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
from ._registry import register_model, generate_default_cfgs

__all__ = ['GlobalContextVit']
Expand Down
3 changes: 1 addition & 2 deletions timm/models/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
Expand All @@ -39,7 +38,7 @@
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint


__all__ = ['Hiera']
Expand Down
3 changes: 1 addition & 2 deletions timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
Expand All @@ -21,7 +20,7 @@
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations

__all__ = ['MobileNetV3', 'MobileNetV3Features']
Expand Down
4 changes: 2 additions & 2 deletions timm/models/mvitv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from typing import Union, List, Tuple, Optional

import torch
import torch.utils.checkpoint as checkpoint
from torch import nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs

__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this

Expand Down
2 changes: 1 addition & 1 deletion timm/models/pvt_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs

__all__ = ['PyramidVisionTransformerV2']
Expand Down
4 changes: 2 additions & 2 deletions timm/models/swin_transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, ClassifierHead,\
resample_patch_embed, ndgrid, get_act_layer, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations

__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
Expand Down
3 changes: 1 addition & 2 deletions timm/models/swin_transformer_v2_cr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model

__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
Expand Down
Loading

0 comments on commit 6f80214

Please sign in to comment.