-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
36 lines (27 loc) · 987 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# File to store loss functions used to train UNet
import numpy as np
import torch
# return an array with the dice coeff per class
def get_dice_per_class(pred, target):
smooth = 1e-3
n_classes = pred.size(dim=1)
dice = np.zeros(n_classes)
iflat = torch.flatten(torch.swapaxes(pred, 0, 1), start_dim=1)
tflat = torch.flatten(torch.swapaxes(target, 0, 1), start_dim=1)
A_sum = iflat.sum(dim=1)
B_sum = tflat.sum(dim=1)
intersection = (iflat * tflat).sum(dim=1)
num = (2 * intersection) + smooth
denom = A_sum + B_sum + smooth
dice = num / denom
return dice
def dice_coeff(pred, target):
"""This definition generalize to real valued pred and target vector.
This should be differentiable.
pred: tensor with first dimension as batch
target: tensor with first dimension as batch
"""
dice = get_dice_per_class(pred, target)
dice = dice.mean(dim=0)
#dice = torch.clamp(dice, 0, 1.0-epsilon)
return dice