A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training
- [2025/11] π We release MagiAttention-v1.0.5 with native support for (distributed) learnable attention sink mechanism in both Flex-Flash-Attention and MagiAttention, plus a drop-in integration for Flash-Attention via our Extensions, alongside which we provide a blog post that shares our design insights and implementation details. Furthermore, we support native group collective kernels for intranode communication based on DeepEP as an experimental feature.
- [2025/09] π We release MagiAttention-v1.0.4 to update the API, support compilable and jit-built FFA, optimize the performance for sparse scenarios, reduce the workspace memory usage, and engage some experimental features in progress.
- [2025/07] π We release MagiAttention-v1.0.3 with improvements including documentation, support for all four mask types with arbitary overlapping, deterministic mode, API updates, FFA performance enhancements with bug fixes, optimized dispatch solvers, hierarchical-comm support, and example codes to train Llama-3 1B model with MagiAttention + FSDP / Transformers.
- [2025/06] π We release MagiAttention-v1.0.2 to provide the example code to integrate Megatron-LM with MagiAttention with several training convergence experiments (see here for more details), with some bug fixes and a roadmap added.
- [2025/05] π We release MagiAttention-v1.0.1 to support overlapped q_ranges when all mask types are
FULL, with some code cleanup and bug fixes. - [2025/04] π We release MagiAttention-v1.0.0 with its blog: a distributed attention towards linear scalability for ultra-long context, heterogeneous mask training.
MagiAttention is a distributed attention mechanism, or context-parallel (CP) strategy, which aims to support a wide variety of attention mask types with kernel-level flexibility, while achieving linear scalability with respect to context-parallel (CP) size across a broad range of scenarios, particularly suitable for training tasks involving ultra-long, heterogeneous mask training like video-generation for Magi-1.
Additionally, it can be easily integrated into prevalent training frameworks such as Megatron-LM, Pytorch's native FSDP and transformers, as illustrated in QuickStart.
We are committed to continually improving the performance and generality of MagiAttention for the broader research community. Stay tuned for exciting enhancements and new features on the horizon!
To realize linear scalability for distributed attention, we implement and introduce key designs as follows.
For implementation details, more experimental results and future works, please visit our blog.
- Flexible Flash Attention Kernel. We introduce a generalized formulation for irregular attention mask patterns and implement a flexible flash attention kernel (FFA). It is natively designed for distribution scenarios and provides greater flexibility in handling diverse attention mask types, with performance comparable to Flash-Attention 3 on Hopper GPUs.
- Computation Load-Balance. With a fine-grained sharding strategy, we elaborate an efficient dispatch solver that ensures balanced attention computational loads across each CP rank in every training iteration.
- Zero-Redundant Communication. Instead of adopting the common Ring-style P2P communication pattern in CP, we propose two novel communication primitives, GroupCast and GroupReduce, built upon All-to-All-v as a prototypal implementation, enabling zero-redundant communication volume for both forward and backward passes.
- Adaptive Multi-Stage Overlap. Leveraging the above enhancements, we further implement a multi-stage compute-communication overlap strategy that effectively hides communication latency and adaptively optimizes overlap through manual or automatic tuning.
Please check here.
-
NGC pytorch docker release note: here
-
docker run command:
# choose one compatible version MAJOR_VERSION=25 MINOR_VERSION=10 # choose from {05, 06, 08, 09, 10} # specify your own names and paths CONTAINER_NAME=... HOST_MNT_ROOT=... CONTAINER_MNT_ROOT=... docker run --name ${CONTAINER_NAME} -v ${HOST_MNT_ROOT}:${CONTAINER_MNT_ROOT} -it -d --privileged --gpus all --network host --ipc host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/pytorch:${MAJOR_VERSION}.${MINOR_VERSION}-py3 /bin/bash
-
docker exec command:
docker exec -it ${CONTAINER_NAME} /bin/bash
-
command:
pip install -r requirements.txt
-
command:
git clone https://github.com/SandAI-org/MagiAttention.git cd MagiAttention git submodule update --init --recursive # NOTE: this progress may take around 20~30 minute and occupies a lot of CPU resources for the first time. pip install --no-build-isolation .
Warning
MagiAttention currently only supports Hopper GPUs. We intend to broaden this support in upcoming updates.
We provide basic example code below of how to use flex_flash_attention (non-distributed attention function) and magi_attention (distributed attention mechanism), respectively.
For more usage instructions, you can refer to our docs.
Basic Usage
-
flex_flash_attention:
import torch from magi_attention.api import flex_flash_attn_func # --- Define attention config --- # total_seqlen = 2048 # 2k tokens seqlen_sink = 4 # 4 sink tokens num_heads_q = 8 # number of attention (query) heads num_heads_kv = 2 # number of key/value heads (GQA) head_dim = 128 # dimension of each attention head dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype is always fp32 for ffa right now) device = "cuda" has_sink = True # whether to apply attention sink # --- Initialize q,k,v,do tensors --- # q = torch.randn(total_seqlen, num_heads_q, head_dim, dtype=dtype, device=device, requires_grad=True) k = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device, requires_grad=True) v = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device, requires_grad=True) do = torch.randn_like(q) # --- Initialize optional sink tensor --- # sink = torch.randn(seqlen_sink, num_heads_q, dtype=torch.float32, device=device, requires_grad=True) if has_sink else None # --- Initialize FFA meta args for customized attention mask --- # # the following customized attention mask looks like (`*` for unmasked, `0` for masked): # - - - - - - - - -> (k) # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * * 0 0 0 # | * * * * * * 0 0 # | * * * * * * * 0 # | * * * * * * * * # V # (q) q_ranges_tensor = torch.tensor([[0, 1024], [1024, 2048]], dtype=torch.int32, device=device) k_ranges_tensor = torch.tensor([[0, 1024], [0, 2048]], dtype=torch.int32, device=device) attn_type_map_tensor = torch.tensor([0, 1], dtype=torch.int32, device=device) # full mask for 1st slice, causal mask for 2nd # --- Attention computation --- # out, lse = flex_flash_attn_func( q=q, k=k, v=v, q_ranges=q_ranges_tensor, k_ranges=k_ranges_tensor, attn_type_map=attn_type_map_tensor, sink=sink, # Defaults to None to not apply attention sink softmax_scale=None, # Defaults to 1/sqrt(head_dim) softcap=0, # Defaults to 0 ) out.backward(do) dq, dk, dv = q.grad, k.grad, v.grad dsink = sink.grad if has_sink else None
-
magi_attention: (NOTE: You need to run the following examples in a distributed environment, e.g. using the common
torchrunscript)# run this python script with the command like: # torchrun --standalone --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 ${SCRIPT_PATH} import torch import torch.nn as nn import torch.distributed as dist import magi_attention from magi_attention.api import ( magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions compute_pad_size, # helper functions ) from magi_attention.common import AttnRanges from magi_attention.common.enum import AttnMaskType from magi_attention.utils import setup_dist_env, clearup_dist_env # --- Set up distributed environment --- # rank, local_rank, world_size, num_nodes, num_local_ranks, world_group, device, seed = setup_dist_env() # --- Define attention config --- # total_seqlen = 32 * 1024 # 32k tokens, if we dispatch it to 8 GPUs, then each GPU holds 4k tokens seqlen_sink = 4 # 4 sink tokens num_heads_q = 48 # number of attention (query) heads num_heads_kv = 8 # number of key/value heads (GQA) head_dim = 128 # dimension of each attention head chunk_size = 512 # chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance. dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype for partial attention outputs is always fp32 for magi_attention right now) has_sink = True # whether to apply attention sink # --- Initialize token embedding tensor --- # embed_dim = 4096 x = torch.randn(total_seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True) # --- Initialize MagiAttention meta configs for customized attention mask --- # # the following customized attention mask is known as `block-causal` mask where `block_size` = 4096 (4k), # which looks like (`*` for unmasked, `0` for masked): # - - - - - - - - -> (k) # | * * 0 0 0 0 0 0 # | * * 0 0 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * * * 0 0 # | * * * * * * 0 0 # | * * * * * * * * # | * * * * * * * * # V # (q) q_ranges = AttnRanges.from_ranges( [ [0, 4096], # 0~4k [4096, 8192], # 4k~8k [8192, 12288], # 8k~12k [12288, 16384], # 12k~16k [16384, 20480], # 16k~20k [20480, 24576], # 20k~24k [24576, 28672], # 24k~28k [28672, 32768], # 28k~32k ] ) k_ranges = AttnRanges.from_ranges( [ [0, 4096], # 0~4k [0, 8192], # 0~8k [0, 12288], # 0~12k [0, 16384], # 0~16k [0, 20480], # 0~20k [0, 24576], # 0~24k [0, 28672], # 0~28k [0, 32768], # 0~32k ] ) attn_mask_type = [AttnMaskType.FULL] * len(q_ranges) total_seqlen_q = total_seqlen_k = total_seqlen pad_size = compute_pad_size( # pad embeds along seqlen dim for better performance total_seqlen_q=total_seqlen_q, cp_size=world_size, # assuming we only have 1-dim context parallelism (cp) chunk_size=chunk_size, ) # --- Dispatch token embedding tensor along seqlen dim to multiple ranks --- # # NOTE: # 1. the dispatched local token embedding may be shuffled along seqlen dim, # so it's safe for token-wise operations such as matmul, layer-norm, etc # while for sample-wise operations like RoPE, you might need to be more careful # 2. the `magi_attn_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs, # which users donβt have to bother with local_x, magi_attn_runtime_key = magi_attn_flex_dispatch( x, q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp) ) # --- Simulate QKV projection --- # q_proj = nn.Linear(embed_dim, num_heads_q * head_dim, dtype=dtype, device=device) k_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device) v_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device) local_q = q_proj(local_x).view(-1, num_heads_q, head_dim) local_k = k_proj(local_x).view(-1, num_heads_kv, head_dim) local_v = v_proj(local_x).view(-1, num_heads_kv, head_dim) # --- Simulate attention sink parameter --- # global_sink = nn.Parameter(torch.randn(seqlen_sink, num_heads_q, dtype=torch.float32, device=device)) if has_sink else None # --- Distributed attention computation --- # local_out, local_lse = calc_attn( q=local_q, k=local_k, v=local_v, key=magi_attn_runtime_key, sink=global_sink, # Defaults to None to not apply attention sink ) # --- Undispatch the output tensor along seqlen dim from multiple ranks and unpad --- # # NOTE: the undispatch API may not be used until the moment you need the seqlen dimension to be compelete and ordered, # e.g. for either aforementioned sample-wise operations, or loss computation total_out = undispatch( x=local_out, key=magi_attn_runtime_key, ) # --- Simulate loss computation --- # loss = total_out.sum() # --- Simulate backward pass --- # loss.backward() dx = x.grad dq_proj, dk_proj, dv_proj = q_proj.weight.grad, k_proj.weight.grad, v_proj.weight.grad if has_sink: dsink = global_sink.grad # NOTE: since usually the training framework such as Megatron-LM, FSDP # will handle the reduction of parameters' gradients across the whole dp x cp group # so by default, MagiAttention will skip the reduction of sink's gradients # unless the users specify the environment variable `MAGI_ATTENTION_DSINK_ALL_REDUCE_OP` (see our docs for more details) if (op:=magi_attention.comm.dsink_all_reduce_op()) != "none": match op: case "sum": dist.all_reduce(dsink, op=dist.ReduceOp.SUM, group=world_group) case "avg": dist.all_reduce(dsink, op=dist.ReduceOp.AVG, group=world_group) case _: raise ValueError(f"Unknown all_reduce_op: {op}") # --- Clear up distributed environment --- # clearup_dist_env()
We provide an example of how to integrate magi_attention with fsdp2 in examples/torch_native. You can use bash run.sh to run the example.
In this example, we build a llama-1b model and apply fsdp2 with magi_attention as the parallelism strategy.
examples/torch_native/modeling_llama.py: build llama model and integrate with magi_attention.examples/torch_native/main.py: main training loop.
We create a new repository Megatron-LM-MagiAttention, forked from Megatron-LM v0.11.0, to provide an example of training the llama-1B model with Megatron-LM + MagiAttention. Furthermore, we conducted an experiment training llama-3-1B model from scratch to verify the convergence of magiattention.
For more information, you can refer to examples/megatron/README.md.
We provide an example of how to integrate magi_attention with transformers in examples/transformers. Furthermore, we conducted a continue-training experiment on llama-3-1B model to verify the convergence of magiattention.
For more information, you can refer to examples/transformers/README.md.
We provide additional magi_attn_extensions to offer supplementary utilities based on magi_attention, such as FlashAttention with Attention Sink.
- [WIP] Optimize
Flex-Flash-Attentionkernels to improve performance and better support sparse attention (such as NSA). - [WIP] Optimize
DistAttnSolverto reduce CPU overhead for meta info calculation and support better comp-/comm- overlapping. - [WIP] Support
Dynamic DistAttnSolverwith query/output communication pattern, one for either hybrid attention model or dynamic mask scenarios like sparse attention, the other for reducing communication overhead for many cases when only communicating key/value is not the best choice. - Support other attention patterns including cross-attention, and inference scenarios involving KV cache (w.r.t. Paged Attention).
- Support Ampere, Blackwell as well as other GPU architectures.
- Provide a more comprehensive documentation with tutorials, and a more detailed technical blog.
- Provide more example codes and recipes for various training scenarios.
- Upgrade
MagiAttentionto a distributed nativeFlex-Flash-Attentionkernel (as a major version update). - Support native
GroupCastandGroupReducecommunication kernels with inter-/intra-node hierarchical optimization (similar to DeepEP). - Support learnable attention sink (w.r.t. StreamingLLM).
- Refactor
Distributed Attention Solverto support all mask types with all kinds of overlap. - Improve
Dispatch Solverto reduce necessary communication volumn while remaining balance in computation (especially for varlen mask patterns). - Build a comprehensive
CP Benchmarkto better compare the performance of different context parallel strategies under various mask patterns and other training configurations. - Provide
DocumentationincludingInstallation,QuickStartandAPI reference.
To demonstrate FFA kernels' state-of-the-art performance and flexibility in handling ultra-long, heterogeneous mask training, we measure the computing power (in
| settings | value |
|---|---|
| batch size (b) | 1 |
| number of heads (nh) | nhq:nhk:nhv = 64:8:8 (GQA) |
| head dimension (hd) | 128 |
| dtype | torch.bfloat16 |
| dropout probability | 0.0 |
| window size | 1024 (for sliding window masks only) |
Benchmark settings: for each mask pattern, we vary the sequence length seqlen from seqlen_q = seqlen_k = seqlen) while measuring computation power (in seqlen.
Some Results are reported in the following figures, see more in our blog.
To validate the scalability of MagiAttention, we assess the per-GPU computing power (in
The experiments are conducted on a large-scale productive GPU cluster (Due to business and confidentiality reasons, specific details about the productive cluster, such as the number and type of GPUs, are withheld.). We scale the total sequence length seqlen, the context-parallel size cp_size, and the node size nnodes together from seqlen:64k, cp_size:1, nnodes:1, seqlen:128k, cp_size:2, nnodes:2, ..., to seqlen:3072k (3M), cp_size:48, nnodes:48.
The tensor-parallel size tp_size is fixed at 8, with sequence-parallel enabled. Other data and model configurations for different mask types are the same as in the table in Kernel-Level Experiments.
Therefore, in every training setting, each rank is assigned constantly with seqlen=64k, num_heads_q = 8 and num_heads_k = 1 for attention propagation, while the remaining activations stays seqlen=8k, num_heads_q = 64 and num_heads_k = 8 with SP enabled. This setup simulates a common training configuration.
Some of the results are presented in the following figures, see more in our blog.
As demonstrated, MagiAttention exhibits linear scalability as the context length and CP size increase, in both full mask and varlen full mask configurations, for both forward and backward passes. In contrast, baseline methods either face strict limitations in scaling up or experience performance degradation with ultra-long contexts, which worsens with varlen mask patterns.
We welcome and value any contributions and collaborations. Please check out CONTRIBUTING.md for how to get involved.
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
If you find MagiAttention useful in your research, please cite:
@misc{magiattention2025,
title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
author={Zewei, Tao and Yunpeng, Huang},
year={2025},
howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}We are grateful to the contributors listed below for their valuable contributions during the early stages of MagiAttention.
| Member | Affiliations | GitHub Account | |
|---|---|---|---|
| Zewei Tao | SandAI | [email protected] | littsk |
| Yunpeng Huang | SandAI | [email protected] | Strivin0311 |
| Qiangang Wang | SandAI, Nanjing University | [email protected] | WT1W |
| Hanwen Sun | SandAI, Peking University | [email protected] | hanwen-sun |
| Jin Li | SandAI, Tsinghua University | [email protected] | lijinnn |
| Tao Bu | Nanjing University | [email protected] | Big-TRex |
| WenYang Fang | Nanjing University | [email protected] | kagami4243 |
| Siyuang Yan | Nanjing University | [email protected] | FibonaccciYan |
| Zixu Jiang | Nanjing University | [email protected] | 191220042 |
| Dingkun Xu | Nanjing University | [email protected] | PureDimension |
| Mingyu Liang | Nanjing University | [email protected] | gaomusiki |
| Jingwei Xu | Nanjing University | [email protected] | paragonlight |
