Training Image Classification in PyTorch¶
Deeplake makes it easy to train image classification models by streaming data directly from managed tables into a PyTorch training loop. In this tutorial, you will ingest the Fashion MNIST dataset from HuggingFace into a Deeplake managed table, then train a ResNet18 model using a standard PyTorch DataLoader.
Objective¶
Ingest the Fashion MNIST dataset from HuggingFace into a Deeplake managed table, then train a ResNet18 model by streaming data with a PyTorch DataLoader.
Prerequisites¶
- Deeplake SDK:
pip install deeplake - PyTorch and torchvision:
pip install torch torchvision - A Deeplake API token.
Set credentials first
Complete Code¶
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from deeplake import Client
# --- Configuration ---
TABLE_NAME = "fashion_mnist"
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 0.001
# --- 1. Ingest the Dataset from HuggingFace ---
client = Client()
client.ingest(TABLE_NAME, {"_huggingface": "fashion_mnist"})
# --- 2. Create the DataLoader ---
ds = client.open_table(TABLE_NAME)
tform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def apply_transform(sample):
sample["image"] = tform(sample["image"])
return sample
train_loader = DataLoader(
ds.pytorch(transform=apply_transform),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
)
# --- 3. Define the Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(weights="DEFAULT")
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10) # 10 Fashion MNIST classes
model = model.to(device)
# --- 4. Training Loop ---
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
for epoch in range(NUM_EPOCHS):
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, batch in enumerate(train_loader):
images = batch["image"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
if i % 100 == 99:
print(f"Epoch {epoch+1}, Batch {i+1}: "
f"Loss={running_loss/100:.3f}, "
f"Acc={100.*correct/total:.1f}%")
running_loss = 0.0
print(f"Epoch {epoch+1} complete - Accuracy: {100.*correct/total:.1f}%")
# --- 5. Evaluate ---
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
for batch in train_loader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
outputs = model(images)
_, predicted = outputs.max(1)
test_total += labels.size(0)
test_correct += predicted.eq(labels).sum().item()
print(f"Final Accuracy: {100.*test_correct/test_total:.1f}%")
Step-by-Step Breakdown¶
1. Ingest the Dataset¶
Deeplake can ingest datasets directly from HuggingFace with a single call. The _huggingface key tells the platform to pull the dataset by name, automatically mapping its columns (image, label) into a managed table.
If the table already exists, you can skip this step and go straight to open_table.
2. Create the DataLoader¶
Open the managed table and wrap it in a standard PyTorch DataLoader. The ds.pytorch() method returns a map-style dataset that streams data directly from Deeplake's storage engine. Pass a transform function that receives the full sample dict and returns a modified copy.
ds = client.open_table(TABLE_NAME)
tform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def apply_transform(sample):
sample["image"] = tform(sample["image"])
return sample
train_loader = DataLoader(
ds.pytorch(transform=apply_transform),
batch_size=32,
shuffle=True,
num_workers=4,
)
Each batch is a dictionary where keys match the table columns: batch["image"] contains the image tensors and batch["label"] contains the class indices.
3. Define the Model¶
We use a pretrained ResNet18 and adapt it for Fashion MNIST. Two modifications are needed: the first convolutional layer is changed from 3-channel RGB to 1-channel grayscale, and the final fully connected layer is resized to output 10 classes.
model = models.resnet18(weights="DEFAULT")
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
4. Training Loop¶
A standard PyTorch training loop with CrossEntropyLoss and SGD. Data streams from the managed table through the DataLoader exactly as it would from a local dataset. No special handling is required.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(NUM_EPOCHS):
model.train()
for batch in train_loader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
5. Evaluate¶
Switch the model to evaluation mode and run a pass over the data to measure accuracy. For a production setup, you would ingest a separate test split and evaluate against that.
model.eval()
with torch.no_grad():
for batch in train_loader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
outputs = model(images)
_, predicted = outputs.max(1)
Why no REST API?¶
Streaming high-performance tensor data over standard REST endpoints introduces significant latency and CPU overhead due to HTTP headers and JSON serialization. For high-throughput training, the Python SDK is the only supported method as it uses optimized C++ streaming kernels.
What to try next¶
- GPU-Streaming Pipeline: learn more about direct-to-GPU data streaming.
- Massive Ingestion: prepare large-scale datasets for training.
- Semantic Search: search your dataset by content similarity.
- Reference: Querying: details on
open_table().