Skip to content

ak811/mnist-simplenet-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MNIST SimpleNet with PyTorch

A small example project that trains a simple fully connected neural network on the MNIST handwritten digits dataset using PyTorch.
The project includes:

  • Training and evaluation scripts
  • Automatic download of MNIST
  • Saved model checkpoint
  • Plots for training loss and sample predictions embedded in this README

Project structure

.
├── data.py              # MNIST transforms and DataLoaders
├── model.py             # SimpleNet model definition
├── train.py             # Training loop and history logging
├── eval.py              # Evaluation and plot generation
├── models/              # Saved model checkpoints (created at runtime)
├── assets/              # Training history and plots (created at runtime)
├── requirements.txt
└── README.md

Getting started

Prerequisites

  • Python 3.10 or later
  • pip
  • (Optional) CUDA capable GPU and drivers for faster training

Install dependencies

git clone <your-repo-url> mnist-simplenet
cd mnist-simplenet

# (Recommended) create and activate a virtual environment

pip install -r requirements.txt

The MNIST dataset will be downloaded automatically to a data/ folder when you run the training script.


Training

Run the training script:

python train.py --epochs 5 --batch-size 64

This will:

  • Download MNIST if needed
  • Train SimpleNet for 5 epochs
  • Print epoch loss and accuracy
  • Save the trained model to models/simple_net_model.pth
  • Save training history (loss and accuracy per epoch) to assets/training_history.json

You can customize training using command line arguments:

python train.py   --epochs 10   --batch-size 128   --lr 0.005   --data-root data   --models-dir models   --assets-dir assets   --device auto   --seed 123

Parameters:

  • --epochs number of training epochs
  • --batch-size batch size
  • --lr learning rate
  • --data-root where MNIST is stored
  • --models-dir directory for checkpoint files
  • --assets-dir directory for JSON history and plots
  • --device "auto", "cpu", or "cuda"
  • --seed random seed for reproducibility

Evaluation and plot generation

After training, run:

python eval.py

This will:

  • Load the trained model from models/simple_net_model.pth
  • Evaluate it on the MNIST test set
  • Print the test accuracy
  • Load assets/training_history.json
  • Generate and save:
    • assets/training_loss.png
    • assets/mnist_predictions.png

You can also override defaults:

python eval.py   --batch-size 128   --data-root data   --models-dir models   --assets-dir assets   --device auto

Model architecture

SimpleNet is a small feedforward network:

  • Input: 28x28 grayscale image (flattened internally to 784 features)
  • Hidden layer: Linear(784, 128) with ReLU activation
  • Output layer: Linear(128, 10) for the 10 digit classes (0 to 9)

This keeps the project lightweight and easy to understand while still demonstrating the full training and evaluation pipeline.


Reproducibility

The training script supports a --seed argument and sets seeds for:

  • Python random
  • NumPy
  • PyTorch (CPU and CUDA)

Using the same seed and hyperparameters should give similar results across runs on the same hardware.


Extending this project

Some ideas for future improvements:

  • Replace SimpleNet with a convolutional neural network for higher accuracy
  • Add a validation split and early stopping
  • Log metrics with TensorBoard or another experiment tracker
  • Add unit tests for the data pipeline and model
  • Add a GitHub Actions workflow to run tests and linting on each push

About

MNIST SimpleNet with PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages