This article will guide you through developing a simple image classification model using PyTorch Lightning. Since model development involves reusable building blocks, you can easily adapt this project to suit your specific needs.
Table of contents
When to use deep learning models?
Deep learning models outperform traditional machine learning algorithms in feature extraction and semantic understanding. Unlike earlier methods where engineers had to manually define and identify task-specific features—an especially challenging task for unstructured data like text, images/videos, or audio—deep learning models automatically handle this complexity. Therefore, we recommend leveraging deep learning models for unstructured data or scenarios where defining appropriate features is particularly difficult.
Why do we choose PyTorch Lightning?
PyTorch Lightning is a wrapper of PyTorch and abstracts much of the boilerplate code associated with training neural networks, such as setting up the training loop, handling device management, and logging. This results in a more concise and focused codebase, making it easier to understand, develop, and maintain. For example, in PyTorch, you need to write explicit code for each step of the training process, while PyTorch Lightning simplifies this with a more streamlined API.
The standard development workflow
Problem formulation: Understanding inputs and outputs, along with the key characteristics of the target mapping function, such as producing interpretable outputs.
Model selection: Based on the problem definition, select an appropriate model zoo. If an exact match for the task is unavailable, consider using models designed for a similar task. For ML beginners, resources like Hugging Face and Papers with Code offer a comprehensive overview of machine learning tasks, models, and datasets.
Model development: Develop Lightning model and Lightning dataset.
Training & evaluation: Logging training process with tensorboard.
Debugging & optimization: Find and resolve the main training issues that affect models’ performance and reliability. For systematic guidance, you can check Debugging Model Training.
Some tips for model training
Use pre-trained weight instead of train a model from the scratch
Use three-stage process to debug model training:
Model issues: Ensure the model has sufficient capacity to capture the data distribution. This can be assessed by monitoring the training loss.
Data issues: Ensure the training data is sufficiently representative. For instance, if the data is only a small subset of the actual distribution, it can result in a significant performance disparity between training and testing.
Overfitting issues: Monitor the model's training process to determine if it is memorizing nuances or noise from the training data. This can be evaluated by comparing training loss with evaluation loss. To address overfitting, employ regularization techniques such as L1 or L2 regularization, dropout, data augmentation, or batch normalization.
Project structure
.
├── configs # experiment configs
└── resnet.yaml # train resnet on MNIST
├── datasets.py # dataset module
├── models.py # lightning model module
├── main.py # profiler, training & evaluation
└── readme.md # notes
Model annotations
"""
models.py
"""
import os
import torch
from torch import optim, nn
from torchvision.models import resnet18
from torchvision.transforms import Resize, Grayscale
from timm import create_model # Use timm for pre-trained Vision
import lightning as L
from torch.nn import functional as F
# Define the LightningModule for ResNet18
class LitResNet18(L.LightningModule):
def __init__(self, num_classes=10):
super().__init__()
# Load ResNet18 and modify the input/output layers for MNIST
self.model = resnet18(pretrained=True)
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # Adjust for single-channel input
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) # Adjust for MNIST classes
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("test_loss", loss)
self.log("test_acc", acc)
return loss
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, y = batch
return self(x)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
Dataset annotations
"""
datasets.py
"""
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split # Added imports
import lightning as L
import torch
import lightning as L
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "./data"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict":
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
self.mnist_train = None
self.mnist_val = None
self.mnist_test = None
self.mnist_predict = None
Train and evaluate model
PyTorch Lightning allows users train models from configs directly. In our example, we follow this and use a config to manage experiment’s hyperparameters.
"""
config/resnet.yaml
"""
model: LitResNet18
trainer:
max_epochs: 10
"""
main.py
"""
import torch
from lightning.pytorch.cli import LightningCLI
from models import LitResNet18, LitVisionTransformer
from datasets import MNISTDataModule
def cli_main():
cli = LightningCLI(datamodule_class=MNISTDataModule)
if __name__ == "__main__":
cli_main()
# train a resnet18 on MNIST
python -W ignore main.py fit --config configs/resnet.yaml
# test the best checkpoint
python -W ignore main.py test --model=LitResNet18 --ckpt_path xxx
Monitor the training process
In default, PyTorch Lightning uses tensorboard to record training logs (e.g., self.log("train_loss", loss)
) . After training, we can check logs via the following command.
tensorboard --logdir .
Github
All the code and examples are available on GitHub.
Miscellaneous
PyTorch Lightning is a widely-used framework built on PyTorch that simplifies the training process by abstracting away less critical components. It allows researchers to concentrate on model design without worrying about the engineering complexities of the training loop.
Conclusion
This guide simplifies image classification model development using PyTorch Lightning, emphasizing structured workflows (modular code, systematic debugging), practical adaptation of pre-trained models (e.g., ResNet18 for MNIST), and extensibility for future tasks. By abstracting boilerplate code and integrating tools like TensorBoard, PyTorch Lightning accelerates experimentation while ensuring reproducibility. The framework’s flexibility supports rapid customization and scaling, making it ideal for both prototyping and deploying robust deep learning solutions.