A Comprehensive Overview of Q-Learning and Actor-Critic Methods
A Comprehensive Overview of Q-Learning and Actor-Critic Methods
Table of Contents
- 1. Q-Learning: A Foundational Approach
- 2. Basic Actor-Critic
- 3. Neural Network Parameterization
- 4. Deep Deterministic Policy Gradient (DDPG)
- 5. Twin Delayed Deep Deterministic Policy Gradient (TD3)
- 6. Proximal Policy Optimization (PPO)
- 7. Soft Actor-Critic (SAC)
- 8. Asynchronous Advantage Actor-Critic (A3C)
- 9. A3C full example:
- 9. Evolution Timeline of the Methods
- 10. Concluding Remarks
1. Q-Learning: A Foundational Approach
Q-Learning attempts to learn the state-action value function:
For a discrete action space, we can keep a table or a neural network to represent .
1.1 Bellman Update
The core Bellman optimality update is:
Python Snippet
import numpy as np
# Suppose Q is a table (2D array: Q[state, action])def q_learning_update(Q, s, a, r, s_next, alpha, gamma): # Q[s, a] = Q[s, a] + alpha * [r + gamma * max(Q[s_next, :]) - Q[s, a]] Q[s, a] += alpha * (r + gamma * np.max(Q[s_next]) - Q[s, a])
1.2 Deep Q-Network (DQN)
In deep Q-learning, we approximate with a neural network . The loss to minimize is:
Here, and denote online and target network parameters.
Python Snippet
import torchimport torch.nn as nnimport torch.optim as optim
class QNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64): super(QNetwork, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, action_dim) # Q(s) -> R^action_dim
def forward(self, s): x = torch.relu(self.fc1(s)) x = torch.relu(self.fc2(x)) return self.fc3(x)
def dqn_loss(q_network, target_network, batch, gamma): # batch includes (s, a, r, s_next, done) s, a, r, s_next, done = batch
# Q(s, a) q_values = q_network(s).gather(1, a.unsqueeze(1)).squeeze(1)
# target = r + gamma * max_a' Q^- (s_next, a') if not done with torch.no_grad(): q_next = target_network(s_next).max(1)[0] q_target = r + gamma * q_next * (1 - done.float())
loss = nn.MSELoss()(q_values, q_target) return loss
Thus, Q-learning (and its deep counterpart) is typically off-policy because it learns about the greedy policy while potentially following a different data-collecting policy (e.g., -greedy).
2. Basic Actor-Critic
Unlike Q-learning, actor-critic methods maintain:
- A policy (the “actor”).
- A value function or -function (the “critic”).
2.1 Policy Gradient Theory
We want to maximize the expected return:
The policy gradient theorem says:
where is a baseline (often the value function).
Python Snippet
import torch
def policy_gradient_loss(log_probs, returns, baselines): # log_probs: tensor of log π(a|s) for each step # returns: tensor of G_t # baselines: tensor of b(s), often V(s)
advantage = returns - baselines # actor loss = - E[ log pi(a|s) * advantage ] loss = -(log_probs * advantage).mean() return loss
2.2 Critic Objective
The critic (value-based) is learned via MSE:
Python Snippet
import torch.nn as nn
def value_loss(value_net, states, returns): # value_net(s) -> scalar V_phi(s) v = value_net(states) loss = nn.MSELoss()(v.squeeze(), returns) return loss
3. Neural Network Parameterization
Below is a typical two-layer MLP for both actor and critic, with explicit shapes:
-
Actor
- First hidden layer:
- Second hidden layer:
- Output layer depends on discrete vs. continuous actions.
-
Critic or
- Similarly a two-layer MLP.
- Output dimension = (scalar).
Python Snippet for a 2-Layer MLP
class MLPActor(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64, discrete=False): super(MLPActor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) # Discrete: outputs action_dim logits # Continuous: outputs mean, possibly log std self.discrete = discrete if discrete: self.fc_out = nn.Linear(hidden_dim, action_dim) else: # For continuous, let's say we output mean only self.fc_mean = nn.Linear(hidden_dim, action_dim) # could also have self.log_std = nn.Parameter(...)
def forward(self, s): x = torch.relu(self.fc1(s)) x = torch.relu(self.fc2(x)) if self.discrete: logits = self.fc_out(x) return logits # use softmax outside else: mean = self.fc_mean(x) return mean
4. Deep Deterministic Policy Gradient (DDPG)
4.1 Deterministic Actor, Q-Critic
- Actor: .
- Critic: .
- Critic Loss:
Python Snippet
def ddpg_critic_loss(q_net, target_q_net, batch, gamma, actor, target_actor): s, a, r, s_next, done = batch q_vals = q_net(s, a) with torch.no_grad(): a_next = target_actor(s_next) q_next = target_q_net(s_next, a_next) q_target = r + gamma * q_next * (1 - done) loss = nn.MSELoss()(q_vals, q_target) return loss
- Actor Update uses the deterministic policy gradient:
Python Snippet
def ddpg_actor_loss(q_net, states, actor): # actor(s) -> a a = actor(states) # compute d/d(a) of Q(s,a), then chain rule q_vals = q_net(states, a) # we want to maximize Q, so minimize negative loss = -q_vals.mean() return loss
DDPG is off-policy and uses a replay buffer plus target networks to improve stability.
5. Twin Delayed Deep Deterministic Policy Gradient (TD3)
5.1 Twin Critics
To reduce overestimation in DDPG, TD3 uses two critics:
The critic target is:
Python Snippet
def td3_critic_loss(q1, q2, q1_target, q2_target, batch, gamma, actor_target): s, a, r, s_next, done = batch
with torch.no_grad(): a_next = actor_target(s_next) q1_next = q1_target(s_next, a_next) q2_next = q2_target(s_next, a_next) q_next_min = torch.min(q1_next, q2_next) q_target = r + gamma * q_next_min * (1 - done)
loss1 = nn.MSELoss()(q1(s, a), q_target) loss2 = nn.MSELoss()(q2(s, a), q_target) return loss1 + loss2
5.2 Delayed Updates
TD3 updates the actor (and target networks) every few critic steps, reducing variance.
6. Proximal Policy Optimization (PPO)
6.1 Probability Ratio and Clipping
PPO is an on-policy method. We define:
The clipped objective is:
Python Snippet
def ppo_clip_loss(pi_new, pi_old, actions, advantages, epsilon=0.2): # pi_new, pi_old: probability of actions under new/old policy ratio = pi_new / (pi_old + 1e-8) unclipped = ratio * advantages clipped = torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantages loss = -torch.min(unclipped, clipped).mean() return loss
7. Soft Actor-Critic (SAC)
7.1 Maximum Entropy RL
SAC encourages exploration via an entropy term . The objective is:
7.2 Two Critics
Like TD3, SAC uses twin critics . The target is:
with .
Python Snippet
def sac_critic_loss(q1, q2, q1_target, q2_target, batch, alpha, gamma, actor_target): s, a, r, s_next, done = batch with torch.no_grad(): # sample new action from actor_target a_next = actor_target(s_next) # log pi(a_next|s_next) logp_a_next = actor_target.log_prob(s_next, a_next) q1_next = q1_target(s_next, a_next) q2_next = q2_target(s_next, a_next) q_next_min = torch.min(q1_next, q2_next) q_target = r + gamma * (q_next_min - alpha * logp_a_next) * (1 - done)
loss1 = nn.MSELoss()(q1(s, a), q_target) loss2 = nn.MSELoss()(q2(s, a), q_target) return loss1 + loss2
8. Asynchronous Advantage Actor-Critic (A3C)
8.1 Parallelization Insight
A3C runs multiple worker processes, each with local copies of and . They asynchronously update the shared global parameters.
8.2 Advantage Actor-Critic Loss
A typical A3C loss (value-based critic) is:
Python Snippet
def a3c_loss(policy_net, value_net, states, actions, returns): # policy_net -> log pi(a|s), value_net -> V(s) log_probs = policy_net.log_prob(states, actions) values = value_net(states).squeeze() advantage = returns - values
actor_loss = - (log_probs * advantage).mean() critic_loss = advantage.pow(2).mean() return actor_loss + critic_loss
A3C’s asynchronous updates help decorrelate data and speed up training on CPUs.
9. A3C full example:
# -*- coding: utf-8 -*-"""Refactored A3C Example: - Renamed classes, functions, and variables for clarity. - Maintains the original multi-process A3C logic with shared parameters.
Author: <Your Name>"""
import gymnasium as gymimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.multiprocessing as mpfrom torch.distributions import Categorical
############################################################################### 1) UTILITY: A safer environment "step" handling older/newer Gym returns.##############################################################################def safe_step_env(env, action): """ Step the environment while handling Gym/Gymnasium return formats: - Some return (obs, reward, done, info) - Gymnasium can return (obs, reward, done, truncated, info) """ results = env.step(action) if len(results) == 5: next_obs, reward, done, truncated, info = results done = done or truncated else: next_obs, reward, done, info = results return next_obs, reward, done, info
############################################################################### 2) SHARED OPTIMIZER##############################################################################class SharedAdam(torch.optim.Adam): """ A custom Adam optimizer that uses shared memory for the moving averages, allowing multiple processes to update a shared global set of parameters. """
def __init__(self, parameters, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0): super(SharedAdam, self).__init__(parameters, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) for group in self.param_groups: for param in group['params']: state = self.state[param] state['step'] = 0 state['exp_avg'] = torch.zeros_like(param.data) state['exp_avg_sq'] = torch.zeros_like(param.data) # Share memory so that parallel processes can update # the optimizer states state['exp_avg'].share_memory_() state['exp_avg_sq'].share_memory_()
############################################################################### 3) A3C NETWORK (Actor + Critic)##############################################################################class A3CNetwork(nn.Module): """ An Actor-Critic network that outputs: - A policy distribution (pi) for selecting actions. - A value function (v) for estimating state value. """
def __init__(self, state_dim, num_actions, gamma=0.99): super(A3CNetwork, self).__init__() self.gamma = gamma
self.actor_hidden = nn.Linear(*state_dim, 128) self.critic_hidden = nn.Linear(*state_dim, 128)
self.actor_out = nn.Linear(128, num_actions) self.critic_out = nn.Linear(128, 1)
# We'll store trajectories in memory buffers between updates self.memory_states = [] self.memory_actions = [] self.memory_rewards = []
def store_experience(self, state, action, reward): """Cache a transition (state, action, reward).""" self.memory_states.append(state) self.memory_actions.append(action) self.memory_rewards.append(reward)
def reset_memory(self): """Clear trajectory buffers.""" self.memory_states = [] self.memory_actions = [] self.memory_rewards = []
def forward(self, state): """ Forward pass: state: a torch.Tensor of shape [batch_size, *state_dim]. Returns: - policy_logits (for the actor) - state_value (scalar per batch item, for the critic) """ actor_hidden_out = F.relu(self.actor_hidden(state)) critic_hidden_out = F.relu(self.critic_hidden(state))
policy_logits = self.actor_out(actor_hidden_out) state_value = self.critic_out(critic_hidden_out) return policy_logits, state_value
def compute_returns(self, done): """ Compute discounted returns for each step in the trajectory. If the episode ended (done=True), final state's value is 0. Otherwise, we bootstrap from the last state's value. """ states_t = torch.tensor(self.memory_states, dtype=torch.float) _, values_t = self.forward(states_t) values_t = values_t.squeeze(-1) # shape [trajectory_len]
# If done, final state's value is 0; else we take the last state's value final_value = 0.0 if done else values_t[-1].item()
returns = [] discounted_return = final_value for reward in reversed(self.memory_rewards): discounted_return = reward + self.gamma * discounted_return returns.append(discounted_return) returns.reverse() return torch.tensor(returns, dtype=torch.float)
def compute_loss(self, done): """ Compute combined Actor + Critic loss over the stored trajectory. """ states_t = torch.tensor(self.memory_states, dtype=torch.float) actions_t = torch.tensor(self.memory_actions, dtype=torch.long) returns_t = self.compute_returns(done)
policy_logits, values_t = self.forward(states_t) values_t = values_t.squeeze(-1) # shape [trajectory_len]
# Critic loss (Mean Squared Error) critic_loss = (returns_t - values_t) ** 2
# Actor loss (REINFORCE with advantage) probabilities = F.softmax(policy_logits, dim=1) dist = Categorical(probabilities) log_probs = dist.log_prob(actions_t)
advantages = returns_t - values_t.detach() # no grad wrt. values for advantage actor_loss = -log_probs * advantages
total_loss = (critic_loss + actor_loss).mean() return total_loss
def select_action(self, observation): """ Sample an action according to the policy's distribution. observation: raw state from the environment (np array or float list). """ observation_t = torch.tensor([observation], dtype=torch.float) policy_logits, _ = self.forward(observation_t) probabilities = F.softmax(policy_logits, dim=1) dist = Categorical(probabilities) action = dist.sample().item() return action
############################################################################### 4) WORKER AGENT (One process per agent)##############################################################################class A3CWorker(mp.Process): """ Each worker interacts with an environment instance, accumulates experience, and updates the global A3C network's parameters. """
def __init__(self, global_network, global_optimizer, state_dim, num_actions, gamma, lr, worker_id, global_episode_counter, env_id, max_episodes, update_interval): super(A3CWorker, self).__init__() self.local_network = A3CNetwork(state_dim, num_actions, gamma) self.global_network = global_network self.global_optimizer = global_optimizer
self.worker_name = f"worker_{worker_id:02d}" self.episode_counter = global_episode_counter
self.env = gym.make(env_id) self.max_episodes = max_episodes self.update_interval = update_interval
def run(self): step_count = 1 while self.episode_counter.value < self.max_episodes: state, _info = self._reset_env() # handle Gymnasium reset done = False episode_return = 0.0
self.local_network.reset_memory()
while not done: action = self.local_network.select_action(state) next_state, reward, done, info = safe_step_env(self.env, action)
episode_return += reward self.local_network.store_experience(state, action, reward)
# Update global network after 'update_interval' steps or on episode done if step_count % self.update_interval == 0 or done: loss = self.local_network.compute_loss(done)
self.global_optimizer.zero_grad() loss.backward()
# Copy grads from local to global for local_param, global_param in zip( self.local_network.parameters(), self.global_network.parameters()): global_param._grad = local_param.grad
self.global_optimizer.step()
# Sync local network with updated global parameters self.local_network.load_state_dict(self.global_network.state_dict()) self.local_network.reset_memory()
step_count += 1 state = next_state
# Increment global episode counter with self.episode_counter.get_lock(): self.episode_counter.value += 1
print(f"{self.worker_name} | Episode: {self.episode_counter.value} " f"| Return: {episode_return:.1f}")
def _reset_env(self): """ Reset environment (Gym/Gymnasium). Handles new reset() returning (obs, info). """ initial_obs = self.env.reset() if isinstance(initial_obs, tuple) and len(initial_obs) == 2: obs, info = initial_obs else: obs = initial_obs info = {} return obs, info
############################################################################### 5) WATCH A TRAINED AGENT##############################################################################def watch_agent(global_network, env_id="CartPole-v1", episodes_to_watch=5): """ Renders a few episodes using the global A3C network's parameters. """ env = gym.make(env_id, render_mode="human")
# Local copy for inference local_network = A3CNetwork([4], 2) # For CartPole: state_dim=[4], num_actions=2 local_network.load_state_dict(global_network.state_dict())
for ep in range(episodes_to_watch): state, _info = env.reset() done = False episode_return = 0.0 while not done: # For older Gym versions, you might do env.render() here action = local_network.select_action(state) state, reward, done, info = safe_step_env(env, action) episode_return += reward
print(f"Watch Episode {ep + 1}, Return: {episode_return:.1f}")
env.close()
############################################################################### 6) MAIN: TRAINING LOGIC##############################################################################if __name__ == '__main__':
LEARNING_RATE = 1e-4 ENV_ID = "CartPole-v1" env = gym.make(ENV_ID)
# Observation space shape => typically something like (4,) state_dim = env.observation_space.shape
# Action space => for Discrete(n), .n is the number of possible actions num_actions = env.action_space.n MAX_EPISODES = 3000 UPDATE_INTERVAL = 500
# Create global (shared) A3C network shared_network = A3CNetwork(state_dim, num_actions) shared_network.share_memory()
# Create shared optimizer shared_optimizer = SharedAdam(shared_network.parameters(), lr=LEARNING_RATE, betas=(0.92, 0.999))
global_episode_counter = mp.Value('i', 0)
# Spawn worker processes num_cpus = mp.cpu_count() workers = [] for cpu_id in range(num_cpus): worker = A3CWorker( global_network=shared_network, global_optimizer=shared_optimizer, state_dim=state_dim, num_actions=num_actions, gamma=0.99, lr=LEARNING_RATE, worker_id=cpu_id, global_episode_counter=global_episode_counter, env_id=ENV_ID, max_episodes=MAX_EPISODES, update_interval=UPDATE_INTERVAL ) workers.append(worker)
# Start and join each worker for w in workers: w.start() for w in workers: w.join()
print("Training complete. Now let's watch the agent in action!") watch_agent(shared_network, env_id=ENV_ID, episodes_to_watch=5000)
9. Evolution Timeline of the Methods
-
Q-Learning:
- Learns using the Bellman update.
- Great for discrete action spaces (DQN for deep version).
- Off-policy, can be inefficient for large continuous spaces.
-
Actor-Critic (baseline):
- Combines policy gradient with a critic to reduce variance.
- Works in both discrete and continuous settings.
-
**DDPG **:
- Deterministic policy + replay buffer + target networks for continuous control.
- Issue: Overestimation, sensitive hyperparameters.
-
**A3C **:
- Multiple asynchronous workers for faster training.
- No replay buffer, but can have higher variance.
-
**TD3 **:
- Twin critics + delayed updates to reduce overestimation in DDPG.
- Deterministic, needs exploration noise.
-
**PPO **:
- On-policy with clipped objective for stable learning.
- Popular and relatively easy to tune.
-
**SAC **:
- Maximum entropy RL for robust exploration.
- Twin critics to reduce overestimation.
- Often state-of-the-art in continuous control tasks.
Hence, each method emerges to address specific challenges:
- Overestimation (TD3, SAC).
- Exploration (SAC’s entropy).
- Stability (PPO clipping, twin critics).
- Efficiency (replay buffers, asynchronous runs).
10. Concluding Remarks
- Q-learning (and DQN) forms the foundation for many discrete-action RL approaches.
- Actor-Critic methods extend naturally to continuous actions and can reduce variance with a learned critic.
- DDPG introduced a deterministic actor with an off-policy, replay-buffer approach, later refined by TD3 to address overestimation.
- PPO simplified stable on-policy learning with a clipped objective.
- SAC combined twin critics with maximum entropy to encourage robust exploration.
- A3C leveraged asynchronous CPU processes to speed up training without replay buffers.