Skip to content

rccchoudhury/apt

Repository files navigation

Rohan Choudhury*1, JungEun Kim*2,3, Jinhyung Park1, Eunho Yang2, László A. Jeni1, Kris M. Kitani1

1Carnegie Mellon University, 2KAIST, 3General Robotics


Standard

APT (ours)

TL;DR

We accelerate Vision Transformers by using adaptive patch sizes based on image content. Instead of using fixed-size patches for all regions, our method dynamically selects smaller patches for detailed areas and larger patches for uniform regions, reducing computational cost while maintaining accuracy. This reduces the total number of patches while maintaining performance.

We are releasing an initial version of our code, and in the next few days will add detailed training code and pretrained checkpoints. Please bear with us as we continue to clean our code and add more capabilities. If you have anything in particular you would like to see, don't hesitate to file a GitHub issue!

Setup

We use mamba and uv for installing packages, but conda and pip should work too.

# Create a new environment with mamba
mamba create -n apt python=3.10 -y
mamba activate apt

# Install PyTorch (adjust for your CUDA version)
mamba install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -y

# Install remaining dependencies with uv pip
uv pip install -r requirements.txt

Dataset Setup

Datasets require configuration through YAML files located in configs/data/. To set up ImageNet:

  1. Edit configs/data/imagenet.yaml and update the data_dir field to point to your ImageNet dataset:
data_dir: /path/to/ILSVRC2012
  1. Ensure your ImageNet directory has the following structure:
ILSVRC2012/
├── train/
│   ├── n01440764/
│   ├── n01443537/
│   └── ...
└── val/
    ├── n01440764/
    ├── n01443537/
    └── ...

The data_dir should point to the ILSVRC2012 directory containing the train and val folders.

Model Support

We support a wide range of timm models through automatic checkpoint downloading. The models we've tested are covered by the configs in configs/model_variants/, and our code automatically downloads the required checkpoint as needed. While we've primarily tested on ViT variants, we cannot guarantee support for all timm models. If you encounter any issues, please file a GitHub issue.

Running A Model

Validate Vision Transformer on ImageNet

# Run validation with default settings (8 GPUs, batch size 64)
python src/eval.py experiment=validate_vit

# Override specific parameters
python src/eval.py experiment=validate_vit \
  trainer.devices=4 \
  data.batch_size=128

# Adjust thresholds and scales
python src/eval.py experiment=validate_vit \
  num_scales=3 \
  thresholds=[6.0, 4.0]

Training

The below script fine-tunes a ViT with APT from a pretrained checkpoint for 50 epochs; you can modify the config to do shorter fine-tuning from an already fine-tuned version as well.

# Fine-tune ViT-B/16 on ImageNet with APT (8 GPUs, 5 epochs)
python src/train.py experiment=train_vit_finetune 

Visualizing APT Patches

Generate a visualization showing how APT selects patches for an image with the following command:

python scripts/gen_visualization_single.py \
  --input /your/input/image.jpg \
  --output /your/output/visualization.jpg \
  --method entropy \
  --vis_type grid

Essential parameters: --input (required), --output, --method (entropy/laplacian/upsample_mse), --vis_type (entropy/grid/none), --patch_size, --num_scales, --thresholds.

About

Public release of the code for "Accelerating Vision Transformers with Adaptive Patches"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages