-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Issue description
With the scalar support in Tensor from PyTorch 0.4, torch.distributions.MultivariateNormal crashes if loc (mean of the distribution) is a scalar (0-dimensional Tensor) although such an input is currently valid . It neither raises a ValueError in torch.distributions.MultivariateNormal.__init__ nor is caught by the real_vector constraint on the loc argument.
A minimal test code is below to reproduce the clueless SIGFPE crash.
Code example
#!/usr/bin/env python
"""
Script to test/reproduce crashes with SIGFPE due to unhandled cases(scalar loc) in distributions.MultivariateNormal
"""
import torch
def test_univariate_scalar_input(loc=0.5, variance=0.1):
mu = torch.tensor(loc)
sigma = torch.tensor(variance)
distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma)
sample = distribution.sample()
print(sample)
def test_univariate_scalar_input_with_args_validation(loc=0.5, variance=0.1):
mu = torch.tensor(loc)
sigma = torch.tensor(variance)
distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma, validate_args=True)
sample = distribution.sample()
print(sample)
def test_univariate_input(loc=([0.5]), variance=0.1):
mu = torch.tensor(loc)
sigma = torch.tensor(variance)
distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma)
sample = distribution.sample()
print(sample)
def test_univariate_input_with_args_validation(loc=([0.5]), variance=0.1):
mu = torch.tensor(loc)
sigma = torch.tensor(variance)
distribution = torch.distributions.MultivariateNormal(mu, torch.eye(1) * sigma, validate_args=True)
sample = distribution.sample()
print(sample)
if __name__ == "__main__":
test_univariate_scalar_input(loc=0.5, variance=0.1) # Crashes with Floating point exception (SIGFPE)
#test_univariate_scalar_input_with_args_validation(loc=0.5, variance=0.1) #Crashes with Floating point exception (SIGFPE)
#test_univariate_input(loc=([0.5]), variance=0.1) # Runs without errors. Haven't verified if samples are from the correct normal distribution
#test_univariate_input_with_args_validation(loc=([0.5]), variance=0.1) # Runs without errors. Haven't verified if samples are from the correct normal distributionI will be happy to submit a PR if you think this needs a fix.
System Info
- PyTorch or Caffe2: PyTorch
- How you installed PyTorch (conda, pip, source): conda
- Build command you used (if compiling from source): NA
- OS: Ubuntu 16.04
- PyTorch version: 0.4.0
- Python version: 3.5.5
- CUDA/cuDNN version: NA
- GPU models and configuration: NA
- GCC version (if compiling from source): NA
- CMake version: NA
- Versions of any other relevant libraries: NA
Metadata
Metadata
Assignees
Labels
No labels