From 9fcb92fe60bb70b98985d53320484254a4ff375b Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Fri, 6 Sep 2024 18:43:31 +0200 Subject: [PATCH] Added glc24 pre_extracted habitat example in inference version --- .../config/glc24_cnn_multimodal_ensemble.yaml | 2 +- ...glc24_cnn_multimodal_ensemble_habitat.yaml | 4 +- .../glc24_cnn_multimodal_ensemble_habitat.py | 44 ++++++++++++------- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble.yaml b/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble.yaml index 81beda82..29894766 100644 --- a/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble.yaml +++ b/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble.yaml @@ -4,7 +4,7 @@ hydra: run: predict_type: 'test_dataset' # choose from ['test_dataset', 'test_point'] - checkpoint_path: "outputs_training/glc24_cnn_multimodal_ensemble/runOK_2024-08-12_11-50-01/last.ckpt" + checkpoint_path: ??? data: root: "dataset/geolifeclef-2024/" diff --git a/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble_habitat.yaml b/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble_habitat.yaml index 6a9965c2..8ebabbc9 100644 --- a/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble_habitat.yaml +++ b/examples/inference/geolifeclef2024_pre_extracted/config/glc24_cnn_multimodal_ensemble_habitat.yaml @@ -3,8 +3,8 @@ hydra: dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} run: - predict: false - checkpoint_path: # "outputs/glc24_cnn_multimodal_ensemble_habitat/2024-08-30_20-21-26_multiclass/last.ckpt" + predict_type: false + checkpoint_path: ??? data: root: "dataset/geolifeclef-2024_habitats/" diff --git a/examples/inference/geolifeclef2024_pre_extracted/glc24_cnn_multimodal_ensemble_habitat.py b/examples/inference/geolifeclef2024_pre_extracted/glc24_cnn_multimodal_ensemble_habitat.py index 5fa914aa..e71f11bb 100644 --- a/examples/inference/geolifeclef2024_pre_extracted/glc24_cnn_multimodal_ensemble_habitat.py +++ b/examples/inference/geolifeclef2024_pre_extracted/glc24_cnn_multimodal_ensemble_habitat.py @@ -65,39 +65,51 @@ def main(cfg: DictConfig): checkpoint_path=cfg.run.checkpoint_path, weights_dir=log_dir, num_classes=cfg.data.num_classes) # multiclass + model_loaded = ClassificationSystemGLC24.load_from_checkpoint(classif_system.checkpoint_path, + model=classif_system.model, + hparams_preprocess=False, + strict=False) # Lightning Trainer callbacks = [ Summary(), ModelCheckpoint( dirpath=log_dir, - filename="checkpoint-{epoch:02d}-{step}-{loss/val:.4f}", - monitor="loss/val", + filename="checkpoint-{epoch:02d}-{step}-{" + f"loss/val" + ":.4f}", + monitor=f"loss/val", mode="min", - save_on_train_epoch_end=True, - save_last=True, - every_n_train_steps=75, ), ] trainer = pl.Trainer(logger=[logger_csv, logger_tb], callbacks=callbacks, **cfg.trainer, deterministic=True) # Run - if cfg.run.predict: - model_loaded = ClassificationSystemGLC24.load_from_checkpoint(classif_system.checkpoint_path, - model=classif_system.model, - hparams_preprocess=False, - strict=False) - + if cfg.run.predict_type == 'test_dataset': + # Option 1: Predict on the entire test dataset (Pytorch Lightning) predictions = model_loaded.predict(datamodule, trainer) preds, probas = datamodule.predict_logits_to_class(predictions, - np.arange(cfg.data.num_classes)) + np.arange(cfg.data.num_classes), + activation_fn=torch.nn.Sigmoid()) datamodule.export_predict_csv(preds, probas, - out_dir=log_dir, out_name='predictions_test_dataset', top_k=None, return_csv=True) + out_dir=log_dir, out_name='predictions_test_dataset', top_k=25, return_csv=True) print('Test dataset prediction (extract) : ', predictions[:1]) - else: - trainer.fit(classif_system, datamodule=datamodule, ckpt_path=classif_system.checkpoint_path) - trainer.validate(classif_system, datamodule=datamodule) + elif cfg.run.predict_type == 'test_point': + # Option 2: Predict 1 data point (Pytorch) + test_data = datamodule.get_test_dataset() + test_data_point = list(test_data[0][:3]) + for i, d in enumerate(test_data_point): + test_data_point[i] = d.unsqueeze(0) + query_point = {'observation_id': [test_data[0][-1]], + 'lon': None, 'lat': None, + 'crs': None, + 'species_id': [test_data[0][-1]]} + prediction = model_loaded.predict_point(cfg.run.checkpoint_path, + test_data_point) + preds, probas = datamodule.predict_logits_to_class(prediction, + np.arange(cfg.data.num_classes)) + datamodule.export_predict_csv(preds, probas, + out_dir=log_dir, out_name='prediction_point', single_point_query=query_point, return_csv=True) + print('Point prediction : ', prediction.shape, prediction) if __name__ == "__main__":