-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[distributions] clip sigmoid to prevent transforms return inf/nan values #20288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Once you get review from other distributions maintainers, use pytorchbot to request a merge |
|
Oppss, I forgot to cc reviewers. @neerajprad Could you please help me review this? |
neerajprad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The math looks fine to me. Since other users have also brought up numerical issues with transforms like SigmoidTransform, this seems like a reasonable change to me. I think that we should surely not return NaNs from log_abs_det_jacobian, but I would like to get other opinions on whether there might be unintended effects when clamping to not return inf/-inf from transforms.
| x = y | ||
| x = y_tmp | ||
| part = self.parts[-1] | ||
| result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this for numerical stability reasons?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mainly to save the last transform computation, where we don't need to recompute y. It can help numerical stability somehow by using an exact y instead of an approximated y through transforms.
torch/distributions/transforms.py
Outdated
|
|
||
| def log_abs_det_jacobian(self, x, y): | ||
| return -(y.reciprocal() + (1 - y).reciprocal()).log() | ||
| return (y * (1 - y)).log() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think another approach for y closer to 0 or 1 would be the one taken in #19802 of operating on x instead of y -F.softplus(-x) - F.softplus(x). I'm not sure how relevant it is with clipping though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, when we use transforms, it is better to compute log_det in term of x (to eliminate numerical precision while getting y from x). Thanks for providing a reference! I'll see if it really helps and gets back on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can discuss about this in #19802 first to let it merged first, then I'll make corresponding changes in this PR later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Now that I am looking at that PR again, I think we should probably just define a clamping function to clamp to (0, 1) and call that in both _call and _inverse, because users might be using the inverse transform with values that can be 0 or 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, agree.
torch/distributions/transforms.py
Outdated
| def _call(self, x): | ||
| offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) | ||
| z = torch.sigmoid(x - offset.log()) | ||
| offset = x.shape[-1] - torch.arange(x.shape[-1], dtype=x.dtype, device=x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was done for compatibility with JIT. See - #11829. Unless that issue is resolved, we should revert this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I'll revert this with a drop for the unnecessary expand operator.
|
|
||
| def _call(self, x): | ||
| return torch.sigmoid(x) | ||
| return _clipped_sigmoid(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm - this change is independent of changes to the StickBreakingTransform?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
|
@neerajprad #19802 is merged so we can keep discussing the changes in this PR. |
torch/distributions/transforms.py
Outdated
| z = torch.clamp(z, min=torch.finfo(x.dtype).tiny) | ||
| # use the identity 1 - z = z * exp(-x), we don't have to worry about | ||
| # the case z ~ 1 | ||
| detJ = (-x + z.log() + y[..., :-1].log()).sum(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use this identity but replace the sigmoid and clamping by using F.logsigmoid instead:
detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @neerajprad ! I just address it. :)
neerajprad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
|
@pytorchbot merge this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas.
For example, with
current behaviour the inverse will return
infand logdet return-infwhile this PR makes it to15.9424and-15.9424.And for
current value is
(inf, nan)and-inffor logdet, while this PR makes it[16.6355, 71.3942]and-47.8272for logdet.Although these finite values are wrong and seems unavoidable, it is better than returning
infornanin my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate.I also fix some small issues of
_Simplexand_RealVectorconstraints where batch shape of the input is not respected when checking validation.