-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: distributionsRelated to torch.distributionsRelated to torch.distributionssmallWe think this is a small issue to fix. Consider knocking off high priority small issuesWe think this is a small issue to fix. Consider knocking off high priority small issuestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
distributions.RelaxedBernoulli with low temperature produces samples (on the boundary 0.0 and 1.0) for which the distribution gives log_prob value NaN. (Same goes for distributions.RelaxedOneHotCategorical)
To Reproduce
import torch
from torch.distributions import RelaxedBernoulli
torch.manual_seed(2)
dist = RelaxedBernoulli(temperature = torch.tensor(0.05), logits = torch.tensor(-5.0))
sample = dist.sample()
print('Sample: {:.8E}'.format(sample))
print('Log prob: {:4.2f}'.format(dist.log_prob(sample)))
gives the following output:
Sample: 0.00000000E+00
Log prob: nan
Expected behavior
I would imagine that for most use-cases clamping the sample (after the SigmoidTransform) between eps, and 1-eps would be more beneficial than producing NaNs.
Environment
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: None
OS: Mac OSX 10.14.3
GCC version: Could not collect
CMake version: version 3.12.3
Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip] numpy==1.16.1
[pip] torch==1.0.0
[pip] torchfile==0.1.0
[pip] torchvision==0.2.1
[conda] torch 1.0.0 <pip>
[conda] torchfile 0.1.0 <pip>
[conda] torchvision 0.2.1 <pip>
My temporary fix
For the project I'm working on I have a temporary fix which is the creation of a ClampedRelaxedBernoulli class replacing the inherited sample and rsample from the TransformedDistribution class' methods (and using the structure in the clamp_probs function in torch.distributions.utils for clamping).
class ClampedRelaxedBernoulli(RelaxedBernoulli):
def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched. Samples first from
base distribution and applies `transform()` for every transform in the
list.
"""
with torch.no_grad():
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
eps = torch.finfo(self.base_dist.logits.dtype).eps
return x.clamp(min=eps, max=1 - eps)
def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched. Samples first from base distribution and applies
`transform()` for every transform in the list.
"""
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
eps = torch.finfo(self.base_dist.logits.dtype).eps
return x.clamp(min=eps, max=1 - eps)
Metadata
Metadata
Assignees
Labels
high prioritymodule: distributionsRelated to torch.distributionsRelated to torch.distributionssmallWe think this is a small issue to fix. Consider knocking off high priority small issuesWe think this is a small issue to fix. Consider knocking off high priority small issuestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module