Skip to content

Commit

Permalink
Merge pull request #14 from funkelab/graph_from_nodes
Browse files Browse the repository at this point in the history
Create graph from nodes list
  • Loading branch information
cmalinmayor authored Jun 18, 2024
2 parents 478b38d + 24bf2fe commit b8c2df9
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .compute_graph import get_candidate_graph
from .compute_graph import get_candidate_graph, get_candidate_graph_from_points_list
from .graph_attributes import EdgeAttr, NodeAttr
from .graph_to_nx import graph_to_nx
from .iou import add_iou
Expand Down
31 changes: 30 additions & 1 deletion src/motile_toolbox/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .conflict_sets import compute_conflict_sets
from .iou import add_iou
from .utils import add_cand_edges, nodes_from_segmentation
from .utils import add_cand_edges, nodes_from_points_list, nodes_from_segmentation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,3 +58,32 @@ def get_candidate_graph(
conflicts.extend(compute_conflict_sets(segs, time))

return cand_graph, conflicts


def get_candidate_graph_from_points_list(
points_list: np.ndarray,
max_edge_distance: float,
) -> nx.DiGraph:
"""Construct a candidate graph from a points list.
Args:
points_list (np.ndarray): An NxD numpy array with N points and D
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
Returns:
nx.DiGraph: A candidate graph that can be passed to the motile solver.
Multiple hypotheses not supported for points input.
"""
# add nodes
cand_graph, node_frame_dict = nodes_from_points_list(points_list)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")
# add edges
add_cand_edges(
cand_graph,
max_edge_distance=max_edge_distance,
node_frame_dict=node_frame_dict,
)
return cand_graph
41 changes: 38 additions & 3 deletions src/motile_toolbox/candidate_graph/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import logging
import math
from typing import Any, Iterable

import networkx as nx
import numpy as np
from skimage.measure import regionprops
from scipy.spatial import KDTree
from skimage.measure import regionprops
from tqdm import tqdm

from .graph_attributes import EdgeAttr, NodeAttr
from .graph_attributes import NodeAttr

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,6 +81,42 @@ def nodes_from_segmentation(
return cand_graph, node_frame_dict


def nodes_from_points_list(
points_list: np.ndarray,
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a list of points. Uses the index of the
point in the list as its unique id.
Returns a networkx graph with only nodes, and also a dictionary from frames to
node_ids for efficient edge adding.
Args:
points_list (np.ndarray): An NxD numpy array with N points and D
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
Returns:
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
and a mapping from time frames to node ids.
"""
cand_graph = nx.DiGraph()
# also construct a dictionary from time frame to node_id for efficiency
node_frame_dict: dict[int, list[Any]] = {}
print("Extracting nodes from points list")
for i, point in enumerate(points_list):
# assume t, [z], y, x
t = point[0]
pos = point[1:]
node_id = i
attrs = {
NodeAttr.TIME.value: t,
NodeAttr.POS.value: pos,
}
cand_graph.add_node(node_id, **attrs)
if t not in node_frame_dict:
node_frame_dict[t] = []
node_frame_dict[t].append(node_id)
return cand_graph, node_frame_dict


def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]:
"""Compute dictionary from time frames to node ids for candidate graph.
Expand Down
19 changes: 19 additions & 0 deletions tests/test_candidate_graph/test_compute_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from collections import Counter

import numpy as np
import pytest
from motile_toolbox.candidate_graph import EdgeAttr, get_candidate_graph
from motile_toolbox.candidate_graph.compute_graph import (
get_candidate_graph_from_points_list,
)


def test_graph_from_segmentation_2d(segmentation_2d, graph_2d):
Expand Down Expand Up @@ -82,3 +86,18 @@ def test_graph_from_multi_segmentation_2d(
assert Counter(list(cand_graph.edges)) == Counter(
[("0_0_1", "1_0_2"), ("0_0_1", "1_1_2"), ("0_1_1", "1_1_2")]
)


def test_graph_from_points_list():
points_list = np.array(
[
[0, 1, 1, 1],
[2, 3, 3, 3],
[1, 2, 2, 2],
[2, 6, 6, 6],
[2, 1, 1, 1],
]
)
cand_graph = get_candidate_graph_from_points_list(points_list, max_edge_distance=3)
assert cand_graph.number_of_edges() == 3
assert len(cand_graph.in_edges(3)) == 0
23 changes: 20 additions & 3 deletions tests/test_candidate_graph/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

import networkx as nx
import numpy as np
import pytest
from motile_toolbox.candidate_graph import (
EdgeAttr,
NodeAttr,
add_cand_edges,
get_node_id,
nodes_from_segmentation,
)
from motile_toolbox.candidate_graph.utils import _compute_node_frame_dict
from motile_toolbox.candidate_graph.utils import (
_compute_node_frame_dict,
nodes_from_points_list,
)


# nodes_from_segmentation
Expand Down Expand Up @@ -98,3 +99,19 @@ def test_compute_node_frame_dict(graph_2d):
1: ["1_1", "1_2"],
}
assert node_frame_dict == expected


def test_nodes_from_points_list_2d():
points_list = np.array(
[
[0, 1, 2, 3],
[2, 3, 4, 5],
[1, 2, 3, 4],
]
)
cand_graph, node_frame_dict = nodes_from_points_list(points_list)
assert Counter(list(cand_graph.nodes)) == Counter([0, 1, 2])
assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0
assert (cand_graph.nodes[0][NodeAttr.POS.value] == np.array([1, 2, 3])).all()
assert cand_graph.nodes[1][NodeAttr.TIME.value] == 2
assert (cand_graph.nodes[1][NodeAttr.POS.value] == np.array([3, 4, 5])).all()

0 comments on commit b8c2df9

Please sign in to comment.