-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Update transforms.py] Fix numerically instability of SigmoidTransform
#19802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Needs a test |
neerajprad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add a test to test_distributions.py::TestTransforms to verify this behavior around the boundary.
torch/distributions/transforms.py
Outdated
|
|
||
| def log_abs_det_jacobian(self, x, y): | ||
| return -(y.reciprocal() + (1 - y).reciprocal()).log() | ||
| return -torch.log(1e-6 + y.reciprocal() + (1 - y).reciprocal()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not address the case when y=1., where we will still see infinity. How about we clamp this using
clamped = torch.clamp(y.reciprocal() + (1 - y).reciprocal(), max=torch.finfo(y.dtype).max)
return -clamped.log()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad It looks good ! I've tried a testing, it seems when std from Normal distribution is large, there is still possibliity to get 'inf, e.g.
d = TransformedDistribution(Normal(torch.tensor(0.0), torch.tensor(100.0)), [StableSigmoidTransform()])
x = d.sample()
print(x)
d.log_prob(x)
>>> tensor(1.)
>>> tensor(-inf)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that would just fix the jacobian term, but in TransformedDistribution.log_prob we still evaluate base_dist.log_prob(y) and in this case y = inf. I think what you really want is something like a ClippedSigmoidTransform which differs from SigmoidTransform in that it clips the output of _call to lie in (0, 1):
def _call(self, x):
finfo = torch.finfo(x.dtype)
clamped = torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)
return clampedThen you shouldn't see any infs in your example above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad @alicanb @fritzo Thanks a lot for the proposal. I think this is rather important for some applications where we evaluate log_prob of a sample drawn from a transformed distribution that is conditional on the input of the NN. Without such fixes, it is easily happen to get NaN/Inf. An example use case is tanh transformed policy network in reinforcement learning context, e.g. SAC algorithm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or this can be provided with an optionally flag in SigmoidTransform, without creating a new class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that’s a good idea - we can have an optional keyword arg, clip=False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm late to the party. @neerajprad your suggestion sounds great, do you think users would want to change clip limits (@zuoxingdong, have you ever had the need?) ? If that's the case, instead of a bool arg, we can keep the clamping symmetric (min=lim, max=1-lim) and have lim as the argument. If not, then bool arg sound great
|
|
||
| def log_abs_det_jacobian(self, x, y): | ||
| return -(y.reciprocal() + (1 - y).reciprocal()).log() | ||
| return -F.softplus(-x) - F.softplus(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I think this is the best way to handle the numerical stability issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change looks great, but just to confirm, this still doesn't handle the TransformedDistribution case you mentioned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so nice!!!!!!!!!!!!!!!!!!!! A great way to address the overflow/underflow while computing y.
|
@zuoxingdong - @fehiepsi has a PR out - #20288, which amongst other things is also tackling this issue that you attempted in this PR. We can get this PR merged first, or otherwise move the discussion to #20288. Feel free to review that PR too. |
|
This PR looks great to me! The precision of |
|
Hi @neerajprad @fehiepsi , thank you for your feedbacks, #20288 looks nice to me, in this case, I'm proposing to only contain the |
|
I've removed the clipping part in this PR as it is already included in #20288 , and now it only contains fix for |
neerajprad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Lets discuss the other changes in #20288.
|
The build failures are unrelated. @pytorchbot merge this please |
|
@pytorchbot retest this please |
|
I think this might need a rebase so that the CI can run |
|
@pytorchbot rebase this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
fix #18254 for numerically instability of
SigmoidTransform