diff --git a/README.md b/README.md index 8b4857d..1e82492 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,9 @@ # coix [![Unittests](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml) -[![Documentation Status](https://readthedocs.org/projects/coix/badge/?version=latest)](https://coix.readthedocs.io/en/latest/?badge=latest) [![PyPI version](https://badge.fury.io/py/coix.svg)](https://badge.fury.io/py/coix) -Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box. +Inference Combinators in JAX (Coix) is a machine learning framework used to +develop inference algorithms that are composed of probabilistic programs. *This is not an officially supported Google product.* - diff --git a/coix/__init__.py b/coix/__init__.py index d46c565..69cefab 100644 --- a/coix/__init__.py +++ b/coix/__init__.py @@ -1,17 +1,3 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """coix API.""" from coix import algo diff --git a/coix/algo.py b/coix/algo.py index f600fbe..45ce271 100644 --- a/coix/algo.py +++ b/coix/algo.py @@ -1,17 +1,3 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Inference algorithms.""" import functools @@ -144,7 +130,6 @@ def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None): if _use_fori_loop(targets, num_targets): def body_fun(i, q): - assert callable(targets) p = extend(compose(momentum, targets(i), suffix=False), refreshment) return propose(p, compose(refreshment, compose(leapfrog, q))) @@ -156,7 +141,7 @@ def body_fun(i, q): targets = [compose(momentum, p, suffix=False) for p in targets] q = targets[0] - loss_fns = (None,) * (len(targets) - 2) + (iwae_loss,) + loss_fns = [None] * (len(targets) - 2) + [iwae_loss] for p, loss_fn in zip(targets[1:], loss_fns): q = compose(refreshment, compose(leapfrog, q)) q = propose(extend(p, refreshment), q, loss_fn=loss_fn) @@ -414,7 +399,7 @@ def body_fun(i, q): return propose(targets(num_targets - 1), q, loss_fn=iwae_loss) q = propose(targets[0], proposals[0]) - loss_fns = (None,) * (len(proposals) - 2) + (iwae_loss,) + loss_fns = [None] * (len(proposals) - 2) + [iwae_loss] for p, fwd, loss_fn in zip(targets[1:], proposals[1:], loss_fns): q = propose(p, compose(fwd, resample(q)), loss_fn=loss_fn) return q diff --git a/coix/algo_test.py b/coix/algo_test.py deleted file mode 100644 index 4cd81fc..0000000 --- a/coix/algo_test.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for algo.py.""" - -import functools - -import coix -import jax -from jax import random -import jax.numpy as jnp -import numpy as np -import numpyro.distributions as dist -import optax - -coix.set_backend("coix.oryx") - -np.random.seed(0) -num_data, dim = 4, 2 -data = np.random.randn(num_data, dim).astype(np.float32) -loc_p = np.random.randn(dim).astype(np.float32) -precision_p = np.random.rand(dim).astype(np.float32) -scale_p = np.sqrt(1 / precision_p) -precision_x = np.random.rand(dim).astype(np.float32) -scale_x = np.sqrt(1 / precision_x) -precision_q = precision_p + num_data * precision_x -loc_q = (data.sum(0) * precision_x + loc_p * precision_p) / precision_q -log_scale_q = -0.5 * np.log(precision_q) - - -def model(params, key): - del params - key_z, key_next = random.split(key) - z = coix.rv(dist.Normal(loc_p, scale_p), name="z")(key_z) - z = jnp.broadcast_to(z, (num_data, dim)) - x = coix.rv(dist.Normal(z, scale_x), obs=data, name="x") - return key_next, z, x - - -def guide(params, key, *args): - del args - key, _ = random.split(key) # split here to test tie_in - scale_q = jnp.exp(params["log_scale_q"]) - z = coix.rv(dist.Normal(params["loc_q"], scale_q), name="z")(key) - return z - - -def check_ess(make_program): - params = {"loc_q": loc_q, "log_scale_q": log_scale_q} - p = jax.vmap(functools.partial(model, params)) - q = jax.vmap(functools.partial(guide, params)) - program = make_program(p, q) - - keys = random.split(random.PRNGKey(0), 5) - ess = coix.traced_evaluate(program)(keys)[2]["ess"] - np.testing.assert_allclose(ess, 5.0) - - -def run_inference(make_program, num_steps=1000): - """Performs inference given an algorithm `make_program`.""" - - def loss_fn(params, key): - p = jax.vmap(functools.partial(model, params)) - q = jax.vmap(functools.partial(guide, params)) - program = make_program(p, q) - - keys = random.split(key, 5) - metrics = coix.traced_evaluate(program)(keys)[2] - return metrics["loss"], metrics - - init_params = { - "loc_q": jnp.zeros_like(loc_q), - "log_scale_q": jnp.zeros_like(log_scale_q), - } - params, _ = coix.util.train( - loss_fn, init_params, optax.adam(0.01), num_steps=num_steps - ) - - np.testing.assert_allclose(params["loc_q"], loc_q, atol=0.2) - np.testing.assert_allclose(params["log_scale_q"], log_scale_q, atol=0.2) - - -def test_apgs(): - check_ess(lambda p, q: coix.algo.apgs(p, [q])) - run_inference(lambda p, q: coix.algo.apgs(p, [q])) - - -def test_rws(): - check_ess(coix.algo.rws) - run_inference(coix.algo.rws) - - -def test_svi_elbo(): - check_ess(coix.algo.svi) - run_inference(coix.algo.svi) - - -def test_svi_iwae(): - check_ess(coix.algo.svi_iwae) - run_inference(coix.algo.svi_iwae) - - -def test_svi_stl(): - check_ess(coix.algo.svi_stl) - run_inference(coix.algo.svi_stl) diff --git a/coix/api.py b/coix/api.py index f6f55f5..c4d6101 100644 --- a/coix/api.py +++ b/coix/api.py @@ -1,22 +1,7 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Program combinators. The implement is pretty much backend-agnostic. We just assume that the core backend supports the following functionality: - + `suffix(p)`: rename latent variables of the program `p`, + `traced_evaluate(p, latents=None)`: execute `p` and collect trace, metrics, optionally we can substitute values in `latents` to `p`, @@ -32,23 +17,6 @@ import jax.numpy as jnp import numpy as np -# pytype: disable=module-attr -try: - wrap_key_data = jax.random.wrap_key_data -except AttributeError: - try: - wrap_key_data = jax.extend.random.wrap_key_data - except AttributeError: - - def _identity(k): - return k - - wrap_key_data = _identity - - -# pytype: enable=module-attr - - __all__ = [ "compose", "extend", @@ -59,14 +27,14 @@ def _identity(k): def compose(q2, q1, suffix=True): - r"""Executes q2(\*q1(...)). + """Executes q2(*q1(...)). Note: We only allow at most one of `q1` or `q2` is weighted. Args: q2: a program q1: a program - suffix: whether to add suffix `\_PREV\_` to variables in `q1` + suffix: whether to add suffix `_PREV_` to variables in `q1` Returns: q: the composed program @@ -80,7 +48,7 @@ def wrapped(*args, **kwargs): def extend(p, f): - r"""Executes f(\*p(...)) with random variables in f marked as auxiliary. + """Executes f(*p(...)) with random variables in f marked as auxiliary. Note: We don't allow recursively marginalize out `p` yet. @@ -100,20 +68,41 @@ def wrapped(*args, **kwargs): return wrapped -def _reshape_key(key, shape): - if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): - return jnp.reshape(key, shape) - else: - return jnp.reshape(key, shape + (2,)) +def _get_batch_ndims(log_probs): + if not log_probs: + return 0 + min_ndim = min(jnp.ndim(lp) for lp in log_probs) + batch_ndims = 0 + for i in range(min_ndim): + if len(set(jnp.shape(lp)[i] for lp in log_probs)) > 1: + break + batch_ndims = batch_ndims + 1 + return batch_ndims + + +def _get_log_weight(trace, batch_ndims): + """Computes log weight of the trace and keeps its batch dimensions.""" + log_weight = jnp.zeros((1,) * batch_ndims) + for site in trace.values(): + lp = util.get_site_log_prob(site) + if util.is_observed_site(site): + log_weight = log_weight + jnp.sum( + lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0)) + ) + else: + log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims]) + return log_weight def _split_key(key): - keys = jax.vmap(jax.random.split, out_axes=1)(_reshape_key(key, (-1,))) - return keys[0].reshape(key.shape), keys[1].reshape(key.shape) + keys = jax.vmap(jax.random.split)(key.reshape(-1, 2)).reshape( + key.shape[:-1] + (2, 2) + ) + return keys[..., 0, :], keys[..., 1, :] def _fold_in_key(key, i): - key_new = jax.vmap(jax.random.fold_in, (0, None))(_reshape_key(key, (-1,)), i) + key_new = jax.vmap(jax.random.fold_in, (0, None))(key.reshape(-1, 2), i) return key_new.reshape(key.shape) @@ -161,7 +150,7 @@ def wrapped(*args, **kwargs): name: util.get_site_log_prob(site) for name, site in q_trace.items() } log_probs = list(p_log_probs.values()) + list(q_log_probs.values()) - batch_ndims = util.get_batch_ndims(log_probs) + batch_ndims = _get_batch_ndims(log_probs) if "log_weight" in q_metrics: in_log_weight = q_metrics["log_weight"] @@ -170,7 +159,7 @@ def wrapped(*args, **kwargs): axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)), ) else: - in_log_weight = util.get_log_weight(q_trace, batch_ndims) + in_log_weight = _get_log_weight(q_trace, batch_ndims) p_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() @@ -179,7 +168,7 @@ def wrapped(*args, **kwargs): # Note: We include superfluous variables, whose `name in p_trace`. q_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) - for lp in q_log_probs.values() + for name, lp in q_log_probs.items() ) incremental_log_weight = p_log_weight - q_log_weight log_weight = in_log_weight + incremental_log_weight @@ -232,20 +221,12 @@ def _maybe_get_along_first_axis(x, idx, n, squeeze=False): x = np.array(x) # Special treatment for cascades. if hasattr(x, "value"): - setattr( - x, - "value", - _maybe_get_along_first_axis( - util.get_site_value(x), idx, n, squeeze=squeeze - ), + x.value = _maybe_get_along_first_axis( + util.get_site_value(x), idx, n, squeeze=squeeze ) if hasattr(x, "log_density"): - setattr( - x, - "log_density", - _maybe_get_along_first_axis( - util.get_site_log_prob(x), idx, n, squeeze=squeeze - ), + x.log_density = _maybe_get_along_first_axis( + util.get_site_log_prob(x), idx, n, squeeze=squeeze ) if ( isinstance(x, (np.ndarray, jnp.ndarray)) @@ -255,12 +236,6 @@ def _maybe_get_along_first_axis(x, idx, n, squeeze=False): idx = idx.reshape(idx.shape + (1,) * (x.ndim - idx.ndim)) if isinstance(x, np.ndarray): y = np.take_along_axis(x, idx, axis=0) - elif jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key): - x_data = jax.random.key_data(x) - idx = idx.reshape(idx.shape + (1,) * (x_data.ndim - idx.ndim)) - y_data = jnp.take_along_axis(x_data, idx, axis=0) - y_data = y_data[0] if (idx.shape[0] == 1 and squeeze) else y_data - y = wrap_key_data(y_data) else: y = jnp.take_along_axis(x, idx, axis=0) y = y.tolist() if is_list else y @@ -286,7 +261,7 @@ def fn(*args, **kwargs): if util.can_extract_key(args): key_r, key_q = _split_key(args[0]) # We just need a single key for resampling. - key_r = _reshape_key(key_r, (-1,))[0] + key_r = key_r.reshape((-1, 2)).sum(0) args = (key_q,) + args[1:] else: key_r = core.prng_key() @@ -294,7 +269,7 @@ def fn(*args, **kwargs): log_probs = { name: util.get_site_log_prob(site) for name, site in trace.items() } - batch_ndims = util.get_batch_ndims(log_probs.values()) + batch_ndims = _get_batch_ndims(log_probs.values()) weighted = ("log_weight" in q_metrics) or any( util.is_observed_site(site) for site in trace.values() ) @@ -309,7 +284,7 @@ def fn(*args, **kwargs): axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)), ) else: - in_log_weight = util.get_log_weight(trace, batch_ndims) + in_log_weight = _get_log_weight(trace, batch_ndims) n = in_log_weight.shape[0] k = n if num_samples is None else num_samples log_weight = jax.nn.logsumexp(in_log_weight, 0) - jnp.log(k if k else 1) @@ -346,20 +321,15 @@ def _add_missing_metrics(metrics, trace): name: util.get_site_log_prob(site) for name, site in trace.items() } if "log_weight" not in metrics: - batch_ndims = min(util.get_batch_ndims(list(log_probs.values())), 1) - log_weight = util.get_log_weight(trace, batch_ndims) + batch_ndims = min(_get_batch_ndims(list(log_probs.values())), 1) + log_weight = _get_log_weight(trace, batch_ndims) full_metrics["log_weight"] = log_weight - else: - batch_ndims = metrics["log_weight"].ndim - log_weight = metrics["log_weight"] - # leftmost dimension is particle dimension - if batch_ndims and "ess" not in metrics: - assert "log_Z" not in metrics - ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0) - full_metrics["ess"] = ess.mean() - n = log_weight.shape[0] - log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log(n) - full_metrics["log_Z"] = log_z.mean() + if batch_ndims: # leftmost dimension is particle dimension + ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0) + full_metrics["ess"] = ess.mean() + n = log_weight.shape[0] + log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log(n) + full_metrics["log_Z"] = log_z.mean() if "loss" not in metrics: full_metrics["loss"] = jnp.array(0.0) if "log_density" not in metrics: @@ -383,18 +353,17 @@ def fori_loop(lower, upper, body_fun, init_program): """ def fn(*args, **kwargs): - def trace_arg_key(fn, key): - return core.traced_evaluate(fn)(key, *args[1:], **kwargs) - - def trace_with_seed(fn, key): - return core.traced_evaluate(fn, seed=key)(*args, **kwargs) - if util.can_extract_key(args): key = args[0] - trace_fn = trace_arg_key + + def trace_fn(fn, key): + return core.traced_evaluate(fn)(key, *args[1:], **kwargs) + else: key = core.prng_key() - trace_fn = trace_with_seed + + def trace_fn(fn, key): + return core.traced_evaluate(fn, seed=key)(*args, **kwargs) key_body, key_init = _split_key(key) @@ -443,7 +412,7 @@ def memoize(p, q, memory=None, memory_size=None): def wrapped(*args, **kwargs): if util.can_extract_key(args): key = args[0] - p_key, q_key = _split_key(key) + p_key, q_key = key + jnp.asarray([1, 0], dtype=key.dtype), key + 1 p_args = (p_key,) + args[1:] q_args = (q_key,) + args[1:] else: @@ -461,11 +430,11 @@ def wrapped(*args, **kwargs): p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items() } - batch_ndims = util.get_batch_ndims(p_log_probs.values()) + batch_ndims = _get_batch_ndims(p_log_probs.values()) p_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) - for lp in p_log_probs.values() + for name, lp in p_log_probs.items() ) marginal_trace = { @@ -476,7 +445,6 @@ def wrapped(*args, **kwargs): new_memory = { name: util.get_site_value(site) for name, site in marginal_trace.items() } - assert not isinstance(p_log_weight, int) num_particles = p_log_weight.shape[0] batch_dim = p_log_weight.ndim flat_memory = { diff --git a/coix/api_test.py b/coix/api_test.py deleted file mode 100644 index 765e0c4..0000000 --- a/coix/api_test.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for api.py.""" - -import coix -import jax -from jax import random -import numpy as np -import numpyro.distributions as dist -import pytest - -coix.set_backend("coix.oryx") - - -def test_compose(): - def p(key): - key, subkey = random.split(key) - x = coix.rv(dist.Normal(0, 1), name="x")(subkey) - return key, x - - def f(key, x): - return coix.rv(dist.Normal(x, 1), name="z")(key) - - _, p_trace, _ = coix.traced_evaluate(coix.compose(f, p))(random.PRNGKey(0)) - assert set(p_trace.keys()) == {"x", "z"} - - -def test_extend(): - def p(key): - key, subkey = random.split(key) - x = coix.rv(dist.Normal(0, 1), name="x")(subkey) - return key, x - - def f(key, x): - return (coix.rv(dist.Normal(x, 1), name="z")(key),) - - def g(z): - return z + 1 - - key = random.PRNGKey(0) - out, trace, _ = coix.traced_evaluate(coix.extend(p, f))(key) - assert set(trace.keys()) == {"x", "z"} - - expected_key, expected_x = p(key) - expected_key = random.key_data(expected_key) - actual_key = random.key_data(out[0]) - np.testing.assert_allclose(actual_key, expected_key) - np.testing.assert_allclose(out[1], expected_x) - - marginal_pfg = coix.traced_evaluate(coix.extend(p, coix.compose(g, f)))(key)[ - 0 - ] - actual_key2, actual_x2 = marginal_pfg - actual_key2 = random.key_data(actual_key2) - np.testing.assert_allclose(actual_key2, expected_key) - np.testing.assert_allclose(actual_x2, expected_x) - - -def test_propose(): - def p(key): - key, subkey = random.split(key) - x = coix.rv(dist.Normal(0, 1), name="x")(subkey) - return key, x - - def f(key, x): - return coix.rv(dist.Normal(x, 1), name="z")(key) - - def q(key): - return coix.rv(dist.Normal(1, 2), name="x")(key) - - program = coix.propose(coix.extend(p, f), q) - key = random.PRNGKey(0) - out, trace, metrics = coix.traced_evaluate(program)(key) - assert set(trace.keys()) == {"x", "z"} - assert isinstance(out, tuple) and len(out) == 2 - assert out[0].shape == key.shape - with np.testing.assert_raises(AssertionError): - np.testing.assert_allclose(metrics["log_density"], 0.0) - - particle_program = coix.propose(jax.vmap(coix.extend(p, f)), jax.vmap(q)) - keys = random.split(key, 3) - particle_out = particle_program(keys) - assert isinstance(particle_out, tuple) and len(particle_out) == 2 - assert particle_out[0].shape == keys.shape - - -def test_resample(): - def q(key): - return coix.rv(dist.Normal(1, 2), name="x")(key) - - particle_program = jax.vmap(q) - keys = random.split(random.PRNGKey(0), 3) - particle_out = coix.resample(particle_program)(keys) - assert particle_out.shape == (3,) - - -def test_resample_one(): - def q(key): - x = coix.rv(dist.Normal(1, 2), name="x")(key) - return coix.rv(dist.Normal(x, 1), name="z", obs=0.0) - - particle_program = jax.vmap(q) - keys = random.split(random.PRNGKey(0), 3) - particle_out = coix.resample(particle_program, num_samples=())(keys) - assert not particle_out.shape - - -def test_fori_loop(): - def drift(key, x): - key_out, key = random.split(key) - x_new = coix.rv(dist.Normal(x, 1.0), name="x")(key) - return key_out, x_new - - compile_time = {"value": 0} - - def body_fun(_, q): - compile_time["value"] += 1 - return coix.propose(drift, coix.compose(drift, q)) - - q = drift - for i in range(5): - q = body_fun(i, q) - x_init = np.zeros(3, np.float32) - q(random.PRNGKey(0), x_init) - assert compile_time["value"] == 5 - - random_walk = coix.fori_loop(0, 5, body_fun, drift) - random_walk(random.PRNGKey(0), x_init) - assert compile_time["value"] == 6 - - -# TODO(phandu): Support memoised arrays. -@pytest.mark.skip(reason="Currently, we only support memoised lists.") -def test_memoize(): - def model(key): - x = coix.rv(dist.Normal(0, 1), name="x")(key) - y = coix.rv(dist.Normal(x, 1), name="y", obs=0.0) - return x, y - - def guide(key): - return coix.rv(dist.Normal(1, 2), name="x")(key) - - def vmodel(key): - return jax.vmap(model)(random.split(key, 5)) - - def vguide(key): - return jax.vmap(guide)(random.split(key, 3)) - - memory = {"x": np.array([2, 4])} - program = coix.memoize(vmodel, vguide, memory) - out, trace, metrics = coix.traced_evaluate(program)(random.PRNGKey(0)) - assert set(trace.keys()) == {"x"} - assert "memory" in metrics - assert metrics["memory"]["x"].shape == (2,) - assert out[0].shape == (2,) diff --git a/coix/core.py b/coix/core.py index 3d4e2b0..cafa417 100644 --- a/coix/core.py +++ b/coix/core.py @@ -1,17 +1,3 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Program transforms.""" import importlib @@ -20,12 +6,13 @@ "detach", "empirical", "factor", + "get_backend_name", "prng_key", "rv", - "register_backend", - "set_backend", "stick_the_landing", "suffix", + "register_backend", + "set_backend", "traced_evaluate", ] @@ -92,8 +79,8 @@ def get_backend_name(): def get_backend(): backend = _COIX_BACKEND if backend is None: - set_backend("coix.numpyro") - return _BACKENDS["coix.numpyro"] + set_backend("coix.oryx") + return _BACKENDS["coix.oryx"] else: return _BACKENDS[backend] @@ -129,12 +116,11 @@ def desuffix(trace): return new_trace -def traced_evaluate(p, latents=None, seed=None, **kwargs): +def traced_evaluate(p, latents=None, rng_seed=None, **kwargs): """Performs traced evaluation for a program `p`.""" - # Work around some backends not having `seed` keyword. kwargs = kwargs.copy() - if seed is not None: - kwargs["seed"] = seed + if rng_seed is not None: + kwargs["rng_seed"] = rng_seed fn = get_backend()["traced_evaluate"](p, latents=latents, **kwargs) def wrapped(*args, **kwargs): diff --git a/coix/core_test.py b/coix/core_test.py deleted file mode 100644 index a3c70c6..0000000 --- a/coix/core_test.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for core.py.""" - -import coix.core - - -def test_desuffix(): - trace = { - "z_PREV__PREV_": 0, - "v_PREV__PREV_": 1, - "z_PREV_": 2, - "v_PREV_": 3, - "v": 4, - } - desuffix_trace = { - "z_PREV_": 0, - "v_PREV__PREV_": 1, - "z": 2, - "v_PREV_": 3, - "v": 4, - } - assert coix.core.desuffix(trace) == desuffix_trace diff --git a/coix/loss.py b/coix/loss.py index bc7fda4..b8c03df 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -1,17 +1,3 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Inference objectives.""" from coix import util diff --git a/coix/loss_test.py b/coix/loss_test.py deleted file mode 100644 index 46fde42..0000000 --- a/coix/loss_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for loss.py.""" - -import coix -import jax.numpy as jnp -import numpy as np - -p_trace = { - "x": {"log_prob": np.full((3, 2), 2.0)}, - "y": {"log_prob": np.array([3.0, 0.0, -2.0])}, - "x_PREV_": {"log_prob": np.ones((3, 2))}, -} -q_trace = { - "x": {"log_prob": np.ones((3, 2))}, - "y": {"log_prob": np.array([1.0, 1.0, 0.0])}, - "x_PREV_": {"log_prob": np.full((3, 2), 3.0)}, -} -incoming_weight = np.zeros(3) -incremental_weight = np.log(np.array([1 / 6, 1 / 3, 1 / 2])) - - -def test_apg(): - result = coix.loss.apg_loss( - q_trace, p_trace, incoming_weight, incremental_weight - ) - np.testing.assert_allclose(result, -6.0) - - -def test_elbo(): - result = coix.loss.elbo_loss( - q_trace, p_trace, incoming_weight, incremental_weight - ) - expected = -incremental_weight.sum() / 3 - np.testing.assert_allclose(result, expected) - - -def test_iwae(): - result = coix.loss.iwae_loss( - q_trace, p_trace, incoming_weight, incremental_weight - ) - w = incoming_weight + incremental_weight - expected = -(jnp.exp(w) * w).sum() - np.testing.assert_allclose(result, expected, rtol=1e-6) - - -def test_rws(): - result = coix.loss.rws_loss( - q_trace, p_trace, incoming_weight, incremental_weight - ) - np.testing.assert_allclose(result, 1.0, rtol=1e-6) diff --git a/coix/numpyro.py b/coix/numpyro.py index e624272..a34e845 100644 --- a/coix/numpyro.py +++ b/coix/numpyro.py @@ -1,22 +1,5 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Backend implementation for NumPyro.""" -from coix.util import get_batch_ndims -from coix.util import get_log_weight -from coix.util import get_site_log_prob import jax import jax.numpy as jnp import numpyro @@ -50,11 +33,6 @@ def wrapped(*args, **kwargs): for name, site in tr.items() if site["type"] == "metric" } - # add log_weight to metrics - if "log_weight" not in metrics: - log_probs = [get_site_log_prob(site) for site in trace.values()] - weight = get_log_weight(trace, get_batch_ndims(log_probs)) - metrics = {**metrics, "log_weight": weight} return out, trace, metrics return wrapped diff --git a/coix/oryx.py b/coix/oryx.py index b03e15f..68bd481 100644 --- a/coix/oryx.py +++ b/coix/oryx.py @@ -1,26 +1,9 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Program primitives and transforms.""" import functools import inspect import itertools -from coix.util import get_batch_ndims -from coix.util import get_log_weight -from coix.util import get_site_log_prob import jax import jax.numpy as jnp @@ -420,10 +403,6 @@ def wrapped(*args, **kwargs): if "log_density" not in metrics: log_density = sum(jnp.sum(site["log_prob"]) for site in trace.values()) metrics["log_density"] = jnp.array(0.0) + log_density - if "log_weight" not in metrics: - log_probs = [get_site_log_prob(site) for site in trace.values()] - weight = get_log_weight(trace, get_batch_ndims(log_probs)) - metrics = {**metrics, "log_weight": weight} return out, trace, metrics return wrapped diff --git a/coix/oryx_test.py b/coix/oryx_test.py deleted file mode 100644 index f925e41..0000000 --- a/coix/oryx_test.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for oryx.py.""" - -import coix -import coix.core -import coix.oryx -import jax -from jax import random -import jax.numpy as jnp -import numpy as np -import numpyro.distributions as dist - -coix.set_backend("coix.oryx") - - -def test_call_and_reap_tags(): - def model(key): - return coix.rv(dist.Normal(0, 1), name="x")(key) - - _, trace, _ = coix.traced_evaluate(model)(random.PRNGKey(0)) - assert set(trace.keys()) == {"x"} - assert set(trace["x"].keys()) == {"value", "log_prob"} - - -def test_delta_distribution(): - def model(key): - x = random.normal(key) - return coix.rv(dist.Delta(x, 5.0), name="x")(key) - - _, trace, _ = coix.traced_evaluate(model)(random.PRNGKey(0)) - assert set(trace.keys()) == {"x"} - - -def test_detach(): - def model(x): - return coix.rv(dist.Delta(x, 0.0), name="x")(None) * x - - x = 2.0 - np.testing.assert_allclose(jax.grad(coix.detach(model))(x), x) - - -def test_detach_vmap(): - def model(x): - return coix.rv(dist.Normal(x, 1.0), name="x")(random.PRNGKey(0)) - - outs = coix.detach(jax.vmap(model))(jnp.ones(2)) - np.testing.assert_allclose(outs[0], outs[1]) - - -def test_distribution(): - def model(key): - x = random.normal(key) - return coix.rv(dist.Delta(x, 5.0), name="x")(key) - - f = coix.oryx.call_and_reap_tags( - coix.oryx.tag_distribution(model), coix.oryx.DISTRIBUTION - ) - assert set(f(random.PRNGKey(0))[1][coix.oryx.DISTRIBUTION].keys()) == {"x"} - - -def test_empirical_program(): - def model(x): - trace = { - "x": {"value": x, "log_prob": 11.0}, - "y": {"value": x + 1, "log_prob": 9.0, "is_observed": True}, - } - return coix.empirical(0.0, trace, {})() - - _, trace, _ = coix.traced_evaluate(model)(1.0) - samples = {name: site["value"] for name, site in trace.items()} - jax.tree_util.tree_map( - np.testing.assert_allclose, samples, {"x": 1.0, "y": 2.0} - ) - assert "is_observed" not in trace["x"] - assert trace["y"]["is_observed"] - - -def test_factor(): - def model(x): - return coix.factor(x, name="x") - - _, trace, _ = coix.traced_evaluate(model)(10.0) - assert "x" in trace - np.testing.assert_allclose(trace["x"]["log_prob"], 10.0) - - -def test_log_prob_detach(): - def model(loc): - x = coix.rv(dist.Normal(loc, 1), name="x")(random.PRNGKey(0)) - return x - - def actual_fn(x): - return coix.traced_evaluate(coix.detach(model))(x)[1]["x"]["log_prob"] - - def expected_fn(x): - return dist.Normal(x, 1).log_prob(model(1.0)) - - actual = jax.grad(actual_fn)(1.0) - expect = jax.grad(expected_fn)(1.0) - np.testing.assert_allclose(actual, expect) - - -def test_observed(): - def model(a): - return coix.rv(dist.Delta(2.0, 3.0), obs=1.0, name="x") + a - - _, trace, _ = coix.traced_evaluate(model)(2.0) - assert "x" in trace - np.testing.assert_allclose(trace["x"]["value"], 1.0) - assert trace["x"]["is_observed"] - - -def test_stick_the_landing(): - def model(lp): - return coix.rv(dist.Delta(0.0, lp), name="x")(None) - - def p(x): - return coix.traced_evaluate(coix.detach(model))(x)[1]["x"]["log_prob"] - - def q(x): - model_stl = coix.detach(coix.stick_the_landing(model)) - return coix.traced_evaluate(model_stl)(x)[1]["x"]["log_prob"] - - np.testing.assert_allclose(jax.grad(p)(5.0), 1.0) - np.testing.assert_allclose(jax.grad(q)(5.0), 0.0) - - -def test_substitute(): - def model(key): - return coix.rv(dist.Delta(1.0, 5.0), name="x")(key) - - expected = {"x": 9.0} - _, trace, _ = coix.traced_evaluate(model, expected)(random.PRNGKey(0)) - actual = {"x": trace["x"]["value"]} - jax.tree_util.tree_map(np.testing.assert_allclose, actual, expected) - - -def test_suffix(): - def model(x): - return coix.rv(dist.Delta(x, 5.0), name="x")(None) - - f = coix.oryx.call_and_reap_tags( - coix.core.suffix(model), coix.oryx.RANDOM_VARIABLE - ) - jax.tree_util.tree_map( - np.testing.assert_allclose, - f(1.0)[1][coix.oryx.RANDOM_VARIABLE], - {"x_PREV_": 1.0}, - ) diff --git a/coix/util.py b/coix/util.py index 2ebf2ac..9a8ce02 100644 --- a/coix/util.py +++ b/coix/util.py @@ -1,17 +1,3 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """Utilities.""" import functools @@ -74,14 +60,9 @@ def can_extract_key(args): return ( args and isinstance(args[0], jnp.ndarray) - and ( - jax.dtypes.issubdtype(args[0].dtype, jax.dtypes.prng_key) - or ( - (args[0].dtype == jnp.uint32) - and (jnp.ndim(args[0]) >= 1) - and (args[0].shape[-1] == 2) - ) - ) + and (args[0].dtype == jnp.uint32) + and (jnp.ndim(args[0]) >= 1) + and (args[0].shape[-1] == 2) ) @@ -136,11 +117,6 @@ def __call__(self, *args, **kwargs): return self.module.apply(self.params, *args, **kwargs) -def _skip_update(grad, opt_state, params): - del params - return jax.tree_util.tree_map(jnp.zeros_like, grad), opt_state - - def train( loss_fn, init_params, @@ -151,8 +127,6 @@ def train( jit_compile=True, eval_fn=None, log_every=None, - init_step=0, - opt_state=None, **kwargs, ): """Optimize the parameters.""" @@ -164,15 +138,10 @@ def step_fn(params, opt_state, *args, **kwargs): grads = jax.tree_util.tree_map( lambda x, y: x.astype(y.dtype), grads, params ) - # Helpful metric to print out during training. - squared_grad_norm = sum( - jnp.square(p).sum() for p in jax.tree_util.tree_leaves(grads) - ) - metrics["squared_grad_norm"] = squared_grad_norm updates, opt_state = jax.lax.cond( jnp.isfinite(jax.flatten_util.ravel_pytree(grads)[0]).all(), optimizer.update, - _skip_update, + lambda g, o, p: (jax.tree_util.tree_map(jnp.zeros_like, g), o), grads, opt_state, params, @@ -184,27 +153,27 @@ def step_fn(params, opt_state, *args, **kwargs): maybe_jitted_step_fn = jit_compile(step_fn) else: maybe_jitted_step_fn = jax.jit(step_fn) if jit_compile else step_fn - opt_state = optimizer.init(init_params) if opt_state is None else opt_state + opt_state = optimizer.init(init_params) params = init_params - run_key = random.PRNGKey(seed) if isinstance(seed, int) else seed + run_key = random.PRNGKey(seed) log_every = max(num_steps // 20, 1) if log_every is None else log_every space = str(len(str(num_steps - 1))) kwargs = kwargs.copy() if eval_fn is not None: print("Evaluating with the initial params...", flush=True) tic = time.time() - eval_fn(init_step, params, opt_state, metrics=None) + eval_fn(0, params, **kwargs) print("Time to compile an eval step:", time.time() - tic, flush=True) print("Compiling the first train step...", flush=True) tic = time.time() metrics = None - for step in range(init_step + 1, num_steps + 1): + for step in range(1, num_steps + 1): key = random.fold_in(run_key, step) args = (key, next(dataloader)) if dataloader is not None else (key,) params, opt_state, metrics = maybe_jitted_step_fn( params, opt_state, *args, **kwargs ) - for name in kwargs: + for name, value in kwargs.items(): if name in metrics: kwargs[name] = metrics[name] if step == 1: @@ -216,10 +185,10 @@ def step_fn(params, opt_state, *args, **kwargs): if np.isscalar(value) or ( isinstance(value, (np.ndarray, jnp.ndarray)) and (value.ndim == 0) ): - log += f" | {name} {float(value):10.4f}" + log += f" | {name} {value:10.4f}" print(log, flush=True) if eval_fn is not None: - eval_fn(step, params, opt_state, metrics) + eval_fn(step, params, **kwargs) return params, metrics @@ -247,30 +216,3 @@ def desuffix(trace): raw_name = names_to_raw_names[name] new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] return new_trace - - -def get_batch_ndims(xs): - """Gets the number of same-size leading dimensions of the elements in xs.""" - if not xs: - return 0 - min_ndim = min(jnp.ndim(lp) for lp in xs) - batch_ndims = 0 - for i in range(min_ndim): - if len(set(jnp.shape(lp)[i] for lp in xs)) > 1: - break - batch_ndims = batch_ndims + 1 - return batch_ndims - - -def get_log_weight(trace, batch_ndims): - """Computes log weight of the trace and keeps its batch dimensions.""" - log_weight = jnp.zeros((1,) * batch_ndims) - for site in trace.values(): - lp = get_site_log_prob(site) - if is_observed_site(site): - log_weight = log_weight + jnp.sum( - lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0)) - ) - else: - log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims]) - return log_weight diff --git a/coix/util_test.py b/coix/util_test.py index fa0832f..71ffabb 100644 --- a/coix/util_test.py +++ b/coix/util_test.py @@ -1,20 +1,6 @@ -# Copyright 2024 The coix Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +"""Test for util.py.""" -"""Tests for util.py.""" - -import coix +from coix import util import jax import numpy as np import pytest @@ -25,7 +11,7 @@ def test_systematic_resampling_uniform(seed): log_weights = np.zeros(5) rng_key = jax.random.PRNGKey(seed) if seed is not None else None num_samples = 5 - resample_indices = coix.util.get_systematic_resampling_indices( + resample_indices = util.get_systematic_resampling_indices( log_weights, rng_key, num_samples ) np.testing.assert_allclose(resample_indices, np.arange(5))