Skip to content

Stable-Baselines3: Easy-to-Use RL

Stable-Baselines3 (SB3) is the most user-friendly and well-documented RL library, perfect for research and rapid prototyping.

Overview

Stable-Baselines3 is the successor to Stable-Baselines (based on OpenAI Baselines), providing:

  • Production-ready implementations of major RL algorithms
  • Extensive documentation and tutorials
  • Easy customization and extension
  • Active maintenance and community support

Key Features: - ✓ Complete implementations (PPO, A2C, SAC, TD3, DQN, DDPG) - ✓ Extensive callbacks and logging - ✓ Pre-trained models (RL Zoo) - ✓ Hyperparameter optimization - ✓ TensorBoard, WandB integration - ✓ Excellent documentation

Official Repository: https://github.com/DLR-RM/stable-baselines3

Installation

Basic Installation

# Install from PyPI
pip install stable-baselines3[extra]

# Core only (minimal dependencies)
pip install stable-baselines3

# Latest from GitHub
pip install git+https://github.com/DLR-RM/stable-baselines3

With Extra Dependencies

# Full installation with all features
pip install stable-baselines3[extra]

# Includes:
# - TensorBoard
# - rich (progress bars)
# - matplotlib (plotting)
# - pandas (logging)

RL Zoo (Pre-trained Models)

# Install RL Zoo
pip install rl-zoo3

# Download pre-trained models
python -m rl_zoo3.load_from_hub --algo ppo --env CartPole-v1

Quick Start

Basic Training

import gymnasium as gym
from stable_baselines3 import PPO

# Create environment
env = gym.make("CartPole-v1")

# Create model
model = PPO("MlpPolicy", env, verbose=1)

# Train
model.learn(total_timesteps=10_000)

# Save
model.save("ppo_cartpole")

# Load
model = PPO.load("ppo_cartpole")

# Evaluate
obs, info = env.reset()
for _ in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        obs, info = env.reset()

Available Algorithms

from stable_baselines3 import PPO, A2C, SAC, TD3, DQN, DDPG

# On-policy
ppo = PPO("MlpPolicy", env)
a2c = A2C("MlpPolicy", env)

# Off-policy (continuous)
sac = SAC("MlpPolicy", env)
td3 = TD3("MlpPolicy", env)
ddpg = DDPG("MlpPolicy", env)

# Off-policy (discrete)
dqn = DQN("MlpPolicy", env)

Core Components

1. Policies

SB3 provides several pre-built policy networks:

from stable_baselines3 import PPO

# Multi-Layer Perceptron (MLP) policy
model = PPO("MlpPolicy", env)

# Convolutional Neural Network (CNN) policy (for images)
model = PPO("CnnPolicy", env)

# Multi-input policy (for dict observations)
model = PPO("MultiInputPolicy", env)

# Custom policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch.nn as nn

class CustomCNN(BaseFeaturesExtractor):
    """Custom CNN feature extractor"""
    def __init__(self, observation_space, features_dim=256):
        super().__init__(observation_space, features_dim)

        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU()
        )

    def forward(self, observations):
        return self.linear(self.cnn(observations))

# Use custom policy
policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=256),
)

model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)

2. Training Configuration

Configure algorithm hyperparameters:

from stable_baselines3 import PPO

model = PPO(
    "MlpPolicy",
    env,
    # Training hyperparameters
    learning_rate=3e-4,
    n_steps=2048,  # Rollout buffer size
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    clip_range_vf=None,  # No clipping for value function
    normalize_advantage=True,
    ent_coef=0.0,  # Entropy coefficient
    vf_coef=0.5,  # Value function coefficient
    max_grad_norm=0.5,

    # Network architecture
    policy_kwargs=dict(
        net_arch=[dict(pi=[256, 256], vf=[256, 256])],  # Separate actor-critic
        activation_fn=nn.ReLU,
        ortho_init=True
    ),

    # Logging
    verbose=1,
    tensorboard_log="./ppo_tensorboard/",

    # Device
    device="cuda"
)

3. Callbacks

SB3 provides powerful callbacks for monitoring and control:

from stable_baselines3.common.callbacks import (
    EvalCallback,
    CheckpointCallback,
    CallbackList,
    StopTrainingOnRewardThreshold
)

# Evaluation callback
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./logs/",
    log_path="./logs/",
    eval_freq=10000,
    deterministic=True,
    render=False
)

# Checkpoint callback
checkpoint_callback = CheckpointCallback(
    save_freq=10000,
    save_path="./checkpoints/",
    name_prefix="ppo_model"
)

# Stop training when reward threshold reached
callback_on_best = StopTrainingOnRewardThreshold(
    reward_threshold=200,
    verbose=1
)

# Combine callbacks
callbacks = CallbackList([
    eval_callback,
    checkpoint_callback,
    callback_on_best
])

# Train with callbacks
model.learn(total_timesteps=1_000_000, callback=callbacks)

4. Custom Callbacks

Create custom callbacks for specific needs:

from stable_baselines3.common.callbacks import BaseCallback

class CustomCallback(BaseCallback):
    """
    Custom callback for saving model based on custom metric

    Example: Save model when episode length > threshold
    """
    def __init__(self, check_freq: int, save_path: str, verbose=1):
        super().__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
        self.best_mean_length = 0

    def _on_step(self) -> bool:
        if self.n_calls % self.check_freq == 0:
            # Retrieve episode lengths
            if len(self.model.ep_info_buffer) > 0:
                mean_length = np.mean([ep_info["l"] for ep_info in self.model.ep_info_buffer])

                if mean_length > self.best_mean_length:
                    self.best_mean_length = mean_length
                    if self.verbose > 0:
                        print(f"Saving new best model with mean length: {mean_length:.2f}")
                    self.model.save(os.path.join(self.save_path, "best_model"))

        return True

# Use custom callback
custom_callback = CustomCallback(check_freq=1000, save_path="./logs/")
model.learn(total_timesteps=100_000, callback=custom_callback)

Advanced Features

Vectorized Environments

Train on multiple environments in parallel:

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

# Method 1: Simple wrapper (single process)
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=DummyVecEnv)

# Method 2: Multi-process (faster)
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)

# Train on vectorized environment
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=100_000)

Wrappers

SB3 includes many useful wrappers:

from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack
from stable_baselines3.common.atari_wrappers import AtariWrapper
import gymnasium as gym

# Normalize observations and rewards
env = make_vec_env("Ant-v4", n_envs=4)
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)

# Frame stacking (for Atari/vision)
env = make_vec_env("PongNoFrameskip-v4", n_envs=4)
env = VecFrameStack(env, n_stack=4)

# Atari preprocessing
env = gym.make("PongNoFrameskip-v4")
env = AtariWrapper(env)

# Train
model = PPO("CnnPolicy", env)
model.learn(total_timesteps=1_000_000)

# Save normalization statistics
env.save("vec_normalize.pkl")

# Load for evaluation
env = make_vec_env("Ant-v4", n_envs=1)
env = VecNormalize.load("vec_normalize.pkl", env)
env.training = False  # Don't update statistics
env.norm_reward = False  # Don't normalize rewards

Replay Buffers

For off-policy algorithms:

from stable_baselines3 import SAC
from stable_baselines3.common.buffers import ReplayBuffer

model = SAC(
    "MlpPolicy",
    env,
    buffer_size=1_000_000,  # Replay buffer size
    learning_starts=10_000,  # Start training after N steps
    batch_size=256,
    tau=0.005,
    gamma=0.99,
    train_freq=1,  # Update every step
    gradient_steps=1,  # Number of gradient steps per update
)

# Access replay buffer
print(f"Buffer size: {model.replay_buffer.size()}")
print(f"Buffer capacity: {model.replay_buffer.buffer_size}")

# Sample from buffer
if model.replay_buffer.size() > 0:
    replay_data = model.replay_buffer.sample(batch_size=256)
    print(f"Observations shape: {replay_data.observations.shape}")

Hyperparameter Optimization

Use Optuna for automated hyperparameter tuning:

import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import gymnasium as gym

def optimize_ppo(trial):
    """Objective function for Optuna"""
    # Sample hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    n_steps = trial.suggest_categorical("n_steps", [256, 512, 1024, 2048])
    batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
    n_epochs = trial.suggest_int("n_epochs", 5, 20)
    gamma = trial.suggest_float("gamma", 0.9, 0.9999, log=True)
    gae_lambda = trial.suggest_float("gae_lambda", 0.8, 0.99)
    clip_range = trial.suggest_float("clip_range", 0.1, 0.4)
    ent_coef = trial.suggest_float("ent_coef", 1e-8, 1e-1, log=True)

    # Create environment
    env = gym.make("CartPole-v1")
    eval_env = gym.make("CartPole-v1")

    # Create model with sampled hyperparameters
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=learning_rate,
        n_steps=n_steps,
        batch_size=batch_size,
        n_epochs=n_epochs,
        gamma=gamma,
        gae_lambda=gae_lambda,
        clip_range=clip_range,
        ent_coef=ent_coef,
        verbose=0
    )

    # Train
    model.learn(total_timesteps=100_000)

    # Evaluate
    mean_reward, std_reward = evaluate_policy(
        model, eval_env, n_eval_episodes=10
    )

    return mean_reward

