The aim of this repository is to implement bidirectional linear attention for non-causal modeling using Triton. Contributions and suggestions are welcome!
- [2026/02] Update PISA
- [2025/02] Update PolaFormer
- [2024/12] Update
simple_la, a simple form oflinear_attnwithout the norm term.
Roughly sorted according to the timeline supported in flash_bla
| Year | Model | Title | Paper | Code | fla impl |
|---|---|---|---|---|---|
| 2024 | Linfusion | LinFusion: 1 GPU, 1 Minute, 16K Image | arxiv | official | code |
| 2024 | MLLA | Demystify Mamba in Vision: A Linear Attention Perspective | arxiv | official | code |
| 2025 | PolaFormer | PolaFormer: Polarity-aware Linear Attention for Vision Transformers | arxiv | official | code |
| 2025 | RALA | Breaking the Low-Rank Dilemma of Linear Attention | arxiv | official | code |
| 2026 | PISA | PISA: Piecewise Sparse Attention Is Wiser for Efficient Diffusion Transformers | arxiv | official | code |
git clone https://github.com/fla-org/flash-bidirectional-linear-attention.git
pip install -e flash-bidirectional-linear-attention/.This library has integrated some models, which can be called directly. Taking LinFusion as an example:
import torch
from diffusers import AutoPipelineForText2Image
from flash_bla.models import LinFusion
sd_repo = "stabilityai/stable-diffusion-xl-base-1.0"
pipeline = AutoPipelineForText2Image.from_pretrained(
sd_repo, torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))
linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path="Yuanshi/LinFusion-XL")
image = pipeline(
"An astronaut floating in space. Beautiful view of the stars and the universe in the background.",
generator=torch.manual_seed(123)
).images[0]Profiled on the A800-80G GPU.
B8-H16-D64:
T torch_la_fwd flash_bla_fwd torch_sdpa_fwd torch_la_bwd flash_bla_bwd torch_sdpa_bwd
1024 0.083968 0.068608 0.073728 0.476160 0.378880 0.405504
4096 0.178176 0.083968 0.784384 1.018880 0.444416 3.175424
16384 0.549888 0.283648 11.750400 3.556352 1.566720 44.189184
32768 1.034240 0.550912 47.788033 6.864896 3.040256 175.127548Thanks to the following repositories for their inspiration:
