Skip to content

Conversation

@yqwangustc
Copy link

@yqwangustc yqwangustc commented Jun 1, 2019

When use ATen version of CTCLoss, we observe indeterminstically having NaN in the output gradient. Initial examination reveal that 1) those NaN happens rarely and only at BLANK position. 2) if we disable large batch handling (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/LossCTC.cu#L552), NaN disappears.
Based on the above observation, I believe this line (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/LossCTC.cu#L560) might incur NaN indeterminstically.
Unfortunately, due to the indeterminstic behavior, I am not able to provide a repo thus pinpoint the exact root cause. Before we find it, we'd like to place a stopgap here to zero out any NaN in gradient.

Note there is a previous PR which zero out infinity only (#16199), unfortunately, NaN is not infinity. so it won't help if there is a NaN in gradient. This PR overloads the zero_infinity option to zero out both infinity and NaN. Considering they are used for the same purpose, I think it is fine to just use "zero_infinity" option at the moment.

Differential Revision: D15591296

@t-vi
Copy link
Collaborator

t-vi commented Jun 1, 2019

I'm skeptical about this without some repro. My intuition is that if it produces invalid NaNs occasonally, it will likely produce wrong non-NaN results. Could you please capture the inputs to CTC loss that produces NaNs (i.e. keep them around and when you get NaN gradients save them)?
If you use 2d targets: is this something that might be helped by #20971 ?

@t-vi
Copy link
Collaborator

t-vi commented Jun 1, 2019

Oh, and thanks for investigating this. It's awesome to see that you tracked this down!

@yqwangustc
Copy link
Author

@t-vi yes, creating a consistent repo is the most important thing here. I was able to repo NaN occasionally by injecting gradient check code during training and pdb once there is any NaN in gradient . But unfortunately, supplying the exact same input to CTCLoss again did not produce NaN. So I am not able to give you a consistent repo yet. I am thinking other approach to get a repo.

Btw, we use 1-d target so #20971 won’t help. Also we notice that once we disable large batch handling, NaN also disappears.

Since a lot of people’s work pending a numerical stable CTC loss, may I suggest we can merge this PR as a stopgap solution, similar as the purpose of zero_infinity ?

Summary:
When use ATen version of CTCLoss, we observe indeterminstically having NaN in the output gradient. Initial examination reveal that  1) those NaN happens rarely and only at BLANK position. 2) if we disable large batch handling (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/LossCTC.cu#L552), NaN disappears.

Based on the above observation, I believe this line (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/LossCTC.cu#L560) might incur NaN indeterminstically.

Unfortunately, due to the indeterminstic behavior, I am not able to provide a repo thus pinpoint the exact root cause. Before we find it, we'd like to place a stopgap here to zero out any NaN in gradient.

Differential Revision: D15591296

fbshipit-source-id: 6865e8da58df69a740f649635c7cc49253b24621
@t-vi
Copy link
Collaborator

t-vi commented Jun 2, 2019

OK, so it's something different. I would indeed want to get us to the situation of zero_infinity - where we clearly understand why we're getting NaNs and so can target avoiding them.
I would expect that there would be something about the inputs when the grads turn NaN, but I'd really like to know what it is.

@t-vi
Copy link
Collaborator

t-vi commented Jun 2, 2019

The other thing I'm wondering about: Do you get the NaN in the forward or the backward? Because you also delete them in the forward, even though you describe observing them in the gradient only.

@ezyang ezyang requested a review from t-vi June 3, 2019 16:01
t-vi
t-vi previously requested changes Jun 3, 2019
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along the lines of the comments above, I think we need some form of repro / example that causes this, and ideally a test.
I have a strong preference for a more proper - i.e. finding and fixing the apparent algorithmic defect - fix.

if (zero_infinity) {
grad = at::where(neg_log_likelihood.view({1, batch_size, 1}) == Scalar(INFINITY), at::zeros({}, grad.options()), grad);
grad = at::where(grad != grad, at::zeros({}, grad.options()), grad);
// to filter out NaN value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I would like to see where this is actually happening. You describe (and I can believe that) that it's not necessarily reproducible, but as mentioned above, I think we need to understand what are the causes and I would half expect that the result will be wrong in the cases when otherwise NaN-producing inputs produce non-NaN results.
Can you capture a sample of inputs for which the gradient turns NaN, please?

Copy link
Author

@yqwangustc yqwangustc Jun 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi , yes this is the main fix I want to perform -- GPU kernels on backward, all those CPU code change is to make the zero_infinity option behaves consistently. As I said I can occasionally capture NaN, however, it is not reproducible consistently (the input I captured results in NaN, but next time I supply them to the CTCLoss, NaN is gone) and I am working on it to produce a true example (trying different approaches now ...)

I fully agree with you that we need to find out the root cause to fix it permanently (and I am eager to find it out!). However, since this NaN issue is blocking a lot of other's work (including me, my colleagues and a lot of other people in PyTorch form), may I suggest that we can patch this for the time being ? I will definitely work with you to find out the real solution afterwards.

Copy link
Collaborator

@t-vi t-vi Jun 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check whether the gradient is correct when it isn't NaN?

Sorry, but without an example, I'm not enthused about zeroing NaNs. Note that most of the people on the forums ended up having errors in their inputs (non-log-probability inputs, invalid targets...). After seeing so many of those I'll have to see an example where the code goes wrong to add bandaids.

The other concern is that if it produces bad NaN gradients, I would not believe the non-NaN gradients either, and these would need a different fix, too.

@yqwangustc
Copy link
Author

@t-vi

The other thing I'm wondering about: Do you get the NaN in the forward or the backward? Because you also delete them in the forward, even though you describe observing them in the gradient only.

sorry that I miss this question. I observe NaN appears in gradient first( no NaN in the input) and loss is a normal number in that case. After gradient has NaN, it propagates to model parameters, and then input (to CTCLoss) in the next mini-batch and then everywhere.

@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2019

@t-vi @yqwangustc and I discussed this offline and decided that we should land this patch to stem the bleeding (but we should continue to investigate what the real problem is.)

@ezyang ezyang dismissed t-vi’s stale review June 3, 2019 19:40

see comment

@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2019

I have one more request, which is that zero_infinities documentation be updated to say that it also zeros NaNs, as a quirk of the current implementation (link to the bug or PR). (I know eventually we will probably want to change it back, but if we forget and this ends up in the release, I'd like the docs to be accurate.)

@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2019

Related: #14335 (no code was provided in the reports but it looks similar)

@yqwangustc
Copy link
Author

yqwangustc commented Jun 3, 2019

A bit more information on the input causing NaN.

  • lprobs.shape
    torch.Size([984, 16, 42])
  • targets_flat.shape
    torch.Size([1316])
  • input_length
    tensor([984, 956, 943, 924, 921, 904, 901, 897, 864, 824, 717, 710, 695, 666, 578, 552], device='cuda:0')
  • target_lengths
    tensor([ 96, 78, 85, 140, 107, 88, 126, 95, 70, 52, 86, 38, 88, 69, 60, 38], device='cuda:0')
  • blank_id = 41

and the loss is a norm floating point, so forward pass is fine. grad[983, 1:10, 41] has all the NaN value, i.e., all the NaN appears at BLANK posterior but they are just padded position in time sequence. Switching to small batch size condition does not yield any NaN, so I am checking grad before/after this line https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/LossCTC.cu#L560

@t-vi will these information help you to identify something ?


update:

by inserting the following code, I confirm that https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/LossCTC.cu#L560 is the offending line.

auto grad_blank = grad.narrow(2, BLANK, 1);
    if (at::isnan(grad_blank).any().item().to<uint8_t>() == 1) {
      printf("Detecting NAN in grad_blank before\n");
    }

    grad_blank -= (at::logsumexp(log_alpha.as_strided({batch_size, log_alpha.size(1), max_target_length+1},
                                                      {log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2)*2})
                                 + log_beta.as_strided({batch_size, log_beta.size(1), max_target_length+1},
                                                       {log_beta.stride(0), log_beta.stride(1), log_beta.stride(2)*2}),
                                 2, true)
                   .permute({1, 0, 2})
                   .add_(neg_log_likelihood.view({1, batch_size, 1}))
                   .sub_(log_probs.narrow(2, BLANK, 1))
                   .exp_()
                   );

     if (at::isnan(grad_blank).any().item().to<uint8_t>() == 1) {
       printf("Detecting NAN in grad_blank after\n");
     }

@t-vi if I understand it correctly, using as_strided needs extreme care. Would you mind looking deep into this line to see whether there is any chance of error ? In particular, I suspect the stride*2 may cause issue at some edge cases ? Maybe we can rewrite it using gather ?


update again

Additional data examination reveal that log_beta has nan in it. Looking into why it will have nan.

@t-vi
Copy link
Collaborator

t-vi commented Jun 4, 2019

Thank you for the analysis! This is extremely useful.
So it occurs in the padded part. In this log_alpha/log_beta will be -infinity. I don't think the as_strided is particularly problematic, but we should really replace it with .slice.
Now if log_probs is infinite or NaN (which should not have an effect in the padded part), it'll give NaN (either from NaN or from inf - inf).

Thanks again for tracking this down!

So I'm happy with your fix, maybe we can add a comment that this is in the padding and that when the blank calculation is moved to a kernel, properly ignoring those bits will be better.
I'd have a preference to not change the forward (the two other bits), but only zero out the NaNs in the backward.

@yqwangustc
Copy link
Author

yqwangustc commented Jun 4, 2019

@t-vi

Actually, after many iterations, I think I found the real root cause now !

I start to notice after log_beta = at::empty_like(log_alpha); there is already some nan in log_beta. Initially I thought this might be fine, since all the (b, t, s) cases will be covered in the GPU kernel ctc_loss_backward_log_beta_gpu_kernel. However, after carefully reviewing the code, I found two exceptions:

These are consistent with the pattern of either NaN or some widely large value (e.g., -1.03845089e+34) I saw in log_beta. This also perfectly explain why I cannot reliably reproduce this issue before.

So the suggest change:

@t-vi Please kindly correct me if I am wrong. I am running more experiments, but I am quite confident this will permanent fix the problem, :-) If this looks good to you, I can abandon this PR, and create a new one (which only has 1 line change !)

@t-vi
Copy link
Collaborator

t-vi commented Jun 4, 2019

He, shaking that function really has the bugs falling out... Awesome work!
I'd probably stick with empty (saves calling a kernel to initialize the memory) if we can make sure we write all of them.

  • So in your first observation > should be >= clearly.
  • The t = max_input_length - 1 case should be handled by extending the s-loop. I was thinking that t = input_length - 1 was already covered and did not consider that we would end up using the uninitialized value.
    Then t >= input_length would cover that (unless input_length == max_input_length).

But I think we might still have to zero "out of sequence" gradients due to the use of log_probs.

@t-vi
Copy link
Collaborator

t-vi commented Jun 4, 2019

Do you want help to update the PR? I'd be glad to help code something up, but you did all the hard work already...

@yqwangustc
Copy link
Author

@t-vi new PR is on the way and will send it out soon. sorry that just wake up, :-)

@t-vi
Copy link
Collaborator

t-vi commented Jun 4, 2019

Supercool! Thanks!

yqwangustc pushed a commit to yqwangustc/pytorch that referenced this pull request Jun 5, 2019
Summary:
as discussed at pytorch#21244, we
found some values in log_beta are not properly initialized. This diff will 1)
initialize all log_beta to -inf; 2) fix a tricky compare condition; 3) zero all
the gradient elements corresponding to padding to zero.

Offline experiments show that this diff can fix previous seen NaN loss.

Differential Revision: D15637977

fbshipit-source-id: 8bc35098aade6aa2035f71f499250883f09b0f25
@yqwangustc
Copy link
Author

Abandon this PR in favor of pytorch/pytorch #21392

@yqwangustc yqwangustc closed this Jun 5, 2019
@yqwangustc yqwangustc deleted the export-D15591296 branch June 5, 2019 03:37
facebook-github-bot pushed a commit that referenced this pull request Jun 5, 2019
Summary:
Pull Request resolved: #21392

as discussed at #21244, we
found some values in log_beta are not properly initialized. This diff will 1)
initialize all log_beta to -inf; 2) fix a tricky compare condition; 3) zero all
the gradient elements corresponding to padding to zero.

Offline experiments show that this diff can fix previous seen NaN loss.

Differential Revision: D15637977

fbshipit-source-id: 477008a5e11aae946bd2aa401ab7e0c513421af0
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 5, 2019
Summary:
Pull Request resolved: pytorch/pytorch#21392

as discussed at pytorch/pytorch#21244, we
found some values in log_beta are not properly initialized. This diff will 1)
initialize all log_beta to -inf; 2) fix a tricky compare condition; 3) zero all
the gradient elements corresponding to padding to zero.

Offline experiments show that this diff can fix previous seen NaN loss.

Differential Revision: D15637977

fbshipit-source-id: 477008a5e11aae946bd2aa401ab7e0c513421af0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cuda Related to torch.cuda, and CUDA support in general

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants