Skip to content

Commit

Permalink
feat: add get_selected_solution
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Feb 26, 2024
1 parent aa49bf6 commit 5613843
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions motile/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,46 @@ def _on_weights_modified(self, old_value: float | None, new_value: float) -> Non
logger.info("Weights have changed")

self._weights_changed = True

def get_selected_solution(
self, solution: ilpy.Solution | None = None
) -> TrackGraph:
"""Return TrackGraph with only the selected nodes/edges from the solution.
Args:
solution:
The solution to use. If not provided, the last solution is used.
Returns:
A new TrackGraph with only the selected nodes and edges.
Raises:
RuntimeError: If no solution is provided and the solver has not been solved
yet.
"""
from motile.variables import EdgeSelected, NodeSelected

if solution is None:
solution = self.solution

# TODO:
# in theory this could be made more efficient by using a nx.DiGraph view
# but TrackGraph itself doesn't provide views (and isn't a subclass)
if not solution:
raise RuntimeError(
"No solution available. Run solve() first or manually pass a solution."
)

node_selected = self.get_variables(NodeSelected)
edge_selected = self.get_variables(EdgeSelected)
selected_graph = TrackGraph()

Check warning on line 301 in motile/solver.py

View check run for this annotation

Codecov / codecov/patch

motile/solver.py#L301

Added line #L301 was not covered by tests

for node_id, node in self.graph.nodes.items():
if solution[node_selected[node_id]]:

Check warning on line 304 in motile/solver.py

View check run for this annotation

Codecov / codecov/patch

motile/solver.py#L303-L304

Added lines #L303 - L304 were not covered by tests
selected_graph.add_node(node_id, node)

for edge_id, edge in self.graph.edges.items():
if solution[edge_selected[edge_id]]:
selected_graph.add_edge(edge_id, edge)

Check warning on line 310 in motile/solver.py

View check run for this annotation

Codecov / codecov/patch

motile/solver.py#L309-L310

Added lines #L309 - L310 were not covered by tests
return selected_graph

0 comments on commit 5613843

Please sign in to comment.