Skip to content

Commit 91ea2cd

Browse files
fehiepsifacebook-github-bot
authored andcommitted
clip sigmoid to prevent transforms return inf/nan values (#20288)
Summary: This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas. For example, with ``` t = torch.distributions.SigmoidTransform() x = torch.tensor(20.) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`. And for ``` t = torch.distributions.StickBreakingTransform() x = torch.tensor([20., 20.]) t.inv(t(x)), t.log_abs_det_jacobian(x, t(x)) ``` current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet. Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate. I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation. Pull Request resolved: #20288 Differential Revision: D15742047 Pulled By: ezyang fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17
1 parent 4bdbd30 commit 91ea2cd

File tree

3 files changed

+32
-18
lines changed

3 files changed

+32
-18
lines changed

test/test_distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3469,7 +3469,7 @@ def test_entropy_exponential_family(self):
34693469

34703470

34713471
class TestConstraints(TestCase):
3472-
def test_params_contains(self):
3472+
def test_params_constraints(self):
34733473
for Dist, params in EXAMPLES:
34743474
for i, param in enumerate(params):
34753475
dist = Dist(**param)
@@ -3492,7 +3492,7 @@ def test_params_contains(self):
34923492
Dist.__name__, i + 1, len(params), name, value)
34933493
self.assertTrue(constraint.check(value).all(), msg=message)
34943494

3495-
def test_support_contains(self):
3495+
def test_support_constraints(self):
34963496
for Dist, params in EXAMPLES:
34973497
self.assertIsInstance(Dist.support, Constraint)
34983498
for i, param in enumerate(params):

torch/distributions/constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class _Simplex(Constraint):
251251
Specifically: `x >= 0` and `x.sum(-1) == 1`.
252252
"""
253253
def check(self, value):
254-
return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all()
254+
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
255255

256256

257257
class _LowerTriangular(Constraint):
@@ -295,7 +295,7 @@ class _RealVector(Constraint):
295295
but additionally reduces across the `event_shape` dimension.
296296
"""
297297
def check(self, value):
298-
return (value == value).all() # False for NANs.
298+
return torch.all(value == value, dim=-1) # False for NANs.
299299

300300

301301
class _Cat(Constraint):

torch/distributions/transforms.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,14 @@ def log_abs_det_jacobian(self, x, y):
274274
if not self.parts:
275275
return torch.zeros_like(x)
276276
result = 0
277-
for part in self.parts:
278-
y = part(x)
279-
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
277+
for part in self.parts[:-1]:
278+
y_tmp = part(x)
279+
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y_tmp),
280280
self.event_dim - part.event_dim)
281-
x = y
281+
x = y_tmp
282+
part = self.parts[-1]
283+
result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
284+
self.event_dim - part.event_dim)
282285
return result
283286

284287
def __repr__(self):
@@ -341,6 +344,11 @@ def log_abs_det_jacobian(self, x, y):
341344
return (self.exponent * y / x).abs().log()
342345

343346

347+
def _clipped_sigmoid(x):
348+
finfo = torch.finfo(x.dtype)
349+
return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)
350+
351+
344352
class SigmoidTransform(Transform):
345353
r"""
346354
Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
@@ -354,9 +362,11 @@ def __eq__(self, other):
354362
return isinstance(other, SigmoidTransform)
355363

356364
def _call(self, x):
357-
return torch.sigmoid(x)
365+
return _clipped_sigmoid(x)
358366

359367
def _inverse(self, y):
368+
finfo = torch.finfo(y.dtype)
369+
y = y.clamp(min=finfo.tiny, max=1. - finfo.eps)
360370
return y.log() - (-y).log1p()
361371

362372
def log_abs_det_jacobian(self, x, y):
@@ -495,23 +505,27 @@ def __eq__(self, other):
495505
return isinstance(other, StickBreakingTransform)
496506

497507
def _call(self, x):
498-
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
499-
z = torch.sigmoid(x - offset.log())
508+
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
509+
z = _clipped_sigmoid(x - offset.log())
500510
z_cumprod = (1 - z).cumprod(-1)
501511
y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
502512
return y
503513

504514
def _inverse(self, y):
505-
shape = y.shape[:-1] + (y.shape[-1] - 1,)
506-
offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1)
507-
sf = (1 - y.cumsum(-1))[..., :-1]
508-
x = y[..., :-1].log() - sf.log() + offset.log()
515+
y_crop = y[..., :-1]
516+
offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
517+
sf = 1 - y_crop.cumsum(-1)
518+
# we clamp to make sure that sf is positive which sometimes does not
519+
# happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
520+
sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
521+
x = y_crop.log() - sf.log() + offset.log()
509522
return x
510523

511524
def log_abs_det_jacobian(self, x, y):
512-
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
513-
z = torch.sigmoid(x - offset.log())
514-
detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1)
525+
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
526+
x = x - offset.log()
527+
# use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
528+
detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
515529
return detJ
516530

517531

0 commit comments

Comments
 (0)