What?
Enabling reparameterisation trick to work for discrete random variables.
Why?
We want low-variance gradient estimate methods (reparameterisation) to work with discrete random variables.
How?
source: original paper
- Stochastic computation graphs
- Nice formalism to represent computation as a bunch of parameterized blocks;
- Blocks can be deterministic or stochastic;
- If we can differentiate the blocks, we can use the chain rule to compute the gradient.
- Adding stochastic nodes to the graph makes our life harder (below, $\theta$ parameterise deterministic blocks, $\phi$ do the stochastic ones)
- Deterministic nodes are still fine. Use Monte Carlo to estimate the gradient:
- $\nabla_\theta L(\theta,\pi) \simeq \frac{1}{S}\sum_{s=1}^S\nabla_\theta f_\theta(X^s)$, where $X^s$ is a sample from a stochastic node $X$;
- Stochastic nodes are problematic
- Can't do MC on the following since it's not in the expectation form ( $\nabla_\phi L(\theta, \phi) = \nabla_\phi \int p_\phi(x) f_\theta(x)dx = \int f_\theta (x) \nabla_\phi p_\phi(x) dx$
- What we can do for stochastic nodes:
- Score function estimators (REINFORCE);
- $\nabla_\phi L(\theta, \phi) = \mathbb{E}{X~\sim p\phi(x)}
[f_\theta (X) \nabla_\phi \log{p_\phi}(X)]$;
- Use MC to estimate;
- High variance (can be reduced by baseline)
- Reparameterisation trick;
- Sample from $q(z)$ first;
- Transform using $g_\phi (z)$ to get a sample from $p_\phi (x)$
- e.g. to get $\texttt{Normal}(\mu, \sigma)$, sample $Z$ from $\texttt{Normal}(0,1)$ and do $g_{\mu, \sigma} = \mu + \sigma Z.$
- With the above, the objective becomes $\nabla_\phi L(\theta, \phi) = \mathbb{E}{X~\sim p\phi(x)}
[f_\theta (X) \nabla_\phi] = \mathbb{E}{Z\sim q(z)}[f\theta(g_\phi(Z))];$
- Now, if $f_\theta(x)$ is differentiable w.r.t. $x$ and $g_\phi(z)$ is differentiable w.r.t. $\phi$, we can calculate the gradients: $\nabla_\phi L(\theta, \phi) = \mathbb{E}{Z~\sim q(z)}
[\nabla\phi f_\theta (g_\phi(Z))]=\mathbb{E}{Z~\sim q(z)}
[\nabla\phi f'\theta (g\phi(Z))\nabla_\phi g_\phi(Z)]$;
- This should have low variance.
- The Concrete Distribution
- Gumbel-Max trick
- Sampling from a discrete distribution;
- Consider a discrete distribution $D \sim \texttt{Discrete}(\alpha)$, where $\alpha = (\alpha_1, \dots, \alpha_n)$ parameterizes the distribution;
- To sample from the distribution, sample from the Gumbel first: $U_k \sim \texttt{Uniform}(0,1)$
- Find the $k$ that maximises $\{ \log{\alpha_k} - \log{(-\log{U_k})} \}$
- Set $D_k = 1$ and other $D_i=0$.
- ^^^ all above is the idea from the above: turn sampling into a deterministic computation after sampling from some generic distribution.
- Concrete random variables
- Okay, how do we backprop through $\arg\max$?
- Use concrete random variables!
- Soften the state of a discrete variable (the one-hot vector from above): $X_k = (\exp{(\log{\alpha_k}+G_k)/\lambda})/\sum_{i=1}^n(\exp{(\log{\alpha_i}+G_i)/\lambda})$
- If $\lambda\rightarrow 0$, the computation becomes $\arg\max$, and we get a discrete random variable.
- The authors derive a distribution such that the definition corresponds to the sampling above.
- Relaxation
- Often, it's infeasible to do the computation with the discrete random vars (Not sure I understand when);
- Using continuous relaxation gives us a biased estimator, but it's easier to compute;
- For training, use the continuous relaxation, for the inference mode, use discrete version.
And?
- I took this paper to understand what a Gumbel-Softmax trick is, so, take everything I said above with a grain of salt.
- I like how concrete (pun intended) the paper is. The intro is well-motivated, the background is concise, and it is a pleasure to read.
- These plots are amazing!
This note is a part of my paper notes series. You can find more here or on Twitter. I also have a blog.