-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Autonormal encoder #2849
base: dev
Are you sure you want to change the base?
[WIP] Autonormal encoder #2849
Changes from all commits
92e80b0
6a3261f
823dce2
065acfa
df4f012
6ec42bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ def model(): | |
from types import SimpleNamespace | ||
from typing import Callable, Dict, Union | ||
|
||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from torch.distributions import biject_to | ||
|
@@ -32,7 +33,11 @@ def model(): | |
import pyro.distributions as dist | ||
import pyro.poutine as poutine | ||
from pyro.distributions import constraints | ||
from pyro.distributions.transforms import affine_autoregressive, iterated | ||
from pyro.distributions.transforms import ( | ||
SoftplusTransform, | ||
affine_autoregressive, | ||
iterated, | ||
) | ||
from pyro.distributions.util import eye_like, is_identically_zero, sum_rightmost | ||
from pyro.infer.autoguide.initialization import ( | ||
InitMessenger, | ||
|
@@ -269,6 +274,25 @@ def median(self, *args, **kwargs): | |
result.update(part.median(*args, **kwargs)) | ||
return result | ||
|
||
def quantiles(self, quantiles, *args, **kwargs): | ||
""" | ||
Returns the posterior quantile values of each latent variable. | ||
|
||
Parameters | ||
---------- | ||
quantiles | ||
A list of requested quantiles between 0 and 1. | ||
|
||
Returns | ||
------- | ||
A dict mapping sample site name to quantiles tensor. | ||
""" | ||
|
||
result = {} | ||
for part in self: | ||
result.update(part.quantiles(quantiles, *args, **kwargs)) | ||
return result | ||
|
||
|
||
class AutoCallable(AutoGuide): | ||
""" | ||
|
@@ -1482,3 +1506,350 @@ def median(self, *args, **kwargs): | |
loc = loc.reshape(shape) | ||
result[name] = biject_to(site["fn"].support)(loc) | ||
return result | ||
|
||
|
||
class AutoNormalEncoder(AutoGuide): | ||
""" | ||
AutoNormal posterior approximation for amortised inference. | ||
|
||
This class defines the following operations (defines locs/scales) for every random variable specified by the user: | ||
|
||
loc = NN(data) @ hidden2loc, | ||
scales = softplus(NN(data) @ hidden2scales - 2), | ||
|
||
where NN is encoder network and hidden2loc/hidden2scales are tensors that | ||
convert hidden layer activations to locs/scales. | ||
|
||
The class supports single encoder NN for all variables as well as one encoder NN per variable. | ||
The output of encoder network is treated as a hidden layer, mean and sd are a linear function of hidden layer nodes, | ||
sd is transformed to positive scale using softplus. Data is transformed on input using `data_transform`. | ||
|
||
This class requires `amortised_plate_sites` dictionary with details about amortised variables (see below). | ||
|
||
Guide will have the same call signature as the model, so any argument to the model can be used for encoding as | ||
annotated in `amortised_plate_sites`, but it does not have to be the same as observed data in the model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model, | ||
amortised_plate_sites: dict, | ||
n_in: int, | ||
n_hidden: int = 200, | ||
init_param=0, | ||
init_param_scale: float = 1 / 50, | ||
data_transform=lambda x: x, | ||
encoder_class=None, # TODO this needs to be defined somewhere (e.g. simpler version of scvi.nn.FCLayers) | ||
encoder_kwargs=None, | ||
create_plates=None, | ||
single_encoder: bool = True, | ||
): | ||
""" | ||
|
||
Parameters | ||
---------- | ||
model | ||
Pyro model | ||
amortised_plate_sites | ||
Dictionary with amortised plate details: | ||
the name of observation/minibatch plate, | ||
indexes of model args to provide to encoder, | ||
variable names that belong to the observation plate | ||
and the number of dimensions in non-plate axis of each variable - such as: | ||
{ | ||
"name": "obs_plate", | ||
"in": [0], # expression data + (optional) batch index ([0, 2]) | ||
"sites": { | ||
"n_s_cells_per_location": 1, | ||
"y_s_groups_per_location": 1, | ||
"z_sr_groups_factors": model.n_groups, | ||
"w_sf": model.n_factors, | ||
"l_s_add": 1, | ||
} | ||
} | ||
n_in | ||
Number of input dimensions (for encoder_class). | ||
n_hidden | ||
Number of hidden nodes in each layer, including final layer. | ||
init_param | ||
Not implemented yet - initial values for amortised variables. | ||
init_param_scale | ||
How to scale/normalise initial values for weights converting hidden layers to mean and sd. | ||
data_transform | ||
Function to use for transforming data before passing it to encoder network. | ||
encoder_class | ||
Class for defining encoder network. | ||
encoder_kwargs | ||
Keyword arguments for encoder_class. | ||
create_plates | ||
Function for creating plates | ||
single_encoder | ||
Use single encoder for all variables (True) or one encoder per variable (False). | ||
""" | ||
|
||
super().__init__(model, create_plates=create_plates) | ||
self.amortised_plate_sites = amortised_plate_sites | ||
self.single_encoder = single_encoder | ||
|
||
self.softplus = SoftplusTransform() | ||
|
||
encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict() | ||
encoder_kwargs["n_hidden"] = n_hidden | ||
self.encoder_kwargs = encoder_kwargs | ||
|
||
self.n_in = n_in | ||
self.n_out = ( | ||
np.sum( | ||
[ | ||
np.sum(amortised_plate_sites["sites"][k]) | ||
for k in amortised_plate_sites["sites"].keys() | ||
] | ||
) | ||
* 2 | ||
) | ||
self.n_hidden = n_hidden | ||
self.encoder_class = encoder_class | ||
if self.single_encoder: | ||
# create a single encoder NN | ||
self.encoder = encoder_class( | ||
n_in=self.n_in, n_out=self.n_hidden, **self.encoder_kwargs | ||
) | ||
|
||
self.init_param_scale = init_param_scale | ||
self.data_transform = data_transform | ||
|
||
def _setup_prototype(self, *args, **kwargs): | ||
|
||
super()._setup_prototype(*args, **kwargs) | ||
|
||
self._event_dims = {} | ||
self._cond_indep_stacks = {} | ||
self.hidden2locs = PyroModule() | ||
self.hidden2scales = PyroModule() | ||
|
||
if not self.single_encoder: | ||
# create module class for collecting multiple encoder NN | ||
self.encoder = PyroModule() | ||
|
||
# Initialize guide params | ||
for name, site in self.prototype_trace.iter_stochastic_nodes(): | ||
# Collect unconstrained event_dims, which may differ from constrained event_dims. | ||
with helpful_support_errors(site): | ||
init_loc = ( | ||
biject_to(site["fn"].support).inv(site["value"].detach()).detach() | ||
) | ||
event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() | ||
self._event_dims[name] = event_dim | ||
|
||
# Collect independence contexts. | ||
self._cond_indep_stacks[name] = site["cond_indep_stack"] | ||
|
||
# add linear layer for locs and scales | ||
param_dim = (self.n_hidden, self.amortised_plate_sites["sites"][name]) | ||
init_param = torch.normal( | ||
torch.full(size=param_dim, fill_value=0.0, device=site["value"].device), | ||
torch.full( | ||
size=param_dim, | ||
fill_value=(1 * self.init_param_scale) / np.sqrt(self.n_hidden), | ||
device=site["value"].device, | ||
), | ||
) | ||
Comment on lines
+1649
to
+1656
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I use torch.normal rather than numpy.random.normal, I get this warning:
I also get different results after training the model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Numpy alternative
|
||
_deep_setattr( | ||
self.hidden2locs, | ||
name, | ||
PyroParam( | ||
torch.tensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that PyroParam(
- torch.tensor(
+ torch.as_tensor(
init_param, ... |
||
init_param, device=site["value"].device, requires_grad=True | ||
) | ||
), | ||
) | ||
|
||
init_param = torch.normal( | ||
torch.full(size=param_dim, fill_value=0.0, device=site["value"].device), | ||
torch.full( | ||
size=param_dim, | ||
fill_value=(1 * self.init_param_scale) / np.sqrt(self.n_hidden), | ||
device=site["value"].device, | ||
), | ||
) | ||
_deep_setattr( | ||
self.hidden2scales, | ||
name, | ||
PyroParam( | ||
torch.tensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: |
||
init_param, device=site["value"].device, requires_grad=True | ||
) | ||
), | ||
) | ||
|
||
if not self.single_encoder: | ||
_deep_setattr( | ||
self.encoder, | ||
name, | ||
self.encoder_class( | ||
n_in=self.n_in, n_out=self.n_hidden, **self.encoder_kwargs | ||
).to(site["value"].device), | ||
) | ||
|
||
def _get_loc_and_scale(self, name, encoded_hidden): | ||
""" | ||
Get mean (loc) and sd (scale) of the posterior distribution, as a linear function of encoder hidden layer. | ||
Parameters | ||
---------- | ||
name | ||
variable name | ||
encoded_hidden | ||
tensor when `single_encoder==True` and dictionary of tensors for each site when `single_encoder=False` | ||
|
||
""" | ||
|
||
linear_locs = _deep_getattr(self.hidden2locs, name) | ||
linear_scales = _deep_getattr(self.hidden2scales, name) | ||
|
||
if not self.single_encoder: | ||
# when using multiple encoders extract hidden layer for this parameter | ||
encoded_hidden = encoded_hidden[name] | ||
|
||
locs = encoded_hidden @ linear_locs | ||
scales = self.softplus((encoded_hidden @ linear_scales) - 2) | ||
|
||
return locs, scales | ||
|
||
def encode(self, *args, **kwargs): | ||
""" | ||
Apply encoder network to input data to obtain hidden layer encoding. | ||
Parameters | ||
---------- | ||
args | ||
Pyro model args | ||
kwargs | ||
Pyro model kwargs | ||
------- | ||
|
||
""" | ||
in_names = self.amortised_plate_sites["in"] | ||
x_in = [kwargs[i] if i in kwargs.keys() else args[i] for i in in_names] | ||
# apply data_transform | ||
x_in = [self.data_transform(x) for x in x_in] | ||
# when there are multiple encoders fetch encoders and encode data | ||
if not self.single_encoder: | ||
res = { | ||
name: _deep_getattr(self.encoder, name)(*x_in) | ||
for name, site in self.prototype_trace.iter_stochastic_nodes() | ||
} | ||
else: | ||
# encode with a single encoder | ||
res = self.encoder(*x_in) | ||
return res | ||
|
||
def forward(self, *args, **kwargs): | ||
""" | ||
An automatic guide with the same ``*args, **kwargs`` as the base ``model``. | ||
|
||
.. note:: This method is used internally by :class:`~torch.nn.Module`. | ||
Users should instead use :meth:`~torch.nn.Module.__call__`. | ||
|
||
:return: A dict mapping sample site name to sampled value. | ||
:rtype: dict | ||
""" | ||
# if we've never run the model before, do so now so we can inspect the model structure | ||
if self.prototype_trace is None: | ||
self._setup_prototype(*args, **kwargs) | ||
|
||
encoded_hidden = self.encode(*args, **kwargs) | ||
|
||
plates = self._create_plates(*args, **kwargs) | ||
result = {} | ||
for name, site in self.prototype_trace.iter_stochastic_nodes(): | ||
transform = biject_to(site["fn"].support) | ||
|
||
with ExitStack() as stack: | ||
for frame in site["cond_indep_stack"]: | ||
if frame.vectorized: | ||
stack.enter_context(plates[frame.name]) | ||
|
||
site_loc, site_scale = self._get_loc_and_scale(name, encoded_hidden) | ||
unconstrained_latent = pyro.sample( | ||
name + "_unconstrained", | ||
dist.Normal( | ||
site_loc, | ||
site_scale, | ||
).to_event(self._event_dims[name]), | ||
infer={"is_auxiliary": True}, | ||
) | ||
|
||
value = transform(unconstrained_latent) | ||
if pyro.poutine.get_mask() is False: | ||
log_density = 0.0 | ||
else: | ||
log_density = transform.inv.log_abs_det_jacobian( | ||
value, | ||
unconstrained_latent, | ||
) | ||
log_density = sum_rightmost( | ||
log_density, | ||
log_density.dim() - value.dim() + site["fn"].event_dim, | ||
) | ||
delta_dist = dist.Delta( | ||
value, | ||
log_density=log_density, | ||
event_dim=site["fn"].event_dim, | ||
) | ||
|
||
result[name] = pyro.sample(name, delta_dist) | ||
|
||
return result | ||
|
||
@torch.no_grad() | ||
def median(self, *args, **kwargs): | ||
""" | ||
Returns the posterior median value of each latent variable. | ||
|
||
:return: A dict mapping sample site name to median tensor. | ||
:rtype: dict | ||
""" | ||
|
||
encoded_latent = self.encode(*args, **kwargs) | ||
|
||
medians = {} | ||
for name, site in self.prototype_trace.iter_stochastic_nodes(): | ||
site_loc, _ = self._get_loc_and_scale(name, encoded_latent) | ||
median = biject_to(site["fn"].support)(site_loc) | ||
if median is site_loc: | ||
median = median.clone() | ||
medians[name] = median | ||
|
||
return medians | ||
|
||
@torch.no_grad() | ||
def quantiles(self, quantiles, *args, **kwargs): | ||
""" | ||
Returns posterior quantiles each latent variable. Example:: | ||
|
||
print(guide.quantiles([0.05, 0.5, 0.95])) | ||
|
||
:param quantiles: A list of requested quantiles between 0 and 1. | ||
:type quantiles: torch.Tensor or list | ||
:return: A dict mapping sample site name to a list of quantile values. | ||
:rtype: dict | ||
""" | ||
|
||
encoded_latent = self.encode(*args, **kwargs) | ||
|
||
results = {} | ||
|
||
for name, site in self.prototype_trace.iter_stochastic_nodes(): | ||
site_loc, site_scale = self._get_loc_and_scale(name, encoded_latent) | ||
|
||
site_quantiles = torch.tensor( | ||
quantiles, dtype=site_loc.dtype, device=site_loc.device | ||
) | ||
site_quantiles_values = dist.Normal(site_loc, site_scale).icdf( | ||
site_quantiles | ||
) | ||
constrained_site_quantiles = biject_to(site["fn"].support)( | ||
site_quantiles_values | ||
) | ||
results[name] = constrained_site_quantiles | ||
|
||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the main assumptions at the moment is that encoded variables are 2D tensors of shape (plate subsample_size aka batch size, self.amortised_plate_sites["sites"][name]) - but I guess that the shape can be automatically guessed, I just did not think that through and do not have applications where variables are more than 2D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's more important to get an initial version merged quickly than to make something fully general right away. So WDYT about simply adding assertions or
NotImplementedError("not yet supported")
checks for your current assumptions?Also feel free to start the class docstring with
EXPERIMENTAL and add a
.. warning:: Interface may change` to give you/us room to slightly change the interface later in case that fully-general version needs slight changes.