Skip to content

Commit 1085480

Browse files
committed
[TRTLLM-11616][fix] Accept strict field in tools and store field in chat requests
- Add `strict: Optional[bool]` to `FunctionDefinition` in openai_protocol.py so OpenAI-compatible clients can pass `tools[].function.strict` without getting an HTTP 400 "extra_forbidden" error. - Add `store: Optional[bool]` to `ChatCompletionRequest` so the `store` field is accepted (but not acted on) matching the OpenAI API spec. - When `strict=True`, build structural tag guided decoding params to constrain tool call arguments to the function's JSON Schema via the existing `structural_tag` machinery. - Add 14 unit tests covering the new fields and guided decoding logic. Signed-off-by: JunyiXu-nv <[email protected]>
1 parent d6b8e6f commit 1085480

3 files changed

Lines changed: 369 additions & 0 deletions

File tree

tensorrt_llm/serve/openai_protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ class FunctionDefinition(OpenAIBaseModel):
618618
name: str
619619
description: Optional[str] = None
620620
parameters: Optional[Dict[str, Any]] = None
621+
strict: Optional[bool] = None
621622

622623

623624
class ChatCompletionToolsParam(OpenAIBaseModel):

tensorrt_llm/serve/openai_server.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from tensorrt_llm.llmapi.llm import RequestOutput
4545
from tensorrt_llm.logger import logger
4646
from tensorrt_llm.metrics.collector import MetricsCollector
47+
from tensorrt_llm.sampling_params import GuidedDecodingParams
4748
from tensorrt_llm.serve.chat_utils import (load_chat_template,
4849
parse_chat_messages_coroutines)
4950
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
@@ -62,6 +63,7 @@
6263
ImageObject,
6364
MemoryUpdateRequest, ModelCard,
6465
ModelList, PromptTokensDetails,
66+
ResponseFormat,
6567
ResponsesRequest,
6668
ResponsesResponse,
6769
UpdateWeightsRequest, UsageInfo,
@@ -96,6 +98,78 @@
9698
TIMEOUT_KEEP_ALIVE = 5 # seconds.
9799

98100

101+
def _build_tool_strict_guided_decoding_params(tools, tool_parser_name):
102+
"""Build GuidedDecodingParams with structural tags for tools with strict=True.
103+
104+
When a tool has ``strict=True`` in its function definition, the server
105+
should use constrained decoding to guarantee that the generated tool call
106+
arguments exactly match the function's ``parameters`` JSON Schema.
107+
108+
This function builds structural tag items from each tool parser's
109+
``structure_info()`` and the tool's ``parameters`` schema, then returns
110+
a ``GuidedDecodingParams`` with the structural tag format.
111+
112+
Returns None if no tool has strict=True or the parser doesn't support
113+
structural tags.
114+
"""
115+
if not tools or not tool_parser_name:
116+
return None
117+
118+
# Check if any tool has strict=True
119+
has_strict = any(tool.function.strict for tool in tools
120+
if tool.function.strict)
121+
if not has_strict:
122+
return None
123+
124+
tool_parser_cls = ToolParserFactory.parsers.get(tool_parser_name.lower())
125+
if tool_parser_cls is None:
126+
logger.warning(
127+
"Tool parser '%s' not found, cannot enforce strict mode for tools.",
128+
tool_parser_name)
129+
return None
130+
131+
parser = tool_parser_cls()
132+
if not parser.supports_structural_tag():
133+
logger.warning(
134+
"Tool parser '%s' does not support structural tags, "
135+
"cannot enforce strict mode for tools.", tool_parser_name)
136+
return None
137+
138+
get_info = parser.structure_info()
139+
140+
tags = []
141+
triggers = set()
142+
for tool in tools:
143+
info = get_info(tool.function.name)
144+
triggers.add(info.trigger)
145+
146+
if tool.function.strict and tool.function.parameters:
147+
# Strict tool: constrain arguments to match the JSON Schema
148+
content = {
149+
"type": "json_schema",
150+
"json_schema": tool.function.parameters,
151+
}
152+
else:
153+
# Non-strict tool or no parameters: allow any text
154+
content = {"type": "any_text"}
155+
156+
tags.append({
157+
"begin": info.begin,
158+
"content": content,
159+
"end": info.end,
160+
})
161+
162+
stag_format = {
163+
"type": "triggered_tags",
164+
"triggers": sorted(triggers),
165+
"tags": tags,
166+
}
167+
168+
resp_format = ResponseFormat(type="structural_tag", format=stag_format)
169+
return GuidedDecodingParams(structural_tag=resp_format.model_dump_json(
170+
by_alias=True, exclude_none=True))
171+
172+
99173
class OpenAIServer:
100174

101175
def __init__(
@@ -856,6 +930,14 @@ async def chat_stream_generator(
856930
if tool_parser_cls and getattr(
857931
tool_parser_cls, 'needs_raw_special_tokens', False):
858932
sampling_params.skip_special_tokens = False
933+
# When strict=True on any tool, apply constrained decoding
934+
# via structural tags (only if response_format doesn't already
935+
# set guided decoding).
936+
if sampling_params.guided_decoding is None:
937+
strict_guided = _build_tool_strict_guided_decoding_params(
938+
request.tools, self.tool_parser)
939+
if strict_guided is not None:
940+
sampling_params.guided_decoding = strict_guided
859941
postproc_args = ChatPostprocArgs.from_request(request)
860942
disaggregated_params = to_llm_disaggregated_params(
861943
request.disaggregated_params)

0 commit comments

Comments
 (0)