Skip to content

Conversation

@chenyangyu1988
Copy link
Contributor

Summary:
Original commit changeset: 398d5f48826a

f123389994, P68573646
During FP16 training, we found char_embeddings.weight get NAN or INF grads

Differential Revision: D16067446

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels Jun 30, 2019
@mrshenli
Copy link
Contributor

CC @madsbk

Looks like #22016 sometimes leads NAN or INF in embedding grads for FP16 training. @chenyangyu1988 Could you please share some context here?

@ngimel
Copy link
Collaborator

ngimel commented Jun 30, 2019

It's possible that truncating grad_weight_per_segment to fp16 is causing this.

@madsbk
Copy link
Contributor

madsbk commented Jun 30, 2019

@ngimel, yes that could be the problem and easy fixable. I will look at it tomorrow.

@chenyangyu1988, do we have an example code that triggers the error?

@chenyangyu1988
Copy link
Contributor Author

@madsbk I don't have a OSS example here, but if you have a FP16 training example, you could check the loss_scale after 1k iterations, it is correct if the loss_scale is much larger than 1. In our case, we still get NAN or INF in the embedding layer grad even loss_scale is some 2.2e-16

@chenyangyu1988
Copy link
Contributor Author

Do we have a target time for this, it is currently block our training. Or could we revert first?

@ngimel
Copy link
Collaborator

ngimel commented Jul 1, 2019

Yes, it should be reverted, #22016 has other problems that need to be fixed first.

@chenyangyu1988
Copy link
Contributor Author

@mrshenli so let's revert it for now?

…ag CUDA Kernel" (pytorch#22377)

Summary:
Pull Request resolved: pytorch#22377

Original commit changeset: 398d5f48826a D15944339

f123389994, P68573646
During FP16 training, we found char_embeddings.weight get NAN or INF grads

Differential Revision: D16067446

fbshipit-source-id: 9ad54f67f73576a2c663cde9c9ee00a9aa669879
@madsbk
Copy link
Contributor

madsbk commented Jul 1, 2019

@chenyangyu1988, can I get you to test #22401?

@chenyangyu1988
Copy link
Contributor Author

chenyangyu1988 commented Jul 1, 2019

@madsbk There are 3 commits in #22401, I think I just need to apply the latest one?

@madsbk
Copy link
Contributor

madsbk commented Jul 1, 2019

@madsbk There are 3 commits in #22401, should I patch all of them?

Yes

facebook-github-bot pushed a commit that referenced this pull request Jul 2, 2019
Summary:
Address the issue raised in #22377.

The PR #22016 introduces a temporary tensor of weights `grad_weight_per_segment` of the same dtype as the end result, which can be a problem when using `float16`.
In this PR, it now use a `float32` temporary tensor when the input is `float16`.

ngimel, can I get you to review? I think I have fixed the issues you have pointed out.
Pull Request resolved: #22401

Differential Revision: D16077319

Pulled By: mrshenli

fbshipit-source-id: 7cfad7f40b4d41a244052baa2982ab51bbbd7309
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 2, 2019
Summary:
Address the issue raised in pytorch/pytorch#22377.

The PR pytorch/pytorch#22016 introduces a temporary tensor of weights `grad_weight_per_segment` of the same dtype as the end result, which can be a problem when using `float16`.
In this PR, it now use a `float32` temporary tensor when the input is `float16`.

ngimel, can I get you to review? I think I have fixed the issues you have pointed out.
Pull Request resolved: pytorch/pytorch#22401

Differential Revision: D16077319

Pulled By: mrshenli

fbshipit-source-id: 7cfad7f40b4d41a244052baa2982ab51bbbd7309
xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
Summary:
Address the issue raised in pytorch#22377.

The PR pytorch#22016 introduces a temporary tensor of weights `grad_weight_per_segment` of the same dtype as the end result, which can be a problem when using `float16`.
In this PR, it now use a `float32` temporary tensor when the input is `float16`.

ngimel, can I get you to review? I think I have fixed the issues you have pointed out.
Pull Request resolved: pytorch#22401

Differential Revision: D16077319

Pulled By: mrshenli

fbshipit-source-id: 7cfad7f40b4d41a244052baa2982ab51bbbd7309
@facebook-github-bot
Copy link
Contributor

Hi @chenyangyu1988!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@jbschlosser
Copy link
Contributor

Closing as out-of-date; feel free to reopen if still relevant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: cuda Related to torch.cuda, and CUDA support in general open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants