Skip to content

Training Image Classification in PyTorch

Deeplake makes it easy to train image classification models by streaming data directly from managed tables into a PyTorch training loop. In this tutorial, you will ingest the Fashion MNIST dataset from HuggingFace into a Deeplake managed table, then train a ResNet18 model using a standard PyTorch DataLoader.

Objective

Ingest the Fashion MNIST dataset from HuggingFace into a Deeplake managed table, then train a ResNet18 model by streaming data with a PyTorch DataLoader.

Prerequisites

  • Deeplake SDK: pip install deeplake
  • PyTorch and torchvision: pip install torch torchvision
  • A Deeplake API token.

Set credentials first

export DEEPLAKE_API_KEY="your-token-here"
export DEEPLAKE_WORKSPACE="your-workspace"  # optional, defaults to "default"

Complete Code

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from deeplake import Client

# --- Configuration ---
TABLE_NAME = "fashion_mnist"
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 0.001

# --- 1. Ingest the Dataset from HuggingFace ---
client = Client()
client.ingest(TABLE_NAME, {"_huggingface": "fashion_mnist"})

# --- 2. Create the DataLoader ---
ds = client.open_table(TABLE_NAME)

tform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def apply_transform(sample):
    sample["image"] = tform(sample["image"])
    return sample

train_loader = DataLoader(
    ds.pytorch(transform=apply_transform),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)

# --- 3. Define the Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(weights="DEFAULT")
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)  # 10 Fashion MNIST classes
model = model.to(device)

# --- 4. Training Loop ---
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, batch in enumerate(train_loader):
        images = batch["image"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if i % 100 == 99:
            print(f"Epoch {epoch+1}, Batch {i+1}: "
                  f"Loss={running_loss/100:.3f}, "
                  f"Acc={100.*correct/total:.1f}%")
            running_loss = 0.0

    print(f"Epoch {epoch+1} complete - Accuracy: {100.*correct/total:.1f}%")

# --- 5. Evaluate ---
model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for batch in train_loader:
        images = batch["image"].to(device)
        labels = batch["label"].to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

print(f"Final Accuracy: {100.*test_correct/test_total:.1f}%")

Step-by-Step Breakdown

1. Ingest the Dataset

Deeplake can ingest datasets directly from HuggingFace with a single call. The _huggingface key tells the platform to pull the dataset by name, automatically mapping its columns (image, label) into a managed table.

client = Client()
client.ingest(TABLE_NAME, {"_huggingface": "fashion_mnist"})

If the table already exists, you can skip this step and go straight to open_table.

2. Create the DataLoader

Open the managed table and wrap it in a standard PyTorch DataLoader. The ds.pytorch() method returns a map-style dataset that streams data directly from Deeplake's storage engine. Pass a transform function that receives the full sample dict and returns a modified copy.

ds = client.open_table(TABLE_NAME)

tform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def apply_transform(sample):
    sample["image"] = tform(sample["image"])
    return sample

train_loader = DataLoader(
    ds.pytorch(transform=apply_transform),
    batch_size=32,
    shuffle=True,
    num_workers=4,
)

Each batch is a dictionary where keys match the table columns: batch["image"] contains the image tensors and batch["label"] contains the class indices.

3. Define the Model

We use a pretrained ResNet18 and adapt it for Fashion MNIST. Two modifications are needed: the first convolutional layer is changed from 3-channel RGB to 1-channel grayscale, and the final fully connected layer is resized to output 10 classes.

model = models.resnet18(weights="DEFAULT")
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

4. Training Loop

A standard PyTorch training loop with CrossEntropyLoss and SGD. Data streams from the managed table through the DataLoader exactly as it would from a local dataset. No special handling is required.

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    model.train()
    for batch in train_loader:
        images = batch["image"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

5. Evaluate

Switch the model to evaluation mode and run a pass over the data to measure accuracy. For a production setup, you would ingest a separate test split and evaluate against that.

model.eval()
with torch.no_grad():
    for batch in train_loader:
        images = batch["image"].to(device)
        labels = batch["label"].to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)

Why no REST API?

Streaming high-performance tensor data over standard REST endpoints introduces significant latency and CPU overhead due to HTTP headers and JSON serialization. For high-throughput training, the Python SDK is the only supported method as it uses optimized C++ streaming kernels.

What to try next