Skip to content

Commit

Permalink
set aggregate to True by default
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed May 8, 2024
1 parent 77da67c commit bcc736a
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def apg_loss(
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=False,
aggregate=True,
):
"""RWS objective that exploits conditional dependency."""
del incoming_log_weight, incremental_log_weight
Expand Down Expand Up @@ -104,7 +104,7 @@ def avo_loss(
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=False,
aggregate=True,
):
"""Annealed Variational Objective."""
del q_trace, p_trace
Expand All @@ -124,7 +124,7 @@ def elbo_loss(
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=False,
aggregate=True,
):
"""Evidence Lower Bound objective."""
del q_trace, p_trace
Expand All @@ -144,7 +144,7 @@ def fkl_loss(
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=False,
aggregate=True,
):
"""Forward KL objective. Here we do not optimize p."""
del p_trace
Expand Down Expand Up @@ -185,7 +185,7 @@ def iwae_loss(
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=False,
aggregate=True,
):
"""Importance Weighted Autoencoder objective."""
del q_trace, p_trace
Expand All @@ -206,7 +206,7 @@ def rkl_loss(
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=False,
aggregate=True,
):
"""Reverse KL objective."""
batch_ndims = incoming_log_weight.ndim
Expand Down Expand Up @@ -252,7 +252,7 @@ def rws_loss(
p_trace,
incoming_log_weight,
incremental_log_weight,
aggregate=False,
aggregate=True,
):
"""Reweighted Wake-Sleep objective."""
batch_ndims = incoming_log_weight.ndim
Expand Down

0 comments on commit bcc736a

Please sign in to comment.