Skip to content

Commit 76dca1f

Browse files
shuqiangzhangpytorchmergebot
authored andcommitted
[c10d] separate the codes for GPU stream synchronization and CPU thread synchronization (#137295)
code Summary: This PR should not change the existing behavior of work.wait(), just separate the stream synchronization code from the CPU busy wait code. Also, remove the need of a private synchronization function. In a longer term, we would like to give user the flexibility of bypassing the watchdog thread and handle the collective error by themselves. Test Plan: python test/distributed/test_c10d_nccl.py NcclErrorHandlingTest Pull Request resolved: #137295 Approved by: https://github.com/kwen2501
1 parent 9f9d252 commit 76dca1f

File tree

4 files changed

+105
-74
lines changed

4 files changed

+105
-74
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2554,6 +2554,24 @@ def _test_nccl_errors_blocking(self, func):
25542554
del process_group
25552555
func()
25562556

2557+
def _test_barrier_error(self):
2558+
store = c10d.FileStore(self.file_name, self.world_size)
2559+
process_group = c10d.ProcessGroupNCCL(
2560+
store,
2561+
self.rank,
2562+
self.world_size,
2563+
timeout=timedelta(seconds=10),
2564+
)
2565+
process_group.barrier().wait()
2566+
if self.rank == 0:
2567+
with self.assertRaisesRegex(dist.DistBackendError, ""):
2568+
# It seems the error message would be different depending on
2569+
# whether the test is run on CI machine and devGPU. Skipping
2570+
# the error message check to make both sides happy.
2571+
process_group.barrier().wait(
2572+
timeout=timedelta(seconds=self.op_timeout_sec)
2573+
)
2574+
25572575
@with_nccl_blocking_wait
25582576
@requires_nccl()
25592577
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@@ -2602,22 +2620,23 @@ def test_nccl_errors_blocking_sigterm(self):
26022620
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
26032621
@skip_if_lt_x_gpu(3)
26042622
def test_nccl_blocking_wait_with_barrier(self):
2605-
store = c10d.FileStore(self.file_name, self.world_size)
2606-
process_group = c10d.ProcessGroupNCCL(
2607-
store,
2608-
self.rank,
2609-
self.world_size,
2610-
timeout=timedelta(seconds=10),
2623+
self._test_barrier_error()
2624+
2625+
@requires_nccl()
2626+
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2627+
@skip_if_lt_x_gpu(3)
2628+
def test_nccl_non_blocking_wait_with_barrier(self):
2629+
# test the barrier behavior in the non blocking wait setting
2630+
prev_nccl_async_error_handling = os.environ.get(
2631+
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
26112632
)
2612-
process_group.barrier().wait()
2613-
if self.rank == 0:
2614-
with self.assertRaisesRegex(dist.DistBackendError, ""):
2615-
# It seems the error message would be different depending on
2616-
# whether the test is run on CI machine and devGPU. Skipping
2617-
# the error message check to make both sides happy.
2618-
process_group.barrier().wait(
2619-
timeout=timedelta(seconds=self.op_timeout_sec)
2620-
)
2633+
# avoid watchdog thread interference
2634+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
2635+
self._test_barrier_error()
2636+
if prev_nccl_async_error_handling is not None:
2637+
os.environ[
2638+
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
2639+
] = prev_nccl_async_error_handling
26212640

26222641
def _run_invalid_nccl_blocking_wait_env(self, val):
26232642
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -596,14 +596,16 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
596596
currentTimepoint - workStartTime_);
597597
auto workTimeout = timeout ? *timeout : opTimeout_;
598598

599-
if (timeElapsed < workTimeout)
599+
if (timeElapsed < workTimeout) {
600600
return false;
601+
}
601602

602603
// Timed out
603604

604605
// There is already an error, we don't override it
605-
if (exception())
606+
if (exception()) {
606607
return true;
608+
}
607609

608610
std::string exceptionMsg = c10::str(
609611
logPrefix(),
@@ -640,9 +642,28 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
640642
}
641643

642644
void ProcessGroupNCCL::WorkNCCL::synchronize() {
643-
// Call Synchronize without a timeout. We use this method to avoid adding a
644-
// timeout argument to the public synchronize API.
645-
synchronizeInternal(kNoTimeout);
645+
synchronizeStream();
646+
647+
// Device synchronize only after we've completed timeout checks.
648+
// TODO: Is this necessary for barrier if we block the cpu thread till
649+
// the completion of the work?
650+
if (barrierTensor_.defined()) {
651+
// If we use the work to do barrier, we should block here
652+
// `dist.barrier()` only requires all CPU processes to enter this
653+
// function, hence we only need to make sure the dummy all-reduce has
654+
// completed. So we would only need to sync the **current stream** back to
655+
// host, and do not need to synchronize the entire device (which may have
656+
// kernels running on other streams).
657+
// Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can:
658+
// - lower chance of hang;
659+
// - CurrentCUDAStream is usually the context of the next operation in
660+
// Python, thus blocking current stream would already block the next
661+
// compute kernel;
662+
// - achieve better barrier performance.
663+
auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
664+
// CUDAStream wrapper will correctly use a DeviceGuard here
665+
currentStream.synchronize();
666+
}
646667
}
647668

648669
void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
@@ -655,13 +676,25 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
655676
}
656677
}
657678

