Skip to content

Wrong multi-class label requirement for the compute_roc_auc function  #3095

@plu-project

Description

@plu-project

Describe the bug
monai.metrics.compute_roc_auc doesn't accept labels not in the one-hot format for multi-class classification.

To Reproduce

from monai.metrics import compute_roc_auc 
import torch
y = torch.randint(4, size=(10,1))
y_pred = torch.randint(4, size=(10,4))
y_pred = y_pred/y_pred.sum(dim=1).view(-1,1)
compute_roc_auc(y_pred, y)

Expected behavior
In the documentation, it says that if the label has a shape: [16, 1], that label will be converted into [16, 2] (where 2 is inferred from y_pred). The function should internally convert the label into the one-hot format or the documentation should explicitly ask for that.

Additional context
Potential fix for converting the label internally:
Adding from monai.networks.utils import one_hot ,
and

if y_pred_ndim == 2 and y.ndimension()==1:
    y = one_hot(y, num_classes=y_pred.shape[1])

after

if y_pred_ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
y_pred_ndim = 1
if y_ndim == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Relationships

None yet

Development

No branches or pull requests

Issue actions