Skip to content

Training Imitation Learning Policies

Practical guide to training IL policies from demonstrations.

Training Pipeline

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from torch.utils.data import DataLoader
import torch.nn as nn

# 1. Load demonstrations
dataset = LeRobotDataset(
    repo_id="your_username/your_dataset",
    root="data/"
)

# 2. Create dataloaders
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

# 3. Initialize policy
policy = BCPolicy(obs_dim=dataset.obs_dim, action_dim=dataset.action_dim)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# 4. Train
for epoch in range(100):
    for batch in train_loader:
        obs = batch['observation']
        actions = batch['action']

        predicted_actions = policy(obs)
        loss = criterion(predicted_actions, actions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 5. Evaluate
success_rate = evaluate_policy(policy, env, n_episodes=50)

Configuration

# config/bc_training.yaml
model:
  type: "mlp"
  hidden_dims: [256, 256]
  activation: "relu"

training:
  batch_size: 64
  learning_rate: 1e-3
  weight_decay: 1e-4
  num_epochs: 100
  gradient_clip: 1.0

data:
  augmentation:
    enabled: true
    noise_level: 0.01
    time_warp: true

evaluation:
  eval_freq: 10  # epochs
  n_eval_episodes: 50
  deterministic: true

Monitoring

import wandb

wandb.init(project="imitation-learning")

for epoch in range(num_epochs):
    # Training
    train_loss = train_epoch(policy, train_loader)

    # Validation
    val_loss = validate(policy, val_loader)

    # Evaluation in environment
    if epoch % eval_freq == 0:
        success_rate, avg_return = evaluate_policy(policy, env)

        wandb.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'success_rate': success_rate,
            'avg_return': avg_return
        })

Troubleshooting

Issue Cause Solution
High training loss Model too simple Increase capacity
Low validation loss, poor performance Distribution shift Use DAgger, augmentation
Jerky actions No smoothness regularization Add action smoothing
Mode collapse Multi-modal actions Use diffusion/VAE policy

Best Practices

  1. Start simple: Begin with vanilla BC
  2. Visualize: Plot predicted vs actual actions
  3. Augment data: Noise, time warping
  4. Regularize: L2, dropout for generalization
  5. Iterate: Collect more data if needed

Next Steps