-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimizer / scheduler rework [backward compatible].
- malpolon.models.utils: Changed behavior of check_optimizer() and added check_scheduler() to allow users to input one or several optimizers (and optionally 1 scheduler per optimizer, possibly with a lr_scheduler_config descriptor) via their config files. - malpolon.models.standard_prediction_systems: changed instantiation of optimizer(s) and scheduler(s) in class GenericPredictionSystem. The class attributes are now lists of instantiated optimizers (respectively, of lr_scheduler_config dictionaries). Updated behavior of method configure_optimizers() to return a dictionary containing all the optimizers and scheudlers (cf. https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers). - malpolon.tests.test_models.utils: Added all corresponding unit tests, testing both valid scenarios and edge cases of incorrect user inputs in the config file. - sentinel-2a-rgbnir_bioclim example: updated the config file to fit previously described changes.
- Loading branch information
Showing
17 changed files
with
618 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
malpolon/tests/data/test_models_utils_configs/1_opt_0_sch.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
hydra: | ||
run: | ||
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} | ||
|
||
optimizer: | ||
optimizer: | ||
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# callable: 'optim.SGD' | ||
kwargs: | ||
lr: 1e-2 | ||
weight_decay: 0 | ||
momentum: 0.9 | ||
nesterov: true | ||
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 | ||
# lr_scheduler_config: | ||
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key | ||
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...] | ||
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# # callable: 'optim.SGD' | ||
# kwargs: | ||
# lr: 1e-2 | ||
# scheduler: | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 |
30 changes: 30 additions & 0 deletions
30
malpolon/tests/data/test_models_utils_configs/1_opt_1_sch.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
hydra: | ||
run: | ||
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} | ||
|
||
optimizer: | ||
optimizer: | ||
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# callable: 'optim.SGD' | ||
kwargs: | ||
lr: 1e-2 | ||
weight_decay: 0 | ||
momentum: 0.9 | ||
nesterov: true | ||
scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler | ||
reduce_lr_on_plateau: | ||
# callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
kwargs: | ||
threshold: 0.001 | ||
lr_scheduler_config: | ||
scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key | ||
monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...] | ||
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# # callable: 'optim.SGD' | ||
# kwargs: | ||
# lr: 1e-2 | ||
# scheduler: | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 |
30 changes: 30 additions & 0 deletions
30
malpolon/tests/data/test_models_utils_configs/2_opt_1_sch.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
hydra: | ||
run: | ||
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} | ||
|
||
optimizer: | ||
optimizer: | ||
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# callable: 'optim.SGD' | ||
kwargs: | ||
lr: 1e-2 | ||
weight_decay: 0 | ||
momentum: 0.9 | ||
nesterov: true | ||
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 | ||
# lr_scheduler_config: | ||
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key | ||
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...] | ||
adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# callable: 'optim.SGD' | ||
kwargs: | ||
lr: 1e-2 | ||
scheduler: | ||
reduce_lr_on_plateau: | ||
# callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
kwargs: | ||
threshold: 0.001 |
30 changes: 30 additions & 0 deletions
30
malpolon/tests/data/test_models_utils_configs/2_opt_2_sch.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
hydra: | ||
run: | ||
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} | ||
|
||
optimizer: | ||
optimizer: | ||
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# callable: 'optim.SGD' | ||
kwargs: | ||
lr: 1e-2 | ||
weight_decay: 0 | ||
momentum: 0.9 | ||
nesterov: true | ||
scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler | ||
reduce_lr_on_plateau: | ||
# callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
kwargs: | ||
threshold: 0.001 | ||
lr_scheduler_config: | ||
scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key | ||
monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...] | ||
adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# callable: 'optim.SGD' | ||
kwargs: | ||
lr: 1e-2 | ||
scheduler: | ||
reduce_lr_on_plateau: | ||
# callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
kwargs: | ||
threshold: 0.001 |
30 changes: 30 additions & 0 deletions
30
malpolon/tests/data/test_models_utils_configs/error_1.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
hydra: | ||
run: | ||
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} | ||
|
||
optimizer: | ||
optimizer: | ||
# sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# # callable: 'optim.SGD' | ||
# kwargs: | ||
# lr: 1e-2 | ||
# weight_decay: 0 | ||
# momentum: 0.9 | ||
# nesterov: true | ||
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 | ||
# lr_scheduler_config: | ||
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key | ||
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...] | ||
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# # callable: 'optim.SGD' | ||
# kwargs: | ||
# lr: 1e-2 | ||
# scheduler: | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 |
30 changes: 30 additions & 0 deletions
30
malpolon/tests/data/test_models_utils_configs/error_2.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
hydra: | ||
run: | ||
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} | ||
|
||
optimizer: | ||
optimizer: | ||
sgd: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# # callable: 'optim.SGD' | ||
# kwargs: | ||
# lr: 1e-2 | ||
# weight_decay: 0 | ||
# momentum: 0.9 | ||
# nesterov: true | ||
# scheduler: # Optional, delete key or leave empty to not use any learning rate scheduler | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 | ||
# lr_scheduler_config: | ||
# scheduler: reduce_lr_on_plateau # Optional, the scheduler to use is the parent key | ||
# monitor: loss/val # ['loss/train', 'loss/val', '<metric>/train', '<metric>/val', ...] | ||
# adam: # ['adam', 'sgd', 'adamw', 'adadelta', 'adagrad', 'adamax', 'rmsprop'] | ||
# # callable: 'optim.SGD' | ||
# kwargs: | ||
# lr: 1e-2 | ||
# scheduler: | ||
# reduce_lr_on_plateau: | ||
# # callable: 'lr_scheduler.reduce_lr_on_plateau' | ||
# kwargs: | ||
# threshold: 0.001 |
Oops, something went wrong.