Skip to content

Commit

Permalink
update of jaccard loss
Browse files Browse the repository at this point in the history
  • Loading branch information
notprime committed Mar 17, 2024
1 parent c724487 commit 23de26d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 93 deletions.
8 changes: 4 additions & 4 deletions torchseg/losses/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def soft_jaccard_score(
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
intersection = torch.sum(output * target, dim = dims)
cardinality = torch.sum(output + target, dim = dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)

union = cardinality - intersection
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
jaccard_score = (intersection + smooth) / (union + eps)
return jaccard_score


Expand All @@ -181,7 +181,7 @@ def soft_dice_score(
output_pow = torch.sum(output ** power)
target_pow = torch.sum(target ** power)
cardinality = output_pow + target_pow
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
dice_score = (2.0 * intersection + smooth) / (cardinality + eps)
return dice_score


Expand Down
7 changes: 4 additions & 3 deletions torchseg/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.nn.functional as F

from ._functional import soft_dice_score, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from .reductions import LossReduction


Expand All @@ -31,14 +30,16 @@ def __init__(
have shape (B, 1, H, W) but you should set mask_to_one_hot = True.
Args:
- mode: Loss mode 'binary', 'multiclass' or 'multilabel'
- classes: List of classes that contribute in loss computation.
By default, all channels are included.
- log_loss: If True, loss computed as `- log(dice_coeff)`,
otherwise `1 - dice_coeff`
- from_logits: If True, assumes input is raw logits
- ignore_index: Label that indicates ignored pixels not contributing to the loss
- mask_to_one_hot: if set to True, the mask is converted into one-hot format.
- power: raise the denominator to the desider power.
- reduction: select the reduction to be applied to the loss.
- smooth: Smoothness constant for dice coefficient added to the numerator to avoid zero
- ignore_index: Label that indicates ignored pixels not contributing to the loss
- eps: A small epsilon added to the denominator for numerical stability to avoid nan
(denominator will be always greater or equal to eps)
Expand Down
170 changes: 84 additions & 86 deletions torchseg/losses/jaccard.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,110 @@
import warnings
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._functional import soft_jaccard_score, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from .reductions import LossReduction


class JaccardLoss(nn.Module):
def __init__(
self,
mode: str,
classes: Optional[list[int]] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
self,
classes: Optional[list[int]] = None,
log_loss: bool = False,
from_logits: bool = True,
mask_to_one_hot: bool = False,
reduction: str = 'mean',
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
):
"""Jaccard loss for image segmentation task.
It supports binary, multiclass and multilabel cases
It supports binary, multiclass and multilabel cases.
Ground truth masks should have shape (B, C, H, W) for multiclass and multilabel cases
or (B, 1, H, W) for binary case. For the multiclass case, the ground truth mask can also
have shape (B, 1, H, W) but you should set mask_to_one_hot = True.
Args:
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
classes: List of classes that contribute in loss computation.
- classes: List of classes that contribute in loss computation.
By default, all channels are included.
log_loss: If True, loss computed as `- log(jaccard_coeff)`,
- log_loss: If True, loss computed as `- log(jaccard_coeff)`,
otherwise `1 - jaccard_coeff`
from_logits: If True, assumes input is raw logits
smooth: Smoothness constant for dice coefficient
ignore_index: Label that indicates ignored pixels
- from_logits: If True, assumes input is raw logits.
- mask_to_one_hot: if set to True, the mask is converted into one-hot format.
- reduction: select the reduction to be applied to the loss.
- smooth: Smoothness constant for dice coefficient
- ignore_index: Label that indicates ignored pixels
(does not contribute to loss)
eps: A small epsilon for numerical stability to avoid zero division error
- eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
- **y_pred** - torch.Tensor of shape (B, C, H, W),
- **y_true** - torch.Tensor of shape (B, C, H, W) or (B, 1, H, W),
where C is the number of classes.
Reference
https://github.com/BloodAxe/pytorch-toolbelt
https://docs.monai.io/en/stable/_modules/monai/losses/dice.html#DiceLoss
"""
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
if reduction not in LossReduction.available_reductions():
raise ValueError(f'Unsupported reduction: {reduction}, '
f'available options are {LossReduction.available_reductions()}.')
super().__init__()

self.mode = mode
if classes is not None:
assert (
mode != BINARY_MODE
), "Masking classes is not supported with mode=binary"
classes = to_tensor(classes, dtype=torch.long)

self.classes = classes
self.from_logits = from_logits
self.mask_to_one_hot = mask_to_one_hot
self.reduction = reduction
self.smooth = smooth
self.eps = eps
self.ignore_index = ignore_index
self.log_loss = log_loss
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.size(0) == y_pred.size(0)

if self.from_logits:
# Apply activations to get [0..1] class probabilities
# Using Log-Exp as this gives more numerically stable
# result and does not cause vanishing gradient on
# extreme values 0 and 1
if self.mode == MULTICLASS_MODE:
y_pred = y_pred.log_softmax(dim=1).exp()
else:
y_pred = F.logsigmoid(y_pred).exp()

bs = y_true.size(0)
num_classes = y_pred.size(1)
dims = (0, 2)

if self.mode == BINARY_MODE:
y_true = y_true.view(bs, 1, -1)
y_pred = y_pred.view(bs, 1, -1)
batch_size = y_pred.shape[0]
num_classes = y_pred.shape[1]
spatial_dims: list[int] = torch.arange(2, len(y_pred.shape)).tolist()

if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask

if self.mode == MULTICLASS_MODE:
y_true = y_true.view(bs, -1)
y_pred = y_pred.view(bs, num_classes, -1)
if self.classes is not None:
if num_classes == 1:
warnings.warn("Single channel prediction, masking classes is not supported for Binary Segmentation")
else:
self.classes = to_tensor(self.classes, dtype = torch.long)

if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask.unsqueeze(1)
if self.from_logits:
# Convert logits to class probabilities using Sigmoid for Binary Case
# and Softmax for multiclass/multilabels cases.
# Using log-exp formulation as it is more numerically stable
# and does not cause vanishing gradient.
if num_classes == 1:
y_pred = F.logsigmoid(y_pred).exp()
else:
y_pred = F.log_softmax(y_pred, dim = 1).exp()

y_true = F.one_hot(
(y_true * mask).to(torch.long), num_classes
) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
if self.mask_to_one_hot:
# Convert y_true to one_hot representation to compute DiceLoss
if num_classes == 1:
warnings.warn("Single channel prediction, 'mask_to_one_hot = True' ignored.")
else:
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # H, C, H*W

if self.mode == MULTILABEL_MODE:
y_true = y_true.view(bs, num_classes, -1)
y_pred = y_pred.view(bs, num_classes, -1)

if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask

scores = soft_jaccard_score(
y_pred,
y_true.type(y_pred.dtype),
smooth=self.smooth,
eps=self.eps,
dims=dims,
)
# maybe there is a better way to handle this?
permute_dims = tuple(dim - 1 for dim in spatial_dims)
y_true = F.one_hot(y_true, num_classes).squeeze(dim = 1) # N, 1, H, W, ... ---> N, H, W, ..., C
y_true = y_true.permute(0, -1, *permute_dims) # N, 1, H, W, ..., C ---> N, C, H, W, ...

if y_true.shape != y_pred.shape:
raise AssertionError(f"Ground truth has different shape ({y_true.shape})"
f" from predicted mask ({y_pred.shape})")

# Only reduce spatial dimensions
scores = soft_jaccard_score(y_pred,
y_true.type_as(y_pred),
smooth = self.smooth,
eps = self.eps,
dims = spatial_dims)

if self.log_loss:
loss = -torch.log(scores.clamp_min(self.eps))
Expand All @@ -127,11 +115,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
# So we zero contribution of channel that does not have true pixels
# NOTE: A better workaround would be to use loss term `mean(y_pred)`
# for this case, however it will be a modified jaccard loss

mask = y_true.sum(dims) > 0
loss *= mask.float()
### same as dice loss, should we remove this?
# to delete?
# dims = tuple(d for d in range(len(y_true.shape)) if d != 1)
# mask = y_true.sum(dims) > 0
# loss *= mask.to(loss.dtype)

if self.classes is not None:
loss = loss[self.classes]
loss = loss[:, self.classes, :]

if self.reduction == LossReduction.MEAN:
loss = torch.mean(loss)
elif self.reduction == LossReduction.SUM:
loss = torch.sum(loss)
elif self.reduction == LossReduction.NONE:
broadcast_shape = list(loss.shape[0:2]) + [1] * (len(y_true.shape) - 2)
loss = loss.view(broadcast_shape)

return loss.mean()
return loss

0 comments on commit 23de26d

Please sign in to comment.