Skip to content

Model-Based Reinforcement Learning

Model-based RL learns a model of the environment to dramatically improve sample efficiency - critical for real-world robotics.

Overview

Model-based RL learns a dynamics model \(p(s_{t+1}|s_t, a_t)\) and optionally a reward model \(r(s_t, a_t)\), then uses this model for:

  1. Planning: Search for actions using the learned model
  2. Data augmentation: Generate synthetic rollouts
  3. Policy learning: Train policy on model predictions

Key advantages: - 10-100x better sample efficiency than model-free methods - Can plan ahead using learned model - Enables sim-to-real transfer - Supports offline learning

Key challenges: - Model errors compound over long horizons - Complex environments difficult to model - Computational cost of planning

State-of-the-art algorithms: - TD-MPC (Temporal Difference Model Predictive Control) - Dreamer v3 (World models with imagination) - MuZero (Model-based planning without explicit model) - PlaNet (Deep planning networks)

TD-MPC: Temporal Difference Model Predictive Control

TD-MPC combines model learning with online MPC planning for state-of-the-art performance.

Paper: Hansen et al., "Temporal Difference Learning for Model Predictive Control", ICML 2022

Mathematical Foundation

TD-MPC learns three components:

  1. Latent dynamics model: \(z_{t+1} = h(z_t, a_t)\)
  2. Reward predictor: \(\hat{r}_t = r(z_t, a_t)\)
  3. Q-function: \(Q(z_t, a_t)\) for value estimation

Planning objective (MPC at inference):

\[ a_t^* = \arg\max_{a_{t:t+H}} \sum_{i=0}^{H} \gamma^i \hat{r}_{t+i} \]

Where predictions are made using learned dynamics \(h\).

Learning objective (TD learning):

\[ \mathcal{L} = \mathbb{E} \left[ (Q(z_t, a_t) - (r_t + \gamma Q(z_{t+1}, \pi(z_{t+1}))))^2 \right] \]

Complete Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class TDMPC:
    """
    Temporal Difference Model Predictive Control

    Reference: Hansen et al., ICML 2022
    """
    def __init__(
        self,
        obs_dim,
        action_dim,
        latent_dim=256,
        hidden_dim=512,
        horizon=5,
        num_samples=512,
        num_elites=64,
        num_iterations=6,
        lr=1e-3,
        gamma=0.99,
        tau=0.01,
        buffer_size=1_000_000
    ):
        self.action_dim = action_dim
        self.latent_dim = latent_dim
        self.horizon = horizon
        self.num_samples = num_samples
        self.num_elites = num_elites
        self.num_iterations = num_iterations
        self.gamma = gamma
        self.tau = tau

        # Encoder: observations -> latent states
        self.encoder = Encoder(obs_dim, latent_dim, hidden_dim)

        # Dynamics model: (latent, action) -> next latent
        self.dynamics = DynamicsModel(latent_dim, action_dim, hidden_dim)

        # Reward predictor
        self.reward = RewardModel(latent_dim, action_dim, hidden_dim)

        # Q-function (twin critics)
        self.q1 = QNetwork(latent_dim, action_dim, hidden_dim)
        self.q2 = QNetwork(latent_dim, action_dim, hidden_dim)

        # Target Q-networks
        self.q1_target = QNetwork(latent_dim, action_dim, hidden_dim)
        self.q2_target = QNetwork(latent_dim, action_dim, hidden_dim)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        # Single optimizer for all networks
        self.optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) +
            list(self.dynamics.parameters()) +
            list(self.reward.parameters()) +
            list(self.q1.parameters()) +
            list(self.q2.parameters()),
            lr=lr
        )

        # Replay buffer
        self.replay_buffer = ReplayBuffer(obs_dim, action_dim, buffer_size)

    def plan(self, obs, deterministic=False):
        """
        Model Predictive Control planning using CEM

        Uses Cross-Entropy Method (CEM) to optimize action sequence
        """
        with torch.no_grad():
            # Encode observation to latent state
            z = self.encoder(torch.FloatTensor(obs).unsqueeze(0))

            # Initialize action distribution
            mean = torch.zeros(self.horizon, self.action_dim)
            std = torch.ones(self.horizon, self.action_dim)

            # CEM iterations
            for iteration in range(self.num_iterations):
                # Sample action sequences
                actions = mean.unsqueeze(0) + std.unsqueeze(0) * torch.randn(
                    self.num_samples, self.horizon, self.action_dim
                )
                actions = torch.clamp(actions, -1, 1)

                # Evaluate action sequences
                values = self._evaluate_sequences(z.repeat(self.num_samples, 1), actions)

                # Select elite samples
                elite_idxs = torch.topk(values, self.num_elites, dim=0).indices
                elite_actions = actions[elite_idxs]

                # Update distribution
                mean = elite_actions.mean(dim=0)
                std = elite_actions.std(dim=0)

            # Return first action of best sequence
            return mean[0].cpu().numpy()

    def _evaluate_sequences(self, z, action_sequences):
        """Evaluate action sequences using learned model and Q-function"""
        total_value = torch.zeros(action_sequences.shape[0])

        for t in range(self.horizon):
            actions = action_sequences[:, t]

            # Predict reward
            rewards = self.reward(z, actions).squeeze(-1)

            # Predict next latent state
            z = self.dynamics(z, actions)

            # Accumulate discounted reward
            total_value += (self.gamma ** t) * rewards

        # Add terminal Q-value (no action, use zeros)
        terminal_actions = torch.zeros(z.shape[0], self.action_dim)
        terminal_q = torch.min(
            self.q1(z, terminal_actions),
            self.q2(z, terminal_actions)
        ).squeeze(-1)

        total_value += (self.gamma ** self.horizon) * terminal_q

        return total_value

    def update(self, batch_size=256):
        """Update all networks using TD learning"""
        if len(self.replay_buffer) < batch_size:
            return {}

        # Sample batch
        obs, actions, rewards, next_obs, dones = self.replay_buffer.sample(batch_size)

        # Encode to latent space
        z = self.encoder(obs)
        next_z = self.encoder(next_obs)

        # ===== Dynamics Model Loss =====

        # Predict next latent state
        pred_next_z = self.dynamics(z, actions)

        # Dynamics loss (predict next encoding)
        dynamics_loss = F.mse_loss(pred_next_z, next_z.detach())

        # ===== Reward Model Loss =====

        pred_rewards = self.reward(z, actions)
        reward_loss = F.mse_loss(pred_rewards, rewards)

        # ===== Q-Function Loss (TD learning) =====

        # Current Q-values
        q1_pred = self.q1(z, actions)
        q2_pred = self.q2(z, actions)

        with torch.no_grad():
            # Use planning to select next action
            # (Simplified: use random action for efficiency)
            next_actions = torch.clamp(torch.randn_like(actions) * 0.3, -1, 1)

            # Compute target Q-value
            target_q1 = self.q1_target(next_z, next_actions)
            target_q2 = self.q2_target(next_z, next_actions)
            target_q = torch.min(target_q1, target_q2)

            target = rewards + (1 - dones) * self.gamma * target_q

        # Q-function losses
        q1_loss = F.mse_loss(q1_pred, target)
        q2_loss = F.mse_loss(q2_pred, target)

        # ===== Consistency Loss =====
        # Ensure Q-function is consistent with reward predictions

        with torch.no_grad():
            # Multi-step rollout
            z_rollout = z
            predicted_return = torch.zeros_like(rewards)

            for h in range(min(3, self.horizon)):  # Short horizon for efficiency
                # Predict reward and next state
                r_pred = self.reward(z_rollout, actions)
                z_rollout = self.dynamics(z_rollout, actions)

                predicted_return += (self.gamma ** h) * r_pred

        consistency_loss = F.mse_loss(q1_pred, predicted_return.detach())

        # ===== Total Loss =====

        total_loss = (
            dynamics_loss +
            reward_loss +
            q1_loss +
            q2_loss +
            0.1 * consistency_loss  # Weaker weight for consistency
        )

        # Optimize
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            list(self.encoder.parameters()) +
            list(self.dynamics.parameters()) +
            list(self.reward.parameters()) +
            list(self.q1.parameters()) +
            list(self.q2.parameters()),
            max_norm=10.0
        )
        self.optimizer.step()

        # Update target networks
        self._soft_update(self.q1, self.q1_target)
        self._soft_update(self.q2, self.q2_target)

        return {
            'dynamics_loss': dynamics_loss.item(),
            'reward_loss': reward_loss.item(),
            'q1_loss': q1_loss.item(),
            'q2_loss': q2_loss.item(),
            'consistency_loss': consistency_loss.item(),
            'total_loss': total_loss.item()
        }

    def _soft_update(self, source, target):
        """Soft update target network"""
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - self.tau) + param.data * self.tau
            )


class Encoder(nn.Module):
    """Encode observations to latent states"""
    def __init__(self, obs_dim, latent_dim, hidden_dim=512):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, latent_dim)
        )

    def forward(self, obs):
        return self.net(obs)


class DynamicsModel(nn.Module):
    """Latent dynamics: z_t+1 = f(z_t, a_t)"""
    def __init__(self, latent_dim, action_dim, hidden_dim=512):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, latent_dim)
        )

    def forward(self, z, action):
        x = torch.cat([z, action], dim=-1)
        return self.net(x)


class RewardModel(nn.Module):
    """Predict reward: r = r(z_t, a_t)"""
    def __init__(self, latent_dim, action_dim, hidden_dim=512):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, z, action):
        x = torch.cat([z, action], dim=-1)
        return self.net(x)


class QNetwork(nn.Module):
    """Q-function: Q(z_t, a_t)"""
    def __init__(self, latent_dim, action_dim, hidden_dim=512):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(latent_dim + action_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, z, action):
        x = torch.cat([z, action], dim=-1)
        return self.net(x)


class ReplayBuffer:
    """Standard replay buffer"""
    def __init__(self, obs_dim, action_dim, max_size=1_000_000):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0

        self.obs = np.zeros((max_size, obs_dim))
        self.actions = np.zeros((max_size, action_dim))
        self.rewards = np.zeros((max_size, 1))
        self.next_obs = np.zeros((max_size, obs_dim))
        self.dones = np.zeros((max_size, 1))

    def add(self, obs, action, reward, next_obs, done):
        self.obs[self.ptr] = obs
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_obs[self.ptr] = next_obs
        self.dones[self.ptr] = done

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        indices = np.random.randint(0, self.size, size=batch_size)

        return (
            torch.FloatTensor(self.obs[indices]),
            torch.FloatTensor(self.actions[indices]),
            torch.FloatTensor(self.rewards[indices]),
            torch.FloatTensor(self.next_obs[indices]),
            torch.FloatTensor(self.dones[indices])
        )

    def __len__(self):
        return self.size

Hyperparameter Tuning

Parameter Typical Range Default Notes
horizon 3-10 5 Longer = more planning, slower
num_samples 256-1024 512 CEM samples per iteration
num_elites 32-128 64 Top samples to fit
num_iterations 4-10 6 CEM refinement iterations
latent_dim 128-512 256 Latent state size
lr 1e-4 to 3e-3 1e-3 Single LR for all networks

Tuning tips:

# For fast control loops (manipulation)
tdmpc_manipulation = TDMPC(
    horizon=3,  # Short horizon
    num_samples=256,  # Fewer samples for speed
    num_iterations=4,
    latent_dim=128
)

# For complex planning (locomotion)
tdmpc_locomotion = TDMPC(
    horizon=10,  # Long horizon
    num_samples=1024,  # More samples
    num_iterations=8,
    latent_dim=512
)

Dreamer v3: World Models

Dreamer learns a world model in latent space and uses it to train policies via imagination.

Papers: - Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination", ICLR 2020 - Hafner et al., "Mastering Diverse Domains through World Models", arXiv 2023 (v3)

Architecture

World model components:

  1. Encoder: \(e: o_t \rightarrow z_t\)
  2. Recurrent state: \(h_t = f(h_{t-1}, z_t, a_{t-1})\) (RSSM)
  3. Decoder: \(\hat{o}_t = d(h_t, z_t)\)
  4. Reward predictor: \(\hat{r}_t = r(h_t, z_t)\)
  5. Continue predictor: \(\hat{c}_t = c(h_t, z_t)\) (not done)

Actor-critic: - Actor: \(\pi(a|h_t, z_t)\) - Critic: \(V(h_t, z_t)\)

Key Concepts

class Dreamer:
    """
    Dreamer v3 - World Models for RL

    Reference: Hafner et al., 2023
    """

    def imagine_trajectories(self, start_state, horizon=15):
        """
        Imagine trajectories using world model

        This is the key to Dreamer: train policy on imagined data
        """
        # Start from real states
        h, z = start_state

        imagined_trajectory = []

        for t in range(horizon):
            # Sample action from policy
            action = self.actor(h, z).sample()

            # Predict next state using world model
            h_next, z_next = self.world_model.predict(h, z, action)

            # Predict reward and continue
            reward = self.reward_model(h_next, z_next)
            continue_prob = self.continue_model(h_next, z_next)

            imagined_trajectory.append({
                'state': (h, z),
                'action': action,
                'reward': reward,
                'continue': continue_prob,
                'next_state': (h_next, z_next)
            })

            h, z = h_next, z_next

        return imagined_trajectory

    def update_actor_critic(self, imagined_trajectories):
        """Train actor and critic on imagined rollouts"""

        # Compute lambda returns (TD(λ))
        returns = self.compute_lambda_returns(imagined_trajectories)

        # Update critic to predict returns
        critic_loss = F.mse_loss(self.critic(states), returns)

        # Update actor to maximize returns (policy gradient)
        advantages = returns - self.critic(states).detach()
        actor_loss = -(advantages * log_probs).mean()

        return actor_loss, critic_loss

Advantages: - Learn from imagined experience (10-100x sample efficiency) - Handles image observations naturally - Single algorithm for vision and state-based tasks

When to use Dreamer: - ✓ Image-based observations - ✓ Complex, high-dimensional environments - ✓ Long-horizon tasks - ✓ Offline RL - ✗Real-time planning (slower than TD-MPC)

MuZero: Model-Based Planning

MuZero combines learned model with tree search (like AlphaGo) for planning.

Paper: Schrittwieser et al., "Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model", Nature 2020

Key Innovation

MuZero doesn't learn an explicit model of observations. Instead:

  1. Representation: \(s^0 = h(o_1, ..., o_t)\)
  2. Dynamics: \(s^{k+1}, r^k = g(s^k, a^k)\) (in latent space)
  3. Prediction: \(p^k, v^k = f(s^k)\) (policy and value)

Planning via MCTS (Monte Carlo Tree Search): - Use learned dynamics \(g\) to simulate trajectories - Use learned value \(v\) to evaluate states - Use learned policy \(p\) as prior for search

Simplified Implementation

