77import traceback
88from contextlib import contextmanager
99from enum import IntEnum
10- from queue import Queue
10+ from queue import Empty , Queue
1111from typing import Callable , Dict , Iterable , List , Optional , Tuple , Union
1212
1313import torch
@@ -1572,6 +1572,16 @@ def handle_executed_batches(executed_batch_num: int):
15721572 self ._handle_executed_batch (executed_batch )
15731573 self .unhandled_batch_counter -= 1
15741574
1575+ def _get_executed_batch (self ):
1576+ while True :
1577+ try :
1578+ return self .executed_batch_queue .get (timeout = 0.001 )
1579+ except Empty :
1580+ # Calling MPI_Test on pending isend handles while idle to prevent potential hangs.
1581+ for handle in self .send_handles :
1582+ if handle is not None :
1583+ handle .test ()
1584+
15751585 def _broadcast_sample_state_loop (self ):
15761586 logger .debug (
15771587 f"Starting broadcast sample state loop for pp_rank { self .dist .pp_rank } "
@@ -1588,17 +1598,10 @@ def _broadcast_sample_state_loop(self):
15881598 new_mpi_comm = mpi_comm ().Dup ()
15891599 set_thread_local_mpi_comm (new_mpi_comm )
15901600 while True :
1591- executed_batch = self .executed_batch_queue . get ()
1601+ executed_batch = self ._get_executed_batch ()
15921602 if executed_batch is None :
15931603 break
15941604 self ._ring_broadcast_sample_state (executed_batch )
1595- # Flush the last isend before this thread goes idle on
1596- # queue.get() — otherwise no MPI call will be made to drive
1597- # progress and the non-blocking send data will never reach
1598- # the receiver, causing a deadlock.
1599- if self .executed_batch_queue .empty ():
1600- self .wait_on_pp_send_handles (self .send_handles ,
1601- executed_batch .microbatch_id )
16021605 set_thread_local_mpi_comm (None )
16031606 new_mpi_comm .Free ()
16041607
0 commit comments