Skip to content

Commit 38aca38

Browse files
committed
fix(cli): canonicalize infer model refs
1 parent 5cf55ed commit 38aca38

2 files changed

Lines changed: 84 additions & 4 deletions

File tree

src/cli/capability-cli.test.ts

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,42 @@ describe("capability cli", () => {
363363
}) as never);
364364
});
365365

366+
async function runModelRunWithModel(model: string, transport: "local" | "gateway") {
367+
await runRegisteredCli({
368+
register: registerCapabilityCli as (program: Command) => void,
369+
argv: [
370+
"capability",
371+
"model",
372+
"run",
373+
"--model",
374+
model,
375+
"--prompt",
376+
"hello",
377+
...(transport === "gateway" ? ["--gateway"] : []),
378+
"--json",
379+
],
380+
});
381+
}
382+
383+
function expectModelRunDispatch(transport: "local" | "gateway", modelRef: string) {
384+
if (transport === "gateway") {
385+
const slash = modelRef.indexOf("/");
386+
expect(mocks.callGateway).toHaveBeenCalledWith(
387+
expect.objectContaining({
388+
method: "agent",
389+
params: expect.objectContaining({
390+
provider: modelRef.slice(0, slash),
391+
model: modelRef.slice(slash + 1),
392+
}),
393+
}),
394+
);
395+
return;
396+
}
397+
expect(mocks.prepareSimpleCompletionModelForAgent).toHaveBeenCalledWith(
398+
expect.objectContaining({ modelRef }),
399+
);
400+
}
401+
366402
it("lists canonical capabilities", async () => {
367403
await runRegisteredCli({
368404
register: registerCapabilityCli as (program: Command) => void,
@@ -779,6 +815,30 @@ describe("capability cli", () => {
779815
);
780816
});
781817

818+
it.each(["local", "gateway"] as const)(
819+
"canonicalizes case-mismatched model refs before %s dispatch",
820+
async (transport) => {
821+
mocks.loadModelCatalog.mockResolvedValueOnce([
822+
{ id: "claude-opus-4-7", provider: "anthropic", name: "Claude Opus 4.7" },
823+
] as never);
824+
825+
await runModelRunWithModel("Anthropic/CLAUDE-OPUS-4-7", transport);
826+
827+
expectModelRunDispatch(transport, "anthropic/claude-opus-4-7");
828+
},
829+
);
830+
831+
it.each(["local", "gateway"] as const)(
832+
"keeps custom mixed-case model refs before %s dispatch when the catalog has no match",
833+
async (transport) => {
834+
mocks.loadModelCatalog.mockResolvedValueOnce([] as never);
835+
836+
await runModelRunWithModel("custom/MyModel", transport);
837+
838+
expectModelRunDispatch(transport, "custom/MyModel");
839+
},
840+
);
841+
782842
it("rejects empty model run prompts before gateway dispatch", async () => {
783843
await expect(
784844
runRegisteredCli({

src/cli/capability-cli.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
} from "../agents/auth-profiles.js";
1212
import { updateAuthProfileStoreWithLock } from "../agents/auth-profiles/store.js";
1313
import { resolveMemorySearchConfig } from "../agents/memory-search.js";
14+
import { findModelInCatalog } from "../agents/model-catalog-lookup.js";
1415
import { loadModelCatalog } from "../agents/model-catalog.js";
1516
import {
1617
completeWithPreparedSimpleCompletionModel,
@@ -558,6 +559,24 @@ function resolveModelRefOverride(raw: string | undefined): { provider?: string;
558559
};
559560
}
560561

562+
function shouldCanonicalizeModelRunRef(
563+
model: string | undefined,
564+
ref: { provider?: string; model?: string },
565+
): ref is { provider: string; model: string } {
566+
return Boolean(model && model !== model.toLowerCase() && ref.provider && ref.model);
567+
}
568+
569+
async function canonicalizeModelRunRef(raw: string | undefined, cfg: OpenClawConfig) {
570+
const model = normalizeStringifiedOptionalString(raw);
571+
const ref = resolveModelRefOverride(model);
572+
if (!shouldCanonicalizeModelRunRef(model, ref)) {
573+
return model;
574+
}
575+
const catalog = await loadModelCatalog({ config: cfg });
576+
const entry = findModelInCatalog(catalog, ref.provider, ref.model);
577+
return entry ? `${entry.provider}/${entry.id}` : model;
578+
}
579+
561580
function requireProviderModelOverride(
562581
raw: string | undefined,
563582
): { provider: string; model: string } | undefined {
@@ -644,11 +663,12 @@ async function runModelRun(params: {
644663
})),
645664
]
646665
: params.prompt;
666+
const model = await canonicalizeModelRunRef(params.model, cfg);
647667
if (params.transport === "local") {
648668
const prepared = await prepareSimpleCompletionModelForAgent({
649669
cfg,
650670
agentId,
651-
modelRef: params.model,
671+
modelRef: model,
652672
allowMissingApiKeyModes: ["aws-sdk"],
653673
skipPiDiscovery: true,
654674
});
@@ -721,10 +741,10 @@ async function runModelRun(params: {
721741
} satisfies CapabilityEnvelope;
722742
}
723743

724-
const { provider, model } = resolveModelRefOverride(params.model);
744+
const { provider, model: modelId } = resolveModelRefOverride(model);
725745
// Provider/model overrides require trusted-operator scope. Use the backend
726746
// shared-secret lane so local gateway smokes do not depend on paired CLI device scopes.
727-
const hasModelOverride = Boolean(provider || model);
747+
const hasModelOverride = Boolean(provider || modelId);
728748
const response: {
729749
result?: {
730750
payloads?: Array<{ text?: string; mediaUrl?: string | null; mediaUrls?: string[] }>;
@@ -751,7 +771,7 @@ async function runModelRun(params: {
751771
}))
752772
: undefined,
753773
provider,
754-
model,
774+
model: modelId,
755775
modelRun: true,
756776
promptMode: "none",
757777
cleanupBundleMcpOnRunEnd: true,

0 commit comments

Comments
 (0)