Skip to content

Checkpointing is slow on nn.DataParallel models #7801

@gpleiss

Description

@gpleiss

@wandering007 pointed out this issue in https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py

I have a model that uses checkpointing on several layers. On a single GPU, the model runs fairly fast (e.g. only a 15-20% overhead). On multiple GPUs, using an nn.DataParallel @wandering007 claims that the model runs up to 100x slower.

Here's the important snippets of the model.

def _bn_function_factory(norm, relu, conv):
    def bn_function(*inputs):
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = conv(relu(norm(concated_features)))
        return bottleneck_output
    return bn_function

class _DenseLayer(nn.Module):
    # ...

    def forward(self, *prev_features):
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        # ...

There are several _DenseLayers throughout the model.
@wandering007 seems to think that the issue has to do with GPU synchronization? As in, the models must synchronize at every checkpoint during the backward pass.

Original issue is here: gpleiss/efficient_densenet_pytorch#36
Full code of the model is here: https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: activation checkpointingRelated to activation checkpointingmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions