Skip to content

CTCLoss with empty target doesn't work well #18215

@t-vi

Description

@t-vi

🐛 Bug

CTCLoss doesn't provide the correct gradient when the target sequence is empty.

To Reproduce

import torch

probs = torch.randn(2, 2, 3, dtype=torch.double).log_softmax(-1).requires_grad_()
labels = torch.tensor([1, 2])
label_sizes = [2, 0]
sizes = [2, 2]
loss = torch.nn.functional.ctc_loss(probs, labels, sizes, label_sizes, reduction='sum', zero_infinity=True)
loss2 = torch.nn.functional.ctc_loss(probs, labels, sizes, label_sizes, reduction='none', zero_infinity=True)
grad, = torch.autograd.grad(loss, probs)

probs_gpu = probs.detach().cuda().requires_grad_()
loss_gpu = torch.nn.functional.ctc_loss(probs_gpu, labels.cuda(), sizes, label_sizes, reduction='sum', zero_infinity=True)
loss2_gpu = torch.nn.functional.ctc_loss(probs_gpu, labels.cuda(), sizes, label_sizes, reduction='none', zero_infinity=True)
grad_gpu, = torch.autograd.grad(loss_gpu, probs_gpu)

print('loss:', loss, loss_gpu)
print('loss2:', loss2, loss2_gpu)
print('grad:', grad, "\n", grad_gpu)

print("grad_check cpu: ",
      torch.autograd.gradcheck(lambda logits: torch.nn.functional.ctc_loss(logits.log_softmax(-1), labels, sizes, label_sizes, reduction='sum', zero_infinity=True), (torch.randn(2, 2, 3, dtype=torch.double, requires_grad=True),), raise_exception=False))
print("grad_check gpu: ",
      torch.autograd.gradcheck(lambda logits: torch.nn.functional.ctc_loss(logits.log_softmax(-1), labels.cuda(), sizes, label_sizes, reduction='sum', zero_infinity=True), (torch.randn(2, 2, 3, dtype=torch.double, device='cuda', requires_grad=True),), raise_exception=False))

also the default reduction doesn't play well with zero length.

Expected behavior

Compute the proper loss and gradient (which would point in the direction of less "blank").

Acknowledgement

This has been pointed out by Evgeni Kirov, thank you for tracking this down!

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: bootcampWe plan to do a full writeup on the issue, and then get someone to do it for onboardingmodule: derivativesRelated to derivatives of operatorsmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions