Skip to content

ak811/fine-grained-pruning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fine-grained Pruning: PyTorch Implementation

This repo implements Fine-grained Pruning using a VGG-style CNN.

It includes:

  • A VGG model (conv-bn-relu blocks + classifier)
  • CIFAR-10 dataloaders + transforms
  • Helpers for accuracy, sparsity, parameter counts, model size (bits/MiB)
  • Magnitude-based fine-grained (unstructured) pruning (binary mask + in-place pruning)
  • Scripts that reproduce the key lab steps and generate figures for reports

Repo Structure

fine_grained_pruning_lab/
  assets/                      # put images here
  data/
    cifar10.py                 # dataloaders/transforms
  models/
    vgg.py                     # VGG used in the lab
  pruning/
    fine_grained.py            # magnitude pruning + model helpers
  scripts/
    evaluate_dense.py          # eval dense checkpoint
    plot_weight_distribution.py# histogram plots per layer
    demo_prune_tensor.py       # visualize pruning on a toy tensor
    prune_one_layer.py         # prune a single conv/fc layer by index
    prune_sweep.py             # sweep sparsity and save curve (no README image)
  tests/
    test_fine_grained.py
    test_vgg_shapes.py
  utils/
    metrics.py                 # accuracy/sparsity/model size
    profiling.py               # MACs (torchprofile)
    seed.py                    # deterministic seeds
    viz.py                     # plotting helpers
  requirements.txt

Installation

python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt

Dataset

By default, scripts download CIFAR-10 into ./data/cifar10.
Override with --data-dir.


Pretrained Checkpoint

The lab uses a pretrained checkpoint like:

  • model_199-1.tar

Put it somewhere (example: ./checkpoints/model_199-1.tar) and pass via --ckpt.


1) Evaluate Dense Model

python -m scripts.evaluate_dense --ckpt checkpoints/model_199-1.tar

Example output:

dense accuracy: 88.42%
dense model size: 35.19 MiB (assuming fp32)
MACs (1x3x32x32): 314,572,800

Accuracy depends on the checkpoint you use.


2) Weight Distribution (Dense Model)

Generate a histogram grid across weight tensors (Conv/Linear weights).

python -m scripts.plot_weight_distribution \
  --ckpt checkpoints/model_199-1.tar \
  --out assets/weight_histogram.png

Example output:

Saved: assets/weight_histogram.png

Weight Histogram


Magnitude-based Fine-grained Pruning

The pruning function is magnitude-based:

  • Step 1: compute num_zeros = round(sparsity * num_elements)
  • Step 2: importance = abs(weight)
  • Step 3: threshold via kthvalue(num_zeros)
  • Step 4: mask importance > threshold
  • Step 5: apply mask in-place

Verify pruning on a dummy tensor (sparsity = 0.75)

Generate the “dense vs sparse tensor” plot used for the verification demo:

python -m scripts.demo_prune_tensor \
  --sparsity 0.75 \
  --out assets/pruning_q2_tensor.png

Expected output (matches the lab demo):

* Test fine_grained_prune()
    target sparsity: 0.75
        sparsity before pruning: 0.04
        sparsity after pruning: 0.76
        sparsity of pruning mask: 0.76
* Test passed.

Pruning Tensor (0.75 sparsity)


Set Sparsity So Only 10 Nonzeros Remain

The tensor is 5×5 → 25 total elements.
To keep 10 nonzeros, prune 15 weights:

  • target sparsity = (25 - 10) / 25 = 0.60

Generate the plot for this case:

python -m scripts.demo_prune_tensor \
  --sparsity 0.60 \
  --out assets/pruning_q3_tensor.png

Expected output (matches the lab target):

* Test fine_grained_prune()
    target sparsity: 0.60
        sparsity before pruning: 0.04
        sparsity after pruning: 0.60
        sparsity of pruning mask: 0.60
* Test passed.

Pruning Tensor (10 nonzeros)


Prune One Layer of VGG (Layer Index)

Prune a single prune-eligible layer by index (weights where dim() > 1) and evaluate.

python -m scripts.prune_one_layer \
  --ckpt checkpoints/model_199-1.tar \
  --layer-index 1 \
  --sparsity 0.75

Example output:

baseline accuracy: 88.42%
pruned layer: 1 | backbone.conv1.weight
layer sparsity: before=0.000 after=0.752 (target=0.750)
after-prune accuracy: 72.18%

Tests

pytest -q

Example output:

2 passed in 1.02s

Notes

  • Fine-grained pruning is unstructured sparsity: it reduces parameter count, but speedups usually require sparse kernels/runtimes.
  • Magnitude pruning works well early because many trained weights cluster near zero.

References

  • Han et al., 2015: Learning both Weights and Connections for Efficient Neural Networks
  • CIFAR-10 dataset (Krizhevsky)
  • PyTorch docs

About

Fine-grained Pruning: PyTorch Implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages