Skip to content

RelaxedBernoulli produces samples on the boundary with NaN log_prob #18254

@mzperix

Description

@mzperix

🐛 Bug

distributions.RelaxedBernoulli with low temperature produces samples (on the boundary 0.0 and 1.0) for which the distribution gives log_prob value NaN. (Same goes for distributions.RelaxedOneHotCategorical)

To Reproduce

import torch
from torch.distributions import RelaxedBernoulli

torch.manual_seed(2)
dist = RelaxedBernoulli(temperature = torch.tensor(0.05), logits = torch.tensor(-5.0))
sample = dist.sample()
print('Sample: {:.8E}'.format(sample))
print('Log prob: {:4.2f}'.format(dist.log_prob(sample)))

gives the following output:

Sample: 0.00000000E+00
Log prob:  nan

Expected behavior

I would imagine that for most use-cases clamping the sample (after the SigmoidTransform) between eps, and 1-eps would be more beneficial than producing NaNs.

Environment

PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.14.3
GCC version: Could not collect
CMake version: version 3.12.3

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] numpy==1.16.1
[pip] torch==1.0.0
[pip] torchfile==0.1.0
[pip] torchvision==0.2.1
[conda] torch                     1.0.0                     <pip>
[conda] torchfile                 0.1.0                     <pip>
[conda] torchvision               0.2.1                     <pip>

My temporary fix

For the project I'm working on I have a temporary fix which is the creation of a ClampedRelaxedBernoulli class replacing the inherited sample and rsample from the TransformedDistribution class' methods (and using the structure in the clamp_probs function in torch.distributions.utils for clamping).

class ClampedRelaxedBernoulli(RelaxedBernoulli):

    def sample(self, sample_shape=torch.Size()):
        """
        Generates a sample_shape shaped sample or sample_shape shaped batch of
        samples if the distribution parameters are batched. Samples first from
        base distribution and applies `transform()` for every transform in the
        list.
        """
        with torch.no_grad():
            x = self.base_dist.sample(sample_shape)
            for transform in self.transforms:
                x = transform(x)
            eps = torch.finfo(self.base_dist.logits.dtype).eps
            return x.clamp(min=eps, max=1 - eps)


    def rsample(self, sample_shape=torch.Size()):
        """
        Generates a sample_shape shaped reparameterized sample or sample_shape
        shaped batch of reparameterized samples if the distribution parameters
        are batched. Samples first from base distribution and applies
        `transform()` for every transform in the list.
        """
        x = self.base_dist.rsample(sample_shape)
        for transform in self.transforms:
            x = transform(x)
        eps = torch.finfo(self.base_dist.logits.dtype).eps
        return x.clamp(min=eps, max=1 - eps)

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: distributionsRelated to torch.distributionssmallWe think this is a small issue to fix. Consider knocking off high priority small issuestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions