1717
1818import numpy as np
1919import torch
20- import torch .linalg as LA
2120import torch .nn as nn
2221import torch .nn .functional as F
2322from torch .nn .modules .loss import _Loss
2423
2524from monai .losses .focal_loss import FocalLoss
2625from monai .losses .spatial_mask import MaskedLoss
26+ from monai .losses .utils import compute_tp_fp_fn
2727from monai .networks import one_hot
2828from monai .utils import DiceCEReduction , LossReduction , Weight , look_up_option , pytorch_after
2929
@@ -67,6 +67,7 @@ def __init__(
6767 smooth_dr : float = 1e-5 ,
6868 batch : bool = False ,
6969 weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
70+ soft_label : bool = False ,
7071 ) -> None :
7172 """
7273 Args:
@@ -98,6 +99,7 @@ def __init__(
9899 of the sequence should be the same as the number of classes. If not ``include_background``,
99100 the number of classes should not include the background category class 0).
100101 The value/values should be no less than 0. Defaults to None.
102+ soft_label: whether the target contains non-binary values or not
101103
102104 Raises:
103105 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -123,6 +125,7 @@ def __init__(
123125 weight = torch .as_tensor (weight ) if weight is not None else None
124126 self .register_buffer ("class_weight" , weight )
125127 self .class_weight : None | torch .Tensor
128+ self .soft_label = soft_label
126129
127130 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
128131 """
@@ -183,22 +186,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
183186 # reducing spatial dimensions and batch
184187 reduce_axis = [0 ] + reduce_axis
185188
186- if self .squared_pred :
187- ground_o = torch .sum (target ** 2 , dim = reduce_axis )
188- pred_o = torch .sum (input ** 2 , dim = reduce_axis )
189- difference = LA .vector_norm (input - target , ord = 2 , dim = reduce_axis ) ** 2
190- else :
191- ground_o = torch .sum (target , dim = reduce_axis )
192- pred_o = torch .sum (input , dim = reduce_axis )
193- difference = LA .vector_norm (input - target , ord = 1 , dim = reduce_axis )
194-
195- denominator = ground_o + pred_o
196- intersection = (denominator - difference ) / 2
197-
198- if self .jaccard :
199- denominator = 2.0 * (denominator - intersection )
189+ ord = 2 if self .squared_pred else 1
190+ tp , fp , fn = compute_tp_fp_fn (input , target , reduce_axis , ord , self .soft_label )
191+ if not self .jaccard :
192+ fp *= 0.5
193+ fn *= 0.5
194+ numerator = 2 * tp + self .smooth_nr
195+ denominator = 2 * (tp + fp + fn ) + self .smooth_dr
200196
201- f : torch . Tensor = 1.0 - ( 2.0 * intersection + self . smooth_nr ) / ( denominator + self . smooth_dr )
197+ f = 1 - numerator / denominator
202198
203199 num_of_classes = target .shape [1 ]
204200 if self .class_weight is not None and num_of_classes != 1 :
@@ -282,6 +278,7 @@ def __init__(
282278 smooth_nr : float = 1e-5 ,
283279 smooth_dr : float = 1e-5 ,
284280 batch : bool = False ,
281+ soft_label : bool = False ,
285282 ) -> None :
286283 """
287284 Args:
@@ -305,6 +302,7 @@ def __init__(
305302 batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
306303 Defaults to False, intersection over union is computed from each item in the batch.
307304 If True, the class-weighted intersection and union areas are first summed across the batches.
305+ soft_label: whether the target contains non-binary values or not
308306
309307 Raises:
310308 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -329,6 +327,7 @@ def __init__(
329327 self .smooth_nr = float (smooth_nr )
330328 self .smooth_dr = float (smooth_dr )
331329 self .batch = batch
330+ self .soft_label = soft_label
332331
333332 def w_func (self , grnd ):
334333 if self .w_type == str (Weight .SIMPLE ):
@@ -381,13 +380,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
381380 if self .batch :
382381 reduce_axis = [0 ] + reduce_axis
383382
384- ground_o = torch .sum (target , reduce_axis )
385- pred_o = torch .sum (input , reduce_axis )
386- difference = LA .vector_norm (input - target , ord = 1 , dim = reduce_axis )
387-
388- denominator = ground_o + pred_o
389- intersection = (denominator - difference ) / 2
383+ tp , fp , fn = compute_tp_fp_fn (input , target , reduce_axis , 1 , self .soft_label )
384+ fp *= 0.5
385+ fn *= 0.5
386+ denominator = 2 * (tp + fp + fn )
390387
388+ ground_o = torch .sum (target , reduce_axis )
391389 w = self .w_func (ground_o .float ())
392390 infs = torch .isinf (w )
393391 if self .batch :
@@ -399,7 +397,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
399397 w = w + infs * max_values
400398
401399 final_reduce_dim = 0 if self .batch else 1
402- numer = 2.0 * (intersection * w ).sum (final_reduce_dim , keepdim = True ) + self .smooth_nr
400+ numer = 2.0 * (tp * w ).sum (final_reduce_dim , keepdim = True ) + self .smooth_nr
403401 denom = (denominator * w ).sum (final_reduce_dim , keepdim = True ) + self .smooth_dr
404402 f : torch .Tensor = 1.0 - (numer / denom )
405403
0 commit comments