Skip to content

Commit 7deb2cc

Browse files
authored
Update nacl_loss.py
1 parent d33f435 commit 7deb2cc

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

monai/losses/nacl_loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
8787
Converts the mask to one hot represenation and applies the spatial filter.
8888
8989
Args:
90-
mask: the shape should be BHW[D]
90+
mask: the shape should be BH[WD].
9191
9292
Returns:
93-
torch.Tensor: the shape would be BNHW[D], N being number of classes.
93+
torch.Tensor: the shape would be BNH[WD], N being number of classes.
9494
"""
9595
rmask: torch.Tensor
9696

@@ -109,8 +109,8 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
109109
Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.
110110
111111
Args:
112-
inputs: the shape should be BNHW[D], where N is the number of classes.
113-
targets: the shape should be BHW[D].
112+
inputs: the shape should be BNH[WD], where N is the number of classes.
113+
targets: the shape should be BH[WD].
114114
115115
Returns:
116116
torch.Tensor: value of the loss.
@@ -122,7 +122,7 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
122122
>>> input = torch.rand(B, N, H, W)
123123
>>> target = torch.randint(0, N, (B, H, W))
124124
>>> criterion = NACLLoss(classes = N, dim = 2)
125-
>>> loss = self(input, target)
125+
>>> loss = criterion(input, target)
126126
"""
127127

128128
loss_ce = self.cross_entropy(inputs, targets)

0 commit comments

Comments
 (0)