Training Imitation Learning Policies¶
Practical guide to training IL policies from demonstrations.
Training Pipeline¶
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from torch.utils.data import DataLoader
import torch.nn as nn
# 1. Load demonstrations
dataset = LeRobotDataset(
repo_id="your_username/your_dataset",
root="data/"
)
# 2. Create dataloaders
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# 3. Initialize policy
policy = BCPolicy(obs_dim=dataset.obs_dim, action_dim=dataset.action_dim)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
criterion = nn.MSELoss()
# 4. Train
for epoch in range(100):
for batch in train_loader:
obs = batch['observation']
actions = batch['action']
predicted_actions = policy(obs)
loss = criterion(predicted_actions, actions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 5. Evaluate
success_rate = evaluate_policy(policy, env, n_episodes=50)
Configuration¶
# config/bc_training.yaml
model:
type: "mlp"
hidden_dims: [256, 256]
activation: "relu"
training:
batch_size: 64
learning_rate: 1e-3
weight_decay: 1e-4
num_epochs: 100
gradient_clip: 1.0
data:
augmentation:
enabled: true
noise_level: 0.01
time_warp: true
evaluation:
eval_freq: 10 # epochs
n_eval_episodes: 50
deterministic: true
Monitoring¶
import wandb
wandb.init(project="imitation-learning")
for epoch in range(num_epochs):
# Training
train_loss = train_epoch(policy, train_loader)
# Validation
val_loss = validate(policy, val_loader)
# Evaluation in environment
if epoch % eval_freq == 0:
success_rate, avg_return = evaluate_policy(policy, env)
wandb.log({
'epoch': epoch,
'train_loss': train_loss,
'val_loss': val_loss,
'success_rate': success_rate,
'avg_return': avg_return
})
Troubleshooting¶
| Issue | Cause | Solution |
|---|---|---|
| High training loss | Model too simple | Increase capacity |
| Low validation loss, poor performance | Distribution shift | Use DAgger, augmentation |
| Jerky actions | No smoothness regularization | Add action smoothing |
| Mode collapse | Multi-modal actions | Use diffusion/VAE policy |
Best Practices¶
- Start simple: Begin with vanilla BC
- Visualize: Plot predicted vs actual actions
- Augment data: Noise, time warping
- Regularize: L2, dropout for generalization
- Iterate: Collect more data if needed
Next Steps¶
- Methods - Try advanced IL methods
- Data Collection - Improve data quality
- Evaluation - Thorough evaluation