-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make DistributedDataParallel use new reducer #18953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Differential Revision: D14768432 Differential Version: 78072755
Differential Revision: D14768432 Differential Version: 78072983
Differential Revision: D14768432 Differential Version: 78248526
Differential Revision: D14580911 Differential Version: 78287520
364a99a to
5aecb87
Compare
Differential Revision: D14580911 Differential Version: 78287661
Differential Revision: D14580911 Differential Version: 78332640
|
@mcarilli This commit results in speedup of a couple percent on single machine runs and more on multi machine runs (when running ResNet50 on 100Gb Ethernet -- it's probably less for bigger models or on IB). Also note that the bucket assignment can be changed on the fly, so this is probably useful for Apex. |
|
Sorry, I haven't had a chance to look at DDP in a while (I've been heads down on core Amp functionality, and probably will be for the next month at least) but this looks seriously awesome. If the new DDP accepts an arbitrary bucket structure, I can try making it work with the on-the-fly bucket construction as used in Apex...once I get a chance to look at it again. Again, I'm pretty slammed with other stuff at the moment :( |
Differential Revision: D14580911 Differential Version: 78602211
Differential Revision: D14580911 Differential Version: 78604488
Differential Revision: D14580911 Differential Version: 78976056
|
@mcarilli No worries and thanks for taking a peek. This does in fact take an arbitrary bucket structure. You can do anything from one massive bucket for the equivalent of |
|
Hi @pietern, I tried the commit locally on 4 V100 GPUs and cuda 9.0 with nccl backend. I found that when training with partial parameters unused, it will hang randomly. It seems that more parameters unused, it is more likely to hang. When all parameters are used, it will run smoothly. Do you have some suggestions on debugging such error? |
|
@caiqi Thanks for giving it a try. Can you share an example that reproduces the issue? |
|
@pietern after more trying, I found that such hang only happens when I have more than one DistributedDataParallel. Follow is the code to reproduce the hang:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
import torch
import torch.nn as nn
OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'max_pool_3x3': lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
'conv_7x1_1x7': lambda C, stride, affine: nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
]
import torch.nn.functional as F
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES:
op = OPS[primitive](C, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
self._ops.append(op)
def forward(self, x, weights):
# weights: NUM_OPS * batch_size or NUM_OPS
out_list = []
for w, op in zip(weights, self._ops):
if torch.all(torch.eq(w, 0)):
continue
o = op(x)
dims = len(o.shape)
# tile function based on repeat function
w = w.view(-1, *[1 for _ in range(dims)])
div = x.shape[0] // w.shape[0]
w = w.repeat(1, div, *[1 for _ in range(dims - 1)])
w = w.view(-1, *[1 for _ in range(dims - 1)])
o_o = w * o
out_list.append(o_o)
stacked_tensor = torch.stack(out_list, dim=1)
return sum(out_list), stacked_tensor
class Cell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
self.reduction = reduction
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self._bns = nn.ModuleList()
for i in range(self._steps):
for j in range(2 + i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride)
self._ops.append(op)
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
hidden_tensors = []
for i in range(self._steps):
out_list = [self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states)]
s = sum([m[0] for m in out_list])
hidden_tensors.extend([m[1] for m in out_list])
offset += len(states)
states.append(s)
hidden_tensors = torch.stack(hidden_tensors, dim=1)
return torch.cat(states[-self._multiplier:], dim=1), hidden_tensors
class spatial_softmax(nn.Module):
def __init__(self):
super(spatial_softmax, self).__init__()
def forward(self, input):
input_shape = input.shape
assert len(input_shape) == 4
input = input.view(input_shape[0], input_shape[1], input_shape[2] * input_shape[3])
input = F.softmax(input, dim=2)
input = input.view(input_shape)
return input
class sigmoid_branch(nn.Module):
def __init__(self, k, num_ops):
super(sigmoid_branch, self).__init__()
self.k = k
self.num_ops = num_ops
def forward(self, input):
input_shape = input.shape
assert len(input_shape) == 4
intput_reshape = input.view(input_shape[0], self.k, self.num_ops)
intput_reshape = torch.sigmoid(intput_reshape)
intput_reshape = intput_reshape.permute(1, 2, 0)
intput_reshape = intput_reshape.contiguous()
return intput_reshape
class softmax_branch(nn.Module):
def __init__(self, k, num_ops):
super(softmax_branch, self).__init__()
self.k = k
self.num_ops = num_ops
def forward(self, input):
input_shape = input.shape
assert len(input_shape) == 4
intput_reshape = input.view(input_shape[0], self.k, self.num_ops)
intput_reshape = F.softmax(intput_reshape, dim=-1)
intput_reshape = intput_reshape.permute(1, 2, 0)
intput_reshape = intput_reshape.contiguous()
return intput_reshape
class Network(nn.Module):
def __init__(self, C, layers, steps=2, multiplier=2, stem_multiplier=2, userelu=False, global_pool=True,
random_weight_max=1):
super(Network, self).__init__()
self._C = C
self.random_seed = 0
self._layers = layers
self.random_weight_max = random_weight_max
self.cnt = 0
self.check_assert = True
self._steps = steps
self._multiplier = multiplier
assert self._multiplier <= self._steps, "multiplier should be no larger than steps"
self.global_pool = global_pool
self.userelu = userelu
if self.userelu:
self.relu = nn.ReLU(inplace=False)
C_curr = stem_multiplier * C
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
# todo change for few-shot learning, decrease by four times
# if i in [layers // 3, 2 * layers // 3]:
reduction_layers = [layers // 4, 2 * layers // 4, 3 * layers // 4]
if i in reduction_layers:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier * C_curr
self.global_pooling = nn.AdaptiveAvgPool2d(1)
def complex_inference(self, input, test_input=None):
s0 = s1 = self.stem(input)
test_s0 = test_s1 = self.stem(test_input)
for i, cell in enumerate(self.cells):
k = sum(1 for i in range(self._steps) for _ in range(2 + i))
num_ops = len(PRIMITIVES)
weights = torch.ones(k, num_ops, 1, requires_grad=False).cuda() / num_ops
weights_test = torch.ones(k, num_ops, 1, requires_grad=False).cuda() / num_ops
if self.random_weight_max > 0 and self.training:
weights = self.random_masked_weight(weights, keep_number=self.random_weight_max)
weights_test = self.random_masked_weight(weights_test, keep_number=self.random_weight_max)
s0, (s1, hiden_tensors) = s1, cell(s0, s1, weights)
test_s0, (test_s1, hiden_tensors) = test_s1, cell(test_s0, test_s1, weights_test)
if self.global_pool:
out = self.global_pooling(s1)
out_test = self.global_pooling(test_s1)
else:
out = s1
out_test = test_s1
if self.userelu:
out = self.relu(out)
out_test = self.relu(out_test)
return out.view(out.size(0), -1), out_test.view(out_test.size(0), -1)
def forward(self, input, test_input=None):
return self.complex_inference(input, test_input)
def random_masked_weight(self, weights, keep_number):
node_num, branch_num, batch_num = weights.shape
weights = weights.permute(0, 2, 1)
weights = weights.contiguous()
weights = weights.view(node_num * batch_num, branch_num)
weights_rand = torch.rand_like(weights)
weights_sort, indices = weights_rand.sort(dim=1, descending=True)
weights_select = weights_sort[:,
int(keep_number):int(keep_number + 1)]
weights_signes = torch.gt(weights_rand, weights_select)
weights_signes = weights_signes.type(weights.type())
weights = weights * (weights_signes > 0).type(weights.type())
weights = weights / (torch.sum(weights, dim=1, keepdim=True) + 1e-6)
weights = weights.view(node_num, batch_num, branch_num)
weights = weights.permute(0, 2, 1)
weights = weights.contiguous()
return weights
args = parser.parse_args()
from torch.distributed import init_process_group
from torch.cuda import set_device
from torch.nn.parallel.distributed import DistributedDataParallel
class Network2(nn.Module):
def __init__(self):
super(Network2, self).__init__()
self.linear = nn.Linear(1024, 1)
def forward(self, input, label):
return torch.pow(self.linear(input) - label, 2).sum()
def main():
device = torch.device("cuda:{}".format(args.local_rank))
set_device(device)
init_process_group(backend="nccl", init_method="env://")
network = Network(C=64, layers=4)
network.cuda()
network = DistributedDataParallel(network, device_ids=[args.local_rank], output_device=args.local_rank,
broadcast_buffers=True)
loss_func = Network2()
loss_func.cuda()
loss_func = DistributedDataParallel(loss_func, device_ids=[args.local_rank], output_device=args.local_rank,
broadcast_buffers=True)
optim = torch.optim.SGD(lr=0.001, params=network.parameters())
for k in range(10):
print(k)
data1 = torch.rand(12, 3, 80, 80)
data1.cuda()
data2 = torch.rand(12, 3, 80, 80)
data2.cuda()
label = torch.rand(1)
label.cuda()
out1, out2 = network(data1, data2)
sm_results = out1 + out2
loss = loss_func(sm_results, label)
loss.backward()
optim.step()
if __name__ == '__main__':
main() |
|
@pytorchbot retest this please |
|
@caiqi Thanks for the repro. Two instances of DDP should not be a problem as they don't share global state, but both will do a mark and sweep on the autograd graph starting from its outputs. It is possible that the second one is interfering with the first one that way. Can you try and combine both models into a single one? Looking at your code that should be possible. If you do, and it doesn't repro, then at least we would have isolated the issue to being caused by use of multiple DDP models in a single iteration. |
|
Yes, I have tested and combining the two DDP into a single one won't hang. |
|
Both DDP models try to use the same distributed process group. All collective calls across processes are expected to be executed in the same order. With different processes taking a different path through the model, it is possible some reordering happens, where for example process A kicks off a reduction for model 1, and process B kicks off a reduction for model 2. These calls don't match, so they result in a hang. I confirmed that you can fix this problem by forcing the different DDP models to use a different process group instance. With this fix, there is no longer a risk of interference, as they both have their own process group. This is the diff: **$ diff main.py main_new.py
--- main.py 2019-04-11 11:32:37.754243099 -0700
+++ main_new.py 2019-04-11 11:33:02.242294191 -0700
@@ -345,14 +345,17 @@
set_device(device)
init_process_group(backend="nccl", init_method="env://")
+ pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
+ pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
+
network = Network(C=64, layers=4)
network.cuda()
network = DistributedDataParallel(network, device_ids=[args.local_rank], output_device=args.local_rank,
- broadcast_buffers=True)
+ broadcast_buffers=True, process_group=pg1)
loss_func = Network2()
loss_func.cuda()
loss_func = DistributedDataParallel(loss_func, device_ids=[args.local_rank], output_device=args.local_rank,
- broadcast_buffers=True)
+ broadcast_buffers=True, process_group=pg2)
optim = torch.optim.SGD(lr=0.001, params=network.parameters())
for k in range(10): |
|
@pietern It works, thanks for the solutions! |
|
@caiqi Glad to hear. Thanks for confirming this PR works for you. |
Differential Revision: D14580911 Differential Version: 79455513
|
This pull request has been merged in a0263ec. |
|
@pietern, I found that when using the commit with module which returns dict will cause the following error: RuntimeError: next_bucket_ == buckets_.size() ASSERT FAILED at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:419, please report a bug to PyTorch. Expected all buckets to be ready at the end of the backward pass. This is the reproducing code : import argparse
import math
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes, userelu=True):
super(ConvBlock, self).__init__()
self.layers = nn.Sequential()
self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,
kernel_size=3, stride=1, padding=1, bias=False))
self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes))
if userelu:
self.layers.add_module('ReLU', nn.ReLU(inplace=True))
self.layers.add_module(
'MaxPool', nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
def forward(self, x):
out = self.layers(x)
return out
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv_blocks = nn.Sequential(
ConvBlock(3, 3),
nn.ReLU(),
ConvBlock(3, 3),
nn.ReLU(),
ConvBlock(3, 3),
nn.ReLU(),
ConvBlock(3, 3),
nn.ReLU(),
)
def forward(self, x):
blocks = [m for m in self.conv_blocks]
index = [i for i in range(2)]
blocks = [blocks[k] for k in index]
for m in blocks:
x = m(x)
out = x
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
label = torch.ones_like(out)
loss = torch.sum(out - label)
return {"loss": loss}
args = parser.parse_args()
from torch.distributed import init_process_group
from torch.cuda import set_device
from torch.nn.parallel.distributed import DistributedDataParallel
def main():
device = torch.device("cuda:{}".format(args.local_rank))
set_device(device)
init_process_group(backend="nccl", init_method="env://")
network = ConvNet()
network.cuda()
network = DistributedDataParallel(network, device_ids=[args.local_rank], output_device=args.local_rank,
broadcast_buffers=True)
optim = torch.optim.SGD(lr=0.001, params=network.parameters())
for k in range(10):
print(k)
data1 = torch.rand(12, 3, 80, 80)
data1.cuda()
out1 = network(data1)
loss = out1["loss"]
optim.zero_grad()
loss.backward()
optim.step()
if __name__ == '__main__':
main()On my server, it will give the error Commenting these two lines or return the tensor rather than the dict will run without any error. index = [i for i in range(2)]
blocks = [blocks[k] for k in index]It looks strange why the return type of module will influence the behavior of reducer. Do you have some suggestions on this error? |
|
@caiqi Thanks for reporting the issue. This is happening because we don't deal with a dict return value at the moment and only scan a tuple or list for returned tensors. If the reducer doesn't know about the output tensor then it can't mark unused parameters as skipped. This means this problem only shows up when you use a model with unused parameters and return a dict. I added #19354 to track this. |
|
Commenting out those lines means all parameters are used, so none of them have to be pre-emptively marked as ready. Stay tuned for a fix, I'll work on this this morning. |
Summary: Pull Request resolved: #19799 A module that returns multiple outputs and where the called may end up doing multiple calls to torch.autograd.backward did not work with DistributedDataParallel. It expected the first call to torch.autograd.backward to provide gradients for ALL parameters that expect gradients and were used in computing the module output. If you have outputs with disjoint autograd graphs it is fine to call torch.autograd.backward on both and fill in the module's parameter gradients in separate chunks. With this change we delay queuing the finalizer callback until we have marked all buckets as ready, instead of queueing it the first time we receive an autograd hook. This returns the current implementation to be functionally equivalent to the DistributedDataParallel implementation before #18953 was merged. Reviewed By: mrshenli Differential Revision: D15097045 fbshipit-source-id: 2df023319713bc31e29a8b45108c78e6593fccd4
Summary: Pull Request resolved: pytorch#18953 This removes Python side bucketing code from DistributedDataParallel and replaces it with calls to the new C++ based bucketing and reducing code. To confirm this is working well, we ran a test with both the previous implementation and the new implementation, and confirmed they are numerically equivalent. Performance is improved by a couple percent or more, including the single machine multiple GPU runs. Closes pytorch#13273. Reviewed By: mrshenli Differential Revision: D14580911 fbshipit-source-id: 44e76f8b0b7e58dd6c91644e3df4660ca2ee4ae2
Summary: Pull Request resolved: pytorch#19799 A module that returns multiple outputs and where the called may end up doing multiple calls to torch.autograd.backward did not work with DistributedDataParallel. It expected the first call to torch.autograd.backward to provide gradients for ALL parameters that expect gradients and were used in computing the module output. If you have outputs with disjoint autograd graphs it is fine to call torch.autograd.backward on both and fill in the module's parameter gradients in separate chunks. With this change we delay queuing the finalizer callback until we have marked all buckets as ready, instead of queueing it the first time we receive an autograd hook. This returns the current implementation to be functionally equivalent to the DistributedDataParallel implementation before pytorch#18953 was merged. Reviewed By: mrshenli Differential Revision: D15097045 fbshipit-source-id: 2df023319713bc31e29a8b45108c78e6593fccd4
Stack:
:black_circle: #18953 Make DistributedDataParallel use new reducer 💚
This removes Python side bucketing code from DistributedDataParallel
and replaces it with calls to the new C++ based bucketing and reducing
code. To confirm this is working well, we ran a test with both the
previous implementation and the new implementation, and confirmed they
are numerically equivalent.
Performance is improved by a couple percent or more, including the
single machine multiple GPU runs.
Closes #13273.
Differential Revision: D14580911