A reproducible training pipeline for 16-class image classification on small datasets without any external pretraining.
It uses self-supervised learning (MoCo-style contrastive learning) as an auxiliary objective and jointly optimizes SSL + classification to improve generalization.
Best run (reference): 86.50% test accuracy using:
- ResNet-18 from scratch (no ImageNet, no foundation models)
- MoCo-style contrastive learning (queue + momentum encoder)
- Joint loss:
L = L_cls + λ * L_ssl
Small datasets love to punish overconfident models. This repo leans on self-supervised representation learning (MoCo) to learn more robust features using only the provided data, then uses a supervised head to do the actual 16-class classification.
No pretraining. No shortcuts.
ssl_moco_from_scratch_max_accuracy/
main.py
inference.py
generate_readme_images.py
plot.py
requirements.txt
ssl/
moco.py
data.py
utils.py
exp/
(created automatically)
plots/
(generated by scripts)
data/
train/ (ImageFolder)
test/ (ImageFolder)
python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
Recommended (reference environment):
- torch 2.4.1
- torchvision 0.19.1
- numpy 1.26.4
- matplotlib 3.9.2
- tqdm 4.65.0
- scikit-learn (for confusion matrix plots)
Standard torchvision ImageFolder:
data/
train/
class_0/
img1.jpg
...
class_1/
...
test/
class_0/
class_1/
...
Defaults match the strongest configuration from the report-style run:
- epochs=150, batch=8, lr=0.03, warmup+cosine, momentum=0.9, wd=1e-3
- MoCo dim=128, queue K=8192, m=0.999, T=0.07
- label smoothing=0.1, dropout=0.2
python main.py --data-dir data --num-classes 16 \
--arch resnet18 --epochs 150 --batch-size 8 --lr 0.03 \
--ssl-weight 1.0 --queue-size 8192 --m 0.999 --t 0.07
Outputs (in a timestamped folder under exp/):
best_model.pthtraining_log.txtaccuracy_plot.png
python inference.py --data-dir data --checkpoint exp/<RUN_DIR>/best_model.pth --num-classes 16
- Two augmented views per image (query/key)
- Momentum encoder updates:
θ_k ← m θ_k + (1-m) θ_q - Queue of negatives to stabilize contrastive learning on small batches
- Cross-entropy (optional label smoothing)
- Dropout before classifier for regularization
L_total = L_cls + λ * L_ssl
Overfitting risk: high (tiny dataset).
Mitigations included:
- dropout (default 0.2)
- label smoothing (default 0.1)
- strong augmentation
- weight decay (1e-3)
- SSL auxiliary objective for representation robustness