@@ -87,10 +87,10 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
8787 Converts the mask to one hot represenation and applies the spatial filter.
8888
8989 Args:
90- mask: the shape should be BHW[D]
90+ mask: the shape should be BH[WD].
9191
9292 Returns:
93- torch.Tensor: the shape would be BNHW[D ], N being number of classes.
93+ torch.Tensor: the shape would be BNH[WD ], N being number of classes.
9494 """
9595 rmask : torch .Tensor
9696
@@ -109,8 +109,8 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
109109 Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.
110110
111111 Args:
112- inputs: the shape should be BNHW[D ], where N is the number of classes.
113- targets: the shape should be BHW[D ].
112+ inputs: the shape should be BNH[WD ], where N is the number of classes.
113+ targets: the shape should be BH[WD ].
114114
115115 Returns:
116116 torch.Tensor: value of the loss.
@@ -122,7 +122,7 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
122122 >>> input = torch.rand(B, N, H, W)
123123 >>> target = torch.randint(0, N, (B, H, W))
124124 >>> criterion = NACLLoss(classes = N, dim = 2)
125- >>> loss = self (input, target)
125+ >>> loss = criterion (input, target)
126126 """
127127
128128 loss_ce = self .cross_entropy (inputs , targets )
0 commit comments