Add CLI flags to test_inference.py without changing default behavior#100
Open
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
Open
Add CLI flags to test_inference.py without changing default behavior#100lonexreb wants to merge 1 commit intoNVlabs:mainfrom
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
Conversation
Wraps the previously top-level test_inference script in a parse_args() +
main(args) + __main__ guard so users can vary the clip, sampling
parameters, model name, seed, device, and dtype without editing the
file. The defaults reproduce the prior hardcoded constants exactly, so
`python src/alpamayo_r1/test_inference.py` is unchanged.
Why this is worth adding:
- Researchers iterating on a single clip want to try several
--num-traj-samples / --temperature / --top-p settings without
re-saving the file each time.
- Picking a different --t0-us or --clip-id is the most common
modification users make and is currently a copy/edit/run loop.
- --seed / --dtype / --device / --model lets the script support quick
ablations and local checkpoints without changes to source.
New flags (all default to the previous hardcoded values):
--clip-id "030c760c-..." (current example clip)
--t0-us 5100000
--model nvidia/Alpamayo-R1-10B
--num-traj-samples 1
--top-p 0.98
--temperature 0.6
--max-generation-length 256
--seed 42
--device cuda
--dtype bfloat16 (choices: bfloat16/float16/float32)
Verified locally without GPU/HF by extracting parse_args via AST and
asserting:
- All defaults equal the prior hardcoded values.
- CLI overrides take effect; untouched fields keep defaults.
- argparse rejects invalid --dtype values via choices=.
The trailing "VLA-reasoning models produce nondeterministic outputs..."
note is preserved verbatim so the operational guidance for users with
single-sample runs is unchanged.
Signed-off-by: lonexreb <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Why
src/alpamayo_r1/test_inference.pyis the script users run first when trying out the model. Today every parameter — clip id, keyframe, sampling temperature, sample count, model name, seed, device, dtype — is a hardcoded literal at module top level. Anyone who wants to:num_traj_samplesfrom 1 to 4 to see CoC variationnvidia/Alpamayo-R1-10Bfor a local checkpoint--seedfor a side-by-side comparison…has to copy the file out, edit constants, re-run, repeat. That's the most common iteration loop in the issues thread.
What
Wrap the existing top-level body in
parse_args()+main(args)+if __name__ == "__main__"and surface the existing constants as CLI flags. All defaults preserve the current behavior exactly, sopython src/alpamayo_r1/test_inference.pyproduces the same output as before this PR.New flags (all default to the previous hardcoded values):
--clip-id030c760c-...--t0-us5_100_000--modelnvidia/Alpamayo-R1-10B--num-traj-samples1--top-p0.98--temperature0.6--max-generation-length256--seed42--devicecudacuda:1)--dtypebfloat16bfloat16,float16,float32Examples added to the module docstring:
The trailing operational note about nondeterministic outputs at
num_traj_samples=1is preserved verbatim.Verification
Validated locally without GPU/HF (no model load required) by extracting
parse_argsvia AST and asserting:--num-traj-samples 8 --temperature 0.9 --seed 7updates exactly those fields and leaves the rest at their defaults.--dtypeis rejected byargparsechoices=(e.g.--dtype invalidexits withinvalid choice: 'invalid').Migration
Zero migration needed. Existing scripts/CI/notebooks calling
python src/alpamayo_r1/test_inference.pycontinue to behave identically.