Skip to content

Commit

Permalink
fix jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Jammy2211 committed Dec 18, 2024
1 parent d876f51 commit e99d806
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 134 deletions.
18 changes: 0 additions & 18 deletions test_autolens/point/triangles/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import autolens as al
import autogalaxy as ag
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
CoordinateArrayTriangles as JAXTriangles,
)
from autolens.mock import NullTracer
from autolens.point.solver import PointSolver

Expand Down Expand Up @@ -83,21 +80,6 @@ def triangle_set(triangles):
}


def test_real_example_jax(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=JAXTriangles,
)

result = jax_solver.solve(
tracer=tracer,
source_plane_coordinate=(0.07, 0.07),
)

assert len(result) == 5


def test_real_example_normal(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
Expand Down
247 changes: 131 additions & 116 deletions test_autolens/point/triangles/test_solver_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,119 +7,134 @@
import autofit as af
import numpy as np
from autolens import PointSolver, Tracer

try:
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
CoordinateArrayTriangles,
)

except ImportError:
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles

from autolens.mock import NullTracer

pytest.importorskip("jax")


@pytest.fixture(autouse=True)
def register(tracer):
af.Model.from_instance(tracer)


@pytest.fixture
def solver(grid):
return PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.01,
array_triangles_cls=CoordinateArrayTriangles,
)


def test_solver(solver):
mass_profile = ag.mp.Isothermal(
centre=(0.0, 0.0),
einstein_radius=1.0,
)
tracer = Tracer(
galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)],
)
result = solver.solve(
tracer,
source_plane_coordinate=(0.0, 0.0),
)
print(result)
assert result


@pytest.mark.parametrize(
"source_plane_coordinate",
[
(0.0, 0.0),
(0.0, 1.0),
(1.0, 0.0),
(1.0, 1.0),
(0.5, 0.5),
(0.1, 0.1),
(-1.0, -1.0),
],
)
def test_trivial(
source_plane_coordinate: Tuple[float, float],
grid,
solver,
):
coordinates = solver.solve(
NullTracer(),
source_plane_coordinate=source_plane_coordinate,
)
coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)]
assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1)


def test_real_example(grid, tracer):
solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=CoordinateArrayTriangles,
)

result = solver.solve(tracer, (0.07, 0.07))
assert len(result) == 5


def _test_jax(grid):
sizes = (5, 10, 15, 20, 25, 30, 35, 40, 45, 50)
run_times = []
init_times = []

for size in sizes:
start = time.time()
solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=CoordinateArrayTriangles,
max_containing_size=size,
)

solver.solve(NullTracer(), (0.07, 0.07))

repeats = 100

done_init_time = time.time()
init_time = done_init_time - start
for _ in range(repeats):
_ = solver.solve(NullTracer(), (0.07, 0.07))

# print(result)

init_times.append(init_time)

run_time = (time.time() - done_init_time) / repeats
run_times.append(run_time)

print(f"Time taken for {size}: {run_time} ({init_time} to init)")

from matplotlib import pyplot as plt

plt.plot(sizes, run_times)
plt.show()
#
# try:
# from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
# CoordinateArrayTriangles,
# )
#
# except ImportError:
# from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
#
# from autolens.mock import NullTracer
#
# pytest.importorskip("jax")
#
#
# @pytest.fixture(autouse=True)
# def register(tracer):
# af.Model.from_instance(tracer)
#
#
# @pytest.fixture
# def solver(grid):
# return PointSolver.for_grid(
# grid=grid,
# pixel_scale_precision=0.01,
# array_triangles_cls=CoordinateArrayTriangles,
# )
#
#
# def test_solver(solver):
# mass_profile = ag.mp.Isothermal(
# centre=(0.0, 0.0),
# einstein_radius=1.0,
# )
# tracer = Tracer(
# galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)],
# )
# result = solver.solve(
# tracer,
# source_plane_coordinate=(0.0, 0.0),
# )
# print(result)
# assert result
#
#
# @pytest.mark.parametrize(
# "source_plane_coordinate",
# [
# (0.0, 0.0),
# (0.0, 1.0),
# (1.0, 0.0),
# (1.0, 1.0),
# (0.5, 0.5),
# (0.1, 0.1),
# (-1.0, -1.0),
# ],
# )
# def test_trivial(
# source_plane_coordinate: Tuple[float, float],
# grid,
# solver,
# ):
# coordinates = solver.solve(
# NullTracer(),
# source_plane_coordinate=source_plane_coordinate,
# )
# coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)]
# assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1)
#
#
# def test_real_example(grid, tracer):
# solver = PointSolver.for_grid(
# grid=grid,
# pixel_scale_precision=0.001,
# array_triangles_cls=CoordinateArrayTriangles,
# )
#
# result = solver.solve(tracer, (0.07, 0.07))
# assert len(result) == 5
#
#
# def _test_jax(grid):
# sizes = (5, 10, 15, 20, 25, 30, 35, 40, 45, 50)
# run_times = []
# init_times = []
#
# for size in sizes:
# start = time.time()
# solver = PointSolver.for_grid(
# grid=grid,
# pixel_scale_precision=0.001,
# array_triangles_cls=CoordinateArrayTriangles,
# max_containing_size=size,
# )
#
# solver.solve(NullTracer(), (0.07, 0.07))
#
# repeats = 100
#
# done_init_time = time.time()
# init_time = done_init_time - start
# for _ in range(repeats):
# _ = solver.solve(NullTracer(), (0.07, 0.07))
#
# # print(result)
#
# init_times.append(init_time)
#
# run_time = (time.time() - done_init_time) / repeats
# run_times.append(run_time)
#
# print(f"Time taken for {size}: {run_time} ({init_time} to init)")
#
# from matplotlib import pyplot as plt
#
# plt.plot(sizes, run_times)
# plt.show()
#
#
# def test_real_example_jax(grid, tracer):
# jax_solver = PointSolver.for_grid(
# grid=grid,
# pixel_scale_precision=0.001,
# array_triangles_cls=CoordinateArrayTriangles,
# )
#
# result = jax_solver.solve(
# tracer=tracer,
# source_plane_coordinate=(0.07, 0.07),
# )
#
# assert len(result) == 5

0 comments on commit e99d806

Please sign in to comment.