Skip to content

Conversation

@tohtana
Copy link
Collaborator

@tohtana tohtana commented Jun 25, 2025

This PR improves the coverage of DeepCompile.

  • Use real parameters when recompilation happens
  • Handling overflow error in profiling

This PR should be merged after #7366.

ZeRO1 and ZeRO3 both worked with OpenRLHF. See Wiki page for more details.

Masahiro Tanaka and others added 21 commits June 14, 2025 01:39
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
@loadams
Copy link
Collaborator

loadams commented Jun 25, 2025

@tohtana - the HPU is down currently, so I'll remove this test for now.

@tohtana tohtana merged commit 6594c26 into master Jun 27, 2025
9 checks passed
@tohtana tohtana deleted the tohtana/dc_improve_z3_coverage branch June 27, 2025 23:28
@hijkzzz
Copy link

hijkzzz commented Jun 30, 2025

When will support for flashattn 2.8.0 be available?

@tohtana
Copy link
Collaborator Author

tohtana commented Jun 30, 2025

Hi @hijkzzz,

It is more about how PyTorch supports flash-attention. With the packing option enabled in OpenRLHF, transformers tried to use flash_attn_varlen_forward. But PyTorch compiler can't compile it. I tried flash-attn v2.8 but the matrix was the same.

Here is the error I got from flash-attention+packing+compile. This error happens without DeepSpeed.

Dynamo failed to run FX node with fake tensors: call_function flash_attn._flash_attn_varlen_forward(*(FakeTensor(..., device='cuda:0', size=(2328, 32, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(2328, 8, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(2328, 8, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), 0.0, 0.08838834764831845), **{'causal': True, 'window_size_left': -1, 'window_size_right': -1, 'softcap': 0.0, 'alibi_slopes': None, 'return_softmax': False, 'block_table': None}): got RuntimeError("flash_attn::_flash_attn_varlen_forward() Expected a value of type 'int' for argument 'max_seqlen_q' but instead found type 'FakeTensor'.\nPosition: 5\nValue: FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)\nDeclaration: flash_attn::_flash_attn_varlen_forward(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left=-1, SymInt window_size_right=-1, float softcap=0., Tensor? alibi_slopes=None, bool return_softmax=False, Tensor? block_table=None, Tensor? leftpad_k=None, Tensor? seqused_k=None, bool zero_tensors=False) -> (Tensor, Tensor, Tensor, Tensor)\nCast error details: Unable to cast Python instance of type <class 'torch._subclasses.fake_tensor.FakeTensor'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")

from user code:
   File "/home/mtanaka/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 279, in torch_dynamo_resume_in__flash_attention_forward_at_272
    attn_output = flash_attn_varlen_func(
  File "/home/mtanaka/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
  File "/home/mtanaka/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
    out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(

@hijkzzz
Copy link

hijkzzz commented Jul 8, 2025

Hi @hijkzzz,

It is more about how PyTorch supports flash-attention. With the packing option enabled in OpenRLHF, transformers tried to use flash_attn_varlen_forward. But PyTorch compiler can't compile it. I tried flash-attn v2.8 but the matrix was the same.

Here is the error I got from flash-attention+packing+compile. This error happens without DeepSpeed.

Dynamo failed to run FX node with fake tensors: call_function flash_attn._flash_attn_varlen_forward(*(FakeTensor(..., device='cuda:0', size=(2328, 32, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(2328, 8, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(2328, 8, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), 0.0, 0.08838834764831845), **{'causal': True, 'window_size_left': -1, 'window_size_right': -1, 'softcap': 0.0, 'alibi_slopes': None, 'return_softmax': False, 'block_table': None}): got RuntimeError("flash_attn::_flash_attn_varlen_forward() Expected a value of type 'int' for argument 'max_seqlen_q' but instead found type 'FakeTensor'.\nPosition: 5\nValue: FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)\nDeclaration: flash_attn::_flash_attn_varlen_forward(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left=-1, SymInt window_size_right=-1, float softcap=0., Tensor? alibi_slopes=None, bool return_softmax=False, Tensor? block_table=None, Tensor? leftpad_k=None, Tensor? seqused_k=None, bool zero_tensors=False) -> (Tensor, Tensor, Tensor, Tensor)\nCast error details: Unable to cast Python instance of type <class 'torch._subclasses.fake_tensor.FakeTensor'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")

from user code:
   File "/home/mtanaka/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 279, in torch_dynamo_resume_in__flash_attention_forward_at_272
    attn_output = flash_attn_varlen_func(
  File "/home/mtanaka/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
  File "/home/mtanaka/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
    out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(

@tohtana Flash-attn is currently a strong dependency for frameworks like OpenRLHF and VerL Alibaba ROLL. Is there any way to work around this issue?

@tohtana
Copy link
Collaborator Author

tohtana commented Jul 8, 2025

@hijkzzz Yeah, I understand the eager implementation is not a choice. What do you think about SDPF or disabling "packing sample" option? In my environment, DeepCompile+SDPF showed better performance with OpenRLHF than using flash attention.

@hijkzzz
Copy link

hijkzzz commented Jul 8, 2025

@hijkzzz Yeah, I understand the eager implementation is not a choice. What do you think about SDPF or disabling "packing sample" option? In my environment, DeepCompile+SDPF showed better performance with OpenRLHF than using flash attention.

@tohtana flash-attn and packing_samples are extremely important for RLHF, and even for SFT... they're not something we can afford to drop for now.

@tohtana
Copy link
Collaborator Author

tohtana commented Jul 9, 2025

@hijkzzz Sure, I understand the importance. I'm just wondering how much the performance difference with SDPA is (Sorry for the typo in my last message). From my understanding, SDPA calls the flash-attention kernels (or almost equivalent) inside, and you can enable it just by settinging sdpa to attn_implementation if you use HF models. At least, I didn't see much performance difference when I tried it.

@hijkzzz
Copy link

hijkzzz commented Jul 14, 2025

@hijkzzz Sure, I understand the importance. I'm just wondering how much the performance difference with SDPA is (Sorry for the typo in my last message). From my understanding, SDPA calls the flash-attention kernels (or almost equivalent) inside, and you can enable it just by settinging sdpa to attn_implementation if you use HF models. At least, I didn't see much performance difference when I tried it.

The packing samples / ring attn directly depends on the API interface of flash attn lib, so we are temporarily unable to do it this way.

lpnpcs pushed a commit to lpnpcs/DeepSpeed that referenced this pull request Jul 30, 2025
This PR improves the coverage of DeepCompile.

- Use real parameters when recompilation happens
- Handling overflow error in profiling

This PR should be merged after deepspeedai#7366.

ZeRO1 and ZeRO3 both worked with OpenRLHF. See [Wiki
page](https://github.com/tohtana/DeepCompile_docs/wiki/Debug-with-OpenRLHF-(%237243))
for more details.

---------

Signed-off-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
This PR improves the coverage of DeepCompile.

- Use real parameters when recompilation happens
- Handling overflow error in profiling

This PR should be merged after deepspeedai#7366.

ZeRO1 and ZeRO3 both worked with OpenRLHF. See [Wiki
page](https://github.com/tohtana/DeepCompile_docs/wiki/Debug-with-OpenRLHF-(%237243))
for more details.

---------

Signed-off-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants