Skip to content

Upsample with a trilinear interpolation works at least 10x slower using Mixed Precision than with FP32. #44206

@mmarcinkiewicz

Description

@mmarcinkiewicz

🐛 Bug

Upsample with a trilinear interpolation works at least 10x slower using Mixed Precision than with FP32.

To Reproduce

Steps to reproduce the behavior:

  1. Run python debug.py --input_shape 8 4 4 4 --batch_size 4 --iterations 100
  2. Run python debug.py --input_shape 8 4 4 4 --batch_size 4 --iterations 100 --amp
  3. Compare results

Expected behavior

Upsample with a trilinear interpolation works at least as fast using Mixed Precision as FP32

Environment

I tried it in nvidia containers back to 20.01 so it seems it's pytorch version agnostic.

Additional context

It seems that the main culprit is backprop. Fprop is actually faster in Mixed Precision.
kernel breakdown:
FP32:
image

FP16:
image

code:

import argparse
from time import time
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(in_channels=8, out_channels=1024, kernel_size=(3, 3, 3), bias=False, padding=(1, 1, 1))
        self.upsample = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        return x


def main():
    flags = parse_arguments()
    torch.backends.cudnn.benchmark = True
    nvtx_enabled = False
    if flags.profile:
        import torch.cuda.profiler as profiler
        import pyprof
        pyprof.init(enable_function_stack=True)
        nvtx_enabled = True
        print("Profiling is enabled")

    model = Model()
    device = torch.device("cuda:0")
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), 1e-8)
    scaler = GradScaler()

    data = torch.rand(size=(flags.batch_size, *flags.input_shape)).to(device)
    with torch.autograd.profiler.emit_nvtx(enabled=nvtx_enabled):
        model.train()
        iteration = 0
        start_marker = time()
        for i in range(flags.iterations):
            iteration += 1
            if flags.profile and iteration == 10:
                profiler.start()
            optimizer.zero_grad()

            with autocast(enabled=flags.amp):
                outputs = model(data)
                loss = torch.sum(outputs)

            if flags.amp:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
        if flags.profile:
            profiler.stop()
        elapsed = time() - start_marker
        print("Finished in", round(elapsed, 2),
              "s. One iteration took on average", round(elapsed/flags.iterations), "s.")


def parse_arguments():
    PARSER = argparse.ArgumentParser(description="UNet-3D")
    PARSER.add_argument('--amp', dest='amp', action='store_true', default=False)
    PARSER.add_argument('--batch_size', dest='batch_size', type=int, default=1)
    PARSER.add_argument('--input_shape', nargs='+', type=int, default=[32, 64, 64, 64])
    PARSER.add_argument('--profile', dest='profile', action='store_true', default=False)
    PARSER.add_argument('--iterations', dest='iterations', type=int, default=100)
    return PARSER.parse_args()


if __name__ == '__main__':
    main()

cc @VitalyFedyunin @ngimel

Metadata

Metadata

Assignees

Labels

module: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions