Soft Actor-Critic (SAC)
Algorithm Review and Notes
Note: The SAC implementation I use for my work is available here. This post makes lots of comparisons to PPO, which I described in an earlier post.
Overview
Soft Actor-Critic (SAC) is a widely used algorithm in reinforcement learning, possibly the most widely used besides PPO (as far as I know). While PPO is an on-policy approach, constantly requiring new data collection and often parallelized, SAC is off-policy, meaning that it can learn from any data that reflects interacting with the environment. In practice, SAC retains a buffer of interactions that include a history of what it has encountered during training, and SAC can leverage any of these experiences during a network update. This allows SAC to be much more sample-efficient than PPO, and we typically do not need parallelization to aid in data collection. Note that both policies are online, meaning they actively collect data (it can be easy to confuse online with on-policy).
Stealing from my PPO post:
We have some environment that we can interact with: it provides us a current "state", and we supply an "action" to take from that state. It then responds with the consequences of that action: the resulting new state and a reward, a numerical quantity indicating our performance in the environment. Our goal is to learn a neural network for which we can input a current state, and the output will be an action which optimizes our total collected reward over one "episode" (one play) of the environment.
SAC learns two neural networks (actually more than this, keep reading...):
The policy, which takes the current state and outputs an action that should maximize long-term reward. Learning this policy is the main goal of SAC (and RL in general)
A Q-value function, which takes as input a state and an action from that state, and predicts the future discounted reward that can be experienced if we continue from there. Compared to the on-policy vaule function in PPO, there are two main differences:
We commit to an action, rather than simple evaluating a state. This is what makes this a "Q-value" instead of a "value".
Our network is trained with off-policy data, meaning that it measures something more general than the current policy. Rather than Q(s,a) measuring the potential returns from following the current policy, we are measuring something like "the best returns we could expect based on our behavior so far during training". While ideally it might simply measure the best possible returns from a (state,action) pair, the training data does come from our history of policies.
To use these networks and update them, we have the following high-level algorithm:
SAC: Repeat until converged:
Use the policy to collect data from the environment. Add this to a large rolling history of observations, maybe one million or more.
Every few steps (~50), grab some data from our buffer to update the Q-network:
- The Q-network should predict the expected returns from each transition onwards. We use a one-step Bellman update for this.
Every few steps (~50), grab some data from our buffer to update the policy:
- The policy should try to maximize Q(s,a). If we predict an action from each state, we can feed this into our (frozen) Q-network and simply try to maximize the output. Backprop will flow through the Q-network into the policy.
Most of the complexity here lies in the Q-network, which we rely on to update our policy. Therefore, we use a lot of tricks to make this as accurate as possible. I will break this down first, and then discuss the policy network.
Nominal Q-Network Update
The Q-network is nominally updated with the Bellman equation, which simply takes an observed step and ties together the Q-value before and after the step: the Q-value of the current step (s,a) should be equal to the resulting reward and the discounted Q-value of a hypothetical subsequent step using the policy:
where the second term is omitted if the step is terminal. Using this update, we bootstrap our Q-values by tying the target values to the output of the Q-network itself. The one thing anchoring this update to an interpretable value is the occasional terminal step, in which case only the reward is used as a target.
Q-values have a tendency to run away from you (be vastly overestimated) when learned this way, because errors can easily propagate through our updates. We are, after all, using Q to produce the training target for itself. There are two interventions we can take to make this more stable:
Firstly, we can learn multiple estimates of Q instead of one (multiple Q networks), and take the minimum of these when computing the target. Generally we just use two Q networks. Our update for both networks now uses the minimum target:
Secondly, we use delayed Q-networks: copies of the Q-network which are rolling averages of the actively updated Q-networks. This means we in total need to maintain four Q-networks: two current networks which are noisy, and two delayed networks which are rolling averages of the first two. Our update is the same as above except we use the delayed networks as our targets:
And we roll our delayed networks via polyak averaging:
where k is usually something like 0.99 or 0.999.
Nominal Policy Update
The policy update is a bit simpler, but can be strange if you have not seen a mechanism like it before. We sample some states from our buffer, predict actions with our policy, and pass the whole resulting (s,a) into a (frozen) Q-network. We then use the output of the Q-network as a loss to maximize. Backprop will trace through Q-network, through the action, and into the policy network, updating it such that it outputs an action that would maximize Q at that state.
Just as in the Q-network update, we actually use the minimum of the two current networks:
Note that we do not use the delayed networks here. We have already used them to stabilize the learning process via the Q update.
Making Things "Soft": Adding Entropy
Taken together, the Q-network updates and policy updates give us an actor-critic algorithm. However, we can modify this further by encouraging our policy to learn a wide distribution where possible (i.e. increased entropy), which has very helpful exploration properties to uncover behaviors during the learning process. Concretely, we model our policy as a normal distribution over actions and encourage (via additional loss terms) the standard deviation to be higher where possible. If we absolutely want to maximize performance during deployment, we can just use the mean of this distribution directly.
The entropy of a distribution is defined as the expected negative log-probability:
and since we can assume that our minibatch updates take care of all the expectation goodness for us, we can maximize entropy by adding to our loss:
However, we do not want to do this directly. We want to introduce a trade-off which balances the entropy of the policy against the long-term discounted returns that the policy achieves. To achieve this, we can imagine that the policy's entropy at a future state is something we can uncover as we interact with the environment:
When we define Q, we want to capture not only the discounted future rewards but also the discounted entropy of the future, essentially treating the future entropy as an additional reward term:
Since we have two terms that are multiplied by gamma we can group them, and we can substitute the definition of entropy to get our full equation:
where alpha is a coefficient to balance the effect of the entropy term. This equation is used to update Q rather than the pure Bellman update discussed above. More fully, we have:
Additionally, when we update the policy, we want to maximize this term rather than purely maximize Q:
More fully, we have:
The Full Algorithm
To recap, here is the full algorithm of SAC:
SAC: Repeat until converged:
Collect experience with the current policy and saved into a large rolling buffer.
Every few steps (~50), update the Q-networks. Sample transitions from the buffer and:
Update the two current networks using the modified update above, taking the minimum of the two delayed predictions and including entropy maximization.
Use polyak averaging to update the delayed Q networks.
Every few steps (~50), update the Policy. Sample transitions from the buffer and:
Freeze the Q-networks
Predict the actions from the states, and use the above modified policy loss to backprop through the Q-network and into the policy.
Unfreeze the Q-networks
Examining the Policy Network
The policy network in SAC is a bit more complex than in PPO. For one, since we want to allow the policy to maximize entropy when it can, the standard deviation of the distribution has to be conditioned on the state (not a constant property of the policy).
Additionally, we have some functionality to limit the action outputs to a nice range (-1 to 1), and compute the resulting probabilities. This is not so critical to do for PPO, where we are examining specific actions that were taken and trying to make them more or less likely based on outcomes. If multiple sampled actions all end up having the same outcome due to clipping the actions to a nice range, that doesn't really change our update.
However, SAC needs to account for this because it is looking to maximize the entropy of the whole action distribution, not only the actions that were taken when collecting data. Therefore if actions are clipped, this directly changes the entropy of the distribution of actions we can take.
Adding these changes to the PPO network, we get the following:
class SquashedGaussianActor(nn.Module):
def __init__(self, obs_dim, act_dim, policy_net_fn, act_limit):
super().__init__()
self.net = policy_net_fn(obs_dim, act_dim)
self.act_limit = act_limit
def forward(self, obs, deterministic=False, with_logprob=True):
mu, log_std = self.net(obs)
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
std = torch.exp(log_std)
# Pre-squash distribution and sample
pi_distribution = Normal(mu, std)
if deterministic:
# Only used for evaluating policy at test time.
pi_action = mu
else:
pi_action = pi_distribution.rsample()
if with_logprob:
# Compute logprob from Gaussian, and then apply correction for Tanh squashing.
# NOTE: The correction formula is a little bit magic. To get an understanding
# of where it comes from, check out the original SAC paper (arXiv 1801.01290)
# and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
# Try deriving it yourself as a (very difficult) exercise. :)
logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
else:
logp_pi = None
pi_action = torch.tanh(pi_action)
pi_action = self.act_limit * pi_action
return pi_action, logp_pi
Recent Posts:
Flowing with Fewer Steps
Shortcut Models Notes and Review
December 12, 2024
Going with the Flow
Notes on Flow Matching (Policies)
December 9, 2024
Modeling the World
RSSM & TSSM Notes and Experiments
December 1, 2024
More Posts