Skip to content

Conversation

@fehiepsi
Copy link
Contributor

@fehiepsi fehiepsi commented May 8, 2019

This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas.

For example, with

t = torch.distributions.SigmoidTransform()
x = torch.tensor(20.)
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))

current behaviour the inverse will return inf and logdet return -inf while this PR makes it to 15.9424 and -15.9424.

And for

t = torch.distributions.StickBreakingTransform()
x = torch.tensor([20., 20.])
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))

current value is (inf, nan) and -inf for logdet, while this PR makes it [16.6355, 71.3942] and -47.8272 for logdet.

Although these finite values are wrong and seems unavoidable, it is better than returning inf or nan in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate.

I also fix some small issues of _Simplex and _RealVector constraints where batch shape of the input is not respected when checking validation.

@pytorchbot pytorchbot added the module: distributions Related to torch.distributions label May 8, 2019
@ezyang ezyang requested review from alicanb and fritzo May 17, 2019 15:27
@ezyang
Copy link
Contributor

ezyang commented May 17, 2019

Once you get review from other distributions maintainers, use pytorchbot to request a merge

@fehiepsi
Copy link
Contributor Author

Oppss, I forgot to cc reviewers. @neerajprad Could you please help me review this?

Copy link
Contributor

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The math looks fine to me. Since other users have also brought up numerical issues with transforms like SigmoidTransform, this seems like a reasonable change to me. I think that we should surely not return NaNs from log_abs_det_jacobian, but I would like to get other opinions on whether there might be unintended effects when clamping to not return inf/-inf from transforms.

x = y
x = y_tmp
part = self.parts[-1]
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for numerical stability reasons?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly to save the last transform computation, where we don't need to recompute y. It can help numerical stability somehow by using an exact y instead of an approximated y through transforms.


def log_abs_det_jacobian(self, x, y):
return -(y.reciprocal() + (1 - y).reciprocal()).log()
return (y * (1 - y)).log()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another approach for y closer to 0 or 1 would be the one taken in #19802 of operating on x instead of y -F.softplus(-x) - F.softplus(x). I'm not sure how relevant it is with clipping though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, when we use transforms, it is better to compute log_det in term of x (to eliminate numerical precision while getting y from x). Thanks for providing a reference! I'll see if it really helps and gets back on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can discuss about this in #19802 first to let it merged first, then I'll make corresponding changes in this PR later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Now that I am looking at that PR again, I think we should probably just define a clamping function to clamp to (0, 1) and call that in both _call and _inverse, because users might be using the inverse transform with values that can be 0 or 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agree.

def _call(self, x):
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
z = torch.sigmoid(x - offset.log())
offset = x.shape[-1] - torch.arange(x.shape[-1], dtype=x.dtype, device=x.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was done for compatibility with JIT. See - #11829. Unless that issue is resolved, we should revert this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll revert this with a drop for the unnecessary expand operator.


def _call(self, x):
return torch.sigmoid(x)
return _clipped_sigmoid(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm - this change is independent of changes to the StickBreakingTransform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

@fehiepsi
Copy link
Contributor Author

fehiepsi commented Jun 7, 2019

@neerajprad #19802 is merged so we can keep discussing the changes in this PR.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 7, 2019
z = torch.clamp(z, min=torch.finfo(x.dtype).tiny)
# use the identity 1 - z = z * exp(-x), we don't have to worry about
# the case z ~ 1
detJ = (-x + z.log() + y[..., :-1].log()).sum(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use this identity but replace the sigmoid and clamping by using F.logsigmoid instead:

detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @neerajprad ! I just address it. :)

Copy link
Contributor

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@neerajprad
Copy link
Contributor

@pytorchbot merge this please

@pytorchbot pytorchbot added the merge-this-please Was marked for merge with @pytorchbot merge this please label Jun 10, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 91ea2cd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-this-please Was marked for merge with @pytorchbot merge this please Merged module: distributions Related to torch.distributions open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants