Skip to content

ak811/moco-joint-ssl-training

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SSL-First Image Classification (MoCo + Joint Optimization)

A reproducible training pipeline for 16-class image classification on small datasets without any external pretraining.
It uses self-supervised learning (MoCo-style contrastive learning) as an auxiliary objective and jointly optimizes SSL + classification to improve generalization.

Best run (reference): 86.50% test accuracy using:

  • ResNet-18 from scratch (no ImageNet, no foundation models)
  • MoCo-style contrastive learning (queue + momentum encoder)
  • Joint loss: L = L_cls + λ * L_ssl

Why this exists

Small datasets love to punish overconfident models. This repo leans on self-supervised representation learning (MoCo) to learn more robust features using only the provided data, then uses a supervised head to do the actual 16-class classification.

No pretraining. No shortcuts.


Project layout

ssl_moco_from_scratch_max_accuracy/
  main.py
  inference.py
  generate_readme_images.py
  plot.py
  requirements.txt
  ssl/
    moco.py
    data.py
    utils.py
  exp/
    (created automatically)
  plots/
    (generated by scripts)
  data/
    train/   (ImageFolder)
    test/    (ImageFolder)

Setup

python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt

Recommended (reference environment):

  • torch 2.4.1
  • torchvision 0.19.1
  • numpy 1.26.4
  • matplotlib 3.9.2
  • tqdm 4.65.0
  • scikit-learn (for confusion matrix plots)

Data format

Standard torchvision ImageFolder:

data/
  train/
    class_0/
      img1.jpg
      ...
    class_1/
      ...
  test/
    class_0/
    class_1/
    ...

Training (MoCo + supervised joint optimization)

Defaults match the strongest configuration from the report-style run:

  • epochs=150, batch=8, lr=0.03, warmup+cosine, momentum=0.9, wd=1e-3
  • MoCo dim=128, queue K=8192, m=0.999, T=0.07
  • label smoothing=0.1, dropout=0.2
python main.py --data-dir data --num-classes 16 \
  --arch resnet18 --epochs 150 --batch-size 8 --lr 0.03 \
  --ssl-weight 1.0 --queue-size 8192 --m 0.999 --t 0.07

Outputs (in a timestamped folder under exp/):

  • best_model.pth
  • training_log.txt
  • accuracy_plot.png

Inference (evaluate a checkpoint)

python inference.py --data-dir data --checkpoint exp/<RUN_DIR>/best_model.pth --num-classes 16

Training recipe

SSL (MoCo-style)

  • Two augmented views per image (query/key)
  • Momentum encoder updates: θ_k ← m θ_k + (1-m) θ_q
  • Queue of negatives to stabilize contrastive learning on small batches

Supervised head

  • Cross-entropy (optional label smoothing)
  • Dropout before classifier for regularization

Joint objective

L_total = L_cls + λ * L_ssl


Drawbacks and mitigations

Overfitting risk: high (tiny dataset).
Mitigations included:

  • dropout (default 0.2)
  • label smoothing (default 0.1)
  • strong augmentation
  • weight decay (1e-3)
  • SSL auxiliary objective for representation robustness

About

SSL-First Image Classification (MoCo + Joint Optimization)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages