Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pytest fixtures to serve the test data #109

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions motile/data.py → tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import networkx as nx

import pytest
from motile import TrackGraph


@pytest.fixture
def arlo_graph_nx() -> nx.DiGraph:
"""Create the "Arlo graph", a simple toy graph for testing.

Expand Down Expand Up @@ -44,11 +45,13 @@ def arlo_graph_nx() -> nx.DiGraph:
return nx_graph


def arlo_graph() -> TrackGraph:
@pytest.fixture
def arlo_graph(arlo_graph_nx) -> TrackGraph:
"""Return the "Arlo graph" as a :class:`motile.TrackGraph` instance."""
return TrackGraph(arlo_graph_nx())
return TrackGraph(arlo_graph_nx)


@pytest.fixture
def toy_graph_nx() -> nx.DiGraph:
"""Return variation of the "Arlo graph".

Expand Down Expand Up @@ -93,11 +96,13 @@ def toy_graph_nx() -> nx.DiGraph:
return nx_graph


def toy_graph() -> TrackGraph:
@pytest.fixture
def toy_graph(toy_graph_nx) -> TrackGraph:
"""Return the `toy_graph_nx` as a :class:`motile.TrackGraph` instance."""
return TrackGraph(toy_graph_nx())
return TrackGraph(toy_graph_nx)


@pytest.fixture
def toy_hypergraph_nx() -> nx.DiGraph:
"""Return variation of `toy_graph` with an edge modified and one hyperedge added.

Expand Down Expand Up @@ -151,6 +156,7 @@ def toy_hypergraph_nx() -> nx.DiGraph:
return nx_graph


def toy_hypergraph() -> TrackGraph:
@pytest.fixture
def toy_hypergraph(toy_hypergraph_nx) -> TrackGraph:
"""Return the `toy_hypergraph_nx` as a :class:`motile.TrackGraph` instance."""
return TrackGraph(toy_hypergraph_nx())
return TrackGraph(toy_hypergraph_nx)
20 changes: 7 additions & 13 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@
NodeSelection,
Split,
)
from motile.data import (
arlo_graph,
arlo_graph_nx,
toy_hypergraph,
toy_hypergraph_nx,
)
from motile.variables import EdgeSelected, NodeSelected


Expand All @@ -30,15 +24,15 @@ def _selected_edges(solver: motile.Solver) -> list:
return sorted([e for e, i in edge_indicators.items() if solution[i] > 0.5])


def test_graph_creation_with_hyperedges():
graph = toy_hypergraph()
def test_graph_creation_with_hyperedges(toy_hypergraph):
graph = toy_hypergraph
assert len(graph.nodes) == 7
assert len(graph.edges) == 10


def test_graph_creation_from_multiple_nx_graphs():
g1 = toy_hypergraph_nx()
g2 = arlo_graph_nx()
def test_graph_creation_from_multiple_nx_graphs(toy_hypergraph_nx, arlo_graph_nx):
g1 = toy_hypergraph_nx
g2 = arlo_graph_nx
graph = motile.TrackGraph()

graph.add_from_nx_graph(g1)
Expand All @@ -54,8 +48,8 @@ def test_graph_creation_from_multiple_nx_graphs():
assert "prediction_distance" in graph.edges[(0, 2)]


def test_solver():
graph = arlo_graph()
def test_solver(arlo_graph):
graph = arlo_graph

solver = motile.Solver(graph)
solver.add_cost(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
Expand Down
9 changes: 4 additions & 5 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@
Pin,
)
from motile.costs import Appear, EdgeSelection, NodeSelection
from motile.data import arlo_graph, arlo_graph_nx

from .test_api import _selected_edges, _selected_nodes


@pytest.fixture
def solver():
return motile.Solver(arlo_graph())
def solver(arlo_graph):
return motile.Solver(arlo_graph)


def test_graph_casting() -> None:
def test_graph_casting(arlo_graph_nx) -> None:
with pytest.warns(UserWarning, match="Coercing networkx graph to TrackGraph"):
motile.Solver(arlo_graph_nx())
motile.Solver(arlo_graph_nx)


def test_pin(solver: motile.Solver) -> None:
Expand Down
7 changes: 2 additions & 5 deletions tests/test_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
NodeSelection,
Split,
)
from motile.data import (
arlo_graph,
)


def test_ignore_attributes():
graph = arlo_graph()
def test_ignore_attributes(arlo_graph):
graph = arlo_graph

# first solve without ignore attribute:

Expand Down
16 changes: 6 additions & 10 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import motile
import pytest
from motile.costs import Appear, EdgeSelection, NodeSelection, Split
from motile.data import arlo_graph
from motile.plot import draw_solution, draw_track_graph

try:
Expand All @@ -11,21 +10,17 @@


@pytest.fixture
def graph() -> motile.TrackGraph:
return arlo_graph()


@pytest.fixture
def solver(graph: motile.TrackGraph) -> motile.Solver:
solver = motile.Solver(graph)
def solver(arlo_graph: motile.TrackGraph) -> motile.Solver:
solver = motile.Solver(arlo_graph)
solver.add_cost(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_cost(EdgeSelection(weight=1.0, attribute="prediction_distance"))
solver.add_cost(Appear(constant=200.0))
solver.add_cost(Split(constant=100.0))
return solver


def test_plot_graph(graph: motile.TrackGraph) -> None:
def test_plot_graph(arlo_graph: motile.TrackGraph) -> None:
graph = arlo_graph
assert isinstance(draw_track_graph(graph), go.Figure)
assert isinstance(draw_track_graph(graph, alpha_attribute="score"), go.Figure)
assert isinstance(draw_track_graph(graph, alpha_func=lambda _: 0.5), go.Figure)
Expand All @@ -40,7 +35,8 @@ def label_func(node):
)


def test_plot_solution(graph: motile.TrackGraph, solver: motile.Solver) -> None:
def test_plot_solution(arlo_graph: motile.TrackGraph, solver: motile.Solver) -> None:
graph = arlo_graph
with pytest.raises(RuntimeError):
draw_solution(graph, solver)
solver.solve()
Expand Down
5 changes: 2 additions & 3 deletions tests/test_structsvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
from motile.constraints import MaxChildren, MaxParents
from motile.costs import Appear, EdgeSelection, NodeSelection
from motile.data import toy_graph
from motile.variables import EdgeSelected, NodeSelected

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,8 +55,8 @@ def create_toy_solver(graph):
return solver


def test_structsvm_common_toy_example():
graph = toy_graph()
def test_structsvm_common_toy_example(toy_graph):
graph = toy_graph

solver = create_toy_solver(graph)

Expand Down
8 changes: 5 additions & 3 deletions tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

import ilpy
import pytest
from motile import Solver, data
from motile import Solver, TrackGraph
from motile.variables import Variable


@pytest.mark.parametrize("VarCls", Variable.__subclasses__())
def test_variable_subclass_protocols(VarCls: type[Variable]) -> None:
def test_variable_subclass_protocols(
arlo_graph: TrackGraph, VarCls: type[Variable]
) -> None:
"""Test that all Variable subclasses properly implement the Variable protocol."""
solver = Solver(data.arlo_graph())
solver = Solver(arlo_graph)

keys = VarCls.instantiate(solver)
assert isinstance(keys, Collection)
Expand Down