Skip to content

Commit

Permalink
test jitting
Browse files Browse the repository at this point in the history
  • Loading branch information
rhayes777 committed Jan 8, 2025
1 parent 62f7c54 commit 07f27f6
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions test_autolens/point/model/test_andrew_implementation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
try:
import jax

JAX_INSTALLED = True
except ImportError:
JAX_INSTALLED = False

import numpy as np
import pytest

Expand All @@ -14,23 +21,23 @@ def noise_map():
return np.array([1.0, 1.0])


def test_andrew_implementation(
data,
noise_map,
):
@pytest.fixture
def fit(data, noise_map):
model_positions = np.array(
[
(-1.0749, -1.1),
(1.19117, 1.175),
]
)

fit = Fit(
return Fit(
data=data,
noise_map=noise_map,
model_positions=model_positions,
)


def test_andrew_implementation(fit):
assert np.allclose(
fit.all_permutations_log_likelihoods(),
[
Expand All @@ -41,6 +48,11 @@ def test_andrew_implementation(
assert fit.log_likelihood() == -4.40375330990644


@pytest.mark.skipif(not JAX_INSTALLED, reason="JAX is not installed")
def test_jax(fit):
assert jax.jit(fit.log_likelihood)() == -4.40375330990644


def test_nan_model_positions(
data,
noise_map,
Expand Down

0 comments on commit 07f27f6

Please sign in to comment.