Skip to content

Commit adbcaee

Browse files
janeyx99pytorchmergebot
authored andcommitted
Init threadpool with user defined num_threads before default (#136793)
Fixes #134714 (or attempts to, idk how to test yet) For posterity, how one can test: 1. make sure you have USE_PTHREADPOOL=1 or pull a packaged binary 2. run gdb --args python, with `r` to enter, `Ctrl-C` to pause, and `c` to get back into Python 3. import torch 4. torch.set_num_threads(1), make sure this does not trigger any additional threads getting created. Pull Request resolved: #136793 Approved by: https://github.com/albanD
1 parent bc21689 commit adbcaee

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

aten/src/ATen/ParallelNative.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ void init_num_threads() {
209209
}
210210

211211
void set_num_threads(int nthreads) {
212-
#ifndef C10_MOBILE
213212
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
213+
#ifndef C10_MOBILE
214214
int no_value = NOT_SET;
215215
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
216216
// num_intraop_threads either stores a positive integer or CONSUMED,
@@ -229,9 +229,8 @@ void set_num_threads(int nthreads) {
229229
}
230230
}
231231
#else
232-
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
232+
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
233233
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
234-
pool->set_thread_count(nthreads);
235234
#endif // C10_MOBILE
236235
}
237236

aten/src/ATen/ParallelOpenMP.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ void set_num_threads(int nthreads) {
6161
#endif
6262
#ifdef USE_PTHREADPOOL
6363
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
64-
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
64+
caffe2::PThreadPool* const pool = caffe2::pthreadpool(nthreads);
6565
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
66-
pool->set_thread_count(nthreads);
6766
#endif
6867
#if AT_MKLDNN_ENABLED()
6968
at::native::mkldnn::clear_computation_cache();

caffe2/utils/threadpool/pthreadpool-cpp.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,9 @@ void PThreadPool::run(
8282
0u);
8383
}
8484

85-
// Forward declaration
86-
size_t getDefaultNumThreads();
87-
88-
PThreadPool* pthreadpool() {
85+
PThreadPool* pthreadpool(size_t thread_count) {
8986
static auto threadpool =
90-
std::make_unique<PThreadPool>(getDefaultNumThreads());
87+
std::make_unique<PThreadPool>(thread_count);
9188
#if !(defined(WIN32))
9289
static std::once_flag flag;
9390
std::call_once(flag, []() {
@@ -105,6 +102,13 @@ PThreadPool* pthreadpool() {
105102
return threadpool.get();
106103
}
107104

105+
// Forward declaration
106+
size_t getDefaultNumThreads();
107+
108+
PThreadPool* pthreadpool() {
109+
return pthreadpool(getDefaultNumThreads());
110+
}
111+
108112
pthreadpool_t pthreadpool_() {
109113
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
110114
return nullptr;

caffe2/utils/threadpool/pthreadpool-cpp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class PThreadPool final {
4242

4343
// Return a singleton instance of PThreadPool for ATen/TH multithreading.
4444
PThreadPool* pthreadpool();
45+
PThreadPool* pthreadpool(size_t thread_count);
4546

4647
// Exposes the underlying implementation of PThreadPool.
4748
// Only for use in external libraries so as to unify threading across

0 commit comments

Comments
 (0)