Overview

I recently stumbled upon the world of distributional Q-learning, and I hope to share some of the insights I've made from reading the following papers:

This article will loosely work through the two papers in order, as they build on each other, but hopefully I can trim off most of the extraneous information and present you with a nice overview of distributional RL, how it works, and how to improve upon the most basic distributional algorithms to get to the current state-of-the-art.

First I'll introduce distributional Q-learning and try to provide some motivations for using it. Then I'll highlight the strategies used in the development of C51, one of the first highly successful distributional Q-learning algorithms (paper #1). Then I'll introduce implicit quantile networks (IQNs) and explain their improvements to C51 (paper #2).

Quick disclaimer: I'm assuming you're familiar with how Q-learning works. That includes V and Q functions, Bellman backups, and the various learning stability tricks like target networks and replay buffers that are commonly used.

Another important note is that these algorithms are only for discrete action spaces.

Motivations for Distributional Deep Q-Learning

In standard Q-Learning, we attempt to learn a function $Q(s, a): \mathcal{S \times A} \rightarrow \mathbb{R}$ that maps state-action pairs to the expected return from that state-action pair. This gives us a pretty accurate idea of how good specific actions are in specific states (if our $Q$ is accurate), but it's missing some information. There exist distributions of returns that we can receive from each state-action pair, and the expectations/means of these distributions is what $Q$ attempts to learn. But why only learn the expectation? Why not try to learn the whole distribution?

Before diving into the algorithms that have been developed for this specific purpose, it's helpful to think about why this is beneficial in the first place. After all, learning a distribution is a lot more complicated than learning a single number, and we don't want to waste precious computational resources on doing something that doesn't help much.

Stabilized Learning

The first possibility I'll throw out there is that learning distributions could stabilize learning. This may seem unintuitive at first, seeing as we're trying to learn something much more complicated than an ordinary $Q$ function. But let's think about what happens when stochasticity in our environment results in our agent receiving a highly unusual return. I'll use the example of driving a car through an intersection.

Let's say you're waiting at a red light that turns green. You begin to drive forward, expecting to simply cruise through the intersection and be on your way. Your internal model of your driving is probably saying "there's no way anything bad will happen if you go straight right now", and there's no reason to think otherwise. But now let's say another driver on the road perpendicular to yours runs straight through their red light and crashes into you. You would be right to be incredibly surprised by this turn of events (and hopefully not dead, either), but how surprised should you be?

If your internal driving model was based only on expected returns, then you wouldn't predict that this accident would occur at all. And since it just did happen, you may be tempted to drastically change your internal model and, as a result, be scared of intersections for quite a bit until you're convinced that they're safe again; however, what if your driving model was based on a distribution over all possible returns? If you mentally assigned a probability of 0.00001 to this accident occurring, and if you've driven through 100,000 intersections before throughout your lifetime, then this accident isn't really that surprising. It still totally sucks and your car is probably totaled, but you shouldn't be irrationally scared of intersections now. After all, you just proved that your model was right!

So yeah that's kinda dark, but I think it highlights how learning a distribution instead of an expectation can reduce the effects of environment stochasticity1

Risk Sensitive Policies

Using distributions over returns also allows us to create brand new classes of policies that take risk into account when deciding which actions to take. I'll use another example that doesn't involve driving but is equally as deadly :) Let's say you need to cross a gorge in the shortest amount of time possible (I'm not sure why, but you do. This is a poorly formulated example). You have two options: using a sketchy bridge that looks like it may fall apart at any moment, or you could walk down a set of stairs on one side of the gorge and then up a set of stairs on the other side. The latter option is incredibly safe. It'll still take significantly longer than using the bridge, though, so is it worth it?

For the purposes of this example, let's give dying a reward of $-1000$ and give every non-deadly interaction with the environment a reward of $-1$. Let's also say that taking the bridge gets you across the gorge in $10$ seconds with probability $0.5$ of making it across safely. Taking the stairs gets you across the gorge $100%$ of the time, but it takes $100$ seconds instead.

Given this information, we can quickly calculate expected returns for each of the two actions

