-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Issue description
Since v0.4 returns scalar (0-dim tensor) loss, gathering the scalar loss manually raises an error like the example.
Unsqueezing the scalar losses back to 1-dim vector like the previous versions works, but is this an intended behavior of nn.parallel.gather?
The given parallel GPU code scheme is used in Annotated Transformer implementation.
Code example
import torch.nn as nn
import torch
# GPUs to use
devices = [0, 1, 2, 3]
# toy feed-forward net
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
# define random data
random_input = torch.randn((4, 10))
random_target = torch.randn((4))
net = Net().cuda()
# replicate nets, scatter inputs and parallel_apply
replicas = nn.parallel.replicate(net, devices)
random_input_scatter = nn.parallel.scatter(random_input, devices)
replicas = replicas[:len(random_input_scatter)]
outputs = nn.parallel.parallel_apply(replicas, random_input_scatter)
# replicate losses, scatter targets, zip output-target pairs
criterion = nn.MSELoss()
criterion = nn.parallel.replicate(criterion, devices)
random_target_scatter = nn.parallel.scatter(random_target, devices)
output_target_pairs = [(output, target) for output, target in zip(outputs, random_target_scatter)]
# this results in scalar (0-dim tensor) losses with v0.4
loss = nn.parallel.parallel_apply(criterion, output_target_pairs)
# gathering 0-dim tensors raises error in line 54 of parallel.gather function => ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
loss_gather = nn.parallel.gather(loss, target_device=devices[0])
# unsqueezing the scalar loss to vector (like from previous versions) works as intended
"""
for idx in range(len(loss)):
loss[idx] = loss[idx].unsqueeze(0)
loss_gather = nn.parallel.gather(loss, target_device=devices[0])
"""
Traceback (most recent call last):
File "gather_bug.py", line 43, in
loss_gather = nn.parallel.gather(loss, target_device=devices[0])
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
return gather_map(outputs)
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in forward
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
RuntimeError: dimension specified as 0 but tensor has no dimensions
System Info
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp
Nvidia driver version: 390.30
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_static.a
/usr/local/MATLAB/R2017b/bin/glnxa64/libcudnn.so.5.1.5
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.6.0.21
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn_static.a
Versions of relevant libraries:
[pip] msgpack-numpy (0.4.1)
[pip] numpy (1.13.3)
[pip] torch (0.4.0)
[pip] torchfile (0.1.0)
[pip] torchnet (0.0.1)
[pip] torchtext (0.2.3)
[pip] torchvision (0.2.1)
[conda] cuda90 1.0 h6433d27_0 pytorch
[conda] pytorch 0.4.0 py36_cuda9.0.176_cudnn7.1.2_1 [cuda90] pytorch
[conda] torchfile 0.1.0
[conda] torchnet 0.0.1
[conda] torchtext 0.2.3
[conda] torchvision 0.2.1 py36_1 pytorch
- PyTorch or Caffe2: PyTorch
- How you installed PyTorch (conda, pip, source): conda
- Build command you used (if compiling from source): NA
- OS: ubuntu 16.04
- PyTorch version: 0.4
- Python version: 3.6
- CUDA/cuDNN version: 9.0.176/7.0.5
- GPU models and configuration: 4x Titan Xp
- GCC version (if compiling from source): N/A
- CMake version: 3.5.1
- Versions of any other relevant libraries: N/A