Skip to content

Missing gradient when autograd called inside a function on Multi-GPU (eg gradient penalty) #16532

@tebesu

Description

@tebesu

🐛 Bug

Gradient is missing when calling torch.autograd.grad wrapped inside a function on multiple GPU's. (eg computing wgan gradient penalty). Calling torch.autograd.grad inline (not wrapped in a function) on multiple GPU's returns expected behavior.

To Reproduce

Code below:

import torch
import torch.nn as nn

torch.cuda.manual_seed_all(0)
torch.manual_seed(0)


def gradient_penalty(netD, x):
    """Functional Gradient Calculation"""
    output = netD(x)
    gradients = torch.autograd.grad(outputs=output, inputs=x,
                                    grad_outputs=x.new_ones(output.size()),
                                    create_graph=True, retain_graph=True)[0].mean()
    return gradients


net = nn.Linear(4, 1).cuda()
multigpu_net = nn.DataParallel(net, [0, 1])

x = torch.ones(2, 4, requires_grad=True).cuda()

print("Single GPU Functional")
net.zero_grad()
loss = gradient_penalty(net, x)
loss.backward()
print("Loss:", loss.item())
print("Grad:", [p.grad for p in net.parameters() if p.grad is not None])


print("\nMulti-GPU Functional")
multigpu_net.zero_grad()
loss = gradient_penalty(multigpu_net, x)
loss.backward()
print("Loss:", loss.item())
print("Grad:", [p.grad for p in net.parameters() if p.grad is not None])

print("\nMulti-GPU Inline")
multigpu_net.zero_grad()
output = multigpu_net(x)

# Compute grad inline
loss = torch.autograd.grad(outputs=output, inputs=x,
                            grad_outputs=x.new_ones(output.size()),
                            create_graph=True, retain_graph=True)[0].mean()
loss.backward()
print("Loss:", loss.item())
print("Grad:", [p.grad for p in net.parameters() if p.grad is not None])

Output for a single GPU calling gradient_penalty function

Single GPU Functional
Loss: -0.1287534236907959
Grad: [tensor([[0.2500, 0.2500, 0.2500, 0.2500]], device='cuda:0')]

MultiGPU calling gradient_penalty function

Multi-GPU Functional
Loss: -0.1287534236907959
Grad: [tensor([[0., 0., 0., 0.]], device='cuda:0')]

Multi-GPU calling autograd.grad inline (not inside a function)

Multi-GPU Inline
Loss: -0.1287534236907959
Grad: [tensor([[0.2500, 0.2500, 0.2500, 0.2500]], device='cuda:0'), tensor([0.], device='cuda:0')]

Expected behavior

The gradient should be accumulated when calling autograd.grad from inside another function. All outputs gradients from the script should be the same.

Environment

  • PyTorch Version (e.g., 1.0): 1.0.9
  • OS (e.g., Linux): Ubuntu 16.04.3 LTS
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.6
  • CUDA/cuDNN version: 9.0/6
  • GPU models and configuration:
    GPU 0: GeForce GTX 1080
    GPU 1: GeForce GTX 1080
    GPU 2: GeForce GTX 1080
    GPU 3: GeForce GTX 1080

Metadata

Metadata

Assignees

Labels

high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions