Model-Based Reinforcement Learning¶
Model-based RL learns a model of the environment to dramatically improve sample efficiency - critical for real-world robotics.
Overview¶
Model-based RL learns a dynamics model \(p(s_{t+1}|s_t, a_t)\) and optionally a reward model \(r(s_t, a_t)\), then uses this model for:
- Planning: Search for actions using the learned model
- Data augmentation: Generate synthetic rollouts
- Policy learning: Train policy on model predictions
Key advantages: - 10-100x better sample efficiency than model-free methods - Can plan ahead using learned model - Enables sim-to-real transfer - Supports offline learning
Key challenges: - Model errors compound over long horizons - Complex environments difficult to model - Computational cost of planning
State-of-the-art algorithms: - TD-MPC (Temporal Difference Model Predictive Control) - Dreamer v3 (World models with imagination) - MuZero (Model-based planning without explicit model) - PlaNet (Deep planning networks)
TD-MPC: Temporal Difference Model Predictive Control¶
TD-MPC combines model learning with online MPC planning for state-of-the-art performance.
Paper: Hansen et al., "Temporal Difference Learning for Model Predictive Control", ICML 2022
Mathematical Foundation¶
TD-MPC learns three components:
- Latent dynamics model: \(z_{t+1} = h(z_t, a_t)\)
- Reward predictor: \(\hat{r}_t = r(z_t, a_t)\)
- Q-function: \(Q(z_t, a_t)\) for value estimation
Planning objective (MPC at inference):
Where predictions are made using learned dynamics \(h\).
Learning objective (TD learning):
Complete Implementation¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class TDMPC:
"""
Temporal Difference Model Predictive Control
Reference: Hansen et al., ICML 2022
"""
def __init__(
self,
obs_dim,
action_dim,
latent_dim=256,
hidden_dim=512,
horizon=5,
num_samples=512,
num_elites=64,
num_iterations=6,
lr=1e-3,
gamma=0.99,
tau=0.01,
buffer_size=1_000_000
):
self.action_dim = action_dim
self.latent_dim = latent_dim
self.horizon = horizon
self.num_samples = num_samples
self.num_elites = num_elites
self.num_iterations = num_iterations
self.gamma = gamma
self.tau = tau
# Encoder: observations -> latent states
self.encoder = Encoder(obs_dim, latent_dim, hidden_dim)
# Dynamics model: (latent, action) -> next latent
self.dynamics = DynamicsModel(latent_dim, action_dim, hidden_dim)
# Reward predictor
self.reward = RewardModel(latent_dim, action_dim, hidden_dim)
# Q-function (twin critics)
self.q1 = QNetwork(latent_dim, action_dim, hidden_dim)
self.q2 = QNetwork(latent_dim, action_dim, hidden_dim)
# Target Q-networks
self.q1_target = QNetwork(latent_dim, action_dim, hidden_dim)
self.q2_target = QNetwork(latent_dim, action_dim, hidden_dim)
self.q1_target.load_state_dict(self.q1.state_dict())
self.q2_target.load_state_dict(self.q2.state_dict())
# Single optimizer for all networks
self.optimizer = torch.optim.Adam(
list(self.encoder.parameters()) +
list(self.dynamics.parameters()) +
list(self.reward.parameters()) +
list(self.q1.parameters()) +
list(self.q2.parameters()),
lr=lr
)
# Replay buffer
self.replay_buffer = ReplayBuffer(obs_dim, action_dim, buffer_size)
def plan(self, obs, deterministic=False):
"""
Model Predictive Control planning using CEM
Uses Cross-Entropy Method (CEM) to optimize action sequence
"""
with torch.no_grad():
# Encode observation to latent state
z = self.encoder(torch.FloatTensor(obs).unsqueeze(0))
# Initialize action distribution
mean = torch.zeros(self.horizon, self.action_dim)
std = torch.ones(self.horizon, self.action_dim)
# CEM iterations
for iteration in range(self.num_iterations):
# Sample action sequences
actions = mean.unsqueeze(0) + std.unsqueeze(0) * torch.randn(
self.num_samples, self.horizon, self.action_dim
)
actions = torch.clamp(actions, -1, 1)
# Evaluate action sequences
values = self._evaluate_sequences(z.repeat(self.num_samples, 1), actions)
# Select elite samples
elite_idxs = torch.topk(values, self.num_elites, dim=0).indices
elite_actions = actions[elite_idxs]
# Update distribution
mean = elite_actions.mean(dim=0)
std = elite_actions.std(dim=0)
# Return first action of best sequence
return mean[0].cpu().numpy()
def _evaluate_sequences(self, z, action_sequences):
"""Evaluate action sequences using learned model and Q-function"""
total_value = torch.zeros(action_sequences.shape[0])
for t in range(self.horizon):
actions = action_sequences[:, t]
# Predict reward
rewards = self.reward(z, actions).squeeze(-1)
# Predict next latent state
z = self.dynamics(z, actions)
# Accumulate discounted reward
total_value += (self.gamma ** t) * rewards
# Add terminal Q-value (no action, use zeros)
terminal_actions = torch.zeros(z.shape[0], self.action_dim)
terminal_q = torch.min(
self.q1(z, terminal_actions),
self.q2(z, terminal_actions)
).squeeze(-1)
total_value += (self.gamma ** self.horizon) * terminal_q
return total_value
def update(self, batch_size=256):
"""Update all networks using TD learning"""
if len(self.replay_buffer) < batch_size:
return {}
# Sample batch
obs, actions, rewards, next_obs, dones = self.replay_buffer.sample(batch_size)
# Encode to latent space
z = self.encoder(obs)
next_z = self.encoder(next_obs)
# ===== Dynamics Model Loss =====
# Predict next latent state
pred_next_z = self.dynamics(z, actions)
# Dynamics loss (predict next encoding)
dynamics_loss = F.mse_loss(pred_next_z, next_z.detach())
# ===== Reward Model Loss =====
pred_rewards = self.reward(z, actions)
reward_loss = F.mse_loss(pred_rewards, rewards)
# ===== Q-Function Loss (TD learning) =====
# Current Q-values
q1_pred = self.q1(z, actions)
q2_pred = self.q2(z, actions)
with torch.no_grad():
# Use planning to select next action
# (Simplified: use random action for efficiency)
next_actions = torch.clamp(torch.randn_like(actions) * 0.3, -1, 1)
# Compute target Q-value
target_q1 = self.q1_target(next_z, next_actions)
target_q2 = self.q2_target(next_z, next_actions)
target_q = torch.min(target_q1, target_q2)
target = rewards + (1 - dones) * self.gamma * target_q
# Q-function losses
q1_loss = F.mse_loss(q1_pred, target)
q2_loss = F.mse_loss(q2_pred, target)
# ===== Consistency Loss =====
# Ensure Q-function is consistent with reward predictions
with torch.no_grad():
# Multi-step rollout
z_rollout = z
predicted_return = torch.zeros_like(rewards)
for h in range(min(3, self.horizon)): # Short horizon for efficiency
# Predict reward and next state
r_pred = self.reward(z_rollout, actions)
z_rollout = self.dynamics(z_rollout, actions)
predicted_return += (self.gamma ** h) * r_pred
consistency_loss = F.mse_loss(q1_pred, predicted_return.detach())
# ===== Total Loss =====
total_loss = (
dynamics_loss +
reward_loss +
q1_loss +
q2_loss +
0.1 * consistency_loss # Weaker weight for consistency
)
# Optimize
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(self.encoder.parameters()) +
list(self.dynamics.parameters()) +
list(self.reward.parameters()) +
list(self.q1.parameters()) +
list(self.q2.parameters()),
max_norm=10.0
)
self.optimizer.step()
# Update target networks
self._soft_update(self.q1, self.q1_target)
self._soft_update(self.q2, self.q2_target)
return {
'dynamics_loss': dynamics_loss.item(),
'reward_loss': reward_loss.item(),
'q1_loss': q1_loss.item(),
'q2_loss': q2_loss.item(),
'consistency_loss': consistency_loss.item(),
'total_loss': total_loss.item()
}
def _soft_update(self, source, target):
"""Soft update target network"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - self.tau) + param.data * self.tau
)
class Encoder(nn.Module):
"""Encode observations to latent states"""
def __init__(self, obs_dim, latent_dim, hidden_dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, latent_dim)
)
def forward(self, obs):
return self.net(obs)
class DynamicsModel(nn.Module):
"""Latent dynamics: z_t+1 = f(z_t, a_t)"""
def __init__(self, latent_dim, action_dim, hidden_dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, latent_dim)
)
def forward(self, z, action):
x = torch.cat([z, action], dim=-1)
return self.net(x)
class RewardModel(nn.Module):
"""Predict reward: r = r(z_t, a_t)"""
def __init__(self, latent_dim, action_dim, hidden_dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
def forward(self, z, action):
x = torch.cat([z, action], dim=-1)
return self.net(x)
class QNetwork(nn.Module):
"""Q-function: Q(z_t, a_t)"""
def __init__(self, latent_dim, action_dim, hidden_dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
def forward(self, z, action):
x = torch.cat([z, action], dim=-1)
return self.net(x)
class ReplayBuffer:
"""Standard replay buffer"""
def __init__(self, obs_dim, action_dim, max_size=1_000_000):
self.max_size = max_size
self.ptr = 0
self.size = 0
self.obs = np.zeros((max_size, obs_dim))
self.actions = np.zeros((max_size, action_dim))
self.rewards = np.zeros((max_size, 1))
self.next_obs = np.zeros((max_size, obs_dim))
self.dones = np.zeros((max_size, 1))
def add(self, obs, action, reward, next_obs, done):
self.obs[self.ptr] = obs
self.actions[self.ptr] = action
self.rewards[self.ptr] = reward
self.next_obs[self.ptr] = next_obs
self.dones[self.ptr] = done
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
indices = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.obs[indices]),
torch.FloatTensor(self.actions[indices]),
torch.FloatTensor(self.rewards[indices]),
torch.FloatTensor(self.next_obs[indices]),
torch.FloatTensor(self.dones[indices])
)
def __len__(self):
return self.size
Hyperparameter Tuning¶
| Parameter | Typical Range | Default | Notes |
|---|---|---|---|
horizon |
3-10 | 5 | Longer = more planning, slower |
num_samples |
256-1024 | 512 | CEM samples per iteration |
num_elites |
32-128 | 64 | Top samples to fit |
num_iterations |
4-10 | 6 | CEM refinement iterations |
latent_dim |
128-512 | 256 | Latent state size |
lr |
1e-4 to 3e-3 | 1e-3 | Single LR for all networks |
Tuning tips:
# For fast control loops (manipulation)
tdmpc_manipulation = TDMPC(
horizon=3, # Short horizon
num_samples=256, # Fewer samples for speed
num_iterations=4,
latent_dim=128
)
# For complex planning (locomotion)
tdmpc_locomotion = TDMPC(
horizon=10, # Long horizon
num_samples=1024, # More samples
num_iterations=8,
latent_dim=512
)
Dreamer v3: World Models¶
Dreamer learns a world model in latent space and uses it to train policies via imagination.
Papers: - Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination", ICLR 2020 - Hafner et al., "Mastering Diverse Domains through World Models", arXiv 2023 (v3)
Architecture¶
World model components:
- Encoder: \(e: o_t \rightarrow z_t\)
- Recurrent state: \(h_t = f(h_{t-1}, z_t, a_{t-1})\) (RSSM)
- Decoder: \(\hat{o}_t = d(h_t, z_t)\)
- Reward predictor: \(\hat{r}_t = r(h_t, z_t)\)
- Continue predictor: \(\hat{c}_t = c(h_t, z_t)\) (not done)
Actor-critic: - Actor: \(\pi(a|h_t, z_t)\) - Critic: \(V(h_t, z_t)\)
Key Concepts¶
class Dreamer:
"""
Dreamer v3 - World Models for RL
Reference: Hafner et al., 2023
"""
def imagine_trajectories(self, start_state, horizon=15):
"""
Imagine trajectories using world model
This is the key to Dreamer: train policy on imagined data
"""
# Start from real states
h, z = start_state
imagined_trajectory = []
for t in range(horizon):
# Sample action from policy
action = self.actor(h, z).sample()
# Predict next state using world model
h_next, z_next = self.world_model.predict(h, z, action)
# Predict reward and continue
reward = self.reward_model(h_next, z_next)
continue_prob = self.continue_model(h_next, z_next)
imagined_trajectory.append({
'state': (h, z),
'action': action,
'reward': reward,
'continue': continue_prob,
'next_state': (h_next, z_next)
})
h, z = h_next, z_next
return imagined_trajectory
def update_actor_critic(self, imagined_trajectories):
"""Train actor and critic on imagined rollouts"""
# Compute lambda returns (TD(λ))
returns = self.compute_lambda_returns(imagined_trajectories)
# Update critic to predict returns
critic_loss = F.mse_loss(self.critic(states), returns)
# Update actor to maximize returns (policy gradient)
advantages = returns - self.critic(states).detach()
actor_loss = -(advantages * log_probs).mean()
return actor_loss, critic_loss
Advantages: - Learn from imagined experience (10-100x sample efficiency) - Handles image observations naturally - Single algorithm for vision and state-based tasks
When to use Dreamer: - ✓ Image-based observations - ✓ Complex, high-dimensional environments - ✓ Long-horizon tasks - ✓ Offline RL - ✗Real-time planning (slower than TD-MPC)
MuZero: Model-Based Planning¶
MuZero combines learned model with tree search (like AlphaGo) for planning.
Paper: Schrittwieser et al., "Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model", Nature 2020
Key Innovation¶
MuZero doesn't learn an explicit model of observations. Instead:
- Representation: \(s^0 = h(o_1, ..., o_t)\)
- Dynamics: \(s^{k+1}, r^k = g(s^k, a^k)\) (in latent space)
- Prediction: \(p^k, v^k = f(s^k)\) (policy and value)
Planning via MCTS (Monte Carlo Tree Search): - Use learned dynamics \(g\) to simulate trajectories - Use learned value \(v\) to evaluate states - Use learned policy \(p\) as prior for search
Simplified Implementation¶
class MuZero:
"""MuZero - Model-Based RL with Tree Search"""
def __init__(self, obs_dim, action_dim, latent_dim=256):
# Representation: observations -> latent state
self.representation = nn.Sequential(
nn.Linear(obs_dim, 512),
nn.ReLU(),
nn.Linear(512, latent_dim)
)
# Dynamics: (latent, action) -> next latent, reward
self.dynamics = nn.Sequential(
nn.Linear(latent_dim + action_dim, 512),
nn.ReLU(),
nn.Linear(512, latent_dim + 1) # latent + reward
)
# Prediction: latent -> policy, value
self.prediction_policy = nn.Linear(latent_dim, action_dim)
self.prediction_value = nn.Linear(latent_dim, 1)
def plan(self, observation, num_simulations=50):
"""Plan action using MCTS with learned model"""
# Encode observation
root_state = self.representation(observation)
# Run MCTS
root = Node(root_state)
for _ in range(num_simulations):
self._simulate(root)
# Select action with most visits
action = root.select_best_action()
return action
def _simulate(self, node):
"""Single MCTS simulation"""
# Selection: traverse tree using UCB
path = [node]
while node.is_expanded():
node = node.select_child_ucb()
path.append(node)
# Expansion: expand leaf node
state = node.state
policy_logits, value = self.predict(state)
node.expand(policy_logits)
# Backup: propagate value up the tree
for parent in reversed(path):
parent.value_sum += value
parent.visit_count += 1
value = parent.reward + self.gamma * value
def predict(self, state):
"""Predict policy and value from latent state"""
policy_logits = self.prediction_policy(state)
value = self.prediction_value(state)
return policy_logits, value
def step_dynamics(self, state, action):
"""Predict next state and reward using dynamics model"""
x = torch.cat([state, action], dim=-1)
output = self.dynamics(x)
next_state = output[:-1]
reward = output[-1]
return next_state, reward
When to use MuZero: - ✓ Discrete action spaces - ✓ Games and board games - ✓ Environments requiring long-horizon planning - ✗Continuous control (use TD-MPC or Dreamer) - ✗Real-time constraints (MCTS is slow)
Algorithm Comparison¶
Sample Efficiency¶
On DeepMind Control Suite (1M environment steps):
| Algorithm | Dog-Run | Cheetah-Run | Hopper-Hop | Avg Score |
|---|---|---|---|---|
| TD-MPC | 850 | 720 | 450 | 673 |
| Dreamer v3 | 820 | 690 | 470 | 660 |
| MuZero | N/A | N/A | N/A | (Discrete only) |
| SAC (model-free) | 580 | 650 | 380 | 537 |
| PPO (model-free) | 420 | 520 | 310 | 417 |
Conclusion: Model-based methods achieve 20-50% higher scores with same data.
Computational Cost¶
| Algorithm | Training Time (relative) | Inference Time | Memory |
|---|---|---|---|
| TD-MPC | 2.0x | Fast (MPC online) | Medium |
| Dreamer v3 | 2.5x | Fast (no planning) | High |
| MuZero | 5.0x | Slow (MCTS) | Very High |
| SAC | 1.0x (baseline) | Fastest | Low |
Recommendation for robotics: - TD-MPC: Best balance of sample efficiency and inference speed - Dreamer: Best for vision-based tasks - SAC/TD3: When sample efficiency matters less than simplicity
Practical Tips¶
Model Accuracy vs Planning Horizon¶
Rule of thumb: Model error compounds exponentially with horizon.
def choose_horizon(model_accuracy, task_horizon):
"""
Choose planning horizon based on model accuracy
model_accuracy: MSE on validation set
task_horizon: Actual task horizon
"""
if model_accuracy < 0.01:
# Very accurate model
return min(10, task_horizon)
elif model_accuracy < 0.1:
# Moderate accuracy
return min(5, task_horizon)
else:
# Poor model, use short horizon
return min(3, task_horizon)
Ensemble Models¶
Use ensemble of dynamics models to quantify uncertainty:
class EnsembleDynamics:
"""Ensemble of dynamics models for uncertainty"""
def __init__(self, latent_dim, action_dim, num_models=5):
self.models = [
DynamicsModel(latent_dim, action_dim)
for _ in range(num_models)
]
def predict(self, z, action):
"""Predict with uncertainty"""
predictions = [model(z, action) for model in self.models]
mean = torch.stack(predictions).mean(dim=0)
std = torch.stack(predictions).std(dim=0)
return mean, std
def select_action_pessimistic(self, z, action_candidates):
"""Select action with lowest uncertainty (pessimistic)"""
uncertainties = []
for action in action_candidates:
_, std = self.predict(z, action)
uncertainties.append(std.sum().item())
best_idx = np.argmin(uncertainties)
return action_candidates[best_idx]
Domain Randomization for Model Learning¶
Improve model robustness with randomization:
class RandomizedDynamics(nn.Module):
"""Dynamics model with domain randomization"""
def forward(self, z, action, randomize=True):
# Standard prediction
z_next = self.net(torch.cat([z, action], dim=-1))
if randomize and self.training:
# Add noise during training
noise = torch.randn_like(z_next) * 0.1
z_next = z_next + noise
return z_next
Common Issues & Solutions¶
Problem: Model Divergence¶
Symptoms: Model predictions diverge from reality after few steps
Solutions:
# 1. Shorter planning horizon
horizon = 3 # Instead of 10
# 2. Model ensemble for uncertainty
dynamics = EnsembleDynamics(num_models=5)
# 3. Replan more frequently
replan_frequency = 1 # Replan every step
# 4. Add model regularization
model_loss = prediction_error + 0.01 * weight_decay
Problem: Slow Planning¶
Symptoms: MPC planning takes too long for real-time control
Solutions:
# 1. Reduce CEM samples
num_samples = 256 # Instead of 512
# 2. Reduce CEM iterations
num_iterations = 4 # Instead of 6
# 3. Use cached plans (warm start)
def plan_with_warmstart(previous_plan):
# Shift previous plan
initial_mean = torch.cat([previous_plan[1:], torch.zeros(1, action_dim)])
return cem_plan(initial_mean=initial_mean)
# 4. Use GPU for parallel simulation
device = 'cuda'
Problem: Poor Sample Efficiency (Still)¶
Symptoms: Model-based method not much better than model-free
Solutions:
# 1. Increase model capacity
hidden_dim = 1024 # Instead of 512
# 2. Use larger batch size
batch_size = 512 # Instead of 256
# 3. Train model more frequently
model_updates_per_env_step = 4 # Multiple updates
# 4. Increase replay buffer
buffer_size = 5_000_000 # Store more diverse data
References¶
Papers¶
- TD-MPC: Hansen et al., "Temporal Difference Learning for Model Predictive Control", ICML 2022 (arXiv)
- Dreamer v1: Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination", ICLR 2020 (arXiv)
- Dreamer v2: Hafner et al., "Mastering Atari with Discrete World Models", ICLR 2021 (arXiv)
- Dreamer v3: Hafner et al., "Mastering Diverse Domains through World Models", arXiv 2023 (arXiv)
- MuZero: Schrittwieser et al., "Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model", Nature 2020 (Nature)
- PlaNet: Hafner et al., "Learning Latent Dynamics for Planning from Pixels", ICML 2019 (arXiv)
- MBPO: Janner et al., "When to Trust Your Model: Model-Based Policy Optimization", NeurIPS 2019 (arXiv)
Books¶
- Sutton & Barto, "Reinforcement Learning: An Introduction", 2nd Edition, 2018
-
Chapter 8: Planning and Learning
-
Bertsekas, "Dynamic Programming and Optimal Control", 2017
-
Comprehensive coverage of planning algorithms
-
LaValle, "Planning Algorithms", 2006
- Free online: http://lavalle.pl/planning/
Code Implementations¶
- TD-MPC Official: https://github.com/nicklashansen/tdmpc
- Clean PyTorch implementation
-
DMC and Meta-World benchmarks
-
Dreamer v3 Official: https://github.com/danijar/dreamerv3
- JAX implementation
-
Supports Atari, DMC, Minecraft
-
MuZero PyTorch: https://github.com/werner-duvaud/muzero-general
- Unofficial but well-documented
-
Supports many environments
-
MBRL-Lib (Facebook): https://github.com/facebookresearch/mbrl-lib
- Model-based RL library
- Includes PETS, MBPO, PlaNet
Tutorials¶
- Model-Based RL Tutorial (Berkeley): https://sites.google.com/view/mbrl-tutorial
- Comprehensive slides and videos
-
Sergey Levine's CS 285 lectures
-
World Models Blog: https://worldmodels.github.io/
- Visual introduction to world models
-
Interactive demos
-
Dreamer Tutorial: https://danijar.com/project/dreamerv3/
- Official tutorial and tips
- Training recipes
Next Steps¶
- Off-Policy Algorithms (SAC, TD3) - Model-free comparison
- Distributed Training - Scale up model-based RL
- Sim-to-Real - Transfer learned models
Framework Guides¶
- RSL-RL - Isaac Lab model-based extensions
- MBRL-Lib - Facebook's model-based library
- Dreamer Implementations - JAX and PyTorch