-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix Binomimal overflow when logits is large #20679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Although this is stable, the implementation will give incorrect grad when |
|
It turns out that the issue lies at torch.clamp which returns grad at 0 is 1. If it returns 0.5, grad will be correct. Related #2195 where the same issue happens for binary cross entropy with logits. If we make grad of torch.clamp at 0 is 0.5, the result for grad at 0 of binomal logprob and binary cross entropy with logits (pre #2195) will be correct. @apaszke What do you think? |
|
Nevermind, we have x = x.clamp(min=0) + x.clamp(max=0), the LHS has grad at 0 is 1 while RHS is 2. We can utilize this difference to add a dummy variable (x - x.clamp(min=0) - x.clamp(max=0)) / 2 to x.clamp(min=0) to adjust its grad. |
|
This is ready to review. I think that I have addressed problems which I can think about:
|
neerajprad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
|
@pytorchbot merge this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR fixes #17843. In addition (test locally), this still maintains the continuity of log_prob which is addressed in #15962
cc @neerajprad