From 711690ca509007326c8a02df43dc7a06283bf356 Mon Sep 17 00:00:00 2001 From: Anton Eriksson Date: Fri, 4 Oct 2024 10:03:29 +0200 Subject: [PATCH 1/6] DataModules: add configurable args to dataloader --- torchgeo/datamodules/geo.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index e8e3aedd194..d616d3c8b6d 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -38,6 +38,8 @@ def __init__( dataset_class: type[Dataset[dict[str, Tensor]]], batch_size: int = 1, num_workers: int = 0, + pin_memory: bool = False, + prefetch_factor: int | None = None, **kwargs: Any, ) -> None: """Initialize a new BaseDataModule instance. @@ -46,6 +48,8 @@ def __init__( dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. + pin_memory: Whether to pin memory in data loaders. + prefetch_factor: Number of samples to prefetch. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__() @@ -53,6 +57,8 @@ def __init__( self.dataset_class = dataset_class self.batch_size = batch_size self.num_workers = num_workers + self.pin_memory = pin_memory + self.prefetch_factor = prefetch_factor self.kwargs = kwargs # Datasets @@ -178,6 +184,8 @@ def __init__( patch_size: int | tuple[int, int] = 64, length: int | None = None, num_workers: int = 0, + pin_memory: bool = False, + prefetch_factor: int | None = None, **kwargs: Any, ) -> None: """Initialize a new GeoDataModule instance. @@ -188,6 +196,8 @@ def __init__( patch_size: Size of each patch, either ``size`` or ``(height, width)``. length: Length of each training epoch. num_workers: Number of workers for parallel data loading. + pin_memory: Whether to pin memory in data loaders. + prefetch_factor: Number of samples to prefetch. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__(dataset_class, batch_size, num_workers, **kwargs) @@ -287,6 +297,8 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: @@ -371,6 +383,8 @@ def __init__( dataset_class: type[NonGeoDataset], batch_size: int = 1, num_workers: int = 0, + pin_memory: bool = False, + prefetch_factor: int | None = None, **kwargs: Any, ) -> None: """Initialize a new NonGeoDataModule instance. @@ -379,6 +393,8 @@ def __init__( dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. + pin_memory: Whether to pin memory in data loaders. + prefetch_factor: Number of samples to prefetch. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__(dataset_class, batch_size, num_workers, **kwargs) @@ -431,6 +447,8 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: From bc789d89c9dc77e5a300d511cacef826bf4530df Mon Sep 17 00:00:00 2001 From: Anton Eriksson Date: Tue, 8 Oct 2024 13:30:07 +0200 Subject: [PATCH 2/6] use prefixed kwargs in dataloader --- torchgeo/datamodules/geo.py | 22 +++------------------- torchgeo/datamodules/utils.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index d616d3c8b6d..507940d7c15 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -21,7 +21,7 @@ RandomBatchGeoSampler, ) from ..transforms import AugmentationSequential -from .utils import MisconfigurationException +from .utils import MisconfigurationException, get_prefixed_kwargs class BaseDataModule(LightningDataModule): @@ -38,8 +38,6 @@ def __init__( dataset_class: type[Dataset[dict[str, Tensor]]], batch_size: int = 1, num_workers: int = 0, - pin_memory: bool = False, - prefetch_factor: int | None = None, **kwargs: Any, ) -> None: """Initialize a new BaseDataModule instance. @@ -48,8 +46,6 @@ def __init__( dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. - pin_memory: Whether to pin memory in data loaders. - prefetch_factor: Number of samples to prefetch. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__() @@ -57,8 +53,6 @@ def __init__( self.dataset_class = dataset_class self.batch_size = batch_size self.num_workers = num_workers - self.pin_memory = pin_memory - self.prefetch_factor = prefetch_factor self.kwargs = kwargs # Datasets @@ -184,8 +178,6 @@ def __init__( patch_size: int | tuple[int, int] = 64, length: int | None = None, num_workers: int = 0, - pin_memory: bool = False, - prefetch_factor: int | None = None, **kwargs: Any, ) -> None: """Initialize a new GeoDataModule instance. @@ -196,8 +188,6 @@ def __init__( patch_size: Size of each patch, either ``size`` or ``(height, width)``. length: Length of each training epoch. num_workers: Number of workers for parallel data loading. - pin_memory: Whether to pin memory in data loaders. - prefetch_factor: Number of samples to prefetch. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__(dataset_class, batch_size, num_workers, **kwargs) @@ -297,8 +287,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, - pin_memory=self.pin_memory, - prefetch_factor=self.prefetch_factor, + **get_prefixed_kwargs('dataloader_', **self.kwargs), ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: @@ -383,8 +372,6 @@ def __init__( dataset_class: type[NonGeoDataset], batch_size: int = 1, num_workers: int = 0, - pin_memory: bool = False, - prefetch_factor: int | None = None, **kwargs: Any, ) -> None: """Initialize a new NonGeoDataModule instance. @@ -393,8 +380,6 @@ def __init__( dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. - pin_memory: Whether to pin memory in data loaders. - prefetch_factor: Number of samples to prefetch. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__(dataset_class, batch_size, num_workers, **kwargs) @@ -447,8 +432,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, - pin_memory=self.pin_memory, - prefetch_factor=self.prefetch_factor, + **get_prefixed_kwargs('dataloader_', **self.kwargs), ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 4c3aab63b61..51a07bcfcaa 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -169,3 +169,16 @@ def group_shuffle_split( test_idxs.append(i) return train_idxs, test_idxs + + +def get_prefixed_kwargs(prefix: str, **kwargs: dict[str, Any]) -> dict[str, Any]: + """Get kwargs with a specific prefix. + + Args: + prefix: Prefix to filter kwargs by. + **kwargs: Keyword arguments to filter. + + Returns: + Dictionary of kwargs with the specified prefix. + """ + return {k.replace(prefix, ''): v for k, v in kwargs.items() if k.startswith(prefix)} From 8257b8ea154fb85107867ed537bd6574b13eda00 Mon Sep 17 00:00:00 2001 From: Anton Eriksson Date: Tue, 8 Oct 2024 13:34:46 +0200 Subject: [PATCH 3/6] fix type hint --- torchgeo/datamodules/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 51a07bcfcaa..6a8070c6271 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -171,7 +171,7 @@ def group_shuffle_split( return train_idxs, test_idxs -def get_prefixed_kwargs(prefix: str, **kwargs: dict[str, Any]) -> dict[str, Any]: +def get_prefixed_kwargs(prefix: str, **kwargs: Any) -> dict[str, Any]: """Get kwargs with a specific prefix. Args: From b5291a654cbb8964615630a8dd38a64fe98610cb Mon Sep 17 00:00:00 2001 From: Anton Eriksson Date: Tue, 8 Oct 2024 19:05:52 +0200 Subject: [PATCH 4/6] ensure no breaking change in kwargs --- torchgeo/datamodules/geo.py | 8 ++++---- torchgeo/datamodules/utils.py | 23 ++++++++++++++++++----- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 507940d7c15..cd21d8ac546 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -21,7 +21,7 @@ RandomBatchGeoSampler, ) from ..transforms import AugmentationSequential -from .utils import MisconfigurationException, get_prefixed_kwargs +from .utils import MisconfigurationException, split_kwargs class BaseDataModule(LightningDataModule): @@ -53,7 +53,7 @@ def __init__( self.dataset_class = dataset_class self.batch_size = batch_size self.num_workers = num_workers - self.kwargs = kwargs + self.dataloader_kwargs, self.kwargs = split_kwargs('dataloader_', **kwargs) # Datasets self.dataset: Dataset[dict[str, Tensor]] | None = None @@ -287,7 +287,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, - **get_prefixed_kwargs('dataloader_', **self.kwargs), + **self.dataloader_kwargs, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: @@ -432,7 +432,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, - **get_prefixed_kwargs('dataloader_', **self.kwargs), + **self.dataloader_kwargs, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 6a8070c6271..5ec9519cc1d 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -171,14 +171,27 @@ def group_shuffle_split( return train_idxs, test_idxs -def get_prefixed_kwargs(prefix: str, **kwargs: Any) -> dict[str, Any]: - """Get kwargs with a specific prefix. +def split_kwargs(*prefixes: str, **kwargs: Any) -> tuple[dict[str, Any], ...]: + """Split kwargs into prefixed and other kwargs. Args: - prefix: Prefix to filter kwargs by. + *prefixes: Prefixes to filter kwargs by. **kwargs: Keyword arguments to filter. Returns: - Dictionary of kwargs with the specified prefix. + Tuple of prefixed kwargs and other kwargs. """ - return {k.replace(prefix, ''): v for k, v in kwargs.items() if k.startswith(prefix)} + prefixed_kwargs: list[dict[str, Any]] = [{} for _ in prefixes] + other_kwargs: dict[str, Any] = {} + + for key, value in kwargs.items(): + matched = False + for i, prefix in enumerate(prefixes): + if key.startswith(prefix): + prefixed_kwargs[i][key[len(prefix) :]] = value + matched = True + break + if not matched: + other_kwargs[key] = value + + return *prefixed_kwargs, other_kwargs From 88cedbc0e1c94d98d52c8e84f2107cd668c38d2b Mon Sep 17 00:00:00 2001 From: Anton Eriksson Date: Tue, 8 Oct 2024 19:19:58 +0200 Subject: [PATCH 5/6] fix DataModule docstring --- torchgeo/datamodules/geo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index cd21d8ac546..bfcebe9dc0b 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -46,7 +46,8 @@ def __init__( dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. - **kwargs: Additional keyword arguments passed to ``dataset_class`` + **kwargs: Additional keyword arguments passed to the ``DataLoader`` + if prefixed with 'dataloader_', else passed to ``dataset_class``. """ super().__init__() From cdace63c4bddc3a479aac98d555bce84e0cb3dcb Mon Sep 17 00:00:00 2001 From: Anton Eriksson Date: Wed, 9 Oct 2024 07:52:12 +0200 Subject: [PATCH 6/6] add test --- tests/datamodules/test_utils.py | 19 ++++++++++++++++++- torchgeo/datamodules/geo.py | 6 ++++-- torchgeo/datamodules/utils.py | 2 +- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/datamodules/test_utils.py b/tests/datamodules/test_utils.py index fcb1f82f584..3f8dba3c5dd 100644 --- a/tests/datamodules/test_utils.py +++ b/tests/datamodules/test_utils.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from torchgeo.datamodules.utils import group_shuffle_split +from torchgeo.datamodules.utils import group_shuffle_split, split_prefixed_kwargs def test_group_shuffle_split() -> None: @@ -44,3 +44,20 @@ def test_group_shuffle_split() -> None: assert len(set(train_indices1) & set(test_indices1)) == 0 assert len(set(groups[train_indices1])) == 2 + + +def test_split_prefixed_kwargs() -> None: + kwargs = { + 'testprefix1_param1': 10, + 'testprefix1_param2': 20, + 'testprefix2_param3': 30, + 'other_param': 40, + } + + testprefix1_kwargs, testprefix2_kwargs, other_kwargs = split_prefixed_kwargs( + 'testprefix1_', 'testprefix2_', **kwargs + ) + + assert testprefix1_kwargs == {'param1': 10, 'param2': 20} + assert testprefix2_kwargs == {'param3': 30} + assert other_kwargs == {'other_param': 40} diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index bfcebe9dc0b..1c3afe4ae4d 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -21,7 +21,7 @@ RandomBatchGeoSampler, ) from ..transforms import AugmentationSequential -from .utils import MisconfigurationException, split_kwargs +from .utils import MisconfigurationException, split_prefixed_kwargs class BaseDataModule(LightningDataModule): @@ -54,7 +54,9 @@ def __init__( self.dataset_class = dataset_class self.batch_size = batch_size self.num_workers = num_workers - self.dataloader_kwargs, self.kwargs = split_kwargs('dataloader_', **kwargs) + self.dataloader_kwargs, self.kwargs = split_prefixed_kwargs( + 'dataloader_', **kwargs + ) # Datasets self.dataset: Dataset[dict[str, Tensor]] | None = None diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 5ec9519cc1d..a79e7b3c887 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -171,7 +171,7 @@ def group_shuffle_split( return train_idxs, test_idxs -def split_kwargs(*prefixes: str, **kwargs: Any) -> tuple[dict[str, Any], ...]: +def split_prefixed_kwargs(*prefixes: str, **kwargs: Any) -> tuple[dict[str, Any], ...]: """Split kwargs into prefixed and other kwargs. Args: