Skip to content

v1.0.0 nn.utils.weight_norm seems to nullify gradients of unrelated parameters if wrapped in DataParallel  #14960

@L0SG

Description

@L0SG

🐛 Bug

nn.utils.weight_norm seems to nullify gradients of unrelated parameters if wrapped in DataParallel in v1.0.0.

From the code below, a DebugNet model computes two disjoint operations:

  1. upsampling input c with nn.ConvTranspose2d and the nn.utils.weight_norm.
  2. calculating with input x with an independent class named ActNorm.

For the single-GPU setup it works as intedned: the ActNorm parameters get gradients from training.
For the multi-GPU setup, if using nn.utils.weight_norm in operation 1, the independent gradient of ActNorm parameters becomes None.

To Reproduce

Steps to reproduce the behavior:

  1. run the sample code with PyTorch v1.0.0 with 2 or more GPUs: CUDA_VISIBLE_DEVICES=0,1 python code.py
  2. if use_weightnorm = True, the gradient of an unrelated parameter is None
  3. if use_weightnorm = False, the gradient is there as expected
  4. this only happens on multi-GPU setup: if running CUDA_VISIBLE_DEVICES=0 python code.py , both flags get the gradient without the change of other code.
import torch
from torch import nn
from math import log, pi

logabs = lambda x: torch.log(torch.abs(x))

use_weightnorm = True


class ActNorm(nn.Module):
    def __init__(self, in_channel, pretrained=False):
        super().__init__()

        self.loc = nn.Parameter(torch.zeros(1, in_channel, 1))
        self.scale = nn.Parameter(torch.ones(1, in_channel, 1))

        self.initialized = pretrained

    def initialize(self, x):
        with torch.no_grad():
            flatten = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
            mean = (
                flatten.mean(1)
                    .unsqueeze(1)
                    .unsqueeze(2)
                    .permute(1, 0, 2)
            )
            std = (
                flatten.std(1)
                    .unsqueeze(1)
                    .unsqueeze(2)
                    .permute(1, 0, 2)
            )

            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

    def forward(self, x):
        _, _, T = x.size()

        if not self.initialized:
            self.initialize(x)
            self.initialized = True

        log_abs = logabs(self.scale)
        logdet = torch.mean(log_abs)

        return self.scale * (x + self.loc), logdet


class DebugNet(nn.Module):
    def __init__(self, in_channel, pretrained=False):
        super().__init__()

        self.actnorm = ActNorm(in_channel, pretrained=pretrained)

        self.upsample_conv = nn.ModuleList()
        # 16x upsampling of c
        for s in [16]:
            convt = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
            # weight_norm seems to nullify the (totally unrelated) gradient of ActNorm parameters if wrapped in DataParallel
            if use_weightnorm:
                convt = nn.utils.weight_norm(convt)
            self.upsample_conv.append(convt)

    def forward(self, x, c):
        # upsample c by 16x
        c = self.upsample(c)

        # actnorm is not dependent on c at all
        out, logdet = self.actnorm(x)

        # maximum likelihood loss
        log_p = 0.5 * (- log(2.0 * pi) - out.pow(2)).mean()
        return log_p, logdet

    def upsample(self, c):
        c = c.unsqueeze(1)
        for f in self.upsample_conv:
            c = f(c)
        c = c.squeeze(1)
        return c


use_cuda = True
device = device = torch.device("cuda" if use_cuda else "cpu")

batch_size = 4
# random data (1 for waveform & 40-band spectrogram for c)
random_data = torch.randn(batch_size, 1, 512).clone().detach().to(device)
random_data_c = torch.randn(batch_size, 40, 32).clone().detach().to(device)

net = DebugNet(in_channel=1, pretrained=False).to(device)

# cast the net into DataParallel
net = nn.DataParallel(net)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# train
for i in range(1):
    optimizer.zero_grad()
    log_p, logdet = net(random_data, random_data_c)
    log_p, logdet = torch.mean(log_p), torch.mean(logdet)
    loss = -(log_p + logdet)
    loss.backward()
    optimizer.step()
    # is there grad at actnorm?
    hmm = net.module.actnorm.scale.grad
    if hmm is None:
        print("no grad at actnorm parameters")
    else:
        print("found grad at actnorm parameters")
        print(hmm)

Expected behavior

In PyTorch v0.4.1, both use_weightnorm = True and use_weightnorm = False got the gradient.

Environment

Collecting environment information...
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: TITAN V
GPU 1: TITAN V
GPU 2: TITAN V
GPU 3: TITAN V
GPU 4: TITAN V
GPU 5: TITAN V
GPU 6: TITAN V
GPU 7: TITAN V

Nvidia driver version: 410.78
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.0.5

Versions of relevant libraries:
[pip] msgpack-numpy (0.4.1)
[pip] numpy (1.14.3)
[pip] torch (1.0.0)
[pip] torchfile (0.1.0)
[pip] torchtext (0.2.3)
[pip] torchvision (0.2.1)
[conda] blas 1.0 mkl
[conda] cuda92 1.0 0 pytorch
[conda] mkl 2018.0.2 1
[conda] mkl_fft 1.0.1 py36h3010b51_0
[conda] mkl_random 1.0.1 py36h629b387_0
[conda] pytorch 1.0.0 py3.6_cuda9.0.176_cudnn7.4.1_1 pytorch
[conda] torchfile 0.1.0
[conda] torchtext 0.2.3
[conda] torchvision 0.2.1 py_2 pytorch

  • PyTorch Version (e.g., 1.0): 1.0
  • OS (e.g., Linux): Ubuntu 16.04
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): NA
  • Python version: 3.6.7
  • CUDA/cuDNN version: conda packages, both 9.0 and 10.0 of v1.0.0
  • GPU models and configuration: 8x TITAN V
  • Any other relevant information: NA

Additional context

The snippet is a minimal excerpt from FloWaveNet repo.

Metadata

Metadata

Assignees

Labels

high prioritymodule: data parallelmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions