Skip to content

Training with PyTorch Lightning

PyTorch Lightning removes boilerplate from your training code while giving you full control over the model, optimizer, and data flow. By combining it with Deeplake's streaming DataLoader, you get a clean, scalable training pipeline that streams data directly from managed tables. No local downloads required.

Objective

Train a ResNet-18 image classifier on Fashion MNIST using PyTorch Lightning, with data streamed from a Deeplake managed table.

Prerequisites

  • Deeplake SDK: pip install deeplake
  • PyTorch and Lightning: pip install torch torchvision pytorch-lightning
  • A Deeplake API token.

Set credentials first

export DEEPLAKE_API_KEY="your-token-here"
export DEEPLAKE_WORKSPACE="your-workspace"  # optional, defaults to "default"

Complete Code

import torch
from torch.utils.data import DataLoader
from torchvision import models
import pytorch_lightning as pl
from deeplake import Client

# 1. Setup
client = Client()
client.ingest("fashion_mnist", {"_huggingface": "fashion_mnist"})

# 2. Create DataLoaders
ds = client.open_table("fashion_mnist")

# For a real project, ingest train and test splits as separate tables.
train_loader = DataLoader(ds.pytorch(), batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(ds.pytorch(), batch_size=32, num_workers=4)

# 3. Define the LightningModule
class FashionClassifier(pl.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.1)

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"].squeeze()
        preds = self(images)
        loss = self.loss_fn(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"].squeeze()
        preds = self(images).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        self.log("val_acc", acc, prog_bar=True)

# 4. Train
trainer = pl.Trainer(max_epochs=3)
trainer.fit(model=FashionClassifier(), train_dataloaders=train_loader, val_dataloaders=val_loader)

Step-by-Step Breakdown

1. Ingest the Dataset

client.ingest("fashion_mnist", {"_huggingface": "fashion_mnist"})

Deeplake pulls Fashion MNIST directly from Hugging Face and stores it as a managed table. This only needs to run once, subsequent runs can skip this step and go straight to open_table().

2. Create DataLoaders

ds = client.open_table("fashion_mnist")

train_loader = DataLoader(ds.pytorch(), batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(ds.pytorch(), batch_size=32, num_workers=4)

ds.pytorch() returns a map-style dataset that plugs directly into a standard PyTorch DataLoader. Each batch is a dictionary with keys matching the table columns (image, label). Deeplake handles streaming and decompression behind the scenes.

3. Define the LightningModule

The FashionClassifier wraps a ResNet-18 backbone adapted for single-channel grayscale images. Lightning's structure keeps the training logic organized:

  • training_step: computes cross-entropy loss and logs accuracy.
  • validation_step: tracks validation accuracy for monitoring.
  • configure_optimizers: returns a standard SGD optimizer.

4. Train with Lightning Trainer

trainer = pl.Trainer(max_epochs=3)
trainer.fit(model=FashionClassifier(), train_dataloaders=train_loader, val_dataloaders=val_loader)

Lightning handles the epoch loop, metric logging, device placement, and checkpointing. If a GPU is available, Trainer will use it automatically.

What to try next