Skip to content

Commit cfa2258

Browse files
committed
[https://nvbugs/6050489][fix] fix agg pp4 hang issue
Signed-off-by: Bo Deng <[email protected]>
1 parent 2dff089 commit cfa2258

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import traceback
88
from contextlib import contextmanager
99
from enum import IntEnum
10-
from queue import Queue
10+
from queue import Empty, Queue
1111
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1212

1313
import torch
@@ -1572,6 +1572,16 @@ def handle_executed_batches(executed_batch_num: int):
15721572
self._handle_executed_batch(executed_batch)
15731573
self.unhandled_batch_counter -= 1
15741574

1575+
def _get_executed_batch(self):
1576+
while True:
1577+
try:
1578+
return self.executed_batch_queue.get(timeout=0.001)
1579+
except Empty:
1580+
# Calling MPI_Test on pending isend handles while idle to prevent potential hangs.
1581+
for handle in self.send_handles:
1582+
if handle is not None:
1583+
handle.test()
1584+
15751585
def _broadcast_sample_state_loop(self):
15761586
logger.debug(
15771587
f"Starting broadcast sample state loop for pp_rank {self.dist.pp_rank}"
@@ -1588,17 +1598,10 @@ def _broadcast_sample_state_loop(self):
15881598
new_mpi_comm = mpi_comm().Dup()
15891599
set_thread_local_mpi_comm(new_mpi_comm)
15901600
while True:
1591-
executed_batch = self.executed_batch_queue.get()
1601+
executed_batch = self._get_executed_batch()
15921602
if executed_batch is None:
15931603
break
15941604
self._ring_broadcast_sample_state(executed_batch)
1595-
# Flush the last isend before this thread goes idle on
1596-
# queue.get() — otherwise no MPI call will be made to drive
1597-
# progress and the non-blocking send data will never reach
1598-
# the receiver, causing a deadlock.
1599-
if self.executed_batch_queue.empty():
1600-
self.wait_on_pp_send_handles(self.send_handles,
1601-
executed_batch.microbatch_id)
16021605
set_thread_local_mpi_comm(None)
16031606
new_mpi_comm.Free()
16041607

0 commit comments

Comments
 (0)