diff --git a/examples/custom_train/sentinel-2a-rgbnir/cnn_on_rgbnir_torchgeo.py b/examples/custom_train/sentinel-2a-rgbnir/cnn_on_rgbnir_torchgeo.py index a991888c..e5be483c 100644 --- a/examples/custom_train/sentinel-2a-rgbnir/cnn_on_rgbnir_torchgeo.py +++ b/examples/custom_train/sentinel-2a-rgbnir/cnn_on_rgbnir_torchgeo.py @@ -83,7 +83,7 @@ def main(cfg: DictConfig) -> None: test_data_point = test_data_point.resize_(1, *test_data_point.shape) prediction = model_loaded.predict_point(cfg.run.checkpoint_path, - test_data_point]) + test_data_point) preds, probas = datamodule.predict_logits_to_class(prediction, datamodule.get_test_dataset().unique_labels) datamodule.export_predict_csv(preds, probas, diff --git a/malpolon/models/standard_prediction_systems.py b/malpolon/models/standard_prediction_systems.py index 25932893..23373722 100644 --- a/malpolon/models/standard_prediction_systems.py +++ b/malpolon/models/standard_prediction_systems.py @@ -339,7 +339,7 @@ def predict_point( self.load_state_dict(ckpt['state_dict']) self.model.eval() with torch.no_grad(): - if '__iter__' in dir(data): + if isinstance(data, (tuple, list, set, dict)): for i, d in enumerate(data): data[i] = d.to(device) if isinstance(d, torch.Tensor) else d prediction = self.model(*data) diff --git a/malpolon/tests/test_examples.py b/malpolon/tests/test_examples.py index 71cda14c..813a0f56 100644 --- a/malpolon/tests/test_examples.py +++ b/malpolon/tests/test_examples.py @@ -258,7 +258,7 @@ ], } -# @pytest.mark.skip(reason="Slow or no guarantee of having the data available.") +@pytest.mark.skip(reason="Slow or no guarantee of having the data available.") def test_train_inference_examples(): ckpt_path = '' for expe_name, v in EXAMPLE_PATHS.items(): @@ -304,7 +304,7 @@ def test_train_inference_examples(): print(f'\n{INFO}[INFO] Done. {RESET}') -@pytest.mark.skip(reason="Slow or no guarantee of having the data available.") +# @pytest.mark.skip(reason="Slow or no guarantee of having the data available.") def test_GLC22_examples(): ckpt_path = '' for expe_name, v in GLC22_EXAMPLE_PATHS.items(): @@ -353,7 +353,7 @@ def test_GLC22_examples(): print(f'\n{INFO}[INFO] Done. {RESET}') -# @pytest.mark.skip(reason="Slow or no guarantee of having the data available.") +@pytest.mark.skip(reason="Slow or no guarantee of having the data available.") def test_GLC23_examples(): ckpt_path = '' for expe_name, v in GLC23_EXAMPLE_PATHS.items(): @@ -402,7 +402,7 @@ def test_GLC23_examples(): print(f'\n{INFO}[INFO] Done. {RESET}') -@pytest.mark.skip(reason="Impossible for pytest to run because user input is needed to validate data download.") +# @pytest.mark.skip(reason="Impossible for pytest to run because user input is needed to validate data download.") def test_GLC24_pre_extracted_examples(): ckpt_path = '' for expe_name, v in GLC24_PRE_EXTRACTED_EXAMPLE_PATHS.items(): diff --git a/requirements_python3.10.txt b/requirements_python3.10.txt index fd875ba5..5c8e07d1 100644 --- a/requirements_python3.10.txt +++ b/requirements_python3.10.txt @@ -9,7 +9,7 @@ ipykernel==6.18.2 ipython==8.17.2 jupyter==1.0.0 jupyter-server==2.11.2 -jupyterlab==3.6.7 +jupyterlab==3.6.8 kaggle==1.5.16 lightning==2.0.9.post0 Markdown==3.4.1 @@ -19,7 +19,7 @@ numpy==1.26.4 odc-geo==0.3.2 odc-stac==0.3.3 omegaconf==2.3.0 -opencv-python==4.7.0.72 +opencv-python==4.8.1.78 pandas==2.2.1 Pillow==10.3.0 planetary-computer==1.0.0 @@ -30,7 +30,7 @@ pyproj==3.6.1 pystac==1.6.1 pystac-client==0.5.1 pytest==7.2.2 -pytorch-lightning==2.1.0 +pytorch-lightning==2.3.3 PyYAML==6.0.1 rasterio==1.3.8.post1 scikit-learn==1.5.0 @@ -42,10 +42,10 @@ tensorboard==2.14.1 protobuf==4.25 tifffile==2022.10.10 timm==0.9.2 -torch==2.1.0 +torch==2.2.0 torchgeo==0.5.0 torchmetrics==1.2.0 -torchvision==0.16.0 +torchvision==0.17.0 tqdm==4.66.3 verde==1.8.0 yarl==1.9.2