$$ \mathbb{E}[\text{return}_\text{bridge}] = (-1000 * 0.5) + (-10 * 0.5) = -505 \\ \mathbb{E}[\text{return}_\text{stairs}] = -100 $$

If you made decisions like a standard Q-learning agent, you would never take the bridge. The expected return is much worse than that of taking the stairs, so there's no reason to choose it. But if you made decisions like a distributional Q-learning agent, your decision can be much more well informed. You can be aware of the probability of dying vs. getting across the gorge more quickly by using the bridge. If the risk of falling to your death is worth it in your particular situation (let's say you're being chased by a wild animal who can run much faster than you), then taking the bridge instead of the stairs could end up being what you want.

Although this example was pretty contrived, it highlights how using return distributions allows us to choose policies that before would have been impossible to formulate. Want a policy that takes as little risk as possible? We can do that now. Want a policy that takes as much risk as possible? Go right ahead, but please don't fall into any gorges.

The Distributional Q-Learning Framework

So now we have a few reasons why using distributions over returns instead of just expected return can be useful, but we need to formulate a few things first so that we can use Q-learning strategies in this new setting.

We'll define $Z(s, a)$ to be the distribution of returns at a given state-action pair, where $Q(s, a)$ is the expected value of $Z(s, a)$.

The usual Bellman equation for $Q$ is defined

$$ Q(s, a) = \mathbb{E}[r(s, a)] + \gamma \mathbb{E}[Q(s', a')] $$

Now we'll change this to be defined in terms of entire distributions instead of just expectations by using $Z$ instead of $Q$. We'll denote the distribution of rewards for a single state-action pair $R(s,a)$.

$$ Z(s, a) = R(s, a) + \gamma Z(s', a') $$

All we need now is a way of iteratively enforcing this Bellman constraint on our $Z$ function. With standard Q-learning, we can do that quite simply by minimizing mean squared error between the outputs of a neural network (which approximates $Q$) and the values $\mathbb{E}[r(s, a)] + \gamma \mathbb{E}[Q(s', a')]$ computing using a target Q-network and transitions sampled from a replay buffer.

Such a straightforward solution doesn't exist in the distributional case because the output from our Z-network is so much more complex than from a Q-network. First we have to decide what kind of distribution to output. Can we approximate return distributions with a simple Gaussian? A mixture of Gaussians? Is there a way to output a distribution of arbitrary complexity? Even if we can output really complex distributions, can we sample from that in a tractable way? And once we've decided on how we'll represent the output distribution, we'll then have to choose a new metric to optimize other than mean squared error since we're no longer working with just scalar outputs. Many ways of measuring the difference between probability distributions exist, but we'll have to choose one to use.

These two problems are what the C51 and IQN papers deal with. They both take different approaches to approximating arbitrarily complex return distributions, and they optimize them differently as well. Let's start off with C51: the algorithm itself is a bit complex, but its foundational ideas are rather simple. I won't dive into the math behind C51, and I'll instead save that for IQN since that's the better algorithm.

C51

The main idea behind C51 is to approximate the return distribution using a set of discrete bars which the paper authors call 'atoms'. This is like using a histogram to plot out a distribution. It's not the most accurate, but it gives us a good sense of what the distribution looks like in general. This strategy also leads to an optimization strategy that isn't too computationally expensive, which is what we want.

Our network can simply output $N$ probabilities, where all $N$ probabilities sum to $1$. Each of these probabilities represents one of the bars in our distribution approximation. The paper recommends using 51 atoms (network outputs) based on empirical tests, but the algorithm is defined so that you don't need to know the number of atoms beforehand.

To minimize the difference between our current distribution outputs and their target values, the paper recommends minimizing the KL divergence of the two distributions. They accomplish this indirectly by minimizing the cross entropy between the distributions instead.

The idea behind this is simple enough, but the math gets a bit funky. Since the distribution that our network outputs is split into discrete units, the theoretical Bellman update has to be projected into that discrete space and the probabilites of each atom distributed to neighboring atoms to keep the distribution relatively smooth.

