Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

timm bits #804

Open
wants to merge 101 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
12d9a6d
First timm.bits commit, add initial abstractions, WIP updates to trai…
rwightman Apr 21, 2021
76de984
Fix some bugs with XLA support, logger, add hacky xla dist launch scr…
rwightman Apr 21, 2021
938716c
Fix import issue, use devenv for dist info in parser_tfds
rwightman Apr 21, 2021
aa92d7b
Major timm.bits update. Updater and DeviceEnv now dataclasses, after_…
rwightman May 17, 2021
74d2829
Merge branch 'master' into bits_and_tpu
rwightman May 17, 2021
6d90fcf
Fix distribute_bn and model_ema
rwightman May 18, 2021
cbd4ee7
Fix model init for XLA, remove some prints.
rwightman May 18, 2021
72ca831
Back to using strings for the enum translation, forgot about import dep
rwightman May 19, 2021
4210d92
Merge branch 'master' into bits_and_tpu
rwightman May 20, 2021
5b9c69e
Add basic training resume based on legacy code
rwightman May 22, 2021
91ab0b6
Add proper TrainState checkpoint save/load. Some reorg/refactoring an…
rwightman Jun 4, 2021
b57a03b
Merge branch 'master' into bits_and_tpu
rwightman Jun 4, 2021
f411724
Fix checkpoint delete issue. Add README about bits and initial Pytorc…
rwightman Jun 4, 2021
c3db5f5
Worker hack for TFDS eval, add TPU env var setting.
rwightman Jun 4, 2021
6b2d9c2
Another bits/README.md update
rwightman Jun 4, 2021
cc870df
Update README.md
rwightman Jun 4, 2021
ee2b8f4
Update README.md
rwightman Jun 4, 2021
5c5cadf
Update README.md
rwightman Jun 4, 2021
847b4af
Update README.md
rwightman Jun 6, 2021
56ed0a0
Merge branch 'vit_and_bit_test_fixes' into bits_and_tpu
rwightman Jun 23, 2021
5e95ced
timm bits checkpoint support for avg_checkpoints.py
rwightman Jun 23, 2021
40457e5
Transforms, augmentation work for bits, add RandomErasing support for…
rwightman Aug 13, 2021
c06c739
Merge branch 'master' into bits_and_tpu
rwightman Aug 13, 2021
b974d85
Merge branch 'bits_and_tpu' of github.com:rwightman/pytorch-image-mod…
rwightman Aug 13, 2021
cb621e0
Remove print, arg order
rwightman Aug 13, 2021
f98662b
Merge branch 'master' into bits_and_tpu
rwightman Aug 18, 2021
b76b48e
Update optimizer creation for master optimizer changes
rwightman Aug 18, 2021
0d82876
Add comment for reference re PyTorch XLA 'race' issue
rwightman Aug 18, 2021
b0265ef
Merge branch 'master' into bits_and_tpu
rwightman Aug 18, 2021
f4fb068
Merge branch 'master' into bits_and_tpu
rwightman Aug 19, 2021
2ee398d
Merge branch 'master' into bits_and_tpu
rwightman Aug 19, 2021
f2e1468
Add force-cpu flag for train/validate, fix CPU fallback for device in…
rwightman Aug 22, 2021
c2f02b0
Merge remote-tracking branch 'origin/attn_update' into bits_and_tpu
rwightman Sep 5, 2021
3581aff
Update train.py with some flags related to scheduler tweaks, fix best…
rwightman Sep 5, 2021
25d52ea
Merge remote-tracking branch 'origin/fixes_bce_regnet' into bits_and_tpu
rwightman Sep 25, 2021
52c481e
Merge remote-tracking branch 'origin/fixes_bce_regnet' into bits_and_tpu
rwightman Sep 27, 2021
1fdc7af
Merge remote-tracking branch 'origin/fixes_bce_regnet' into bits_and_tpu
rwightman Oct 1, 2021
3b6ba76
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Oct 22, 2021
690f31d
Post merge cleanup, restore previous unwrap fn
rwightman Oct 24, 2021
a45186a
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Oct 27, 2021
59a3409
Update README.md
rwightman Nov 4, 2021
07693f8
Validation fix since we don't have multi-GPU DataParallel support yet
rwightman Nov 10, 2021
406c486
Merge remote-tracking branch 'origin/more_datasets' into bits_and_tpu
rwightman Nov 11, 2021
80ca078
Fix a few bugs and formatting/naming issues
rwightman Nov 11, 2021
d9b0b3d
device arg wasn't removed from PrefetcherCuda instantiation of RE
rwightman Nov 12, 2021
4f33855
Fixes and improvements for metrics, tfds parser, loader / transform h…
rwightman Nov 12, 2021
871cef4
version 0.5.1
rwightman Nov 12, 2021
809c7bb
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Nov 21, 2021
cad170e
Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
rwightman Dec 1, 2021
0e212e8
Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
rwightman Dec 1, 2021
820ae99
Fix load_state_dict to handle None ema entries
rwightman Dec 3, 2021
69e90dc
Merge branch 'norm_norm_norm' into bits_and_tpu
rwightman Dec 5, 2021
ff0f709
Testing TFDS shuffle across epochs
rwightman Dec 8, 2021
7bbbd5e
EvoNorm and GroupNormAct options for debugging TPU / XLA concerns
rwightman Dec 8, 2021
66daee4
Last change wasn't complete, missed adding full evo_norm changeset
rwightman Dec 8, 2021
88a5b54
A few small evonorm tweaks for convergence comparisons
rwightman Dec 10, 2021
4d7a554
Remove inplace sigmoid for consistency with other impl
rwightman Dec 10, 2021
1f54a1f
Add C16 and E8 EvoNormS0 configs for RegNetZ BYOB nets
rwightman Dec 11, 2021
57fca2b
Fix c16_evos stem / first conv setup
rwightman Dec 11, 2021
d829858
Significant norm update
rwightman Dec 14, 2021
1c21cac
Add drop args to benchmark.py
rwightman Dec 14, 2021
40f4745
Merge branch 'norm_norm_norm' into bits_and_tpu
rwightman Dec 14, 2021
cbc4f33
Merge branch 'norm_norm_norm' into bits_and_tpu
rwightman Dec 16, 2021
0012bf7
Merge branch 'norm_norm_norm' into bits_and_tpu
rwightman Dec 17, 2021
4c8bb29
Remove bn-tf arg
rwightman Dec 17, 2021
7eb7e73
File will not stay deleted
rwightman Dec 17, 2021
066e490
Merge branch 'norm_norm_norm' into bits_and_tpu
rwightman Jan 28, 2022
f82fb6b
Add base lr w/ linear and sqrt scaling to train script
rwightman Jan 28, 2022
7148039
Tweak base lr log
rwightman Jan 28, 2022
fafece2
Allow changing base lr batch size from 256 via arg
rwightman Jan 28, 2022
a16ea1e
Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
rwightman Mar 1, 2022
c639a86
Change TFDS default to full re-shuffle (init) each epoch (for now)
rwightman Mar 1, 2022
10fa42b
Merge branch 'ChristophReich1996-master' into bits_and_tpu
rwightman Mar 1, 2022
bb85b09
swin v2 fixup for latest changes on norm_norm_norm / bits_and_tpu branch
rwightman Mar 1, 2022
15cc9ea
Fix Swin v2 tuple type hint
rwightman Mar 1, 2022
3fce010
Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
rwightman Mar 1, 2022
da2796a
Add webdataset (WDS) support, update TFDS to make some naming in pars…
rwightman Mar 8, 2022
a444d4b
Add alternative label support to WDS for imagenet22k/12k split, add 2…
rwightman Mar 9, 2022
229ac6b
Fix alternate label handling in WDS parser to skip invalid alt labels
rwightman Mar 12, 2022
7eeaf52
use gopen in wds to open info file in case it's at a url/gs location
rwightman Mar 12, 2022
ab16a35
Add log and continue handler for WDS errors, fix args.num_gpu for val…
rwightman Mar 16, 2022
ef57561
Fix some TPU (XLA) issues with swin transformer v2
rwightman Mar 16, 2022
59ffab5
Fix mistake in wds sample slicing
rwightman Mar 17, 2022
5e1be34
Add ImageNet-22k/12k TFDS dataset defs
rwightman Mar 18, 2022
95739b4
Fix partially removed alt_lable impl from TFDS variant of ImageNet22/12k
rwightman Mar 18, 2022
749856c
Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
rwightman Mar 22, 2022
1ba0ec4
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Mar 26, 2022
754e114
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Apr 4, 2022
c76d772
Add support for different TFDS `BuilderConfig`s
dedeswim Apr 28, 2022
9c321be
Merge pull request #1239 from dedeswim/parser-fix
rwightman Apr 28, 2022
dff3373
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman May 25, 2022
1186fc9
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Jul 13, 2022
6fe0199
verions 0.8.x for bits_and_tpu branch
rwightman Jul 13, 2022
bd6d377
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Jul 28, 2022
5a40c6a
Fix issue with torchvision's ImageNet
dedeswim Aug 17, 2022
f07dfe0
Merge remote-tracking branch 'origin/more_vit' into bits_and_tpu
rwightman Aug 24, 2022
1dced60
Merge branch 'rwightman:bits_and_tpu' into bits_and_tpu
dedeswim Aug 26, 2022
87bfb05
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Sep 2, 2022
38594ef
Merge remote-tracking branch 'origin/master' into bits_and_tpu
rwightman Sep 7, 2022
b4ea69c
Merge pull request #1414 from dedeswim/bits_and_tpu
rwightman Oct 14, 2022
a25bf97
Update README.md
rwightman Nov 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

Thanks to the following for hardware support:
* TPU Research Cloud (TRC) (https://sites.research.google/trc/about/)
* TPU support can be found on the [`bits_and_tpu`](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/) branch, w/ some setup help [here](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits)
* Nvidia (https://www.nvidia.com/en-us/)

And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face.
Expand Down
31 changes: 14 additions & 17 deletions clean_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
import hashlib
import shutil
from collections import OrderedDict

from timm.models.helpers import load_state_dict
from timm.utils import setup_default_logging

parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
Expand All @@ -30,28 +32,25 @@

def main():
args = parser.parse_args()
setup_default_logging()

if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)

clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn)


def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint))
state_dict = load_state_dict(checkpoint, use_ema=use_ema)
new_state_dict = {}
if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> Loading checkpoint '{}'".format(args.checkpoint))
state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if clean_aux_bn and 'aux_bn' in k:
if args.clean_aux_bn and 'aux_bn' in k:
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
# load with the unmodified model using BatchNorm2d.
continue
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint))
print("=> Loaded state_dict from '{}'".format(args.checkpoint))

try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
Expand All @@ -61,19 +60,17 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()

if output:
checkpoint_root, checkpoint_base = os.path.split(output)
if args.output:
checkpoint_root, checkpoint_base = os.path.split(args.output)
checkpoint_base = os.path.splitext(checkpoint_base)[0]
else:
checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0]
checkpoint_base = os.path.splitext(args.checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename
else:
print("Error: Checkpoint ({}) doesn't exist".format(checkpoint))
return ''
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch

from timm.models import create_model, apply_test_time_pool
from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.data import ImageDataset, create_loader_v2, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging

torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -82,7 +82,7 @@ def main():
else:
model = model.cuda()

loader = create_loader(
loader = create_loader_v2(
ImageDataset(args.data),
input_size=config['input_size'],
batch_size=args.batch_size,
Expand Down
66 changes: 66 additions & 0 deletions launch_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.

`torch.distributed.launch` is a module that spawns up multiple distributed
training processes on each of the training nodes.

"""


import sys
import subprocess
import importlib
import os
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO

import torch_xla.distributed.xla_multiprocessing as xmp


def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description="PyTorch distributed training launch helper utility"
"that will spawn up multiple distributed processes")

# Optional arguments for the launch helper
parser.add_argument("--num-devices", type=int, default=1,
help="The number of XLA devices to use for distributed training")

# positional
parser.add_argument(
"script", type=str,
help="The full path to the single device training script to be launched"
"in parallel, followed by all the arguments for the training script")

# rest from the training program
parser.add_argument('script_args', nargs=REMAINDER)
return parser.parse_args()


def main():
args = parse_args()

# set PyTorch distributed related environmental variables
# current_env = os.environ.copy()
# current_env["MASTER_ADDR"] = args.master_addr
# current_env["MASTER_PORT"] = str(args.master_port)
# current_env["WORLD_SIZE"] = str(dist_world_size)
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
# current_env["OMP_NUM_THREADS"] = str(1)

script_abs = os.path.abspath(args.script)
script_base, script_rel = os.path.split(script_abs)
sys.path.append(script_base)
mod = importlib.import_module(os.path.splitext(script_rel)[0])

sys.argv = [args.script] + args.script_args

xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)


if __name__ == "__main__":
main()
Loading