Skip to content

Commit

Permalink
Add aggregate argument for more flexible loss (#38)
Browse files Browse the repository at this point in the history
* add aggregate function

* format

* set aggregate to True by default
  • Loading branch information
fehiepsi authored May 8, 2024
1 parent fb84070 commit dd17d17
Showing 1 changed file with 70 additions and 14 deletions.
84 changes: 70 additions & 14 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
]


def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
def apg_loss(
q_trace,
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=True,
):
"""RWS objective that exploits conditional dependency."""
del incoming_log_weight, incremental_log_weight
p_log_probs = {
Expand Down Expand Up @@ -87,35 +93,59 @@ def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
surrogate_loss = target_lp + forward_lp
log_weight = target_lp + reverse_lp - (forward_lp + proposal_lp)
w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0))
loss = -(w * surrogate_loss).sum()
loss = -(w * surrogate_loss)
if aggregate:
loss = loss.sum()
return loss


def avo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
def avo_loss(
q_trace,
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=True,
):
"""Annealed Variational Objective."""
del q_trace, p_trace
surrogate_loss = incremental_log_weight
if jnp.ndim(incoming_log_weight) > 0:
w1 = 1.0 / incoming_log_weight.shape[0]
else:
w1 = 1.0
loss = -(w1 * surrogate_loss).sum()
loss = -(w1 * surrogate_loss)
if aggregate:
loss = loss.sum()
return loss


def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
def elbo_loss(
q_trace,
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=True,
):
"""Evidence Lower Bound objective."""
del q_trace, p_trace
surrogate_loss = incremental_log_weight
if jnp.ndim(incoming_log_weight) > 0:
w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0))
else:
w1 = 1.0
loss = -(w1 * surrogate_loss).sum()
loss = -(w1 * surrogate_loss)
if aggregate:
loss = loss.sum()
return loss


def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
def fkl_loss(
q_trace,
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=True,
):
"""Forward KL objective. Here we do not optimize p."""
del p_trace
batch_ndims = incoming_log_weight.ndim
Expand Down Expand Up @@ -144,11 +174,19 @@ def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0))
log_weight = incoming_log_weight + incremental_log_weight
w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0))
loss = -(w * surrogate_loss - w1 * proposal_lp).sum()
loss = -(w * surrogate_loss - w1 * proposal_lp)
if aggregate:
loss = loss.sum()
return loss


def iwae_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
def iwae_loss(
q_trace,
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=True,
):
"""Importance Weighted Autoencoder objective."""
del q_trace, p_trace
log_weight = incoming_log_weight + incremental_log_weight
Expand All @@ -157,11 +195,19 @@ def iwae_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0))
else:
w = 1.0
loss = -(w * surrogate_loss).sum()
loss = -(w * surrogate_loss)
if aggregate:
loss = loss.sum()
return loss


def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
def rkl_loss(
q_trace,
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=True,
):
"""Reverse KL objective."""
batch_ndims = incoming_log_weight.ndim
p_log_probs = {
Expand Down Expand Up @@ -195,11 +241,19 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
)
log_weight = incoming_log_weight + incremental_log_weight
w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0))
loss = -(w1 * surrogate_loss - w * target_lp).sum()
loss = -(w1 * surrogate_loss - w * target_lp)
if aggregate:
loss = loss.sum()
return loss


def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
def rws_loss(
q_trace,
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=True,
):
"""Reweighted Wake-Sleep objective."""
batch_ndims = incoming_log_weight.ndim
p_log_probs = {
Expand Down Expand Up @@ -234,5 +288,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
surrogate_loss = (target_lp - proposal_lp) + forward_lp
log_weight = incoming_log_weight + incremental_log_weight
w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0))
loss = -(w * surrogate_loss).sum()
loss = -(w * surrogate_loss)
if aggregate:
loss = loss.sum()
return loss

0 comments on commit dd17d17

Please sign in to comment.