-
Notifications
You must be signed in to change notification settings - Fork 27.4k
PyTorch unfold could be faster #60466
Copy link
Copy link
Open
Labels
module: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
Torch implementation of Unfold is slower than it could be. I provide a comparison with a simple implementation based on as_strided which is both faster and more memory efficient.
To Reproduce
Run the following code on a machine with a cuda GPU. Or alternatively check out this colab.
from contextlib import contextmanager
import gc
import time
import torch
from torch.nn import functional as F
class Result:
pass
def unfold1d(input, kernel_size: int, stride: int):
*shape, length = input.shape
n_frames = (max(length, kernel_size) - kernel_size) // stride + 1
tgt_length = (n_frames - 1) * stride + kernel_size
input = input[..., :tgt_length].contiguous()
strides = list(input.stride())
strides = strides[:-1] + [stride, 1]
out = input.as_strided(shape + [n_frames, kernel_size], strides)
return out.transpose(-1, -2)
def torch_unfold(x, kernel, stride):
B, C, T = x.shape
frames = F.unfold(x[:, :, None], kernel_size=[1, kernel], stride=[1, stride])
frames = frames.reshape(B, C, kernel, -1)
return frames
@contextmanager
def measure():
gc.collect()
torch.cuda.reset_max_memory_allocated()
torch.cuda.empty_cache()
result = Result()
begin = time.time()
try:
yield result
finally:
torch.cuda.synchronize()
result.duration = time.time() - begin
result.mem = torch.cuda.max_memory_allocated() / 2**20
def compare(kernel, stride):
print("For", kernel, stride)
x = torch.randn(1, 1, 160000, device='cuda')
with measure() as r1:
frames = unfold1d(x, kernel, stride)
with measure() as r2:
frames2 = torch_unfold(x, kernel, stride)
print(f'time unfold1d / time torch unfold {r1.duration / r2.duration:.4f}')
print(f'memory unfold1d / time torch unfold {r1.mem / r2.mem:.4f}')
assert frames.shape == frames2.shape and (frames == frames2).all()
compare(64, 8)
compare(1024, 256)
compare(1024, 167)
compare(2048, 190)Edit: here is the output from running the google Colab (first ratio is time and second ratio is memory, sorry for the typo in the print).
For 64 8
time unfold1d / time torch unfold 0.4795
memory unfold1d / time torch unfold 0.5263
For 1024 256
time unfold1d / time torch unfold 0.3702
memory unfold1d / time torch unfold 0.6675
For 1024 167
time unfold1d / time torch unfold 0.1427
memory unfold1d / time torch unfold 0.5850
For 2048 190
time unfold1d / time torch unfold 0.0357
memory unfold1d / time torch unfold 0.4618
The proposed implementation uses half the memory and can be up to 20x faster.
Expected behavior
I expected PyTorch unfold to be more efficient in terms of memory and speed.
Environment
PyTorch version: 1.9.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26
Python version: 3.7 (64-bit runtime)
Python platform: Linux-5.4.104+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: True
CUDA runtime version: 11.0.221
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.9.0+cu102
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.10.0
[pip3] torchvision==0.10.0+cu102
[conda] Could not collect
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
module: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module