Skip to content

Diffusion Policies for Robotics

Diffusion Policies represent a breakthrough in imitation learning, enabling multi-modal action distributions and superior performance on complex manipulation tasks.

Paper: "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" (Chi et al., RSS 2023)

Why Diffusion for Robot Learning?

Traditional behavioral cloning fails on multi-modal action distributions:

# Problem: Averaging diverse expert actions
expert_demo_1: grasp_from_top()     # Action: [0, 0, -1, ...]
expert_demo_2: grasp_from_side()    # Action: [1, 0, 0, ...]

# Behavioral Cloning averages:
bc_action = mean([demo_1, demo_2])  # Result: [0.5, 0, -0.5, ...] ✗Invalid!

Diffusion Policies model the full distribution instead of collapsing to a single mode.

Core Concept

Diffusion models learn to denoise:

graph LR
    A[Pure Noise] -->|Denoise| B[Slightly Less Noisy]
    B -->|Denoise| C[Recognizable Action]
    C -->|Denoise| D[Clean Action]

    style A fill:#ff9999
    style B fill:#ffcc99
    style C fill:#99ccff
    style D fill:#99ff99

Mathematical Foundation

Forward Process (add noise):

\[ q(a_t | a_{t-1}) = \mathcal{N}(a_t; \sqrt{1-\beta_t} a_{t-1}, \beta_t I) \]

Reverse Process (denoise):

\[ p_\theta(a_{t-1} | a_t, o) = \mathcal{N}(a_{t-1}; \mu_\theta(a_t, t, o), \Sigma_\theta(a_t, t, o)) \]

Where: - \(a_t\): Noisy action at diffusion step \(t\) - \(o\): Observation (image, state, etc.) - \(\beta_t\): Noise schedule - \(\theta\): Neural network parameters

Architecture

import torch
import torch.nn as nn

class DiffusionPolicy(nn.Module):
    """
    Diffusion Policy for visuomotor control
    Adapted from Chi et al., RSS 2023
    """
    def __init__(self, obs_dim, action_dim, action_horizon=16, diffusion_steps=100):
        super().__init__()

        self.action_dim = action_dim
        self.action_horizon = action_horizon
        self.diffusion_steps = diffusion_steps

        # Vision encoder (ResNet or ViT)
        self.vision_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.GroupNorm(8, 64),
            nn.ReLU(),
            ResNetBlocks(num_blocks=3),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, 256)
        )

        # Observation encoder (for proprioception)
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )

        # Noise prediction network (U-Net style)
        self.noise_pred_net = ConditionalUNet1D(
            input_dim=action_dim,
            global_cond_dim=256 + 128,  # vision + obs
            diffusion_step_embed_dim=128,
            down_dims=[256, 512, 1024],
            kernel_size=5,
            n_groups=8
        )

        # Noise schedule (DDPM)
        self.betas = self.make_beta_schedule(diffusion_steps)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

    def make_beta_schedule(self, num_steps, schedule='cosine'):
        """Create noise schedule"""
        if schedule == 'linear':
            return torch.linspace(0.0001, 0.02, num_steps)
        elif schedule == 'cosine':
            # Improved cosine schedule from Nichol & Dhariwal 2021
            s = 0.008
            steps = num_steps + 1
            x = torch.linspace(0, num_steps, steps)
            alphas_cumprod = torch.cos(((x / num_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
            return torch.clip(betas, 0.0001, 0.9999)

    def forward(self, obs, actions=None):
        """
        Training: predict noise given noisy actions
        Inference: iteratively denoise random noise
        """
        # Encode observations
        if 'image' in obs:
            vision_features = self.vision_encoder(obs['image'])
        else:
            vision_features = torch.zeros(obs['state'].shape[0], 256).to(obs['state'].device)

        obs_features = self.obs_encoder(obs['state'])
        cond = torch.cat([vision_features, obs_features], dim=-1)

        if self.training:
            # Training: denoise actions
            return self.compute_loss(cond, actions)
        else:
            # Inference: generate actions
            return self.generate_actions(cond)

    def compute_loss(self, cond, actions):
        """Compute diffusion training loss"""
        batch_size = actions.shape[0]

        # Sample random diffusion timesteps
        t = torch.randint(0, self.diffusion_steps, (batch_size,), device=actions.device)

        # Add noise to actions
        noise = torch.randn_like(actions)
        noisy_actions = self.add_noise(actions, noise, t)

        # Predict the noise
        predicted_noise = self.noise_pred_net(
            noisy_actions,
            timestep=t,
            global_cond=cond
        )

        # MSE loss
        loss = F.mse_loss(predicted_noise, noise)

        return loss

    def add_noise(self, actions, noise, t):
        """Add noise to actions according to schedule"""
        sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod[t])
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod[t])

        # Expand dimensions for broadcasting
        sqrt_alphas_cumprod = sqrt_alphas_cumprod.view(-1, 1, 1)
        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.view(-1, 1, 1)

        noisy_actions = sqrt_alphas_cumprod * actions + sqrt_one_minus_alphas_cumprod * noise

        return noisy_actions

    @torch.no_grad()
    def generate_actions(self, cond, num_samples=1):
        """Generate action sequence through iterative denoising"""
        batch_size = cond.shape[0]

        # Start from random noise
        actions = torch.randn(
            batch_size,
            self.action_horizon,
            self.action_dim,
            device=cond.device
        )

        # Iteratively denoise
        for t in reversed(range(self.diffusion_steps)):
            # Predict noise
            t_batch = torch.full((batch_size,), t, device=cond.device, dtype=torch.long)
            predicted_noise = self.noise_pred_net(
                actions,
                timestep=t_batch,
                global_cond=cond
            )

            # Denoise step
            actions = self.denoise_step(actions, predicted_noise, t)

        return actions

    def denoise_step(self, noisy_actions, predicted_noise, t):
        """Single denoising step (DDPM)"""
        alpha = self.alphas[t]
        alpha_cumprod = self.alphas_cumprod[t]
        beta = self.betas[t]

        # Predict x_0 (clean actions)
        predicted_actions = (
            noisy_actions - torch.sqrt(1 - alpha_cumprod) * predicted_noise
        ) / torch.sqrt(alpha_cumprod)

        # Compute mean for p(x_{t-1} | x_t)
        if t > 0:
            alpha_cumprod_prev = self.alphas_cumprod[t - 1]
        else:
            alpha_cumprod_prev = torch.tensor(1.0)

        predicted_mean = (
            torch.sqrt(alpha_cumprod_prev) * beta / (1 - alpha_cumprod) * predicted_actions +
            torch.sqrt(alpha) * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod) * noisy_actions
        )

        # Add noise (except at last step)
        if t > 0:
            noise = torch.randn_like(noisy_actions)
            variance = beta * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod)
            predicted_mean = predicted_mean + torch.sqrt(variance) * noise

        return predicted_mean

