Skip to content

Commit

Permalink
adapt images model image now property
Browse files Browse the repository at this point in the history
  • Loading branch information
Jammy2211 committed Dec 22, 2023
1 parent 55856e9 commit 1273e19
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 21 deletions.
16 changes: 13 additions & 3 deletions autolens/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,27 @@ def make_fit_point_dict_x2_plane():


def make_adapt_galaxy_name_image_dict_7x7():

image_0 = ag.Array2D(
np.full(fill_value=2.0, shape=make_mask_2d_7x7().pixels_in_mask),
mask=make_mask_2d_7x7(),
)

image_1 = ag.Array2D(
np.full(fill_value=3.0, shape=make_mask_2d_7x7().pixels_in_mask),
mask=make_mask_2d_7x7(),
)

adapt_galaxy_name_image_dict = {
"('galaxies', 'lens')": make_adapt_galaxy_image_0_7x7(),
"('galaxies', 'source')": make_adapt_galaxy_image_1_7x7(),
"('galaxies', 'lens')": image_0,
"('galaxies', 'source')": image_1,
}

return adapt_galaxy_name_image_dict


def make_adapt_images_7x7():
return ag.AdaptImages(
model_image=make_adapt_model_image_7x7(),
galaxy_name_image_dict=make_adapt_galaxy_name_image_dict_7x7(),
)

Expand Down
4 changes: 0 additions & 4 deletions test_autolens/aggregator/test_aggregator_fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ def test__fit_imaging__adapt_images(
for fit_gen in fit_pdf_gen:
for fit_list in fit_gen:
i += 1

assert (
fit_list[0].adapt_images.model_image == adapt_images_7x7.model_image
).all()
assert (
list(fit_list[0].adapt_images.galaxy_image_dict.values())[0]
== list(adapt_images_7x7.galaxy_name_image_dict.values())[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ def test__fit_interferometer__adapt_images(
for fit_gen in fit_pdf_gen:
for fit_list in fit_gen:
i += 1

assert (
fit_list[0].adapt_images.model_image == adapt_images_7x7.model_image
).all()
assert (
list(fit_list[0].adapt_images.galaxy_image_dict.values())[0]
== list(adapt_images_7x7.galaxy_name_image_dict.values())[0]
Expand Down
5 changes: 0 additions & 5 deletions test_autolens/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,6 @@ def make_adapt_galaxy_image_0_7x7():
return fixtures.make_adapt_galaxy_image_0_7x7()


@pytest.fixture(name="adapt_model_image_7x7")
def make_adapt_model_image_7x7():
return fixtures.make_adapt_model_image_7x7()


@pytest.fixture(name="adapt_galaxy_name_image_dict_7x7")
def make_adapt_galaxy_name_image_dict_7x7():
return fixtures.make_adapt_galaxy_name_image_dict_7x7()
Expand Down
10 changes: 7 additions & 3 deletions test_autolens/imaging/test_fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test__model_image__with_and_without_psf_blurring(



def test__fit_figure_of_merit(masked_imaging_7x7, masked_imaging_covariance_7x7, adapt_model_image_7x7):
def test__fit_figure_of_merit(masked_imaging_7x7, masked_imaging_covariance_7x7):

g0 = al.Galaxy(
redshift=0.5,
Expand Down Expand Up @@ -214,9 +214,13 @@ def test__fit_figure_of_merit(masked_imaging_7x7, masked_imaging_covariance_7x7,

tracer = al.Tracer.from_galaxies(galaxies=[g0, galaxy_pix])

model_image = al.Array2D(
np.full(fill_value=5.0, shape=masked_imaging_7x7.mask.pixels_in_mask),
mask=masked_imaging_7x7.mask,
)

adapt_images = al.AdaptImages(
model_image=adapt_model_image_7x7,
galaxy_image_dict={galaxy_pix: adapt_model_image_7x7},
galaxy_image_dict={galaxy_pix: model_image},
)

fit = al.FitImaging(
Expand Down
4 changes: 2 additions & 2 deletions test_autolens/lens/test_to_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test__adapt_galaxy_image_pg_list(sub_grid_2d_7x7):

gal_pix = al.Galaxy(redshift=0.5, pixelization=pixelization)

adapt_images = al.AdaptImages(model_image=1, galaxy_image_dict={gal_pix: 1})
adapt_images = al.AdaptImages(galaxy_image_dict={gal_pix: 1})

tracer = al.Tracer.from_galaxies(galaxies=[gal_pix, gal])

Expand All @@ -201,7 +201,7 @@ def test__adapt_galaxy_image_pg_list(sub_grid_2d_7x7):
gal_pix2 = al.Galaxy(redshift=2.0, pixelization=pixelization)

adapt_images = al.AdaptImages(
model_image=1, galaxy_image_dict={gal_pix0: 1, gal_pix1: 2, gal_pix2: 3}
galaxy_image_dict={gal_pix0: 1, gal_pix1: 2, gal_pix2: 3}
)

tracer = al.Tracer.from_galaxies(
Expand Down

0 comments on commit 1273e19

Please sign in to comment.