Skip to content

Enabling flash_attention_2 causes RuntimeError: cu_seqlens_q must have shape (batch_size + 1) #52

@hyc-hw

Description

@hyc-hw

Thank you for your excellent open source work. I have a question about using flash_attention_2

Environment:
Hardware: NVIDIA A100
Python version: 3.12
Package versions:
torch==2.8.0
torchvision==0.23.0
transformers==4.57.1
flash-attn==2.8.3
Model: Official Alpamayo model from Hugging Face
Description:
When loading the official Alpamayo model, I found that the default attention implementation is sdpa even FlashAttention-2 is used in config.json.

To enable faster inference, I manually forced FlashAttention-2 by adding the following line in alpamayo/src/alpamayo_r1/models/base_model.py:

class ReasoningVLA(PreTrainedModel, TrajectoryFusionMixin):
    """Reasoning Vision-Language-Action model."""

    config_class: type[ReasoningVLAConfig] = ReasoningVLAConfig
    base_model_prefix: str = "vlm"

    def __init__(
        self,
        config: ReasoningVLAConfig,
        pretrained_modules: dict[str, torch.nn.Module] | None = None,
        original_vocab_size: int | None = None,
        print_param_count: bool = True,
    ) -> None:
        super().__init__(config)

        if pretrained_modules is not None:
            for module in pretrained_modules.values():
                if not isinstance(module, torch.nn.Module):
                    continue
                _recursive_setattr(module, "_is_hf_initialized", True)
        else:
            pretrained_modules = {}
        config.attn_implementation = 'flash_attention_2'

However, running the inference script (alpamayo/src/alpamayo_r1/test_inference.py) results in the following error:

Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.
Traceback (most recent call last):
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/test_inference_ori.py", line 52, in <module>
    pred_xyz, pred_rot, extra = model.sample_trajectories_from_data_with_vlm_rollout(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/models/alpamayo_r1.py", line 291, in sample_trajectories_from_data_with_vlm_rollout
    sampled_action = self.diffusion.sample(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/diffusion/flow_matching.py", line 79, in sample
    return self._euler(
           ^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/diffusion/flow_matching.py", line 131, in _euler
    v = step_fn(x=x, t=t_start)
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/src/alpamayo_r1/models/alpamayo_r1.py", line 269, in step_fn
    expert_out_base = self.expert(
                      ^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 850, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 502, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 444, in forward
    attn_output, attn_weights = attention_interface(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/integrations/flash_attention.py", line 66, in flash_attention_forward
    attn_output = _flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 616, in _flash_attention_forward
    out_unpad = flash_varlen_fn(
                ^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
    out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_ops.py", line 836, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward
    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cu_seqlens_q must have shape (batch_size + 1)

This suggests a mismatch between the expected input format for FlashAttention-2 (which, in varlen mode, requires packed sequences and valid cu_seqlens_q) and the actual padded input used during inference.
Current Workaround:
Keep attn_implementation as the default (sdpa) to avoid the crash.
Questions:
Does the official Alpamayo model officially support flash_attention_2?
If so, does it require special input preprocessing (e.g., sequence packing)?
Could this be a compatibility issue between Qwen3-VL (used as the VLM backbone) and FlashAttention-2 in transformers==4.57.1?
Any guidance on enabling FA2 safely would be greatly appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions