Skip to content

Multi-Task Learning for VLA Models

Multi-task learning enables a single VLA model to perform diverse robotic tasks, improving generalization and sample efficiency.

Why Multi-Task Learning?

Single-task models learn one skill at a time:

Model_1: "Pick red block" → Trained on 1000 demos
Model_2: "Open drawer" → Trained on 1000 demos
Model_3: "Wipe table" → Trained on 1000 demos
Total: 3 models, 3000 demos

Multi-task models learn all skills simultaneously:

Model: Any task → Trained on 3000 diverse demos
Total: 1 model, but better generalization

Benefits: - ✓ Shared representations: Common skills transfer across tasks - ✓ Better generalization: Exposure to diverse scenarios - ✓ Sample efficiency: Each task benefits from others - ✓ Single deployment: One model for all tasks - ✓ Compositional skills: Can combine learned primitives

Architecture for Multi-Task VLA

Task-Conditioned Policy

Language naturally provides task conditioning:

class MultiTaskVLA(nn.Module):
    """VLA model for multiple tasks"""
    def __init__(self, vision_encoder, language_encoder, action_head):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.language_encoder = language_encoder
        self.action_head = action_head

        # Shared representations
        self.fusion = nn.Linear(512 + 768, 512)

    def forward(self, image, task_instruction, state):
        """
        Single forward pass for any task

        Args:
            image: (B, 3, H, W)
            task_instruction: "pick up red block" / "open drawer" / etc.
            state: (B, state_dim)
        Returns:
            action: (B, action_dim)
        """
        # Extract features
        vision_features = self.vision_encoder(image)
        task_embedding = self.language_encoder(task_instruction)

        # Fuse (task conditioning happens here)
        combined = torch.cat([vision_features, task_embedding], dim=-1)
        features = self.fusion(combined)

        # Predict action
        action = self.action_head(features)

        return action

Key insight: Language instruction specifies the task, so no architectural changes needed!

Task Embeddings

Alternative: Learn task embeddings:

class TaskEmbeddingVLA(nn.Module):
    """VLA with learned task embeddings"""
    def __init__(self, num_tasks=10, embedding_dim=128):
        super().__init__()

        # Discrete task embeddings
        self.task_embeddings = nn.Embedding(num_tasks, embedding_dim)

        # Task ID to embedding
        self.task_names = [
            "pick_red_block",
            "open_drawer",
            "close_drawer",
            "wipe_table",
            # ... more tasks
        ]

    def forward(self, image, task_id, state):
        """
        Args:
            task_id: integer in [0, num_tasks-1]
        """
        # Get task embedding
        task_emb = self.task_embeddings(task_id)

        # Rest same as before
        features = self.encode(image, task_emb, state)
        action = self.action_head(features)

        return action

Pros: Efficient, no language encoder needed Cons: Not compositional, can't generalize to new task descriptions

Training Strategies

1. Uniform Sampling

Sample tasks uniformly during training:

def train_multitask_uniform(model, task_datasets, config):
    """Train with uniform task sampling"""

    # Combine all datasets
    all_data = []
    for task_name, dataset in task_datasets.items():
        for demo in dataset:
            demo['task'] = task_name
            all_data.append(demo)

    # Shuffle
    random.shuffle(all_data)

    dataloader = DataLoader(all_data, batch_size=config.batch_size, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

    for epoch in range(config.num_epochs):
        for batch in dataloader:
            # Forward pass (task specified by language instruction)
            loss = model.compute_loss(batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model

Issue: Tasks with more data dominate training

2. Balanced Sampling

Balance samples across tasks:

class BalancedTaskSampler:
    """Sample equally from each task"""
    def __init__(self, task_datasets):
        self.task_datasets = task_datasets
        self.task_names = list(task_datasets.keys())

        # Create iterators
        self.iterators = {
            task: iter(DataLoader(dataset, batch_size=1, shuffle=True))
            for task, dataset in task_datasets.items()
        }

    def sample_batch(self, batch_size):
        """Sample batch with balanced tasks"""
        batch = []
        samples_per_task = batch_size // len(self.task_names)

        for task_name in self.task_names:
            for _ in range(samples_per_task):
                try:
                    sample = next(self.iterators[task_name])
                    batch.append(sample)
                except StopIteration:
                    # Restart iterator
                    self.iterators[task_name] = iter(DataLoader(
                        self.task_datasets[task_name],
                        batch_size=1,
                        shuffle=True
                    ))
                    sample = next(self.iterators[task_name])
                    batch.append(sample)

        return batch

# Training loop
sampler = BalancedTaskSampler(task_datasets)

for step in range(config.max_steps):
    batch = sampler.sample_batch(config.batch_size)
    loss = model.compute_loss(batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Benefit: Each task gets equal training signal

3. Curriculum Learning

Start with easy tasks, gradually increase difficulty:

class CurriculumScheduler:
    """Schedule task difficulty over training"""
    def __init__(self, task_difficulties):
        """
        Args:
            task_difficulties: dict mapping task -> difficulty score [0-1]
        """
        self.task_difficulties = task_difficulties
        self.current_step = 0

    def get_task_weights(self, total_steps):
        """Get sampling weights for each task"""
        # Progress through curriculum
        progress = self.current_step / total_steps

        weights = {}
        for task, difficulty in self.task_difficulties.items():
            # Easy tasks (low difficulty) early, hard tasks later
            if difficulty <= progress:
                weights[task] = 1.0  # Active
            else:
                weights[task] = 0.1  # Reduced probability

        return weights

# Example
task_difficulties = {
    'reach_target': 0.1,  # Easy
    'pick_object': 0.3,
    'place_object': 0.5,
    'stack_blocks': 0.7,
    'open_drawer': 0.8,
    'complex_assembly': 0.9  # Hard
}

scheduler = CurriculumScheduler(task_difficulties)

for step in range(total_steps):
    # Get current task distribution
    task_weights = scheduler.get_task_weights(total_steps)

    # Sample task
    task = np.random.choice(
        list(task_weights.keys()),
        p=np.array(list(task_weights.values())) / sum(task_weights.values())
    )

    # Train on this task
    batch = task_datasets[task].sample()
    loss = model.compute_loss(batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    scheduler.current_step += 1

4. Task-Specific Losses

Different losses for different tasks:

class MultiTaskLoss(nn.Module):
    """Combined loss for multi-task learning"""
    def __init__(self, task_weights=None):
        super().__init__()
        self.task_weights = task_weights or {}

    def forward(self, model, batch):
        """Compute weighted multi-task loss"""
        total_loss = 0
        task_losses = {}

        for task_name in batch['task_names']:
            # Get samples for this task
            task_indices = batch['task_names'] == task_name
            task_batch = {k: v[task_indices] for k, v in batch.items()}

            # Task-specific loss
            if task_name == 'pick_place':
                # Binary cross-entropy for grasp success
                loss = self.pick_place_loss(model, task_batch)
            elif task_name == 'navigation':
                # Distance-based loss
                loss = self.navigation_loss(model, task_batch)
            else:
                # Default: action prediction
                loss = F.mse_loss(
                    model.predict(task_batch),
                    task_batch['actions']
                )

            # Weight and accumulate
            weight = self.task_weights.get(task_name, 1.0)
            total_loss += weight * loss
            task_losses[task_name] = loss.item()

        return total_loss, task_losses

    def pick_place_loss(self, model, batch):
        """Loss for pick-and-place"""
        # Action prediction
        action_loss = F.mse_loss(
            model.predict(batch),
            batch['actions']
        )

        # Auxiliary: grasp success prediction
        grasp_logits = model.predict_grasp_success(batch)
        grasp_loss = F.binary_cross_entropy_with_logits(
            grasp_logits,
            batch['grasp_success']
        )

        return action_loss + 0.1 * grasp_loss

    def navigation_loss(self, model, batch):
        """Loss for navigation"""
        # Predict waypoints
        waypoints_pred = model.predict_waypoints(batch)
        waypoints_true = batch['waypoints']

        # Distance-weighted loss (later waypoints less important)
        weights = torch.exp(-torch.arange(len(waypoints_true)) * 0.1)
        loss = (weights * F.mse_loss(
            waypoints_pred,
            waypoints_true,
            reduction='none'
        ).mean(dim=-1)).mean()

        return loss

Task Composition

Learn primitives that can be combined:

class HierarchicalMultiTaskVLA(nn.Module):
    """Hierarchical model with task primitives"""
    def __init__(self):
        super().__init__()

        # Low-level controllers (primitives)
        self.primitives = nn.ModuleDict({
            'reach': PrimitiveController(),
            'grasp': PrimitiveController(),
            'move': PrimitiveController(),
            'release': PrimitiveController(),
        })

        # High-level policy (selects primitives)
        self.high_level_policy = HighLevelPolicy()

    def forward(self, observation, task_instruction):
        """
        Two-level policy

        1. High-level: which primitive to execute
        2. Low-level: how to execute primitive
        """
        # Parse task into sequence of primitives
        # E.g., "pick and place" → [reach, grasp, move, release]
        primitive_sequence = self.parse_task(task_instruction)

        # Execute primitives sequentially
        actions = []
        for primitive_name in primitive_sequence:
            # Get primitive controller
            primitive = self.primitives[primitive_name]

            # Execute primitive
            action = primitive(observation)
            actions.append(action)

            # Update state (in simulation or real robot)
            observation = self.step(action)

        return actions

    def parse_task(self, task_instruction):
        """
        Parse language into primitive sequence

        Uses high-level policy (can be learned or rule-based)
        """
        # Option 1: Rule-based
        if "pick and place" in task_instruction:
            return ['reach', 'grasp', 'move', 'release']
        elif "open drawer" in task_instruction:
            return ['reach', 'grasp', 'pull']

        # Option 2: Learned (seq2seq model)
        return self.high_level_policy.predict_primitives(task_instruction)

Multi-Task Evaluation

Evaluate on all tasks simultaneously:

class MultiTaskEvaluator:
    """Evaluate multi-task model"""
    def __init__(self, model, task_envs):
        self.model = model
        self.task_envs = task_envs  # Dict: task_name -> env

    def evaluate(self, num_episodes_per_task=10):
        """Evaluate on all tasks"""
        results = {}

        for task_name, env in self.task_envs.items():
            task_results = self.evaluate_task(task_name, env, num_episodes_per_task)
            results[task_name] = task_results

        # Aggregate metrics
        avg_success_rate = np.mean([r['success_rate'] for r in results.values()])
        results['average_success_rate'] = avg_success_rate

        # Print report
        print("="*60)
        print("MULTI-TASK EVALUATION RESULTS")
        print("="*60)
        for task_name, task_results in results.items():
            if task_name != 'average_success_rate':
                print(f"{task_name}:")
                print(f"  Success Rate: {task_results['success_rate']*100:.1f}%")
                print(f"  Avg Steps: {task_results['avg_steps']:.1f}")
        print(f"\nOverall Success Rate: {avg_success_rate*100:.1f}%")
        print("="*60)

        return results

    def evaluate_task(self, task_name, env, num_episodes):
        """Evaluate on single task"""
        successes = 0
        episode_lengths = []

        for _ in range(num_episodes):
            obs = env.reset()
            done = False
            steps = 0

            # Task instruction
            instruction = env.get_instruction()

            while not done and steps < 500:
                # Predict action
                with torch.no_grad():
                    action = self.model.predict(obs, instruction)

                # Execute
                obs, reward, done, info = env.step(action)
                steps += 1

            success = info.get('success', False)
            successes += int(success)
            episode_lengths.append(steps)

        return {
            'success_rate': successes / num_episodes,
            'avg_steps': np.mean(episode_lengths)
        }

Handling Task Interference

Tasks can interfere with each other during training:

1. Gradient Surgery

Prevent conflicting gradients:

def gradient_surgery(losses, shared_params):
    """
    PCGrad: Project conflicting gradients

    Paper: "Gradient Surgery for Multi-Task Learning"
    """
    # Compute gradients for each task
    grads = {}
    for task_name, loss in losses.items():
        task_grads = torch.autograd.grad(
            loss, shared_params,
            retain_graph=True,
            create_graph=False
        )
        grads[task_name] = task_grads

    # Project conflicting gradients
    for task1, grad1 in grads.items():
        for task2, grad2 in grads.items():
            if task1 != task2:
                # Check if gradients conflict (negative dot product)
                dot_product = sum(
                    (g1 * g2).sum()
                    for g1, g2 in zip(grad1, grad2)
                )

                if dot_product < 0:
                    # Project grad1 to be orthogonal to grad2
                    grads[task1] = [
                        g1 - (dot_product / (torch.norm(g2)**2 + 1e-8)) * g2
                        for g1, g2 in zip(grad1, grad2)
                    ]

    # Average projected gradients
    avg_grads = [
        sum(grads[task][i] for task in grads.keys()) / len(grads)
        for i in range(len(shared_params))
    ]

    # Apply averaged gradients
    for param, grad in zip(shared_params, avg_grads):
        param.grad = grad

2. Task-Specific Parameters

Give each task some dedicated parameters:

class MultiHeadVLA(nn.Module):
    """VLA with task-specific heads"""
    def __init__(self, num_tasks):
        super().__init__()

        # Shared encoder
        self.shared_encoder = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 256)
        )

        # Task-specific heads
        self.task_heads = nn.ModuleList([
            nn.Linear(256, 7)  # 7-DoF action
            for _ in range(num_tasks)
        ])

    def forward(self, observation, task_id):
        # Shared features
        features = self.shared_encoder(observation)

        # Task-specific prediction
        action = self.task_heads[task_id](features)

        return action

Advanced: Meta-Learning for Multi-Task

Learn to quickly adapt to new tasks:

class MAMLMultiTask:
    """Model-Agnostic Meta-Learning for multi-task VLA"""
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=outer_lr)

    def meta_train(self, task_datasets, num_iterations=1000):
        """Meta-training loop"""
        for iteration in range(num_iterations):
            # Sample batch of tasks
            task_batch = random.sample(list(task_datasets.keys()), k=4)

            meta_loss = 0

            for task_name in task_batch:
                # Clone model
                task_model = copy.deepcopy(self.model)

                # Inner loop: adapt to task with few examples
                support_data = task_datasets[task_name].sample(k=5)  # 5-shot

                for _ in range(5):  # 5 gradient steps
                    loss = task_model.compute_loss(support_data)

                    # Inner update
                    grads = torch.autograd.grad(loss, task_model.parameters())
                    for param, grad in zip(task_model.parameters(), grads):
                        param.data -= self.inner_lr * grad

                # Query loss after adaptation
                query_data = task_datasets[task_name].sample(k=10)
                query_loss = task_model.compute_loss(query_data)

                meta_loss += query_loss

            # Outer update: improve initial parameters
            self.meta_optimizer.zero_grad()
            meta_loss.backward()
            self.meta_optimizer.step()

            if iteration % 100 == 0:
                print(f"Meta-iteration {iteration}, Meta-loss: {meta_loss.item():.4f}")

        return self.model

Best Practices

DO:

✓ Use language for task conditioning (most flexible) ✓ Balance task sampling during training ✓ Monitor per-task performance separately ✓ Start with curriculum learning for complex task sets ✓ Share as much as possible (encoder, fusion) ✓ Use task-specific losses when appropriate

DON'T:

✗Train on imbalanced task distribution ✗Ignore task interference (use gradient surgery if needed) ✗Use separate models when multi-task would work ✗Forget to evaluate on task composition ✗Over-parameterize task-specific components

Example: Complete Multi-Task Training

# Define tasks
task_datasets = {
    'pick_red_block': load_dataset('pick_red'),
    'pick_blue_block': load_dataset('pick_blue'),
    'open_drawer': load_dataset('drawer'),
    'close_drawer': load_dataset('drawer_close'),
    'wipe_table': load_dataset('wipe'),
}

# Create model
model = MultiTaskVLA(
    vision_encoder=CLIPVisionEncoder(),
    language_encoder=CLIPTextEncoder(),
    action_head=ActionHead(action_dim=7)
)

# Balanced sampler
sampler = BalancedTaskSampler(task_datasets)

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for step in range(10000):
    # Sample balanced batch
    batch = sampler.sample_batch(batch_size=32)

    # Compute loss
    loss = model.compute_loss(batch)

    # Update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Evaluate periodically
    if step % 1000 == 0:
        evaluator = MultiTaskEvaluator(model, task_envs)
        results = evaluator.evaluate()

# Save model
torch.save(model.state_dict(), 'multitask_vla.pt')

Next Steps