Proximal Policy Optimization (PPO)
Algorithm Review and Notes
Note: The PPO implementation I use for my work is available here.
The Big Idea
This post assumes you are somewhat familiar with RL already. To be brief- Proximal Policy Optimization (PPO) is a popular technique in deep reinforcement learning. We have some environment that we can interact with: it provides us a current "state", and we supply an "action" to take from that state. It then responds with the consequences of that action: the resulting new state and a reward, a numerical quantity indicating our performance in the environment. Our goal is to learn a neural network for which we can input a current state, and the output will be an action which optimizes our total collected reward over one "episode" (one play) of the environment.
In fact, PPO learns two neural networks simultaneously:
The policy network, which takes states as input and outputs a distribution over actions. This is what we really care about.
The value network, which measures the expected returns from a given state if we were to follow our policy. This is used to estimate "how good" a current state is for us to be in, to put it very informally.
PPO is online, meaning that we actively interact with the environment to collect data about it. PPO is also on-policy, meaning that we update our model(s) using the latest data collected by the policy. Our models are not necessarily useful beyond understanding the current policy.
Because PPO is online and on-policy, we need to constantly interact with the environment during training to understand the behavior of the policy. The high-level algorithm is as follows:
PPO: Repeat until converged:
Use the policy to collect data from the environment.
Use this data to update the policy:
- The policy should predict actions proportional to how beneficial they are.
Use this data to update the value network:
- The value function should predict the expected return from each state (the experienced return).
Throw away the data: it no longer reflects the current policy's behavior.
What is really nice about this setup (or wasteful if you care about data efficiency) is that it is massively parallelizable if we want it to be. At no point do we store data beyond a single iteration, so we can kick off many copies of this and churn through tons of data if we have the required compute:
PPO Parallel: Repeat until converged:
Use the policy to collect data from the environment across N copies of the algorithm, each with their own environment instance.
Use this data to update the policy:
Compute a local update for the policy based on locally collected data
Synchronize the update among all copies (average across copies)
Apply update
Use this data to update the value network:
Compute a local update for the value network based on locally collected data
Synchronize the update among all copies (average across copies)
Apply update
Throw away the data: it no longer reflects the current policy's behavior.
Hopefully the data collection piece needs no further explanation, but the update steps definitely do.
The Value Loss Function
Our value network should predict the expected discounted return of each state, and from our data collection we have direct experience of the returns we collect after a given state, so we have everything we need. For every state along a collected trajectory, we calculate:
where the current state is "state zero" and the trajectory ends at state T. The last term is a bit of a subtlety- if we stopped data collection short of a terminal condition, we use to estimate future returns beyond our experience (aka "bootstrapping"). If we stopped because the episode actually ended, we use to indicate no future returns.
Our loss function is then just the MSE against this target.
The Policy Loss Function
The policy update is the primary contribution of PPO. I will just try to build some intuition around it- for a full derivation starting with the policy gradient see Lillian Weng's explanation.
We wish to update our policy according to the objective:
where our policy (the probability of a given action) is denoted by , and is the advantage of each transition- a measurement of how much better or worse the outcome of the action was than we would expect based on the state alone. If we maximize this quantity, we should have high probability associated with high advantage (and conversely, low probability with low advantage). See the end section for a breakdown of this quantity.
For now, note that we do not need the calculation of advantage to be part of our computation graph- it is simply a weighting term that we calculate separately and use to weight our policy outputs. This gives us something like the following:
Update Policy:
Compute the advantage for each (s,a) transition using the value function (which has not been updated yet).
For a minibatch of (s,a) tuples:
For each (s,a) transition, determine how probable a was under the policy.
Update the policy to maximize prob(a)*A
Note that in a automatic differentiation framework we can literally just set the loss to the negative of this quantity in order to maximize it. When I was new to RL this was hard to get my head around since I was used to seeing loss functions as nice packaged things from a library, but it can really just be any quantity we want to maximize or minimize.
Practically, we just take a minibatch of experience and average over it, giving us:
Improving this loss function:
There are two changes we make to this to arrive at the PPO loss function. Firstly, the above technique is only valid for a single network update. As soon as we change the policy, we have a bit of an issue: the data was generated under a different policy than we are now working with. This is a problem because our objective is an expectation: a weighted average over our data proportional to how probable the data is to occur. When we collect our data, we also experience our data according to how probable it is, so everything is fine for the first update. If we take a simple average over the data, it empirically reflects the underlying probability and therefore our expectation is valid.
This is a subtle but critical point to understanding PPO. After we change the policy network, we need to correct for the fact that the data itself would have different probabilities under the updated policy. To make this change, we can use importance sampling: we simply divide by the probability of the "old" policy (the one that collected the data). This allows us to estimate the probabilities under the current policy even though they were sampled under the old one:
To implement this, we just need to record the probability of the data when it was collected vs when we are doing our update:
Update policy (with collected data, including the original probabilities of the action at the time they were collected):
Compute the advantage for each (s,a) transition using the value function (which has not been updated yet).
For a minibatch of (s,a) tuples:
For each (s,a) transition, determine how probable a was under the policy.
Update the policy to maximize (prob(a)/original_prob(a))*A
The second addition we make to the loss function is to prevent it from making egregiously large updates to our policy. We want to make sure that our new policy does not stray too far from the previous one which generated our data in the first place (i.e. we want the current policy to stay in the proximity of the old one, hence the name PPO).
We actually already have a measurement of how "far apart" the two policies are: we just took a ratio of their probabilities in the previous step. If this ratio was, say, 1.2, it would indicate that under the new policy an action is 20% more probable than it was originally. If the ratio is 0.9, it indicates said action is 10% less probable.
So, to prevent our updates becoming out-of-whack, we simply clip this correction factor to be within some small boundary around 1, usually [0.8-1.2] or [0.9-1.1]. If it is outside of this range, we simply clip it and thus represent the datapoint as being closer to the original policy than it really is:
where:
Note that this does not actually enforce our new policy from straying too far. It simply prevents any one datapoint from dominating our update. If the ratio of the two policies is something extreme for a single datapoint ins our minibatch (say 100x more probably under the new policy), we clip 100 to 1+epsilon so that it doesn't destroy our update.
To more fully prevent the divergence of the policies, we can also monitor this ratio during the update steps and stop updating if it gets too extreme. Putting these together, we can often update our policy many times from a single data collection iteration.
Finally, we take the minimum between and the un-clipped objective. This prevents our policy from assigning huge probabilities to high-performing actions (or extremely low probabilities to low-performing actions) during a single iteration. This gives us the actual, final, PPO objective:
One easy thing to forget when reasoning about this update is that advantages can be negative (In fact, advantages from a single data collection round are often zero-centered for stability). This means that a term in the above equation which is very large does not necessarily reflect high probability. It could also reflect low probability and negative advantage. Or, since they are multiplied, it could represent a very average probability against a very high advantage. Both of these terms are working in tandem.
My "Implementation"
I put implementation in quotes here because in this case I cannot say I wrote most of it; the implementation I use currently is derived from spinning up, a clean and readable implementation of RL algorithms from OpenAI. I did some restructuring to allow for more customization, namely:
Separating the main algorithm flow into its own routine, so that it is easier to see the big picture. A separate work class implements routines for data collection, policy updates, and value updates.
Allowing the policy and value function to wrap arbitrary torch modules, so that you can put something fancy in there if you want to.
Some simplified MPI routines like
sync_weights()
andaverage_gradients()
.
I also made some changes:
Removed support for a combined policy-value network, since I'm not sure that is widely used anymore. It feels a lot cleaner to keep them separate.
Some estimates of FPS and time remaining. PPO can take a lonnnnnng time to work, so its nice to know how long you need to wait for it (or if you should just go to bed).
Added support for an entropy coefficient to encourage higher-entropy policies, which may help with convergence and stability.
Finally, I added some hooks to support algorithms like GAIL. More on that in another post.
Additional Notes
Advantage
The Advantage measures the difference in expected discounted return between taking some action from a state and simply being in that state (without taking an action). The expected return for a state on its own is exactly what is provided by the value function, .
In other areas of RL, we may also predict the returns after an action is taken, represented as , but in PPO we can instead look directly at the data we collect: after taking some action from a state, were the experienced returns greater or lesser than what we would have predicted with our value network?
Note that since PPO is on-policy, all of our measurements reflect operating under the current policy. represents the expected return from the state if the current policy is followed. If our value function were perfect and we fully understood the current policy, we would have zero advantage (on average): the difference between our expected returns and the true returns would be zero (on average), as we would perfectly predict them. So here, is measuring how much better or worse did we fared than what we would predict based on our current understanding of the policy.
So, we need to actually quantify this, and we have a lot of options. At one extreme, we could measure how well we fared over just a single step- our reward plus expected future value of the resulting state, compared to our original understanding of the value of the original state:
At the other extreme, we could look at the entire trajectory that we experienced. How did we fare over many steps, compared to what was expected?:
In between and , we could choose any other horizon and it would also be valid.
As all of these provide different estimates over our understanding of our data, we can also average over all of them, which is called TD-lambda returns. All of the possible N-step returns are weighted by , such that all weights sum to 1.0. Typically a value around 0.95 is used for lambda.
There are several tricks to efficiently compute all of these returns for all steps of a trajectory, so the code for computing these returns may initially be foreign looking. In this case, the advantages are computed within the buffer when the trajectory ends. Look there for more details. Also note, the advantages are normalized over all data collected, since we really only care how our actions fared relative to each other.
Parallelization
The parallelization in PPO is different from other areas of software, in which parallelization is used to get through a set amount of data more quickly. If we have 1000 things to process and we can split this into 10 independent jobs, we can do 100 things on each job and thus be done 10 times faster.
This is not why we parallelize in PPO. We parallelize here so that our updates reflect large quantities of data and are therefore more stable / better updates. If we do N updates on one thread, or N updates on 20 threads, either way we have updated the network 20 times (and they both take a similar amount of time.) They key difference is, the updates over 20 threads will be better updates, since they give us a better understanding of the policy.
IMO, if you are using PPO, don't worry about data efficiency. That's not really what it is designed for- just run it on as many cpu cores as you can live without for as long as you are willing to wait (or until it clearly converges).
Policy Network
It is worth looking at the policy network briefly in code to point out some aspects of how it works:
class GaussianHead(Actor):
def __init__(self, stem, act_dim):
super().__init__()
log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
self.mu_net = stem
def _distribution(self, obs):
mu = self.mu_net(obs)
std = torch.exp(self.log_std)
return Normal(mu, std)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution
def forward(self, obs, act=None):
# Produce action distributions for given observations, and
# optionally compute the log likelihood of given actions under
# those distributions.
pi = self._distribution(obs)
entropy = pi.entropy()
logp_a = None
if act is not None: logp_a = self._log_prob_from_distribution(pi, act)
return pi, logp_a, entropy
The forward function is interesting- it can be used in one of two ways. A state is always provided, and it computes a Gaussian distribution (in the continuous action case) over actions to output. If an action is also provided, we can now query how probable that action is under the policy from that state. This allows us to collect data with our policy, and later ask how probable it is (even after we have updated the policy a few times). If no action is provided, we can sample one from the returned distribution.
Also note that the standard deviation is a separate parameter, and is not conditioned on the state. It is more a property of the policy itself- how confident are we about each action dimension, in general? This differs from the typical approach in SAC, in which the deviation is calculated on a per-state basis.
Possible future topic for their own posts: A deeper dive into TD-lambda and how to calculate it.
Recent Posts:
Flowing with Fewer Steps
Shortcut Models Notes and Review
December 12, 2024
Going with the Flow
Notes on Flow Matching (Policies)
December 9, 2024
Modeling the World
RSSM & TSSM Notes and Experiments
December 1, 2024
More Posts