-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Checkpointing is slow on nn.DataParallel models #7801
Description
@wandering007 pointed out this issue in https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py
I have a model that uses checkpointing on several layers. On a single GPU, the model runs fairly fast (e.g. only a 15-20% overhead). On multiple GPUs, using an nn.DataParallel @wandering007 claims that the model runs up to 100x slower.
Here's the important snippets of the model.
def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Module):
# ...
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
# ...There are several _DenseLayers throughout the model.
@wandering007 seems to think that the issue has to do with GPU synchronization? As in, the models must synchronize at every checkpoint during the backward pass.
Original issue is here: gpleiss/efficient_densenet_pytorch#36
Full code of the model is here: https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py