To actually use the discretized distribution to make action choices, the paper authors just use the weighted mean of the atoms. This weighted mean is effectively just an approximation of the standard Q-value.

IQN

C51 works well, but it has some pretty obvious flaws. First off, its distribution approximations aren't going to be very precise. We can use a massive neural network during training, but all those neurons' information gets funneled into just $N$ output atoms at the end of the day. This is the bottleneck on how accurate our network can get, but increasing the number of atoms will increase the amount of computation our algorithm requires.

A second issue with C51 is that it doesn't take full advantage of knowing return distributions. When deciding which actions to take, it just uses the mean of its approximate return distribution. Under optimality, this is really no different than standard Q-learning.

Implicit quantile networks address both of these issues: they allow us to approximate much more complex distributions without additional computation requirements, and they also allow us to easily decide how risky our agent will be when acting.

Implicit Networks

The first issue with C51 is addressed by not explicitly representing a return distribution with our neural networks. If we do this, then our chosen representation of the distribution acts as a major bottleneck in terms of how accurate our approximations can be. Additionally, sampling from arbitrarily complex distributions is intractable if we want to represent them explicitly. IQN's solution: don't train a network to explicitly represent a distribution, train a network to provide samples from the distribution instead.

Since we aren't explicitly representing any distributions, that means our accuracy bottleneck rests entirely in the size of our neural network. This means we can easily make our distribution approximations more accurate without adding on much to the amount of required computation.

Additionally, since our network is being trained to provide us samples from some unknown distribution, the intractable sampling problem goes away.

The second issue with C51 (not using risk-sensitive policies) is also addressed by using implicit networks. We haven't gone over how we'll actually implement such networks, but trust me when I say that we'll be able to easily manipulate the input to them to induce risky or risk-averse action decisions.

Quantile Functions

Before we go through the implementation of these myterious implicit networks, we have to go over a few other things about probability distributions that we'll use when deriving the IQN algorithm.

First off, every probability distribution has what's called a cumulative density function (CDF). If the probability of getting the value $35$ out of a probability distribution $P(X)$ is denoted $P(X = 35)$, then the cumulative probability of getting $35$ from that distribution is $P(X \leq 35)$.

The CDF of a distribution does exactly that, excpet it defines a cumulative probability for all possible outputs of the distribution. You can think of the CDF as really just an integral from the beginning of a distribution up to a given point on it. A nice property of CDFs is that their outputs are bounded between 0 and 1. This should be pretty intuitive, since the integral over a probability distribution has to be equal to 1. An example of a CDF for a unit Gaussian distribution is shown below.

CDF

Quantile functions are closely related to CDFs. In fact, they're just the inverse. CDFs take in an $x$ and return a probability, but quantile functions take in a probability and return an $x$. The quantile function for a unit Gaussian (same as with the previous example CDF) is shown below.

Quantile Function

Representing an Implicit Distribution

Now we can finally get to the fun stuff: figuring out how to represent an arbitarily complex distribution implicitly. Seeing as I just went on a bit of a detour to talk about quantile functions, you probably already know that that's what we're gonna use. But how and why will that work for us?

First off, quantile functions all have the same input domain, regardless of whatever distribution they're for. Your distribution could be uniform, Gaussian, energy-based, whatever really, and its quantile function would only accept input values between 0 and 1. Since we want to represent any arbitrary distribution, this definitely seems like a property that we want to take advantege of.

Additionally, using quantile functions allows us to sample directly from our distribution without ever having an explicit representation of the distribution. Sampling from the uniform distribution $U([0, 1])$ and passing that as input to our quantile function is equivalent to sampling directly from $Z(s, a)$. Since we can implement this entirely within a neural network, this means there's no major accuracy bottleneck either.

We can also add in another feature to our implicit network to give us the ability to make risk-sensitive policy decisions. We can quite simply distort the input to our quantile network. If we want to make the tails of our distribution less important, for example, then we can map input values closer to 0.5 before passing them to our quantile function.

Formalization

