-
Notifications
You must be signed in to change notification settings - Fork 758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Renyi divergence #769
base: master
Are you sure you want to change the base?
Renyi divergence #769
Changes from 32 commits
5a5f8a9
3d65841
523210a
4682fa7
0e30c18
cc63477
c79dbb2
397eb71
005dd03
a38fa82
102453c
8fdbe52
f85d97d
b146080
d618579
9f9a889
66b8e87
17bdb8b
377ff9c
4c67eed
5aa9a25
d4f98b0
c340e46
18fd32f
8623ebc
57c5ba0
6fc9b8a
0df0215
671541b
9e0a3b7
821d102
6de5523
9d58f4f
bdcaa8f
dfad744
e5c4867
89fe5cd
da83c97
a9139ed
717d236
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import six | ||
import tensorflow as tf | ||
|
||
from edward.inferences.variational_inference import VariationalInference | ||
from edward.models import RandomVariable | ||
from edward.util import copy | ||
|
||
try: | ||
from edward.models import Normal | ||
from tensorflow.contrib.distributions import kl_divergence | ||
except Exception as e: | ||
raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) | ||
|
||
|
||
class RenyiDivergence(VariationalInference): | ||
"""Variational inference with the Renyi divergence [@li2016renyi]. | ||
|
||
It minimizes the Renyi divergence | ||
|
||
$ \text{D}_{R}^{(\alpha)}(q(z)||p(z \mid x)) | ||
= \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz.$ | ||
|
||
The optimization is performed using the gradient estimator as defined in | ||
[@li2016renyi]. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The citekey is being used as a direct object so it should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
#### Notes | ||
+ The gradient estimator used here does not have any analytic version. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With Markdown formatting, you don't need the 4 spaces of indentation. E.g., you can just do
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
+ The gradient estimator used here does not have any version for non | ||
reparametrizable models. | ||
+ backward_pass = 'max': (extreme case $\alpha \rightarrow -\infty$) | ||
the algorithm chooses the sample that has the maximum unnormalised | ||
importance weight. This does not minimize the Renyi divergence | ||
anymore. | ||
+ backward_pass = 'min': (extreme case $\alpha \rightarrow +\infty$) | ||
the algorithm chooses the sample that has the minimum unnormalised | ||
importance weight. This does not minimize the Renyi divergence | ||
anymore. This mode is not describe in the paper but implemented | ||
in the publicly available implementation of the paper's experiments. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
|
||
super(RenyiDivergence, self).__init__(*args, **kwargs) | ||
|
||
self.is_reparameterizable = all([ | ||
rv.reparameterization_type == | ||
tf.contrib.distributions.FULLY_REPARAMETERIZED | ||
for rv in six.itervalues(self.latent_vars)]) | ||
|
||
def initialize(self, | ||
n_samples=32, | ||
alpha=1.0, | ||
backward_pass='full', | ||
*args, **kwargs): | ||
"""Initialize inference algorithm. It initializes hyperparameters | ||
and builds ops for the algorithm's computation graph. | ||
|
||
Args: | ||
n_samples: int, optional. | ||
Number of samples from variational model for calculating | ||
stochastic gradients. | ||
alpha: float, optional. | ||
Renyi divergence coefficient. $\alpha \in \mathbb{R}$. | ||
When $\alpha < 0$, the algorithm still does something sensible but | ||
does not minimize the Renyi divergence anymore. | ||
(see [@li2016renyi] - section 4.2) | ||
backward_pass: str, optional. | ||
Backward pass mode to be used. | ||
Options: 'min', 'max', 'full' | ||
(see [@li2016renyi] - section 4.2) | ||
""" | ||
self.n_samples = n_samples | ||
self.alpha = alpha | ||
self.backward_pass = backward_pass | ||
|
||
return super(RenyiDivergence, self).initialize(*args, **kwargs) | ||
|
||
def build_loss_and_gradients(self, var_list): | ||
"""Build the Renyi ELBO function. | ||
|
||
Its automatic differentiation is a stochastic gradient of | ||
|
||
$ \mcalL_{R}^{\alpha}(q; x) = | ||
\frac{1}{1-\alpha} \log \dsE_{q} \left[ | ||
\left( \frac{p(x, z)}{q(z)}\right)^{1-\alpha} \right].$ | ||
|
||
It uses: | ||
+ Monte Carlo approximation of the ELBO [@li2016renyi]. | ||
+ Reparameterization gradients [@kingma2014auto]. | ||
+ Stochastic approximation of the joint distribution [@li2016renyi]. | ||
|
||
#### Notes | ||
+ If the model is not reparameterizable, it returns a | ||
NotImplementedError. | ||
+ See Renyi Divergence Variational Inference [@li2016renyi] for | ||
more details. | ||
""" | ||
if self.is_reparameterizable: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
p_log_prob = [0.0] * self.n_samples | ||
q_log_prob = [0.0] * self.n_samples | ||
base_scope = tf.get_default_graph().unique_name("inference") + '/' | ||
for s in range(self.n_samples): | ||
# Form dictionary in order to replace conditioning on prior or | ||
# observed variable with conditioning on a specific value. | ||
scope = base_scope \ | ||
+ tf.get_default_graph().unique_name("sample") | ||
dict_swap = {} | ||
for x, qx in six.iteritems(self.data): | ||
if isinstance(x, RandomVariable): | ||
if isinstance(qx, RandomVariable): | ||
qx_copy = copy(qx, scope=scope) | ||
dict_swap[x] = qx_copy.value() | ||
else: | ||
dict_swap[x] = qx | ||
|
||
for z, qz in six.iteritems(self.latent_vars): | ||
# Copy q(z) to obtain new set of posterior samples. | ||
qz_copy = copy(qz, scope=scope) | ||
dict_swap[z] = qz_copy.value() | ||
q_log_prob[s] += tf.reduce_sum( | ||
self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) | ||
|
||
for z in six.iterkeys(self.latent_vars): | ||
z_copy = copy(z, dict_swap, scope=scope) | ||
p_log_prob[s] += tf.reduce_sum( | ||
self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) | ||
|
||
for x in six.iterkeys(self.data): | ||
if isinstance(x, RandomVariable): | ||
x_copy = copy(x, dict_swap, scope=scope) | ||
p_log_prob[s] += tf.reduce_sum( | ||
self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) | ||
|
||
log_ratios = [p - q for p, q in zip(p_log_prob, q_log_prob)] | ||
|
||
if self.backward_pass == 'max': | ||
log_ratios = tf.stack(log_ratios) | ||
log_ratios = tf.reduce_max(log_ratios, 0) | ||
loss = tf.reduce_mean(log_ratios) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understood the code correctly, Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. Thanks for spotting this. |
||
elif self.backward_pass == 'min': | ||
log_ratios = tf.stack(log_ratios) | ||
log_ratios = tf.reduce_min(log_ratios, 0) | ||
loss = tf.reduce_mean(log_ratios) | ||
elif np.abs(self.alpha - 1.0) < 10e-3: | ||
loss = tf.reduce_mean(log_ratios) | ||
else: | ||
log_ratios = tf.stack(log_ratios) | ||
log_ratios = log_ratios * (1 - self.alpha) | ||
log_ratios_max = tf.reduce_max(log_ratios, 0) | ||
log_ratios = tf.log( | ||
tf.maximum(1e-9, | ||
tf.reduce_mean(tf.exp(log_ratios - log_ratios_max), 0))) | ||
log_ratios = (log_ratios + log_ratios_max) / (1 - self.alpha) | ||
loss = tf.reduce_mean(log_ratios) | ||
loss = -loss | ||
|
||
if self.logging: | ||
p_log_prob = tf.reduce_mean(p_log_prob) | ||
q_log_prob = tf.reduce_mean(q_log_prob) | ||
tf.summary.scalar("loss/p_log_prob", p_log_prob, | ||
collections=[self._summary_key]) | ||
tf.summary.scalar("loss/q_log_prob", q_log_prob, | ||
collections=[self._summary_key]) | ||
|
||
grads = tf.gradients(loss, var_list) | ||
grads_and_vars = list(zip(grads, var_list)) | ||
return loss, grads_and_vars | ||
else: | ||
raise NotImplementedError( | ||
"Variational Renyi inference only works with reparameterizable" | ||
" models.") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#!/usr/bin/env python | ||
"""Variational auto-encoder for MNIST data using Renyi variational objective | ||
[@li2016renyi] | ||
|
||
#### Notes | ||
This example is almost exactly similar to example/vae.py. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the miscommunication. What I meant was that you can edit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed vae_renyi.py and modifed vae.py instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you using the latest version of Edward? We updated a few details in vae.py so it actually runs better. For example, you should be using the observations library and a generator, which is far more transparent than the In addition, since If you have thoughts otherwise, happy to take alternative suggestions. |
||
The only difference is the use of the Renyi objective. | ||
For $\alpha=1.0$, the Renyi objective is equivalent to the KL-objective and the | ||
normal VAE is obtained. | ||
|
||
References | ||
---------- | ||
http://edwardlib.org/tutorials/decoder | ||
http://edwardlib.org/tutorials/inference-networks | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
from edward.inferences.renyi_divergence import RenyiDivergence | ||
import numpy as np | ||
import os | ||
import tensorflow as tf | ||
|
||
from edward.models import Bernoulli, Normal | ||
from edward.util import Progbar | ||
from keras.layers import Dense | ||
from scipy.misc import imsave | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
|
||
DATA_DIR = "data/mnist" | ||
IMG_DIR = "img" | ||
|
||
if not os.path.exists(DATA_DIR): | ||
os.makedirs(DATA_DIR) | ||
if not os.path.exists(IMG_DIR): | ||
os.makedirs(IMG_DIR) | ||
|
||
ed.set_seed(42) | ||
|
||
M = 100 # batch size during training | ||
d = 2 # latent dimension | ||
alpha = 0.5 # alpha values for renyi divergence | ||
n_samples = 5 # number of samples used to estimate the Renyi ELBO | ||
backward_pass = 'max' # Back propagation style ('min', 'max' or 'full') | ||
|
||
# DATA. MNIST batches are fed at training time. | ||
mnist = input_data.read_data_sets(DATA_DIR) | ||
|
||
# MODEL | ||
# Define a subgraph of the full model, corresponding to a minibatch of | ||
# size M. | ||
z = Normal(loc=tf.zeros([M, d]), scale=tf.ones([M, d])) | ||
hidden = Dense(256, activation='relu')(z.value()) | ||
x = Bernoulli(logits=Dense(28 * 28)(hidden)) | ||
|
||
# INFERENCE | ||
# Define a subgraph of the variational model, corresponding to a | ||
# minibatch of size M. | ||
x_ph = tf.placeholder(tf.int32, [M, 28 * 28]) | ||
hidden = Dense(256, activation='relu')(tf.cast(x_ph, tf.float32)) | ||
qz = Normal(loc=Dense(d)(hidden), | ||
scale=Dense(d, activation='softplus')(hidden)) | ||
|
||
# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x. | ||
inference = RenyiDivergence({z: qz}, data={x: x_ph}) | ||
optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0) | ||
inference.initialize(optimizer=optimizer, | ||
n_samples=n_samples, | ||
alpha=alpha, | ||
backward_pass=backward_pass) | ||
sess = ed.get_session() | ||
tf.global_variables_initializer().run() | ||
|
||
n_epoch = 100 | ||
n_iter_per_epoch = 1000 | ||
for epoch in range(n_epoch): | ||
avg_loss = 0.0 | ||
|
||
pbar = Progbar(n_iter_per_epoch) | ||
for t in range(1, n_iter_per_epoch + 1): | ||
pbar.update(t) | ||
x_train, _ = mnist.train.next_batch(M) | ||
x_train = np.random.binomial(1, x_train) | ||
info_dict = inference.update(feed_dict={x_ph: x_train}) | ||
avg_loss += info_dict['loss'] | ||
|
||
# Print a lower bound to the average marginal likelihood for an | ||
# image. | ||
avg_loss = avg_loss / n_iter_per_epoch | ||
avg_loss = avg_loss / M | ||
print("log p(x) >= {:0.3f}".format(avg_loss)) | ||
|
||
# Prior predictive check. | ||
imgs = sess.run(x) | ||
for m in range(M): | ||
imsave(os.path.join(IMG_DIR, '%d.png') % m, imgs[m].reshape(28, 28)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As convention, we alphabetize the ordering of the import libraries.