Skip to content

Commit

Permalink
Added glc24 pre_extracted habitat example in inference version
Browse files Browse the repository at this point in the history
  • Loading branch information
tlarcher committed Sep 6, 2024
1 parent a62731d commit 9fcb92f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ hydra:

run:
predict_type: 'test_dataset' # choose from ['test_dataset', 'test_point']
checkpoint_path: "outputs_training/glc24_cnn_multimodal_ensemble/runOK_2024-08-12_11-50-01/last.ckpt"
checkpoint_path: ???

data:
root: "dataset/geolifeclef-2024/"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ hydra:
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}

run:
predict: false
checkpoint_path: # "outputs/glc24_cnn_multimodal_ensemble_habitat/2024-08-30_20-21-26_multiclass/last.ckpt"
predict_type: false
checkpoint_path: ???

data:
root: "dataset/geolifeclef-2024_habitats/"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,39 +65,51 @@ def main(cfg: DictConfig):
checkpoint_path=cfg.run.checkpoint_path,
weights_dir=log_dir,
num_classes=cfg.data.num_classes) # multiclass
model_loaded = ClassificationSystemGLC24.load_from_checkpoint(classif_system.checkpoint_path,
model=classif_system.model,
hparams_preprocess=False,
strict=False)

# Lightning Trainer
callbacks = [
Summary(),
ModelCheckpoint(
dirpath=log_dir,
filename="checkpoint-{epoch:02d}-{step}-{loss/val:.4f}",
monitor="loss/val",
filename="checkpoint-{epoch:02d}-{step}-{" + f"loss/val" + ":.4f}",
monitor=f"loss/val",
mode="min",
save_on_train_epoch_end=True,
save_last=True,
every_n_train_steps=75,
),
]
trainer = pl.Trainer(logger=[logger_csv, logger_tb], callbacks=callbacks, **cfg.trainer, deterministic=True)

# Run
if cfg.run.predict:
model_loaded = ClassificationSystemGLC24.load_from_checkpoint(classif_system.checkpoint_path,
model=classif_system.model,
hparams_preprocess=False,
strict=False)

if cfg.run.predict_type == 'test_dataset':
# Option 1: Predict on the entire test dataset (Pytorch Lightning)
predictions = model_loaded.predict(datamodule, trainer)
preds, probas = datamodule.predict_logits_to_class(predictions,
np.arange(cfg.data.num_classes))
np.arange(cfg.data.num_classes),
activation_fn=torch.nn.Sigmoid())
datamodule.export_predict_csv(preds, probas,
out_dir=log_dir, out_name='predictions_test_dataset', top_k=None, return_csv=True)
out_dir=log_dir, out_name='predictions_test_dataset', top_k=25, return_csv=True)
print('Test dataset prediction (extract) : ', predictions[:1])

else:
trainer.fit(classif_system, datamodule=datamodule, ckpt_path=classif_system.checkpoint_path)
trainer.validate(classif_system, datamodule=datamodule)
elif cfg.run.predict_type == 'test_point':
# Option 2: Predict 1 data point (Pytorch)
test_data = datamodule.get_test_dataset()
test_data_point = list(test_data[0][:3])
for i, d in enumerate(test_data_point):
test_data_point[i] = d.unsqueeze(0)
query_point = {'observation_id': [test_data[0][-1]],
'lon': None, 'lat': None,
'crs': None,
'species_id': [test_data[0][-1]]}
prediction = model_loaded.predict_point(cfg.run.checkpoint_path,
test_data_point)
preds, probas = datamodule.predict_logits_to_class(prediction,
np.arange(cfg.data.num_classes))
datamodule.export_predict_csv(preds, probas,
out_dir=log_dir, out_name='prediction_point', single_point_query=query_point, return_csv=True)
print('Point prediction : ', prediction.shape, prediction)


if __name__ == "__main__":
Expand Down

0 comments on commit 9fcb92f

Please sign in to comment.