Skip to content

Commit

Permalink
Merge branch 'ignore_appear_disappear_costs'
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Jan 22, 2024
2 parents 54aa2c1 + a2d7b9b commit c160b19
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
16 changes: 15 additions & 1 deletion motile/costs/appear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,33 @@ class Appear(Costs):
constant:
A constant cost for each node that starts a track.
ignore_attribute:
The name of an optional node attribute that, if it is set and
evaluates to ``True``, will not set the appear costs for that node.
This is useful to allow nodes in the first frame to appear at no
cost.
"""

def __init__(
self, weight: float = 1, attribute: str | None = None, constant: float = 0
self,
weight: float = 1,
attribute: str | None = None,
constant: float = 0,
ignore_attribute: str | None = None,
) -> None:
self.weight = Weight(weight)
self.constant = Weight(constant)
self.attribute = attribute
self.ignore_attribute = ignore_attribute

def apply(self, solver: Solver) -> None:
appear_indicators = solver.get_variables(NodeAppear)

for node, index in appear_indicators.items():
if self.ignore_attribute is not None:
if solver.graph.nodes[node].get(self.ignore_attribute, False):
continue
if self.attribute is not None:
solver.add_variable_cost(
index, solver.graph.nodes[node][self.attribute], self.weight
Expand Down
10 changes: 10 additions & 0 deletions motile/costs/disappear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,23 @@ class Disappear(Costs):
Args:
constant (float):
A constant cost for each node that ends a track.
ignore_attribute:
The name of an optional node attribute that, if it is set and
evaluates to ``True``, will not set the disappear costs for that
node. This is useful to allow nodes in the last frame to disappear
at no cost.
"""

def __init__(self, constant: float) -> None:
self.constant = Weight(constant)
self.ignore_attribute = ignore_attribute

def apply(self, solver: Solver) -> None:
disappear_indicators = solver.get_variables(NodeDisappear)

for index in disappear_indicators.values():
if self.ignore_attribute is not None:
if solver.graph.nodes[node].get(self.ignore_attribute, False):
continue
solver.add_variable_cost(index, 1.0, self.constant)
62 changes: 62 additions & 0 deletions tests/test_costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import motile
from motile.constraints import MaxChildren, MaxParents
from motile.costs import (
Appear,
EdgeDistance,
EdgeSelection,
NodeSelection,
Split,
)
from motile.data import (
arlo_graph,
)


def test_ignore_attributes():
graph = arlo_graph()

# first solve without ignore attribute:

solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_costs(
EdgeSelection(weight=0.5, attribute="prediction_distance", constant=-1.0)
)
solver.add_costs(EdgeDistance(position_attributes=("x",), weight=0.5))
solver.add_costs(Appear(constant=200.0, attribute="score", weight=-1.0))
solver.add_costs(Split(constant=100.0, attribute="score", weight=1.0))

solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(2))

solution = solver.solve()
no_ignore_value = solution.get_value()

# solve and ignore appear costs in frame 0

for first_node in graph.nodes_by_frame(0):
graph.nodes[first_node]["ignore_appear_cost"] = True

solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_costs(
EdgeSelection(weight=0.5, attribute="prediction_distance", constant=-1.0)
)
solver.add_costs(EdgeDistance(position_attributes=("x",), weight=0.5))
solver.add_costs(
Appear(
constant=200.0,
attribute="score",
weight=-1.0,
ignore_attribute="ignore_appear_cost",
)
)
solver.add_costs(Split(constant=100.0, attribute="score", weight=1.0))

solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(2))

solution = solver.solve()
ignore_value = solution.get_value()

assert ignore_value < no_ignore_value

0 comments on commit c160b19

Please sign in to comment.