diff --git a/navis/graph/graph_utils.py b/navis/graph/graph_utils.py index 6cff299b..c2adc88d 100644 --- a/navis/graph/graph_utils.py +++ b/navis/graph/graph_utils.py @@ -340,9 +340,9 @@ def _edge_count_to_root_old(x: 'core.TreeNeuron') -> dict: @utils.map_neuronlist(desc='Classifying', allow_parallel=True) @utils.lock_neuron -def classify_nodes(x: 'core.NeuronObject', - inplace: bool = True - ) -> Optional['core.NeuronObject']: +def _classify_nodes_old(x: 'core.NeuronObject', + inplace: bool = True + ) -> Optional['core.NeuronObject']: """Classify neuron's nodes into end nodes, branches, slabs or root. Adds ``'type'`` column to ``x.nodes``. @@ -414,6 +414,70 @@ def classify_nodes(x: 'core.NeuronObject', return x + +@utils.map_neuronlist(desc='Classifying', allow_parallel=True) +@utils.lock_neuron +def classify_nodes(x: 'core.NeuronObject', + categorical=True, + inplace: bool = True + ) -> Optional['core.NeuronObject']: + """Classify neuron's nodes into end nodes, branches, slabs or root. + + Adds ``'type'`` column to ``x.nodes`` table. + + Parameters + ---------- + x : TreeNeuron | NeuronList + Neuron(s) whose nodes to classify. + categorical : bool + If True (default), will use categorical data type which takes + up much less memory at a small run-time overhead. + inplace : bool, optional + If ``False``, nodes will be classified on a copy which is then + returned leaving the original neuron unchanged. + + Returns + ------- + TreeNeuron/List + + Examples + -------- + >>> import navis + >>> nl = navis.example_neurons(2) + >>> _ = navis.graph.classify_nodes(nl, inplace=True) + + """ + if not inplace: + x = x.copy() + + if not isinstance(x, core.TreeNeuron): + raise TypeError(f'Expected TreeNeuron(s), got "{type(x)}"') + + # At this point x is TreeNeuron + x: core.TreeNeuron + + # Make sure there are nodes to classify + if not x.nodes.empty: + x.nodes['type'] = 'slab' + x.nodes.loc[~x.nodes.node_id.isin(x.nodes.parent_id), 'type'] = 'end' + bp = x.nodes.parent_id.value_counts() + bp = bp[bp > 1].index.values + x.nodes.loc[x.nodes.node_id.isin(bp), 'type'] = 'branch' + x.nodes.loc[x.nodes.parent_id < 0, 'type'] = 'root' + else: + x.nodes['type'] = None + + # Turn into categorical data - saves tons of memory + # Note that we have to make sure all categories are set even if they + # don't exist (e.g. if a neuron has no branch points) + if categorical: + cat_types = CategoricalDtype(categories=["end", "branch", "root", "slab"], + ordered=False) + x.nodes['type'] = x.nodes['type'].astype(cat_types) + + return x + + # only this combination will return a single bool @overload def distal_to(x: 'core.TreeNeuron',