Skip to content

zhongshsh/ASR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stripe Observation Guided Inference Cost-Free Attention Mechanism

This repository provides the official implementation of "Stripe Observation Guided Inference Cost-Free Attention Mechanism" [paper], accepted at ECCV 2024.

🚀 Introduction

Structural Re-parameterization (SRP) is a technique that enhances neural networks without extra inference costs. While existing SRP methods work well with normalizations, convolutions, etc., they struggle with attention modules due to their multiplicative nature and input dependency.

🔍 Our Key Insight: Stripe Observation

We discover a phenomenon where channel attention values tend to stabilize to constant vectors during training. Inspired by this, we propose ASR (Attention-alike SRP), which allows attention mechanisms to be re-parameterized like SRP—achieving the benefits of attention without added inference cost.


Left: The phenomenon of Stripe Observation. Right: The architecture of ASR. During training, a learnable vector is used as input to the attention module. At inference time, the attention module is merged into the backbone network.

🛠️ Getting Started

We provide an example to train ResNet164 + ASR (SE) on CIFAR-100, where SE is a typical self-attention module. In this example, we merge our ASR into BatchNorm.

1️⃣ Install Dependencies

Ensure you have Python and PyTorch installed. Then, install the required packages:

pip install -r requirements.txt

2️⃣ Train a Model with ASR

python run.py --arch resnet --module senet --use-asr --dataset cifar100 --use-timestamp --gpu-id 0 

🔹 Explanation of Key Arguments

  • --module: Specifies the attention module. New modules can be added under models/modules.
  • --use-asr: Enables ASR to structurally re-parameterize the attention module.
  • --use-both: (Optional) Combines ASR (SE) and SE for further improvement.
Models Args Top-1 acc. ↑ Speed (FPS) ↑
ResNet164 74.10 [checkpoint] 438
+ SE --module senet 75.03 [checkpoint] 286
+ ASR (SE) --module senet --use-asr 75.21 [checkpoint] 442
+ ASR (SE) + SE --module senet --use-asr --use-both 76.23 [checkpoint] 289
  • (Cost-free) ASR enhances performance without adding inference cost.
    • Inference speed: ResNet164 + ASR (SE) = ResNet164.
    • Performance: ResNet164 + ASR (SE) ≈ ResNet164 + SE.
  • (Compatibility) ASR seamlessly integrates with existing attention modules, such as ResNet164 + ASR (SE) + SE.

🔧 How ASR Works

The core implementation is in models/asr.py. We integrate attention into BatchNorm2d.

🔹 Switching Between Training & Inference

This allows the attention module to be removed at inference time, making the model as efficient as the original backbone.

Before training, enable the computation of attention module:

for m in model.modules():
    if hasattr(m, 'switch_to_train'):
        m.switch_to_train()

Before testing, merge attention module to BatchNorm2d:

for m in model.modules():
    if hasattr(m, 'switch_to_deploy'):
        m.switch_to_deploy()

If the model is only used for inference, we can remove unnecessary modules using switch_to_deploy(delete=True) to accelerate the inference process.

📜 Citation

If you find this work useful, please cite:

@inproceedings{huang2024stripe,
  title={Stripe Observation Guided Inference Cost-Free Attention Mechanism},
  author={Huang, Zhongzhan and Zhong, Shanshan and Wen, Wushao and Qin, Jinghui and Lin, Liang},
  booktitle={European Conference on Computer Vision},
  pages={90--107},
  year={2024},
  organization={Springer}
}

🙌 Acknowledgments

This project is based on bearpaw's PyTorch classification framework. Many thanks for their clean and simple implementation!

About

ECCV‘24, a novel attention-alike structural re-parameterization (ASR)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages