-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: numpyRelated to numpy support, and also numpy compatibility of our operatorsRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Thanks for the amazing software. I stumbled across (what I believe is a) bug when multiplying a torch tensor with a numpy scalar.
Issue description
Multiplication of a torch tensor with numpy scalars exhibits unexpected behavior depending on the order of multiplication and datatypes. Specifically, multiplication of torch.FloatTensor with np.float32 does not work. Multiplication of torch.FloatTensor with np.float64 only works when written as tensor * scalar when tensor.requires_grad = True.
Code example
Trial 1: right multiplication with np.float64: the only setting that works:
tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float64(2.0)
prod = tensor * scalar
print(prod, prod.requires_grad, prod.dtype)The output is:
tensor([ 2., 2.]) True torch.float32
Trial 2: left multiplication with np.float64: does not work when tensor.requires_grad=True
tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float64(2.0)
prod = scalar * tensor
print(prod, prod.requires_grad, prod.dtype)The error message is:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-26-58039673906b> in <module>()
1 tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
2 scalar = np.float64(2.0)
----> 3 prod = scalar * tensor
4 print(prod, prod.requires_grad, prod.dtype)
~/software/anaconda3/envs/pyt4/lib/python3.6/site-packages/torch/tensor.py in __array__(self, dtype)
374 def __array__(self, dtype=None):
375 if dtype is None:
--> 376 return self.cpu().numpy()
377 else:
378 return self.cpu().numpy().astype(dtype, copy=False)
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
Trial 3: right multiplication with np.float32: does not work:
tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float32(2.0)
prod = tensor * scalar
print(prod, prod.requires_grad, prod.dtype)The error message is:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-29-1eb8188d0e28> in <module>()
1 tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
2 scalar = np.float32(2.0)
----> 3 prod = tensor * scalar
4 print(prod, prod.requires_grad, prod.dtype)
TypeError: mul() received an invalid combination of arguments - got (numpy.float32), but expected one of:
* (Tensor other)
didn't match because some of the arguments have invalid types: (!numpy.float32!)
* (float other)
didn't match because some of the arguments have invalid types: (!numpy.float32!)
Trial 4: left multiplication with np.float32: does not work when tensor.requires_grad=True
tensor = torch.ones(2, requires_grad=True, dtype=torch.float32)
scalar = np.float32(2.0)
prod = scalar * tensor
print(prod, prod.requires_grad, prod.dtype)Same error message as Trial 2.
System Info
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 9.1.85
OS: Ubuntu 17.10
GCC version: (Ubuntu 6.4.0-8ubuntu1) 6.4.0 20171010
CMake version: version 3.9.1
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 8.0.61
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
Nvidia driver version: 390.30
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy (1.14.3)
[pip] numpydoc (0.8.0)
[pip] torch (0.4.0)
[pip] torchvision (0.2.1)
[conda] cuda91 1.0 h4c16780_0 pytorch
[conda] pytorch 0.4.0 py36_cuda9.1.85_cudnn7.1.2_1 [cuda91] pytorch
[conda] torchvision 0.2.1 py36_1 pytorch
Metadata
Metadata
Assignees
Labels
module: numpyRelated to numpy support, and also numpy compatibility of our operatorsRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module