diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index fdcef5450d1..3569bbfef8c 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -26,6 +26,11 @@ L8 Biome .. autoclass:: L8BiomeDataModule +MMFlood +^^^^^^^^ + +.. autoclass:: MMFloodDataModule + NAIP ^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index c60b08f6666..6c5c57ff176 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -142,6 +142,10 @@ Landsat .. autoclass:: Landsat2 .. autoclass:: Landsat1 +MMFlood +^^^^^^^ +.. autoclass:: MMFlood + NAIP ^^^^ diff --git a/docs/api/datasets/geo_datasets.csv b/docs/api/datasets/geo_datasets.csv index 4bb5788609e..2bfafc39c39 100644 --- a/docs/api/datasets/geo_datasets.csv +++ b/docs/api/datasets/geo_datasets.csv @@ -20,6 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30" `LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5 `Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30 +`MMFlood`_,"Imagery,DEM,Masks","Sentinel, MapZen/TileZen, OpenStreetMap",MIT,"2,147x2,313",20 `NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2 `NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10 `NLCD`_,Masks,Landsat,"public domain",-,30 diff --git a/tests/conf/mmflood.yaml b/tests/conf/mmflood.yaml new file mode 100644 index 00000000000..1a6301080a7 --- /dev/null +++ b/tests/conf/mmflood.yaml @@ -0,0 +1,19 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: 'ce' + model: 'unet' + backbone: 'resnet18' + in_channels: 4 + num_classes: 2 + num_filters: 1 + ignore_index: 255 +data: + class_path: MMFloodDataModule + init_args: + batch_size: 1 + dict_kwargs: + root: 'tests/data/mmflood' + patch_size: 8 + include_dem: True + include_hydro: True diff --git a/tests/data/mmflood/activations.json b/tests/data/mmflood/activations.json new file mode 100644 index 00000000000..f9f2e3e901a --- /dev/null +++ b/tests/data/mmflood/activations.json @@ -0,0 +1 @@ +{"EMSR000": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR000_00"]}, "EMSR001": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR001_00"]}, "EMSR003": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "val", "delineations": ["EMSR003_00"]}, "EMSR004": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "test", "delineations": ["EMSR004_00"]}} \ No newline at end of file diff --git a/tests/data/mmflood/activations.tar.000.gz.part b/tests/data/mmflood/activations.tar.000.gz.part new file mode 100644 index 00000000000..af2e6bf26ee Binary files /dev/null and b/tests/data/mmflood/activations.tar.000.gz.part differ diff --git a/tests/data/mmflood/activations.tar.001.gz.part b/tests/data/mmflood/activations.tar.001.gz.part new file mode 100644 index 00000000000..c9d27cc9ef1 Binary files /dev/null and b/tests/data/mmflood/activations.tar.001.gz.part differ diff --git a/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-0.tif b/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-0.tif new file mode 100644 index 00000000000..457a4772988 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-1.tif b/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-1.tif new file mode 100644 index 00000000000..d75cdfbd4ec Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-2.tif b/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-2.tif new file mode 100644 index 00000000000..0a3e7fde60a Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/DEM/EMSR000-2.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-0.tif b/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-0.tif new file mode 100644 index 00000000000..26281f81091 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-1.tif b/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-1.tif new file mode 100644 index 00000000000..ac5f9f9cf7a Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-2.tif b/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-2.tif new file mode 100644 index 00000000000..cc78438487f Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/mask/EMSR000-2.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-0.tif b/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-0.tif new file mode 100644 index 00000000000..0a587035d2d Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-1.tif b/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-1.tif new file mode 100644 index 00000000000..ec6dbed7b70 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-2.tif b/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-2.tif new file mode 100644 index 00000000000..6bf55230734 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR000-0/s1_raw/EMSR000-2.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/DEM/EMSR001-0.tif b/tests/data/mmflood/activations/EMSR001-0/DEM/EMSR001-0.tif new file mode 100644 index 00000000000..5e3e246d191 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/DEM/EMSR001-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/DEM/EMSR001-1.tif b/tests/data/mmflood/activations/EMSR001-0/DEM/EMSR001-1.tif new file mode 100644 index 00000000000..5f4d785cf25 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/DEM/EMSR001-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/hydro/EMSR001-0.tif b/tests/data/mmflood/activations/EMSR001-0/hydro/EMSR001-0.tif new file mode 100644 index 00000000000..134e9479d74 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/hydro/EMSR001-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/hydro/EMSR001-1.tif b/tests/data/mmflood/activations/EMSR001-0/hydro/EMSR001-1.tif new file mode 100644 index 00000000000..a0c043f0d31 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/hydro/EMSR001-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/mask/EMSR001-0.tif b/tests/data/mmflood/activations/EMSR001-0/mask/EMSR001-0.tif new file mode 100644 index 00000000000..3cb5a7fcc63 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/mask/EMSR001-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/mask/EMSR001-1.tif b/tests/data/mmflood/activations/EMSR001-0/mask/EMSR001-1.tif new file mode 100644 index 00000000000..9e12af8b2b7 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/mask/EMSR001-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/s1_raw/EMSR001-0.tif b/tests/data/mmflood/activations/EMSR001-0/s1_raw/EMSR001-0.tif new file mode 100644 index 00000000000..5e06184db96 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/s1_raw/EMSR001-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR001-0/s1_raw/EMSR001-1.tif b/tests/data/mmflood/activations/EMSR001-0/s1_raw/EMSR001-1.tif new file mode 100644 index 00000000000..4a7aadce480 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR001-0/s1_raw/EMSR001-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR003-0/DEM/EMSR003-0.tif b/tests/data/mmflood/activations/EMSR003-0/DEM/EMSR003-0.tif new file mode 100644 index 00000000000..9c7ce27bbda Binary files /dev/null and b/tests/data/mmflood/activations/EMSR003-0/DEM/EMSR003-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR003-0/DEM/EMSR003-1.tif b/tests/data/mmflood/activations/EMSR003-0/DEM/EMSR003-1.tif new file mode 100644 index 00000000000..54b32ade880 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR003-0/DEM/EMSR003-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR003-0/hydro/EMSR003-0.tif b/tests/data/mmflood/activations/EMSR003-0/hydro/EMSR003-0.tif new file mode 100644 index 00000000000..61ece234e90 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR003-0/hydro/EMSR003-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR003-0/mask/EMSR003-0.tif b/tests/data/mmflood/activations/EMSR003-0/mask/EMSR003-0.tif new file mode 100644 index 00000000000..d414fddd446 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR003-0/mask/EMSR003-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR003-0/mask/EMSR003-1.tif b/tests/data/mmflood/activations/EMSR003-0/mask/EMSR003-1.tif new file mode 100644 index 00000000000..f8bc4908ac5 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR003-0/mask/EMSR003-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR003-0/s1_raw/EMSR003-0.tif b/tests/data/mmflood/activations/EMSR003-0/s1_raw/EMSR003-0.tif new file mode 100644 index 00000000000..44771a2b0dc Binary files /dev/null and b/tests/data/mmflood/activations/EMSR003-0/s1_raw/EMSR003-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR003-0/s1_raw/EMSR003-1.tif b/tests/data/mmflood/activations/EMSR003-0/s1_raw/EMSR003-1.tif new file mode 100644 index 00000000000..675768a5910 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR003-0/s1_raw/EMSR003-1.tif differ diff --git a/tests/data/mmflood/activations/EMSR004-0/DEM/EMSR004-0.tif b/tests/data/mmflood/activations/EMSR004-0/DEM/EMSR004-0.tif new file mode 100644 index 00000000000..c761aed58f4 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR004-0/DEM/EMSR004-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR004-0/hydro/EMSR004-0.tif b/tests/data/mmflood/activations/EMSR004-0/hydro/EMSR004-0.tif new file mode 100644 index 00000000000..02b5da0c1b3 Binary files /dev/null and b/tests/data/mmflood/activations/EMSR004-0/hydro/EMSR004-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR004-0/mask/EMSR004-0.tif b/tests/data/mmflood/activations/EMSR004-0/mask/EMSR004-0.tif new file mode 100644 index 00000000000..b686190d2bc Binary files /dev/null and b/tests/data/mmflood/activations/EMSR004-0/mask/EMSR004-0.tif differ diff --git a/tests/data/mmflood/activations/EMSR004-0/s1_raw/EMSR004-0.tif b/tests/data/mmflood/activations/EMSR004-0/s1_raw/EMSR004-0.tif new file mode 100644 index 00000000000..52a2e7ca04d Binary files /dev/null and b/tests/data/mmflood/activations/EMSR004-0/s1_raw/EMSR004-0.tif differ diff --git a/tests/data/mmflood/data.py b/tests/data/mmflood/data.py new file mode 100644 index 00000000000..4a684d401d5 --- /dev/null +++ b/tests/data/mmflood/data.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json +import os +import tarfile + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + + +def generate_data( + path: str, filename: str, height: int, width: int, include_hydro: bool = False +) -> None: + max_value = 1000.0 + min_value = 0.0 + interval = max_value - min_value + folders = ['s1_raw', 'DEM', 'mask', 'hydro'] + profile = { + 'driver': 'GTiff', + 'dtype': 'float32', + 'nodata': None, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( + 0.0001287974837883981, + 0.0, + 14.438064999669106, + 0.0, + -8.989523639880024e-05, + 45.71617928533084, + ), + 'blockysize': 1, + 'tiled': False, + 'interleave': 'pixel', + 'height': height, + 'width': width, + } + data = { + 's1_raw': np.random.rand(2, height, width).astype(np.float32) * interval + - min_value, + 'DEM': np.random.rand(1, height, width).astype(np.float32) * interval + - min_value, + 'mask': np.random.randint(low=0, high=2, size=(1, height, width)).astype( + np.uint8 + ), + } + + if include_hydro: + data['hydro'] = ( + np.random.rand(1, height, width).astype(np.float32) * interval - min_value + ) + + for folder in folders: + folder_path = os.path.join(path, folder) + os.makedirs(folder_path, exist_ok=True) + filepath = os.path.join(folder_path, filename) + profile2 = profile.copy() + profile2['count'] = 2 if folder == 's1_raw' else 1 + if folder in data: + with rasterio.open(filepath, mode='w', **profile2) as src: + src.write(data[folder]) + + +def generate_tar_gz(src: str, dst: str) -> None: + with tarfile.open(dst, 'w:gz') as tar: + tar.add(src, arcname=src) + + +def split_tar(path: str, dst: str, nparts: int) -> None: + fstats = os.stat(path) + size = fstats.st_size + chunk = size // nparts + + with open(path, 'rb') as fp: + for idx in range(nparts): + part_path = os.path.join(dst, f'activations.tar.{idx:03}.gz.part') + + bytes_to_write = chunk if idx < nparts - 1 else size - fp.tell() + with open(part_path, 'wb') as dst_fp: + dst_fp.write(fp.read(bytes_to_write)) + + +def generate_folders_and_metadata(datapath: str, metadatapath: str) -> None: + folders_splits = [ + ('EMSR000', 'train'), + ('EMSR001', 'train'), + ('EMSR003', 'val'), + ('EMSR004', 'test'), + ] + num_files = {'EMSR000': 3, 'EMSR001': 2, 'EMSR003': 2, 'EMSR004': 1} + num_hydro = {'EMSR001': 2, 'EMSR003': 1, 'EMSR004': 1} + metadata = {} + for folder, split in folders_splits: + data = {} + data['title'] = 'Test flood' + data['type'] = 'Flood' + data['country'] = 'N/A' + data['start'] = '2014-11-06T17:57:00' + data['end'] = '2015-01-29T12:47:04' + data['lat'] = 45.82427031690563 + data['lon'] = 14.484407562009336 + data['subset'] = split + data['delineations'] = [f'{folder}_00'] + + count_hydro = 0 + + dst_folder = os.path.join(datapath, f'{folder}-0') + for idx in range(num_files[folder]): + include_hydro = count_hydro < num_hydro.get(folder, 0) + generate_data( + dst_folder, + filename=f'{folder}-{idx}.tif', + height=16, + width=16, + include_hydro=include_hydro, + ) + if include_hydro: + count_hydro += 1 + + metadata[folder] = data + + generate_tar_gz(src='activations', dst='activations.tar.gz') + split_tar(path='activations.tar.gz', dst='.', nparts=2) + os.remove('activations.tar.gz') + with open(os.path.join(metadatapath, 'activations.json'), 'w') as fp: + json.dump(metadata, fp) + + +if __name__ == '__main__': + datapath = os.path.join(os.getcwd(), 'activations') + metadatapath = os.getcwd() + + generate_folders_and_metadata(datapath, metadatapath) diff --git a/tests/datasets/test_mmflood.py b/tests/datasets/test_mmflood.py new file mode 100644 index 00000000000..ce29c43b55f --- /dev/null +++ b/tests/datasets/test_mmflood.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from itertools import product +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from rasterio.crs import CRS + +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + MMFlood, + UnionDataset, +) + + +class TestMMFlood: + @pytest.fixture( + params=product([True, False], [True, False], ['train', 'val', 'test']) + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> MMFlood: + dataset_root = os.path.join('tests', 'data', 'mmflood/') + url = os.path.join(dataset_root) + + monkeypatch.setattr(MMFlood, 'url', url) + monkeypatch.setattr(MMFlood, '_nparts', 2) + + include_dem, include_hydro, split = request.param + root = tmp_path + return MMFlood( + root, + split=split, + include_dem=include_dem, + include_hydro=include_hydro, + transforms=nn.Identity(), + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: MMFlood) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + nchannels = 2 + + # If DEM is included and hydro is included, check if 4 channels are present, + # If only one between DEM or hydro is included, check if 3 channels are present + # 2 otherwise + if dataset.include_dem: + nchannels += 1 + if dataset.include_hydro: + nchannels += 1 + assert x['image'].size(0) == nchannels + + def test_len(self, dataset: MMFlood) -> None: + if dataset.split == 'train': + if not dataset.include_hydro: + assert len(dataset) == 5 + else: + assert len(dataset) == 2 + elif dataset.split == 'val': + if not dataset.include_hydro: + assert len(dataset) == 2 + else: + assert len(dataset) == 1 + else: + assert len(dataset) == 1 + + def test_and(self, dataset: MMFlood) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: MMFlood) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_downloaded(self, dataset: MMFlood) -> None: + MMFlood(root=dataset.root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + MMFlood(tmp_path) + + def test_plot(self, dataset: MMFlood) -> None: + x = dataset[dataset.bounds] + dataset.plot(x, suptitle='Test') + plt.close() + + def test_plot_prediction(self, dataset: MMFlood) -> None: + x = dataset[dataset.bounds] + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') + plt.close() + + def test_invalid_query(self, dataset: MMFlood) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match='query: .* not found in index with bounds:' + ): + dataset[query] diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 4bdd966a1bb..a21f0e5f4c8 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -65,6 +65,7 @@ class TestSemanticSegmentationTask: 'landcoverai', 'landcoverai100', 'loveda', + 'mmflood', 'naipchesapeake', 'potsdam2d', 'sen12ms_all', diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 6dd7231e3df..be49fa97463 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -29,6 +29,7 @@ from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule from .loveda import LoveDADataModule +from .mmflood import MMFloodDataModule from .naip import NAIPChesapeakeDataModule from .nasa_marine_debris import NASAMarineDebrisDataModule from .oscd import OSCDDataModule @@ -87,6 +88,7 @@ 'LandCoverAI100DataModule', 'LandCoverAIDataModule', 'LoveDADataModule', + 'MMFloodDataModule', 'MisconfigurationException', 'NAIPChesapeakeDataModule', 'NASAMarineDebrisDataModule', diff --git a/torchgeo/datamodules/mmflood.py b/torchgeo/datamodules/mmflood.py new file mode 100644 index 00000000000..ddd6804cded --- /dev/null +++ b/torchgeo/datamodules/mmflood.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""MMFlood datamodule.""" + +from typing import Any + +import kornia.augmentation as K +import torch +from kornia.constants import DataKey, Resample +from torch import Tensor + +from ..datasets import MMFlood +from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..samplers.utils import _to_tuple +from .geo import GeoDataModule + + +class MMFloodDataModule(GeoDataModule): + """LightningDataModule implementation for the MMFlood dataset. + + .. versionadded:: 0.7 + """ + + # Computed over train set + # VV, VH, dem, hydro + median = torch.tensor([0.116051525, 0.025692634, 86.0, 0.0]) + std = torch.tensor([2.405442, 0.22719479, 242.74359, 0.1482505053281784]) + + def __init__( + self, + batch_size: int = 32, + patch_size: int | tuple[int, int] = 512, + length: int | None = None, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new MMFloodDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.MMFlood`. + """ + super().__init__( + MMFlood, + batch_size=batch_size, + patch_size=patch_size, + length=length, + num_workers=num_workers, + **kwargs, + ) + avg, std = self._get_mean_std( + dem=kwargs.get('include_dem', False), + hydro=kwargs.get('include_hydro', False), + ) + + # Using median for normalization for better stability, + # as stated by the original authors + self.train_aug = K.AugmentationSequential( + K.RandomResizedCrop(_to_tuple(self.patch_size), p=0.8, scale=(0.5, 1.0)), + K.Normalize(avg, std), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomRotation90((0, 3), p=0.5), + K.RandomElasticTransform(sigma=(50, 50)), + keepdim=True, + data_keys=None, + extra_args={ + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} + }, + ) + + self.aug = K.AugmentationSequential( + K.Normalize(avg, std), keepdim=True, data_keys=None + ) + + def _get_mean_std( + self, dem: bool = False, hydro: bool = False + ) -> tuple[Tensor, Tensor]: + """Retrieve mean and standard deviation tensors used for normalization. + + Args: + dem: True if DEM data is loaded + hydro: True if hydrography data is loaded + + Returns: + mean and standard deviation tensors + """ + idxs = [0, 1] # VV, VH + if dem: + idxs.append(2) + if hydro: + idxs.append(3) + return self.median[idxs], self.std[idxs] + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', 'predict'. + """ + if stage in ['fit']: + self.train_dataset = MMFlood(**self.kwargs, split='train') + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, self.patch_size, self.batch_size, self.length + ) + if stage in ['fit', 'validate']: + self.val_dataset = MMFlood(**self.kwargs, split='val') + self.val_sampler = GridGeoSampler( + self.val_dataset, self.patch_size, self.patch_size + ) + if stage in ['test']: + self.test_dataset = MMFlood(**self.kwargs, split='test') + self.test_sampler = GridGeoSampler( + self.test_dataset, self.patch_size, self.patch_size + ) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0e522c09976..8f238abd916 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -88,6 +88,7 @@ from .mdas import MDAS from .millionaid import MillionAID from .mmearth import MMEarth +from .mmflood import MMFlood from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris from .nccm import NCCM @@ -245,6 +246,7 @@ 'Landsat9', 'LoveDA', 'MMEarth', + 'MMFlood', 'MapInWild', 'MillionAID', 'NASAMarineDebris', diff --git a/torchgeo/datasets/mmflood.py b/torchgeo/datasets/mmflood.py new file mode 100644 index 00000000000..0330a8839da --- /dev/null +++ b/torchgeo/datasets/mmflood.py @@ -0,0 +1,382 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""MMFlood dataset.""" + +from __future__ import annotations + +import os +from collections.abc import Callable +from glob import glob +from typing import ClassVar, Literal + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.figure import Figure +from rasterio.crs import CRS +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import IntersectionDataset, RasterDataset +from .utils import BoundingBox, Path, download_url, extract_archive + + +class MMFloodComponent(RasterDataset): + """Base component for MMFlood dataset.""" + + def __init__( + self, + subfolders: list[str], + content: Literal['s1_raw', 'DEM', 'hydro', 'mask'], + root: Path = 'data', + crs: CRS | None = None, + res: float | None = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + cache: bool = False, + ) -> None: + """Initialize MMFloodComponent dataset instance. + + Args: + subfolders: list of directories to be loaded + content: specifies which component to load + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + """ + self.content = content + self.is_image = content != 'mask' + paths = [] + for s in subfolders: + paths += glob(os.path.join(root, '**', f'{s}*-*', self.content, '*.tif')) + paths = sorted(paths) + super().__init__(paths, crs, res, transforms=transforms, cache=cache) + + +class MMFloodIntersection(IntersectionDataset): + """Intersection dataset used to merge two or more MMFloodComponents.""" + + def __init__( + self, + dataset1: MMFloodComponent | MMFloodIntersection, + dataset2: MMFloodComponent | MMFloodIntersection, + ) -> None: + """Initialize a new MMFloodIntersection instance. + + Args: + dataset1: the first dataset to merge + dataset2: the second dataset to merge + """ + # if hydro component is passed, it should always be passed as dataset2 + super().__init__(dataset1, dataset2) + + def _merge_dataset_indices(self) -> None: + """Create a new R-tree out of the individual indices from Sentinel-1, DEM and hydrography datasets.""" + _, ds2 = self.datasets + # Always use index of ds2, since it either coincides with ds1 index + # or refers to hydro, which represents only a subset of the dataset + self.index = ds2.index + + +class MMFlood(IntersectionDataset): + """MMFlood dataset. + + `MMFlood `__ dataset is a multimodal + flood delineation dataset. Sentinel-1 data is matched with masks and DEM data for all + available tiles. If hydrography maps are loaded, only a subset of the dataset is loaded, + since only 1,012 Sentinel-1 tiles have a corresponding hydrography map. + Some Sentinel-1 tiles have missing data, which are automatically set to 0. + Corresponding pixels in masks are set to 255 and should be ignored in performance computation. + + Dataset features: + + * 1,748 Sentinel-1 tiles of varying pixel dimensions + * multimodal dataset + * 95 flood events from 42 different countries + * includes DEMs + * includes hydrography maps (available for 1,012 tiles out of 1,748) + * flood delineation maps (ground truth) is obtained from Copernicus EMS + + Dataset classes: + + 0. no flood + 1. flood + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1109/ACCESS.2022.3205419 + + .. versionadded:: 0.7 + """ + + url = 'https://huggingface.co/datasets/links-ads/mmflood/resolve/24ca097306c9e50ad0711903c11e1ba13ea1bedc/' + _ignore_index = 255 + _nparts = 11 + + metadata: ClassVar[dict[str, str]] = { + 'part_file': 'activations.tar.{part}.gz.part', + 'filename': 'activations.tar.gz', + 'directory': 'activations', + 'metadata_file': 'activations.json', + } + _splits: ClassVar[set[str]] = {'train', 'val', 'test'} + _md5: ClassVar[dict[str, str]] = { + 'activations.json': 'de33a3ac7e55a0051ada21cbdfbb4745', + 'activations.tar.gz': '3cd4c4fe7506aa40263f74639d85ccce', + 'activations.tar.000.gz.part': 'a8424653edca6e79999831bdda53d4dc', + 'activations.tar.001.gz.part': '517def8760d3ce86885c7600c77a1d6c', + 'activations.tar.002.gz.part': '6797b97121f5b98ff58fde7491f584b2', + 'activations.tar.003.gz.part': 'e69d2a6b1746ef869d1da4d22018a71a', + 'activations.tar.004.gz.part': '0ccf7ea69ea6c0e88db1b1015ec3361e', + 'activations.tar.005.gz.part': '8ef6765afe20f254b1e752d7a2742fda', + 'activations.tar.006.gz.part': '3f330a44b66511b7a95f4a555f8b793a', + 'activations.tar.007.gz.part': '1d2046b5f3c473c3681a05dc94b29b86', + 'activations.tar.008.gz.part': 'f386b5acf78f8ae34592404c6c7ec43c', + 'activations.tar.009.gz.part': 'dd5317a3c0d33de815beadb9850baa38', + 'activations.tar.010.gz.part': '5a14a7e3f916c5dcf288c2ca88daf4d0', + } + + def __init__( + self, + root: Path = 'data', + crs: CRS | None = None, + res: float | None = None, + split: str = 'train', + include_dem: bool = False, + include_hydro: bool = False, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + cache: bool = False, + ) -> None: + """Initialize a new MMFlood dataset instance. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + split: train/val/test split to load + include_dem: If True, DEM data is concatenated after Sentinel-1 bands. + include_hydro: If True, hydrography data is concatenated as last channel. + Only a smaller subset of the original dataset is loaded in this case. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + cache: if True, cache file handle to speed up repeated sampling + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + + """ + assert split in self._splits + + self.root = root + self.split = split + self.include_dem = include_dem + self.include_hydro = include_hydro + self.transforms = transforms + self.download = download + self.checksum = checksum + # Verify integrity of the dataset + self._verify() + self.metadata_df = pd.read_json( + os.path.join(self.root, self.metadata['metadata_file']) + ).transpose() + + split_subfolders = self.metadata_df[ + self.metadata_df['subset'] == self.split + ].index.tolist() + self.image: MMFloodComponent | MMFloodIntersection = MMFloodComponent( + split_subfolders, 's1_raw', root, crs, res, cache=cache + ) + if include_dem: + dem = MMFloodComponent(split_subfolders, 'DEM', root, crs, res, cache=cache) + self.image = MMFloodIntersection(self.image, dem) + if include_hydro: + hydro = MMFloodComponent( + split_subfolders, 'hydro', root, crs, res, cache=cache + ) + self.image = MMFloodIntersection(self.image, hydro) + self.mask = MMFloodComponent( + split_subfolders, 'mask', root, crs, res, cache=cache + ) + + super().__init__(self.image, self.mask, transforms=transforms) + + def _merge_tar_files(self) -> None: + """Merge part tar gz files.""" + dst_filename = self.metadata['filename'] + dst_path = os.path.join(self.root, dst_filename) + + print('Merging separate part files...') + with open(dst_path, 'wb') as dst_fp: + for idx in range(self._nparts): + part_filename = f'activations.tar.{idx:03}.gz.part' + part_path = os.path.join(self.root, part_filename) + print(f'Processing file {part_path!s}') + + with open(part_path, 'rb') as part_fp: + dst_fp.write(part_fp.read()) + + def __getitem__(self, query: BoundingBox) -> dict[str, Tensor]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image, mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + data = super().__getitem__(query) + missing_data = data['image'].isnan().any(dim=0) + # Set all pixel values of invalid areas to 0, all mask values to 255 + data['image'][:, missing_data] = 0 + data['mask'][missing_data] = self._ignore_index + return data + + def _merge_dataset_indices(self) -> None: + """Create a new R-tree out of the individual indices from Sentinel-1, DEM and hydrography datasets.""" + ds1, _ = self.datasets + # Use ds1 index + self.index = ds1.index + + def _download(self) -> None: + """Download the dataset.""" + + def _check_and_download(filename: str, url: str) -> None: + path = os.path.join(self.root, filename) + if not os.path.exists(path): + md5 = self._md5[filename] if self.checksum else None + download_url(url, self.root, filename, md5) + return + + filename = self.metadata['filename'] + filepath = os.path.join(self.root, filename) + if not os.path.exists(filepath): + for idx in range(self._nparts): + part_file = f'activations.tar.{idx:03}.gz.part' + url = self.url + part_file + + _check_and_download(part_file, url) + + _check_and_download( + self.metadata['metadata_file'], self.url + self.metadata['metadata_file'] + ) + + def _extract(self) -> None: + """Extract the dataset.""" + filepath = os.path.join(self.root, self.metadata['filename']) + if str(filepath).endswith('.tar.gz'): + extract_archive(filepath) + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + dirpath = os.path.join(self.root, self.metadata['directory']) + metadata_filepath = os.path.join(self.root, self.metadata['metadata_file']) + # Check if both metadata file and directory exist + if os.path.isdir(dirpath) and os.path.isfile(metadata_filepath): + return + if not self.download: + raise DatasetNotFoundError(self) + self._download() + self._merge_tar_files() + self._extract() + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + show_mask = 'mask' in sample + image = sample['image'][[0, 1]].permute(1, 2, 0).numpy() + ncols = 1 + show_predictions = 'prediction' in sample + if self.include_dem: + dem_idx = -2 if self.include_hydro else -1 + dem = sample['image'][dem_idx].squeeze(0).numpy() + ncols += 1 + if self.include_hydro: + hydro = sample['image'][-1].squeeze(0).numpy() + ncols += 1 + if show_mask: + mask = sample['mask'].numpy() + # Set ignore_index values to 0 + mask[mask == self._ignore_index] = 0 + ncols += 1 + if show_predictions: + pred = sample['prediction'].numpy() + ncols += 1 + + # Compute False Color image, from Sentinel1 plot function + co_polarization = image[..., 0] # transmit == receive + cross_polarization = image[..., 1] # transmit != receive + ratio = co_polarization / cross_polarization + + # https://gis.stackexchange.com/a/400780/123758 + co_polarization = np.clip(co_polarization / 0.3, a_min=0, a_max=1) + cross_polarization = np.clip(cross_polarization / 0.05, a_min=0, a_max=1) + ratio = np.clip(ratio / 25, a_min=0, a_max=1) + + image = np.stack((co_polarization, cross_polarization, ratio), axis=-1) + + # Generate the figure + fig, axs = plt.subplots(ncols=ncols, figsize=(4 * ncols, 4)) + axs[0].imshow(image) + axs[0].axis('off') + axs_idx = 1 + if self.include_dem: + axs[axs_idx].imshow(dem, cmap='gray') + axs[axs_idx].axis('off') + axs_idx += 1 + if self.include_hydro: + axs[axs_idx].imshow(hydro, cmap='gray') + axs[axs_idx].axis('off') + axs_idx += 1 + if show_mask: + axs[axs_idx].imshow(mask, cmap='gray') + axs[axs_idx].axis('off') + axs_idx += 1 + if show_predictions: + axs[axs_idx].imshow(pred, cmap='gray') + axs[axs_idx].axis('off') + + if show_titles: + axs[0].set_title('Image') + axs_idx = 1 + if self.include_dem: + axs[axs_idx].set_title('DEM') + axs_idx += 1 + if self.include_hydro: + axs[axs_idx].set_title('Hydrography Map') + axs_idx += 1 + if show_mask: + axs[axs_idx].set_title('Mask') + axs_idx += 1 + if show_predictions: + axs[axs_idx].set_title('Prediction') + + if suptitle is not None: + plt.suptitle(suptitle) + return fig