Skip to content

Commit

Permalink
Optimizer / scheduler rework [backward compatible].
Browse files Browse the repository at this point in the history
- malpolon.models.utils: Changed behavior of check_optimizer() and added check_scheduler() to allow users to input one or several optimizers (and optionally 1 scheduler per optimizer, possibly with a lr_scheduler_config descriptor) via their config files.
- malpolon.models.standard_prediction_systems: changed instantiation of optimizer(s) and scheduler(s) in class GenericPredictionSystem. The class attributes are now lists of instantiated optimizers (respectively, of lr_scheduler_config dictionaries). Updated behavior of method configure_optimizers() to return a dictionary containing all the optimizers and scheudlers (cf. https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers).
- malpolon.tests.test_models.utils: Added all corresponding unit tests, testing both valid scenarios and edge cases of incorrect user inputs in the config file.
- sentinel-2a-rgbnir_bioclim example: updated the config file to fit previously described changes.
  • Loading branch information
tlarcher committed Oct 14, 2024
1 parent 6d0a373 commit 293595b
Show file tree
Hide file tree
Showing 17 changed files with 618 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,30 @@ model:

optimizer:
optimizer:
# sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# callable: 'optim.SGD'
kwargs:
lr: 1e-2
weight_decay: 0
momentum: 0.9
nesterov: true
scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler
reduce_lr_on_plateau:
# callable: 'lr_scheduler.reduce_lr_on_plateau'
kwargs:
threshold: 0.001
lr_scheduler_config:
scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key
monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...]
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# # callable: 'optim.SGD'
# kwargs:
# lr: 1e-2
# weight_decay: 0
# momentum: 0.9
# nesterov: true
# scheduler:
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
metrics:
# binary_accuracy:
# # callable: 'Fmetrics.accuracy'
Expand Down
15 changes: 11 additions & 4 deletions malpolon/models/standard_prediction_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
model: Union[torch.nn.Module, Mapping],
loss: torch.nn.modules.loss._Loss,
optimizer: Union[torch.optim.Optimizer, Mapping],
scheduler: Union[torch.optim.Optimizer, Mapping] = None,
scheduler: Union[torch.optim.Optimizer] = None,
metrics: Optional[dict[str, Callable]] = None,
save_hyperparameters: Optional[bool] = True,
):
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
self.checkpoint_path = None if not hasattr(self, 'checkpoint_path') else self.checkpoint_path # Avoids overwriting the attribute. This class will need to be re-written properly alongside ClassificationSystem
self.model = check_model(model)
self.optimizer, config_scheduler = check_optimizer(optimizer, self.model)
self.scheduler = check_scheduler(config_scheduler, self.optimizer) if (isinstance(optimizer, torch.optim.Optimizer) and scheduler is not None) else config_scheduler
self.scheduler = config_scheduler if scheduler is None else check_scheduler(scheduler, self.optimizer)
self.loss = check_loss(loss)
self.metrics = metrics or {}
if len(self.optimizer) > 1:
Expand Down Expand Up @@ -191,7 +191,13 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self(x)

def configure_optimizers(self) -> dict:
return self.optimizer, self.scheduler
res = []
for i, opt in enumerate(self.optimizer):
tmp = {'optimizer': opt, 'lr_scheduler': self.scheduler[i]}
if tmp['lr_scheduler'] is None:
tmp.pop('lr_scheduler')
res.append(tmp)
return res

@staticmethod
def state_dict_replace_key(
Expand Down Expand Up @@ -441,5 +447,6 @@ def __init__(
"accuracy": {'callable': Fmetrics.classification.binary_accuracy,
'kwargs': {}}
}

# from torch.optim import Optimizer, lr_scheduler
# sch = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
super().__init__(model, loss, optimizer, metrics=metrics)
70 changes: 39 additions & 31 deletions malpolon/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def check_model(model: Union[nn.Module, Mapping]) -> nn.Module:


def check_scheduler(scheduler: Union[LRScheduler, dict],
optimizer: optim.Optimizer) -> LRScheduler:
optimizer: optim.Optimizer) -> dict:
"""Ensure input scheduler is a pytorch scheduler.
Input can either be an Omegaconf mapping (passed through a hydra config
Expand All @@ -162,34 +162,44 @@ def check_scheduler(scheduler: Union[LRScheduler, dict],
Returns
-------
LRScheduler
list of instantiated scheduler(s)
dict
dictionary of LR scheduler config
"""
if scheduler is None:
return

