Skip to content

Fine-Tuning Vision-Language-Action Models

Fine-tuning pre-trained VLA models for your specific robot is often more effective than training from scratch.

Why Fine-Tune?

Advantages over training from scratch: - ✓ Sample efficient: 50-200 demos vs 10K+ from scratch - ✓ Faster convergence: Hours instead of days - ✓ Better generalization: Leverages pre-trained knowledge - ✓ Lower compute: Single GPU vs multi-GPU cluster

When to fine-tune: - You have a pre-trained model (OpenVLA, Octo, RT-2) - Your robot/task is similar to training data - You have 50-500 demonstrations available - You want quick iteration

Fine-Tuning Strategies

1. Full Fine-Tuning

Update all model parameters:

def full_finetune(model, dataset, config):
    """Fine-tune all model parameters"""

    # All parameters trainable
    for param in model.parameters():
        param.requires_grad = True

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,  # Typically 1e-5 to 1e-4
        weight_decay=config.weight_decay
    )

    # Training loop
    for epoch in range(config.num_epochs):
        for batch in dataset:
            # Forward pass
            loss = model.compute_loss(batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

    return model

Pros: - Maximum adaptation capability - Best final performance (if you have enough data)

Cons: - Risk of overfitting with small datasets - Slow (updates millions of parameters) - High memory usage

Best for: 1000+ demonstrations

2. Partial Fine-Tuning

Freeze backbone, train task-specific layers:

def partial_finetune(model, dataset, config):
    """Freeze vision/language encoders, train action head"""

    # Freeze vision encoder
    for param in model.vision_encoder.parameters():
        param.requires_grad = False

    # Freeze language encoder
    for param in model.language_encoder.parameters():
        param.requires_grad = False

    # Train fusion and action prediction
    for param in model.fusion.parameters():
        param.requires_grad = True

    for param in model.action_head.parameters():
        param.requires_grad = True

    # Only optimize trainable parameters
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config.learning_rate
    )

    # Training loop (same as full fine-tuning)
    for epoch in range(config.num_epochs):
        for batch in dataset:
            loss = model.compute_loss(batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model

Pros: - Less prone to overfitting - Faster training - Lower memory usage

Cons: - Limited adaptation of perception

Best for: 200-1000 demonstrations

3. LoRA (Low-Rank Adaptation)

Most efficient approach for large models:

from peft import LoraConfig, get_peft_model

def lora_finetune(model, dataset, config):
    """Fine-tune with LoRA adapters"""

    # Configure LoRA
    lora_config = LoraConfig(
        r=16,  # Rank (higher = more capacity, 8-32 typical)
        lora_alpha=32,  # Scaling factor
        target_modules=[
            "q_proj",  # Query projection in attention
            "v_proj",  # Value projection
            "k_proj",  # Key projection
            "o_proj",  # Output projection
        ],
        lora_dropout=0.05,
        bias="none"
    )

    # Apply LoRA to model
    model = get_peft_model(model, lora_config)

    # Print trainable parameters
    model.print_trainable_parameters()
    # Output: "trainable params: 2.3M || all params: 7B || trainable%: 0.03%"

    # Optimizer (only LoRA parameters updated)
    optimizer = torch.optim.AdamW(
        model.parameters(),  # Automatically filters trainable params
        lr=config.learning_rate  # Can use higher LR with LoRA
    )

    # Training loop
    for epoch in range(config.num_epochs):
        for batch in dataset:
            loss = model.compute_loss(batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model

How LoRA Works:

Original layer: \(W \in \mathbb{R}^{d \times k}\)

LoRA decomposition: \(W' = W + \Delta W = W + BA\)

Where: - \(B \in \mathbb{R}^{d \times r}\), \(A \in \mathbb{R}^{r \times k}\) - \(r \ll \min(d, k)\) (rank bottleneck) - Only \(A\) and \(B\) are trained

class LoRALinear(nn.Module):
    """Linear layer with LoRA adapter"""
    def __init__(self, in_features, out_features, r=16, lora_alpha=32):
        super().__init__()

        # Frozen pre-trained weight
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.weight.requires_grad = False

        # LoRA adapter matrices
        self.lora_A = nn.Parameter(torch.randn(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))

        # Scaling
        self.scaling = lora_alpha / r

    def forward(self, x):
        # Original projection
        out = F.linear(x, self.weight)

        # Add LoRA adaptation
        lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
        out = out + lora_out * self.scaling

        return out

Pros: - ✓ Extremely parameter efficient (train 0.1% of parameters) - ✓ Fast training (fewer gradients to compute) - ✓ Low memory (no need to store most gradients) - ✓ No overfitting (strong regularization) - ✓ Portable (can save/load just adapter weights)

Cons: - ✗Limited capacity (may underfit complex adaptations)

Best for: 50-500 demonstrations, recommended for most use cases

4. Adapter Layers

Insert trainable adapter modules:

class AdapterModule(nn.Module):
    """Bottleneck adapter"""
    def __init__(self, dim, reduction_factor=4):
        super().__init__()
        hidden_dim = dim // reduction_factor

        self.adapter = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )

        # Initialize to near-identity
        nn.init.zeros_(self.adapter[-1].weight)
        nn.init.zeros_(self.adapter[-1].bias)

    def forward(self, x):
        return x + self.adapter(x)

def add_adapters_to_model(model, reduction_factor=4):
    """Add adapter modules to transformer layers"""

    for layer in model.transformer.layers:
        # Add adapter after attention
        layer.attention_adapter = AdapterModule(
            model.hidden_dim,
            reduction_factor
        )

        # Add adapter after feed-forward
        layer.ff_adapter = AdapterModule(
            model.hidden_dim,
            reduction_factor
        )

    return model

# Modified forward pass
class TransformerLayerWithAdapter(nn.Module):
    def forward(self, x):
        # Attention
        x = x + self.attention(x)
        x = self.attention_adapter(x)  # Adapter

        # Feed-forward
        x = x + self.feed_forward(x)
        x = self.ff_adapter(x)  # Adapter

        return x

Fine-Tuning Configuration

Learning Rate

Critical hyperparameter - too high causes instability, too low converges slowly:

def get_learning_rate(strategy, base_lr=1e-4):
    """Recommended learning rates"""
    lr_schedule = {
        'full_finetune': base_lr * 0.1,  # 1e-5
        'partial_finetune': base_lr,  # 1e-4
        'lora': base_lr * 3,  # 3e-4 (can be higher)
        'adapter': base_lr * 2  # 2e-4
    }
    return lr_schedule.get(strategy, base_lr)

# Use warmup
def get_lr_scheduler(optimizer, warmup_steps=100, total_steps=1000):
    """Cosine schedule with warmup"""
    from transformers import get_cosine_schedule_with_warmup

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    return scheduler

Data Augmentation

Prevent overfitting on small datasets:

class RobotDataAugmentation:
    """Augmentation for robot demonstrations"""
    def __init__(self):
        # Vision augmentations
        self.vision_aug = transforms.Compose([
            transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2,
                hue=0.05
            ),
            transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)),
        ])

    def __call__(self, demo):
        """Augment demonstration"""
        augmented = copy.deepcopy(demo)

        # Augment images
        for step in augmented:
            step['observation']['image'] = self.vision_aug(
                step['observation']['image']
            )

        # Optional: Add small noise to actions
        if np.random.rand() < 0.5:
            for step in augmented:
                noise = np.random.randn(*step['action'].shape) * 0.01
                step['action'] = step['action'] + noise

        return augmented

