Thank you for your excellent open source work. I have a question about using flash_attention_2
To enable faster inference, I manually forced FlashAttention-2 by adding the following line in alpamayo/src/alpamayo_r1/models/base_model.py:
However, running the inference script (alpamayo/src/alpamayo_r1/test_inference.py) results in the following error:
Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.
Traceback (most recent call last):
File "/data_b/hyc/alpamayo/src/alpamayo_r1/test_inference_ori.py", line 52, in <module>
pred_xyz, pred_rot, extra = model.sample_trajectories_from_data_with_vlm_rollout(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/src/alpamayo_r1/models/alpamayo_r1.py", line 291, in sample_trajectories_from_data_with_vlm_rollout
sampled_action = self.diffusion.sample(
^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/src/alpamayo_r1/diffusion/flow_matching.py", line 79, in sample
return self._euler(
^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/src/alpamayo_r1/diffusion/flow_matching.py", line 131, in _euler
v = step_fn(x=x, t=t_start)
^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/src/alpamayo_r1/models/alpamayo_r1.py", line 269, in step_fn
expert_out_base = self.expert(
^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 850, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 502, in forward
hidden_states, _ = self.self_attn(
^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 444, in forward
attn_output, attn_weights = attention_interface(
^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/integrations/flash_attention.py", line 66, in flash_attention_forward
attn_output = _flash_attention_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 616, in _flash_attention_forward
out_unpad = flash_varlen_fn(
^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
return FlashAttnVarlenFunc.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_ops.py", line 836, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_b/hyc/alpamayo/ar1_venv_uv2/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cu_seqlens_q must have shape (batch_size + 1)
This suggests a mismatch between the expected input format for FlashAttention-2 (which, in varlen mode, requires packed sequences and valid cu_seqlens_q) and the actual padded input used during inference.
Current Workaround:
Keep attn_implementation as the default (sdpa) to avoid the crash.
Questions:
Does the official Alpamayo model officially support flash_attention_2?
If so, does it require special input preprocessing (e.g., sequence packing)?
Could this be a compatibility issue between Qwen3-VL (used as the VLM backbone) and FlashAttention-2 in transformers==4.57.1?
Any guidance on enabling FA2 safely would be greatly appreciated!
Thank you for your excellent open source work. I have a question about using flash_attention_2
Environment:
Hardware: NVIDIA A100
Python version: 3.12
Package versions:
torch==2.8.0
torchvision==0.23.0
transformers==4.57.1
flash-attn==2.8.3
Model: Official Alpamayo model from Hugging Face
Description:
When loading the official Alpamayo model, I found that the default attention implementation is sdpa even FlashAttention-2 is used in config.json.
To enable faster inference, I manually forced FlashAttention-2 by adding the following line in alpamayo/src/alpamayo_r1/models/base_model.py:
However, running the inference script (alpamayo/src/alpamayo_r1/test_inference.py) results in the following error:
This suggests a mismatch between the expected input format for FlashAttention-2 (which, in varlen mode, requires packed sequences and valid cu_seqlens_q) and the actual padded input used during inference.
Current Workaround:
Keep attn_implementation as the default (sdpa) to avoid the crash.
Questions:
Does the official Alpamayo model officially support flash_attention_2?
If so, does it require special input preprocessing (e.g., sequence packing)?
Could this be a compatibility issue between Qwen3-VL (used as the VLM backbone) and FlashAttention-2 in transformers==4.57.1?
Any guidance on enabling FA2 safely would be greatly appreciated!