Skip to content

SvenLigensa/efficient-vision-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Efficient Vision Attention Mechanisms for Dense Prediction

This repository accompanies my master's thesis with the same title. It's structure is inspired by the repository of the Swin Transformer. It contains the following:

To compare the models in a fair manner, they were reimplemented based on the original implementations in a new, unified framework. The framework lets the user compose a model from three main components: the encoder, the decoder, and the attention mechanism.

Repository Structure

├── README.md
├── pyproject.toml                  Project configuration
├── .gitignore                      Files to exclude from this repository
├── uv.lock                         Lockfile containing python dependencies
├── .python-version                 Specifies the Python version of this project
├── .pre-commit-config.yaml         Defines the pre-commit hooks
├── code                            Contains this project's code
│   ├── main.py                     Entry point for building, debugging, training, and evaluating the model
│   ├── config.py                   Handles configuration and sets default values
│   ├── config.pyi                  Defines types of configuration values for autocomplete
│   ├── loss.py                     Defines the loss used for training
│   ├── data                        Contains files to build and inspect the dataset
│   │   ├── build.py                Creates the dataset and dataloader objects for the model
│   │   ├── dataset.py              Defines the `SatelliteDataset` (how it is loaded, etc.)
│   │   ├── create_dataset.py       Creates the dataset based on the preprocessed data provided by Jan Pauls
│   │   ├── split_dataset.py        Splits the dataset into a training and validation set
│   │   └── Show_Dataset.ipynb      For visual inspection of the datasetcode
│   │
│   ├── models                      Model-related code
│   │   ├── build.py                Instantiates the encoder-decoder model according to the configuration
│   │   ├── transformer_layers.py   Common building blocks shared by multiple models
│   │   ├── encoders                Re-implementation of encoder models, i.e., producing a hierarchical feature map
│   │   ├── decoders                Re-implementation of decoder models, i.e., producing a pixel-wise regression map
│   │   └── attention               Re-implementation of attention mechanisms to be used in the encoders and decoders
│   │
│   ├── utils                       Utility functions
│   │   ├── checkpoint.py           Functions for saving and loading model checkpoints
│   │   ├── my_logging.py           Functions related to logging
│   │   ├── my_tensor.py            Common reshape operations
│   │   ├── wandb.py                Functions for wandb tracking
│   │   └── window.py               Helper functions for windowed attention mechansims
│   │
│   └── visualizations              Code to prepare data for visualization
│       ├── crop_images.py          Cropping of satellite images / predictions
│       ├── visualize_profiles.py   Visualize the memory profiles of the profiling traces and store them to .dat format
│       └── efficient_attention.py  Generates the tables for the explanation of efficient attention in the appendix
│
├── configs                         YAML configurations to initialize the models and SLURM scripts to run them
│   ├── _base_                      Contains the default configurations of the models
│   └── <i>_<experiment_name>       Configurations and scripts for the i-th experiment
│
└── output                          Output files (excluded via .gitignore, as they are too large)

Logging

  • Logging functions can be found in utils/logging.py.
  • Logging is performed via a singleton logger, which is created by calling init_logger() in main.py.
  • To have less clutter in the code, two annotations are introduced:
    1. log_signature() to log the values passed to a function (usually used for __init__())
    2. log_shapes() to log the shape of tensors (usually used for forward())
  • Alternatively, other files can use the singleton logger by calling logger = get_logger()
  • config.DEBUG_MODE == True → Logging level set to DEBUG
  • config.DEBUG_MODE == False → Logging level set to INFO

Configuration

The configuration is done via YACS. A new experiment is performed by:

  1. Creating a new config file in /code/configs/new_config.yaml. All values for configuring the classes are taken exclusively from this config file.
  2. Running the pipeline via torchrun --nproc_per_node=1 --master_port=$RANDOM -m code.main --debug --cfg configs/new_config.yaml

No Default Values

The constructors in this repository in general do not have any default values. This forces the programmer to be specific about which parameters need to be passed (usually only the configuration), which ensures that no class has a state which was not declared in the configuration.

PALMA Jobs

The models were trained on PALMA. The scripts for training are placed inside the configs/ folder

Transformer vs. Convolution View

  • Typically, Tensors in

    • Transformers have shape (B, N, C)
    • CNNs have shape (B, C, H, W)
  • utils/my_tensor implements typical conversions

Convention in this repository: Tensors typically have shape (B, N, C), except when performing a convolution op on them $\Rightarrow$ typical pattern:

x = bnc2bchw(x)     # Convert from Transformer-view to Conv-view
x = some_conv_op(x) # Perform some conv op
x = bchw2bnc(x)     # Convert from Conv-view to Transformer view
  • NOTE: This repository only works for square images and window sizes!
    • (Makes code significantly simpler)

Unifying the Transformers

  • While the implementations of the different model variants are based on their original repositories, this framework tries to unify the implementations, which comes with multiple benefits
    • Better overview about what is actually different (not just difference in coding style)
    • Easier "interpolation" between model variants via CONFIG (e.g. have Mix-FFN of MiT and SW-MSA from Swin)

stage_idx Parameter

The stage_idx$\in { 0, \cdots, depths }$ parameter indicates how "deep" we are inside the architecture.

For the SwinUnet it looks e.g. like this:

---- Encoder ---- -- Decoder --
0 -> 1 -> 2 -> 3 -> 2 -> 1 -> 0

Development

This repository uses uv for dependency management and black and isort for formatting, which can be run like this:

  • uvx black .
  • uvx isort .

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors