Skip to content

Commit a9125ec

Browse files
committed
refactor: share OpenAI tool schema normalization
1 parent 31016c5 commit a9125ec

5 files changed

Lines changed: 265 additions & 145 deletions

File tree

src/agents/openai-tool-schema.ts

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import { normalizeToolParameterSchema } from "./pi-tools.schema.js";
2+
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
3+
4+
type OpenAITransportKind = "stream" | "websocket";
5+
6+
type OpenAIStrictToolModel = {
7+
provider?: unknown;
8+
api?: unknown;
9+
baseUrl?: unknown;
10+
id?: unknown;
11+
compat?: { supportsStore?: boolean };
12+
};
13+
14+
type ToolWithParameters = {
15+
parameters: unknown;
16+
};
17+
18+
export function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown {
19+
return normalizeStrictOpenAIJsonSchemaRecursive(normalizeToolParameterSchema(schema ?? {}));
20+
}
21+
22+
function normalizeStrictOpenAIJsonSchemaRecursive(schema: unknown): unknown {
23+
if (Array.isArray(schema)) {
24+
let changed = false;
25+
const normalized = schema.map((entry) => {
26+
const next = normalizeStrictOpenAIJsonSchemaRecursive(entry);
27+
changed ||= next !== entry;
28+
return next;
29+
});
30+
return changed ? normalized : schema;
31+
}
32+
if (!schema || typeof schema !== "object") {
33+
return schema;
34+
}
35+
36+
const record = schema as Record<string, unknown>;
37+
let changed = false;
38+
const normalized: Record<string, unknown> = {};
39+
for (const [key, value] of Object.entries(record)) {
40+
const next = normalizeStrictOpenAIJsonSchemaRecursive(value);
41+
normalized[key] = next;
42+
changed ||= next !== value;
43+
}
44+
45+
if (normalized.type === "object") {
46+
const properties =
47+
normalized.properties &&
48+
typeof normalized.properties === "object" &&
49+
!Array.isArray(normalized.properties)
50+
? (normalized.properties as Record<string, unknown>)
51+
: undefined;
52+
if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) {
53+
normalized.required = [];
54+
changed = true;
55+
}
56+
}
57+
58+
return changed ? normalized : schema;
59+
}
60+
61+
export function normalizeOpenAIStrictToolParameters<T>(schema: T, strict: boolean): T {
62+
if (!strict) {
63+
return normalizeToolParameterSchema(schema ?? {}) as T;
64+
}
65+
return normalizeStrictOpenAIJsonSchema(schema) as T;
66+
}
67+
68+
export function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean {
69+
return isStrictOpenAIJsonSchemaCompatibleRecursive(normalizeStrictOpenAIJsonSchema(schema));
70+
}
71+
72+
function isStrictOpenAIJsonSchemaCompatibleRecursive(schema: unknown): boolean {
73+
if (Array.isArray(schema)) {
74+
return schema.every((entry) => isStrictOpenAIJsonSchemaCompatibleRecursive(entry));
75+
}
76+
if (!schema || typeof schema !== "object") {
77+
return true;
78+
}
79+
80+
const record = schema as Record<string, unknown>;
81+
if ("anyOf" in record || "oneOf" in record || "allOf" in record) {
82+
return false;
83+
}
84+
if (Array.isArray(record.type)) {
85+
return false;
86+
}
87+
if (record.type === "object" && record.additionalProperties !== false) {
88+
return false;
89+
}
90+
if (record.type === "object") {
91+
const properties =
92+
record.properties &&
93+
typeof record.properties === "object" &&
94+
!Array.isArray(record.properties)
95+
? (record.properties as Record<string, unknown>)
96+
: {};
97+
const required = Array.isArray(record.required)
98+
? record.required.filter((entry): entry is string => typeof entry === "string")
99+
: undefined;
100+
if (!required) {
101+
return false;
102+
}
103+
const requiredSet = new Set(required);
104+
if (Object.keys(properties).some((key) => !requiredSet.has(key))) {
105+
return false;
106+
}
107+
}
108+
109+
return Object.entries(record).every(([key, entry]) => {
110+
if (key === "properties" && entry && typeof entry === "object" && !Array.isArray(entry)) {
111+
return Object.values(entry as Record<string, unknown>).every((value) =>
112+
isStrictOpenAIJsonSchemaCompatibleRecursive(value),
113+
);
114+
}
115+
return isStrictOpenAIJsonSchemaCompatibleRecursive(entry);
116+
});
117+
}
118+
119+
export function resolveOpenAIStrictToolFlagForInventory<T extends ToolWithParameters>(
120+
tools: readonly T[],
121+
strict: boolean | null | undefined,
122+
): boolean | undefined {
123+
if (strict !== true) {
124+
return strict === false ? false : undefined;
125+
}
126+
return tools.every((tool) => isStrictOpenAIJsonSchemaCompatible(tool.parameters));
127+
}
128+
129+
export function resolvesToNativeOpenAIStrictTools(
130+
model: OpenAIStrictToolModel,
131+
transport: OpenAITransportKind,
132+
): boolean {
133+
const capabilities = resolveProviderRequestCapabilities({
134+
provider: model.provider,
135+
api: model.api,
136+
baseUrl: model.baseUrl,
137+
capability: "llm",
138+
transport,
139+
modelId: model.id,
140+
compat:
141+
model.compat && typeof model.compat === "object"
142+
? (model.compat as { supportsStore?: boolean })
143+
: undefined,
144+
});
145+
if (!capabilities.usesKnownNativeOpenAIRoute) {
146+
return false;
147+
}
148+
return (
149+
capabilities.provider === "openai" ||
150+
capabilities.provider === "openai-codex" ||
151+
capabilities.provider === "azure-openai" ||
152+
capabilities.provider === "azure-openai-responses"
153+
);
154+
}
155+
156+
export function resolveOpenAIStrictToolSetting(
157+
model: OpenAIStrictToolModel,
158+
options?: { transport?: OpenAITransportKind; supportsStrictMode?: boolean },
159+
): boolean | undefined {
160+
if (resolvesToNativeOpenAIStrictTools(model, options?.transport ?? "stream")) {
161+
return true;
162+
}
163+
if (options?.supportsStrictMode) {
164+
return false;
165+
}
166+
return undefined;
167+
}

