Skip to content

Commit

Permalink
Model weights download (#56)
Browse files Browse the repository at this point in the history
* Changed custom models location in new module 'malpolon.models.custom_models'. This includes glc24 pre_extracted MME model and multi_modal.py. For MME: classificationsystem and nn module have been split in 2 files to allow calling MME from model_builder without triggering a circular import through check_model. Updated examples consequently.

* Fix: state_dict altered during training.
- state_dict contains a loss parameter pos_weight as key loss.pos_weight. This key is created when the loss is instantiated by GenericPredictionSystem. However, this loss parameter was accessed and modified during the _step() process, which also alters the state_dict. Consequently, when loading the model by its checkpoint, there would be a value mismatch and the model would not load to resume training. This has been fixed by restoring the initial value of the loss parameter within the _step() function before the return statement.
- 'positive_weigh_factor' model hyperparameter has been deleted and replaced by loss parameter 'pos_weight', which achieves the same purpose. In the config file, 'positive_weigh_factor' model key has been substituted for subkey 'pos_weight' nested under 'loss_kwargs' nested in the optimizer section

* Cleaned remainings of previous commit testing

* Added download weight option for all classification system and updated checkopoint_path call for MME example

* Fixed wrong checkpoint_path path initialization behavior.
- glc24_cnn_multimodal_ensemble: updated example config file and main script to new checkpoint_path behavior, in both training and inference runs
- standard_prediction_systems.py: Fixed wrong checkpoint_path path initialization behavior
- glc2024_pre_extracted_prediction_system.py: added missing checkpoint_path argument and removed checkpoint_path setter as it is carried out by GenericPredictionSystem

* Updated example cnn_on_rgbnir_torchgeo following checkpoint_path update

* Updated example cnn_on_rgbnir_concat following checkpoint_path update

* Updated example cnn_on_rgbnir_glc23_patches following checkpoint_path update

* Reset yaml file glc23 example

* Fixed wrong variable assignment in exmaples micro_geolifeclef2022/cnn_on_rgb_nir_patches  and micro_geolifeclef2022/cnn_on_rgb_patches

* Added predict run part in example geolifeclef2022/cnn_on_rgb_patches and updated main script following checkpoint_path update.
- data_module: Added more flexibility for predictions without targets
- geolifeclef2022 dataset: Added default -1 value for targets in predict mode to comply with standard_prediction_system predict() method

* Updated glc22 and microglc22 examples following checkpoint_path update, and added inference part in the run section for those which didn't have one. Added input argument in custom GLC22 datamodules + model output in prediction mode, to such extent.

* Updated CIFAR-10 example following checkpoint_path update

* Updated all inference examples following checkpoint_path update

* Removed duplicate import

* Updated code docstrings

* Fixed task value from binary to multilabel (doesn't change behavior)

* Added 'malpolon' as model providers.
- model_builder: Added provider method and created new dictionary with model names as keys, and local imports of models as values

- data_module: Added posisblity of applying no activation function when running inference, so as to output the model's logits. Enhanced CSV export method's info prints.

- glc2024_multimodal_ensemble_model: Added new init argument and class attribute 'pretrained' which the datmaodule uses to determine whether to download pretrained weights (formerly: a standalone 'weights_download' variable was used by the datamodule). Added docstrings.

- glc2024_pre_extracted_prediction_system: Changed handling behavior of the model's loss during '_step()' to prevent overwritting the loss parameter during training which resulted in a de-synchronization of the state_dcit() before and after running the model (since loss parameters are automatically added as learnable parameters)
- glc24_cnn_multimodal_ensemble.yaml: Updated config file accordingly. Cleaned config file with correct values.
- glc24_cnn_multimodal_ensemble.py: Updated MME main srcipt accordingly. Changed activation function of inference run from softmax() to sigmoid()

* Updated glc22 tests following class getter changes

* Removed commented dict
  • Loading branch information
tlarcher authored Aug 13, 2024
1 parent 57edabc commit 44e9d68
Show file tree
Hide file tree
Showing 38 changed files with 642 additions and 322 deletions.
24 changes: 17 additions & 7 deletions examples/benchmarks/cifar-10/cnn_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@ class Cifar10Datamodule(BaseDataModule):
num_workers : int
Number of workers to use for loading data.
"""

def __init__(self, dataset_path: str, batch_size: int, num_workers: int, **kwargs):
def __init__(self,
dataset_path: str,
train_batch_size: int,
inference_batch_size: int,
num_workers: int,
**kwargs):
super().__init__()
self.dataset_path = dataset_path
self.batch_size = batch_size
self.train_batch_size = train_batch_size
self.inference_batch_size = inference_batch_size
self.num_workers = num_workers
self.__dict__.update(kwargs)
self.cifar10_train = None
Expand Down Expand Up @@ -104,15 +109,19 @@ def main(cfg: DictConfig) -> None:
hydra config dictionary created from the .yaml config file
associated with this script.
"""
# Loggers
log_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
logger_csv = pl.loggers.CSVLogger(log_dir, name="", version="")
logger_csv.log_hyperparams(cfg)
logger_tb = pl.loggers.TensorBoardLogger(log_dir, name="tensorboard_logs", version="")
logger_tb.log_hyperparams(cfg)

# Datamodule & Model
datamodule = Cifar10Datamodule(**cfg.data, **cfg.task)
model = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task)
classif_system = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task,
checkpoint_path=cfg.run.checkpoint_path)