# Run optimization
study = optuna.create_study(direction="maximize")
study.optimize(optimize_ppo, n_trials=100, timeout=3600)

print("Best hyperparameters:")
print(study.best_params)
print(f"Best reward: {study.best_value}")

Pre-trained Models (RL Zoo)

Use pre-trained models from RL Zoo:

# List available models
python -m rl_zoo3.cli list

# Download pre-trained model
python -m rl_zoo3.load_from_hub --algo ppo --env HalfCheetah-v4 -f logs/ -orga sb3

# Evaluate pre-trained model
python -m rl_zoo3.enjoy --algo ppo --env HalfCheetah-v4 -f logs/ -n 5000

Python API:

from rl_zoo3 import ALGOS
from rl_zoo3.load_from_hub import download_from_hub

# Download model
model_path = download_from_hub(
    algo="ppo",
    env_name="HalfCheetah-v4",
    org="sb3"
)

# Load and use
model = PPO.load(model_path)

env = gym.make("HalfCheetah-v4")
obs, info = env.reset()

for _ in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)

Algorithm-Specific Examples

PPO (On-Policy)

from stable_baselines3 import PPO

model = PPO(
    "MlpPolicy",
    env,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.0,
    verbose=1
)

model.learn(total_timesteps=1_000_000)

SAC (Off-Policy, Continuous)

from stable_baselines3 import SAC

model = SAC(
    "MlpPolicy",
    env,
    learning_rate=3e-4,
    buffer_size=1_000_000,
    learning_starts=10_000,
    batch_size=256,
    tau=0.005,
    gamma=0.99,
    train_freq=1,
    gradient_steps=1,
    ent_coef='auto',  # Automatic entropy tuning
    verbose=1
)

model.learn(total_timesteps=1_000_000)

TD3 (Off-Policy, Continuous)

from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise

# Action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(
    mean=np.zeros(n_actions),
    sigma=0.1 * np.ones(n_actions)
)

model = TD3(
    "MlpPolicy",
    env,
    learning_rate=1e-3,
    buffer_size=1_000_000,
    learning_starts=10_000,
    batch_size=100,
    tau=0.005,
    gamma=0.99,
    train_freq=(1, "episode"),
    gradient_steps=-1,  # Update at end of each episode
    action_noise=action_noise,
    policy_delay=2,  # Delayed policy updates
    target_policy_noise=0.2,
    target_noise_clip=0.5,
    verbose=1
)

model.learn(total_timesteps=1_000_000)

DQN (Off-Policy, Discrete)

from stable_baselines3 import DQN

model = DQN(
    "MlpPolicy",
    env,
    learning_rate=1e-4,
    buffer_size=100_000,
    learning_starts=10_000,
    batch_size=32,
    tau=1.0,
    gamma=0.99,
    train_freq=4,
    gradient_steps=1,
    target_update_interval=10_000,
    exploration_fraction=0.1,
    exploration_initial_eps=1.0,
    exploration_final_eps=0.05,
    verbose=1
)

model.learn(total_timesteps=1_000_000)

Robotics Integration

Custom Gym Environment

import gymnasium as gym
from gymnasium import spaces
import numpy as np

class RobotEnv(gym.Env):
    """Custom robot environment"""

    def __init__(self):
        super().__init__()

        # Define action and observation space
        self.action_space = spaces.Box(
            low=-1.0, high=1.0, shape=(4,), dtype=np.float32
        )
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32
        )

        # Initialize robot (pseudocode)
        # self.robot = Robot()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        # Reset robot to initial state
        # observation = self.robot.get_observation()
        observation = np.zeros(10, dtype=np.float32)
        info = {}

        return observation, info

    def step(self, action):
        # Apply action to robot
        # self.robot.apply_action(action)

        # Get new observation
        # observation = self.robot.get_observation()
        observation = np.zeros(10, dtype=np.float32)

        # Calculate reward
        reward = self._calculate_reward(observation, action)

        # Check if done
        terminated = False  # Task completed
        truncated = False  # Time limit reached

        info = {}

        return observation, reward, terminated, truncated, info

    def _calculate_reward(self, observation, action):
        # Example reward function
        reward = -np.linalg.norm(observation[:3])  # Distance to goal
        reward -= 0.01 * np.linalg.norm(action)  # Action penalty
        return reward

