Skip to content

Add CLI flags to test_inference.py without changing default behavior#100

Open
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
lonexreb:feat/test-inference-cli-flags
Open

Add CLI flags to test_inference.py without changing default behavior#100
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
lonexreb:feat/test-inference-cli-flags

Conversation

@lonexreb
Copy link
Copy Markdown
Contributor

@lonexreb lonexreb commented May 4, 2026

Why

src/alpamayo_r1/test_inference.py is the script users run first when trying out the model. Today every parameter — clip id, keyframe, sampling temperature, sample count, model name, seed, device, dtype — is a hardcoded literal at module top level. Anyone who wants to:

  • try a different clip
  • bump num_traj_samples from 1 to 4 to see CoC variation
  • swap nvidia/Alpamayo-R1-10B for a local checkpoint
  • run with --seed for a side-by-side comparison
  • evaluate at fp16/fp32 instead of bf16

…has to copy the file out, edit constants, re-run, repeat. That's the most common iteration loop in the issues thread.

What

Wrap the existing top-level body in parse_args() + main(args) + if __name__ == "__main__" and surface the existing constants as CLI flags. All defaults preserve the current behavior exactly, so python src/alpamayo_r1/test_inference.py produces the same output as before this PR.

New flags (all default to the previous hardcoded values):

Flag Default Purpose
--clip-id 030c760c-... PAI clip id
--t0-us 5_100_000 Keyframe timestamp
--model nvidia/Alpamayo-R1-10B HF model id or local path
--num-traj-samples 1 Samples per CoC rollout
--top-p 0.98 Nucleus sampling
--temperature 0.6 Sampling temperature
--max-generation-length 256 CoC token budget
--seed 42 CUDA RNG seed
--device cuda Torch device (e.g. cuda:1)
--dtype bfloat16 Choices: bfloat16, float16, float32

Examples added to the module docstring:

python src/alpamayo_r1/test_inference.py
python src/alpamayo_r1/test_inference.py --num-traj-samples 8 --temperature 0.9
python src/alpamayo_r1/test_inference.py --t0-us 6000000 --clip-id <other-clip>

The trailing operational note about nondeterministic outputs at num_traj_samples=1 is preserved verbatim.

Verification

Validated locally without GPU/HF (no model load required) by extracting parse_args via AST and asserting:

  1. Defaults match the prior hardcoded values for every parameter (clip_id, t0_us, model, num_traj_samples, top_p, temperature, max_generation_length, seed, device, dtype).
  2. Overrides take effect: passing --num-traj-samples 8 --temperature 0.9 --seed 7 updates exactly those fields and leaves the rest at their defaults.
  3. Invalid --dtype is rejected by argparse choices= (e.g. --dtype invalid exits with invalid choice: 'invalid').

Migration

Zero migration needed. Existing scripts/CI/notebooks calling python src/alpamayo_r1/test_inference.py continue to behave identically.

Wraps the previously top-level test_inference script in a parse_args() +
main(args) + __main__ guard so users can vary the clip, sampling
parameters, model name, seed, device, and dtype without editing the
file. The defaults reproduce the prior hardcoded constants exactly, so
`python src/alpamayo_r1/test_inference.py` is unchanged.

Why this is worth adding:

- Researchers iterating on a single clip want to try several
  --num-traj-samples / --temperature / --top-p settings without
  re-saving the file each time.
- Picking a different --t0-us or --clip-id is the most common
  modification users make and is currently a copy/edit/run loop.
- --seed / --dtype / --device / --model lets the script support quick
  ablations and local checkpoints without changes to source.

New flags (all default to the previous hardcoded values):

    --clip-id                      "030c760c-..." (current example clip)
    --t0-us                        5100000
    --model                        nvidia/Alpamayo-R1-10B
    --num-traj-samples             1
    --top-p                        0.98
    --temperature                  0.6
    --max-generation-length        256
    --seed                         42
    --device                       cuda
    --dtype                        bfloat16   (choices: bfloat16/float16/float32)

Verified locally without GPU/HF by extracting parse_args via AST and
asserting:
- All defaults equal the prior hardcoded values.
- CLI overrides take effect; untouched fields keep defaults.
- argparse rejects invalid --dtype values via choices=.

The trailing "VLA-reasoning models produce nondeterministic outputs..."
note is preserved verbatim so the operational guidance for users with
single-sample runs is unchanged.

Signed-off-by: lonexreb <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant