Skip to content

Commit f8dd09d

Browse files
committed
Add disagg conversation ID header support
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 601d1f0 commit f8dd09d

4 files changed

Lines changed: 110 additions & 7 deletions

File tree

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: U
8080

8181

8282
class OpenAIDisaggServer:
83+
_CONVERSATION_ID_HEADERS = (
84+
"x-session-id",
85+
"x-correlation-id",
86+
"x-session-affinity",
87+
"x-multi-turn-session-id",
88+
)
8389

8490
def __init__(self,
8591
config: DisaggServerConfig,
@@ -164,25 +170,31 @@ def register_routes(self):
164170

165171
@staticmethod
166172
def _extract_conversation_id(req: UCompletionRequest, raw_req: Request):
167-
"""Populate conversation_id from the X-Correlation-ID header.
173+
"""Populate conversation_id from supported session headers.
168174
169175
When not already set in the request body, copies the header value
170176
into ``disaggregated_params.conversation_id``.
171177
172-
aiperf sends multi-turn session IDs via the ``X-Correlation-ID``
173-
header (see aiperf ``base_transports.build_headers``). We mirror
174-
that convention so the ConversationRouter can provide session
175-
affinity without requiring clients to set the body field.
178+
Supported headers are checked in priority order: ``X-Session-ID``,
179+
``X-Correlation-ID``, ``x-session-affinity``, and
180+
``x-multi-turn-session-id``. We mirror these conventions so the
181+
ConversationRouter can provide session affinity without requiring
182+
clients to set the body field.
176183
177184
When ``disaggregated_params`` is ``None`` (standard OpenAI
178185
requests without disagg fields), a minimal instance is created
179186
to carry the conversation_id. The service layer always rebuilds
180187
``disaggregated_params`` in ``_get_ctx_request`` /
181188
``_get_gen_request`` before forwarding to workers.
182189
"""
183-
header_conv_id = raw_req.headers.get("x-correlation-id")
184-
if header_conv_id is None:
190+
header_conv_id = None
191+
for header_name in OpenAIDisaggServer._CONVERSATION_ID_HEADERS:
192+
header_conv_id = raw_req.headers.get(header_name)
193+
if header_conv_id is not None:
194+
break
195+
else:
185196
return
197+
186198
if req.disaggregated_params is None:
187199
req.disaggregated_params = DisaggregatedParams(
188200
request_type="context_only",

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
775775
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen3VL_MOE::test_auto_dtype
776776
accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype
777777
accuracy/test_llm_api_pytorch_ray.py::TestLlama3_1_8BInstruct::test_pp2_ray
778+
unittest/disaggregated/test_openai_disagg_server.py
778779
disaggregated/test_auto_scaling.py::test_disagg_server_restart[etcd-round_robin]
779780
disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin]
780781
disaggregated/test_auto_scaling.py::test_minimal_instances[etcd-round_robin]

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ l0_a10:
4242
- unittest/others/test_tracing.py
4343
- unittest/disaggregated/test_disagg_openai_client.py
4444
- unittest/disaggregated/test_disagg_utils.py
45+
- unittest/disaggregated/test_openai_disagg_server.py
4546
- unittest/disaggregated/test_openai_disagg_service.py
4647
- unittest/disaggregated/test_router.py
4748
- unittest/disaggregated/test_remoteDictionary.py
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from types import SimpleNamespace
15+
16+
from starlette.datastructures import Headers
17+
18+
from tensorrt_llm.serve.openai_disagg_server import OpenAIDisaggServer
19+
from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams
20+
21+
22+
def _raw_request(headers: dict[str, str]):
23+
return SimpleNamespace(headers=Headers(headers=headers))
24+
25+
26+
def test_extract_conversation_id_from_headers():
27+
cases = [
28+
({"X-Session-ID": "session-id"}, "session-id"),
29+
({"X-Correlation-ID": "correlation-id"}, "correlation-id"),
30+
({"x-session-affinity": "session-affinity"}, "session-affinity"),
31+
({"x-multi-turn-session-id": "multi-turn-session-id"}, "multi-turn-session-id"),
32+
(
33+
{
34+
"X-Correlation-ID": "correlation-id",
35+
"X-Session-ID": "session-id",
36+
"x-session-affinity": "session-affinity",
37+
"x-multi-turn-session-id": "multi-turn-session-id",
38+
},
39+
"session-id",
40+
),
41+
(
42+
{
43+
"x-session-affinity": "session-affinity",
44+
"x-multi-turn-session-id": "multi-turn-session-id",
45+
},
46+
"session-affinity",
47+
),
48+
]
49+
50+
for headers, expected_conversation_id in cases:
51+
request = CompletionRequest(model="test-model", prompt="hello")
52+
53+
OpenAIDisaggServer._extract_conversation_id(request, _raw_request(headers))
54+
55+
assert request.disaggregated_params is not None
56+
assert request.disaggregated_params.conversation_id == expected_conversation_id
57+
58+
59+
def test_extract_conversation_id_preserves_body_conversation_id():
60+
request = CompletionRequest(
61+
model="test-model",
62+
prompt="hello",
63+
disaggregated_params=DisaggregatedParams(
64+
request_type="context_only",
65+
conversation_id="body-id",
66+
),
67+
)
68+
69+
OpenAIDisaggServer._extract_conversation_id(
70+
request,
71+
_raw_request({"X-Session-ID": "header-id"}),
72+
)
73+
74+
assert request.disaggregated_params.conversation_id == "body-id"
75+
76+
77+
def test_extract_conversation_id_populates_existing_disaggregated_params():
78+
request = CompletionRequest(
79+
model="test-model",
80+
prompt="hello",
81+
disaggregated_params=DisaggregatedParams(request_type="context_only"),
82+
)
83+
84+
OpenAIDisaggServer._extract_conversation_id(
85+
request,
86+
_raw_request({"x-multi-turn-session-id": "multi-turn-session-id"}),
87+
)
88+
89+
assert request.disaggregated_params.conversation_id == "multi-turn-session-id"

0 commit comments

Comments
 (0)