# Lightning Trainer
callbacks = [
Summary(),
ModelCheckpoint(
Expand All @@ -126,9 +135,10 @@ def main(cfg: DictConfig) -> None:
]
trainer = pl.Trainer(logger=[logger_csv, logger_tb], callbacks=callbacks, **cfg.trainer)

# Run
if cfg.run.predict:
model_loaded = ClassificationSystem.load_from_checkpoint(cfg.run.checkpoint_path,
model=model.model,
model=classif_system.model,
hparams_preprocess=False)

# Option 1: Predict on the entire test dataset (Pytorch Lightning)
Expand All @@ -149,8 +159,8 @@ def main(cfg: DictConfig) -> None:
df.to_csv(os.path.join(log_dir, 'scores_test_dataset.csv'), index=False)
print('Test dataset prediction (extract) : ', predictions[:1])
else:
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.run.checkpoint_path)
trainer.validate(model, datamodule=datamodule)
trainer.fit(classif_system, datamodule=datamodule, ckpt_path=cfg.run.checkpoint_path)
trainer.validate(classif_system, datamodule=datamodule)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/benchmarks/cifar-10/config/cnn_cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ hydra:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

run:
predict: false
checkpoint_path: # "outputs/cnn_cifar10/30_epochs_no_pretrained/last.ckpt"
predict: true
checkpoint_path: "outputs/cnn_cifar10/30_epochs_pretrained/last.ckpt"

data:
num_classes: &num_classes 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path

import hydra
import numpy as np
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -43,11 +44,13 @@ def __init__(
inference_batch_size: int = 256,
num_workers: int = 8,
download: bool = False,
task: str = 'classification_multiclass',
):
super().__init__(train_batch_size, inference_batch_size, num_workers)
self.dataset_path = dataset_path
self.minigeolifeclef = minigeolifeclef
self.download = download
self.task = task

@property
def train_transform(self):
Expand Down Expand Up @@ -96,17 +99,19 @@ def get_dataset(self, split, transform, **kwargs):

@hydra.main(version_base="1.3", config_path="config", config_name="mono_modal_3_channels_model")
def main(cfg: DictConfig) -> None:

# Loggers
log_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
logger_csv = pl.loggers.CSVLogger(log_dir, name="", version="")
logger_csv.log_hyperparams(cfg)
logger_tb = pl.loggers.TensorBoardLogger(Path(log_dir)/Path(cfg.loggers.log_dir_name), name=cfg.loggers.exp_name, version="")
logger_tb.log_hyperparams(cfg)

datamodule = GeoLifeCLEF2022DataModule(**cfg.data)

model = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task)
# Datamodule & Model
datamodule = GeoLifeCLEF2022DataModule(**cfg.data, **cfg.task)
classif_system = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task,
checkpoint_path=cfg.run.checkpoint_path)

# Lightning Trainer
callbacks = [
Summary(),
ModelCheckpoint(
Expand All @@ -119,9 +124,23 @@ def main(cfg: DictConfig) -> None:
),
]
trainer = pl.Trainer(logger=[logger_csv, logger_tb], callbacks=callbacks, **cfg.trainer)
trainer.fit(model, datamodule=datamodule)

trainer.validate(model, datamodule=datamodule)
# Run
if cfg.run.predict:
model_loaded = ClassificationSystem.load_from_checkpoint(classif_system.checkpoint_path,
model=classif_system.model,
hparams_preprocess=False)

# Option 1: Predict on the entire test dataset (Pytorch Lightning)
predictions = model_loaded.predict(datamodule, trainer)
preds, probas = datamodule.predict_logits_to_class(predictions,
np.arange(datamodule.get_test_dataset().n_classes))
datamodule.export_predict_csv(preds, probas,
out_dir=log_dir, out_name='predictions_test_dataset', top_k=3, return_csv=True)
print('Test dataset prediction (extract) : ', predictions[:1])
else:
trainer.fit(classif_system, datamodule=datamodule, ckpt_path=classif_system.checkpoint_path)
trainer.validate(classif_system, datamodule=datamodule)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path

