@@ -2554,6 +2554,24 @@ def _test_nccl_errors_blocking(self, func):
25542554 del process_group
25552555 func ()
25562556
2557+ def _test_barrier_error (self ):
2558+ store = c10d .FileStore (self .file_name , self .world_size )
2559+ process_group = c10d .ProcessGroupNCCL (
2560+ store ,
2561+ self .rank ,
2562+ self .world_size ,
2563+ timeout = timedelta (seconds = 10 ),
2564+ )
2565+ process_group .barrier ().wait ()
2566+ if self .rank == 0 :
2567+ with self .assertRaisesRegex (dist .DistBackendError , "" ):
2568+ # It seems the error message would be different depending on
2569+ # whether the test is run on CI machine and devGPU. Skipping
2570+ # the error message check to make both sides happy.
2571+ process_group .barrier ().wait (
2572+ timeout = timedelta (seconds = self .op_timeout_sec )
2573+ )
2574+
25572575 @with_nccl_blocking_wait
25582576 @requires_nccl ()
25592577 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
@@ -2602,22 +2620,23 @@ def test_nccl_errors_blocking_sigterm(self):
26022620 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
26032621 @skip_if_lt_x_gpu (3 )
26042622 def test_nccl_blocking_wait_with_barrier (self ):
2605- store = c10d .FileStore (self .file_name , self .world_size )
2606- process_group = c10d .ProcessGroupNCCL (
2607- store ,
2608- self .rank ,
2609- self .world_size ,
2610- timeout = timedelta (seconds = 10 ),
2623+ self ._test_barrier_error ()
2624+
2625+ @requires_nccl ()
2626+ @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2627+ @skip_if_lt_x_gpu (3 )
2628+ def test_nccl_non_blocking_wait_with_barrier (self ):
2629+ # test the barrier behavior in the non blocking wait setting
2630+ prev_nccl_async_error_handling = os .environ .get (
2631+ "TORCH_NCCL_ASYNC_ERROR_HANDLING" , None
26112632 )
2612- process_group .barrier ().wait ()
2613- if self .rank == 0 :
2614- with self .assertRaisesRegex (dist .DistBackendError , "" ):
2615- # It seems the error message would be different depending on
2616- # whether the test is run on CI machine and devGPU. Skipping
2617- # the error message check to make both sides happy.
2618- process_group .barrier ().wait (
2619- timeout = timedelta (seconds = self .op_timeout_sec )
2620- )
2633+ # avoid watchdog thread interference
2634+ os .environ ["TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = "0"
2635+ self ._test_barrier_error ()
2636+ if prev_nccl_async_error_handling is not None :
2637+ os .environ [
2638+ "TORCH_NCCL_ASYNC_ERROR_HANDLING"
2639+ ] = prev_nccl_async_error_handling
26212640
26222641 def _run_invalid_nccl_blocking_wait_env (self , val ):
26232642 os .environ ["TORCH_NCCL_BLOCKING_WAIT" ] = val
0 commit comments