658-
// Waiting on the work's corresponding CUDA events
659-
void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
660-
std::chrono::milliseconds timeout) {
661-
synchronizeStream();
679+
// Same as calling synchronize() when blockingWait_ is false
680+
bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
681+
RECORD_PARAM_COMMS(
682+
static_cast<int>(this->seq_), // seq
683+
std::make_tuple(pgUID_, pgDesc_), // PG name tuple
684+
rank_, // rank
685+
"wait", // collective name
686+
0, // inNelems
687+
0, // outNelems
688+
at::kByte, // dType
689+
std::vector<int64_t>(), // inSplitSizes
690+
std::vector<int64_t>(), // outSplitSizes
691+
-1,
692+
-1,
693+
static_cast<int>(1)); // number of device?
662694

663-
// In case of blocking, wait for the operation to complete.
664-
if (blockingWait_) {
695+
// In case of blockingWait or a timeout value is specified by the user, we
696+
// block the CPU thread until the work is completed or timed out.
697+
if (blockingWait_ || timeout != kNoTimeout) {
665698
while (!isCompleted()) {
666699
bool timedOut = checkTimeout(
667700
timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout));
@@ -672,18 +705,15 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
672705
// can not run new events successfully.
673706
if (timedOut) {
674707
std::string exceptionMsg = c10::str(
675-
logPrefix(),
676-
"Work ",
677-
(*this),
678-
" timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1).");
708+
logPrefix(), "Work ", (*this), " timed out in blocking wait.");
679709
LOG(ERROR) << exceptionMsg;
680710
break;
681711
}
682712
// Yield
683713
std::this_thread::sleep_for(
684714
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
685715
}
686-
// exception() includes timeout and error during blocking wait
716+
687717
if (exception()) {
688718
// Abort NCCL communicators
689719
abort();
@@ -692,42 +722,9 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
692722
}
693723
}
694724

695-
// Device synchronize only after we've completed timeout checks.
696-
if (barrierTensor_.defined()) {
697-
// If we use the work to do barrier, we should block here
698-
// `dist.barrier()` only requires all CPU processes to enter this
699-
// function, hence we only need to make sure the dummy all-reduce has
700-
// completed. So we would only need to sync the **current stream** back to
701-
// host, and do not need to synchronize the entire device (which may have
702-
// kernels running on other streams).
703-
// Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can:
704-
// - lower chance of hang;
705-
// - CurrentCUDAStream is usually the context of the next operation in
706-
// Python, thus blocking current stream would already block the next
707-
// compute kernel;
708-
// - achieve better barrier performance.
709-
auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
710-
// CUDAStream wrapper will correctly use a DeviceGuard here
711-
currentStream.synchronize();
712-
}
713-
}
725+
// syncrhoize() will block the current stream on the NCCL stream
726+
synchronize();
714727

715-
// Same as calling synchronize().
716-
bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
717-
RECORD_PARAM_COMMS(
718-
static_cast<int>(this->seq_), // seq
719-
std::make_tuple(pgUID_, pgDesc_), // PG name tuple
720-
rank_, // rank
721-
"wait", // collective name
722-
0, // inNelems
723-
0, // outNelems
724-
at::kByte, // dType
725-
std::vector<int64_t>(), // inSplitSizes
726-
std::vector<int64_t>(), // outSplitSizes
727-
-1,
728-
-1,
729-
static_cast<int>(1)); // number of device?
730-
synchronizeInternal(timeout);
731728
// TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL
732729
// upgrade. Once a NCCL version is qualified, this code should not be needed
733730
// at runtime.

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
301301

302302
bool isSuccess() const override;
303303

304-
// Same as calling synchronize() for NCCL work.
304+
// Same as calling synchronize() for NCCL work if timeout is not set.
305+
// Otherwise, it will block the CPU thread until the NCCL work is completed
306+
// or timed out. If timeout, exception will be thrown.
305307
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
306308

307309
void abort() override;
308310

309-
// Let current stream wait on the completing of the NCCL work
310-
// Throws on exceptions. Blocking operation, which will wait for work
311-
// completion.
311+
// Let current stream wait on the completion of the NCCL work
312+
// Throws on exceptions.
312313
void synchronize() override;
313314

314315
// Synchronize streams by blocking each on the NCCL stream
@@ -404,9 +405,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
404405
const WorkNCCL& workNCCL);
405406

406407
private:
407-
// Helper function for synchronize
408-
void synchronizeInternal(std::chrono::milliseconds timeout);
409-
410408
// Checks for NCCL errors and sets an appropriate exception_ptr.
411409
void checkAndSetException();
412410

torch/csrc/distributed/c10d/init.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2988,7 +2988,24 @@ such as `dist.all_reduce(tensor, async_op=True)`.
29882988
"wait",
29892989
&::c10d::Work::wait,
29902990
py::arg("timeout") = kNoTimeout,
2991-
py::call_guard<py::gil_scoped_release>())
2991+
py::call_guard<py::gil_scoped_release>(),
2992+
R"(
2993+
Returns:
2994+
true/false.
2995+
2996+
Example::
2997+
try:
2998+
work.wait(timeout)
2999+
except:
3000+
# some handling
3001+
3002+
.. warning ::
3003+
In normal cases, users do not need to set the timeout.
3004+
calling wait() is the same as calling synchronize():
3005+
Letting the current stream block on the completion of the NCCL work.
3006+
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
3007+
or timed out. If timeout, exception will be thrown.
3008+
)")
29923009
.def(
29933010
"get_future",
29943011
[](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {

0 commit comments

Comments
 (0)