Regularization

Prevent catastrophic forgetting:

class RegularizedFineTuning:
    """Fine-tune with regularization to preserve pre-trained knowledge"""
    def __init__(self, model, regularization_weight=0.01):
        self.model = model
        self.reg_weight = regularization_weight

        # Store initial parameters
        self.initial_params = {
            name: param.clone().detach()
            for name, param in model.named_parameters()
            if param.requires_grad
        }

    def compute_loss(self, batch):
        """Loss with L2 regularization on parameter changes"""
        # Task loss
        task_loss = self.model.compute_loss(batch)

        # Regularization: penalize large changes from initial parameters
        reg_loss = 0
        for name, param in self.model.named_parameters():
            if name in self.initial_params:
                reg_loss += torch.norm(param - self.initial_params[name]) ** 2

        total_loss = task_loss + self.reg_weight * reg_loss

        return total_loss

Complete Fine-Tuning Pipeline

class VLAFineTuner:
    """Complete fine-tuning pipeline for VLA models"""
    def __init__(self, pretrained_model, strategy='lora'):
        self.base_model = pretrained_model
        self.strategy = strategy

        # Prepare model based on strategy
        if strategy == 'lora':
            self.model = self.setup_lora()
        elif strategy == 'adapter':
            self.model = self.setup_adapters()
        elif strategy == 'partial':
            self.model = self.setup_partial()
        else:  # full
            self.model = self.base_model

    def setup_lora(self):
        """Setup LoRA fine-tuning"""
        from peft import LoraConfig, get_peft_model

        config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none"
        )

        model = get_peft_model(self.base_model, config)
        return model

    def finetune(self, train_dataset, val_dataset, config):
        """Fine-tune model"""
        # Optimizer
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=get_learning_rate(self.strategy, config.base_lr),
            weight_decay=config.weight_decay
        )

        # Scheduler
        total_steps = len(train_dataset) * config.num_epochs // config.batch_size
        scheduler = get_lr_scheduler(optimizer, config.warmup_steps, total_steps)

        # Data augmentation
        augmenter = RobotDataAugmentation()

        # Training loop
        best_val_loss = float('inf')
        patience = 0

        for epoch in range(config.num_epochs):
            # Train
            self.model.train()
            train_losses = []

            for batch in train_dataset:
                # Augment
                if config.use_augmentation:
                    batch = augmenter(batch)

                # Forward
                loss = self.model.compute_loss(batch)

                # Backward
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), config.grad_clip)
                optimizer.step()
                scheduler.step()

                train_losses.append(loss.item())

            # Validate
            self.model.eval()
            val_losses = []

            with torch.no_grad():
                for batch in val_dataset:
                    loss = self.model.compute_loss(batch)
                    val_losses.append(loss.item())

            # Logging
            avg_train_loss = np.mean(train_losses)
            avg_val_loss = np.mean(val_losses)

            print(f"Epoch {epoch+1}/{config.num_epochs}:")
            print(f"  Train Loss: {avg_train_loss:.4f}")
            print(f"  Val Loss: {avg_val_loss:.4f}")

            # Early stopping
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience = 0
                self.save_checkpoint('best_model.pt')
            else:
                patience += 1
                if patience >= config.early_stopping_patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        # Load best model
        self.load_checkpoint('best_model.pt')

        return self.model

    def save_checkpoint(self, path):
        """Save model checkpoint"""
        if self.strategy == 'lora':
            # Save only LoRA weights
            self.model.save_pretrained(path)
        else:
            torch.save(self.model.state_dict(), path)

    def load_checkpoint(self, path):
        """Load model checkpoint"""
        if self.strategy == 'lora':
            from peft import PeftModel
            self.model = PeftModel.from_pretrained(self.base_model, path)
        else:
            self.model.load_state_dict(torch.load(path))

