diff --git a/src/spyglass/spikesorting/analysis/v1/group.py b/src/spyglass/spikesorting/analysis/v1/group.py index 573697f1a..10ac27474 100644 --- a/src/spyglass/spikesorting/analysis/v1/group.py +++ b/src/spyglass/spikesorting/analysis/v1/group.py @@ -96,9 +96,10 @@ def filter_units( for ind, unit_labels in enumerate(labels): if isinstance(unit_labels, str): unit_labels = [unit_labels] - if np.all(~np.isin(unit_labels, include_labels)) or np.any( - np.isin(unit_labels, exclude_labels) - ): + if ( + include_labels.size > 0 + and np.all(~np.isin(unit_labels, include_labels)) + ) or np.any(np.isin(unit_labels, exclude_labels)): # if the unit does not have any of the include labels # or has any of the exclude labels, skip continue