-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
We have noticed some differences in calculating the num_threads to spawn in different ATen Parallel backends, which, as far as we have observed, may lead to degraded performance of OpenMP backend under some circumstances.
In the OpenMP version:
pytorch/aten/src/ATen/ParallelOpenMP.h
Line 28 in e327df3
| #pragma omp parallel if (!omp_in_parallel() && ((end - begin) >= grain_size)) |
The logic seems to be: either do not parallel at all (if num < grain_size), or spawn as many as omp_get_max_threads() threads.
In other backends, e.g. Native:
pytorch/aten/src/ATen/ParallelNative.h
Lines 41 to 44 in e327df3
| size_t chunk_size = divup((end - begin), get_num_threads()); | |
| // make sure each task is at least grain_size size | |
| chunk_size = std::max((size_t)grain_size, chunk_size); | |
| size_t num_tasks = divup((end - begin), chunk_size); |
The logic is likely something as: min(get_num_threads(), num / grain_size).
We noticed this problem because after upgrading our PyTorch (>= 1.1.0 in our experince), the DataLoader, with pin_memory=True, is saturating our 36C/72T cpus even with some very small datasets (CIFAR10, resize=32, batch_size=40) and with num_threads=1, however, with no significant training speed boost. Further investigation shows that the copy operation to the pinned memory behaves very differently between two releases. After looking at the code we conjecture that this is due to the overhead that OMP tries to spawn an unecessarily large number of threads for small CPU Tensors.
We installed both PyTorch from conda so we are not absolutely certain about the build options. However we provide some code snippets that may possibly reveal the aformentioned problem:
- Simple comparisons on copy operation:
from sys import argv
import torch
print(torch.__version__)
from tqdm import tqdm
N = int(argv[1])
a, b = [torch.randn([N]) for _ in (0, 1)]
for _, i in tqdm(enumerate(range(1000000))):
b.copy_(a)time python test.py 327671.1.0
1000000it [00:12, 83148.11it/s]real 0m12.866s
user 0m12.686s
sys 0m0.154s
time python test.py 327681.1.0
1000000it [00:07, 128799.44it/s]real 0m8.696s
user 4m43.088s
sys 0m0.432s
time OMP_NUM_THREADS=4 python test.py 327681.1.0
1000000it [00:08, 123182.80it/s]real 0m9.168s
user 0m33.244s
sys 0m0.194s
time python test.py 32768 # pytorch 1.0.11.0.1.post2
1000000it [00:53, 18525.88it/s]real 0m55.207s
user 0m54.584s
sys 0m0.182s
time python test.py 10000000 # pytorch 1.0.1, reduced iter number to 1000 to avoid insane running time1.0.1.post2
1000it [00:18, 54.57it/s]real 0m19.421s
user 0m19.192s
sys 0m0.195s
It shows that the cpu usage of copying a 32768-sized Tensor drastically increase to approx. 3600% compared to approx. 100% for a 32767-sized Tensor. Although there was indeed some performance gain, it is only comparable to that when OMP_NUM_THREADS is limited to 4.
In 1.0.1, there is no such problem for a much larger Tensor (of size 10M), although the performance was worse than 1.1.0 given the same size.
- Data Loading test
from sys import argv
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
class Dummy(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, idx):
return torch.zeros([3, 32, 32])
loader = DataLoader(Dummy(), num_workers=int(argv[1]), pin_memory=int(argv[2]), batch_size=int(argv[3]))
loader = iter(loader)
next(loader)
print(torch.__version__)
for _, im in tqdm(enumerate(loader)):
passtime python test2.py 1 1 111.1.0
90909it [02:01, 748.58it/s]real 2m7.478s
user 22m20.561s
sys 121m57.231s
time python test2.py 1 0 111.1.0
90909it [00:47, 1894.42it/s]real 0m49.030s
user 1m5.537s
sys 0m24.925s
time python test2.py 1 1 101.1.0
99999it [01:09, 1439.45it/s]real 1m14.376s
user 1m37.910s
sys 0m39.113s
time OMP_NUM_THREADS=1 python test2.py 1 1 111.1.0
90909it [00:58, 1551.88it/s]real 1m3.575s
user 1m26.514s
sys 0m34.365s
We could observe that: 1) Dataloader with pin_memory=True and num_workers=1 leads to approx. 7200% cpu usage, and most time is in sys, and real speed is slower; 2) operations that avoid the multi-threaded copy (disable pin_memory, reduce batch_size or set
OMP_NUM_THREADS) could reduce the cpu usage, indicating that the problem is very possibly related to the copy operation.
However we have not yet submitted any PR for this, because we are not sure e.g. whether this is some side-effect of more important 'features', or whether the team is planning for a refactoring of the relevant code, etc. If it does need a fix, we are happy to provide one.