-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Interpolate(antialias=False) is 8X slower than antialias=True depending on the tensor stride #83840
Description
On a float32 tensor of shape (3, 64, 64) with stride (4096, 64, 1), calling
torch.nn.functional.interpolate(x, size=[224, 224], mode="bilinear", align_corners=False, antialias=antialias)takes ~175 μs with antialias=True and ~129μs with antialias=False. So far so good.
However, if we change the stride to (1, 192, 3) then antialias=False suddenly becomes 8X slower: 1471.8 μs (vs 176.4 μs for antialias=True, so fairly similar as before).
Perhaps this is expected but I was assuming that antialias=False is less work than antialias=True, and so we would expect it to be slightly faster?
Reproducing example:
Details
import torch
from time import time
def bench(f, inp, num_exp=1000, num_prime=10):
for _ in range(num_prime):
f(inp)
times = []
for _ in range(num_exp):
start = time()
f(inp)
end = time()
times.append((end - start))
median = torch.median(torch.tensor(times))
print(f"Median over {num_exp} exp= {median * 1e6 :.1f} μs")
return median
def interpolate(x, antialias):
torch.nn.functional.interpolate(x, size=[224, 224], mode="bilinear", align_corners=False, antialias=antialias)
tensor_img = torch.randint(0, 256, (3, 64, 64), dtype=torch.float32)
assert tensor_img.stride() == (4096, 64, 1)
tensor_img = tensor_img[None, :, :, :] # add batch dim for call to interpolate()
bench(lambda x: interpolate(x, antialias=True), tensor_img) # 175.7 μs
bench(lambda x: interpolate(x, antialias=False), tensor_img) # 129.7 μstensor_img = torch.randint(0, 256, (3, 64, 64), dtype=torch.float32)
assert tensor_img.stride() == (4096, 64, 1)
tensor_img = tensor_img.as_strided(tensor_img.size(), stride=(1, 192, 3))
tensor_img = tensor_img[None, :, :, :] # add batch dim for call to interpolate()
bench(lambda x: interpolate(x, antialias=True), tensor_img) # 176.4 μs
bench(lambda x: interpolate(x, antialias=False), tensor_img) # 1471.8 μs 8X slower!!!System Info
TL;DR: pytorch version == nightly 1.13.0.dev20220712
Details
PyTorch version: 1.13.0.dev20220712 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/AOS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (conda-forge gcc 9.5.0-16) 9.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.22.3
Libc version: glibc-2.27
Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:56:21) [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-1069-aws-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.0.221
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] pytorch-pfn-extras==0.5.8
[pip3] torch==1.13.0.dev20220712
[pip3] torchdata==0.5.0a0+f1a128e
[pip3] torchvision==0.14.0a0+a35ef88
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] libblas 3.9.0 14_linux64_mkl conda-forge
[conda] libcblas 3.9.0 14_linux64_mkl conda-forge
[conda] liblapack 3.9.0 14_linux64_mkl conda-forge
[conda] liblapacke 3.9.0 14_linux64_mkl conda-forge
[conda] mkl 2022.0.1 h06a4308_117
[conda] numpy 1.22.3 pypi_0 pypi
[conda] pytorch 1.13.0.dev20220712 py3.9_cuda11.3_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] pytorch-pfn-extras 0.5.8 pypi_0 pypi
[conda] torchdata 0.5.0a0+f1a128e dev_0
[conda] torchvision 0.14.0a0+a35ef88 dev_0