Skip to content

Pin expert decoder to SDPA so model loads with FA2 globally enabled (fixes #52)#75

Open
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
lonexreb:fix/expert-attn-impl-sdpa-for-fa2-compat
Open

Pin expert decoder to SDPA so model loads with FA2 globally enabled (fixes #52)#75
lonexreb wants to merge 1 commit intoNVlabs:mainfrom
lonexreb:fix/expert-attn-impl-sdpa-for-fa2-compat

Conversation

@lonexreb
Copy link
Copy Markdown
Contributor

@lonexreb lonexreb commented May 2, 2026

Summary

Issue #52 reports that enabling flash_attention_2 (the default in ReasoningVLAConfig.attn_implementation) crashes during diffusion sampling with:

RuntimeError: cu_seqlens_q must have shape (batch_size + 1)

The crash isn't in the VLM prefill — it's in the diffusion expert step.

Root cause

AlpamayoR1.sample_trajectories_from_data_with_vlm_rollout builds a custom 4D float attention mask to handle variable-length VLM rollouts (src/alpamayo_r1/models/alpamayo_r1.py:242-250):

attention_mask = torch.zeros(
    (b_star, 1, n_diffusion_tokens, prompt_cache.get_seq_length() + n_diffusion_tokens),
    dtype=torch.float32,
    device=device,
)
for i in range(b_star):
    attention_mask[i, :, :, offset[i] : -n_diffusion_tokens] = torch.finfo(
        attention_mask.dtype
    ).min

Each rollout has a different valid prefix length (<traj_future_start> at a different position per sample), so a per-row prefix is masked with -inf. SDPA handles this fine; FA2's transformers integration only supports causal attention or padded varlen via cu_seqlens_q/cu_seqlens_k and cannot consume arbitrary 4D float masks.

The expert inherits _attn_implementation from a deep-copy of vlm.config.text_config (alpamayo_r1.py:90), so flipping the VLM to FA2 silently flips the expert too — and the expert is where the 4D mask lives.

Fix

Explicitly pin the expert to SDPA after the deep-copy. The expert step is short (n_diffusion_tokens is small, batch is B * num_traj_samples), so SDPA is cheap; the VLM is unaffected and keeps using whatever attn_implementation the config specifies (FA2 by default), where the long prefill actually benefits.

 # we only need the text config for the expert model
 expert_config = copy.deepcopy(self.vlm.config.text_config)
+# The diffusion expert step builds a per-row 4D float attention mask
+# in sample_trajectories_from_data_with_vlm_rollout (see lines below)
+# to handle variable-length VLM rollouts. FlashAttention-2's transformers
+# integration cannot consume arbitrary 4D float masks (it only supports
+# causal or padded-varlen via cu_seqlens) and crashes with
+# "cu_seqlens_q must have shape (batch_size + 1)". Pin the expert to
+# SDPA so the model is usable when the VLM is loaded with FA2.
+# See https://github.com/NVlabs/alpamayo/issues/52.
+expert_config._attn_implementation = "sdpa"
 if config.expert_cfg is not None:
     for key, value in config.expert_cfg.items():
         setattr(expert_config, key, value)
 self.expert = AutoModel.from_config(expert_config)

expert_cfg.attn_implementation (passed via Hydra) still wins because the override loop runs after the explicit pin — so anyone who has rewritten the expert step to use a 2D padding mask can flip back to FA2 without a code change.

Alternative considered

Right-pad all rollouts to the same length, drop the per-row 4D mask, and use a 2D padding mask — would make the expert FA2-compatible but seems unnecessary given how short the expert forward is.

Test plan

Related

The diffusion expert step in sample_trajectories_from_data_with_vlm_rollout
builds a per-row 4D float attention mask to handle variable-length VLM
rollouts: each batch row has a different valid prefix length (the
<traj_future_start> token shows up at a different position per sample),
so a per-row prefix is masked out with -inf in a (B*, 1, T_diff, T_kv)
float mask.

FlashAttention-2's transformers integration cannot consume arbitrary 4D
float masks. It only supports causal attention or padded sequences via
cu_seqlens_q/cu_seqlens_k (varlen). When users force attn_implementation
to flash_attention_2 globally, the deep-copy of vlm.config.text_config
into expert_config carries FA2 down into the expert, which then crashes
during the diffusion step with:

  RuntimeError: cu_seqlens_q must have shape (batch_size + 1)

Fix: explicitly pin the expert to SDPA. The expert step is short
(n_diffusion_tokens is small) so the SDPA path is cheap; the VLM is
unaffected and keeps using whatever attn_implementation the config
specifies (FA2 by default), where the long prefill actually pays off.

Fixes NVlabs#52.

Signed-off-by: lonexreb <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Enabling flash_attention_2 causes RuntimeError: cu_seqlens_q must have shape (batch_size + 1)

1 participant