Key Innovations

1. Action Chunking with Receding Horizon

Predict multiple future actions, execute only the first few:

# Predict 16-step action sequence
action_sequence = diffusion_policy.generate_actions(obs)  # Shape: (1, 16, 7)

# Execute only first 8 steps
for i in range(8):
    robot.execute(action_sequence[0, i])
    time.sleep(1/control_frequency)

# Re-predict for next window

Benefits: - Temporal coherence - Smoother trajectories - Better long-horizon planning

2. Conditional U-Net Architecture

class ConditionalUNet1D(nn.Module):
    """U-Net for 1D action sequences with global conditioning"""
    def __init__(self, input_dim, global_cond_dim, diffusion_step_embed_dim, down_dims):
        super().__init__()

        # Time embedding (sinusoidal)
        self.time_embedding = SinusoidalPosEmb(diffusion_step_embed_dim)

        # Global conditioning projection
        self.global_cond_proj = nn.Linear(global_cond_dim, down_dims[0])

        # Downsampling path
        self.down_blocks = nn.ModuleList()
        in_dim = input_dim
        for out_dim in down_dims:
            self.down_blocks.append(
                ResidualBlock1D(in_dim, out_dim, diffusion_step_embed_dim)
            )
            in_dim = out_dim

        # Upsampling path
        self.up_blocks = nn.ModuleList()
        for out_dim in reversed(down_dims[:-1]):
            self.up_blocks.append(
                ResidualBlock1D(in_dim + out_dim, out_dim, diffusion_step_embed_dim)  # Skip connection
            )
            in_dim = out_dim

        # Final layer
        self.final = nn.Conv1d(in_dim, input_dim, 1)

    def forward(self, x, timestep, global_cond):
        # Time embedding
        t_emb = self.time_embedding(timestep)

        # Global condition
        g_emb = self.global_cond_proj(global_cond)

        # Downsampling with skip connections
        skips = []
        for block in self.down_blocks:
            x = block(x, t_emb, g_emb)
            skips.append(x)

        # Upsampling with skip connections
        for block, skip in zip(self.up_blocks, reversed(skips[:-1])):
            x = torch.cat([x, skip], dim=1)
            x = block(x, t_emb, g_emb)

        # Final projection
        return self.final(x)

3. DDPM vs DDIM Sampling

DDPM (Denoising Diffusion Probabilistic Models): - Stochastic sampling - Requires all T diffusion steps - Higher quality but slower

DDIM (Denoising Diffusion Implicit Models): - Deterministic sampling - Can skip steps (e.g., use only 10 steps instead of 100) - Faster inference

@torch.no_grad()
def ddim_sample(self, cond, ddim_steps=10):
    """Fast sampling with DDIM"""
    # Select subset of timesteps
    timesteps = torch.linspace(0, self.diffusion_steps-1, ddim_steps, dtype=torch.long)

    # Start from noise
    actions = torch.randn(1, self.action_horizon, self.action_dim, device=cond.device)

    # Iterative denoising (only ddim_steps iterations)
    for i in reversed(range(ddim_steps)):
        t = timesteps[i]
        t_prev = timesteps[i-1] if i > 0 else 0

        # Predict noise
        predicted_noise = self.noise_pred_net(actions, timestep=t, global_cond=cond)

        # DDIM update (deterministic)
        alpha_t = self.alphas_cumprod[t]
        alpha_t_prev = self.alphas_cumprod[t_prev]

        predicted_x0 = (actions - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)

        # Direction pointing to x_t
        dir_xt = torch.sqrt(1 - alpha_t_prev) * predicted_noise

        # Deterministic step
        actions = torch.sqrt(alpha_t_prev) * predicted_x0 + dir_xt

    return actions

Training

from torch.utils.data import DataLoader

def train_diffusion_policy(model, dataset, config):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)

    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)

    for epoch in range(config.num_epochs):
        for batch in dataloader:
            obs = {
                'image': batch['images'].cuda(),
                'state': batch['states'].cuda()
            }
            actions = batch['actions'].cuda()  # Shape: (B, action_horizon, action_dim)

            # Forward pass
            loss = model(obs, actions)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Performance

Benchmark Results

On challenging manipulation tasks:

Task Behavioral Cloning Diffusion Policy Improvement
Can Sorting 43% 86% +100%
Tool Use 31% 78% +152%
Bimanual 28% 72% +157%
Square Insertion 12% 48% +300%

Why So Much Better?

  1. Multi-modal actions: Handles multiple valid solutions
  2. Temporal coherence: Action chunking provides smooth trajectories
  3. Expressiveness: Can model complex action distributions
  4. Robustness: Less sensitive to distribution shift

Advanced: Image-Based Diffusion Policy

class ImageDiffusionPolicy(nn.Module):
    """Diffusion policy with vision transformer"""
    def __init__(self, action_dim=7, action_horizon=16):
        super().__init__()

        # Vision encoder: DINOv2 ViT
        self.vision_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

        # Temporal aggregation over observations
        self.obs_horizon = 3  # Use last 3 observations
        self.temporal_aggregator = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=768, nhead=12),
            num_layers=2
        )

        # Diffusion U-Net
        self.diffusion_net = ConditionalUNet1D(
            input_dim=action_dim,
            global_cond_dim=768,
            diffusion_step_embed_dim=128,
            down_dims=[256, 512, 1024]
        )

        # ... rest of diffusion machinery

    def encode_observations(self, image_sequence):
        """Encode sequence of images"""
        # image_sequence: (B, T, C, H, W) where T = obs_horizon

        B, T = image_sequence.shape[:2]

        # Encode each image
        features = []
        for t in range(T):
            with torch.no_grad():
                feat = self.vision_encoder(image_sequence[:, t])
            features.append(feat)

        features = torch.stack(features, dim=1)  # (B, T, 768)

        # Temporal aggregation
        aggregated = self.temporal_aggregator(features)

        # Use last timestep
        return aggregated[:, -1, :]

Practical Tips

1. Hyperparameter Tuning

# Good default hyperparameters
diffusion:
  num_diffusion_steps: 100  # Training
  ddim_steps: 10            # Inference (much faster)
  action_horizon: 16        # Predict 16 steps ahead
  obs_horizon: 3            # Use last 3 observations
  noise_schedule: 'cosine'  # Better than linear

training:
  learning_rate: 1e-4
  weight_decay: 1e-6
  batch_size: 64
  epochs: 1000
  ema_decay: 0.995  # Exponential moving average

2. Action Normalization

Critical for stable training:

# Normalize actions to [-1, 1]
def normalize_actions(actions, stats):
    return (actions - stats['mean']) / stats['std']

def denormalize_actions(actions_normalized, stats):
    return actions_normalized * stats['std'] + stats['mean']

3. Efficient Inference

# Use EMA model for inference
ema_model = copy.deepcopy(model)

# During training
for step in training_steps:
    loss.backward()
    optimizer.step()

    # Update EMA
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(0.995).add_(param.data, alpha=0.005)

# Use EMA model for inference
ema_model.eval()
actions = ema_model.generate_actions(obs)

Comparisons

Method Multi-modal Sample Efficiency Inference Speed
BC ⭐⭐⭐ ⭐⭐⭐
VAE Policy ⭐⭐ ⭐⭐⭐
Normalizing Flows ⭐⭐ ⭐⭐
Diffusion Policy ⭐⭐⭐ ⭐ (DDPM) / ⭐⭐ (DDIM)

References

  1. Chi et al., "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion", RSS 2023
  2. Ho et al., "Denoising Diffusion Probabilistic Models", NeurIPS 2020
  3. Song et al., "Denoising Diffusion Implicit Models", ICLR 2021

Code

Full implementation: https://github.com/real-stanford/diffusion_policy

Next Steps