Skip to content

LeRobot Examples

Complete examples of working with LeRobot datasets.

Example 1: Basic Behavioral Cloning

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

# Load dataset
dataset = LeRobotDataset("lerobot/pusht")

# Split train/val
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Define policy
class BCPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, obs):
        return self.net(obs)

# Initialize
policy = BCPolicy(obs_dim=dataset.obs_dim, action_dim=dataset.action_dim)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Train
for epoch in range(100):
    policy.train()
    for batch in train_loader:
        obs = batch['observation.state']
        actions = batch['action']

        predicted_actions = policy(obs)
        loss = criterion(predicted_actions, actions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validate
    policy.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            obs = batch['observation.state']
            actions = batch['action']
            predicted = policy(obs)
            val_loss += criterion(predicted, actions).item()

    print(f"Epoch {epoch}: Val Loss = {val_loss/len(val_loader):.4f}")

# Save model
torch.save(policy.state_dict(), "bc_policy.pt")

Example 2: Vision-Based Policy

import torch.nn as nn
from torchvision import transforms

# Image transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset with transforms
dataset = LeRobotDataset("lerobot/aloha_static", transforms=transform)

# Vision-based policy
class VisionPolicy(nn.Module):
    def __init__(self, action_dim):
        super().__init__()

        # Vision encoder (ResNet)
        from torchvision.models import resnet18
        self.vision_encoder = resnet18(pretrained=True)
        self.vision_encoder.fc = nn.Identity()  # Remove final layer

        # Action decoder
        self.action_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, image):
        features = self.vision_encoder(image)
        return self.action_head(features)

# Training loop
policy = VisionPolicy(action_dim=7)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)

for epoch in range(50):
    for batch in train_loader:
        images = batch['observation.image']
        actions = batch['action']

        predicted_actions = policy(images)
        loss = nn.MSELoss()(predicted_actions, actions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Example 3: Creating Custom Dataset

import numpy as np
import cv2
from pathlib import Path
import pandas as pd

class RobotDataCollector:
    def __init__(self, robot, camera, save_dir):
        self.robot = robot
        self.camera = camera
        self.save_dir = Path(save_dir)
        self.episodes = []

        # Create directories
        (self.save_dir / 'episodes').mkdir(parents=True, exist_ok=True)
        (self.save_dir / 'videos').mkdir(parents=True, exist_ok=True)

    def collect_episode(self, duration=10.0, fps=30):
        episode_idx = len(self.episodes)
        episode_data = []
        frames = []

        start_time = time.time()
        frame_idx = 0

        while time.time() - start_time < duration:
            # Get observation
            image = self.camera.get_image()
            state = self.robot.get_state()

            # Get action (from human demonstration or autonomous execution)
            action = self.get_action()  # Implement based on your setup

            # Record
            episode_data.append({
                'episode_index': episode_idx,
                'frame_index': frame_idx,
                'timestamp': time.time() - start_time,
                'observation.state': state.tolist(),
                'action': action.tolist()
            })

            frames.append(image)
            frame_idx += 1

            time.sleep(1.0 / fps)

        # Save video
        video_path = self.save_dir / f'videos/episode_{episode_idx:06d}.mp4'
        self.save_video(frames, video_path, fps)

        # Update episode data with video paths
        for i, data in enumerate(episode_data):
            data['observation.image'] = f'videos/episode_{episode_idx:06d}.mp4#frame={i}'

        # Save episode
        df = pd.DataFrame(episode_data)
        df.to_parquet(self.save_dir / f'episodes/episode_{episode_idx:06d}.parquet')

        self.episodes.append(episode_data)
        print(f"Collected episode {episode_idx} with {len(episode_data)} frames")

    def save_video(self, frames, path, fps):
        height, width = frames[0].shape[:2]
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(str(path), fourcc, fps, (width, height))

        for frame in frames:
            writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

        writer.release()

    def save_metadata(self):
        metadata = {
            'fps': 30,
            'total_episodes': len(self.episodes),
            'total_frames': sum(len(ep) for ep in self.episodes),
            'robot_type': 'franka'
        }

        import json
        (self.save_dir / 'meta').mkdir(exist_ok=True)
        with open(self.save_dir / 'meta/info.json', 'w') as f:
            json.dump(metadata, f, indent=2)

# Usage
collector = RobotDataCollector(robot, camera, 'data/my_dataset')

for _ in range(10):  # Collect 10 episodes
    input("Press Enter to start episode...")
    collector.collect_episode(duration=5.0)

collector.save_metadata()

Example 4: Multi-Task Dataset

# Create multi-task dataset
class MultiTaskDataset:
    def __init__(self, tasks):
        self.datasets = {
            task_name: LeRobotDataset(f"username/{task_name}")
            for task_name in tasks
        }

    def get_task_batch(self, task_name, batch_size):
        dataset = self.datasets[task_name]
        indices = np.random.randint(0, len(dataset), batch_size)
        return [dataset[i] for i in indices]

    def get_mixed_batch(self, batch_size):
        # Sample from all tasks
        task_names = list(self.datasets.keys())
        batch = []

        for _ in range(batch_size):
            task = np.random.choice(task_names)
            idx = np.random.randint(0, len(self.datasets[task]))
            sample = self.datasets[task][idx]
            sample['task'] = task
            batch.append(sample)

        return batch

# Usage
multi_task = MultiTaskDataset(['pick_cube', 'place_cube', 'push_cube'])

# Train on mixed batches
for iteration in range(1000):
    batch = multi_task.get_mixed_batch(batch_size=32)
    # Train multi-task policy

Example 5: Dataset Analysis

import matplotlib.pyplot as plt
import numpy as np

def analyze_dataset(dataset):
    """Comprehensive dataset analysis"""

    # Load all episodes
    all_actions = []
    all_states = []
    episode_lengths = []
    success_count = 0

    for episode_idx in range(dataset.num_episodes):
        episode = dataset.get_episode(episode_idx)

        # Collect statistics
        actions = np.array([frame['action'] for frame in episode])
        states = np.array([frame['observation.state'] for frame in episode])

        all_actions.append(actions)
        all_states.append(states)
        episode_lengths.append(len(episode))

        if episode[-1].get('next.success', False):
            success_count += 1

    # Statistics
    print(f"Total episodes: {dataset.num_episodes}")
    print(f"Success rate: {success_count / dataset.num_episodes * 100:.1f}%")
    print(f"Avg episode length: {np.mean(episode_lengths):.1f} frames")
    print(f"Total frames: {sum(episode_lengths)}")

    # Plot action distribution
    all_actions_flat = np.concatenate(all_actions, axis=0)

    fig, axes = plt.subplots(1, all_actions_flat.shape[1], figsize=(15, 3))
    for i, ax in enumerate(axes):
        ax.hist(all_actions_flat[:, i], bins=50)
        ax.set_title(f'Action Dim {i}')
        ax.set_xlabel('Value')
        ax.set_ylabel('Count')

    plt.tight_layout()
    plt.savefig('action_distribution.png')

    # Plot episode lengths
    plt.figure()
    plt.hist(episode_lengths, bins=30)
    plt.xlabel('Episode Length')
    plt.ylabel('Count')
    plt.title('Episode Length Distribution')
    plt.savefig('episode_lengths.png')

# Usage
dataset = LeRobotDataset("lerobot/pusht")
analyze_dataset(dataset)

Next Steps