# Register environment
gym.register(
    id='Robot-v0',
    entry_point='__main__:RobotEnv',
    max_episode_steps=1000
)

# Use with SB3
env = gym.make('Robot-v0')
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=100_000)

Real Robot Example

from stable_baselines3 import SAC
import gymnasium as gym

# Connect to real robot
env = gym.make('RealRobot-v0')  # Your robot environment

# Load pre-trained model (from sim)
model = SAC.load("sim_policy")

# Fine-tune on real robot (with small LR)
model.learning_rate = 1e-5
model.set_env(env)

# Fine-tune with safety limits
from stable_baselines3.common.callbacks import BaseCallback

class SafetyCallback(BaseCallback):
    """Stop training if unsafe behavior detected"""
    def __init__(self, force_limit=100.0):
        super().__init__()
        self.force_limit = force_limit

    def _on_step(self) -> bool:
        # Check safety constraints
        if "force" in self.locals["infos"][0]:
            force = self.locals["infos"][0]["force"]
            if force > self.force_limit:
                print(f"Safety violation! Force: {force}")
                return False  # Stop training
        return True

safety_callback = SafetyCallback(force_limit=100.0)
model.learn(total_timesteps=10_000, callback=safety_callback)

Tips & Best Practices

Choosing an Algorithm

# For continuous control (robotics)
# - Sample efficient: SAC
# - Stable: TD3
# - Fast convergence: PPO

# For discrete control (games)
# - Best overall: DQN
# - On-policy: A2C/PPO

# Quick guide
if env.action_space.__class__.__name__ == 'Box':
    # Continuous actions
    if you_need_sample_efficiency:
        algorithm = SAC
    elif you_need_stability:
        algorithm = TD3
    else:
        algorithm = PPO
else:
    # Discrete actions
    algorithm = DQN

Debugging Tips

# 1. Start simple
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)

# 2. Check environment
from stable_baselines3.common.env_checker import check_env
check_env(env)

# 3. Monitor training
from stable_baselines3.common.monitor import Monitor
env = Monitor(env, "./logs/")

# 4. Visualize with TensorBoard
tensorboard --logdir ./ppo_tensorboard/

# 5. Evaluate frequently
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward} +/- {std_reward}")

Common Issues

Problem: Slow training

# Solution: Use vectorized environments
env = make_vec_env("Ant-v4", n_envs=8, vec_env_cls=SubprocVecEnv)

Problem: Unstable training

# Solution: Normalize observations/rewards
env = VecNormalize(env, norm_obs=True, norm_reward=True)

Problem: Poor exploration

# Solution: Increase entropy coefficient (PPO)
model = PPO("MlpPolicy", env, ent_coef=0.01)

# Or increase exploration noise (TD3/DDPG)
action_noise = NormalActionNoise(mean=0, sigma=0.3)
model = TD3("MlpPolicy", env, action_noise=action_noise)

References

Official Resources

  • Documentation: https://stable-baselines3.readthedocs.io/
  • GitHub: https://github.com/DLR-RM/stable-baselines3
  • RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo

Tutorials

  • Getting Started: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html
  • Custom Policies: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html
  • Callbacks: https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html

Papers

SB3 implements algorithms from:

  1. PPO: Schulman et al., "Proximal Policy Optimization", 2017
  2. SAC: Haarnoja et al., "Soft Actor-Critic", 2018
  3. TD3: Fujimoto et al., "Addressing Function Approximation Error", 2018
  4. DQN: Mnih et al., "Human-level control through deep RL", Nature 2015

Community

  • GitHub Discussions: https://github.com/DLR-RM/stable-baselines3/discussions
  • Discord: https://discord.gg/nnWPWFbcCK
  • RL Discord: https://discord.gg/xhfNqQv

Next Steps

  • RSL-RL - Isaac Lab specialized library
  • RL Games - High-performance alternative
  • SB3 Contrib - Additional algorithms (TQC, QR-DQN, etc.)