Skip to content

Commit

Permalink
revert target dtype to long in dataset and change to float in trainer…
Browse files Browse the repository at this point in the history
… instead.
  • Loading branch information
keves1 committed Jan 8, 2025
1 parent 1839b8e commit 97f17ca
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 2 additions & 2 deletions torchgeo/datamodules/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def __init__(

self.aug = K.AugmentationSequential(
K.VideoSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
),
data_keys=None,
keepdim=True,
Expand Down
11 changes: 7 additions & 4 deletions torchgeo/datasets/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def __init__(
root: Path = 'data',
split: str = 'train',
bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
transforms: Callable[[dict[str, Tensor]],
dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
Expand Down Expand Up @@ -184,7 +185,8 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]:
def get_image_paths(ind: int) -> list[str]:
return sorted(
glob.glob(
os.path.join(images_root, region, f'imgs_{ind}_rect', '*.tif')
os.path.join(images_root, region,
f'imgs_{ind}_rect', '*.tif')
),
key=sort_sentinel2_bands,
)
Expand Down Expand Up @@ -223,7 +225,8 @@ def _load_image(self, paths: Sequence[Path]) -> Tensor:
for path in paths:
with Image.open(path) as img:
images.append(np.array(img))
array: np.typing.NDArray[np.int_] = np.stack(images, axis=0).astype(np.int_)
array: np.typing.NDArray[np.int_] = np.stack(
images, axis=0).astype(np.int_)
tensor = torch.from_numpy(array).float()
return tensor

Expand All @@ -241,7 +244,7 @@ def _load_target(self, path: Path) -> Tensor:
array: np.typing.NDArray[np.int_] = np.array(img.convert('L'))
tensor = torch.from_numpy(array)
tensor = torch.clamp(tensor, min=0, max=1)
tensor = tensor.to(torch.float)
tensor = tensor.to(torch.long)
return tensor

def _verify(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion torchgeo/trainers/change.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch import Tensor
from torchmetrics import MetricCollection
Expand Down Expand Up @@ -173,7 +174,7 @@ def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor:
x = x.flatten(start_dim=1, end_dim=2)
y_hat = self(x)

loss: Tensor = self.criterion(y_hat, y)
loss: Tensor = self.criterion(y_hat, y.to(torch.float))
self.log(f'{stage}_loss', loss)

# Retrieve the correct metrics based on the stage
Expand Down

0 comments on commit 97f17ca

Please sign in to comment.