Skip to content

Commit

Permalink
Make Cut2D accept axis specification and no-axis specification. Test …
Browse files Browse the repository at this point in the history
…both cases.
  • Loading branch information
gwm17 committed Jun 14, 2024
1 parent bfdfad1 commit 548aed3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
16 changes: 8 additions & 8 deletions src/spyral_utils/plot/cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,21 +355,21 @@ def deserialize_cut(filepath: Path) -> Cut2D | None:
with open(filepath, "r") as input:
buffer = input.read()
cut_dict = json.loads(buffer)
if not (
"name" in cut_dict
and "vertices" in cut_dict
and "xaxis" in cut_dict
and "yaxis" in cut_dict
):
if not ("name" in cut_dict and "vertices" in cut_dict):
print(
f"Data in file {filepath} is not the right format for Cut2D, could not load"
)
return None
xaxis = DEFAULT_CUT_AXIS
yaxis = DEFAULT_CUT_AXIS
if "xaxis" in cut_dict and "yaxis" in cut_dict:
xaxis = cut_dict["xaxis"]
yaxis = cut_dict["yaxis"]
return Cut2D(
cut_dict["name"],
cut_dict["vertices"],
x_axis=cut_dict["xaxis"],
y_axis=cut_dict["yaxis"],
x_axis=xaxis,
y_axis=yaxis,
)
except Exception as error:
print(
Expand Down
4 changes: 2 additions & 2 deletions tests/cut.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "test_cut",
"xaxis": "my_x",
"yaxis": "my_y",
"xaxis": "x",
"yaxis": "y",
"vertices": [
[
0.0,
Expand Down
25 changes: 25 additions & 0 deletions tests/cut_noaxis.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"name": "test_cut",
"vertices": [
[
0.0,
0.0
],
[
1.0,
0.0
],
[
1.0,
1.0
],
[
0.0,
1.0
],
[
0.0,
0.0
]
]
}
15 changes: 15 additions & 0 deletions tests/test_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ def test_cut():
handler = CutHandler()
df = pl.DataFrame({"x": [0.4, 0.2], "y": [0.4, 0.2]})

assert isinstance(cut, Cut2D)
assert cut.is_point_inside(0.5, 0.5)
assert not cut.is_point_inside(-1.0, -1.0)
df_gated = df.filter(
pl.struct([cut.get_x_axis(), cut.get_y_axis()]).map_batches(cut.is_cols_inside)
)
rows = len(df_gated.select("x").to_numpy())
assert rows == 2


def test_cut_noaxis():
cut = deserialize_cut(CUT_JSON_PATH)
handler = CutHandler()
df = pl.DataFrame({"x": [0.4, 0.2], "y": [0.4, 0.2]})

assert isinstance(cut, Cut2D)
assert cut.is_point_inside(0.5, 0.5)
assert not cut.is_point_inside(-1.0, -1.0)
Expand Down

0 comments on commit 548aed3

Please sign in to comment.