diff --git a/motile/data.py b/tests/conftest.py similarity index 93% rename from motile/data.py rename to tests/conftest.py index 81f6f85..477b93b 100644 --- a/motile/data.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import networkx as nx - +import pytest from motile import TrackGraph +@pytest.fixture def arlo_graph_nx() -> nx.DiGraph: """Create the "Arlo graph", a simple toy graph for testing. @@ -44,11 +45,13 @@ def arlo_graph_nx() -> nx.DiGraph: return nx_graph -def arlo_graph() -> TrackGraph: +@pytest.fixture +def arlo_graph(arlo_graph_nx) -> TrackGraph: """Return the "Arlo graph" as a :class:`motile.TrackGraph` instance.""" - return TrackGraph(arlo_graph_nx()) + return TrackGraph(arlo_graph_nx) +@pytest.fixture def toy_graph_nx() -> nx.DiGraph: """Return variation of the "Arlo graph". @@ -93,11 +96,13 @@ def toy_graph_nx() -> nx.DiGraph: return nx_graph -def toy_graph() -> TrackGraph: +@pytest.fixture +def toy_graph(toy_graph_nx) -> TrackGraph: """Return the `toy_graph_nx` as a :class:`motile.TrackGraph` instance.""" - return TrackGraph(toy_graph_nx()) + return TrackGraph(toy_graph_nx) +@pytest.fixture def toy_hypergraph_nx() -> nx.DiGraph: """Return variation of `toy_graph` with an edge modified and one hyperedge added. @@ -151,6 +156,7 @@ def toy_hypergraph_nx() -> nx.DiGraph: return nx_graph -def toy_hypergraph() -> TrackGraph: +@pytest.fixture +def toy_hypergraph(toy_hypergraph_nx) -> TrackGraph: """Return the `toy_hypergraph_nx` as a :class:`motile.TrackGraph` instance.""" - return TrackGraph(toy_hypergraph_nx()) + return TrackGraph(toy_hypergraph_nx) diff --git a/tests/test_api.py b/tests/test_api.py index fab1f27..46786ca 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -9,12 +9,6 @@ NodeSelection, Split, ) -from motile.data import ( - arlo_graph, - arlo_graph_nx, - toy_hypergraph, - toy_hypergraph_nx, -) from motile.variables import EdgeSelected, NodeSelected @@ -30,15 +24,15 @@ def _selected_edges(solver: motile.Solver) -> list: return sorted([e for e, i in edge_indicators.items() if solution[i] > 0.5]) -def test_graph_creation_with_hyperedges(): - graph = toy_hypergraph() +def test_graph_creation_with_hyperedges(toy_hypergraph): + graph = toy_hypergraph assert len(graph.nodes) == 7 assert len(graph.edges) == 10 -def test_graph_creation_from_multiple_nx_graphs(): - g1 = toy_hypergraph_nx() - g2 = arlo_graph_nx() +def test_graph_creation_from_multiple_nx_graphs(toy_hypergraph_nx, arlo_graph_nx): + g1 = toy_hypergraph_nx + g2 = arlo_graph_nx graph = motile.TrackGraph() graph.add_from_nx_graph(g1) @@ -54,8 +48,8 @@ def test_graph_creation_from_multiple_nx_graphs(): assert "prediction_distance" in graph.edges[(0, 2)] -def test_solver(): - graph = arlo_graph() +def test_solver(arlo_graph): + graph = arlo_graph solver = motile.Solver(graph) solver.add_cost(NodeSelection(weight=-1.0, attribute="score", constant=-100.0)) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 6eb1baa..a4db99d 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -8,19 +8,18 @@ Pin, ) from motile.costs import Appear, EdgeSelection, NodeSelection -from motile.data import arlo_graph, arlo_graph_nx from .test_api import _selected_edges, _selected_nodes @pytest.fixture -def solver(): - return motile.Solver(arlo_graph()) +def solver(arlo_graph): + return motile.Solver(arlo_graph) -def test_graph_casting() -> None: +def test_graph_casting(arlo_graph_nx) -> None: with pytest.warns(UserWarning, match="Coercing networkx graph to TrackGraph"): - motile.Solver(arlo_graph_nx()) + motile.Solver(arlo_graph_nx) def test_pin(solver: motile.Solver) -> None: diff --git a/tests/test_costs.py b/tests/test_costs.py index 57145d2..9eafbb4 100644 --- a/tests/test_costs.py +++ b/tests/test_costs.py @@ -7,13 +7,10 @@ NodeSelection, Split, ) -from motile.data import ( - arlo_graph, -) -def test_ignore_attributes(): - graph = arlo_graph() +def test_ignore_attributes(arlo_graph): + graph = arlo_graph # first solve without ignore attribute: diff --git a/tests/test_plot.py b/tests/test_plot.py index 3fa047f..aec8cb5 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,7 +1,6 @@ import motile import pytest from motile.costs import Appear, EdgeSelection, NodeSelection, Split -from motile.data import arlo_graph from motile.plot import draw_solution, draw_track_graph try: @@ -11,13 +10,8 @@ @pytest.fixture -def graph() -> motile.TrackGraph: - return arlo_graph() - - -@pytest.fixture -def solver(graph: motile.TrackGraph) -> motile.Solver: - solver = motile.Solver(graph) +def solver(arlo_graph: motile.TrackGraph) -> motile.Solver: + solver = motile.Solver(arlo_graph) solver.add_cost(NodeSelection(weight=-1.0, attribute="score", constant=-100.0)) solver.add_cost(EdgeSelection(weight=1.0, attribute="prediction_distance")) solver.add_cost(Appear(constant=200.0)) @@ -25,7 +19,8 @@ def solver(graph: motile.TrackGraph) -> motile.Solver: return solver -def test_plot_graph(graph: motile.TrackGraph) -> None: +def test_plot_graph(arlo_graph: motile.TrackGraph) -> None: + graph = arlo_graph assert isinstance(draw_track_graph(graph), go.Figure) assert isinstance(draw_track_graph(graph, alpha_attribute="score"), go.Figure) assert isinstance(draw_track_graph(graph, alpha_func=lambda _: 0.5), go.Figure) @@ -40,7 +35,8 @@ def label_func(node): ) -def test_plot_solution(graph: motile.TrackGraph, solver: motile.Solver) -> None: +def test_plot_solution(arlo_graph: motile.TrackGraph, solver: motile.Solver) -> None: + graph = arlo_graph with pytest.raises(RuntimeError): draw_solution(graph, solver) solver.solve() diff --git a/tests/test_structsvm.py b/tests/test_structsvm.py index eb0815f..a8845a9 100644 --- a/tests/test_structsvm.py +++ b/tests/test_structsvm.py @@ -5,7 +5,6 @@ import numpy as np from motile.constraints import MaxChildren, MaxParents from motile.costs import Appear, EdgeSelection, NodeSelection -from motile.data import toy_graph from motile.variables import EdgeSelected, NodeSelected logger = logging.getLogger(__name__) @@ -56,8 +55,8 @@ def create_toy_solver(graph): return solver -def test_structsvm_common_toy_example(): - graph = toy_graph() +def test_structsvm_common_toy_example(toy_graph): + graph = toy_graph solver = create_toy_solver(graph) diff --git a/tests/test_variables.py b/tests/test_variables.py index e90be51..705d9ca 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -4,14 +4,16 @@ import ilpy import pytest -from motile import Solver, data +from motile import Solver, TrackGraph from motile.variables import Variable @pytest.mark.parametrize("VarCls", Variable.__subclasses__()) -def test_variable_subclass_protocols(VarCls: type[Variable]) -> None: +def test_variable_subclass_protocols( + arlo_graph: TrackGraph, VarCls: type[Variable] +) -> None: """Test that all Variable subclasses properly implement the Variable protocol.""" - solver = Solver(data.arlo_graph()) + solver = Solver(arlo_graph) keys = VarCls.instantiate(solver) assert isinstance(keys, Collection)