Skip to content

Commit 4fca6c1

Browse files
committed
[None][fix] Handle unset attention_dp_relax in ADP routers
OpenAI requests can carry a SchedulingParams object while leaving attention_dp_relax unset. The ADP routers used that field directly as the sort key, which makes mixed None and False requests fail when Python tries to compare None with bool values. Treat only an explicit False as strict and keep None aligned with the existing missing-scheduling-params behavior. Apply the same logic to both DefaultADPRouter and KVCacheAwareADPRouter, with regression coverage for both paths. Signed-off-by: peihengh <[email protected]>
1 parent 355ba94 commit 4fca6c1

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def get_relax_value(req_item):
304304
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
305305
if scheduling_params is None:
306306
return True
307-
return scheduling_params.attention_dp_relax
307+
return scheduling_params.attention_dp_relax is not False
308308

309309
sorted_requests = sorted(new_requests, key=get_relax_value)
310310

@@ -576,7 +576,7 @@ def get_relax_value(req_item):
576576
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
577577
if scheduling_params is None:
578578
return True
579-
return scheduling_params.attention_dp_relax
579+
return scheduling_params.attention_dp_relax is not False
580580

581581
sorted_requests = sorted(new_requests, key=get_relax_value)
582582

tests/unittest/_torch/executor/test_adp_router.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,22 @@ def test_target_dp_rank_at_capacity_falls_through(self):
203203
assert len(result[0]) == 1
204204
assert len(result[1]) == 0
205205

206+
def test_none_attention_dp_relax_is_relaxed(self):
207+
router = DefaultADPRouter(dist=_mock_dist())
208+
states = [
209+
RankState(rank=0, num_active_requests=0, num_active_tokens=0),
210+
RankState(rank=1, num_active_requests=0, num_active_tokens=0),
211+
]
212+
req_relax = _make_request_item(1, target_dp_rank=0, attention_dp_relax=None)
213+
req_strict = _make_request_item(2, target_dp_rank=0, attention_dp_relax=False)
214+
215+
result, _ = router.route_requests(
216+
states, [req_relax, req_strict], max_num_active_requests=1
217+
)
218+
219+
assert result[0] == [req_strict]
220+
assert req_relax in result[1]
221+
206222
def test_favors_less_loaded_rank(self):
207223
router = DefaultADPRouter(dist=_mock_dist())
208224
states = [
@@ -926,6 +942,18 @@ def test_cache_affinity_wins(self):
926942
assert result[0] == []
927943
assert result[1] == [req]
928944

945+
def test_none_attention_dp_relax_is_relaxed(self):
946+
router = self._make_router(tp_size=2)
947+
req_relax = _make_request_item(1, target_dp_rank=0, attention_dp_relax=None)
948+
req_strict = _make_request_item(2, target_dp_rank=0, attention_dp_relax=False)
949+
950+
result, _ = router.route_requests(
951+
self._rank_states(2), [req_relax, req_strict], max_num_active_requests=1
952+
)
953+
954+
assert result[0] == [req_strict]
955+
assert req_relax in result[1]
956+
929957
def test_match_rate_threshold_gates_cache_affinity(self):
930958
"""With rank 0 loaded but holding cache, and rank 1 idle with no
931959
cache:

0 commit comments

Comments
 (0)