diff --git a/tests/conf/oscd.yaml b/tests/conf/oscd.yaml new file mode 100644 index 00000000000..d44d9c9a459 --- /dev/null +++ b/tests/conf/oscd.yaml @@ -0,0 +1,15 @@ +model: + class_path: ChangeDetectionTask + init_args: + loss: 'bce' + model: 'unet' + backbone: 'resnet18' + in_channels: 13 +data: + class_path: OSCDDataModule + init_args: + batch_size: 2 + patch_size: 16 + val_split_pct: 0.5 + dict_kwargs: + root: 'tests/data/oscd' diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py deleted file mode 100644 index e67bd6d5678..00000000000 --- a/tests/datamodules/test_oscd.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import pytest -from _pytest.fixtures import SubRequest -from lightning.pytorch import Trainer - -from torchgeo.datamodules import OSCDDataModule -from torchgeo.datasets import OSCD - - -class TestOSCDDataModule: - @pytest.fixture(params=[OSCD.all_bands, OSCD.rgb_bands]) - def datamodule(self, request: SubRequest) -> OSCDDataModule: - bands = request.param - root = os.path.join('tests', 'data', 'oscd') - dm = OSCDDataModule( - root=root, - download=True, - bands=bands, - batch_size=1, - patch_size=2, - val_split_pct=0.5, - num_workers=0, - ) - dm.prepare_data() - dm.trainer = Trainer(accelerator='cpu', max_epochs=1) - return dm - - def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup('fit') - if datamodule.trainer: - datamodule.trainer.training = True - batch = next(iter(datamodule.train_dataloader())) - batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 - assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 - if datamodule.bands == OSCD.all_bands: - assert batch['image1'].shape[1] == 13 - assert batch['image2'].shape[1] == 13 - else: - assert batch['image1'].shape[1] == 3 - assert batch['image2'].shape[1] == 3 - - def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup('validate') - if datamodule.trainer: - datamodule.trainer.validating = True - batch = next(iter(datamodule.val_dataloader())) - batch = datamodule.on_after_batch_transfer(batch, 0) - if datamodule.val_split_pct > 0.0: - assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 - assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 - if datamodule.bands == OSCD.all_bands: - assert batch['image1'].shape[1] == 13 - assert batch['image2'].shape[1] == 13 - else: - assert batch['image1'].shape[1] == 3 - assert batch['image2'].shape[1] == 3 - - def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup('test') - if datamodule.trainer: - datamodule.trainer.testing = True - batch = next(iter(datamodule.test_dataloader())) - batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 - assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 - if datamodule.bands == OSCD.all_bands: - assert batch['image1'].shape[1] == 13 - assert batch['image2'].shape[1] == 13 - else: - assert batch['image1'].shape[1] == 3 - assert batch['image2'].shape[1] == 3 diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index 711392f7fc4..5c10508e3bb 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -66,19 +66,15 @@ def dataset( def test_getitem(self, dataset: OSCD) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x['image1'], torch.Tensor) - assert x['image1'].ndim == 3 - assert isinstance(x['image2'], torch.Tensor) - assert x['image2'].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].ndim == 4 assert isinstance(x['mask'], torch.Tensor) assert x['mask'].ndim == 2 if dataset.bands == OSCD.rgb_bands: - assert x['image1'].shape[0] == 3 - assert x['image2'].shape[0] == 3 + assert x['image'].shape[1] == 3 else: - assert x['image1'].shape[0] == 13 - assert x['image2'].shape[0] == 13 + assert x['image'].shape[1] == 13 def test_len(self, dataset: OSCD) -> None: if dataset.split == 'train': diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index a3ce098ae7d..920e8ee6abc 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -6,8 +6,8 @@ from pathlib import Path import pytest +import timm import torch -import torchvision from _pytest.fixtures import SubRequest from torch import Tensor from torch.nn.modules import Module @@ -22,8 +22,9 @@ def fast_dev_run(request: SubRequest) -> bool: @pytest.fixture(scope='package') -def model() -> Module: - model: Module = torchvision.models.resnet18(weights=None) +def model(request: SubRequest) -> Module: + in_channels = getattr(request, 'param', 3) + model: Module = timm.create_model('resnet18', in_chans=in_channels) return model diff --git a/tests/trainers/test_change.py b/tests/trainers/test_change.py new file mode 100644 index 00000000000..e760d8b6ef7 --- /dev/null +++ b/tests/trainers/test_change.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from pathlib import Path +from typing import Any, cast + +import pytest +import segmentation_models_pytorch as smp +import timm +import torch +import torch.nn as nn +from pytest import MonkeyPatch +from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum + +from torchgeo.datamodules import MisconfigurationException +from torchgeo.main import main +from torchgeo.models import ResNet18_Weights +from torchgeo.trainers import ChangeDetectionTask + + +class ChangeDetectionTestModel(Module): + def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return cast(torch.Tensor, self.conv1(x)) + + +def create_model(**kwargs: Any) -> Module: + return ChangeDetectionTestModel(**kwargs) + + +class TestChangeDetectionTask: + @pytest.mark.parametrize('name', ['oscd']) + def test_trainer( + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool + ) -> None: + config = os.path.join('tests', 'conf', name + '.yaml') + + monkeypatch.setattr(smp, 'Unet', create_model) + + args = [ + '--config', + config, + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', + str(fast_dev_run), + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', + ] + + main(['fit', *args]) + try: + main(['test', *args]) + except MisconfigurationException: + pass + try: + main(['predict', *args]) + except MisconfigurationException: + pass + + @pytest.fixture + def weights(self) -> WeightsEnum: + return ResNet18_Weights.SENTINEL2_ALL_MOCO + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + # multiply in_chans by 2 since images are concatenated + model = timm.create_model( + weights.meta['model'], in_chans=weights.meta['in_chans'] * 2 + ) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + @pytest.mark.parametrize('model', [6], indirect=True) + def test_weight_file(self, checkpoint: str) -> None: + ChangeDetectionTask(backbone='resnet18', weights=checkpoint) + + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=mocked_weights.meta['model'], + weights=mocked_weights, + in_channels=mocked_weights.meta['in_chans'], + ) + + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=mocked_weights.meta['model'], + weights=str(mocked_weights), + in_channels=mocked_weights.meta['in_chans'], + ) + + @pytest.mark.slow + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=weights.meta['model'], + weights=weights, + in_channels=weights.meta['in_chans'], + ) + + @pytest.mark.slow + def test_weight_str_download(self, weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=weights.meta['model'], + weights=str(weights), + in_channels=weights.meta['in_chans'], + ) + + def test_invalid_model(self) -> None: + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + ChangeDetectionTask(model='invalid_model') + + def test_invalid_loss(self) -> None: + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + ChangeDetectionTask(loss='invalid_loss') + + @pytest.mark.parametrize('model_name', ['unet']) + @pytest.mark.parametrize( + 'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] + ) + def test_freeze_backbone(self, model_name: str, backbone: str) -> None: + model = ChangeDetectionTask( + model=model_name, backbone=backbone, freeze_backbone=True + ) + assert all( + [param.requires_grad is False for param in model.model.encoder.parameters()] + ) + assert all([param.requires_grad for param in model.model.decoder.parameters()]) + assert all( + [ + param.requires_grad + for param in model.model.segmentation_head.parameters() + ] + ) + + @pytest.mark.parametrize('model_name', ['unet']) + def test_freeze_decoder(self, model_name: str) -> None: + model = ChangeDetectionTask(model=model_name, freeze_decoder=True) + assert all( + [param.requires_grad is False for param in model.model.decoder.parameters()] + ) + assert all([param.requires_grad for param in model.model.encoder.parameters()]) + assert all( + [ + param.requires_grad + for param in model.model.segmentation_head.parameters() + ] + ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 8db1dd7061a..5cd898e0699 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -85,8 +85,10 @@ def __init__( self.std = torch.tensor([STD[b] for b in self.bands]) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, batch_size), + K.VideoSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + ), data_keys=None, keepdim=True, ) diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 28f7714a7c6..d261cbbf291 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -150,7 +150,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image1 = self._load_image(files['images1']) image2 = self._load_image(files['images2']) mask = self._load_target(str(files['mask'])) - sample = {'image1': image1, 'image2': image2, 'mask': mask} + image = torch.stack(tensors=[image1, image2], dim=0) + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -324,8 +325,8 @@ def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]': ) return array - image1 = get_masked(sample['image1']) - image2 = get_masked(sample['image2']) + image1 = get_masked(sample['image'][0]) + image2 = get_masked(sample['image'][1]) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) axs[0].axis('off') diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index ee69bff0021..0dbb9987557 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -5,6 +5,7 @@ from .base import BaseTask from .byol import BYOLTask +from .change import ChangeDetectionTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask from .iobench import IOBenchTask @@ -14,8 +15,10 @@ from .simclr import SimCLRTask __all__ = ( + # Supervised 'BYOLTask', 'BaseTask', + 'ChangeDetectionTask', 'ClassificationTask', 'IOBenchTask', 'MoCoTask', diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py new file mode 100644 index 00000000000..4788dda91c5 --- /dev/null +++ b/torchgeo/trainers/change.py @@ -0,0 +1,239 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Trainers for change detection.""" + +import os +from typing import Any + +import segmentation_models_pytorch as smp +import torch +import torch.nn as nn +from torch import Tensor +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + BinaryAccuracy, + BinaryF1Score, + BinaryJaccardIndex, +) +from torchvision.models._api import WeightsEnum + +from ..models import FCSiamConc, FCSiamDiff, get_weight +from . import utils +from .base import BaseTask + + +class ChangeDetectionTask(BaseTask): + """Change Detection. Currently supports binary change between two timesteps.""" + + def __init__( + self, + model: str = 'unet', + backbone: str = 'resnet50', + weights: WeightsEnum | str | bool | None = None, + in_channels: int = 3, + pos_weight: Tensor | None = None, + loss: str = 'bce', + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + freeze_decoder: bool = False, + ) -> None: + """Inititalize a new ChangeDetectionTask instance. + + Args: + model: Name of the model to use. + backbone: Name of the `timm + `__ or `smp + `__ backbone to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False or + None for random weights, or the path to a saved model state dict. FCN + model does not support pretrained weights. Pretrained ViT weight enums + are not supported yet. + in_channels: Number of input channels to model. + pos_weight: A weight of positive examples and used with 'bce' loss. + loss: Name of the loss function, currently supports + 'bce', 'jaccard', or 'focal' loss. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + freeze_backbone: Freeze the backbone network to fine-tune the + decoder and segmentation head. + freeze_decoder: Freeze the decoder network to linear probe + the segmentation head. + + .. versionadded: 0.7 + """ + self.weights = weights + super().__init__() + + def configure_losses(self) -> None: + """Initialize the loss criterion. + + Raises: + ValueError: If *loss* is invalid. + """ + loss: str = self.hparams['loss'] + if loss == 'bce': + self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.hparams['pos_weight']) + elif loss == 'jaccard': + self.criterion = smp.losses.JaccardLoss(mode='binary') + elif loss == 'focal': + self.criterion = smp.losses.FocalLoss(mode='binary', normalized=True) + else: + raise ValueError( + f"Loss type '{loss}' is not valid. " + "Currently, supports 'bce', 'jaccard', or 'focal' loss." + ) + + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + metrics = MetricCollection( + { + 'accuracy': BinaryAccuracy(), + 'jaccard': BinaryJaccardIndex(), + 'f1': BinaryF1Score(), + } + ) + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') + + def configure_models(self) -> None: + """Initialize the model. + + Raises: + ValueError: If *model* is invalid. + """ + model: str = self.hparams['model'] + backbone: str = self.hparams['backbone'] + weights = self.weights + in_channels: int = self.hparams['in_channels'] + num_classes = 1 + + if model == 'unet': + self.model = smp.Unet( + encoder_name=backbone, + encoder_weights='imagenet' if weights is True else None, + in_channels=in_channels * 2, # images are concatenated + classes=num_classes, + ) + elif model == 'fcsiamdiff': + self.model = FCSiamDiff( + in_channels=in_channels, + classes=num_classes, + encoder_weights='imagenet' if weights is True else None, + ) + elif model == 'fcsiamconc': + self.model = FCSiamConc( + in_channels=in_channels, + classes=num_classes, + encoder_weights='imagenet' if weights is True else None, + ) + else: + raise ValueError( + f"Model type '{model}' is not valid. " + "Currently, only supports 'unet', 'fcsiamdiff, and 'fcsiamconc'." + ) + + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.model.encoder.load_state_dict(state_dict) + + # Freeze backbone + if self.hparams['freeze_backbone'] and model in ['unet']: + for param in self.model.encoder.parameters(): + param.requires_grad = False + + # Freeze decoder + if self.hparams['freeze_decoder'] and model in ['unet']: + for param in self.model.decoder.parameters(): + param.requires_grad = False + + def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: + """Compute the loss and additional metrics for the given stage. + + Args: + batch: The output of your DataLoader._ + batch_idx: Integer displaying index of this batch._ + stage: The current stage. + + Returns: + The loss tensor. + """ + model: str = self.hparams['model'] + x = batch['image'] + y = batch['mask'] + y = y.unsqueeze(dim=1) # channel dim for binary loss functions/metrics + if model == 'unet': + x = x.flatten(start_dim=1, end_dim=2) + y_hat = self(x) + + loss: Tensor = self.criterion(y_hat, y.to(torch.float)) + self.log(f'{stage}_loss', loss) + + # Retrieve the correct metrics based on the stage + metrics = getattr(self, f'{stage}_metrics', None) + if metrics: + metrics(y_hat, y) + self.log_dict({f'{k}': v for k, v in metrics.compute().items()}) + + return loss + + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Compute the training loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + + Returns: + The loss tensor. + """ + loss = self._shared_step(batch, batch_idx, 'train') + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + """Compute the validation loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ + self._shared_step(batch, batch_idx, 'val') + + def test_step(self, batch: Any, batch_idx: int) -> None: + """Compute the test loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ + self._shared_step(batch, batch_idx, 'test') + + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted class. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + + Returns: + Output predicted class. + """ + model: str = self.hparams['model'] + threshold = 0.5 + x = batch['image'] + if model == 'unet': + x = x.flatten(start_dim=1, end_dim=2) + y_hat: Tensor = self(x) + y_hat_hard = (nn.functional.sigmoid(y_hat) > threshold).int() + return y_hat_hard diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8f80bdcaac..e04389d8b18 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -102,7 +102,13 @@ def forward(self, batch: dict[str, Any]) -> dict[str, Any]: batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # Torchmetrics does not support masks with a channel dimension - if 'mask' in batch and batch['mask'].shape[1] == 1: + # Kornia adds a temporal dimension to mask when passed through VideoSequential. + if 'mask' in batch and batch['mask'].ndim == 5: + if batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () c h w -> b c h w') + if batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + elif 'mask' in batch and batch['mask'].shape[1] == 1: batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') if 'masks' in batch and batch['masks'].ndim == 4: batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w')