Auto-Encoding Variational Bayes
Accept info: ICLR 2014 Oral
Authors: Diederik P Kingma, Max Welling
Affiliation: Machine Learning Group Universiteit van Amsterdam
Links: arXiv, OpenReview
TLDR: Train directed probabilistic models by maximizing variational lower bound with reparameterization trick for efficient gradient-based optimization.
1. Intuition & Motivation
- Goal: general and efficient algorithm for learning and inference in directed probabilistic models, whose continuous latent variables have intractable posterior distributions, scalable to large dataset
Variational Bayesian (VB) approach: approximate the intractable distribution using a tractable distribution
\(\mathrm{log} \ p_{\theta}(x) = \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{p_{\theta}(x, z)}{p_{\theta}(z \mid x)} \frac{q_{\phi}(z \mid x)}{q_{\phi}(z \mid x)}]\)
\(= \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{q_{\phi}(z \mid x)}{p_{\theta}(z \mid x)} + \mathrm{log} \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)}]\)
\(= D_{KL} \left (q_{\phi}(z \mid x) \parallel p_{\theta}(z \mid x) \right ) + \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)}]\)
\(\geq \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)}]\)
Variational lower bound: \(L(\theta, \phi; x) = \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)}]\)
\(\mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)}] = \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{p_{\theta}(x \mid z) \ p_{\theta}(z)}{q_{\phi}(z \mid x)}]\)
\(= \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \frac{p_{\theta}(z)}{q_{\phi}(z \mid x)} + \mathrm{log} \ p_{\theta}(x \mid z)]\)
\(= - D_{KL} \left ( q_{\phi}(z \mid x) \parallel p_{\theta}(z) \right ) + \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \ p_{\theta}(x \mid z)]\)
Gradient of the variational lower bound \(L(\theta, \phi; x)\) w.r.t. to variational parameters \(\phi\) is a bit problematic.
\(\nabla_{\phi} \mathbb{E}_{z \sim q_{\phi}(z)} [f(z)] = \nabla_{\phi} \int q_{\phi}(z) \ f(z) \ dz\)
\(= \int \nabla_{\phi} q_{\phi}(z) \ f(z) \ dz\)
\(= \int q_{\phi}(z) \ \nabla_{\phi} \mathrm{log} q_{\phi}(z) \ f(z) \ dz\)
\(= \mathbb{E}_{z \sim q_{\phi}(z)} [f(z) \ \nabla_{\phi} \mathrm{log} q_{\phi}(z)]\)
\(\simeq \frac{1}{L} \sum_{l=1}^{L} [f(z^{(l)}) \ \nabla_{\phi} \mathrm{log} q_{\phi}(z^{(l)})]\)
The naive Monte Carlo gradient estimator suffers from high variance, because the gradient \(\nabla_{\phi} \mathrm{log} q_{\phi}(z)\) is scaled by the potentially large and noisy term \(f(z)\), making it impractical.
Removing the dependency on \(\phi\) in the subscript of \(\mathbb{E}_{z \sim q_{\phi}(z)} [f(z)]\) will enable efficient gradient estimation of the variational lower bound.
2. Reparameterization Trick, SGVB Estimator and AEVB Algorithm
2.1. Approach Overview
Reparameterization Trick: rewrite \(\mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \ p_{\theta}(x \mid z)]\) such that the Monte Carlo estimate of the expectation is differentiable w.r.t. \(\phi\).
SGVB Estimator: apply reparameterization on variational lower bound, and estimate gradient of it.
AEVB Algorithm: learning algorithm using SGVB estimator.
2.2. Reparameterization Trick
Express the random variable \(z\) as a deterministic variable \(z = g_{\phi}(\epsilon, x)\), where \(\epsilon\) is an auxiliary variable with independent marginal \(p(\epsilon)\), and \(g_{\phi}(\cdot)\) is some vector-valued function parameterized by \(\phi\).
2.3. SGVB Estimator
\(q_{\phi}(z \mid x) dz = p(\epsilon) d\epsilon\)
\(\mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \ p_{\theta}(x \mid z)] = \int q_{\phi}(z \mid x) \ \mathrm{log} \ p_{\theta}(x \mid z) \ dz\)
\(= \int p(\epsilon) \ \mathrm{log} \ p_{\theta}(x \mid z) \ d\epsilon\)
\(= \mathbb{E}_{\epsilon \sim p(\epsilon)} [\mathrm{log} \ p_{\theta}(x \mid z)]\)
\(L(\theta, \phi; x) = - D_{KL} \left ( q_{\phi}(z \mid x) \parallel p_{\theta}(z) \right ) + \mathbb{E}_{z \sim q_{\phi}(z \mid x)} [\mathrm{log} \ p_{\theta}(x \mid z)]\)
\(= - D_{KL} \left ( q_{\phi}(z \mid x) \parallel p_{\theta}(z) \right ) + \mathbb{E}_{\epsilon \sim p(\epsilon)} [\mathrm{log} \ p_{\theta}(x \mid z)]\)
\(\tilde{L}^{B}(\theta, \phi; x) = - D_{KL} \left ( q_{\phi}(z \mid x) \parallel p_{\theta}(z) \right ) + \frac{1}{L} \sum_{l=1}^{L} \mathrm{log} \ p_{\theta}(x \mid z^{(l)})\)
Given the multiple datapoints from a dataset \(X\) with \(N\) datapoints, we can construct an estimator of the marginal likelihood variational lower bound of the full dataset based on minibatches, where the minibatch \(X^M\) is a randomly drawn sample of \(M\) datapoints from the full dataset \(X\) with \(N\) datapoints.
\[L(\theta, \phi; X) \simeq \tilde{L}^{B}(\theta, \phi; X^M) = \frac{N}{M} \sum_{i=1}^{M} \tilde{L}^{B}(\theta, \phi; x^{(i)})\]