1212from __future__ import annotations
1313
1414import math
15+ import warnings
1516
1617import torch
1718import torch .nn as nn
2122from monai .utils import pytorch_after
2223
2324
25+ def get_mean_kernel_2d (ksize : int = 3 ) -> torch .Tensor :
26+ mean_kernel = torch .ones ([ksize , ksize ]) / (ksize ** 2 )
27+ return mean_kernel
28+
29+
30+ def get_mean_kernel_3d (ksize : int = 3 ) -> torch .Tensor :
31+ mean_kernel = torch .ones ([ksize , ksize , ksize ]) / (ksize ** 3 )
32+ return mean_kernel
33+
34+
2435def get_gaussian_kernel_2d (ksize : int = 3 , sigma : float = 1.0 ) -> torch .Tensor :
2536 x_grid = torch .arange (ksize ).repeat (ksize ).view (ksize , ksize )
2637 y_grid = x_grid .t ()
@@ -101,6 +112,50 @@ def forward(self, x):
101112 return self .svls_layer (x ) / self .svls_kernel .sum ()
102113
103114
115+ class MeanFilter (torch .nn .Module ):
116+ def __init__ (self , dim : int = 3 , ksize : int = 3 , channels : int = 0 ) -> torch .Tensor :
117+ super (MeanFilter , self ).__init__ ()
118+
119+ if dim == 2 :
120+ self .svls_kernel = get_mean_kernel_2d (ksize = ksize )
121+ svls_kernel_2d = self .svls_kernel .view (1 , 1 , ksize , ksize )
122+ svls_kernel_2d = svls_kernel_2d .repeat (channels , 1 , 1 , 1 )
123+ padding = int (ksize / 2 )
124+
125+ self .svls_layer = torch .nn .Conv2d (
126+ in_channels = channels ,
127+ out_channels = channels ,
128+ kernel_size = ksize ,
129+ groups = channels ,
130+ bias = False ,
131+ padding = padding ,
132+ padding_mode = "replicate" ,
133+ )
134+ self .svls_layer .weight .data = svls_kernel_2d
135+ self .svls_layer .weight .requires_grad = False
136+
137+ if dim == 3 :
138+ self .svls_kernel = get_mean_kernel_3d (ksize = ksize )
139+ svls_kernel_3d = self .svls_kernel .view (1 , 1 , ksize , ksize )
140+ svls_kernel_3d = svls_kernel_3d .repeat (channels , 1 , 1 , 1 )
141+ padding = int (ksize / 2 )
142+
143+ self .svls_layer = torch .nn .Conv3d (
144+ in_channels = channels ,
145+ out_channels = channels ,
146+ kernel_size = ksize ,
147+ groups = channels ,
148+ bias = False ,
149+ padding = padding ,
150+ padding_mode = "replicate" ,
151+ )
152+ self .svls_layer .weight .data = svls_kernel_3d
153+ self .svls_layer .weight .requires_grad = False
154+
155+ def forward (self , x ):
156+ return self .svls_layer (x ) / self .svls_kernel .sum ()
157+
158+
104159class NACLLoss (_Loss ):
105160 """
106161 Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
@@ -118,6 +173,7 @@ def __init__(
118173 classes : int ,
119174 dim : int ,
120175 kernel_size : int = 3 ,
176+ kernel_ops : str = "mean" ,
121177 distance_type : str = "l1" ,
122178 alpha : float = 0.1 ,
123179 sigma : float = 1.0 ,
@@ -133,6 +189,9 @@ def __init__(
133189
134190 super ().__init__ ()
135191
192+ if kernel_ops not in ["mean" , "gaussian" ]:
193+ raise ValueError ("Kernel ops must be either mean or gaussian" )
194+
136195 if dim not in [2 , 3 ]:
137196 raise ValueError ("Supoorts 2d and 3d" )
138197
@@ -146,7 +205,10 @@ def __init__(
146205 self .alpha = alpha
147206 self .ks = kernel_size
148207
149- self .svls_layer = GaussianFilter (dim = dim , ksize = kernel_size , sigma = sigma , channels = classes )
208+ if kernel_ops == "mean" :
209+ self .svls_layer = MeanFilter (dim = dim , ksize = kernel_size , channels = classes )
210+ if kernel_ops == "gaussian" :
211+ self .svls_layer = GaussianFilter (dim = dim , ksize = kernel_size , sigma = sigma , channels = classes )
150212
151213 self .old_pt_ver = not pytorch_after (1 , 10 )
152214
@@ -173,24 +235,16 @@ def __init__(
173235 # return self.cross_entropy(input, target) # type: ignore[no-any-return]
174236
175237 def get_constr_target (self , mask : torch .Tensor ) -> torch .Tensor :
176-
177238 if self .dim == 2 :
178-
179- oh_labels = (
180- F .one_hot (mask .to (torch .int64 ), num_classes = self .nc ).contiguous ().permute (0 , 3 , 1 , 2 ).float ()
181- )
239+ oh_labels = F .one_hot (mask .to (torch .int64 ), num_classes = self .nc ).contiguous ().permute (0 , 3 , 1 , 2 ).float ()
182240 rmask = self .svls_layer (oh_labels )
183241
184242 if self .dim == 3 :
185-
186- oh_labels = (
187- F .one_hot (mask .to (torch .int64 ), num_classes = self .nc ).contiguous ().permute (0 , 4 , 1 , 2 , 3 ).float ()
188- )
243+ oh_labels = F .one_hot (mask .to (torch .int64 ), num_classes = self .nc ).contiguous ().permute (0 , 4 , 1 , 2 , 3 ).float ()
189244 rmask = self .svls_layer (oh_labels )
190245
191246 return rmask
192247
193-
194248 def forward (self , inputs : torch .Tensor , targets : torch .Tensor ) -> torch .Tensor :
195249 loss_ce = self .cross_entropy (inputs , targets )
196250
0 commit comments