Skip to content

Commit dccde47

Browse files
committed
Add mean loss and resolve formatting
1 parent 0067953 commit dccde47

File tree

2 files changed

+249
-119
lines changed

2 files changed

+249
-119
lines changed

monai/losses/segcalib.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import math
15+
import warnings
1516

1617
import torch
1718
import torch.nn as nn
@@ -21,6 +22,16 @@
2122
from monai.utils import pytorch_after
2223

2324

25+
def get_mean_kernel_2d(ksize: int = 3) -> torch.Tensor:
26+
mean_kernel = torch.ones([ksize, ksize]) / (ksize**2)
27+
return mean_kernel
28+
29+
30+
def get_mean_kernel_3d(ksize: int = 3) -> torch.Tensor:
31+
mean_kernel = torch.ones([ksize, ksize, ksize]) / (ksize**3)
32+
return mean_kernel
33+
34+
2435
def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor:
2536
x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize)
2637
y_grid = x_grid.t()
@@ -101,6 +112,50 @@ def forward(self, x):
101112
return self.svls_layer(x) / self.svls_kernel.sum()
102113

103114

115+
class MeanFilter(torch.nn.Module):
116+
def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor:
117+
super(MeanFilter, self).__init__()
118+
119+
if dim == 2:
120+
self.svls_kernel = get_mean_kernel_2d(ksize=ksize)
121+
svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize)
122+
svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1)
123+
padding = int(ksize / 2)
124+
125+
self.svls_layer = torch.nn.Conv2d(
126+
in_channels=channels,
127+
out_channels=channels,
128+
kernel_size=ksize,
129+
groups=channels,
130+
bias=False,
131+
padding=padding,
132+
padding_mode="replicate",
133+
)
134+
self.svls_layer.weight.data = svls_kernel_2d
135+
self.svls_layer.weight.requires_grad = False
136+
137+
if dim == 3:
138+
self.svls_kernel = get_mean_kernel_3d(ksize=ksize)
139+
svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize)
140+
svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1)
141+
padding = int(ksize / 2)
142+
143+
self.svls_layer = torch.nn.Conv3d(
144+
in_channels=channels,
145+
out_channels=channels,
146+
kernel_size=ksize,
147+
groups=channels,
148+
bias=False,
149+
padding=padding,
150+
padding_mode="replicate",
151+
)
152+
self.svls_layer.weight.data = svls_kernel_3d
153+
self.svls_layer.weight.requires_grad = False
154+
155+
def forward(self, x):
156+
return self.svls_layer(x) / self.svls_kernel.sum()
157+
158+
104159
class NACLLoss(_Loss):
105160
"""
106161
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
@@ -118,6 +173,7 @@ def __init__(
118173
classes: int,
119174
dim: int,
120175
kernel_size: int = 3,
176+
kernel_ops: str = "mean",
121177
distance_type: str = "l1",
122178
alpha: float = 0.1,
123179
sigma: float = 1.0,
@@ -133,6 +189,9 @@ def __init__(
133189

134190
super().__init__()
135191

192+
if kernel_ops not in ["mean", "gaussian"]:
193+
raise ValueError("Kernel ops must be either mean or gaussian")
194+
136195
if dim not in [2, 3]:
137196
raise ValueError("Supoorts 2d and 3d")
138197

@@ -146,7 +205,10 @@ def __init__(
146205
self.alpha = alpha
147206
self.ks = kernel_size
148207

149-
self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes)
208+
if kernel_ops == "mean":
209+
self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes)
210+
if kernel_ops == "gaussian":
211+
self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes)
150212

151213
self.old_pt_ver = not pytorch_after(1, 10)
152214

@@ -173,24 +235,16 @@ def __init__(
173235
# return self.cross_entropy(input, target) # type: ignore[no-any-return]
174236

175237
def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
176-
177238
if self.dim == 2:
178-
179-
oh_labels = (
180-
F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
181-
)
239+
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
182240
rmask = self.svls_layer(oh_labels)
183241

184242
if self.dim == 3:
185-
186-
oh_labels = (
187-
F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
188-
)
243+
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
189244
rmask = self.svls_layer(oh_labels)
190245

191246
return rmask
192247

193-
194248
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
195249
loss_ce = self.cross_entropy(inputs, targets)
196250

0 commit comments

Comments
 (0)