Skip to content

fla-org/flash-bidirectional-linear-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

42 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash Bidirectional Linear Attention

The aim of this repository is to implement bidirectional linear attention for non-causal modeling using Triton. Contributions and suggestions are welcome!

image

Update

  • [2026/02] Update PISA
  • [2025/02] Update PolaFormer
  • [2024/12] Update simple_la, a simple form of linear_attn without the norm term.

Models

Roughly sorted according to the timeline supported in flash_bla

Year Model Title Paper Code fla impl
2024 Linfusion LinFusion: 1 GPU, 1 Minute, 16K Image arxiv official code
2024 MLLA Demystify Mamba in Vision: A Linear Attention Perspective arxiv official code
2025 PolaFormer PolaFormer: Polarity-aware Linear Attention for Vision Transformers arxiv official code
2025 RALA Breaking the Low-Rank Dilemma of Linear Attention arxiv official code
2026 PISA PISA: Piecewise Sparse Attention Is Wiser for Efficient Diffusion Transformers arxiv official code

Usage

Installation

git clone https://github.com/fla-org/flash-bidirectional-linear-attention.git
pip install -e flash-bidirectional-linear-attention/.

Integrated Models

This library has integrated some models, which can be called directly. Taking LinFusion as an example:

import torch
from diffusers import AutoPipelineForText2Image
from flash_bla.models import LinFusion

sd_repo = "stabilityai/stable-diffusion-xl-base-1.0"


pipeline = AutoPipelineForText2Image.from_pretrained(
    sd_repo, torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))

linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path="Yuanshi/LinFusion-XL")

image = pipeline(
    "An astronaut floating in space. Beautiful view of the stars and the universe in the background.",
    generator=torch.manual_seed(123)
).images[0]

Benchmarks

Profiled on the A800-80G GPU.

B8-H16-D64:
    T  torch_la_fwd  flash_bla_fwd  torch_sdpa_fwd  torch_la_bwd  flash_bla_bwd  torch_sdpa_bwd
 1024      0.083968       0.068608        0.073728      0.476160       0.378880        0.405504
 4096      0.178176       0.083968        0.784384      1.018880       0.444416        3.175424
16384      0.549888       0.283648       11.750400      3.556352       1.566720       44.189184
32768      1.034240       0.550912       47.788033      6.864896       3.040256      175.127548

Acknowledgments

Thanks to the following repositories for their inspiration:

About

Triton implement of bi-directional (non-causal) linear attention

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages