-
Notifications
You must be signed in to change notification settings - Fork 1.5k
GeneralizedDiceLoss is not differentiable #3618
Description
Describe the bug
The code used for weight clamping in the forward method of the GeneralizedDiceLoss module is not differentiable:
w = self.w_func(ground_o.float())
for b in w:
infs = torch.isinf(b)
b[infs] = 0.0
b[infs] = torch.max(b)When computing the loss gradients, Pytorch throws the following error:
RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
To Reproduce
The bug can be reproduced using the following python script:
import torch
from monai.losses.dice import GeneralizedDiceLoss
prediction = torch.zeros(2, 3, 3, 3)
target = torch.zeros(2, 3, 3, 3)
prediction.requires_grad = True
target.requires_grad = True
generalized_dice_loss = GeneralizedDiceLoss()
loss = generalized_dice_loss(prediction, target)Note that the error does not occur if requires_grad is set to False for prediction and target.
Expected behavior
The forward method of the GeneralizedDiceLoss module should be differentiable.
Environment
Ensuring you use the relevant python executable, please paste the output of:
python -c 'import monai; monai.config.print_debug_info()'
================================
Printing MONAI config...
================================
MONAI version: 0.8.0
Numpy version: 1.22.0
Pytorch version: 1.10.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 714d00dffe6653e21260160666c4c201ab66511b
Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: NOT INSTALLED or UNKNOWN VERSION.
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
================================
Printing system config...
================================
`psutil` required for `print_system_info`
================================
Printing GPU config...
================================
Num GPUs: 0
Has CUDA: False
cuDNN enabled: False
Additional context