diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index 8b67fe2..3293cf3 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,4 +1,4 @@ -from .compute_graph import get_candidate_graph +from .compute_graph import get_candidate_graph, get_candidate_graph_from_points_list from .graph_attributes import EdgeAttr, NodeAttr from .graph_to_nx import graph_to_nx from .iou import add_iou diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index 31994ff..2c911cc 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -6,7 +6,7 @@ from .conflict_sets import compute_conflict_sets from .iou import add_iou -from .utils import add_cand_edges, nodes_from_segmentation +from .utils import add_cand_edges, nodes_from_points_list, nodes_from_segmentation logger = logging.getLogger(__name__) @@ -58,3 +58,32 @@ def get_candidate_graph( conflicts.extend(compute_conflict_sets(segs, time)) return cand_graph, conflicts + + +def get_candidate_graph_from_points_list( + points_list: np.ndarray, + max_edge_distance: float, +) -> nx.DiGraph: + """Construct a candidate graph from a points list. + + Args: + points_list (np.ndarray): An NxD numpy array with N points and D + (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + + Returns: + nx.DiGraph: A candidate graph that can be passed to the motile solver. + Multiple hypotheses not supported for points input. + """ + # add nodes + cand_graph, node_frame_dict = nodes_from_points_list(points_list) + logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + return cand_graph diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 61307ba..3450a98 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -1,14 +1,13 @@ import logging -import math from typing import Any, Iterable import networkx as nx import numpy as np -from skimage.measure import regionprops from scipy.spatial import KDTree +from skimage.measure import regionprops from tqdm import tqdm -from .graph_attributes import EdgeAttr, NodeAttr +from .graph_attributes import NodeAttr logger = logging.getLogger(__name__) @@ -82,6 +81,42 @@ def nodes_from_segmentation( return cand_graph, node_frame_dict +def nodes_from_points_list( + points_list: np.ndarray, +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + """Extract candidate nodes from a list of points. Uses the index of the + point in the list as its unique id. + Returns a networkx graph with only nodes, and also a dictionary from frames to + node_ids for efficient edge adding. + + Args: + points_list (np.ndarray): An NxD numpy array with N points and D + (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). + + Returns: + tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, + and a mapping from time frames to node ids. + """ + cand_graph = nx.DiGraph() + # also construct a dictionary from time frame to node_id for efficiency + node_frame_dict: dict[int, list[Any]] = {} + print("Extracting nodes from points list") + for i, point in enumerate(points_list): + # assume t, [z], y, x + t = point[0] + pos = point[1:] + node_id = i + attrs = { + NodeAttr.TIME.value: t, + NodeAttr.POS.value: pos, + } + cand_graph.add_node(node_id, **attrs) + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].append(node_id) + return cand_graph, node_frame_dict + + def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]: """Compute dictionary from time frames to node ids for candidate graph. diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py index 77ce13d..19dcf92 100644 --- a/tests/test_candidate_graph/test_compute_graph.py +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -1,7 +1,11 @@ from collections import Counter +import numpy as np import pytest from motile_toolbox.candidate_graph import EdgeAttr, get_candidate_graph +from motile_toolbox.candidate_graph.compute_graph import ( + get_candidate_graph_from_points_list, +) def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): @@ -82,3 +86,18 @@ def test_graph_from_multi_segmentation_2d( assert Counter(list(cand_graph.edges)) == Counter( [("0_0_1", "1_0_2"), ("0_0_1", "1_1_2"), ("0_1_1", "1_1_2")] ) + + +def test_graph_from_points_list(): + points_list = np.array( + [ + [0, 1, 1, 1], + [2, 3, 3, 3], + [1, 2, 2, 2], + [2, 6, 6, 6], + [2, 1, 1, 1], + ] + ) + cand_graph = get_candidate_graph_from_points_list(points_list, max_edge_distance=3) + assert cand_graph.number_of_edges() == 3 + assert len(cand_graph.in_edges(3)) == 0 diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py index 9b4dda6..3bd96ac 100644 --- a/tests/test_candidate_graph/test_utils.py +++ b/tests/test_candidate_graph/test_utils.py @@ -2,15 +2,16 @@ import networkx as nx import numpy as np -import pytest from motile_toolbox.candidate_graph import ( - EdgeAttr, NodeAttr, add_cand_edges, get_node_id, nodes_from_segmentation, ) -from motile_toolbox.candidate_graph.utils import _compute_node_frame_dict +from motile_toolbox.candidate_graph.utils import ( + _compute_node_frame_dict, + nodes_from_points_list, +) # nodes_from_segmentation @@ -98,3 +99,19 @@ def test_compute_node_frame_dict(graph_2d): 1: ["1_1", "1_2"], } assert node_frame_dict == expected + + +def test_nodes_from_points_list_2d(): + points_list = np.array( + [ + [0, 1, 2, 3], + [2, 3, 4, 5], + [1, 2, 3, 4], + ] + ) + cand_graph, node_frame_dict = nodes_from_points_list(points_list) + assert Counter(list(cand_graph.nodes)) == Counter([0, 1, 2]) + assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0 + assert (cand_graph.nodes[0][NodeAttr.POS.value] == np.array([1, 2, 3])).all() + assert cand_graph.nodes[1][NodeAttr.TIME.value] == 2 + assert (cand_graph.nodes[1][NodeAttr.POS.value] == np.array([3, 4, 5])).all()