Skip to content

Commit 3f74183

Browse files
committed
Add helper function
1 parent cfd2d1e commit 3f74183

File tree

3 files changed

+88
-31
lines changed

3 files changed

+88
-31
lines changed

monai/losses/dice.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
import numpy as np
1919
import torch
20-
import torch.linalg as LA
2120
import torch.nn as nn
2221
import torch.nn.functional as F
2322
from torch.nn.modules.loss import _Loss
2423

2524
from monai.losses.focal_loss import FocalLoss
2625
from monai.losses.spatial_mask import MaskedLoss
26+
from monai.losses.utils import compute_tp_fp_fn
2727
from monai.networks import one_hot
2828
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
2929

@@ -67,6 +67,7 @@ def __init__(
6767
smooth_dr: float = 1e-5,
6868
batch: bool = False,
6969
weight: Sequence[float] | float | int | torch.Tensor | None = None,
70+
soft_label: bool = False,
7071
) -> None:
7172
"""
7273
Args:
@@ -98,6 +99,7 @@ def __init__(
9899
of the sequence should be the same as the number of classes. If not ``include_background``,
99100
the number of classes should not include the background category class 0).
100101
The value/values should be no less than 0. Defaults to None.
102+
soft_label: whether the target contains non-binary values or not
101103
102104
Raises:
103105
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -123,6 +125,7 @@ def __init__(
123125
weight = torch.as_tensor(weight) if weight is not None else None
124126
self.register_buffer("class_weight", weight)
125127
self.class_weight: None | torch.Tensor
128+
self.soft_label = soft_label
126129

127130
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
128131
"""
@@ -183,22 +186,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
183186
# reducing spatial dimensions and batch
184187
reduce_axis = [0] + reduce_axis
185188

186-
if self.squared_pred:
187-
ground_o = torch.sum(target**2, dim=reduce_axis)
188-
pred_o = torch.sum(input**2, dim=reduce_axis)
189-
difference = LA.vector_norm(input - target, ord=2, dim=reduce_axis) ** 2
190-
else:
191-
ground_o = torch.sum(target, dim=reduce_axis)
192-
pred_o = torch.sum(input, dim=reduce_axis)
193-
difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis)
194-
195-
denominator = ground_o + pred_o
196-
intersection = (denominator - difference) / 2
197-
198-
if self.jaccard:
199-
denominator = 2.0 * (denominator - intersection)
189+
ord = 2 if self.squared_pred else 1
190+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
191+
if not self.jaccard:
192+
fp *= 0.5
193+
fn *= 0.5
194+
numerator = 2 * tp + self.smooth_nr
195+
denominator = 2 * (tp + fp + fn) + self.smooth_dr
200196

201-
f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)
197+
f = 1 - numerator / denominator
202198

203199
num_of_classes = target.shape[1]
204200
if self.class_weight is not None and num_of_classes != 1:
@@ -282,6 +278,7 @@ def __init__(
282278
smooth_nr: float = 1e-5,
283279
smooth_dr: float = 1e-5,
284280
batch: bool = False,
281+
soft_label: bool = False,
285282
) -> None:
286283
"""
287284
Args:
@@ -305,6 +302,7 @@ def __init__(
305302
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
306303
Defaults to False, intersection over union is computed from each item in the batch.
307304
If True, the class-weighted intersection and union areas are first summed across the batches.
305+
soft_label: whether the target contains non-binary values or not
308306
309307
Raises:
310308
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -329,6 +327,7 @@ def __init__(
329327
self.smooth_nr = float(smooth_nr)
330328
self.smooth_dr = float(smooth_dr)
331329
self.batch = batch
330+
self.soft_label = soft_label
332331

333332
def w_func(self, grnd):
334333
if self.w_type == str(Weight.SIMPLE):
@@ -381,13 +380,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
381380
if self.batch:
382381
reduce_axis = [0] + reduce_axis
383382

384-
ground_o = torch.sum(target, reduce_axis)
385-
pred_o = torch.sum(input, reduce_axis)
386-
difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis)
387-
388-
denominator = ground_o + pred_o
389-
intersection = (denominator - difference) / 2
383+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)
384+
fp *= 0.5
385+
fn *= 0.5
386+
denominator = 2 * (tp + fp + fn)
390387

388+
ground_o = torch.sum(target, reduce_axis)
391389
w = self.w_func(ground_o.float())
392390
infs = torch.isinf(w)
393391
if self.batch:
@@ -399,7 +397,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
399397
w = w + infs * max_values
400398

401399
final_reduce_dim = 0 if self.batch else 1
402-
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
400+
numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
403401
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
404402
f: torch.Tensor = 1.0 - (numer / denom)
405403

monai/losses/tversky.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from collections.abc import Callable
1616

1717
import torch
18-
import torch.linalg as LA
1918
from torch.nn.modules.loss import _Loss
2019

20+
from monai.losses.utils import compute_tp_fp_fn
2121
from monai.networks import one_hot
2222
from monai.utils import LossReduction
2323

@@ -50,6 +50,7 @@ def __init__(
5050
smooth_nr: float = 1e-5,
5151
smooth_dr: float = 1e-5,
5252
batch: bool = False,
53+
soft_label: bool = False,
5354
) -> None:
5455
"""
5556
Args:
@@ -74,6 +75,7 @@ def __init__(
7475
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
7576
Defaults to False, a Dice loss value is computed independently from each item in the batch
7677
before any `reduction`.
78+
soft_label: whether the target contains non-binary values or not
7779
7880
Raises:
7981
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -97,6 +99,7 @@ def __init__(
9799
self.smooth_nr = float(smooth_nr)
98100
self.smooth_dr = float(smooth_dr)
99101
self.batch = batch
102+
self.soft_label = soft_label
100103

101104
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
102105
"""
@@ -144,13 +147,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
144147
# reducing spatial dimensions and batch
145148
reduce_axis = [0] + reduce_axis
146149

147-
pred_o = torch.sum(input, reduce_axis)
148-
ground_o = torch.sum(target, reduce_axis)
149-
difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis)
150-
151-
tp = (pred_o + ground_o - difference) / 2
152-
fp = self.alpha * (pred_o - tp)
153-
fn = self.beta * (ground_o - tp)
150+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)
151+
fp *= self.alpha
152+
fn *= self.beta
154153
numerator = tp + self.smooth_nr
155154
denominator = tp + fp + fn + self.smooth_dr
156155

monai/losses/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
16+
import torch
17+
import torch.linalg as LA
18+
19+
20+
def compute_tp_fp_fn(
21+
input: torch.Tensor,
22+
target: torch.Tensor,
23+
reduce_axis: list[int],
24+
ord: int,
25+
soft_label: bool,
26+
decoupled: bool = True,
27+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
28+
"""
29+
Adapted from:
30+
https://github.com/zifuwanggg/JDTLosses
31+
"""
32+
if torch.unique(target).shape[0] > 2 and not soft_label:
33+
warnings.warn("soft labels are used, but `soft_label == False`.")
34+
35+
# the original implementation that is erroneous with soft labels
36+
if ord == 1 and not soft_label:
37+
tp = torch.sum(input * target, dim=reduce_axis)
38+
# the original implementation of Dice and Jaccard loss
39+
if decoupled:
40+
fp = torch.sum(input, dim=reduce_axis) - tp
41+
fn = torch.sum(target, dim=reduce_axis) - tp
42+
# the original implementation of Tversky loss
43+
else:
44+
fp = torch.sum(input * (1 - target), dim=reduce_axis)
45+
fn = torch.sum((1 - input) * target, dim=reduce_axis)
46+
else:
47+
pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
48+
ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
49+
difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)
50+
51+
if ord > 1:
52+
pred_o = torch.pow(pred_o, exponent=ord)
53+
ground_o = torch.pow(ground_o, exponent=ord)
54+
difference = torch.pow(difference, exponent=ord)
55+
56+
tp = (pred_o + ground_o - difference) / 2
57+
fp = pred_o - tp
58+
fn = ground_o - tp
59+
60+
return tp, fp, fn

0 commit comments

Comments
 (0)