class MuZero:
    """MuZero - Model-Based RL with Tree Search"""

    def __init__(self, obs_dim, action_dim, latent_dim=256):
        # Representation: observations -> latent state
        self.representation = nn.Sequential(
            nn.Linear(obs_dim, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )

        # Dynamics: (latent, action) -> next latent, reward
        self.dynamics = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim + 1)  # latent + reward
        )

        # Prediction: latent -> policy, value
        self.prediction_policy = nn.Linear(latent_dim, action_dim)
        self.prediction_value = nn.Linear(latent_dim, 1)

    def plan(self, observation, num_simulations=50):
        """Plan action using MCTS with learned model"""

        # Encode observation
        root_state = self.representation(observation)

        # Run MCTS
        root = Node(root_state)
        for _ in range(num_simulations):
            self._simulate(root)

        # Select action with most visits
        action = root.select_best_action()

        return action

    def _simulate(self, node):
        """Single MCTS simulation"""

        # Selection: traverse tree using UCB
        path = [node]
        while node.is_expanded():
            node = node.select_child_ucb()
            path.append(node)

        # Expansion: expand leaf node
        state = node.state
        policy_logits, value = self.predict(state)
        node.expand(policy_logits)

        # Backup: propagate value up the tree
        for parent in reversed(path):
            parent.value_sum += value
            parent.visit_count += 1
            value = parent.reward + self.gamma * value

    def predict(self, state):
        """Predict policy and value from latent state"""
        policy_logits = self.prediction_policy(state)
        value = self.prediction_value(state)
        return policy_logits, value

    def step_dynamics(self, state, action):
        """Predict next state and reward using dynamics model"""
        x = torch.cat([state, action], dim=-1)
        output = self.dynamics(x)
        next_state = output[:-1]
        reward = output[-1]
        return next_state, reward

When to use MuZero: - ✓ Discrete action spaces - ✓ Games and board games - ✓ Environments requiring long-horizon planning - ✗Continuous control (use TD-MPC or Dreamer) - ✗Real-time constraints (MCTS is slow)

Algorithm Comparison

Sample Efficiency

On DeepMind Control Suite (1M environment steps):

Algorithm Dog-Run Cheetah-Run Hopper-Hop Avg Score
TD-MPC 850 720 450 673
Dreamer v3 820 690 470 660
MuZero N/A N/A N/A (Discrete only)
SAC (model-free) 580 650 380 537
PPO (model-free) 420 520 310 417

Conclusion: Model-based methods achieve 20-50% higher scores with same data.

Computational Cost

Algorithm Training Time (relative) Inference Time Memory
TD-MPC 2.0x Fast (MPC online) Medium
Dreamer v3 2.5x Fast (no planning) High
MuZero 5.0x Slow (MCTS) Very High
SAC 1.0x (baseline) Fastest Low

Recommendation for robotics: - TD-MPC: Best balance of sample efficiency and inference speed - Dreamer: Best for vision-based tasks - SAC/TD3: When sample efficiency matters less than simplicity

Practical Tips

Model Accuracy vs Planning Horizon

Rule of thumb: Model error compounds exponentially with horizon.

def choose_horizon(model_accuracy, task_horizon):
    """
    Choose planning horizon based on model accuracy

    model_accuracy: MSE on validation set
    task_horizon: Actual task horizon
    """
    if model_accuracy < 0.01:
        # Very accurate model
        return min(10, task_horizon)
    elif model_accuracy < 0.1:
        # Moderate accuracy
        return min(5, task_horizon)
    else:
        # Poor model, use short horizon
        return min(3, task_horizon)

Ensemble Models

Use ensemble of dynamics models to quantify uncertainty:

class EnsembleDynamics:
    """Ensemble of dynamics models for uncertainty"""

    def __init__(self, latent_dim, action_dim, num_models=5):
        self.models = [
            DynamicsModel(latent_dim, action_dim)
            for _ in range(num_models)
        ]

    def predict(self, z, action):
        """Predict with uncertainty"""
        predictions = [model(z, action) for model in self.models]

        mean = torch.stack(predictions).mean(dim=0)
        std = torch.stack(predictions).std(dim=0)

        return mean, std

    def select_action_pessimistic(self, z, action_candidates):
        """Select action with lowest uncertainty (pessimistic)"""
        uncertainties = []

        for action in action_candidates:
            _, std = self.predict(z, action)
            uncertainties.append(std.sum().item())

        best_idx = np.argmin(uncertainties)
        return action_candidates[best_idx]

Domain Randomization for Model Learning

Improve model robustness with randomization:

