Skip to content

VLA Finetuning with LeRobot Data

Finetune a Vision-Language-Action (VLA) model on robotics data stored in Deeplake. VLA models take camera images + language instructions and predict robot actions. This guide covers the full end-to-end process: connecting to your data, loading a pretrained VLA, building the training pipeline, and saving the finetuned adapter.

Deeplake's ds.batches() dataloader streams data directly from cloud storage at ~1 Gb/s — no local copy needed.

Ingest your data first

This guide assumes your LeRobot dataset is already in Deeplake. See the LeRobot Integration guide for ingestion.

Prerequisites

  • Deeplake SDK: pip install deeplake
  • PyTorch and vision: pip install torch torchvision
  • VLA dependencies: pip install "transformers>=4.40,<4.46" "timm>=0.9.10,<1.0" accelerate peft pillow
  • A Deeplake API token.
  • A GPU with at least 24 GB VRAM (A10, L4, or better). LoRA keeps memory low enough for a single GPU.

OpenVLA version requirements

OpenVLA's custom model code requires transformers<4.46 and timm<1.0. Newer versions cause import errors or attribute mismatches. Tested with: deeplake==4.5.5, torch==2.10.0, transformers==4.45.2, timm==0.9.16, peft==0.18.1.

Set credentials first

export DEEPLAKE_API_KEY="your-token-here"
export DEEPLAKE_WORKSPACE="your-workspace"  # optional, defaults to "default"

Finetuning Code

import io
import torch
import numpy as np
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from deeplake import Client
from transformers import AutoModelForVision2Seq, AutoProcessor
from peft import LoraConfig, get_peft_model

# ── 1. Open the ingested table with Deeplake dataloader ──────
client = Client()
ds = client.open_table("aloha_shrimp")

# Filter to overhead camera only, query returns a DatasetView
# that supports the same .batches() streaming API
ds = ds.query("select * where camera_name = 'cam_high'")

# ── 2. Load pretrained VLA with LoRA ─────────────────────────
MODEL_ID = "openvla/openvla-7b"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()  # ~0.5% of total

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ── 3. Define batch transform for VLA input format ────────────
# ThreadPool decodes JPEGs in parallel, 1.8x faster than single-threaded PIL
decode_pool = ThreadPoolExecutor(max_workers=8)

def decode_jpeg(img_bytes):
    return Image.open(io.BytesIO(img_bytes)).convert("RGB")

def prepare_vla_batch(batch):
    """Transform a ds.batches() batch dict into VLA model inputs."""
    images = list(decode_pool.map(decode_jpeg, batch["image"]))
    tasks = batch["task"]
    prompts = [
        f"In: What action should the robot take to {(t or 'perform the manipulation task').lower()}?\n"
        for t in tasks
    ]

    inputs = processor(
        images=images,
        text=prompts,
        return_tensors="pt",
        padding="max_length",
        max_length=256,
        truncation=True,
    )

    actions = torch.from_numpy(batch["action"])
    return inputs, actions

# ── 4. Training loop using ds.batches() ──────────────────────
# ds.batches() uses Deeplake's C++ async prefetcher (~1 Gb/s).
# This is much faster than DataLoader for cloud-streamed data.
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
GRAD_ACCUM_STEPS = 8  # effective batch = 4 × 8 = 32

model.train()
for epoch in range(3):
    total_loss = 0.0
    n_steps = 0

    for batch in ds.batches(batch_size=4):
        inputs, actions_target = prepare_vla_batch(batch)
        actions_target = actions_target.to(device)
        # Cast float tensors to bf16 to match model weights
        batch_inputs = {k: v.to(device, dtype=torch.bfloat16) if v.is_floating_point()
                        else v.to(device) for k, v in inputs.items()}

        outputs = model(**batch_inputs)

        logits = outputs.logits[:, -1, :actions_target.shape[-1]]
        loss = torch.nn.functional.mse_loss(logits.float(), actions_target.float())
        loss = loss / GRAD_ACCUM_STEPS
        loss.backward()

        if (n_steps + 1) % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * GRAD_ACCUM_STEPS
        n_steps += 1

    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch {epoch+1}: avg loss = {total_loss / max(n_steps, 1):.4f}")

# ── 5. Save finetuned adapter ────────────────────────────────
model.save_pretrained("openvla-aloha-lora")
print("LoRA adapter saved to openvla-aloha-lora/")

Why ds.batches()

ds.batches() uses Deeplake's C++ async prefetcher which streams at multiple Gb/s from cloud storage. See the dataloaders benchmark for details.

How the Deeplake Dataloader Works

Step What happens
open_table() Connects to Deeplake storage engine — no SQL layer, no local download
.query(...) Filters rows on the dataset, returns a DatasetView with the same .batches() API
client.table().select().where().execute() Fluent query builder for exploring data via the managed API
client.query(sql) Raw SQL queries via the managed API
.batches(batch_size) C++ async prefetcher (~1 Gb/s), yields dict batches with multi-threaded I/O

ds.batches() streams directly from cloud storage with constant memory, regardless of dataset size.

Curating Training Data with Queries

Dataset query — filter before streaming

Chain .query() on the dataset to filter what gets streamed to the dataloader:

ds = client.open_table("aloha_shrimp")

# Only overhead camera, first 5 episodes
view = ds.query("""
    select * where camera_name = 'cam_high'
    and episode_index < 5
""")
print(f"Training on {len(view)} frames")

# Stream the filtered view with ds.batches()
for batch in view.batches(batch_size=4):
    inputs, actions = prepare_vla_batch(batch)
    # ... training step
    break

Fluent query — inspect and explore data

Use the fluent query builder on client.table() to explore data before training:

# Fluent chaining — inspect specific frames
moving = (
    client.table("aloha_shrimp")
        .select("frame_index", "camera_name", "action")
        .where("camera_name = 'cam_high'")
        .order_by("frame_index")
        .limit(10)
        .execute()
)
for row in moving:
    print(f"  frame {row['frame_index']}: action={row['action']}")

# Raw SQL — same result
rows = client.query("""
    SELECT frame_index, camera_name, action
    FROM aloha_shrimp
    WHERE camera_name = 'cam_high'
    ORDER BY frame_index
    LIMIT 10
""")

Supported VLA Models

Model Parameters Action Head GPU Requirement
OpenVLA 7B Tokenized (256 bins) 1×A100 / 1×L4 with LoRA
Octo 93M Diffusion 1×A10
RT-2 55B Language tokens 4×A100
π₀ 3B Flow matching 1×A100 / 1×L4 with LoRA

Tips

  • Filter first, train second: Use .query() to select high-quality demonstrations before streaming. Curated subsets often outperform training on everything.
  • LoRA: Full finetuning a 7B VLA needs 4×A100. LoRA trains ~0.5% of parameters on a single 24 GB GPU, producing a ~50 MB adapter.
  • Gradient accumulation: With LoRA on 24 GB, batch size 4 + 8 accumulation steps ≈ effective batch 32 without OOM.
  • Image resolution: Most VLAs resize to 224×224 internally. The processor handles it. No need to pre-resize in Deeplake.
  • Action heads vary: OpenVLA discretizes into 256 bins (cross-entropy), Octo uses diffusion, RT-2 outputs language tokens. Check the model's finetuning guide for the correct loss.

Loss function

This example uses MSE loss for simplicity. OpenVLA's official finetuning uses discretized action tokens with cross-entropy loss via output.loss — see the OpenVLA repo for details.

What to try next