import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig
Expand All @@ -22,7 +23,8 @@
MiniGeoLifeCLEF2022Dataset)
from malpolon.data.environmental_raster import PatchExtractor
from malpolon.logging import Summary
from malpolon.models.multi_modal import HomogeneousMultiModalModel
from malpolon.models.custom_models.multi_modal import \
HomogeneousMultiModalModel
from malpolon.models.standard_prediction_systems import ClassificationSystem


Expand All @@ -45,10 +47,12 @@ def __init__(
train_batch_size: int = 32,
inference_batch_size: int = 256,
num_workers: int = 8,
task: str = 'classification_multiclass',
):
super().__init__(train_batch_size, inference_batch_size, num_workers)
self.dataset_path = dataset_path
self.minigeolifeclef = minigeolifeclef
self.task = task

@property
def train_transform(self):
Expand Down Expand Up @@ -117,29 +121,32 @@ def __init__(
num_outputs: int,
cfg_optimizer: DictConfig,
cfg_task: DictConfig,
checkpoint_path: str = None,
):
model = HomogeneousMultiModalModel(
["rgb", "temperature"],
modalities_model,
torch.nn.LazyLinear(num_outputs),
)

super().__init__(model, **cfg_optimizer, **cfg_task)
super().__init__(model, **cfg_optimizer, **cfg_task, checkpoint_path=checkpoint_path)


@hydra.main(version_base="1.3", config_path="config", config_name="homogeneous_multi_modal_model")
def main(cfg: DictConfig) -> None:

# Loggers
log_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
logger_csv = pl.loggers.CSVLogger(log_dir, name="", version="")
logger_csv.log_hyperparams(cfg)
logger_tb = pl.loggers.TensorBoardLogger(Path(log_dir)/Path(cfg.loggers.log_dir_name), name=cfg.loggers.exp_name, version="")
logger_tb.log_hyperparams(cfg)

datamodule = GeoLifeCLEF2022DataModule(**cfg.data)

model = CustomClassificationSystem(**cfg.model, cfg_optimizer=cfg.optimizer, cfg_task=cfg.task)
# Datamodule & Model
datamodule = GeoLifeCLEF2022DataModule(**cfg.data, **cfg.task)
classif_system = CustomClassificationSystem(**cfg.model, cfg_optimizer=cfg.optimizer, cfg_task=cfg.task,
checkpoint_path=cfg.run.checkpoint_path)

# Lightning Trainer
callbacks = [
Summary(),
ModelCheckpoint(
Expand All @@ -152,9 +159,22 @@ def main(cfg: DictConfig) -> None:
),
]
trainer = pl.Trainer(logger=[logger_csv, logger_tb], callbacks=callbacks, **cfg.trainer)
trainer.fit(model, datamodule=datamodule)

trainer.validate(model, datamodule=datamodule)
if cfg.run.predict:
model_loaded = CustomClassificationSystem.load_from_checkpoint(classif_system.checkpoint_path,
model=classif_system.model,
hparams_preprocess=False)

# Option 1: Predict on the entire test dataset (Pytorch Lightning)
predictions = model_loaded.predict(datamodule, trainer)
preds, probas = datamodule.predict_logits_to_class(predictions,
np.arange(datamodule.get_test_dataset().n_classes))
datamodule.export_predict_csv(preds, probas,
out_dir=log_dir, out_name='predictions_test_dataset', top_k=3, return_csv=True)
print('Test dataset prediction (extract) : ', predictions[:1])
else:
trainer.fit(classif_system, datamodule=datamodule)
trainer.validate(classif_system, datamodule=datamodule)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path

import hydra
import numpy as np
import pytorch_lightning as pl
from cnn_on_rgb_patches import ClassificationSystem
from omegaconf import DictConfig
Expand Down Expand Up @@ -43,10 +44,12 @@ def __init__(
train_batch_size: int = 32,
inference_batch_size: int = 256,
num_workers: int = 8,
task: str = 'classification_multiclass',
):
super().__init__(train_batch_size, inference_batch_size, num_workers)
self.dataset_path = dataset_path
self.minigeolifeclef = minigeolifeclef
self.task = task

@property
def train_transform(self):
Expand Down Expand Up @@ -100,17 +103,19 @@ def get_dataset(self, split, transform, **kwargs):

@hydra.main(version_base="1.3", config_path="config", config_name="mono_modal_3_channels_model")
def main(cfg: DictConfig) -> None:

# Loggers
log_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
logger_csv = pl.loggers.CSVLogger(log_dir, name="", version="")
logger_csv.log_hyperparams(cfg)
logger_tb = pl.loggers.TensorBoardLogger(Path(log_dir)/Path(cfg.loggers.log_dir_name), name=cfg.loggers.exp_name, version="")
logger_tb.log_hyperparams(cfg)

datamodule = GeoLifeCLEF2022DataModule(**cfg.data)

model = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task)
# Datamodule & Model
datamodule = GeoLifeCLEF2022DataModule(**cfg.data, **cfg.task)
classif_system = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task,
checkpoint_path=cfg.run.checkpoint_path)

# Lightning Trainer
callbacks = [
Summary(),
ModelCheckpoint(
Expand All @@ -123,9 +128,23 @@ def main(cfg: DictConfig) -> None:
),
]
trainer = pl.Trainer(logger=[logger_csv, logger_tb], callbacks=callbacks, **cfg.trainer)
trainer.fit(model, datamodule=datamodule)

trainer.validate(model, datamodule=datamodule)
# Run
if cfg.run.predict:
model_loaded = ClassificationSystem.load_from_checkpoint(classif_system.checkpoint_path,
model=classif_system.model,
hparams_preprocess=False)

# Option 1: Predict on the entire test dataset (Pytorch Lightning)
predictions = model_loaded.predict(datamodule, trainer)
preds, probas = datamodule.predict_logits_to_class(predictions,
np.arange(datamodule.get_test_dataset().n_classes))
datamodule.export_predict_csv(preds, probas,
out_dir=log_dir, out_name='predictions_test_dataset', top_k=3, return_csv=True)
print('Test dataset prediction (extract) : ', predictions[:1])
else:
trainer.fit(classif_system, datamodule=datamodule)
trainer.validate(classif_system, datamodule=datamodule)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

run:
predict: false
checkpoint_path:

trainer:
# gpus: 1 # Deprecated since pytorchlightning 1.7, removed in 2.0. Replaced by the 2 next attributes
accelerator: 'gpu'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ hydra:
run:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

run:
predict: false
checkpoint_path:

trainer:
# gpus: 1 # Deprecated since pytorchlightning 1.7, removed in 2.0. Replaced by the 2 next attributes
accelerator: 'gpu'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,31 +166,36 @@ def main(cfg: DictConfig) -> None:
hydra config dictionary created from the .yaml config file
associated with this script.
"""
# Loggers
log_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
logger_csv = pl.loggers.CSVLogger(log_dir, name="", version="")
logger_csv.log_hyperparams(cfg)
logger_tb = pl.loggers.TensorBoardLogger(log_dir, name="tensorboard_logs", version="")
logger_tb.log_hyperparams(cfg)

# Datamodule & Model
datamodule = Sentinel2PatchesDataModule(**cfg.data, **cfg.task)
model = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task)
classif_system = ClassificationSystem(cfg.model, **cfg.optimizer, **cfg.task,
checkpoint_path=cfg.run.checkpoint_path)

# Lightning Trainer
callbacks = [
Summary(),
ModelCheckpoint(
dirpath=log_dir,
filename="checkpoint-{epoch:02d}-{step}-{" + f"{next(iter(model.metrics.keys()))}/val" + ":.4f}",
monitor=f"{next(iter(model.metrics.keys()))}/val",
filename="checkpoint-{epoch:02d}-{step}-{" + f"{next(iter(classif_system.metrics.keys()))}/val" + ":.4f}",
monitor=f"{next(iter(classif_system.metrics.keys()))}/val",
mode="max",
save_on_train_epoch_end=True,
save_last=True,
),
]

trainer = pl.Trainer(logger=[logger_csv, logger_tb], callbacks=callbacks, num_sanity_val_steps=0, **cfg.trainer)

# Run
if cfg.run.predict:
model_loaded = ClassificationSystem.load_from_checkpoint(cfg.run.checkpoint_path,
model=model.model,
model=classif_system.model,
hparams_preprocess=False)

# Option 1: Predict on the entire test dataset (Pytorch Lightning)
Expand Down Expand Up @@ -225,8 +230,8 @@ def main(cfg: DictConfig) -> None:
print('Point prediction : ', prediction.shape, prediction)
else:
CrashHandler(trainer)
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.run.checkpoint_path)
trainer.validate(model, datamodule=datamodule)
trainer.fit(classif_system, datamodule=datamodule, ckpt_path=cfg.run.checkpoint_path)
trainer.validate(classif_system, datamodule=datamodule)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 44e9d68

Please sign in to comment.