class RandomizedDynamics(nn.Module):
    """Dynamics model with domain randomization"""

    def forward(self, z, action, randomize=True):
        # Standard prediction
        z_next = self.net(torch.cat([z, action], dim=-1))

        if randomize and self.training:
            # Add noise during training
            noise = torch.randn_like(z_next) * 0.1
            z_next = z_next + noise

        return z_next

Common Issues & Solutions

Problem: Model Divergence

Symptoms: Model predictions diverge from reality after few steps

Solutions:

# 1. Shorter planning horizon
horizon = 3  # Instead of 10

# 2. Model ensemble for uncertainty
dynamics = EnsembleDynamics(num_models=5)

# 3. Replan more frequently
replan_frequency = 1  # Replan every step

# 4. Add model regularization
model_loss = prediction_error + 0.01 * weight_decay

Problem: Slow Planning

Symptoms: MPC planning takes too long for real-time control

Solutions:

# 1. Reduce CEM samples
num_samples = 256  # Instead of 512

# 2. Reduce CEM iterations
num_iterations = 4  # Instead of 6

# 3. Use cached plans (warm start)
def plan_with_warmstart(previous_plan):
    # Shift previous plan
    initial_mean = torch.cat([previous_plan[1:], torch.zeros(1, action_dim)])
    return cem_plan(initial_mean=initial_mean)

# 4. Use GPU for parallel simulation
device = 'cuda'

Problem: Poor Sample Efficiency (Still)

Symptoms: Model-based method not much better than model-free

Solutions:

# 1. Increase model capacity
hidden_dim = 1024  # Instead of 512

# 2. Use larger batch size
batch_size = 512  # Instead of 256

# 3. Train model more frequently
model_updates_per_env_step = 4  # Multiple updates

# 4. Increase replay buffer
buffer_size = 5_000_000  # Store more diverse data

References

Papers

  1. TD-MPC: Hansen et al., "Temporal Difference Learning for Model Predictive Control", ICML 2022 (arXiv)
  2. Dreamer v1: Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination", ICLR 2020 (arXiv)
  3. Dreamer v2: Hafner et al., "Mastering Atari with Discrete World Models", ICLR 2021 (arXiv)
  4. Dreamer v3: Hafner et al., "Mastering Diverse Domains through World Models", arXiv 2023 (arXiv)
  5. MuZero: Schrittwieser et al., "Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model", Nature 2020 (Nature)
  6. PlaNet: Hafner et al., "Learning Latent Dynamics for Planning from Pixels", ICML 2019 (arXiv)
  7. MBPO: Janner et al., "When to Trust Your Model: Model-Based Policy Optimization", NeurIPS 2019 (arXiv)

Books

  1. Sutton & Barto, "Reinforcement Learning: An Introduction", 2nd Edition, 2018
  2. Chapter 8: Planning and Learning

  3. Bertsekas, "Dynamic Programming and Optimal Control", 2017

  4. Comprehensive coverage of planning algorithms

  5. LaValle, "Planning Algorithms", 2006

  6. Free online: http://lavalle.pl/planning/

Code Implementations

  • TD-MPC Official: https://github.com/nicklashansen/tdmpc
  • Clean PyTorch implementation
  • DMC and Meta-World benchmarks

  • Dreamer v3 Official: https://github.com/danijar/dreamerv3

  • JAX implementation
  • Supports Atari, DMC, Minecraft

  • MuZero PyTorch: https://github.com/werner-duvaud/muzero-general

  • Unofficial but well-documented
  • Supports many environments

  • MBRL-Lib (Facebook): https://github.com/facebookresearch/mbrl-lib

  • Model-based RL library
  • Includes PETS, MBPO, PlaNet

Tutorials

  • Model-Based RL Tutorial (Berkeley): https://sites.google.com/view/mbrl-tutorial
  • Comprehensive slides and videos
  • Sergey Levine's CS 285 lectures

  • World Models Blog: https://worldmodels.github.io/

  • Visual introduction to world models
  • Interactive demos

  • Dreamer Tutorial: https://danijar.com/project/dreamerv3/

  • Official tutorial and tips
  • Training recipes

Next Steps

Framework Guides