Domain-Specific Fine-Tuning

Action Space Adaptation

Adapt to different action dimensions/ranges:

def adapt_action_head(model, new_action_dim, new_action_ranges):
    """Replace action head for new robot"""

    # Save old action head (for initialization)
    old_action_head = model.action_head

    # Create new action head
    model.action_head = ActionHead(
        feature_dim=model.hidden_dim,
        action_dim=new_action_dim,
        action_type=model.action_type
    )

    # Initialize from old head if possible
    if new_action_dim >= old_action_head.action_dim:
        # Copy weights for overlapping dimensions
        with torch.no_grad():
            model.action_head.mean_head[0].weight[:old_action_head.action_dim] = \
                old_action_head.mean_head[0].weight

    return model

Multi-Task Fine-Tuning

Fine-tune on multiple related tasks:

def multitask_finetune(model, task_datasets, config):
    """Fine-tune on multiple tasks simultaneously"""

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config.learning_rate
    )

    # Create iterators for each task
    task_iterators = {
        task_name: iter(DataLoader(dataset, batch_size=config.batch_size, shuffle=True))
        for task_name, dataset in task_datasets.items()
    }

    for step in range(config.max_steps):
        # Sample task (can use curriculum)
        task_name = np.random.choice(list(task_datasets.keys()))

        # Get batch
        try:
            batch = next(task_iterators[task_name])
        except StopIteration:
            # Reset iterator
            task_iterators[task_name] = iter(DataLoader(
                task_datasets[task_name],
                batch_size=config.batch_size,
                shuffle=True
            ))
            batch = next(task_iterators[task_name])

        # Forward pass
        loss = model.compute_loss(batch)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"Step {step}, Task: {task_name}, Loss: {loss.item():.4f}")

    return model

Evaluation During Fine-Tuning

Monitor both offline and online metrics:

class FineTuningEvaluator:
    """Evaluate model during fine-tuning"""
    def __init__(self, model, val_dataset, env=None):
        self.model = model
        self.val_dataset = val_dataset
        self.env = env

    def evaluate(self):
        """Run comprehensive evaluation"""
        results = {}

        # Offline metrics (from validation data)
        results['offline'] = self.evaluate_offline()

        # Online metrics (if environment available)
        if self.env is not None:
            results['online'] = self.evaluate_online()

        return results

    def evaluate_offline(self):
        """Offline evaluation on validation set"""
        self.model.eval()

        total_loss = 0
        action_errors = []

        with torch.no_grad():
            for batch in self.val_dataset:
                # Prediction loss
                loss = self.model.compute_loss(batch)
                total_loss += loss.item()

                # Action prediction error
                predicted_actions = self.model.predict(batch)
                true_actions = batch['actions']
                error = torch.norm(predicted_actions - true_actions, dim=-1).mean()
                action_errors.append(error.item())

        return {
            'loss': total_loss / len(self.val_dataset),
            'action_error': np.mean(action_errors)
        }

    def evaluate_online(self, num_episodes=10):
        """Online evaluation in environment"""
        self.model.eval()

        success_count = 0
        episode_lengths = []

        for _ in range(num_episodes):
            obs = self.env.reset()
            done = False
            steps = 0

            while not done and steps < 500:
                # Predict action
                with torch.no_grad():
                    action = self.model.predict(obs)

                # Execute
                obs, reward, done, info = self.env.step(action)
                steps += 1

            success = info.get('success', False)
            success_count += int(success)
            episode_lengths.append(steps)

        return {
            'success_rate': success_count / num_episodes,
            'avg_episode_length': np.mean(episode_lengths)
        }

Best Practices

DO:

✓ Start with LoRA for most applications ✓ Use learning rate warmup ✓ Apply data augmentation for small datasets (<200 demos) ✓ Monitor both offline and online metrics ✓ Save best checkpoint based on validation performance ✓ Use early stopping to prevent overfitting

DON'T:

✗Fine-tune with too high learning rate (causes forgetting) ✗Train without validation set ✗Ignore domain shift between pre-training and fine-tuning data ✗Over-train on tiny datasets (< 50 demos) ✗Forget to normalize observations/actions consistently

Example: Complete Fine-Tuning Script

from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import torch

# 1. Load pre-trained model
base_model = AutoModelForCausalLM.from_pretrained('openvla/openvla-7b')

# 2. Setup LoRA
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
model = get_peft_model(base_model, lora_config)

# 3. Prepare data
from your_robot_data import load_demonstrations
train_demos = load_demonstrations(num_demos=100, split='train')
val_demos = load_demonstrations(num_demos=20, split='val')

# 4. Fine-tune
finetuner = VLAFineTuner(model, strategy='lora')
model = finetuner.finetune(train_demos, val_demos, config={
    'num_epochs': 20,
    'base_lr': 1e-4,
    'batch_size': 8,
    'warmup_steps': 100,
    'grad_clip': 1.0,
    'early_stopping_patience': 5
})

# 5. Deploy
model.save_pretrained('./finetuned_model')

Next Steps