Skip to content

Commit

Permalink
Merge pull request #52 from fusion-energy/adding_slice_index_to_slice…
Browse files Browse the repository at this point in the history
…_value

added get_slice_index_from_axis_value
  • Loading branch information
shimwell authored Feb 24, 2023
2 parents 87b093c + 0a5d967 commit 06138dc
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/regular_mesh_plotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,32 @@ def get_mpl_plot_extent(self, view_direction: str = "x"):
return (x_min, x_max, y_min, y_max)


def get_slice_axis_value_from_index(self, slice_index: int, view_direction: str):
x_array = self.centroids

index_val = {"x": 0, "y": 1, "z": 2}[view_direction]

if view_direction == "x":
axis_values = x_array[index_val, slice_index, :, :]
if view_direction == "y":
axis_values = x_array[index_val, :, slice_index, :]
if view_direction == "z":
axis_values = x_array[index_val, :, :, slice_index]

# values should all be the same so picking the first
return axis_values[0][0]


# TODO
# def get_slice_index_from_axis_value(
# self,
# axis_value: int,
# view_direction: str
# ):

# return slice_index


def get_side_extent(self, side: str, view_direction: str = "x", bb=None):
if bb is None:
bb = (self.lower_left, self.upper_right)
Expand Down Expand Up @@ -194,6 +220,11 @@ def get_axis_labels(self, view_direction):
return xlabel, ylabel


openmc.RegularMesh.get_slice_axis_value_from_index = get_slice_axis_value_from_index
openmc.mesh.RegularMesh.get_slice_axis_value_from_index = (
get_slice_axis_value_from_index
)

openmc.RegularMesh.reshape_data = reshape_data
openmc.mesh.RegularMesh.reshape_data = reshape_data

Expand Down
108 changes: 108 additions & 0 deletions tests/test_slice_index_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import openmc
import regular_mesh_plotter


def test_simple_mesh():
mesh = openmc.RegularMesh()
mesh.lower_left = (-10, -10, -10)
mesh.upper_right = (10, 10, 10)
mesh.dimension = (10, 10, 10)

assert (
mesh.get_slice_axis_value_from_index(view_direction="x", slice_index=5) == 1.0
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="x", slice_index=4) == -1.0
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="y", slice_index=5) == 1.0
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="y", slice_index=4) == -1.0
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="z", slice_index=5) == 1.0
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="z", slice_index=4) == -1.0
)

mesh = openmc.RegularMesh()
mesh.lower_left = (-5, -5, -5)
mesh.upper_right = (5, 5, 5)
mesh.dimension = (10, 10, 10)

assert (
mesh.get_slice_axis_value_from_index(view_direction="x", slice_index=5) == 0.5
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="x", slice_index=4) == -0.5
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="y", slice_index=5) == 0.5
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="y", slice_index=4) == -0.5
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="z", slice_index=5) == 0.5
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="z", slice_index=4) == -0.5
)


def test_non_central_mesh():
mesh = openmc.RegularMesh()
mesh.lower_left = (0, 0, 0)
mesh.upper_right = (10, 10, 10)
mesh.dimension = (10, 10, 10)

assert (
mesh.get_slice_axis_value_from_index(view_direction="x", slice_index=5) == 5.5
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="x", slice_index=4) == 4.5
)


def test_unequal_dimensions():
mesh = openmc.RegularMesh()
mesh.lower_left = (0, 0, 0)
mesh.upper_right = (10, 10, 10)
mesh.dimension = (1, 2, 4)

assert (
mesh.get_slice_axis_value_from_index(view_direction="x", slice_index=0) == 5.0
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="y", slice_index=0) == 2.5
)
assert (
mesh.get_slice_axis_value_from_index(view_direction="z", slice_index=0) == 1.25
)


# TODO
# def test_simple_mesh():

# mesh = openmc.RegularMesh()
# mesh.lower_left = (-10, -10, -10)
# mesh.upper_right = (10, 10, 10)
# mesh.dimension = (10, 10, 10)

# assert get_slice_index_from_axis_value(view_direction='x', axis_value=-10) == 0
# assert get_slice_index_from_axis_value(view_direction='y', axis_value=-10) == 0
# assert get_slice_index_from_axis_value(view_direction='z', axis_value=-10) == 0

# assert get_slice_index_from_axis_value(view_direction='x', axis_value=10) == 10
# assert get_slice_index_from_axis_value(view_direction='y', axis_value=10) == 10
# assert get_slice_index_from_axis_value(view_direction='z', axis_value=10) == 10

# assert get_slice_index_from_axis_value(view_direction='x', axis_value=0.1) == 5
# assert get_slice_index_from_axis_value(view_direction='y', axis_value=0.1) == 5
# assert get_slice_index_from_axis_value(view_direction='z', axis_value=0.1) == 5

# assert get_slice_index_from_axis_value(view_direction='x', axis_value=-0.1) == 4
# assert get_slice_index_from_axis_value(view_direction='y', axis_value=-0.1) == 4
# assert get_slice_index_from_axis_value(view_direction='z', axis_value=-0.1) == 4

0 comments on commit 06138dc

Please sign in to comment.