📎 [Arxiv], 🚀 [NVlabs/MaskLLM (Official)]
This repo contains a minimal re-implementation of the paper "MaskLLM: Learnable Semi-structured Sparsity for Large Language Models" for vision tasks.
- ViTs on ImageNet-1k (Classification)
- DiTs on ImageNet-1k (Generation)
- Multi-modal LLMs
- Other vision tasks (e.g., object detection, segmentation)
- TensorRT examples
To enable MaskLLM on your model, please replace the nn.Linear with sparsity.maskllm.MaskedLinear. This can be easily achieved with the following code snippet:
import sparsity
model = sparsity.utils.replace_linear_with_(model, sparsity.maskllm.MaskedLinear, exclude=[model.head], N=2, M=4, hard=False)
print(model)| Method | Sparsity Pattern | Weight Update | Mask Prior | Top-1 Acc. (%) |
|---|---|---|---|---|
| ViT-B/16 (in1k) | Dense | - | - | 79.15 |
| Magnitude | 2:4 | - | - | 65.92 |
| Wanda | 2:4 | - | - | 63.28 |
| SparseGPT | 2:4 | ✔️ | - | 71.52 |
| SparseGPT w/o Update | 2:4 | - | - | 59.72 |
| MaskLLM-4V (1 Epoch) | 2:4 | - | SparseGPT | 76.23 |
| MaskLLM-4V (1 Epoch) | 2:4 | - | Magnitude | 76.18 |
| MaskLLM-4V (20 Epochs) | 2:4 | - | SparseGPT | 79.46 |
| MaskLLM-4V (20 Epochs) | 2:4 | - | Magnitude | 79.28 |
Note: MaskLLM learns a separate mask with frozen network parameters for sparsification. For ViT-B/16, we can find a lossless mask through end-to-end training
Please prepare the ImageNet-1k dataset under ./data/imagenet directory. The directory structure should be as follows:
data
├── imagenet
│ ├── train
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
│ └── val
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n014913611.1. MaskLLM for Vision Transformers
We trained MaskLLM on ViT-B/16 with 4x24GB GPUs, requiring 17G memory on each GPU with a batch size of 128.
We first generate prior masks using oneshot pruning. This prior mask will hugely accelerate the convergence speed of the MaskLLM. replace the --pruner argument with magnitude, wanda, or sparsegpt to generate different prior masks.
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner sparsegpt --save-model output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24.ptWe took training hyperparameters from this timm issue. By default, we train the model with EMA for 20 epochs. For one-epoch training, please disable EMA like this script.
bash scripts/maskllm_vit_base_patch16_224.augreg_in1k.sparsegpt24.shpython timm_validate.py --model vit_base_patch16_224 --checkpoint output/maskllm_vit_base_patch16_224.augreg_in1k.sparsegpt24/MaskLLM-4V/model_best.pth.tar --sparsity-mode maskllm{
"model": "vit_base_patch16_224",
"top1": 79.456,
"top1_err": 20.544,
"top5": 94.548,
"top5_err": 5.452,
"param_count": 213.97,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}To perform MaskLLM on other models or prior types, please change the --model and --checkpoint arguments.
Detailed Instructions
# Eval
python timm_validate.py --model vit_base_patch16_224.augreg_in1k --pretrained{
"model": "vit_base_patch16_224.augreg_in1k",
"top1": 79.158,
"top1_err": 20.842,
"top5": 94.088,
"top5_err": 5.912,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}Detailed Instructions
# Magnitude pruning
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner magnitude --save-model output/pruned/vit_base_patch16_224.augreg_in1k.magnitude24.pt
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.magnitude24.pt --sparsity-mode sparse{
"model": "vit_base_patch16_224",
"top1": 65.92,
"top1_err": 34.08,
"top5": 86.058,
"top5_err": 13.942,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}Detailed Instructions
# Wanda pruning
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner wanda --save-model output/pruned/vit_base_patch16_224.augreg_in1k.wanda24.pt
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.wanda24.pt --sparsity-mode sparse{
"model": "vit_base_patch16_224",
"top1": 63.282,
"top1_err": 36.718,
"top5": 85.574,
"top5_err": 14.426,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}Detailed Instructions
# SparseGPT pruning
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner sparsegpt --save-model output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24.pt
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24.pt --sparsity-mode sparse{
"model": "vit_base_patch16_224",
"top1": 59.728,
"top1_err": 40.272,
"top5": 82.326,
"top5_err": 17.674,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}# SparseGPT pruning with weight update
python oneshot_pruning_timm.py --model vit_base_patch16_224.augreg_in1k --pruner sparsegpt --save-model output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24_updated.pt --enable-update
# Eval
python timm_validate.py --model vit_base_patch16_224 --checkpoint output/pruned/vit_base_patch16_224.augreg_in1k.sparsegpt24_updated.pt --sparsity-mode sparse{
"model": "vit_base_patch16_224",
"top1": 71.52,
"top1_err": 28.48,
"top5": 90.246,
"top5_err": 9.754,
"param_count": 86.57,
"img_size": 224,
"crop_pct": 0.9,
"interpolation": "bicubic"
}This part is still in progress. Please stay tuned.
Please prepare the ImageNet-1k dataset under ./data/imagenet directory. The directory structure should be as follows:
data
├── imagenet
│ ├── train
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n01491361
│ └── val
│ │ ├── n01440764
│ │ ├── n01443537
│ │ ├── n01484850
│ │ ├── n014913612.1 MaskLLM for Diffusion Transformers
TODO
python sample.py --model DiT-XL/2python oneshot_pruning_dit.py --model DiT-XL/2 --pruner magnitude python oneshot_pruning_dit.py --model DiT-XL/2 --pruner wanda python oneshot_pruning_dit.py --model DiT-XL/2 --pruner sparsegpt python oneshot_pruning_dit.py --model DiT-XL/2 --pruner sparsegpt --enable-updateThis project is based on the following repositories:
If you find this repository helpful, please consider citing the following paper.
@article{fang2024maskllm,
title={Maskllm: Learnable semi-structured sparsity for large language models},
author={Fang, Gongfan and Yin, Hongxu and Muralidharan, Saurav and Heinrich, Greg and Pool, Jeff and Kautz, Jan and Molchanov, Pavlo and Wang, Xinchao},
journal={arXiv preprint arXiv:2409.17481},
year={2024}
}





