Skip to content

Commit

Permalink
GLC23: added code documentation, cleaned class arguments, updated exa…
Browse files Browse the repository at this point in the history
…mples consequently.
  • Loading branch information
tlarcher committed Feb 22, 2024
1 parent 9d49f36 commit a37d163
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from torchgeo.samplers import Units
from torchvision import transforms

from malpolon.data.data_module import BaseDataModule
Expand All @@ -21,7 +22,6 @@
from malpolon.logging import Summary
from malpolon.models import ClassificationSystem
from malpolon.models.utils import CrashHandler
from torchgeo.samplers import Units


class Sentinel2PatchesDataModule(BaseDataModule):
Expand Down Expand Up @@ -94,7 +94,6 @@ def get_dataset(self, split, transform, target_transform=None, **kwargs):
providers=[jpp_rgbnir],
transform=transform,
target_transform=target_transform,
id_getitem='patchID',
item_columns=['lat', 'lon', 'patchID'],
**kwargs
)
Expand Down
5 changes: 2 additions & 3 deletions examples/kaggle/geolifeclef2023/example_patch_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ def main():
# create dataset
dataset = PatchesDataset(occurrences=data_path + 'Presence_only_occurrences/Presences_only_train_sample.csv',
providers=[p_hfp_d, p_bioclim, p_hfp_s, p_rgb],
item_columns=['lat', 'lon', 'patchID'],
id_getitem='patchID')
item_columns=['lat', 'lon', 'patchID'])
dataset_multi = PatchesDatasetMultiLabel(occurrences=data_path + 'Presence_only_occurrences/Presences_only_train_sample.csv',
providers=[p_hfp_d, p_bioclim, p_hfp_s, p_rgb],
item_columns=['lat', 'lon', 'patchID'],
id_getitem='patchID')

# print random tensors from dataset
ids = [random.randint(0, len(dataset) - 1) for i in range(1)]
ids = [random.randint(0, len(dataset) - 1) for i in range(5)]
for i in ids:
tensor, label = dataset[i]
label_multi = dataset_multi[i][1]
Expand Down
Loading

0 comments on commit a37d163

Please sign in to comment.