lr_sch_config = {'scheduler': None}

if isinstance(scheduler, LRScheduler):
return [scheduler]
lr_sch_config['scheduler'] = scheduler
return lr_sch_config

try:
k, v = next(iter(scheduler.items())) # Get 1st key & value of scheduler dict as there can only be 1 scheduler per optimizer
if 'lr_scheduler_config' in v and v['lr_scheduler_config'] is not None:
lr_sch_config = lr_sch_config | v['lr_scheduler_config']
if 'callable' in v:
scheduler[k]['callable'] = eval(v['callable'])
v['callable'] = eval(v['callable'])
else:
scheduler[k]['callable'] = SCHEDULER_CALLABLES[k]
scheduler = scheduler[k]['callable'](optimizer, **scheduler[k]['kwargs'])
v['callable'] = SCHEDULER_CALLABLES[k]
scheduler = v['callable'](optimizer, **v['kwargs'])
except ValueError as e:
print('\n[WARNING]: Please make sure you have registered'
print('\n[ERROR]: Please make sure you have registered'
' a dict-like value to your "scheduler" key in your'
' config file. Defaulting scheduler to None.\n')
' config file.\n')
print(e, '\n')
scheduler = None
raise e
except KeyError as e:
print('\n[WARNING]: Please make sure the name of your scheduler'
print('\n[ERROR]: Please make sure the name of your scheduler'
' registered in your config file match an entry'
' in constant SCHEDULER_CALLABLES.'
' Defaulting scheduler to None.\n')
' in constant SCHEDULER_CALLABLES; or that you have provided a'
' callable function if your scheduler\'s name is not pre-registered'
' in SCHEDULER_CALLABLES.\n')
print(e, '\n')
scheduler = None
raise e

return scheduler
lr_sch_config['scheduler'] = scheduler
return lr_sch_config