src/agents/openai-transport-stream.ts

Lines changed: 14 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ import {
2727
applyOpenAIResponsesPayloadPolicy,
2828
resolveOpenAIResponsesPayloadPolicy,
2929
} from "./openai-responses-payload-policy.js";
30-
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
30+
import {
31+
normalizeOpenAIStrictToolParameters,
32+
resolveOpenAIStrictToolFlagForInventory,
33+
resolveOpenAIStrictToolSetting,
34+
} from "./openai-tool-schema.js";
3135
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
3236
import { stripSystemPromptCacheBoundary } from "./system-prompt-cache-boundary.js";
3337
import { transformTransportMessages } from "./transport-message-transform.js";
@@ -332,7 +336,7 @@ function convertResponsesTools(
332336
tools: NonNullable<Context["tools"]>,
333337
options?: { strict?: boolean | null },
334338
): FunctionTool[] {
335-
const strict = resolveStrictToolFlagForInventory(tools, options?.strict);
339+
const strict = resolveOpenAIStrictToolFlagForInventory(tools, options?.strict);
336340
if (strict === undefined) {
337341
return tools.map((tool) => ({
338342
type: "function",
@@ -350,104 +354,6 @@ function convertResponsesTools(
350354
}));
351355
}
352356

353-
function normalizeOpenAIStrictToolParameters<T>(schema: T, strict: boolean): T {
354-
if (!strict) {
355-
return schema;
356-
}
357-
return normalizeStrictOpenAIJsonSchema(schema) as T;
358-
}
359-
360-
function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown {
361-
if (Array.isArray(schema)) {
362-
let changed = false;
363-
const normalized = schema.map((entry) => {
364-
const next = normalizeStrictOpenAIJsonSchema(entry);
365-
changed ||= next !== entry;
366-
return next;
367-
});
368-
return changed ? normalized : schema;
369-
}
370-
if (!schema || typeof schema !== "object") {
371-
return schema;
372-
}
373-
374-
const record = schema as Record<string, unknown>;
375-
let changed = false;
376-
const normalized: Record<string, unknown> = {};
377-
for (const [key, value] of Object.entries(record)) {
378-
const next = normalizeStrictOpenAIJsonSchema(value);
379-
normalized[key] = next;
380-
changed ||= next !== value;
381-
}
382-
383-
if (normalized.type === "object") {
384-
const properties =
385-
normalized.properties &&
386-
typeof normalized.properties === "object" &&
387-
!Array.isArray(normalized.properties)
388-
? (normalized.properties as Record<string, unknown>)
389-
: undefined;
390-
if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) {
391-
normalized.required = [];
392-
changed = true;
393-
}
394-
}
395-
396-
return changed ? normalized : schema;
397-
}
398-
399-
function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean {
400-
if (Array.isArray(schema)) {
401-
return schema.every((entry) => isStrictOpenAIJsonSchemaCompatible(entry));
402-
}
403-
if (!schema || typeof schema !== "object") {
404-
return true;
405-
}
406-
407-
const record = schema as Record<string, unknown>;
408-
if ("anyOf" in record || "oneOf" in record || "allOf" in record) {
409-
return false;
410-
}
411-
if (Array.isArray(record.type)) {
412-
return false;
413-
}
414-
if (record.type === "object" && record.additionalProperties !== false) {
415-
return false;
416-
}
417-
if (record.type === "object") {
418-
const properties =
419-
record.properties &&
420-
typeof record.properties === "object" &&
421-
!Array.isArray(record.properties)
422-
? (record.properties as Record<string, unknown>)
423-
: {};
424-
const required = Array.isArray(record.required)
425-
? record.required.filter((entry): entry is string => typeof entry === "string")
426-
: undefined;
427-
if (!required) {
428-
return false;
429-
}
430-
const requiredSet = new Set(required);
431-
if (Object.keys(properties).some((key) => !requiredSet.has(key))) {
432-
return false;
433-
}
434-
}
435-
436-
return Object.values(record).every((entry) => isStrictOpenAIJsonSchemaCompatible(entry));
437-
}
438-
439-
function resolveStrictToolFlagForInventory(
440-
tools: NonNullable<Context["tools"]>,
441-
strict: boolean | null | undefined,
442-
): boolean | undefined {
443-
if (strict !== true) {
444-
return strict === false ? false : undefined;
445-
}
446-
return tools.every((tool) =>
447-
isStrictOpenAIJsonSchemaCompatible(normalizeStrictOpenAIJsonSchema(tool.parameters)),
448-
);
449-
}
450-
451357
async function processResponsesStream(
452358
openaiStream: AsyncIterable<unknown>,
453359
output: MutableAssistantOutput,
@@ -857,7 +763,9 @@ export function buildOpenAIResponsesParams(
857763
}
858764
if (context.tools) {
859765
params.tools = convertResponsesTools(context.tools, {
860-
strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel),
766+
strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel, {
767+
transport: "stream",
768+
}),
861769
});
862770
}
863771
if (model.reasoning) {
@@ -1318,51 +1226,17 @@ function mapReasoningEffort(effort: string, reasoningEffortMap: Record<string, s
13181226
return reasoningEffortMap[effort] ?? effort;
13191227
}
13201228

1321-
function resolvesToNativeOpenAIStrictTools(model: OpenAIModeModel): boolean {
1322-
const capabilities = resolveProviderRequestCapabilities({
1323-
provider: model.provider,
1324-
api: model.api,
1325-
baseUrl: model.baseUrl,
1326-
capability: "llm",
1327-
transport: "stream",
1328-
modelId: model.id,
1329-
compat:
1330-
model.compat && typeof model.compat === "object"
1331-
? (model.compat as { supportsStore?: boolean })
1332-
: undefined,
1333-
});
1334-
if (!capabilities.usesKnownNativeOpenAIRoute) {
1335-
return false;
1336-
}
1337-
return (
1338-
capabilities.provider === "openai" ||
1339-
capabilities.provider === "openai-codex" ||
1340-
capabilities.provider === "azure-openai" ||
1341-
capabilities.provider === "azure-openai-responses"
1342-
);
1343-
}
1344-
1345-
function resolveOpenAIStrictToolSetting(
1346-
model: OpenAIModeModel,
1347-
compat?: ReturnType<typeof getCompat>,
1348-
): boolean | undefined {
1349-
if (resolvesToNativeOpenAIStrictTools(model)) {
1350-
return true;
1351-
}
1352-
if (compat?.supportsStrictMode) {
1353-
return false;
1354-
}
1355-
return undefined;
1356-
}
1357-
13581229
function convertTools(
13591230
tools: NonNullable<Context["tools"]>,
13601231
compat: ReturnType<typeof getCompat>,
13611232
model: OpenAIModeModel,
13621233
) {
1363-
const strict = resolveStrictToolFlagForInventory(
1234+
const strict = resolveOpenAIStrictToolFlagForInventory(
13641235
tools,
1365-
resolveOpenAIStrictToolSetting(model, compat),
1236+
resolveOpenAIStrictToolSetting(model, {
1237+
transport: "stream",
1238+
supportsStrictMode: compat?.supportsStrictMode,
1239+
}),
13661240
);
13671241
return tools.map((tool) => ({
13681242
type: "function",

0 commit comments

Comments
 (0)