@@ -596,14 +596,16 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
596596 currentTimepoint - workStartTime_);
597597 auto workTimeout = timeout ? *timeout : opTimeout_;
598598
599- if (timeElapsed < workTimeout)
599+ if (timeElapsed < workTimeout) {
600600 return false ;
601+ }
601602
602603 // Timed out
603604
604605 // There is already an error, we don't override it
605- if (exception ())
606+ if (exception ()) {
606607 return true ;
608+ }
607609
608610 std::string exceptionMsg = c10::str (
609611 logPrefix (),
@@ -640,9 +642,28 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
640642}
641643
642644void ProcessGroupNCCL::WorkNCCL::synchronize () {
643- // Call Synchronize without a timeout. We use this method to avoid adding a
644- // timeout argument to the public synchronize API.
645- synchronizeInternal (kNoTimeout );
645+ synchronizeStream ();
646+
647+ // Device synchronize only after we've completed timeout checks.
648+ // TODO: Is this necessary for barrier if we block the cpu thread till
649+ // the completion of the work?
650+ if (barrierTensor_.defined ()) {
651+ // If we use the work to do barrier, we should block here
652+ // `dist.barrier()` only requires all CPU processes to enter this
653+ // function, hence we only need to make sure the dummy all-reduce has
654+ // completed. So we would only need to sync the **current stream** back to
655+ // host, and do not need to synchronize the entire device (which may have
656+ // kernels running on other streams).
657+ // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can:
658+ // - lower chance of hang;
659+ // - CurrentCUDAStream is usually the context of the next operation in
660+ // Python, thus blocking current stream would already block the next
661+ // compute kernel;
662+ // - achieve better barrier performance.
663+ auto currentStream = at::cuda::getCurrentCUDAStream (device_.index ());
664+ // CUDAStream wrapper will correctly use a DeviceGuard here
665+ currentStream.synchronize ();
666+ }
646667}
647668
648669void ProcessGroupNCCL::WorkNCCL::synchronizeStream () {
@@ -655,13 +676,25 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
655676 }
656677}
657678
658- // Waiting on the work's corresponding CUDA events
659- void ProcessGroupNCCL::WorkNCCL::synchronizeInternal (
660- std::chrono::milliseconds timeout) {
661- synchronizeStream ();
679+ // Same as calling synchronize() when blockingWait_ is false
680+ bool ProcessGroupNCCL::WorkNCCL::wait (std::chrono::milliseconds timeout) {
681+ RECORD_PARAM_COMMS (
682+ static_cast <int >(this ->seq_ ), // seq
683+ std::make_tuple (pgUID_, pgDesc_), // PG name tuple
684+ rank_, // rank
685+ " wait" , // collective name
686+ 0 , // inNelems
687+ 0 , // outNelems
688+ at::kByte , // dType
689+ std::vector<int64_t >(), // inSplitSizes
690+ std::vector<int64_t >(), // outSplitSizes
691+ -1 ,
692+ -1 ,
693+ static_cast <int >(1 )); // number of device?
662694
663- // In case of blocking, wait for the operation to complete.
664- if (blockingWait_) {
695+ // In case of blockingWait or a timeout value is specified by the user, we
696+ // block the CPU thread until the work is completed or timed out.
697+ if (blockingWait_ || timeout != kNoTimeout ) {
665698 while (!isCompleted ()) {
666699 bool timedOut = checkTimeout (
667700 timeout == kNoTimeout ? std::nullopt : std::make_optional (timeout));
@@ -672,18 +705,15 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
672705 // can not run new events successfully.
673706 if (timedOut) {
674707 std::string exceptionMsg = c10::str (
675- logPrefix (),
676- " Work " ,
677- (*this ),
678- " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1)." );
708+ logPrefix (), " Work " , (*this ), " timed out in blocking wait." );
679709 LOG (ERROR) << exceptionMsg;
680710 break ;
681711 }
682712 // Yield
683713 std::this_thread::sleep_for (
684714 std::chrono::milliseconds (kSynchronizeBusyWaitMillis ));
685715 }
686- // exception() includes timeout and error during blocking wait
716+
687717 if (exception ()) {
688718 // Abort NCCL communicators
689719 abort ();
@@ -692,42 +722,9 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
692722 }
693723 }
694724
695- // Device synchronize only after we've completed timeout checks.
696- if (barrierTensor_.defined ()) {
697- // If we use the work to do barrier, we should block here
698- // `dist.barrier()` only requires all CPU processes to enter this
699- // function, hence we only need to make sure the dummy all-reduce has
700- // completed. So we would only need to sync the **current stream** back to
701- // host, and do not need to synchronize the entire device (which may have
702- // kernels running on other streams).
703- // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can:
704- // - lower chance of hang;
705- // - CurrentCUDAStream is usually the context of the next operation in
706- // Python, thus blocking current stream would already block the next
707- // compute kernel;
708- // - achieve better barrier performance.
709- auto currentStream = at::cuda::getCurrentCUDAStream (device_.index ());
710- // CUDAStream wrapper will correctly use a DeviceGuard here
711- currentStream.synchronize ();
712- }
713- }
725+ // syncrhoize() will block the current stream on the NCCL stream
726+ synchronize ();
714727
715- // Same as calling synchronize().
716- bool ProcessGroupNCCL::WorkNCCL::wait (std::chrono::milliseconds timeout) {
717- RECORD_PARAM_COMMS (
718- static_cast <int >(this ->seq_ ), // seq
719- std::make_tuple (pgUID_, pgDesc_), // PG name tuple
720- rank_, // rank
721- " wait" , // collective name
722- 0 , // inNelems
723- 0 , // outNelems
724- at::kByte , // dType
725- std::vector<int64_t >(), // inSplitSizes
726- std::vector<int64_t >(), // outSplitSizes
727- -1 ,
728- -1 ,
729- static_cast <int >(1 )); // number of device?
730- synchronizeInternal (timeout);
731728 // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL
732729 // upgrade. Once a NCCL version is qualified, this code should not be needed
733730 // at runtime.
0 commit comments