Skip to content

Commit

Permalink
Merge pull request #5 from jax-ml:key
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551556715
  • Loading branch information
The coix Authors committed Mar 28, 2024
2 parents 766b9ca + cd6103c commit ed92054
Show file tree
Hide file tree
Showing 15 changed files with 86 additions and 834 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# coix

[![Unittests](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml)
[![Documentation Status](https://readthedocs.org/projects/coix/badge/?version=latest)](https://coix.readthedocs.io/en/latest/?badge=latest)
[![PyPI version](https://badge.fury.io/py/coix.svg)](https://badge.fury.io/py/coix)

Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box.
Inference Combinators in JAX (Coix) is a machine learning framework used to
develop inference algorithms that are composed of probabilistic programs.

*This is not an officially supported Google product.*

14 changes: 0 additions & 14 deletions coix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2024 The coix Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""coix API."""

from coix import algo
Expand Down
19 changes: 2 additions & 17 deletions coix/algo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2024 The coix Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Inference algorithms."""

import functools
Expand Down Expand Up @@ -144,7 +130,6 @@ def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None):
if _use_fori_loop(targets, num_targets):

def body_fun(i, q):
assert callable(targets)
p = extend(compose(momentum, targets(i), suffix=False), refreshment)
return propose(p, compose(refreshment, compose(leapfrog, q)))

Expand All @@ -156,7 +141,7 @@ def body_fun(i, q):

targets = [compose(momentum, p, suffix=False) for p in targets]
q = targets[0]
loss_fns = (None,) * (len(targets) - 2) + (iwae_loss,)
loss_fns = [None] * (len(targets) - 2) + [iwae_loss]
for p, loss_fn in zip(targets[1:], loss_fns):
q = compose(refreshment, compose(leapfrog, q))
q = propose(extend(p, refreshment), q, loss_fn=loss_fn)
Expand Down Expand Up @@ -414,7 +399,7 @@ def body_fun(i, q):
return propose(targets(num_targets - 1), q, loss_fn=iwae_loss)

q = propose(targets[0], proposals[0])
loss_fns = (None,) * (len(proposals) - 2) + (iwae_loss,)
loss_fns = [None] * (len(proposals) - 2) + [iwae_loss]
for p, fwd, loss_fn in zip(targets[1:], proposals[1:], loss_fns):
q = propose(p, compose(fwd, resample(q)), loss_fn=loss_fn)
return q
116 changes: 0 additions & 116 deletions coix/algo_test.py

This file was deleted.

Loading

0 comments on commit ed92054

Please sign in to comment.