-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
First, really thanks for providing the only gpu optimized implementation of CTC with unbounded label / target length!
The bug I found here is that although the loss itself is computed correctly (when compared to CPU implementation on pytorch and tenserflow), the gradients are wrong when the label length increases above 896 for double inputs (logits) or 1024 for floats
Just judging from those numbers, I think this maybe related to these lines from LossCTC.cu
std::is_same<scalar_t, float>::value ? 1024 : 896
To Reproduce
Use following snippet
from torch import nn
import torch
import math
import time
time_step = 4000 # Input sequence length
vocab_size = 3 # Number of classes
batch_size = 4 # Batch size
target_sql = 700 # Target sequence length
ctc_loss = nn.CTCLoss(reduction='sum')
for j in range(10):
print('\nLogits length : ',time_step,'Label length : ',target_sql)
print('-------------------')
x = torch.randn(time_step, batch_size, vocab_size).requires_grad_().cuda()#.double() #uncomment to use double
x = nn.Parameter(x)
y = torch.randint(low=1, high=vocab_size-1, size=(batch_size, target_sql),dtype=torch.long).cuda()
x_lengths = torch.full(size=(batch_size,), fill_value=time_step, dtype=torch.long).cuda()
y_lengths = torch.full(size=(batch_size,), fill_value=target_sql, dtype=torch.long).cuda()
loss1 = ctc_loss(x, y, x_lengths, y_lengths)
loss2 = ctc_loss(x.cpu(), y.cpu(), x_lengths.cpu(), y_lengths.cpu())
loss1.backward()
tg1 = x.grad.clone().detach()
x.grad.zero_()
loss2.backward()
tg2 = x.grad.clone().detach()
x.grad.zero_()
print('Grads Diff : ',torch.norm(tg1-tg2).item())
print('Losses Diff : ',torch.norm(loss1.cpu()-loss2).item())
target_sql += 100
The output for the float case the following, notice the huge difference in gradients starting from 1100 target length, more debugging shows that 1024 works fine but 1025 is corrupted
Logits length : 4000 Label length : 700
-------------------
Grads Diff : 0.03249376639723778
Losses Diff : 0.0
Logits length : 4000 Label length : 800
-------------------
Grads Diff : 0.03427853062748909
Losses Diff : 0.0009765625
Logits length : 4000 Label length : 900
-------------------
Grads Diff : 0.03299879655241966
Losses Diff : 0.0
Logits length : 4000 Label length : 1000
-------------------
Grads Diff : 0.03270421177148819
Losses Diff : 0.0009765625
Logits length : 4000 Label length : 1100
-------------------
Grads Diff : 71.10806274414062
Losses Diff : 0.0
Logits length : 4000 Label length : 1200
-------------------
Grads Diff : 70.92623901367188
Losses Diff : 0.0009765625
Logits length : 4000 Label length : 1300
-------------------
Grads Diff : 71.46635437011719
Losses Diff : 0.0009765625
Logits length : 4000 Label length : 1400
-------------------
Grads Diff : 72.2627182006836
Losses Diff : 0.0
Logits length : 4000 Label length : 1500
-------------------
Grads Diff : 72.44979858398438
Losses Diff : 0.0
Logits length : 4000 Label length : 1600
-------------------
Grads Diff : 73.44546508789062
Losses Diff : 0.0009765625
The output for the double case is the following, notice the huge difference in gradients starting from 900 target length, more debugging shows that 896 works fine but 897 is corrupted
Logits length : 4000 Label length : 700
-------------------
Grads Diff : 8.477306146763702e-11
Losses Diff : 0.0
Logits length : 4000 Label length : 800
-------------------
Grads Diff : 8.588737980344711e-11
Losses Diff : 0.0
Logits length : 4000 Label length : 900
-------------------
Grads Diff : 71.26078984351248
Losses Diff : 1.8189894035458565e-12
Logits length : 4000 Label length : 1000
-------------------
Grads Diff : 70.95303213362187
Losses Diff : 0.0
Logits length : 4000 Label length : 1100
-------------------
Grads Diff : 70.60947779103446
Losses Diff : 0.0
Logits length : 4000 Label length : 1200
-------------------
Grads Diff : 70.94595045695341
Losses Diff : 1.8189894035458565e-12
Logits length : 4000 Label length : 1300
-------------------
Grads Diff : 71.45715679703945
Losses Diff : 1.8189894035458565e-12
Logits length : 4000 Label length : 1400
-------------------
Grads Diff : 71.74132504186824
Losses Diff : 0.0
Logits length : 4000 Label length : 1500
-------------------
Grads Diff : 72.6907375164959
Losses Diff : 0.0
Logits length : 4000 Label length : 1600
-------------------
Grads Diff : 73.49384136194594
Losses Diff : 0.0
Environment
- PyTorch Version (e.g., 1.0): tested problem on 1.1 & 1.2 and master
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): pip - Python version: 3.6
- CUDA/cuDNN version: CUDA: 10.0.130, cuDNN: 7.4.2
- GPU models and configuration: tested on V100 & K80