Pin expert decoder to SDPA so model loads with FA2 globally enabled (fixes #52)#75
Open
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
Open
Pin expert decoder to SDPA so model loads with FA2 globally enabled (fixes #52)#75lonexreb wants to merge 1 commit intoNVlabs:mainfrom
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
Conversation
The diffusion expert step in sample_trajectories_from_data_with_vlm_rollout builds a per-row 4D float attention mask to handle variable-length VLM rollouts: each batch row has a different valid prefix length (the <traj_future_start> token shows up at a different position per sample), so a per-row prefix is masked out with -inf in a (B*, 1, T_diff, T_kv) float mask. FlashAttention-2's transformers integration cannot consume arbitrary 4D float masks. It only supports causal attention or padded sequences via cu_seqlens_q/cu_seqlens_k (varlen). When users force attn_implementation to flash_attention_2 globally, the deep-copy of vlm.config.text_config into expert_config carries FA2 down into the expert, which then crashes during the diffusion step with: RuntimeError: cu_seqlens_q must have shape (batch_size + 1) Fix: explicitly pin the expert to SDPA. The expert step is short (n_diffusion_tokens is small) so the SDPA path is cheap; the VLM is unaffected and keeps using whatever attn_implementation the config specifies (FA2 by default), where the long prefill actually pays off. Fixes NVlabs#52. Signed-off-by: lonexreb <[email protected]>
This was referenced May 2, 2026
Open
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Issue #52 reports that enabling
flash_attention_2(the default inReasoningVLAConfig.attn_implementation) crashes during diffusion sampling with:The crash isn't in the VLM prefill — it's in the diffusion expert step.
Root cause
AlpamayoR1.sample_trajectories_from_data_with_vlm_rolloutbuilds a custom 4D float attention mask to handle variable-length VLM rollouts (src/alpamayo_r1/models/alpamayo_r1.py:242-250):Each rollout has a different valid prefix length (
<traj_future_start>at a different position per sample), so a per-row prefix is masked with-inf. SDPA handles this fine; FA2's transformers integration only supports causal attention or padded varlen viacu_seqlens_q/cu_seqlens_kand cannot consume arbitrary 4D float masks.The expert inherits
_attn_implementationfrom a deep-copy ofvlm.config.text_config(alpamayo_r1.py:90), so flipping the VLM to FA2 silently flips the expert too — and the expert is where the 4D mask lives.Fix
Explicitly pin the expert to SDPA after the deep-copy. The expert step is short (
n_diffusion_tokensis small, batch isB * num_traj_samples), so SDPA is cheap; the VLM is unaffected and keeps using whateverattn_implementationthe config specifies (FA2 by default), where the long prefill actually benefits.expert_cfg.attn_implementation(passed via Hydra) still wins because the override loop runs after the explicit pin — so anyone who has rewritten the expert step to use a 2D padding mask can flip back to FA2 without a code change.Alternative considered
Right-pad all rollouts to the same length, drop the per-row 4D mask, and use a 2D padding mask — would make the expert FA2-compatible but seems unnecessary given how short the expert forward is.
Test plan
python -c \"import ast; ast.parse(open('src/alpamayo_r1/models/alpamayo_r1.py').read())\"— syntax OK.expert_cfg.attn_implementationstill takes precedence.python src/alpamayo_r1/test_inference.pywith the FA2 forcing patch from Enabling flash_attention_2 causes RuntimeError: cu_seqlens_q must have shape (batch_size + 1) #52 and confirm the crash is gone.Related