Skip to content

Conversation

@neerajprad
Copy link
Contributor

@neerajprad neerajprad commented Jan 11, 2019

This issue was discovered by @fehiepsi in pyro-ppl/pyro#1706 with the log_prob computation for Binomial, and can be seen with torch.float32 when we have a combination of low probability value and high total_count - a test is added to capture this (since scipy only uses float64, the comparison is done using relative tolerance).

The problem is in the code that tries to pull out the minimum values amongst the logits (written by me earlier, presumably to avoid numerical instability issues), but it is not needed.

EDIT: After a few attempts, I have been unable to reliably show that the change is more numerically stable, and have removed my previous test which fails on linux. The reason is that the issue manifests itself when total_count is high and probs is very low. However, the precision of lgamma when total_count is high is bad enough to wash away any benefits. The justification for this still stands though - (a) simplifies code (removes the unnecessary bit), (b) is no worse than the previous implementation, (c) has better continuity behavior as observed by @fehiepsi in the issue above.

cc. @fehiepsi, @alicanb, @fritzo

Copy link
Contributor

@fehiepsi fehiepsi Jan 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyTorch, logits != probs.log(), so this formula seems wrong to me. Previously, the tests passed, which makes me confused.

Copy link
Contributor Author

@neerajprad neerajprad Jan 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi - That is correct - logits != probs.log() for the binary case.

Does this seem reasonable?

value * probs.log() + (total_count - value) * torch.log1p(-probs))
 = value * (probs.log() - torch.log1p(-probs)) + total_count * torch.log1p(-probs)
 = value * logits + total_count * (1 - probs).log()
 = value * logits - total_count * torch.log1p(logits.exp()))

Copy link
Collaborator

@fritzo fritzo Jan 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first equation: correct
second equation: correct
third equation: correct incorrectly assumes probs=logits.exp()

Copy link
Contributor

@fehiepsi fehiepsi Jan 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It seems correct to me now. @fritzo will be confused. :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case, this causes confusion for others, expanding on the last step:

(1 - probs).log()
 = (1 - (logits.exp()/(1 + logits.exp())).log()
 = (1 / (1 + logits.exp())).log()
 = - (1 + logits.exp())).log()
 = - log1p(logits.exp())

@neerajprad
Copy link
Contributor Author

torch.lgamma seems to have precision issues with FloatTensor in linux (much more so than OSX) due to which it isn't passing the test that I added. I will modify the test accordingly.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants