|
44 | 44 | from tensorrt_llm.llmapi.llm import RequestOutput |
45 | 45 | from tensorrt_llm.logger import logger |
46 | 46 | from tensorrt_llm.metrics.collector import MetricsCollector |
| 47 | +from tensorrt_llm.sampling_params import GuidedDecodingParams |
47 | 48 | from tensorrt_llm.serve.chat_utils import (load_chat_template, |
48 | 49 | parse_chat_messages_coroutines) |
49 | 50 | from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client |
|
62 | 63 | ImageObject, |
63 | 64 | MemoryUpdateRequest, ModelCard, |
64 | 65 | ModelList, PromptTokensDetails, |
| 66 | + ResponseFormat, |
65 | 67 | ResponsesRequest, |
66 | 68 | ResponsesResponse, |
67 | 69 | UpdateWeightsRequest, UsageInfo, |
|
96 | 98 | TIMEOUT_KEEP_ALIVE = 5 # seconds. |
97 | 99 |
|
98 | 100 |
|
| 101 | +def _build_tool_strict_guided_decoding_params(tools, tool_parser_name): |
| 102 | + """Build GuidedDecodingParams with structural tags for tools with strict=True. |
| 103 | +
|
| 104 | + When a tool has ``strict=True`` in its function definition, the server |
| 105 | + should use constrained decoding to guarantee that the generated tool call |
| 106 | + arguments exactly match the function's ``parameters`` JSON Schema. |
| 107 | +
|
| 108 | + This function builds structural tag items from each tool parser's |
| 109 | + ``structure_info()`` and the tool's ``parameters`` schema, then returns |
| 110 | + a ``GuidedDecodingParams`` with the structural tag format. |
| 111 | +
|
| 112 | + Returns None if no tool has strict=True or the parser doesn't support |
| 113 | + structural tags. |
| 114 | + """ |
| 115 | + if not tools or not tool_parser_name: |
| 116 | + return None |
| 117 | + |
| 118 | + # Check if any tool has strict=True |
| 119 | + has_strict = any(tool.function.strict for tool in tools |
| 120 | + if tool.function.strict) |
| 121 | + if not has_strict: |
| 122 | + return None |
| 123 | + |
| 124 | + tool_parser_cls = ToolParserFactory.parsers.get(tool_parser_name.lower()) |
| 125 | + if tool_parser_cls is None: |
| 126 | + logger.warning( |
| 127 | + "Tool parser '%s' not found, cannot enforce strict mode for tools.", |
| 128 | + tool_parser_name) |
| 129 | + return None |
| 130 | + |
| 131 | + parser = tool_parser_cls() |
| 132 | + if not parser.supports_structural_tag(): |
| 133 | + logger.warning( |
| 134 | + "Tool parser '%s' does not support structural tags, " |
| 135 | + "cannot enforce strict mode for tools.", tool_parser_name) |
| 136 | + return None |
| 137 | + |
| 138 | + get_info = parser.structure_info() |
| 139 | + |
| 140 | + tags = [] |
| 141 | + triggers = set() |
| 142 | + for tool in tools: |
| 143 | + info = get_info(tool.function.name) |
| 144 | + triggers.add(info.trigger) |
| 145 | + |
| 146 | + if tool.function.strict and tool.function.parameters: |
| 147 | + # Strict tool: constrain arguments to match the JSON Schema |
| 148 | + content = { |
| 149 | + "type": "json_schema", |
| 150 | + "json_schema": tool.function.parameters, |
| 151 | + } |
| 152 | + else: |
| 153 | + # Non-strict tool or no parameters: allow any text |
| 154 | + content = {"type": "any_text"} |
| 155 | + |
| 156 | + tags.append({ |
| 157 | + "begin": info.begin, |
| 158 | + "content": content, |
| 159 | + "end": info.end, |
| 160 | + }) |
| 161 | + |
| 162 | + stag_format = { |
| 163 | + "type": "triggered_tags", |
| 164 | + "triggers": sorted(triggers), |
| 165 | + "tags": tags, |
| 166 | + } |
| 167 | + |
| 168 | + resp_format = ResponseFormat(type="structural_tag", format=stag_format) |
| 169 | + return GuidedDecodingParams(structural_tag=resp_format.model_dump_json( |
| 170 | + by_alias=True, exclude_none=True)) |
| 171 | + |
| 172 | + |
99 | 173 | class OpenAIServer: |
100 | 174 |
|
101 | 175 | def __init__( |
@@ -856,6 +930,14 @@ async def chat_stream_generator( |
856 | 930 | if tool_parser_cls and getattr( |
857 | 931 | tool_parser_cls, 'needs_raw_special_tokens', False): |
858 | 932 | sampling_params.skip_special_tokens = False |
| 933 | + # When strict=True on any tool, apply constrained decoding |
| 934 | + # via structural tags (only if response_format doesn't already |
| 935 | + # set guided decoding). |
| 936 | + if sampling_params.guided_decoding is None: |
| 937 | + strict_guided = _build_tool_strict_guided_decoding_params( |
| 938 | + request.tools, self.tool_parser) |
| 939 | + if strict_guided is not None: |
| 940 | + sampling_params.guided_decoding = strict_guided |
859 | 941 | postproc_args = ChatPostprocArgs.from_request(request) |
860 | 942 | disaggregated_params = to_llm_disaggregated_params( |
861 | 943 | request.disaggregated_params) |
|
0 commit comments