Fine-Tuning Vision-Language-Action Models¶
Fine-tuning pre-trained VLA models for your specific robot is often more effective than training from scratch.
Why Fine-Tune?¶
Advantages over training from scratch: - ✓ Sample efficient: 50-200 demos vs 10K+ from scratch - ✓ Faster convergence: Hours instead of days - ✓ Better generalization: Leverages pre-trained knowledge - ✓ Lower compute: Single GPU vs multi-GPU cluster
When to fine-tune: - You have a pre-trained model (OpenVLA, Octo, RT-2) - Your robot/task is similar to training data - You have 50-500 demonstrations available - You want quick iteration
Fine-Tuning Strategies¶
1. Full Fine-Tuning¶
Update all model parameters:
def full_finetune(model, dataset, config):
"""Fine-tune all model parameters"""
# All parameters trainable
for param in model.parameters():
param.requires_grad = True
# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate, # Typically 1e-5 to 1e-4
weight_decay=config.weight_decay
)
# Training loop
for epoch in range(config.num_epochs):
for batch in dataset:
# Forward pass
loss = model.compute_loss(batch)
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return model
Pros: - Maximum adaptation capability - Best final performance (if you have enough data)
Cons: - Risk of overfitting with small datasets - Slow (updates millions of parameters) - High memory usage
Best for: 1000+ demonstrations
2. Partial Fine-Tuning¶
Freeze backbone, train task-specific layers:
def partial_finetune(model, dataset, config):
"""Freeze vision/language encoders, train action head"""
# Freeze vision encoder
for param in model.vision_encoder.parameters():
param.requires_grad = False
# Freeze language encoder
for param in model.language_encoder.parameters():
param.requires_grad = False
# Train fusion and action prediction
for param in model.fusion.parameters():
param.requires_grad = True
for param in model.action_head.parameters():
param.requires_grad = True
# Only optimize trainable parameters
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.learning_rate
)
# Training loop (same as full fine-tuning)
for epoch in range(config.num_epochs):
for batch in dataset:
loss = model.compute_loss(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
Pros: - Less prone to overfitting - Faster training - Lower memory usage
Cons: - Limited adaptation of perception
Best for: 200-1000 demonstrations
3. LoRA (Low-Rank Adaptation)¶
Most efficient approach for large models:
from peft import LoraConfig, get_peft_model
def lora_finetune(model, dataset, config):
"""Fine-tune with LoRA adapters"""
# Configure LoRA
lora_config = LoraConfig(
r=16, # Rank (higher = more capacity, 8-32 typical)
lora_alpha=32, # Scaling factor
target_modules=[
"q_proj", # Query projection in attention
"v_proj", # Value projection
"k_proj", # Key projection
"o_proj", # Output projection
],
lora_dropout=0.05,
bias="none"
)
# Apply LoRA to model
model = get_peft_model(model, lora_config)
# Print trainable parameters
model.print_trainable_parameters()
# Output: "trainable params: 2.3M || all params: 7B || trainable%: 0.03%"
# Optimizer (only LoRA parameters updated)
optimizer = torch.optim.AdamW(
model.parameters(), # Automatically filters trainable params
lr=config.learning_rate # Can use higher LR with LoRA
)
# Training loop
for epoch in range(config.num_epochs):
for batch in dataset:
loss = model.compute_loss(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
How LoRA Works:
Original layer: \(W \in \mathbb{R}^{d \times k}\)
LoRA decomposition: \(W' = W + \Delta W = W + BA\)
Where: - \(B \in \mathbb{R}^{d \times r}\), \(A \in \mathbb{R}^{r \times k}\) - \(r \ll \min(d, k)\) (rank bottleneck) - Only \(A\) and \(B\) are trained
class LoRALinear(nn.Module):
"""Linear layer with LoRA adapter"""
def __init__(self, in_features, out_features, r=16, lora_alpha=32):
super().__init__()
# Frozen pre-trained weight
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.weight.requires_grad = False
# LoRA adapter matrices
self.lora_A = nn.Parameter(torch.randn(r, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, r))
# Scaling
self.scaling = lora_alpha / r
def forward(self, x):
# Original projection
out = F.linear(x, self.weight)
# Add LoRA adaptation
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
out = out + lora_out * self.scaling
return out
Pros: - ✓ Extremely parameter efficient (train 0.1% of parameters) - ✓ Fast training (fewer gradients to compute) - ✓ Low memory (no need to store most gradients) - ✓ No overfitting (strong regularization) - ✓ Portable (can save/load just adapter weights)
Cons: - ✗Limited capacity (may underfit complex adaptations)
Best for: 50-500 demonstrations, recommended for most use cases
4. Adapter Layers¶
Insert trainable adapter modules:
class AdapterModule(nn.Module):
"""Bottleneck adapter"""
def __init__(self, dim, reduction_factor=4):
super().__init__()
hidden_dim = dim // reduction_factor
self.adapter = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, dim)
)
# Initialize to near-identity
nn.init.zeros_(self.adapter[-1].weight)
nn.init.zeros_(self.adapter[-1].bias)
def forward(self, x):
return x + self.adapter(x)
def add_adapters_to_model(model, reduction_factor=4):
"""Add adapter modules to transformer layers"""
for layer in model.transformer.layers:
# Add adapter after attention
layer.attention_adapter = AdapterModule(
model.hidden_dim,
reduction_factor
)
# Add adapter after feed-forward
layer.ff_adapter = AdapterModule(
model.hidden_dim,
reduction_factor
)
return model
# Modified forward pass
class TransformerLayerWithAdapter(nn.Module):
def forward(self, x):
# Attention
x = x + self.attention(x)
x = self.attention_adapter(x) # Adapter
# Feed-forward
x = x + self.feed_forward(x)
x = self.ff_adapter(x) # Adapter
return x
Fine-Tuning Configuration¶
Learning Rate¶
Critical hyperparameter - too high causes instability, too low converges slowly:
def get_learning_rate(strategy, base_lr=1e-4):
"""Recommended learning rates"""
lr_schedule = {
'full_finetune': base_lr * 0.1, # 1e-5
'partial_finetune': base_lr, # 1e-4
'lora': base_lr * 3, # 3e-4 (can be higher)
'adapter': base_lr * 2 # 2e-4
}
return lr_schedule.get(strategy, base_lr)
# Use warmup
def get_lr_scheduler(optimizer, warmup_steps=100, total_steps=1000):
"""Cosine schedule with warmup"""
from transformers import get_cosine_schedule_with_warmup
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
return scheduler
Data Augmentation¶
Prevent overfitting on small datasets:
class RobotDataAugmentation:
"""Augmentation for robot demonstrations"""
def __init__(self):
# Vision augmentations
self.vision_aug = transforms.Compose([
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.05
),
transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)),
])
def __call__(self, demo):
"""Augment demonstration"""
augmented = copy.deepcopy(demo)
# Augment images
for step in augmented:
step['observation']['image'] = self.vision_aug(
step['observation']['image']
)
# Optional: Add small noise to actions
if np.random.rand() < 0.5:
for step in augmented:
noise = np.random.randn(*step['action'].shape) * 0.01
step['action'] = step['action'] + noise
return augmented
Regularization¶
Prevent catastrophic forgetting:
class RegularizedFineTuning:
"""Fine-tune with regularization to preserve pre-trained knowledge"""
def __init__(self, model, regularization_weight=0.01):
self.model = model
self.reg_weight = regularization_weight
# Store initial parameters
self.initial_params = {
name: param.clone().detach()
for name, param in model.named_parameters()
if param.requires_grad
}
def compute_loss(self, batch):
"""Loss with L2 regularization on parameter changes"""
# Task loss
task_loss = self.model.compute_loss(batch)
# Regularization: penalize large changes from initial parameters
reg_loss = 0
for name, param in self.model.named_parameters():
if name in self.initial_params:
reg_loss += torch.norm(param - self.initial_params[name]) ** 2
total_loss = task_loss + self.reg_weight * reg_loss
return total_loss
Complete Fine-Tuning Pipeline¶
class VLAFineTuner:
"""Complete fine-tuning pipeline for VLA models"""
def __init__(self, pretrained_model, strategy='lora'):
self.base_model = pretrained_model
self.strategy = strategy
# Prepare model based on strategy
if strategy == 'lora':
self.model = self.setup_lora()
elif strategy == 'adapter':
self.model = self.setup_adapters()
elif strategy == 'partial':
self.model = self.setup_partial()
else: # full
self.model = self.base_model
def setup_lora(self):
"""Setup LoRA fine-tuning"""
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none"
)
model = get_peft_model(self.base_model, config)
return model
def finetune(self, train_dataset, val_dataset, config):
"""Fine-tune model"""
# Optimizer
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=get_learning_rate(self.strategy, config.base_lr),
weight_decay=config.weight_decay
)
# Scheduler
total_steps = len(train_dataset) * config.num_epochs // config.batch_size
scheduler = get_lr_scheduler(optimizer, config.warmup_steps, total_steps)
# Data augmentation
augmenter = RobotDataAugmentation()
# Training loop
best_val_loss = float('inf')
patience = 0
for epoch in range(config.num_epochs):
# Train
self.model.train()
train_losses = []
for batch in train_dataset:
# Augment
if config.use_augmentation:
batch = augmenter(batch)
# Forward
loss = self.model.compute_loss(batch)
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), config.grad_clip)
optimizer.step()
scheduler.step()
train_losses.append(loss.item())
# Validate
self.model.eval()
val_losses = []
with torch.no_grad():
for batch in val_dataset:
loss = self.model.compute_loss(batch)
val_losses.append(loss.item())
# Logging
avg_train_loss = np.mean(train_losses)
avg_val_loss = np.mean(val_losses)
print(f"Epoch {epoch+1}/{config.num_epochs}:")
print(f" Train Loss: {avg_train_loss:.4f}")
print(f" Val Loss: {avg_val_loss:.4f}")
# Early stopping
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience = 0
self.save_checkpoint('best_model.pt')
else:
patience += 1
if patience >= config.early_stopping_patience:
print(f"Early stopping at epoch {epoch+1}")
break
# Load best model
self.load_checkpoint('best_model.pt')
return self.model
def save_checkpoint(self, path):
"""Save model checkpoint"""
if self.strategy == 'lora':
# Save only LoRA weights
self.model.save_pretrained(path)
else:
torch.save(self.model.state_dict(), path)
def load_checkpoint(self, path):
"""Load model checkpoint"""
if self.strategy == 'lora':
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.base_model, path)
else:
self.model.load_state_dict(torch.load(path))
Domain-Specific Fine-Tuning¶
Action Space Adaptation¶
Adapt to different action dimensions/ranges:
def adapt_action_head(model, new_action_dim, new_action_ranges):
"""Replace action head for new robot"""
# Save old action head (for initialization)
old_action_head = model.action_head
# Create new action head
model.action_head = ActionHead(
feature_dim=model.hidden_dim,
action_dim=new_action_dim,
action_type=model.action_type
)
# Initialize from old head if possible
if new_action_dim >= old_action_head.action_dim:
# Copy weights for overlapping dimensions
with torch.no_grad():
model.action_head.mean_head[0].weight[:old_action_head.action_dim] = \
old_action_head.mean_head[0].weight
return model
Multi-Task Fine-Tuning¶
Fine-tune on multiple related tasks:
def multitask_finetune(model, task_datasets, config):
"""Fine-tune on multiple tasks simultaneously"""
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.learning_rate
)
# Create iterators for each task
task_iterators = {
task_name: iter(DataLoader(dataset, batch_size=config.batch_size, shuffle=True))
for task_name, dataset in task_datasets.items()
}
for step in range(config.max_steps):
# Sample task (can use curriculum)
task_name = np.random.choice(list(task_datasets.keys()))
# Get batch
try:
batch = next(task_iterators[task_name])
except StopIteration:
# Reset iterator
task_iterators[task_name] = iter(DataLoader(
task_datasets[task_name],
batch_size=config.batch_size,
shuffle=True
))
batch = next(task_iterators[task_name])
# Forward pass
loss = model.compute_loss(batch)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f"Step {step}, Task: {task_name}, Loss: {loss.item():.4f}")
return model
Evaluation During Fine-Tuning¶
Monitor both offline and online metrics:
class FineTuningEvaluator:
"""Evaluate model during fine-tuning"""
def __init__(self, model, val_dataset, env=None):
self.model = model
self.val_dataset = val_dataset
self.env = env
def evaluate(self):
"""Run comprehensive evaluation"""
results = {}
# Offline metrics (from validation data)
results['offline'] = self.evaluate_offline()
# Online metrics (if environment available)
if self.env is not None:
results['online'] = self.evaluate_online()
return results
def evaluate_offline(self):
"""Offline evaluation on validation set"""
self.model.eval()
total_loss = 0
action_errors = []
with torch.no_grad():
for batch in self.val_dataset:
# Prediction loss
loss = self.model.compute_loss(batch)
total_loss += loss.item()
# Action prediction error
predicted_actions = self.model.predict(batch)
true_actions = batch['actions']
error = torch.norm(predicted_actions - true_actions, dim=-1).mean()
action_errors.append(error.item())
return {
'loss': total_loss / len(self.val_dataset),
'action_error': np.mean(action_errors)
}
def evaluate_online(self, num_episodes=10):
"""Online evaluation in environment"""
self.model.eval()
success_count = 0
episode_lengths = []
for _ in range(num_episodes):
obs = self.env.reset()
done = False
steps = 0
while not done and steps < 500:
# Predict action
with torch.no_grad():
action = self.model.predict(obs)
# Execute
obs, reward, done, info = self.env.step(action)
steps += 1
success = info.get('success', False)
success_count += int(success)
episode_lengths.append(steps)
return {
'success_rate': success_count / num_episodes,
'avg_episode_length': np.mean(episode_lengths)
}
Best Practices¶
DO:¶
✓ Start with LoRA for most applications ✓ Use learning rate warmup ✓ Apply data augmentation for small datasets (<200 demos) ✓ Monitor both offline and online metrics ✓ Save best checkpoint based on validation performance ✓ Use early stopping to prevent overfitting
DON'T:¶
✗Fine-tune with too high learning rate (causes forgetting) ✗Train without validation set ✗Ignore domain shift between pre-training and fine-tuning data ✗Over-train on tiny datasets (< 50 demos) ✗Forget to normalize observations/actions consistently
Example: Complete Fine-Tuning Script¶
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import torch
# 1. Load pre-trained model
base_model = AutoModelForCausalLM.from_pretrained('openvla/openvla-7b')
# 2. Setup LoRA
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
model = get_peft_model(base_model, lora_config)
# 3. Prepare data
from your_robot_data import load_demonstrations
train_demos = load_demonstrations(num_demos=100, split='train')
val_demos = load_demonstrations(num_demos=20, split='val')
# 4. Fine-tune
finetuner = VLAFineTuner(model, strategy='lora')
model = finetuner.finetune(train_demos, val_demos, config={
'num_epochs': 20,
'base_lr': 1e-4,
'batch_size': 8,
'warmup_steps': 100,
'grad_clip': 1.0,
'early_stopping_patience': 5
})
# 5. Deploy
model.save_pretrained('./finetuned_model')
Next Steps¶
- Training Guide - Train from scratch
- Deployment - Deploy fine-tuned models
- Optimization - Speed up inference
- Multi-task Learning - Fine-tune on multiple tasks