-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix numerical stability in binomial.log_prob #15962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/distributions/binomial.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PyTorch, logits != probs.log(), so this formula seems wrong to me. Previously, the tests passed, which makes me confused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi - That is correct - logits != probs.log() for the binary case.
Does this seem reasonable?
value * probs.log() + (total_count - value) * torch.log1p(-probs))
= value * (probs.log() - torch.log1p(-probs)) + total_count * torch.log1p(-probs)
= value * logits + total_count * (1 - probs).log()
= value * logits - total_count * torch.log1p(logits.exp()))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first equation: correct
second equation: correct
third equation: correct incorrectly assumes probs=logits.exp()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! It seems correct to me now. @fritzo will be confused. :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just in case, this causes confusion for others, expanding on the last step:
(1 - probs).log()
= (1 - (logits.exp()/(1 + logits.exp())).log()
= (1 / (1 + logits.exp())).log()
= - (1 + logits.exp())).log()
= - log1p(logits.exp())
643cf69 to
09e4550
Compare
09e4550 to
c384ed1
Compare
|
|
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This issue was discovered by @fehiepsi in pyro-ppl/pyro#1706 with the
log_probcomputation for Binomial,and can be seen withtorch.float32when we have a combination of low probability value and hightotal_count- a test is added to capture this (since scipy only uses float64, the comparison is done using relative tolerance).The problem is in the code that tries to pull out the minimum values amongst the logits (written by me earlier, presumably to avoid numerical instability issues), but it is not needed.
EDIT: After a few attempts, I have been unable to reliably show that the change is more numerically stable, and have removed my previous test which fails on linux. The reason is that the issue manifests itself when
total_countis high andprobsis very low. However, the precision oflgammawhentotal_countis high is bad enough to wash away any benefits. The justification for this still stands though - (a) simplifies code (removes the unnecessary bit), (b) is no worse than the previous implementation, (c) has better continuity behavior as observed by @fehiepsi in the issue above.cc. @fehiepsi, @alicanb, @fritzo