From cb85ff97e063885900dca3d3e11d96d4543ca46c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20=C5=A0m=C3=ADd?= Date: Thu, 19 Mar 2020 22:23:26 +0100 Subject: [PATCH 1/3] add trainer debugging option "overfit_single_batch" This reduces the training set to a single batch and turns off the validation set. Training on a single batch should quickly overfit and reach accuracy 1.0. This is a recommended step for debugging neural networks, see https://twitter.com/karpathy/status/1013244313327681536 --- base/base_data_loader.py | 13 +++++++++++-- base/base_trainer.py | 1 + train.py | 6 +++++- trainer/trainer.py | 5 +++-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/base/base_data_loader.py b/base/base_data_loader.py index 91a0d98a..27c296ae 100644 --- a/base/base_data_loader.py +++ b/base/base_data_loader.py @@ -8,14 +8,23 @@ class BaseDataLoader(DataLoader): """ Base class for all data loaders """ - def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): + def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate, + single_batch=False): self.validation_split = validation_split self.shuffle = shuffle self.batch_idx = 0 self.n_samples = len(dataset) - self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) + if not single_batch: + self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) + else: + idx_full = np.arange(self.n_samples) + np.random.seed(0) + np.random.shuffle(idx_full) + self.sampler = SubsetRandomSampler(idx_full[:batch_size]) + self.valid_sampler = None + self.shuffle = None self.init_kwargs = { 'dataset': dataset, diff --git a/base/base_trainer.py b/base/base_trainer.py index 93e8b670..9bc6469c 100644 --- a/base/base_trainer.py +++ b/base/base_trainer.py @@ -26,6 +26,7 @@ def __init__(self, model, criterion, metric_ftns, optimizer, config): self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') + self.overfit_single_batch = cfg_trainer.get('overfit_single_batch', False) # configuration to monitor model performance and save best if self.monitor == 'off': diff --git a/train.py b/train.py index 004d5354..923eebc0 100644 --- a/train.py +++ b/train.py @@ -17,11 +17,15 @@ torch.backends.cudnn.benchmark = False np.random.seed(SEED) + def main(config): logger = config.get_logger('train') # setup data_loader instances - data_loader = config.init_obj('data_loader', module_data) + if config['trainer'].get('overfit_single_batch', False): + data_loader = config.init_obj('data_loader', module_data, single_batch=True) + else: + data_loader = config.init_obj('data_loader', module_data) valid_data_loader = data_loader.split_validation() # build model architecture, then print to console diff --git a/trainer/trainer.py b/trainer/trainer.py index d87ea834..c5153666 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -59,8 +59,9 @@ def _train_epoch(self, epoch): loss.item())) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) - if batch_idx == self.len_epoch: + if batch_idx == self.len_epoch or self.overfit_single_batch: break + log = self.train_metrics.result() if self.do_validation: @@ -81,7 +82,7 @@ def _valid_epoch(self, epoch): self.model.eval() self.valid_metrics.reset() with torch.no_grad(): - for batch_idx, (data, target) in enumerate(self.valid_data_loader): + for batch_idx, (data, target, _) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) From 2229f45170a08b3600c96b794daad6d0339f8fb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20=C5=A0m=C3=ADd?= Date: Sun, 22 Mar 2020 00:12:51 +0100 Subject: [PATCH 2/3] simplify single batch overfitting --- train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 923eebc0..545c341a 100644 --- a/train.py +++ b/train.py @@ -22,10 +22,8 @@ def main(config): logger = config.get_logger('train') # setup data_loader instances - if config['trainer'].get('overfit_single_batch', False): - data_loader = config.init_obj('data_loader', module_data, single_batch=True) - else: - data_loader = config.init_obj('data_loader', module_data) + data_loader = config.init_obj('data_loader', module_data, + single_batch=config['trainer'].get('overfit_single_batch', False)) valid_data_loader = data_loader.split_validation() # build model architecture, then print to console From b2d551252346c582f8c6d85adc7b82bd39eb82b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20=C5=A0m=C3=ADd?= Date: Sun, 22 Mar 2020 00:32:47 +0100 Subject: [PATCH 3/3] add overfit_single_batch option to sample config.json --- config.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config.json b/config.json index 0339e6ab..c77f496c 100644 --- a/config.json +++ b/config.json @@ -45,6 +45,7 @@ "monitor": "min val_loss", "early_stop": 10, - "tensorboard": true + "tensorboard": true, + "overfit_single_batch": false } }