def check_optimizer(optimizer: Union[Optimizer, OmegaConf],
Expand Down Expand Up @@ -217,7 +227,7 @@ def check_optimizer(optimizer: Union[Optimizer, OmegaConf],
scheduler_list = []

if isinstance(optimizer, Optimizer):
return [optimizer], scheduler_list
return [optimizer], [None]

try:
if optimizer is not None:
Expand All @@ -229,25 +239,23 @@ def check_optimizer(optimizer: Union[Optimizer, OmegaConf],
else:
optimizer[k]['callable'] = OPTIMIZERS_CALLABLES[k]
optim_list.append(optimizer[k]['callable'](model.parameters(), **optimizer[k]['kwargs']))
if 'scheduler' in v and v['scheduler'] is not None:
scheduler_list.append(check_scheduler(v['scheduler'], optim_list[-1]))

except ValueError as e:
print('\n[WARNING]: Please make sure you have registered'
' a dict-like value to your "optimizer" key in your'
' config file. Defaulting optimizer to None.\n')
scheduler_list.append(check_scheduler(v.get('scheduler'), optim_list[-1]))
except (TypeError, ValueError) as e:
print('\n[ERROR]: Please make sure you have registered'
' a non-empty dict-like value to your "optimizer" key in your'
' config file. Your optimizer dict might be empty (NoneType).')
print(e, '\n')
optimizer = None
raise e
except KeyError as e:
print('\n[WARNING]: Please make sure the name of your optimizer'
print('\n[ERROR]: Please make sure the name of your optimizer'
' registered in your config file match an entry'
' in constant OPTIMIZERS_CALLABLES.'
' Defaulting optimizer to None.\n')
' in constant OPTIMIZERS_CALLABLES; or that you have provided a'
' callable function if your optimizer\'s name is not pre-registered'
' in OPTIMIZERS_CALLABLES.\n'
' Please make sure your optimizer\'s and scheduler\'s kwargs keys'
' are valid.\n')
print(e, '\n')
optimizer = None

if len(optim_list) > 1 and len(scheduler_list) >= 1:
assert len(optim_list) == len(scheduler_list), "When using multiple optimizers, there should be as many schedulers as there are optimizers, or none at all."
raise e

return optim_list, scheduler_list

30 changes: 30 additions & 0 deletions malpolon/tests/data/test_models_utils_configs/1_opt_0_sch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

optimizer:
optimizer:
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# callable: 'optim.SGD'
kwargs:
lr: 1e-2
weight_decay: 0
momentum: 0.9
nesterov: true
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
# lr_scheduler_config:
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...]
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# # callable: 'optim.SGD'
# kwargs:
# lr: 1e-2
# scheduler:
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
30 changes: 30 additions & 0 deletions malpolon/tests/data/test_models_utils_configs/1_opt_1_sch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

optimizer:
optimizer:
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# callable: 'optim.SGD'
kwargs:
lr: 1e-2
weight_decay: 0
momentum: 0.9
nesterov: true
scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler
reduce_lr_on_plateau:
# callable: 'lr_scheduler.reduce_lr_on_plateau'
kwargs:
threshold: 0.001
lr_scheduler_config:
scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key
monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...]
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# # callable: 'optim.SGD'
# kwargs:
# lr: 1e-2
# scheduler:
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
30 changes: 30 additions & 0 deletions malpolon/tests/data/test_models_utils_configs/2_opt_1_sch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

optimizer:
optimizer:
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# callable: 'optim.SGD'
kwargs:
lr: 1e-2
weight_decay: 0
momentum: 0.9
nesterov: true
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
# lr_scheduler_config:
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...]
adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# callable: 'optim.SGD'
kwargs:
lr: 1e-2
scheduler:
reduce_lr_on_plateau:
# callable: 'lr_scheduler.reduce_lr_on_plateau'
kwargs:
threshold: 0.001
30 changes: 30 additions & 0 deletions malpolon/tests/data/test_models_utils_configs/2_opt_2_sch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

optimizer:
optimizer:
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# callable: 'optim.SGD'
kwargs:
lr: 1e-2
weight_decay: 0
momentum: 0.9
nesterov: true
scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler
reduce_lr_on_plateau:
# callable: 'lr_scheduler.reduce_lr_on_plateau'
kwargs:
threshold: 0.001
lr_scheduler_config:
scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key
monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...]
adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# callable: 'optim.SGD'
kwargs:
lr: 1e-2
scheduler:
reduce_lr_on_plateau:
# callable: 'lr_scheduler.reduce_lr_on_plateau'
kwargs:
threshold: 0.001
30 changes: 30 additions & 0 deletions malpolon/tests/data/test_models_utils_configs/error_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

optimizer:
optimizer:
# sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# # callable: 'optim.SGD'
# kwargs:
# lr: 1e-2
# weight_decay: 0
# momentum: 0.9
# nesterov: true
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
# lr_scheduler_config:
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...]
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# # callable: 'optim.SGD'
# kwargs:
# lr: 1e-2
# scheduler:
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
30 changes: 30 additions & 0 deletions malpolon/tests/data/test_models_utils_configs/error_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

optimizer:
optimizer:
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# # callable: 'optim.SGD'
# kwargs:
# lr: 1e-2
# weight_decay: 0
# momentum: 0.9
# nesterov: true
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
# lr_scheduler_config:
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...]
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop']
# # callable: 'optim.SGD'
# kwargs:
# lr: 1e-2
# scheduler:
# reduce_lr_on_plateau:
# # callable: 'lr_scheduler.reduce_lr_on_plateau'
# kwargs:
# threshold: 0.001
Loading

0 comments on commit 293595b

Please sign in to comment.