f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization

Accept info: NIPS 2016 Spotlight
Authors: Sebastian Nowozin, Botond Cseke, Ryota Tomioka
Affiliation: Microsoft Research
Links: arXiv, OpenReview
TLDR: Generalize GAN training objectives for all f-divergences using variational lower bound.

1. Intuition & Motivation

What are the general goals, when we want to estimate a probabilistic model?
To fit a model distribution \(Q\) to the data distribution \(P\).

How can we measure the difference between two different distributions, \(P\) and \(Q\)?
If we can measure the distance, then we can minimize that distance.

3 main distances

  1. Integral Probability Metrics
    \(P\): expectation, \(Q\): expectation
    Define distances by taking a difference of expectations, which means approximate only using samples.
    ex. kernel MMD, Wasserstein distance
  2. Proper Scoring Rules
    \(P\): expectation, \(Q\): distribution
    Model distribution should be in rich form that could compute in point-wise.
    ex. log-likelihood
  3. f-divergences
    \(P\): distribution, \(Q\): distribution
    Need density function of true distribution that generate the data.
    ex. KL divergence, Jensen-Shannon divergence

For f-divergences, it seems not possible to apply to estimate model.

The interesting thing about GAN is that it approximates the minimization of the Jensen-Shannon divergence.
We can think of GAN as a way to convert the difficult case to the simple case, where we just need difference of expectations.

Since GAN shows we can approximate Jensen-Shannon divergence using only samples, it is natural to infer that all f-divergences can be approximated only using samples.

2. Method

2.1. Overview

Table 1 Table 2

Model fitting using approximation of any f-divergences using \(F(\theta, \omega)= ( \mathbb{E}_{x \sim P} [g_f(V_{\omega}(x))] + \mathbb{E}_{x \sim Q} [-f^{*}( g_f(V_{\omega}(x)) )] )\).

2.2. Preliminary

Convex Conjugate Function
Every convex, lower-semicontinuous function \(f\) has a convex conjugate function \(f^{*}\), also known as Fenchel conjugate.
The function \(f^{*}\) is again convex and lower-semicontinuous and the pair \((f, f^{*})\) is dual to another in the sense that \(f^{**} = f\).
High-level intuition: any convex function can be represented as point-wise max of linear functions.

\(f^{*}(t) = \underset{u \in \mathrm{dom}_{f}}{\mathrm{sup}} \{ut - f(u) \}\)
\(f(u) = \underset{t \in \mathrm{dom}_{f^{*}}}{\mathrm{sup}} \{tu - f^{*}(t) \}\)

Proof Recipe 1
\(\underset{x}{\mathrm{sup}} f(x) \geq f(x)\)
\(\mathbb{E} [\underset{x}{\mathrm{sup}} f(x)] \geq \mathbb{E} [f(x)]\)
\(\mathbb{E} [\underset{x}{\mathrm{sup}} f(x)] \geq \underset{x}{\mathrm{sup}}(\mathbb{E} [f(x)]) \geq \mathbb{E} [f(x)]\)

Proof Recipe 2
\(f(u) = \underset{t \in \mathrm{dom}_{f^{*}}}{\mathrm{sup}} \{ tu - f^{*}(t) \}\)
Bound is tight at \(u = (f^{*})'(t)\).
\((f')^{-1} = (f^{*})'\)

\(\frac{d}{dt} \{ tu - f^{*}(t) \} = 0\)
\(u - (f^{*})'(t) = 0\)
\(u = (f^{*})'(t)\)

\(\frac{d}{du} f(u) = \frac{d}{du} (tu - f^{*}(t))\)
\(= \frac{dt}{du} u + t - (f^{*})'(t) \frac{dt}{du}\)
\(= t\)

\(u = (f^{*})'(t)\)
\(f'(u) = t\)
\(f'((f^{*})'(t)) = t\)
\((f')^{-1} = (f^{*})'\)

2.3. Variational Divergence Minimization (VDM)

The f-divergence Family
f-divergence: \(D_f(P \parallel Q) = \int_{\mathcal{X}} q(x) f(\frac{p(x)}{q(x)}) dx\)
Generator function \(f: \mathbb{R}_{+} \to \mathbb{R}\) is a convex, lower-semicontinuous function satisfying \(f(1) = 0\).

Variational Estimation of f-divergences
\(D_f(P \parallel Q) = \int_{\mathcal{X}} q(x) f(\frac{p(x)}{q(x)}) dx\)
\(= \int_{\mathcal{X}} q(x) \underset{t \in \mathrm{dom}_{f^{*}}}{\mathrm{sup}} \{t \frac{p(x)}{q(x)} - f^{*}(t) \} dx\)
\(\geq \underset{T \in \mathcal{T}}{\mathrm{sup}} (\int_{\mathcal{X}} q(x) \{T(x) \frac{p(x)}{q(x)} - f^{*}(T(x)) \} dx)\)
\(= \underset{T \in \mathcal{T}}{\mathrm{sup}} (\int_{\mathcal{X}} T(x) p(x) dx - \int q(x) f^{*}(T(x)) dx)\)
\(= \underset{T \in \mathcal{T}}{\mathrm{sup}} ( \mathbb{E}_{x \sim P} [T(x)] - \mathbb{E}_{x \sim Q} [f^{*}(T(x))] )\)

Bound is tight for \(T^{*}(x) = f'(\frac{p(x)}{q(x)})\).
\(\underset{t \in \mathrm{dom}_{f^{*}}}{\mathrm{sup}} \{t \frac{p(x)}{q(x)} - f^{*}(t) \}\)
\(\frac{p(x)}{q(x)} = (f^{*})'(t)\)
\(f'(\frac{p(x)}{q(x)}) = t\)

Variational Divergence Minimization (VDM)
Since we do not know \(p(x)\) in \(T^{*}(x) = f'(\frac{p(x)}{q(x)})\), we approximate \(T^{*}(x)\) using neural networks.
\(F(\theta, \omega) = ( \mathbb{E}_{x \sim P} [T_{\omega}(x)] - \mathbb{E}_{x \sim Q} [f^{*}(T_{\omega}(x))] )\)

Represent the variational function \(T_{\omega}(x) = g_f(V_{\omega}(x))\), where \(V_{\omega}\) is the model and \(g_f\) is an output activation function.
\(F(\theta, \omega) = ( \mathbb{E}_{x \sim P} [g_f(V_{\omega}(x))] + \mathbb{E}_{x \sim Q} [-f^{*}( g_f(V_{\omega}(x)) )] )\)

2.4. Single-Step Gradient Method

Algorithm 1

In GAN, they optimized using double-loop method; the internal loop tightens the lower bound on the divergence, whereas the outer loop improves the generator model.
In practice, single step in the inner loop works well.
We can prove that Algorithm 1 geometrically converges to a saddle point \((\theta^*, \omega^*)\).

\(\nabla F(\theta, \omega) = \begin{pmatrix} \nabla_{\theta} F(\theta, \omega) \\ \nabla_{\omega} F(\theta, \omega) \end{pmatrix}, \ \tilde{\nabla} F(\theta, \omega) = \begin{pmatrix} - \nabla_{\theta} F(\theta, \omega) \\ \nabla_{\omega} F(\theta, \omega) \end{pmatrix}\)
\(\pi = (\theta, \omega), \ \pi^* = (\theta^*, \omega^*)\)
\(\pi^{t+1} = \pi^{t} + \eta \tilde{\nabla} F(\pi^{t})\)
\(J(\pi) = \frac{1}{2} \left\| \nabla F(\pi) \right\|^2_2 = \frac{1}{2} \left\| \tilde{\nabla} F(\pi) \right\|^2_2\)

Assumption: for neighborhood around the saddle point, \(F\) is strongly convex in \(\theta\), strongly concave in \(\omega\), and sufficiently smooth.
\(\nabla_{\theta} F(\pi^*) = 0, \ \nabla_{\omega} F(\pi^*) = 0\)
\(\nabla_{\theta}^2 F(\pi) \succeq \delta I, \ \nabla_{\omega}^2 F(\pi) \preceq - \delta I\)
\(\left\| \nabla J(\pi') - \nabla J(\pi) \right\|_2 \leq L \left\| \pi' - \pi \right\|_2\)

Using the step size \(\eta = \delta / L\) in Algorithm 1, then \(J(\pi^{t}) \leq (1 - \frac{\delta^2}{L})^t J(\pi^0)\).
The squared norm of the gradient \(\nabla F(\pi)\) decreases geometrically, resulting that Algorithm 1 converges to a saddle point \((\pi^*)\).

\(\left< \tilde{\nabla} F(\pi), \nabla J(\pi) \right> = \left< \tilde{\nabla} F(\pi), \nabla^2 F(\pi) \nabla F(\pi) \right>\)
(\(\nabla J(\pi) = \nabla (\frac{1}{2} \left\| \nabla F(\pi) \right\|^2_2) = \nabla^2 F(\pi) \nabla F(\pi)\))
\(= \left< \begin{pmatrix} -\nabla_{\theta} F(\theta, \omega) \\ \nabla_{\omega} F(\theta, \omega) \end{pmatrix}, \begin{pmatrix} \nabla_{\theta}^2 F(\theta, \omega) & \nabla_{\theta} \nabla_{\omega} F(\theta, \omega) \\ \nabla_{\omega} \nabla_{\theta} F(\theta, \omega) & \nabla_{\omega}^2 F(\theta, \omega) \\ \end{pmatrix} \begin{pmatrix} \nabla_{\theta} F(\theta, \omega) \\ \nabla_{\omega} F(\theta, \omega) \end{pmatrix} \right>\)
\(= \left< \begin{pmatrix} -\nabla_{\theta} F(\theta, \omega) \\ \nabla_{\omega} F(\theta, \omega) \end{pmatrix}, \begin{pmatrix} \nabla_{\theta}^2 F(\theta, \omega) \nabla_{\theta} F(\theta, \omega) + \nabla_{\theta} \nabla_{\omega} F(\theta, \omega) \nabla_{\omega} F(\theta, \omega) \\ \nabla_{\omega} \nabla_{\theta} F(\theta, \omega) \nabla_{\theta} F(\theta, \omega) + \nabla_{\omega}^2 F(\theta, \omega) \nabla_{\omega} F(\theta, \omega) \end{pmatrix} \right>\)
\(= - \left< \nabla_{\theta} F(\theta, \omega), \nabla_{\theta}^2 F(\theta, \omega) \nabla_{\theta} F(\theta, \omega) \right> + \left< \nabla_{\omega} F(\theta, \omega), \nabla_{\omega}^2 F(\theta, \omega) \nabla_{\omega} F(\theta, \omega) \right>\)
\(\leq - \delta \left\| \nabla_{\theta} F(\theta, \omega) \right\|_2^2 - \delta \left\| \nabla_{\omega} F(\theta, \omega) \right\|_2^2\)
(\(\nabla_{\theta}^2 F(\pi) \succeq \delta I, \ \nabla_{\omega}^2 F(\pi) \preceq - \delta I\))
\(= - \delta \left\| \nabla F(\pi) \right\|_2^2\)

\(\left\| \nabla J(\pi') - \nabla J(\pi) \right\|_2 \leq L \left\| \pi' - \pi \right\|_2\)
\(\frac{\left\| \nabla J(\pi') - \nabla J(\pi) \right\|_2}{\left\| \pi' - \pi \right\|_2} \leq L\)
\(\underset{\pi' \to \pi}{\mathrm{lim}} \frac{\left\| \nabla J(\pi') - \nabla J(\pi) \right\|_2}{\left\| \pi' - \pi \right\|_2} \leq L\)
\(H(J(\pi)) \leq L\)

\(J(\pi') \approx J(\pi) + <\nabla J(\pi), \pi' - \pi> + \frac{1}{2}(\pi' - \pi)^T H(J(\pi)) (\pi' - \pi)\)
\(J(\pi') \leq J(\pi) + <\nabla J(\pi), \pi' - \pi> + \frac{L}{2}(\pi' - \pi)^2\)

\(\pi^{t+1} - \pi^t = \eta \tilde{\nabla} F(\theta, \omega)\)
\(J(\pi^{t+1}) \leq J(\pi^t) + <\nabla J(\pi^t), \pi^{t+1} - \pi^t> + \frac{L}{2}(\pi^{t+1} - \pi^t)^2\)
\(J(\pi^{t+1}) \leq J(\pi^t) + <\nabla J(\pi^t), \eta \tilde{\nabla} F(\pi^t)> + \frac{L}{2}(\eta \tilde{\nabla} F(\pi^t))^2\)
\(J(\pi^{t+1}) \leq J(\pi^t) + \eta <\nabla J(\pi^t), \tilde{\nabla} F(\pi^t)> + \frac{L \eta^2}{2} \left\| \tilde{\nabla} F(\pi^t) \right\|^2\)
\(J(\pi^{t+1}) \leq J(\pi^t) - 2 \eta \delta J(\pi^t) + L \eta^2 J(\pi^t)\)
(\(\left< \tilde{\nabla} F(\pi), \nabla J(\pi) \right> \leq - \delta \left\| \nabla F(\pi) \right\|_2^2, \ J(\pi) = \frac{1}{2} \left\| \nabla F(\pi) \right\|^2_2 = \frac{1}{2} \left\| \tilde{\nabla} F(\pi) \right\|^2_2\))
\(J(\pi^{t+1}) \leq J(\pi^t) - 2 \eta \delta J(\pi^t) + L \eta^2 J(\pi^t)\)
\(J(\pi^{t+1}) \leq (1 - 2 \eta \delta + L \eta^2) J(\pi^t)\)
\(J(\pi^{t+1}) \leq (1 - \frac{\delta^2}{L}) J(\pi^t)\)
(\(\eta = \delta / L\))