Training with PyTorch Lightning¶
PyTorch Lightning removes boilerplate from your training code while giving you full control over the model, optimizer, and data flow. By combining it with Deeplake's streaming DataLoader, you get a clean, scalable training pipeline that streams data directly from managed tables. No local downloads required.
Objective¶
Train a ResNet-18 image classifier on Fashion MNIST using PyTorch Lightning, with data streamed from a Deeplake managed table.
Prerequisites¶
- Deeplake SDK:
pip install deeplake - PyTorch and Lightning:
pip install torch torchvision pytorch-lightning - A Deeplake API token.
Set credentials first
Complete Code¶
import torch
from torch.utils.data import DataLoader
from torchvision import models
import pytorch_lightning as pl
from deeplake import Client
# 1. Setup
client = Client()
client.ingest("fashion_mnist", {"_huggingface": "fashion_mnist"})
# 2. Create DataLoaders
ds = client.open_table("fashion_mnist")
# For a real project, ingest train and test splits as separate tables.
train_loader = DataLoader(ds.pytorch(), batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(ds.pytorch(), batch_size=32, num_workers=4)
# 3. Define the LightningModule
class FashionClassifier(pl.LightningModule):
def __init__(self, num_classes=10):
super().__init__()
self.model = models.resnet18(weights=None)
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.1)
def training_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"].squeeze()
preds = self(images)
loss = self.loss_fn(preds, labels)
acc = (preds.argmax(dim=-1) == labels).float().mean()
self.log("train_loss", loss)
self.log("train_acc", acc, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"].squeeze()
preds = self(images).argmax(dim=-1)
acc = (labels == preds).float().mean()
self.log("val_acc", acc, prog_bar=True)
# 4. Train
trainer = pl.Trainer(max_epochs=3)
trainer.fit(model=FashionClassifier(), train_dataloaders=train_loader, val_dataloaders=val_loader)
Step-by-Step Breakdown¶
1. Ingest the Dataset¶
Deeplake pulls Fashion MNIST directly from Hugging Face and stores it as a managed table. This only needs to run once, subsequent runs can skip this step and go straight to open_table().
2. Create DataLoaders¶
ds = client.open_table("fashion_mnist")
train_loader = DataLoader(ds.pytorch(), batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(ds.pytorch(), batch_size=32, num_workers=4)
ds.pytorch() returns a map-style dataset that plugs directly into a standard PyTorch DataLoader. Each batch is a dictionary with keys matching the table columns (image, label). Deeplake handles streaming and decompression behind the scenes.
3. Define the LightningModule¶
The FashionClassifier wraps a ResNet-18 backbone adapted for single-channel grayscale images. Lightning's structure keeps the training logic organized:
training_step: computes cross-entropy loss and logs accuracy.validation_step: tracks validation accuracy for monitoring.configure_optimizers: returns a standard SGD optimizer.
4. Train with Lightning Trainer¶
trainer = pl.Trainer(max_epochs=3)
trainer.fit(model=FashionClassifier(), train_dataloaders=train_loader, val_dataloaders=val_loader)
Lightning handles the epoch loop, metric logging, device placement, and checkpointing. If a GPU is available, Trainer will use it automatically.
What to try next¶
- GPU-Streaming Pipeline: learn about direct-to-GPU data streaming.
- Massive Ingestion: prepare large-scale datasets for training.
- Reference: Querying: details on
open_table().