Skip to content

[bug] Multiplication of tensor with numpy scalar does not always work #9468

@krishnap25

Description

@krishnap25

Thanks for the amazing software. I stumbled across (what I believe is a) bug when multiplying a torch tensor with a numpy scalar.

Issue description

Multiplication of a torch tensor with numpy scalars exhibits unexpected behavior depending on the order of multiplication and datatypes. Specifically, multiplication of torch.FloatTensor with np.float32 does not work. Multiplication of torch.FloatTensor with np.float64 only works when written as tensor * scalar when tensor.requires_grad = True.

Code example

Trial 1: right multiplication with np.float64: the only setting that works:

tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float64(2.0)
prod = tensor * scalar
print(prod, prod.requires_grad, prod.dtype)

The output is:

tensor([ 2.,  2.]) True torch.float32

Trial 2: left multiplication with np.float64: does not work when tensor.requires_grad=True

tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float64(2.0)
prod = scalar * tensor
print(prod, prod.requires_grad, prod.dtype)

The error message is:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-26-58039673906b> in <module>()
      1 tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
      2 scalar = np.float64(2.0)
----> 3 prod = scalar * tensor
      4 print(prod, prod.requires_grad, prod.dtype)

~/software/anaconda3/envs/pyt4/lib/python3.6/site-packages/torch/tensor.py in __array__(self, dtype)
    374     def __array__(self, dtype=None):
    375         if dtype is None:
--> 376             return self.cpu().numpy()
    377         else:
    378             return self.cpu().numpy().astype(dtype, copy=False)

RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

Trial 3: right multiplication with np.float32: does not work:

tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float32(2.0)
prod = tensor * scalar
print(prod, prod.requires_grad, prod.dtype)

The error message is:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-29-1eb8188d0e28> in <module>()
      1 tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
      2 scalar = np.float32(2.0)
----> 3 prod = tensor * scalar
      4 print(prod, prod.requires_grad, prod.dtype)

TypeError: mul() received an invalid combination of arguments - got (numpy.float32), but expected one of:
 * (Tensor other)
      didn't match because some of the arguments have invalid types: (!numpy.float32!)
 * (float other)
      didn't match because some of the arguments have invalid types: (!numpy.float32!)

Trial 4: left multiplication with np.float32: does not work when tensor.requires_grad=True

tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float32(2.0)
prod = scalar * tensor
print(prod, prod.requires_grad, prod.dtype)

Same error message as Trial 2.

System Info

PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 9.1.85

OS: Ubuntu 17.10
GCC version: (Ubuntu 6.4.0-8ubuntu1) 6.4.0 20171010
CMake version: version 3.9.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 8.0.61
GPU models and configuration: 
GPU 0: TITAN Xp
GPU 1: TITAN Xp

Nvidia driver version: 390.30
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy (1.14.3)
[pip] numpydoc (0.8.0)
[pip] torch (0.4.0)
[pip] torchvision (0.2.1)
[conda] cuda91                    1.0                  h4c16780_0    pytorch
[conda] pytorch                   0.4.0           py36_cuda9.1.85_cudnn7.1.2_1  [cuda91]  pytorch
[conda] torchvision               0.2.1                    py36_1    pytorch

cc @mruberry @rgommers @heitorschueroff

Metadata

Metadata

Assignees

Labels

module: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions