-
Notifications
You must be signed in to change notification settings - Fork 758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Renyi divergence #769
base: master
Are you sure you want to change the base?
Renyi divergence #769
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to go catch a flight but some preliminary comments:
.gitignore
Outdated
@@ -100,3 +100,9 @@ docs/*.html | |||
# IDE related | |||
.idea/ | |||
.vscode/ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove changes that aren't relevant for this PR? This includes changes to .gitignore
here as well as deletion of CSVs.
from edward.util import copy | ||
|
||
try: | ||
from edward.models import Normal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As convention, we use 2-space indent.
from __future__ import print_function | ||
|
||
import six | ||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As convention, we alphabetize the ordering of the import libraries.
"{0}. Your TensorFlow version is not supported.".format(e)) | ||
|
||
|
||
class Renyi_divergence(VariationalInference): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As convention, we use CamelCase for class names.
To perform the optimization, this class uses the techniques from | ||
Renyi Divergence Variational Inference (Y. Li & al, 2016) | ||
|
||
# Notes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstrings are parsed as Markdown and formatted in a somewhat specific way as they appear on the API docs. I recommend following the other classes, where you would denote a subsection as #### Notes
and when writing bullet points, do, e.g.,
#### Notes
+ bullet 1
+ bullet 2
+ maybe bulleted list in a bullet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Some comments below. The code looks correct and only minor suggestions with respect to formatting are laid out.
Can you include a unit test? See, e.g., how KLpq
is tested under the file tests/inferences/test_klpq.py
.
$ \text{D}_{R}^{(\alpha)}(q(z)||p(z \mid x)) | ||
= \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz $ | ||
|
||
To perform the optimization, this class uses the techniques from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Periods at end of sentences. (If you'd look at the generated API for the class, I recommend compiling the website following instructions from docs/
.)
= \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz $ | ||
|
||
To perform the optimization, this class uses the techniques from | ||
Renyi Divergence Variational Inference (Y. Li & al, 2016) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use bibtex for handling references in docstrings. This is handled by adding the appropriate bib entry to docs/tex/bib.bib
; make sure it's also written in the right order: we sort bib entries by their year, then alphabetically according to their citekey within each year.
When using references, you can produce (Li et al., 2016)
and Li et al. (2016)
by writing [@li2016renyi]
and @li2016renyirespectively, assuming that
li2016renyi` is the citekey.
|
||
# Notes: | ||
- Renyi divergence does not have any analytic version. | ||
- Renyi divergence does not have any version for non reparametrizable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does but the gradient estimator in @li2016variational
doesn't. I recommend just stating that this inference algorithm is restricted to variational approximations whose random variables all satisfy rv.reparameterization_type == tf.contrib.distributions.FULLY_REPARAMETERIZED
.
Also, instead of checking this during build_loss_and_gradients
I recommend checking this during the __init__
. This sort of check is done statically any graph construction similar to how we check for compatible shapes in all latent variables and data during __init__
.
|
||
def initialize(self, | ||
n_samples=32, | ||
alpha=1., |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As convention, we append all numerics with 0, e.g., 1.0
.
Number of samples from variational model for calculating | ||
stochastic gradients. | ||
alpha: float, optional. | ||
Renyi divergence coefficient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be useful to specify the domain of the coefficient. E.g., Must be greater than 0.
or etc.
"Variational Renyi inference only works with reparameterizable" | ||
" models") | ||
|
||
######### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is only used in one location and is a one-liner; could you write that line instead of defining a new function?
examples/vae_renyi.py
Outdated
scale=Dense(d, activation='softplus')(hidden)) | ||
|
||
# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x. | ||
inference = Renyi_divergence({z: qz}, data={x: x_ph}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code looks exactly the same as an older version of vae.py
but only differs in this line. To keep the VAE versions better synced, could you add a comment suggesting that this is also an alternative in the existing vae.py
?
Ideally, we'd like a specific application where ed.RenyiDivergence
produces better results by some metric than alternatives. IIRC, the paper had some interesting results for a Bayesian neural net on some specific UCI data sets. That would be great to have and reproduce some of their results.
If you don't have time for this, we can leave it off for now and raise it as a Github issue post-merging this PR.
self.scale.get(x, 1.0) | ||
* x_copy.log_prob(dict_swap[x])) | ||
|
||
logF = [p - q for p, q in zip(p_log_prob, q_log_prob)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of logF
, what about something like log_ratios
, which is more Pythonic in snake_case and also more semantically meaningful?
Replaced LogF by log_ratios Fix convention errors
Thanks for the suggestion and the very informative feedback.
Will do later today. |
I've added some testing in a similar way as KLqp. (both normal_normal and the bernouilli distribution. |
import tensorflow as tf | ||
|
||
from edward.models import Bernoulli, Normal | ||
from edward.inferences.renyi_divergence import RenyiDivergence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check the import works by instead using ed.RenyiDivergence
in the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got some issue with this but it should be working now.
[@li2016renyi]. | ||
|
||
#### Notes | ||
+ The gradient estimator used here does not have any analytic version. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With Markdown formatting, you don't need the 4 spaces of indentation. E.g., you can just do
#### Notes
+ The gradient estimator ...
+ ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
= \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz.$ | ||
|
||
The optimization is performed using the gradient estimator as defined in | ||
[@li2016renyi]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The citekey is being used as a direct object so it should be [@li2016renyi]
-> @li2016renyi
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
+ See Renyi Divergence Variational Inference [@li2016renyi] for | ||
more details. | ||
""" | ||
if self.is_reparameterizable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_reparameterizable
should be checked with the possible raising error during the __init__
, and since it's checked there it doesn't need to be stored in the class. This also helps to remove one layer of indentation in this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if self.backward_pass == 'max': | ||
log_ratios = tf.stack(log_ratios) | ||
log_ratios = tf.reduce_max(log_ratios, 0) | ||
loss = tf.reduce_mean(log_ratios) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understood the code correctly, log_ratios
when first created is a list of n_samples
elements, where each element is a log ratio calculation per sample from q. For the min / max modes, we take the min / max of these log ratios, which is a scalar.
Is tf.reduce_mean
for the loss needed? You can also remove the tf.stack
line in the min and max cases in the same way you didn't use it for the self.alpha \approx 1
case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. Thanks for spotting this.
examples/vae_renyi.py
Outdated
[@li2016renyi] | ||
|
||
#### Notes | ||
This example is almost exactly similar to example/vae.py. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the miscommunication. What I meant was that you can edit vae.py
, comment out the 1-2 lines of code to use ed.RenyiDivergence
, and add these notes there. This helps to compress the content in the examples, c.f., https://github.com/blei-lab/edward/blob/master/examples/bayesian_logistic_regression.py#L51.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed vae_renyi.py and modifed vae.py instead.
The version of vae.py I had wasn't running though. So I've modified it quite a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you using the latest version of Edward? We updated a few details in vae.py so it actually runs better. For example, you should be using the observations library and a generator, which is far more transparent than the mnist_data
class from that TensorFlow tutorial.
In addition, since vae.py
is also our canonical VAE example, I prefer keeping it as ed.KLqp
as the default, and with the renyi divergence option commented out; similarly, the top-level comments should be written in-line near the renyi divergence option instead.
If you have thoughts otherwise, happy to take alternative suggestions.
|
||
class test_renyi_divergence_class(tf.test.TestCase): | ||
|
||
def _test_normal_normal(self, Inference, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since RenyiDivergence
is used across all tests, you don't need Inference
as an arg to the test functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've used the same template as test_klpq
where only KLpq
is used during the tests and where Inference
is stil an argument to the test functions.
But I think I have modified to be closer to what you had in mind
Merged example to VAE Simplify tests
Modify examples and test calls
It keeps failing the travis-ci check for python 2.7 but before getting into the proper testing of my code (fail to install matplotlib and seaborn). |
Looks like this is happening in Travis on any build. I'll look into it. |
Here is an implementation of the Renyi divergence variational inference.
There's also an example on VAEs.
Here is a link to the edward forum with some more info:
https://discourse.edwardlib.org/t/renyi-divergence-variational-inference/366/3
ps: Sorry for the quite messy commit history.