Skip to content

CTCLoss cuda backend computes wrong gradient when target (i.e. label) length is greater than 896 for double inputs or 1024 for float inputs #27442

@ASDen

Description

@ASDen

🐛 Bug

First, really thanks for providing the only gpu optimized implementation of CTC with unbounded label / target length!

The bug I found here is that although the loss itself is computed correctly (when compared to CPU implementation on pytorch and tenserflow), the gradients are wrong when the label length increases above 896 for double inputs (logits) or 1024 for floats

Just judging from those numbers, I think this maybe related to these lines from LossCTC.cu
std::is_same<scalar_t, float>::value ? 1024 : 896

To Reproduce

Use following snippet

from torch import nn
import torch
import math
import time

time_step  = 4000  # Input sequence length
vocab_size = 3  # Number of classes
batch_size = 4  # Batch size
target_sql = 700  # Target sequence length

ctc_loss  = nn.CTCLoss(reduction='sum')

for j in range(10):

  print('\nLogits length : ',time_step,'Label length : ',target_sql)
  print('-------------------')
  
  x = torch.randn(time_step, batch_size, vocab_size).requires_grad_().cuda()#.double() #uncomment to use double
  x = nn.Parameter(x)
  y = torch.randint(low=1, high=vocab_size-1, size=(batch_size, target_sql),dtype=torch.long).cuda()

  x_lengths = torch.full(size=(batch_size,), fill_value=time_step, dtype=torch.long).cuda()
  y_lengths = torch.full(size=(batch_size,), fill_value=target_sql, dtype=torch.long).cuda()

  loss1 = ctc_loss(x, y, x_lengths, y_lengths)
  loss2 = ctc_loss(x.cpu(), y.cpu(), x_lengths.cpu(), y_lengths.cpu()) 
  
  loss1.backward()
  tg1 = x.grad.clone().detach()
  x.grad.zero_()

  loss2.backward()
  tg2 = x.grad.clone().detach()
  x.grad.zero_()

  print('Grads  Diff : ',torch.norm(tg1-tg2).item())
  print('Losses Diff : ',torch.norm(loss1.cpu()-loss2).item())

  target_sql += 100

The output for the float case the following, notice the huge difference in gradients starting from 1100 target length, more debugging shows that 1024 works fine but 1025 is corrupted

Logits length :  4000 Label length :  700
-------------------
Grads  Diff :  0.03249376639723778
Losses Diff :  0.0

Logits length :  4000 Label length :  800
-------------------
Grads  Diff :  0.03427853062748909
Losses Diff :  0.0009765625

Logits length :  4000 Label length :  900
-------------------
Grads  Diff :  0.03299879655241966
Losses Diff :  0.0

Logits length :  4000 Label length :  1000
-------------------
Grads  Diff :  0.03270421177148819
Losses Diff :  0.0009765625

Logits length :  4000 Label length :  1100
-------------------
Grads  Diff :  71.10806274414062
Losses Diff :  0.0

Logits length :  4000 Label length :  1200
-------------------
Grads  Diff :  70.92623901367188
Losses Diff :  0.0009765625

Logits length :  4000 Label length :  1300
-------------------
Grads  Diff :  71.46635437011719
Losses Diff :  0.0009765625

Logits length :  4000 Label length :  1400
-------------------
Grads  Diff :  72.2627182006836
Losses Diff :  0.0

Logits length :  4000 Label length :  1500
-------------------
Grads  Diff :  72.44979858398438
Losses Diff :  0.0

Logits length :  4000 Label length :  1600
-------------------
Grads  Diff :  73.44546508789062
Losses Diff :  0.0009765625

The output for the double case is the following, notice the huge difference in gradients starting from 900 target length, more debugging shows that 896 works fine but 897 is corrupted

Logits length :  4000 Label length :  700
-------------------
Grads  Diff :  8.477306146763702e-11
Losses Diff :  0.0

Logits length :  4000 Label length :  800
-------------------
Grads  Diff :  8.588737980344711e-11
Losses Diff :  0.0

Logits length :  4000 Label length :  900
-------------------
Grads  Diff :  71.26078984351248
Losses Diff :  1.8189894035458565e-12

Logits length :  4000 Label length :  1000
-------------------
Grads  Diff :  70.95303213362187
Losses Diff :  0.0

Logits length :  4000 Label length :  1100
-------------------
Grads  Diff :  70.60947779103446
Losses Diff :  0.0

Logits length :  4000 Label length :  1200
-------------------
Grads  Diff :  70.94595045695341
Losses Diff :  1.8189894035458565e-12

Logits length :  4000 Label length :  1300
-------------------
Grads  Diff :  71.45715679703945
Losses Diff :  1.8189894035458565e-12

Logits length :  4000 Label length :  1400
-------------------
Grads  Diff :  71.74132504186824
Losses Diff :  0.0

Logits length :  4000 Label length :  1500
-------------------
Grads  Diff :  72.6907375164959
Losses Diff :  0.0

Logits length :  4000 Label length :  1600
-------------------
Grads  Diff :  73.49384136194594
Losses Diff :  0.0

Environment

  • PyTorch Version (e.g., 1.0): tested problem on 1.1 & 1.2 and master
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.6
  • CUDA/cuDNN version: CUDA: 10.0.130, cuDNN: 7.4.2
  • GPU models and configuration: tested on V100 & K80

@t-vi

Metadata

Metadata

Assignees

Labels

module: cudaRelated to torch.cuda, and CUDA support in generalmodule: lossProblem is related to loss functiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions