Skip to content

Conversation

@pietern
Copy link
Contributor

@pietern pietern commented Apr 5, 2019

Stack:
    :black_circle:  #18953 Make DistributedDataParallel use new reducer  💚

This removes Python side bucketing code from DistributedDataParallel
and replaces it with calls to the new C++ based bucketing and reducing
code. To confirm this is working well, we ran a test with both the
previous implementation and the new implementation, and confirmed they
are numerically equivalent.

Performance is improved by a couple percent or more, including the
single machine multiple GPU runs.

Closes #13273.

Differential Revision: D14580911

pietern added 3 commits April 3, 2019 23:31
Differential Revision: D14768432
Differential Version: 78072755
Differential Revision: D14768432
Differential Version: 78072983
Differential Revision: D14768432
Differential Version: 78248526
Differential Revision: D14580911
Differential Version: 78287520
@pietern pietern force-pushed the export-D14580911 branch from 364a99a to 5aecb87 Compare April 5, 2019 15:56
Differential Revision: D14580911
Differential Version: 78287661
Differential Revision: D14580911
Differential Version: 78332640
@pietern
Copy link
Contributor Author

pietern commented Apr 5, 2019

@mcarilli This commit results in speedup of a couple percent on single machine runs and more on multi machine runs (when running ResNet50 on 100Gb Ethernet -- it's probably less for bigger models or on IB).

Also note that the bucket assignment can be changed on the fly, so this is probably useful for Apex.

@mcarilli
Copy link
Collaborator

mcarilli commented Apr 5, 2019

Sorry, I haven't had a chance to look at DDP in a while (I've been heads down on core Amp functionality, and probably will be for the next month at least) but this looks seriously awesome. If the new DDP accepts an arbitrary bucket structure, I can try making it work with the on-the-fly bucket construction as used in Apex...once I get a chance to look at it again. Again, I'm pretty slammed with other stuff at the moment :(

pietern added 3 commits April 8, 2019 15:10
Differential Revision: D14580911
Differential Version: 78602211
Differential Revision: D14580911
Differential Version: 78604488
Differential Revision: D14580911
Differential Version: 78976056
@pietern
Copy link
Contributor Author

pietern commented Apr 11, 2019

@mcarilli No worries and thanks for taking a peek. This does in fact take an arbitrary bucket structure. You can do anything from one massive bucket for the equivalent of delay_allreduce to a bucket per gradient (as long as the device and dtype match of course). This new class also does a sweep of the autograd graph to find unused parameters and preemptively marks them as ready. I think this should eliminate the mysterious hangs for which in the past delay_allreduce was the solution.

@caiqi
Copy link

caiqi commented Apr 11, 2019

Hi @pietern, I tried the commit locally on 4 V100 GPUs and cuda 9.0 with nccl backend. I found that when training with partial parameters unused, it will hang randomly. It seems that more parameters unused, it is more likely to hang. When all parameters are used, it will run smoothly. Do you have some suggestions on debugging such error?

@pietern
Copy link
Contributor Author

pietern commented Apr 11, 2019

@caiqi Thanks for giving it a try. Can you share an example that reproduces the issue?

@caiqi
Copy link

caiqi commented Apr 11, 2019

@pietern after more trying, I found that such hang only happens when I have more than one DistributedDataParallel. Follow is the code to reproduce the hang:

  • Follow code will hang after 2 iterations on my server
  • Change random_weight_max=1 to random_weight_max=3 will run smoothly
  • rm loss_func and use some simple loss function like torch.sum(...) won't hang for any random_weight_max
  • a little more note, random_weight_max controls how many branches are executed and the dynamic running is implemented with if torch.all(torch.eq(w, 0)):
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
import torch
import torch.nn as nn

OPS = {
    'none': lambda C, stride, affine: Zero(stride),
    'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
    'max_pool_3x3': lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
    'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
    'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
    'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
    'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
    'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
    'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
    'conv_7x1_1x7': lambda C, stride, affine: nn.Sequential(
        nn.ReLU(inplace=False),
        nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
        nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
        nn.BatchNorm2d(C, affine=affine)
    ),
}


class ReLUConvBN(nn.Module):

    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
        super(ReLUConvBN, self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(C_out, affine=affine)
        )

    def forward(self, x):
        return self.op(x)


class DilConv(nn.Module):

    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
        super(DilConv, self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
                      groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_out, affine=affine),
        )

    def forward(self, x):
        return self.op(x)


class SepConv(nn.Module):

    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
        super(SepConv, self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_in, affine=affine),
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_out, affine=affine),
        )

    def forward(self, x):
        return self.op(x)


class Identity(nn.Module):

    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class Zero(nn.Module):

    def __init__(self, stride):
        super(Zero, self).__init__()
        self.stride = stride

    def forward(self, x):
        if self.stride == 1:
            return x.mul(0.)
        return x[:, :, ::self.stride, ::self.stride].mul(0.)


class FactorizedReduce(nn.Module):

    def __init__(self, C_in, C_out, affine=True):
        super(FactorizedReduce, self).__init__()
        assert C_out % 2 == 0
        self.relu = nn.ReLU(inplace=False)
        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(C_out, affine=affine)

    def forward(self, x):
        x = self.relu(x)
        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
        out = self.bn(out)
        return out


from collections import namedtuple

Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES = [
    'none',
    'max_pool_3x3',
    'avg_pool_3x3',
    'skip_connect',
    'sep_conv_3x3',
    'sep_conv_5x5',
]
import torch.nn.functional as F


class MixedOp(nn.Module):
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        self._ops = nn.ModuleList()
        for primitive in PRIMITIVES:
            op = OPS[primitive](C, stride, False)
            if 'pool' in primitive:
                op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
            self._ops.append(op)

    def forward(self, x, weights):
        # weights: NUM_OPS * batch_size or NUM_OPS
        out_list = []
        for w, op in zip(weights, self._ops):
            if torch.all(torch.eq(w, 0)):
                continue
            o = op(x)
            dims = len(o.shape)
            # tile function based on repeat function
            w = w.view(-1, *[1 for _ in range(dims)])
            div = x.shape[0] // w.shape[0]
            w = w.repeat(1, div, *[1 for _ in range(dims - 1)])
            w = w.view(-1, *[1 for _ in range(dims - 1)])
            o_o = w * o
            out_list.append(o_o)
        stacked_tensor = torch.stack(out_list, dim=1)
        return sum(out_list), stacked_tensor


class Cell(nn.Module):
    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
        super(Cell, self).__init__()
        self.reduction = reduction
        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
        self._steps = steps
        self._multiplier = multiplier
        self._ops = nn.ModuleList()
        self._bns = nn.ModuleList()
        for i in range(self._steps):
            for j in range(2 + i):
                stride = 2 if reduction and j < 2 else 1
                op = MixedOp(C, stride)
                self._ops.append(op)

    def forward(self, s0, s1, weights):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        offset = 0
        hidden_tensors = []

        for i in range(self._steps):
            out_list = [self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states)]
            s = sum([m[0] for m in out_list])
            hidden_tensors.extend([m[1] for m in out_list])
            offset += len(states)
            states.append(s)
        hidden_tensors = torch.stack(hidden_tensors, dim=1)
        return torch.cat(states[-self._multiplier:], dim=1), hidden_tensors


class spatial_softmax(nn.Module):
    def __init__(self):
        super(spatial_softmax, self).__init__()

    def forward(self, input):
        input_shape = input.shape
        assert len(input_shape) == 4
        input = input.view(input_shape[0], input_shape[1], input_shape[2] * input_shape[3])
        input = F.softmax(input, dim=2)
        input = input.view(input_shape)
        return input


class sigmoid_branch(nn.Module):
    def __init__(self, k, num_ops):
        super(sigmoid_branch, self).__init__()
        self.k = k
        self.num_ops = num_ops

    def forward(self, input):
        input_shape = input.shape
        assert len(input_shape) == 4
        intput_reshape = input.view(input_shape[0], self.k, self.num_ops)
        intput_reshape = torch.sigmoid(intput_reshape)
        intput_reshape = intput_reshape.permute(1, 2, 0)
        intput_reshape = intput_reshape.contiguous()
        return intput_reshape


class softmax_branch(nn.Module):
    def __init__(self, k, num_ops):
        super(softmax_branch, self).__init__()
        self.k = k
        self.num_ops = num_ops

    def forward(self, input):
        input_shape = input.shape
        assert len(input_shape) == 4
        intput_reshape = input.view(input_shape[0], self.k, self.num_ops)
        intput_reshape = F.softmax(intput_reshape, dim=-1)
        intput_reshape = intput_reshape.permute(1, 2, 0)
        intput_reshape = intput_reshape.contiguous()
        return intput_reshape


class Network(nn.Module):

    def __init__(self, C, layers, steps=2, multiplier=2, stem_multiplier=2, userelu=False, global_pool=True,
                 random_weight_max=1):
        super(Network, self).__init__()
        self._C = C
        self.random_seed = 0
        self._layers = layers
        self.random_weight_max = random_weight_max
        self.cnt = 0
        self.check_assert = True
        self._steps = steps
        self._multiplier = multiplier
        assert self._multiplier <= self._steps, "multiplier should be no larger than steps"
        self.global_pool = global_pool
        self.userelu = userelu
        if self.userelu:
            self.relu = nn.ReLU(inplace=False)

        C_curr = stem_multiplier * C
        self.stem = nn.Sequential(
            nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
            nn.BatchNorm2d(C_curr)
        )

        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
        self.cells = nn.ModuleList()
        reduction_prev = False
        for i in range(layers):
            # todo change for few-shot learning, decrease by four times
            # if i in [layers // 3, 2 * layers // 3]:
            reduction_layers = [layers // 4, 2 * layers // 4, 3 * layers // 4]
            if i in reduction_layers:
                C_curr *= 2
                reduction = True
            else:
                reduction = False
            cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
            reduction_prev = reduction
            self.cells += [cell]
            C_prev_prev, C_prev = C_prev, multiplier * C_curr
        self.global_pooling = nn.AdaptiveAvgPool2d(1)

    def complex_inference(self, input, test_input=None):
        s0 = s1 = self.stem(input)
        test_s0 = test_s1 = self.stem(test_input)
        for i, cell in enumerate(self.cells):
            k = sum(1 for i in range(self._steps) for _ in range(2 + i))
            num_ops = len(PRIMITIVES)
            weights = torch.ones(k, num_ops, 1, requires_grad=False).cuda() / num_ops
            weights_test = torch.ones(k, num_ops, 1, requires_grad=False).cuda() / num_ops
            if self.random_weight_max > 0 and self.training:
                weights = self.random_masked_weight(weights, keep_number=self.random_weight_max)
                weights_test = self.random_masked_weight(weights_test, keep_number=self.random_weight_max)
            s0, (s1, hiden_tensors) = s1, cell(s0, s1, weights)
            test_s0, (test_s1, hiden_tensors) = test_s1, cell(test_s0, test_s1, weights_test)
        if self.global_pool:
            out = self.global_pooling(s1)
            out_test = self.global_pooling(test_s1)
        else:
            out = s1
            out_test = test_s1
        if self.userelu:
            out = self.relu(out)
            out_test = self.relu(out_test)
        return out.view(out.size(0), -1), out_test.view(out_test.size(0), -1)

    def forward(self, input, test_input=None):
        return self.complex_inference(input, test_input)

    def random_masked_weight(self, weights, keep_number):
        node_num, branch_num, batch_num = weights.shape
        weights = weights.permute(0, 2, 1)
        weights = weights.contiguous()
        weights = weights.view(node_num * batch_num, branch_num)
        weights_rand = torch.rand_like(weights)
        weights_sort, indices = weights_rand.sort(dim=1, descending=True)
        weights_select = weights_sort[:,
                         int(keep_number):int(keep_number + 1)]
        weights_signes = torch.gt(weights_rand, weights_select)
        weights_signes = weights_signes.type(weights.type())
        weights = weights * (weights_signes > 0).type(weights.type())
        weights = weights / (torch.sum(weights, dim=1, keepdim=True) + 1e-6)
        weights = weights.view(node_num, batch_num, branch_num)
        weights = weights.permute(0, 2, 1)
        weights = weights.contiguous()
        return weights


args = parser.parse_args()
from torch.distributed import init_process_group
from torch.cuda import set_device
from torch.nn.parallel.distributed import DistributedDataParallel


class Network2(nn.Module):
    def __init__(self):
        super(Network2, self).__init__()
        self.linear = nn.Linear(1024, 1)

    def forward(self, input, label):
        return torch.pow(self.linear(input) - label, 2).sum()


def main():
    device = torch.device("cuda:{}".format(args.local_rank))
    set_device(device)
    init_process_group(backend="nccl", init_method="env://")

    network = Network(C=64, layers=4)
    network.cuda()
    network = DistributedDataParallel(network, device_ids=[args.local_rank], output_device=args.local_rank,
                                      broadcast_buffers=True)
    loss_func = Network2()
    loss_func.cuda()
    loss_func = DistributedDataParallel(loss_func, device_ids=[args.local_rank], output_device=args.local_rank,
                                        broadcast_buffers=True)
    optim = torch.optim.SGD(lr=0.001, params=network.parameters())

    for k in range(10):
        print(k)
        data1 = torch.rand(12, 3, 80, 80)
        data1.cuda()
        data2 = torch.rand(12, 3, 80, 80)
        data2.cuda()
        label = torch.rand(1)
        label.cuda()
        out1, out2 = network(data1, data2)
        sm_results = out1 + out2
        loss = loss_func(sm_results, label)
        loss.backward()
        optim.step()


if __name__ == '__main__':
    main()

@pietern
Copy link
Contributor Author

pietern commented Apr 11, 2019

@pytorchbot retest this please

@pietern
Copy link
Contributor Author

pietern commented Apr 11, 2019

@caiqi Thanks for the repro. Two instances of DDP should not be a problem as they don't share global state, but both will do a mark and sweep on the autograd graph starting from its outputs. It is possible that the second one is interfering with the first one that way.

Can you try and combine both models into a single one? Looking at your code that should be possible. If you do, and it doesn't repro, then at least we would have isolated the issue to being caused by use of multiple DDP models in a single iteration.

@caiqi
Copy link

caiqi commented Apr 11, 2019

Yes, I have tested and combining the two DDP into a single one won't hang.

@pietern
Copy link
Contributor Author

pietern commented Apr 11, 2019

Both DDP models try to use the same distributed process group. All collective calls across processes are expected to be executed in the same order. With different processes taking a different path through the model, it is possible some reordering happens, where for example process A kicks off a reduction for model 1, and process B kicks off a reduction for model 2. These calls don't match, so they result in a hang. I confirmed that you can fix this problem by forcing the different DDP models to use a different process group instance. With this fix, there is no longer a risk of interference, as they both have their own process group.

This is the diff:

**$ diff main.py main_new.py 
--- main.py     2019-04-11 11:32:37.754243099 -0700
+++ main_new.py 2019-04-11 11:33:02.242294191 -0700
@@ -345,14 +345,17 @@
     set_device(device)
     init_process_group(backend="nccl", init_method="env://")
 
+    pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
+    pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
+
     network = Network(C=64, layers=4)
     network.cuda()
     network = DistributedDataParallel(network, device_ids=[args.local_rank], output_device=args.local_rank,
-                                      broadcast_buffers=True)
+                                      broadcast_buffers=True, process_group=pg1)
     loss_func = Network2()
     loss_func.cuda()
     loss_func = DistributedDataParallel(loss_func, device_ids=[args.local_rank], output_device=args.local_rank,
-                                        broadcast_buffers=True)
+                                        broadcast_buffers=True, process_group=pg2)
     optim = torch.optim.SGD(lr=0.001, params=network.parameters())
 
     for k in range(10):

@caiqi
Copy link

caiqi commented Apr 11, 2019

@pietern It works, thanks for the solutions!

@pietern
Copy link
Contributor Author

pietern commented Apr 11, 2019

@caiqi Glad to hear. Thanks for confirming this PR works for you.

pietern added 2 commits April 11, 2019 14:16
Differential Revision: D14580911
Differential Version: 79077177
Differential Revision: D14580911
Differential Version: 79455513
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in a0263ec.

@caiqi
Copy link

caiqi commented Apr 17, 2019

@pietern, I found that when using the commit with module which returns dict will cause the following error:

RuntimeError: next_bucket_ == buckets_.size() ASSERT FAILED at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:419, please report a bug to PyTorch. Expected all buckets to be ready at the end of the backward pass.

This is the reproducing code :

import argparse
import math

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes, userelu=True):
        super(ConvBlock, self).__init__()
        self.layers = nn.Sequential()
        self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,
                                                 kernel_size=3, stride=1, padding=1, bias=False))
        self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes))

        if userelu:
            self.layers.add_module('ReLU', nn.ReLU(inplace=True))

        self.layers.add_module(
            'MaxPool', nn.MaxPool2d(kernel_size=2, stride=2, padding=0))

    def forward(self, x):
        out = self.layers(x)
        return out


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.conv_blocks = nn.Sequential(
            ConvBlock(3, 3),
            nn.ReLU(),
            ConvBlock(3, 3),
            nn.ReLU(),
            ConvBlock(3, 3),
            nn.ReLU(),
            ConvBlock(3, 3),
            nn.ReLU(),
        )

    def forward(self, x):
        blocks = [m for m in self.conv_blocks]
        index = [i for i in range(2)]
        blocks = [blocks[k] for k in index]
        for m in blocks:
            x = m(x)
        out = x
        out = self.global_pooling(out)
        out = out.view(out.size(0), -1)
        label = torch.ones_like(out)
        loss = torch.sum(out - label)
        return {"loss": loss}


args = parser.parse_args()
from torch.distributed import init_process_group
from torch.cuda import set_device
from torch.nn.parallel.distributed import DistributedDataParallel


def main():
    device = torch.device("cuda:{}".format(args.local_rank))
    set_device(device)
    init_process_group(backend="nccl", init_method="env://")
    network = ConvNet()
    network.cuda()
    network = DistributedDataParallel(network, device_ids=[args.local_rank], output_device=args.local_rank,
                                      broadcast_buffers=True)
    optim = torch.optim.SGD(lr=0.001, params=network.parameters())

    for k in range(10):
        print(k)
        data1 = torch.rand(12, 3, 80, 80)
        data1.cuda()
        out1 = network(data1)
        loss = out1["loss"]
        optim.zero_grad()
        loss.backward()
        optim.step()


if __name__ == '__main__':
    main()

On my server, it will give the error

Expected all buckets to be ready at the end of the backward pass.

Commenting these two lines or return the tensor rather than the dict will run without any error.

        index = [i for i in range(2)]
        blocks = [blocks[k] for k in index]

It looks strange why the return type of module will influence the behavior of reducer. Do you have some suggestions on this error?

@pietern
Copy link
Contributor Author

pietern commented Apr 17, 2019

@caiqi Thanks for reporting the issue. This is happening because we don't deal with a dict return value at the moment and only scan a tuple or list for returned tensors. If the reducer doesn't know about the output tensor then it can't mark unused parameters as skipped. This means this problem only shows up when you use a model with unused parameters and return a dict. I added #19354 to track this.

@pietern
Copy link
Contributor Author

pietern commented Apr 17, 2019

Commenting out those lines means all parameters are used, so none of them have to be pre-emptively marked as ready. Stay tuned for a fix, I'll work on this this morning.

facebook-github-bot pushed a commit that referenced this pull request Apr 26, 2019
Summary:
Pull Request resolved: #19799

A module that returns multiple outputs and where the called may end up
doing multiple calls to torch.autograd.backward did not work with
DistributedDataParallel. It expected the first call to
torch.autograd.backward to provide gradients for ALL parameters that
expect gradients and were used in computing the module output. If you
have outputs with disjoint autograd graphs it is fine to call
torch.autograd.backward on both and fill in the module's parameter
gradients in separate chunks.

With this change we delay queuing the finalizer callback until we have
marked all buckets as ready, instead of queueing it the first time we
receive an autograd hook. This returns the current implementation to
be functionally equivalent to the DistributedDataParallel
implementation before #18953 was merged.

Reviewed By: mrshenli

Differential Revision: D15097045

fbshipit-source-id: 2df023319713bc31e29a8b45108c78e6593fccd4
zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary:
Pull Request resolved: pytorch#18953

This removes Python side bucketing code from DistributedDataParallel
and replaces it with calls to the new C++ based bucketing and reducing
code. To confirm this is working well, we ran a test with both the
previous implementation and the new implementation, and confirmed they
are numerically equivalent.

Performance is improved by a couple percent or more, including the
single machine multiple GPU runs.

Closes pytorch#13273.

Reviewed By: mrshenli

Differential Revision: D14580911

fbshipit-source-id: 44e76f8b0b7e58dd6c91644e3df4660ca2ee4ae2
zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary:
Pull Request resolved: pytorch#19799

A module that returns multiple outputs and where the called may end up
doing multiple calls to torch.autograd.backward did not work with
DistributedDataParallel. It expected the first call to
torch.autograd.backward to provide gradients for ALL parameters that
expect gradients and were used in computing the module output. If you
have outputs with disjoint autograd graphs it is fine to call
torch.autograd.backward on both and fill in the module's parameter
gradients in separate chunks.

With this change we delay queuing the finalizer callback until we have
marked all buckets as ready, instead of queueing it the first time we
receive an autograd hook. This returns the current implementation to
be functionally equivalent to the DistributedDataParallel
implementation before pytorch#18953 was merged.

Reviewed By: mrshenli

Differential Revision: D15097045

fbshipit-source-id: 2df023319713bc31e29a8b45108c78e6593fccd4
@ezyang ezyang deleted the export-D14580911 branch May 30, 2019 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Getting DDP support models that has part of the parameters/layers unused

6 participants