From 2c15221e41b2f6c2efbdd425790a6cba45779906 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 Nov 2023 10:39:53 -0400 Subject: [PATCH] refactor: minimize ilpy --- motile/constraints/constraint.py | 8 +- motile/constraints/expression.py | 23 +- motile/constraints/max_children.py | 2 +- motile/constraints/max_parents.py | 2 +- motile/costs/features.py | 7 +- motile/expressions.py | 508 +++++++++++++++++++++++++++++ motile/solver.py | 40 ++- motile/ssvm.py | 2 +- motile/variables/__init__.py | 4 +- motile/variables/edge_selected.py | 4 +- motile/variables/node_appear.py | 9 +- motile/variables/node_disappear.py | 9 +- motile/variables/node_selected.py | 4 +- motile/variables/node_split.py | 33 +- motile/variables/variable.py | 14 +- tests/test_variables.py | 10 +- 16 files changed, 591 insertions(+), 88 deletions(-) create mode 100644 motile/expressions.py diff --git a/motile/constraints/constraint.py b/motile/constraints/constraint.py index 3b72095..595f9bd 100644 --- a/motile/constraints/constraint.py +++ b/motile/constraints/constraint.py @@ -3,9 +3,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Iterable -if TYPE_CHECKING: - import ilpy +from motile.expressions import Expression +if TYPE_CHECKING: from motile.solver import Solver @@ -13,9 +13,7 @@ class Constraint(ABC): """A base class for a constraint that can be added to a solver.""" @abstractmethod - def instantiate( - self, solver: Solver - ) -> Iterable[ilpy.Constraint | ilpy.Expression]: + def instantiate(self, solver: Solver) -> Iterable[Expression]: """Create and return specific linear constraints for the given solver. Args: diff --git a/motile/constraints/expression.py b/motile/constraints/expression.py index 78fffe7..8523f39 100644 --- a/motile/constraints/expression.py +++ b/motile/constraints/expression.py @@ -4,9 +4,9 @@ import contextlib from typing import TYPE_CHECKING, Union -import ilpy +from motile.expressions import Expression, Constant -from ..variables import EdgeSelected, NodeSelected, Variable +from ..variables import EdgeSelected, NodeSelected, Variables from .constraint import Constraint if TYPE_CHECKING: @@ -76,13 +76,13 @@ def __init__( self.eval_nodes = eval_nodes self.eval_edges = eval_edges - def instantiate(self, solver: Solver) -> list[ilpy.Constraint]: + def instantiate(self, solver: Solver) -> list[Expression]: # create two constraints: one to select nodes/edges, and one to exclude - select = ilpy.Constraint() - exclude = ilpy.Constraint() + select: Expression = Constant(0) + exclude: Expression = Constant(0) n_selected = 0 # number of nodes/edges selected - to_evaluate: list[tuple[NodesOrEdges, type[Variable]]] = [] + to_evaluate: list[tuple[NodesOrEdges, type[Variables]]] = [] if self.eval_nodes: to_evaluate.append((solver.graph.nodes, NodeSelected)) if self.eval_edges: @@ -99,17 +99,14 @@ def instantiate(self, solver: Solver) -> list[ilpy.Constraint]: # contextlib.suppress (above) will just skip it and move on... if eval(self._expression, None, node_or_edge): # if the expression evaluates to True, we select the node/edge - select.set_coefficient(indicator_variables[id_], 1) + select += indicator_variables[id_] n_selected += 1 else: # Otherwise, we exclude it. - exclude.set_coefficient(indicator_variables[id_], 1) + exclude += indicator_variables[id_] # finally, apply the relation and value to the constraints - select.set_relation(ilpy.Relation.Equal) - select.set_value(n_selected) - - exclude.set_relation(ilpy.Relation.Equal) - exclude.set_value(0) + select = select == n_selected + exclude = exclude == 0 return [select, exclude] diff --git a/motile/constraints/max_children.py b/motile/constraints/max_children.py index ced891d..95db007 100644 --- a/motile/constraints/max_children.py +++ b/motile/constraints/max_children.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Iterable -from ilpy.expressions import Constant, Expression +from motile.expressions import Constant, Expression from ..variables import EdgeSelected from .constraint import Constraint diff --git a/motile/constraints/max_parents.py b/motile/constraints/max_parents.py index 3c26c9b..1b1fe7a 100644 --- a/motile/constraints/max_parents.py +++ b/motile/constraints/max_parents.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Iterable -from ilpy.expressions import Constant, Expression +from motile.expressions import Constant, Expression from ..variables import EdgeSelected from .constraint import Constraint diff --git a/motile/costs/features.py b/motile/costs/features.py index e429c6f..6cce9fb 100644 --- a/motile/costs/features.py +++ b/motile/costs/features.py @@ -1,11 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import numpy as np -if TYPE_CHECKING: - import ilpy +from motile.expressions import Variable class Features: @@ -53,7 +50,7 @@ def _increase_features(self, num_features: int) -> None: self._values = np.hstack((self._values, new_features)) def add_feature( - self, variable_index: int | ilpy.Variable, feature_index: int, value: float + self, variable_index: int | Variable, feature_index: int, value: float ) -> None: """Add a value to a feature. diff --git a/motile/expressions.py b/motile/expressions.py new file mode 100644 index 0000000..774cf77 --- /dev/null +++ b/motile/expressions.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +import ast +from enum import IntEnum, auto +from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Union + +if TYPE_CHECKING: + import ilpy + +Number = Union[float, int] + + +class VariableType(IntEnum): + Continuous = 0 + Integer = auto() + Binary = auto() + + +class Relation(IntEnum): + LessEqual = 0 + Equal = auto() + GreaterEqual = auto() + + +class Sense(IntEnum): + Minimize = 0 + Maximize = auto() + + +class Expression(ast.AST): + """Base class for all expression nodes. + + Expressions allow ilpy to represent mathematical expressions in an + intuitive syntax, and then convert to a native Constraint object. + + This class provides all of the operators and methods needed to build + expressions. For example, to create the expression ``2 * x - y >= 0``, you can + write ``2 * Variable('x') - Variable('y') >= 0``. + + Tip: you can use ``ast.dump`` to see the AST representation of an expression. + Or, use ``print(expr)` to see the string representation of an expression. + """ + + def __hash__(self) -> int: + # allow use as dict key + return id(self) + + def as_ilpy_constraint(self) -> ilpy.Constraint: + """Create an ilpy.Constraint object from this expression.""" + import ilpy + + l_coeffs, q_coeffs, value = _get_coeff_indices(self) + return ilpy.Constraint.from_coefficients( + coefficients=l_coeffs, + quadratic_coefficients=q_coeffs, + relation=_get_ilpy_relation(self) or ilpy.Relation.LessEqual, + value=-value, # negate value to convert to RHS form + ) + + def as_ilpy_objective(self, sense: Sense = Sense.Minimize) -> ilpy.Objective: + """Create a linear objective from this expression.""" + import ilpy + + if _get_ilpy_relation(self) is not None: + # TODO: may be supported in the future, eg. for piecewise objectives? + raise ValueError(f"Objective function cannot have comparisons: {self}") + + l_coeffs, q_coeffs, value = _get_coeff_indices(self) + return ilpy.Objective.from_coefficients( + coefficients=l_coeffs, + quadratic_coefficients=q_coeffs, + constant=value, + sense=ilpy.Sense(sense), + ) + + @staticmethod + def _cast(obj: Any) -> Expression: + """Cast object into an Expression.""" + return obj if isinstance(obj, Expression) else Constant(obj) + + def __str__(self) -> str: + """Serialize this expression to string form.""" + return str(_ExprSerializer(self)) + + # comparisons + + def __lt__(self, other: Expression | float) -> Compare: + return Compare(self, [ast.Lt()], [other]) + + def __le__(self, other: Expression | float) -> Compare: + return Compare(self, [ast.LtE()], [other]) + + def __eq__(self, other: Expression | float) -> Compare: # type: ignore + return Compare(self, [ast.Eq()], [other]) + + def __ne__(self, other: Expression | float) -> Compare: # type: ignore + return Compare(self, [ast.NotEq()], [other]) + + def __gt__(self, other: Expression | float) -> Compare: + return Compare(self, [ast.Gt()], [other]) + + def __ge__(self, other: Expression | float) -> Compare: + return Compare(self, [ast.GtE()], [other]) + + # binary operators + # (note that __and__ and __or__ are reserved for boolean operators.) + + def __add__(self, other: Expression | Number) -> BinOp: + return BinOp(self, ast.Add(), other) + + def __radd__(self, other: Expression | Number) -> BinOp: + return BinOp(other, ast.Add(), self) + + def __sub__(self, other: Expression | Number) -> BinOp: + return BinOp(self, ast.Sub(), other) + + def __rsub__(self, other: Expression | Number) -> BinOp: + return BinOp(other, ast.Sub(), self) + + def __mul__(self, other: Any) -> BinOp | Constant: + return BinOp(self, ast.Mult(), other) + + def __rmul__(self, other: Number) -> BinOp | Constant: + if not isinstance(other, (int, float)): + raise TypeError("Right multiplication must be with a number") + return Constant(other) * self + + def __truediv__(self, other: Number) -> BinOp: + return BinOp(self, ast.Div(), other) + + def __rtruediv__(self, other: Number) -> BinOp: + return BinOp(other, ast.Div(), self) + + # unary operators + + def __neg__(self) -> UnaryOp: + return UnaryOp(ast.USub(), self) + + def __pos__(self) -> UnaryOp: + # usually a no-op + return UnaryOp(ast.UAdd(), self) + + # specifically not implemented on Expression for now. + # We don't want to reimplement a full CAS like sympy. + # (But we could support sympy expressions!) + # Implemented below only on Constant and Variable. + + # def __pow__(self, other: Number) -> BinOp: + # return BinOp(self, ast.Pow(), other) + + +class Compare(Expression, ast.Compare): + """A comparison of two or more values. + + `left` is the first value in the comparison, `ops` the list of operators, + and `comparators` the list of values after the first element in the + comparison. + """ + + def __init__( + self, + left: Expression, + ops: Sequence[ast.cmpop], + comparators: Sequence[Expression | Number], + **kwargs: Any, + ) -> None: + super().__init__( + Expression._cast(left), + ops, + [Expression._cast(c) for c in comparators], + **kwargs, + ) + + +class BinOp(Expression, ast.BinOp): + """A binary operation (like addition or division). + + `op` is the operator, and `left` and `right` are any expression nodes. + """ + + def __init__( + self, + left: Expression | Number, + op: ast.operator, + right: Expression | Number, + **kwargs: Any, + ) -> None: + super().__init__(Expression._cast(left), op, Expression._cast(right), **kwargs) + + +class UnaryOp(Expression, ast.UnaryOp): + """A unary operation. + + `op` is the operator, and `operand` any expression node. + """ + + def __init__(self, op: ast.unaryop, operand: Expression, **kwargs: Any) -> None: + super().__init__(op, Expression._cast(operand), **kwargs) + + +class Constant(Expression, ast.Constant): + """A constant value. + + The `value` attribute contains the Python object it represents. + types supported: int, float + """ + + def __init__(self, value: Number, kind: str | None = None, **kwargs: Any) -> None: + if not isinstance(value, (float, int)): + raise TypeError("Constants must be numbers") + super().__init__(value, kind, **kwargs) + + def __mul__(self, other: Any) -> BinOp | Constant: + if isinstance(other, Constant): + return Constant(self.value**other.value) + if isinstance(other, (float, int)): + return Constant(self.value * other) + return super().__mul__(other) + + def __pow__(self, other: Number) -> Expression: + if not isinstance(other, (int, float)): + raise TypeError("Exponent must be a number") + return Constant(self.value**other) + + +class Variable(Expression, ast.Name): + """A variable. + + `id` holds the index as a string (becuase ast.Name requires a string). + + The special attribute `index` is added here for the purpose of storing + the index of a variable in a solver's variable list: ``Variable('u', index=0)`` + """ + + def __init__(self, id: str, index: int | None = None) -> None: + self.index = index + super().__init__(str(id), ctx=ast.Load()) + + def __pow__(self, other: Number) -> Expression: + if not isinstance(other, (int, float)): + raise TypeError("Exponent must be a number") + if other == 2: + return BinOp(self, ast.Mult(), self) + elif other == 1: + return self + raise ValueError("Only quadratic variables are supported") + + def __hash__(self) -> int: + # allow use as dict key + return id(self) + + def __int__(self) -> int: + if self.index is None: + raise TypeError(f"Variable {self!r} has no index") + return int(self.index) + + __index__ = __int__ + + def __repr__(self) -> str: + return f"motile.Variable({self.id!r}, index={self.index!r})" + + +# conversion between ast comparison operators and ilpy relations +# TODO: support more less/greater than operators +OPERATOR_MAP: dict[type[ast.cmpop], Relation] = { + ast.LtE: Relation.LessEqual, + ast.Eq: Relation.Equal, + ast.GtE: Relation.GreaterEqual, +} + + +def _get_ilpy_relation(expr: Expression) -> ilpy.Relation | None: + import ilpy + + seen_compare = False + relation: Relation | None = None + for sub in ast.walk(expr): + if isinstance(sub, Compare): + if seen_compare: + raise ValueError("Only single comparisons are supported") + + op_type = type(sub.ops[0]) + try: + relation = OPERATOR_MAP[op_type] + except KeyError as e: + raise ValueError(f"Unsupported comparison operator: {op_type}") from e + seen_compare = True + return ilpy.Relation(relation) if relation is not None else None + + +def _get_coeff_indices( + expr: Expression, +) -> tuple[dict[int, float], dict[tuple[int, int], float], float]: + l_coeffs: dict[int, float] = {} + q_coeffs: dict[tuple[int, int], float] = {} + constant = 0.0 + for var, coefficient in _get_coefficients(expr).items(): + if var is None: + constant = coefficient + elif isinstance(var, tuple): + q_coeffs[(_ensure_index(var[0]), _ensure_index(var[1]))] = coefficient + elif coefficient != 0: + l_coeffs[_ensure_index(var)] = coefficient + return l_coeffs, q_coeffs, constant + + +def _ensure_index(var: Variable) -> int: + if var.index is None: + raise ValueError("All variables in an Expression must have an index") + return var.index + + +def _get_coefficients( + expr: Expression | ast.expr, + coeffs: dict[Variable | None | tuple[Variable, Variable], float] | None = None, + scale: int = 1, + var_scale: Variable | None = None, +) -> dict[Variable | None | tuple[Variable, Variable], float]: + """Get the coefficients of an expression. + + The coefficients are returned as a dictionary mapping Variable to coefficient. + The key `None` is used for the constant term. Quadratic coefficients are + represented with a two-tuple of variables. + + Note also that expressions on the right side of a comparison are negated, + (so that the comparison is effectively against zero.) + + Args: + expr: The expression to get the coefficients of. + coeffs: The dictionary to add the coefficients to. If not given, a new + dictionary is created. + scale: The scale to apply to the coefficients. This is used to negate + expressions on the right side of a comparison or scale for multiplication. + var_scale: The variable to scale by. This is used to represent multiplication + or division between two variables. + + Example: + >>> u = Variable('u') + >>> v = Variable('v') + >>> _get_coefficients(2 * u - 5 * v <= 7) + {u: 2, v: -5, None: -7} + + coefficients are simplified in the process: + >>> _get_coefficients(2 * u - (u + 2 * u) <= 7) + {u: -1, None: -7} + """ + if coeffs is None: + coeffs = {} + + if isinstance(expr, Constant): + if var_scale is not None: + breakpoint() + coeffs.setdefault(None, 0) + coeffs[None] += expr.value * scale + + elif isinstance(expr, UnaryOp): + if var_scale is not None: + breakpoint() + if isinstance(expr.op, ast.USub): + scale = -scale + _get_coefficients(expr.operand, coeffs, scale, var_scale) + + elif isinstance(expr, Variable): + if var_scale is not None: + # multiplication or division between two variables + key = _sort_vars(expr, var_scale) + coeffs.setdefault(key, 0) + coeffs[key] += scale + else: + coeffs.setdefault(expr, 0) + coeffs[expr] += scale + + elif isinstance(expr, Compare): + if len(expr.ops) != 1: + raise ValueError("Only single comparisons are supported") + _get_coefficients(expr.left, coeffs, scale, var_scale) + # negate the right hand side of the comparison + _get_coefficients(expr.comparators[0], coeffs, scale * -1, var_scale) + + elif isinstance(expr, BinOp): + if isinstance(expr.op, (ast.Mult, ast.Div)): + _process_mult_op(expr, coeffs, scale, var_scale) + elif isinstance(expr.op, (ast.Add, ast.UAdd, ast.USub, ast.Sub)): + _get_coefficients(expr.left, coeffs, scale, var_scale) + if isinstance(expr.op, (ast.USub, ast.Sub)): + scale = -scale + _get_coefficients(expr.right, coeffs, scale, var_scale) + else: + raise ValueError(f"Unsupported binary operator: {type(expr.op)}") + + else: + breakpoint() + raise ValueError(f"Unsupported expression type: {type(expr)}") + + return coeffs + + +def _sort_vars(v1: Variable, v2: Variable) -> tuple[Variable, Variable]: + """Sort variables by index, or by id if index is None. + + This is so that a pair of variables can be used as a dictionary key. + Without worrying about the order of the variables (and without using a set, + which would exclude the possibility of having the same variable twice). + """ + # two lines are used to tell mypy it's a length 2 tuple + _v1, _v2 = sorted((v1, v2), key=lambda v: getattr(v, "index", id(v))) + return _v1, _v2 + + +def _process_mult_op( + expr: BinOp, + coeffs: dict[Variable | None | tuple[Variable, Variable], float], + scale: int, + var_scale: Variable | None = None, +) -> None: + """Helper function for _get_coefficients to process multiplication and division.""" + if isinstance(expr.right, Constant): + v = expr.right.value + scale *= 1 / v if isinstance(expr.op, ast.Div) else v + _get_coefficients(expr.left, coeffs, scale, var_scale) + elif isinstance(expr.left, Constant): + v = expr.left.value + scale *= 1 / v if isinstance(expr.op, ast.Div) else v + _get_coefficients(expr.right, coeffs, scale, var_scale) + elif isinstance(expr.left, Variable): + if var_scale is not None: + raise TypeError("Cannot multiply by more than two variables.") + _get_coefficients(expr.right, coeffs, scale, expr.left) + elif isinstance(expr.right, Variable): + if var_scale is not None: + raise TypeError("Cannot multiply by more than two variables.") + _get_coefficients(expr.left, coeffs, scale, expr.right) + else: + raise TypeError( + "Unexpected multiplcation or division between " + f"{type(expr.left)} and {type(expr.right)}" + ) + + +class _ExprSerializer(ast.NodeVisitor): + """Serializes an :class:`Expression` into a string. + + Used above in `Expression.__str__`. + """ + + OP_MAP: ClassVar[ + dict[type[ast.operator] | type[ast.cmpop] | type[ast.unaryop], str] + ] = { + # ast.cmpop + ast.Eq: "==", + ast.Gt: ">", + ast.GtE: ">=", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + # ast.operator + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.Div: "/", + # ast.unaryop + ast.UAdd: "+", + ast.USub: "-", + } + + def __init__(self, node: Expression | None = None) -> None: + self._result: list[str] = [] + + def write(*params: ast.AST | str) -> None: + for item in params: + if isinstance(item, ast.AST): + self.visit(item) + elif item: + self._result.append(item) + + self.write = write + + if node is not None: + self.visit(node) + + def __str__(self) -> str: + return "".join(self._result) + + def visit_Variable(self, node: Variable) -> None: + self.write(node.id) + + def visit_Constant(self, node: ast.Constant) -> None: + self.write(repr(node.value)) + + def visit_Compare(self, node: ast.Compare) -> None: + self.visit(node.left) + for op, right in zip(node.ops, node.comparators): + self.write(f" {self.OP_MAP[type(op)]} ", right) + + def visit_BinOp(self, node: ast.BinOp) -> None: + opstring = f" {self.OP_MAP[type(node.op)]} " + args: list[ast.AST | str] = [node.left, opstring, node.right] + # wrap in parentheses if the left or right side is a binary operation + if isinstance(node.op, ast.Mult): + if isinstance(node.left, ast.BinOp): + args[:1] = ["(", node.left, ")"] + if isinstance(node.right, ast.BinOp): + args[2:] = ["(", node.right, ")"] + self.write(*args) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> None: + sym = self.OP_MAP[type(node.op)] + self.write(sym, " " if sym.isalpha() else "", node.operand) diff --git a/motile/solver.py b/motile/solver.py index 6c9328a..d8c61a0 100644 --- a/motile/solver.py +++ b/motile/solver.py @@ -3,22 +3,24 @@ import logging from typing import TYPE_CHECKING, TypeVar, cast -import ilpy import numpy as np from .constraints import SelectEdgeNodes from .constraints.constraint import Constraint from .costs import Features, Weight, Weights +from .expressions import Expression, Variable, VariableType from .ssvm import fit_weights logger = logging.getLogger(__name__) if TYPE_CHECKING: + import ilpy + from motile.costs import Costs from motile.track_graph import TrackGraph - from motile.variables import Variable + from motile.variables import Variables - V = TypeVar("V", bound=Variable) + V = TypeVar("V", bound=Variables) class Solver: @@ -38,17 +40,15 @@ def __init__( self, track_graph: TrackGraph, skip_core_constraints: bool = False ) -> None: self.graph = track_graph - self.variables: dict[type[Variable], Variable] = {} - self.variable_types: dict[int, ilpy.VariableType] = {} + self.variables: dict[type[Variables], Variables] = {} + self.variable_types: dict[int, VariableType] = {} self.weights = Weights() self.weights.register_modify_callback(self._on_weights_modified) self._weights_changed = True self.features = Features() - self.ilp_solver: ilpy.Solver | None = None - self.objective: ilpy.Objective | None = None - self.constraints = ilpy.Constraints() + self._constraints: set[Expression] = set() self.num_variables: int = 0 self._costs = np.zeros((0,), dtype=np.float32) @@ -103,7 +103,16 @@ def add_constraints(self, constraints: Constraint) -> None: logger.info("Adding %s constraints...", type(constraints).__name__) for constraint in constraints.instantiate(self): - self.constraints.add(constraint) + self._constraints.add(constraint) + + @property + def ilpy_constraints(self) -> ilpy.Constraints: + import ilpy + + constraints = ilpy.Constraints() + for expr in self._constraints: + constraints.add(expr.as_ilpy_constraint()) + return constraints def solve(self, timeout: float = 0.0, num_threads: int = 1) -> ilpy.Solution: """Solve the global optimization problem. @@ -122,21 +131,26 @@ def solve(self, timeout: float = 0.0, num_threads: int = 1) -> ilpy.Solution: :func:`get_variables` to find the indices of variables in this vector. """ + import ilpy + self.objective = ilpy.Objective(self.num_variables) for i, c in enumerate(self.costs): logger.debug("Setting cost of var %d to %.3f", i, c) self.objective.set_coefficient(i, c) # TODO: support other variable types + ilpy_var_types = { + i: ilpy.VariableType(v) for i, v in self.variable_types.items() + } self.ilp_solver = ilpy.Solver( self.num_variables, ilpy.VariableType.Binary, - variable_types=self.variable_types, + variable_types=ilpy_var_types, preference=ilpy.Preference.Any, ) self.ilp_solver.set_objective(self.objective) - self.ilp_solver.set_constraints(self.constraints) + self.ilp_solver.set_constraints(self.ilpy_constraints) self.ilp_solver.set_num_threads(num_threads) if timeout > 0: @@ -171,7 +185,7 @@ def get_variables(self, cls: type[V]) -> V: return cast("V", self.variables[cls]) def add_variable_cost( - self, index: int | ilpy.Variable, value: float, weight: Weight + self, index: int | Variable, value: float, weight: Weight ) -> None: """Add costs for an individual variable. @@ -238,7 +252,7 @@ def _add_variables(self, cls: type[V]) -> None: self.variable_types[index] = cls.variable_type for constraint in cls.instantiate_constraints(self): - self.constraints.add(constraint) + self._constraints.add(constraint) self.features.resize(num_variables=self.num_variables) diff --git a/motile/ssvm.py b/motile/ssvm.py index 43c2c64..279f379 100644 --- a/motile/ssvm.py +++ b/motile/ssvm.py @@ -59,7 +59,7 @@ def fit_weights( ground_truth[index] = gt loss = ssvm.SoftMarginLoss( - solver.constraints, + solver.ilpy_constraints, features.T, # TODO: fix in ssvm ground_truth, ssvm.HammingCosts(ground_truth, mask), diff --git a/motile/variables/__init__.py b/motile/variables/__init__.py index 49cc1c2..317971d 100644 --- a/motile/variables/__init__.py +++ b/motile/variables/__init__.py @@ -3,7 +3,7 @@ from .node_disappear import NodeDisappear from .node_selected import NodeSelected from .node_split import NodeSplit -from .variable import Variable +from .variable import Variables __all__ = [ "EdgeSelected", @@ -11,5 +11,5 @@ "NodeDisappear", "NodeSelected", "NodeSplit", - "Variable", + "Variables", ] diff --git a/motile/variables/edge_selected.py b/motile/variables/edge_selected.py index e3e0e06..7e7af80 100644 --- a/motile/variables/edge_selected.py +++ b/motile/variables/edge_selected.py @@ -2,14 +2,14 @@ from typing import TYPE_CHECKING, Collection -from .variable import Variable +from .variable import Variables if TYPE_CHECKING: from motile._types import EdgeId from motile.solver import Solver -class EdgeSelected(Variable["EdgeId"]): +class EdgeSelected(Variables["EdgeId"]): """Binary variable indicates whether an edge is part of the solution or not.""" @staticmethod diff --git a/motile/variables/node_appear.py b/motile/variables/node_appear.py index 37caecb..7a761f6 100644 --- a/motile/variables/node_appear.py +++ b/motile/variables/node_appear.py @@ -4,16 +4,15 @@ from .edge_selected import EdgeSelected from .node_selected import NodeSelected -from .variable import Variable +from .variable import Variables if TYPE_CHECKING: - import ilpy - from motile._types import NodeId + from motile.expressions import Expression from motile.solver import Solver -class NodeAppear(Variable["NodeId"]): +class NodeAppear(Variables["NodeId"]): r"""Binary variable indicating whether a node is the start of a track. (i.e., the node is selected and has no selected incoming edges). @@ -39,7 +38,7 @@ def instantiate(solver: Solver) -> Collection[NodeId]: return solver.graph.nodes @staticmethod - def instantiate_constraints(solver: Solver) -> Iterable[ilpy.Expression]: + def instantiate_constraints(solver: Solver) -> Iterable[Expression]: appear_indicators = solver.get_variables(NodeAppear) node_indicators = solver.get_variables(NodeSelected) edge_indicators = solver.get_variables(EdgeSelected) diff --git a/motile/variables/node_disappear.py b/motile/variables/node_disappear.py index 93e9683..6507e63 100644 --- a/motile/variables/node_disappear.py +++ b/motile/variables/node_disappear.py @@ -4,16 +4,15 @@ from .edge_selected import EdgeSelected from .node_selected import NodeSelected -from .variable import Variable +from .variable import Variables if TYPE_CHECKING: - import ilpy - + from motile.expressions import Expression from motile._types import NodeId from motile.solver import Solver -class NodeDisappear(Variable["NodeId"]): +class NodeDisappear(Variables["NodeId"]): r"""Binary variable to indicate whether a node disappears. This variable indicates whether the node is the end of a track (i.e., the node is @@ -39,7 +38,7 @@ def instantiate(solver: Solver) -> Collection[NodeId]: return solver.graph.nodes @staticmethod - def instantiate_constraints(solver: Solver) -> Iterable[ilpy.Expression]: + def instantiate_constraints(solver: Solver) -> Iterable[Expression]: node_indicators = solver.get_variables(NodeSelected) edge_indicators = solver.get_variables(EdgeSelected) disappear_indicators = solver.get_variables(NodeDisappear) diff --git a/motile/variables/node_selected.py b/motile/variables/node_selected.py index eb4dd1a..fbfa52f 100644 --- a/motile/variables/node_selected.py +++ b/motile/variables/node_selected.py @@ -2,14 +2,14 @@ from typing import TYPE_CHECKING, Collection -from .variable import Variable +from .variable import Variables if TYPE_CHECKING: from motile._types import NodeId from motile.solver import Solver -class NodeSelected(Variable["NodeId"]): +class NodeSelected(Variables["NodeId"]): """Binary variable indicating whether a node is part of the solution or not.""" @staticmethod diff --git a/motile/variables/node_split.py b/motile/variables/node_split.py index 4e8c1bb..291f5c3 100644 --- a/motile/variables/node_split.py +++ b/motile/variables/node_split.py @@ -2,17 +2,17 @@ from typing import TYPE_CHECKING, Collection, Iterable -import ilpy +from motile.expressions import Constant, Expression from .edge_selected import EdgeSelected -from .variable import Variable +from .variable import Variables if TYPE_CHECKING: from motile._types import NodeId from motile.solver import Solver -class NodeSplit(Variable): +class NodeSplit(Variables): r"""Binary variable indicating whether a node has more than one child. (i.e., the node is selected and has more than one selected outgoing edge). @@ -36,13 +36,11 @@ def instantiate(solver: Solver) -> Collection[NodeId]: return solver.graph.nodes @staticmethod - def instantiate_constraints(solver: Solver) -> Iterable[ilpy.Constraint]: + def instantiate_constraints(solver: Solver) -> Iterable[Expression]: split_indicators = solver.get_variables(NodeSplit) edge_indicators = solver.get_variables(EdgeSelected) for node in solver.graph.nodes: - next_edges = solver.graph.next_edges[node] - # Ensure that the following holds: # # split = 0 <=> sum(next_selected) <= 1 @@ -53,21 +51,16 @@ def instantiate_constraints(solver: Solver) -> Iterable[ilpy.Constraint]: # (1) 2 * split - sum(next_selected) <= 0 # (2) (num_next - 1) * split - sum(next_selected) >= -1 - constraint1 = ilpy.Constraint() - constraint2 = ilpy.Constraint() + c1: Expression = Constant(0) + c2: Expression = Constant(0) - constraint1.set_coefficient(split_indicators[node], 2.0) - constraint2.set_coefficient(split_indicators[node], len(next_edges) - 1.0) + next_edges = solver.graph.next_edges[node] + c1 += 2.0 * split_indicators[node] + c2 += (len(next_edges) - 1.0) * split_indicators[node] for next_edge in next_edges: - constraint1.set_coefficient(edge_indicators[next_edge], -1.0) - constraint2.set_coefficient(edge_indicators[next_edge], -1.0) - - constraint1.set_relation(ilpy.Relation.LessEqual) - constraint2.set_relation(ilpy.Relation.GreaterEqual) - - constraint1.set_value(0.0) - constraint2.set_value(-1.0) + c1 -= edge_indicators[next_edge] + c2 -= edge_indicators[next_edge] - yield constraint1 - yield constraint2 + yield c1 <= 0.0 + yield c2 >= -1.0 diff --git a/motile/variables/variable.py b/motile/variables/variable.py index 24ecb84..6da3491 100644 --- a/motile/variables/variable.py +++ b/motile/variables/variable.py @@ -12,7 +12,7 @@ TypeVar, ) -import ilpy +from motile.expressions import Expression, Variable, VariableType if TYPE_CHECKING: from motile.solver import Solver @@ -20,7 +20,7 @@ _KT = TypeVar("_KT", bound=Hashable) -class Variable(ABC, Mapping[_KT, ilpy.Variable]): +class Variables(ABC, Mapping[_KT, Variable]): """Base class for solver variables. New variables can be introduced by inheriting from this base class and @@ -53,7 +53,7 @@ class Variable(ABC, Mapping[_KT, ilpy.Variable]): """ # default variable type, replace in subclasses to override - variable_type: ClassVar[ilpy.VariableType] = ilpy.VariableType.Binary + variable_type: ClassVar[VariableType] = VariableType.Binary @staticmethod @abstractmethod @@ -90,9 +90,7 @@ def instantiate(solver): pass @staticmethod - def instantiate_constraints( - solver: Solver, - ) -> Iterable[ilpy.Constraint | ilpy.Expression]: + def instantiate_constraints(solver: Solver) -> Iterable[Expression]: """Add constraints for this variable to the solver. This ensures that these variables are coupled to other variables of the solver. @@ -127,9 +125,9 @@ def __repr__(self) -> str: rs.append(r) return "\n".join(rs) - def __getitem__(self, key: _KT) -> ilpy.Variable: + def __getitem__(self, key: _KT) -> Variable: name = f"{type(self).__name__}({key})" - return ilpy.Variable(name, index=self._index_map[key]) + return Variable(name, index=self._index_map[key]) def __iter__(self) -> Iterator[_KT]: return iter(self._index_map) diff --git a/tests/test_variables.py b/tests/test_variables.py index e90be51..aa5273c 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -2,14 +2,14 @@ from typing import Collection, Hashable, Iterable -import ilpy import pytest from motile import Solver, data -from motile.variables import Variable +from motile.expressions import Expression +from motile.variables import Variables -@pytest.mark.parametrize("VarCls", Variable.__subclasses__()) -def test_variable_subclass_protocols(VarCls: type[Variable]) -> None: +@pytest.mark.parametrize("VarCls", Variables.__subclasses__()) +def test_variable_subclass_protocols(VarCls: type[Variables]) -> None: """Test that all Variable subclasses properly implement the Variable protocol.""" solver = Solver(data.arlo_graph()) @@ -19,4 +19,4 @@ def test_variable_subclass_protocols(VarCls: type[Variable]) -> None: constraints = VarCls.instantiate_constraints(solver) assert isinstance(constraints, Iterable) - assert all(isinstance(c, (ilpy.Expression, ilpy.Constraint)) for c in constraints) + assert all(isinstance(c, Expression) for c in constraints)