-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
Description
🐛 Bug
Gradient is missing when calling torch.autograd.grad wrapped inside a function on multiple GPU's. (eg computing wgan gradient penalty). Calling torch.autograd.grad inline (not wrapped in a function) on multiple GPU's returns expected behavior.
To Reproduce
Code below:
import torch
import torch.nn as nn
torch.cuda.manual_seed_all(0)
torch.manual_seed(0)
def gradient_penalty(netD, x):
"""Functional Gradient Calculation"""
output = netD(x)
gradients = torch.autograd.grad(outputs=output, inputs=x,
grad_outputs=x.new_ones(output.size()),
create_graph=True, retain_graph=True)[0].mean()
return gradients
net = nn.Linear(4, 1).cuda()
multigpu_net = nn.DataParallel(net, [0, 1])
x = torch.ones(2, 4, requires_grad=True).cuda()
print("Single GPU Functional")
net.zero_grad()
loss = gradient_penalty(net, x)
loss.backward()
print("Loss:", loss.item())
print("Grad:", [p.grad for p in net.parameters() if p.grad is not None])
print("\nMulti-GPU Functional")
multigpu_net.zero_grad()
loss = gradient_penalty(multigpu_net, x)
loss.backward()
print("Loss:", loss.item())
print("Grad:", [p.grad for p in net.parameters() if p.grad is not None])
print("\nMulti-GPU Inline")
multigpu_net.zero_grad()
output = multigpu_net(x)
# Compute grad inline
loss = torch.autograd.grad(outputs=output, inputs=x,
grad_outputs=x.new_ones(output.size()),
create_graph=True, retain_graph=True)[0].mean()
loss.backward()
print("Loss:", loss.item())
print("Grad:", [p.grad for p in net.parameters() if p.grad is not None])Output for a single GPU calling gradient_penalty function
Single GPU Functional
Loss: -0.1287534236907959
Grad: [tensor([[0.2500, 0.2500, 0.2500, 0.2500]], device='cuda:0')]
MultiGPU calling gradient_penalty function
Multi-GPU Functional
Loss: -0.1287534236907959
Grad: [tensor([[0., 0., 0., 0.]], device='cuda:0')]
Multi-GPU calling autograd.grad inline (not inside a function)
Multi-GPU Inline
Loss: -0.1287534236907959
Grad: [tensor([[0.2500, 0.2500, 0.2500, 0.2500]], device='cuda:0'), tensor([0.], device='cuda:0')]
Expected behavior
The gradient should be accumulated when calling autograd.grad from inside another function. All outputs gradients from the script should be the same.
Environment
- PyTorch Version (e.g., 1.0): 1.0.9
- OS (e.g., Linux): Ubuntu 16.04.3 LTS
- How you installed PyTorch (
conda,pip, source): pip - Python version: 3.6
- CUDA/cuDNN version: 9.0/6
- GPU models and configuration:
GPU 0: GeForce GTX 1080
GPU 1: GeForce GTX 1080
GPU 2: GeForce GTX 1080
GPU 3: GeForce GTX 1080
aluo-x, zsef123, crashmoon, MKFMIKU, wuhaozhe and 4 more17Skye17 and Saleh-Gholam-Zadeh
Metadata
Metadata
Assignees
Labels
high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module