We've gone over a lot, so let's take a step back and formalize it a bit. The usual convention for denoting a quantile function over random variable $Z$ (our return) would be $F^{-1}_{Z}(\tau)$, where $\tau \in [0, 1]$. For simplicity's sake, though, we'll define

$$ Z_\tau \doteq F^{-1}_{Z}(\tau) $$

We can also define sampling from $Z(s, a)$ with the following

$$ Z_\tau(s, a), \\ \tau \sim U([0, 1]) $$

To distort our $\tau$ values, we'll define a mapping

$$ \beta : [0, 1] \rightarrow [0, 1] $$

Putting these definitions together, we can reclaim a new distorted Q-value

$$ Q_{\beta}(s, a) \doteq \mathbb{E}_{\tau \sim U([0, 1])} [Z_{\beta(\tau)}(s, a)] $$

To define our policy, we can just take whichever action maximizes this distorted Q-value

$$ \pi_{\beta}(s) = \arg\max\limits_{a \in \mathcal{A}} Q_{\beta}(s, a) $$

Optimization

Now to figure out a way to iteratively update our distribution approximations... We'll use Huber quantile loss, a nice metric that extends Huber loss to work with quantiles instead of just scalar outputs

$$ \rho^\kappa_\tau(\delta_{ij}) = | \tau - \mathbb{I}\{ \delta_{ij} < 0 \} | \frac{\mathcal{L}_\kappa(\delta_{ij})}{\kappa}, \text{with} \\ \mathcal{L}_\kappa(\delta_{ij}) = \begin{cases} \frac{1}{2} \delta^2_{ij} &\text{if } | \delta_{ij} < \kappa | \\ \kappa (| \delta_{ij} | - \frac{1}{2} \kappa) &\text{otherwise} \end{cases} $$

This is a messy loss term, but it essentially tries to minimize TD error while keeping the network's output close to what we expect the quantile function to look like (according to our current approximation).

This loss metric is based on the TD error $\delta_{ij}$, which we can define just like normal TD error

$$ \delta_{ij} = r + \gamma Z_i(s', \pi_\beta(s')) - Z_j(s, a) $$

Notice how in this definition, $i$ and $j$ act as two separate $\tau$ samples from the $U([0, 1])$ distribution. We use two separate $\tau$ samples to keep the terms in the TD error definition decorrelated. To get a more accurate estimation of the loss, we'll sample it multiple times in the following fashion

$$ \mathcal{L} = \frac{1}{N'} \sum_{i=1}^N \sum_{j=1}^{N'} \rho^\kappa_{\tau_i}(\delta_{\tau_i, \tau_j}) $$

where $\tau_i$ and $\tau_j$ are both newly sampled for every term in the summation.

Finally, we'll approximate $\pi_\beta$, which we defined earlier, using a similar sampling technique

$$ \tilde{\pi}_\beta(s) = \arg\max\limits_{a \in \mathcal{A}} \frac{1}{K} \sum_{k=1}^K Z_{\beta(\tau_k)}(s, a) $$

where $\tau_k$ is newly sampled every time as well.

That was a lot, but it's all we need to make an IQN. We could spend time thinking about different choices of $\beta$, but that's really a choice that depends on your specific environment. And during implementation, you can just decide that $\beta$ will be the identity function and then change it later if you think you can get better performance with risk-aware action selection.

Review

We started off with the more obvious way of implementing distributional deep Q-learning, which was explicitly representing the return distribution. Although it worked well, using an explicit representation of the return distribution created an accuracy bottleneck that was hard to overcome. It was also difficult to inject risk-sensitivity into the algorithm.

Using an implicit distribution instead allowed us to get around those two problems, giving us much greater representational power and allowing us much greater control over how our agent handles risk.

Of course, there's always room for improvement. Small techniques like using prioritized experience replay and n-step returns for calculating TD error can be used to make the IQN algorithm more powerful. And since distributional RL is still a pretty new field, there will no doubt be major improvements coming down the academia pipeline to be on the lookout for.

Footnotes

1 see paper #1, section 6.1 for a short discussion of what the paper authors call 'chattering'