Skip to content

Conversation

@zuoxingdong
Copy link
Contributor

@zuoxingdong zuoxingdong commented Apr 26, 2019

fix #18254 for numerically instability of SigmoidTransform

@pytorchbot pytorchbot added the module: distributions Related to torch.distributions label Apr 26, 2019
@zou3519
Copy link
Contributor

zou3519 commented Apr 26, 2019

Needs a test

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 26, 2019
@zou3519
Copy link
Contributor

zou3519 commented Apr 26, 2019

cc @fritzo @neerajprad

Copy link
Contributor

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a test to test_distributions.py::TestTransforms to verify this behavior around the boundary.


def log_abs_det_jacobian(self, x, y):
return -(y.reciprocal() + (1 - y).reciprocal()).log()
return -torch.log(1e-6 + y.reciprocal() + (1 - y).reciprocal())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not address the case when y=1., where we will still see infinity. How about we clamp this using

clamped = torch.clamp(y.reciprocal() + (1 - y).reciprocal(), max=torch.finfo(y.dtype).max)
return -clamped.log()

Copy link
Contributor Author

@zuoxingdong zuoxingdong Apr 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad It looks good ! I've tried a testing, it seems when std from Normal distribution is large, there is still possibliity to get 'inf, e.g.

d = TransformedDistribution(Normal(torch.tensor(0.0), torch.tensor(100.0)), [StableSigmoidTransform()])
x = d.sample()
print(x)
d.log_prob(x)
>>> tensor(1.)
>>> tensor(-inf)

Copy link
Contributor

@neerajprad neerajprad Apr 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would just fix the jacobian term, but in TransformedDistribution.log_prob we still evaluate base_dist.log_prob(y) and in this case y = inf. I think what you really want is something like a ClippedSigmoidTransform which differs from SigmoidTransform in that it clips the output of _call to lie in (0, 1):

    def _call(self, x):
        finfo = torch.finfo(x.dtype)
        clamped = torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)
        return clamped

Then you shouldn't see any infs in your example above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alicanb, @fritzo - Do you think this is something we should do by default (tensorflow doesn't seem to do this either), or have a separate transform class ClippedSigmoidTransform?

Copy link
Contributor Author

@zuoxingdong zuoxingdong Apr 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad @alicanb @fritzo Thanks a lot for the proposal. I think this is rather important for some applications where we evaluate log_prob of a sample drawn from a transformed distribution that is conditional on the input of the NN. Without such fixes, it is easily happen to get NaN/Inf. An example use case is tanh transformed policy network in reinforcement learning context, e.g. SAC algorithm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or this can be provided with an optionally flag in SigmoidTransform, without creating a new class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that’s a good idea - we can have an optional keyword arg, clip=False.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm late to the party. @neerajprad your suggestion sounds great, do you think users would want to change clip limits (@zuoxingdong, have you ever had the need?) ? If that's the case, instead of a bool arg, we can keep the clamping symmetric (min=lim, max=1-lim) and have lim as the argument. If not, then bool arg sound great


def log_abs_det_jacobian(self, x, y):
return -(y.reciprocal() + (1 - y).reciprocal()).log()
return -F.softplus(-x) - F.softplus(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I think this is the best way to handle the numerical stability issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks great, but just to confirm, this still doesn't handle the TransformedDistribution case you mentioned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so nice!!!!!!!!!!!!!!!!!!!! A great way to address the overflow/underflow while computing y.

@neerajprad
Copy link
Contributor

@zuoxingdong - @fehiepsi has a PR out - #20288, which amongst other things is also tackling this issue that you attempted in this PR. We can get this PR merged first, or otherwise move the discussion to #20288. Feel free to review that PR too.

@fehiepsi
Copy link
Contributor

This PR looks great to me! The precision of log_abs_det_jacobian is really helpful to prevent the unconstrained parameter x getting too large or too small domain in inference algorithms (in those domains, large magnitude logdet implies smaller log-likelihood). Otherwise, the algorithm has to rely on other factors to recognize that moving x to large/small domain is bad.

@zuoxingdong
Copy link
Contributor Author

Hi @neerajprad @fehiepsi , thank you for your feedbacks, #20288 looks nice to me, in this case, I'm proposing to only contain the log_abs_det_jacobian numerical stability fix in this PR to be merged, and we could discuss other issues in #20288 all together in one place. What do you think ?

@zuoxingdong
Copy link
Contributor Author

I've removed the clipping part in this PR as it is already included in #20288 , and now it only contains fix for log_abs_det_jacobian

Copy link
Contributor

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Lets discuss the other changes in #20288.

@neerajprad
Copy link
Contributor

The build failures are unrelated.

@pytorchbot merge this please

@pytorchbot pytorchbot added the merge-this-please Was marked for merge with @pytorchbot merge this please label May 20, 2019
@fehiepsi
Copy link
Contributor

fehiepsi commented Jun 4, 2019

@pytorchbot retest this please

@zou3519
Copy link
Contributor

zou3519 commented Jun 4, 2019

I think this might need a rebase so that the CI can run

@ezyang
Copy link
Contributor

ezyang commented Jun 6, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in c5d5d45.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-this-please Was marked for merge with @pytorchbot merge this please module: distributions Related to torch.distributions open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RelaxedBernoulli produces samples on the boundary with NaN log_prob

8 participants