Offline Reinforcement Learning¶
Offline RL (also called Batch RL) learns policies from fixed datasets without environment interaction - critical for robotics where online exploration is expensive or unsafe.
Why Offline RL?¶
Problem with Online RL: - Requires millions of environment interactions - Unsafe during exploration (real robots can break) - Expensive (time, wear-and-tear) - May not be allowed (medical, industrial settings)
Offline RL Solution: - Learn from pre-collected datasets - No environment interaction during training - Safe (no exploration) - Can leverage existing demonstration data
# Online RL
for episode in range(1_000_000): # Millions of episodes!
obs = env.reset()
while not done:
action = policy.explore(obs) # Can be unsafe
next_obs, reward, done, _ = env.step(action)
buffer.add(obs, action, reward, next_obs)
# Offline RL
dataset = load_fixed_dataset() # Pre-collected
for epoch in range(100):
batch = dataset.sample()
policy.update(batch) # No environment interaction!
The Challenge: Distributional Shift¶
Problem: Policy may select out-of-distribution (OOD) actions at test time
# Training: policy only sees actions from dataset distribution
dataset_actions = [a1, a2, a3, ...] # From demonstrations
# Test: policy might choose unseen action
test_action = policy(obs) # Could be OOD!
# If OOD: Q-values unreliable → poor performance
Solution: Constrain policy to stay close to data distribution
Offline RL Algorithms¶
1. Conservative Q-Learning (CQL)¶
Idea: Penalize Q-values for out-of-distribution actions
import torch
import torch.nn as nn
import torch.nn.functional as F
class CQL:
"""
Conservative Q-Learning
Paper: Kumar et al., "Conservative Q-Learning for Offline RL", NeurIPS 2020
"""
def __init__(self, state_dim, action_dim, alpha=1.0):
self.q_network = QNetwork(state_dim, action_dim)
self.target_q = QNetwork(state_dim, action_dim)
self.alpha = alpha # CQL penalty weight
self.optimizer = torch.optim.Adam(
self.q_network.parameters(),
lr=3e-4
)
def compute_cql_loss(self, batch):
"""
CQL loss = TD error + penalty for high OOD Q-values
L = L_TD + alpha * (Q_OOD - Q_data)
"""
states = batch['states']
actions = batch['actions']
rewards = batch['rewards']
next_states = batch['next_states']
dones = batch['dones']
# 1. Standard TD loss
with torch.no_grad():
# Target Q-value
next_actions = self.select_action(next_states)
target_q = self.target_q(next_states, next_actions)
target = rewards + (1 - dones) * 0.99 * target_q
current_q = self.q_network(states, actions)
td_loss = F.mse_loss(current_q, target)
# 2. CQL penalty
# Sample random actions (OOD)
random_actions = torch.rand_like(actions) * 2 - 1 # [-1, 1]
ood_q_values = self.q_network(states, random_actions)
# Q-values for dataset actions
data_q_values = self.q_network(states, actions)
# Penalty: push OOD Q-values down
cql_penalty = (ood_q_values.mean() - data_q_values.mean())
# Total loss
total_loss = td_loss + self.alpha * cql_penalty
return total_loss, {
'td_loss': td_loss.item(),
'cql_penalty': cql_penalty.item()
}
def train_step(self, batch):
"""Single training step"""
loss, metrics = self.compute_cql_loss(batch)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update target network periodically
self.soft_update_target()
return metrics
def soft_update_target(self, tau=0.005):
"""Polyak averaging for target network"""
for target_param, param in zip(
self.target_q.parameters(),
self.q_network.parameters()
):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
def select_action(self, state):
"""Select action using learned Q-function"""
# Discrete: argmax over actions
# Continuous: optimize action to maximize Q
# (simplified - usually need actor network)
with torch.no_grad():
# Sample candidate actions
num_samples = 100
sampled_actions = torch.rand(num_samples, self.action_dim) * 2 - 1
# Evaluate Q-values
states_expanded = state.unsqueeze(0).repeat(num_samples, 1)
q_values = self.q_network(states_expanded, sampled_actions)
# Select best
best_idx = q_values.argmax()
best_action = sampled_actions[best_idx]
return best_action
class QNetwork(nn.Module):
"""Q-network"""
def __init__(self, state_dim, action_dim):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim + action_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1) # Q-value
)
def forward(self, states, actions):
x = torch.cat([states, actions], dim=-1)
return self.network(x).squeeze(-1)
2. Implicit Q-Learning (IQL)¶
Idea: Avoid explicitly computing policy - use implicit policy from Q-function
class IQL:
"""
Implicit Q-Learning
Paper: Kostrikov et al., "Offline RL with Implicit Q-Learning", ICLR 2022
Key insight: Decouple value learning from policy extraction
"""
def __init__(self, state_dim, action_dim, tau=0.7):
# Two Q-networks (twin Q)
self.q1 = QNetwork(state_dim, action_dim)
self.q2 = QNetwork(state_dim, action_dim)
# Value network V(s) - expectile regression
self.v_network = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1)
)
# Actor network (for deployment)
self.actor = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim),
nn.Tanh()
)
self.tau = tau # Expectile parameter
self.q_optimizer = torch.optim.Adam(
list(self.q1.parameters()) + list(self.q2.parameters()),
lr=3e-4
)
self.v_optimizer = torch.optim.Adam(self.v_network.parameters(), lr=3e-4)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
def train_step(self, batch):
"""IQL training step"""
states = batch['states']
actions = batch['actions']
rewards = batch['rewards']
next_states = batch['next_states']
dones = batch['dones']
# 1. Update V-network with expectile regression
with torch.no_grad():
q1_val = self.q1(states, actions)
q2_val = self.q2(states, actions)
q_val = torch.min(q1_val, q2_val)
v_val = self.v_network(states).squeeze(-1)
# Expectile loss (asymmetric L2)
v_loss = self.expectile_loss(q_val - v_val, self.tau)
self.v_optimizer.zero_grad()
v_loss.backward()
self.v_optimizer.step()
# 2. Update Q-networks
with torch.no_grad():
next_v = self.v_network(next_states).squeeze(-1)
target_q = rewards + (1 - dones) * 0.99 * next_v
q1_pred = self.q1(states, actions)
q2_pred = self.q2(states, actions)
q_loss = F.mse_loss(q1_pred, target_q) + F.mse_loss(q2_pred, target_q)
self.q_optimizer.zero_grad()
q_loss.backward()
self.q_optimizer.step()
# 3. Update actor with advantage-weighted regression
with torch.no_grad():
q_val = torch.min(
self.q1(states, actions),
self.q2(states, actions)
)
v_val = self.v_network(states).squeeze(-1)
advantage = q_val - v_val
# Advantage weights
weights = torch.exp(advantage / 0.1) # Temperature = 0.1
weights = torch.clamp(weights, max=100.0) # Clip for stability
# Actor loss: weighted behavior cloning
predicted_actions = self.actor(states)
actor_loss = (weights * F.mse_loss(
predicted_actions, actions, reduction='none'
).mean(dim=-1)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
return {
'v_loss': v_loss.item(),
'q_loss': q_loss.item(),
'actor_loss': actor_loss.item()
}
def expectile_loss(self, diff, expectile):
"""
Expectile regression loss
Asymmetric L2 loss that learns different quantiles
tau=0.5: mean, tau=0.7: ~upper 30%, tau=0.9: ~upper 10%
"""
weight = torch.where(diff > 0, expectile, 1 - expectile)
return (weight * (diff ** 2)).mean()
def select_action(self, state):
"""Select action using learned actor"""
with torch.no_grad():
action = self.actor(state)
return action.cpu().numpy()
3. Decision Transformer¶
Idea: Frame RL as sequence modeling - predict actions conditioned on desired returns
class DecisionTransformer(nn.Module):
"""
Decision Transformer
Paper: Chen et al., "Decision Transformer: Reinforcement Learning via Sequence Modeling", NeurIPS 2021
Treats RL as conditional sequence modeling:
(R, s1, a1, R, s2, a2, ...) → predict actions
"""
def __init__(self, state_dim, action_dim, max_ep_len=1000, hidden_dim=128):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_dim = hidden_dim
# Embedding layers
self.embed_state = nn.Linear(state_dim, hidden_dim)
self.embed_action = nn.Linear(action_dim, hidden_dim)
self.embed_return = nn.Linear(1, hidden_dim)
# Positional embedding
self.embed_timestep = nn.Embedding(max_ep_len, hidden_dim)
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=4,
dim_feedforward=hidden_dim * 4,
batch_first=True
),
num_layers=6
)
# Prediction heads
self.predict_action = nn.Linear(hidden_dim, action_dim)
def forward(self, states, actions, returns_to_go, timesteps):
"""
Args:
states: (B, T, state_dim)
actions: (B, T, action_dim)
returns_to_go: (B, T, 1) - target cumulative return
timesteps: (B, T) - timestep indices
Returns:
predicted_actions: (B, T, action_dim)
"""
B, T, _ = states.shape
# Embed each modality
state_embeddings = self.embed_state(states) # (B, T, hidden_dim)
action_embeddings = self.embed_action(actions)
return_embeddings = self.embed_return(returns_to_go)
# Add positional encoding
time_embeddings = self.embed_timestep(timesteps)
state_embeddings = state_embeddings + time_embeddings
action_embeddings = action_embeddings + time_embeddings
return_embeddings = return_embeddings + time_embeddings
# Interleave: (R_1, s_1, a_1, R_2, s_2, a_2, ...)
# Stack: (B, 3*T, hidden_dim)
sequence = torch.stack([
return_embeddings,
state_embeddings,
action_embeddings
], dim=2).reshape(B, 3 * T, self.hidden_dim)
# Transformer forward pass
transformer_out = self.transformer(sequence)
# Extract state embeddings (every 3rd token, offset by 1)
state_hidden = transformer_out[:, 1::3, :]
# Predict actions
predicted_actions = self.predict_action(state_hidden)
return predicted_actions
def get_action(self, states, actions, returns_to_go, timesteps):
"""
Inference: predict next action
Condition on desired return to guide behavior
"""
# Ensure correct shapes
states = states.reshape(1, -1, self.state_dim)
actions = actions.reshape(1, -1, self.action_dim)
returns_to_go = returns_to_go.reshape(1, -1, 1)
timesteps = timesteps.reshape(1, -1)
# Forward pass
predicted_actions = self.forward(states, actions, returns_to_go, timesteps)
# Return last predicted action
return predicted_actions[0, -1]
def train_decision_transformer(model, dataset, config):
"""Train Decision Transformer"""
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
for epoch in range(config.num_epochs):
for batch in dataset:
# Batch contains full trajectories
states = batch['states'] # (B, T, state_dim)
actions = batch['actions'] # (B, T, action_dim)
rewards = batch['rewards'] # (B, T)
timesteps = batch['timesteps'] # (B, T)
# Compute returns-to-go
returns_to_go = torch.zeros_like(rewards)
for t in reversed(range(rewards.shape[1])):
if t == rewards.shape[1] - 1:
returns_to_go[:, t] = rewards[:, t]
else:
returns_to_go[:, t] = rewards[:, t] + 0.99 * returns_to_go[:, t+1]
returns_to_go = returns_to_go.unsqueeze(-1) # (B, T, 1)
# Predict actions
predicted_actions = model(states, actions, returns_to_go, timesteps)
# Loss: MSE between predicted and true actions
loss = F.mse_loss(predicted_actions, actions)
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
return model
# Inference: condition on desired return
def inference_with_desired_return(model, env, target_return=500):
"""Deploy Decision Transformer conditioned on target return"""
obs = env.reset()
done = False
# Trajectory history
states = [obs]
actions = [np.zeros(model.action_dim)] # Dummy first action
returns_to_go = [target_return]
timesteps = [0]
t = 0
while not done:
# Get action from model
action = model.get_action(
states=torch.FloatTensor(states),
actions=torch.FloatTensor(actions),
returns_to_go=torch.FloatTensor(returns_to_go),
timesteps=torch.LongTensor(timesteps)
)
# Execute
next_obs, reward, done, _ = env.step(action.numpy())
# Update history
states.append(next_obs)
actions.append(action.numpy())
returns_to_go.append(returns_to_go[-1] - reward) # Decrement return
timesteps.append(t + 1)
t += 1
print(f"Episode return: {target_return - returns_to_go[-1]}")
Dataset Quality Matters¶
Offline RL performance heavily depends on dataset quality:
class OfflineDatasetAnalyzer:
"""Analyze offline RL dataset quality"""
def __init__(self, dataset):
self.dataset = dataset
def analyze(self):
"""Comprehensive dataset analysis"""
print("="*60)
print("OFFLINE RL DATASET ANALYSIS")
print("="*60)
# 1. Size
print(f"\nDataset Size: {len(self.dataset)} transitions")
# 2. Return distribution
returns = self.compute_episode_returns()
print(f"\nReturn Statistics:")
print(f" Mean: {np.mean(returns):.2f}")
print(f" Std: {np.std(returns):.2f}")
print(f" Min: {np.min(returns):.2f}")
print(f" Max: {np.max(returns):.2f}")
# 3. Coverage
state_coverage = self.estimate_state_coverage()
print(f"\nState Space Coverage: {state_coverage:.2%}")
# 4. Quality score
quality = self.compute_quality_score()
print(f"\nDataset Quality Score: {quality:.2f} / 10")
if quality >= 7:
print("✓ High-quality dataset - offline RL should work well")
elif quality >= 5:
print("⚠️ Medium-quality - consider data augmentation")
else:
print("✗Low-quality - offline RL may struggle")
print("="*60)
def compute_episode_returns(self):
"""Compute return for each episode"""
returns = []
current_return = 0
for transition in self.dataset:
current_return += transition['reward']
if transition['done']:
returns.append(current_return)
current_return = 0
return np.array(returns)
def estimate_state_coverage(self):
"""Estimate what fraction of state space is covered"""
states = np.array([t['state'] for t in self.dataset])
# Discretize state space
bins_per_dim = 10
covered_bins = set()
for state in states:
bin_coords = tuple(
int((s + 1) / 2 * bins_per_dim) # Assume states in [-1, 1]
for s in state
)
covered_bins.add(bin_coords)
total_bins = bins_per_dim ** states.shape[1]
coverage = len(covered_bins) / total_bins
return coverage
def compute_quality_score(self):
"""Overall dataset quality score [0-10]"""
scores = []
# 1. Return score (higher is better)
returns = self.compute_episode_returns()
return_score = min(10, np.mean(returns) / 100) # Normalize
scores.append(return_score)
# 2. Diversity score
coverage = self.estimate_state_coverage()
diversity_score = min(10, coverage * 100)
scores.append(diversity_score)
# 3. Size score
size_score = min(10, len(self.dataset) / 100000)
scores.append(size_score)
return np.mean(scores)
Best Practices¶
DO:¶
✓ Use high-quality datasets (expert or near-expert) ✓ Check dataset coverage before training ✓ Start with CQL or IQL (most robust) ✓ Use conservative hyperparameters (low learning rate) ✓ Evaluate on same distribution as training data ✓ Monitor Q-value overestimation
DON'T:¶
✗Expect to outperform dataset performance significantly ✗Use tiny datasets (<1000 transitions) ✗Skip dataset quality analysis ✗Use high exploration during evaluation ✗Ignore distributional shift warnings
When to Use Offline RL¶
Use when: - Have existing dataset of demonstrations - Online interaction expensive/unsafe - Need to leverage sub-optimal data - Exploring is dangerous
Don't use when: - Can cheaply collect online data - Need to significantly exceed dataset performance - Dataset quality is very poor - Have very little data (<1000 transitions)
Resources¶
Key Papers: - Kumar et al., "Conservative Q-Learning for Offline RL", NeurIPS 2020 - Kostrikov et al., "Offline RL with Implicit Q-Learning", ICLR 2022 - Chen et al., "Decision Transformer", NeurIPS 2021 - Fujimoto & Gu, "A Minimalist Approach to Offline RL", NeurIPS 2021
Next Steps¶
- Model-Based RL - Combine with world models
- Imitation Learning - Alternative for offline learning
- Data Collection - Collect quality datasets
- Evaluation - Evaluate offline policies