Skip to content

Crash with SIGFPE due to unhandled cases in distributions.MultivariateNormal #8508

@praveen-palanisamy

Description

@praveen-palanisamy

Issue description

With the scalar support in Tensor from PyTorch 0.4, torch.distributions.MultivariateNormal crashes if loc (mean of the distribution) is a scalar (0-dimensional Tensor) although such an input is currently valid . It neither raises a ValueError in torch.distributions.MultivariateNormal.__init__ nor is caught by the real_vector constraint on the loc argument.

A minimal test code is below to reproduce the clueless SIGFPE crash.

Code example

#!/usr/bin/env python
"""
Script to test/reproduce crashes with SIGFPE due to unhandled cases(scalar loc) in distributions.MultivariateNormal
"""
import torch

def test_univariate_scalar_input(loc=0.5, variance=0.1):
    mu = torch.tensor(loc)
    sigma = torch.tensor(variance)
    distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma)
    sample = distribution.sample()
    print(sample)

def test_univariate_scalar_input_with_args_validation(loc=0.5, variance=0.1):
    mu = torch.tensor(loc)
    sigma = torch.tensor(variance)
    distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma, validate_args=True)
    sample = distribution.sample()
    print(sample)

def test_univariate_input(loc=([0.5]), variance=0.1):
    mu = torch.tensor(loc)
    sigma = torch.tensor(variance)
    distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma)
    sample = distribution.sample()
    print(sample)

def test_univariate_input_with_args_validation(loc=([0.5]), variance=0.1):
    mu = torch.tensor(loc)
    sigma = torch.tensor(variance)
    distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma, validate_args=True)
    sample = distribution.sample()
    print(sample)


if __name__ == "__main__":
    test_univariate_scalar_input(loc=0.5, variance=0.1)  # Crashes with Floating point exception (SIGFPE)
    #test_univariate_scalar_input_with_args_validation(loc=0.5, variance=0.1)  #Crashes with Floating point exception (SIGFPE)
    #test_univariate_input(loc=([0.5]), variance=0.1)  # Runs without errors. Haven't verified if samples are from the correct normal distribution
    #test_univariate_input_with_args_validation(loc=([0.5]), variance=0.1)  # Runs without errors. Haven't verified if samples are from the correct normal distribution

I will be happy to submit a PR if you think this needs a fix.

System Info

  • PyTorch or Caffe2: PyTorch
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): NA
  • OS: Ubuntu 16.04
  • PyTorch version: 0.4.0
  • Python version: 3.5.5
  • CUDA/cuDNN version: NA
  • GPU models and configuration: NA
  • GCC version (if compiling from source): NA
  • CMake version: NA
  • Versions of any other relevant libraries: NA

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions