Skip to content

Issues trying to mask DiceCELoss #3917

@demcbs

Description

@demcbs

Goal
I am trying to segment 2 objects from (3D) images with one network. However, sometimes the masks are missing for one of the object. In that case I want to essentially ignore the loss for that object.

Method
In order to do this I modify the 'outputs' and 'masks' to be equal for these predictions, resulting in 0 loss over it, and add a correction factor for loss averaging, so. that loss will not be too optimistic from the perfect predictions.

Possibly related. Before using this 'trick' for the loss value calculation I tried to work with 'reduction='mean'' loss function output instead. However, DiceCELoss would throw an error because of a mismatch in the shape of the Dice and CE parts when summing.

Expected behavior

Note: Should be fully reproducible with the code below with the prints containing the necessary information.

Let's take a situation where all masks for 1 object are missing in a batch. Now, if I would split the batch at the channel dimension for both objects and calculate loss, the one with present masks should yield, say 'x', and the one with missing masks should yield 0, after applying above mentioned trick. If the trick should be working properly, one should also get a value of 'x' for loss calculated over the batch, without splitting the batch first.

I find that this is the case when I use DiceLoss. It is also the case when I use a basic binary cross entropy calculation (apart for a small offset to ensure no ln(0) happens). When using DiceCELoss this is not the case, however. I am not sure why this is. The DiceCELoss summing error is also included at the end.

from monai.losses import DiceLoss, DiceCELoss
import torch
import numpy as np

def BinaryCrossEntropy(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    term_0 = (1-y_true) * np.log(1-y_pred + 1e-7)
    term_1 = y_true * np.log(y_pred + 1e-7)
    return -np.mean(term_0+term_1)

def loss_split_and_combined(preds, masks, loss_function, trick_multiplier = None):
    object1_preds = preds[:,0:1,...]
    object2_preds = preds[:,1:2,...]
    
    object1_masks = masks[:,0:1,...]
    object2_masks = masks[:,1:2,...]

    if trick_multiplier is None:
        print("Plain: ")
        print("Object 1, Object 2, Combined")
        print(loss_function(object1_preds, object1_masks), loss_function(object2_preds, object2_masks), loss_function(preds, masks))
    else:
        print("After trick: ")
        print("Object 1, Object 2, Combined")
        print(loss_function(object1_preds, object1_masks), loss_function(object2_preds, object2_masks), loss_function(preds, masks) * trick_multiplier)

preds = torch.rand(4,2,160,160,16)
masks = torch.randint(high = 2, size = (4,2,160,160,16))
masks[:,1:2,...] = 0

# This is not the normal trick calculation, but good for an example if all object 2 masks in the batch are missing
object1_preds = preds[:,0:1,...]
object2_masks = masks[:,1:2,...]
trick_preds = torch.zeros(preds.shape)
trick_preds[:,0:1,...] = object1_preds
trick_preds[:,1:2,...] = object2_masks
trick_multiplier = trick_preds.numel()/object2_masks.numel()

print("\nDICE LOSS")
print("-" * 40)
loss_split_and_combined(preds, masks, DiceLoss(sigmoid = False))
loss_split_and_combined(trick_preds, masks, DiceLoss(sigmoid = False), trick_multiplier = trick_multiplier)
print("\nBINARY CROSS ENTROPY LOSS")
print("-" * 40)
loss_split_and_combined(preds, masks, BinaryCrossEntropy)
loss_split_and_combined(trick_preds, masks, BinaryCrossEntropy, trick_multiplier=trick_multiplier)
print("\nDICE CE LOSS")
print("-" * 40)
loss_split_and_combined(preds, masks, DiceCELoss(sigmoid = False))
loss_split_and_combined(trick_preds, masks, DiceCELoss(sigmoid = False), trick_multiplier = trick_multiplier)

DiceCELoss(sigmoid = True, reduction = 'none')(preds, masks)

Additional context
I am of course also open to other solutions to obtain my original goal.

Edit: Code didn't run as is after all, sorry. Added

object1_preds = preds[:,0:1,...]
object2_masks = masks[:,1:2,...]

before creating trick_preds

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions