Skip to content

Commit

Permalink
Merge pull request #3533 from yger/clustering_components_api
Browse files Browse the repository at this point in the history
Clustering components api
  • Loading branch information
samuelgarcia authored Jan 8, 2025
2 parents 9b022da + 33c1b1b commit 1c43ef1
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 94 deletions.
9 changes: 3 additions & 6 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"matched_filtering": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.8},
"job_kwargs": {"n_jobs": 0.5},
"debug": False,
}

Expand Down Expand Up @@ -115,7 +115,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction
from spikeinterface.sortingcomponents.tools import get_prototype_spike

job_kwargs = fix_job_kwargs(params["job_kwargs"])
job_kwargs.update({"progress_bar": verbose})
Expand Down Expand Up @@ -219,7 +218,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["radius_um"] = radius_um
clustering_params["waveforms"]["ms_before"] = ms_before
clustering_params["waveforms"]["ms_after"] = ms_after
clustering_params["job_kwargs"] = job_kwargs
clustering_params["noise_levels"] = noise_levels
clustering_params["ms_before"] = exclude_sweep_ms
clustering_params["ms_after"] = exclude_sweep_ms
Expand All @@ -233,7 +231,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_method = "random_projections"

labels, peak_labels = find_cluster_from_peaks(
recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params
recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs
)

## We get the labels for our peaks
Expand Down Expand Up @@ -284,11 +282,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_method = params["matching"].pop("method")
matching_params = params["matching"].copy()
matching_params["templates"] = templates
matching_job_params = job_kwargs.copy()

if matching_method is not None:
spikes = find_spikes_from_templates(
recording_w, matching_method, method_kwargs=matching_params, **matching_job_params
recording_w, matching_method, method_kwargs=matching_params, **job_kwargs
)

if params["debug"]:
Expand Down
17 changes: 7 additions & 10 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection
from spikeinterface.core.template import Templates
Expand Down Expand Up @@ -63,16 +62,13 @@ class CircusClustering:
"rank": 5,
"noise_levels": None,
"tmp_folder": None,
"job_kwargs": {},
"verbose": True,
}

@classmethod
def main_function(cls, recording, peaks, params):
def main_function(cls, recording, peaks, params, job_kwargs=dict()):
assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed"

job_kwargs = fix_job_kwargs(params["job_kwargs"])

d = params
verbose = d["verbose"]

Expand Down Expand Up @@ -248,8 +244,10 @@ def main_function(cls, recording, peaks, params):
probe=recording.get_probe(),
is_scaled=False,
)

if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False)
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
Expand All @@ -260,13 +258,12 @@ def main_function(cls, recording, peaks, params):
if verbose:
print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids)))

cleaning_matching_params = params["job_kwargs"].copy()
cleaning_matching_params["n_jobs"] = 1
cleaning_matching_params["progress_bar"] = False
cleaning_job_kwargs = job_kwargs.copy()
cleaning_job_kwargs["progress_bar"] = False
cleaning_params = params["cleaning_kwargs"].copy()

labels, peak_labels = remove_duplicates_via_matching(
templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params
templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params
)

if verbose:
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def clean_clusters(
count = np.zeros(n, dtype="int64")
for i, label in enumerate(labels_set):
count[i] = np.sum(peak_labels == label)
print(count)

templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels)

Expand All @@ -42,6 +41,5 @@ def clean_clusters(
max_values = -np.min(templates, axis=(1, 2))
elif peak_sign == "pos":
max_values = np.max(templates, axis=(1, 2))
print(max_values)

return clean_labels
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/clustering/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DummyClustering:
_default_params = {}

@classmethod
def main_function(cls, recording, peaks, params):
def main_function(cls, recording, peaks, params, job_kwargs=dict()):
labels = np.arange(recording.get_num_channels(), dtype="int64")
peak_labels = peaks["channel_index"]
return labels, peak_labels
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/clustering/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={},
params = method_class._default_params.copy()
params.update(**method_kwargs)

outputs = method_class.main_function(recording, peaks, params)
outputs = method_class.main_function(recording, peaks, params, job_kwargs=job_kwargs)

if extra_outputs:
return outputs
Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/sortingcomponents/clustering/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@ class PositionClustering:
"hdbscan_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1},
"debug": False,
"tmp_folder": None,
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"},
}

@classmethod
def main_function(cls, recording, peaks, params):
def main_function(cls, recording, peaks, params, job_kwargs=dict()):
assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed"
d = params

if d["peak_locations"] is None:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"])
peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **job_kwargs)
else:
peak_locations = d["peak_locations"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,14 @@ class PositionAndFeaturesClustering:
"ms_before": 1.5,
"ms_after": 1.5,
"cleaning_method": "dip",
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True},
}

@classmethod
def main_function(cls, recording, peaks, params):
def main_function(cls, recording, peaks, params, job_kwargs=dict()):
from sklearn.preprocessing import QuantileTransformer

assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed"

if "n_jobs" in params["job_kwargs"]:
if params["job_kwargs"]["n_jobs"] == -1:
params["job_kwargs"]["n_jobs"] = os.cpu_count()

if "core_dist_n_jobs" in params["hdbscan_kwargs"]:
if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1:
params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count()

d = params

peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]
Expand All @@ -80,7 +71,7 @@ def main_function(cls, recording, peaks, params):
}

features_data = compute_features_from_peaks(
recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"]
recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **job_kwargs
)

hdbscan_data = np.zeros((len(peaks), 3), dtype=np.float32)
Expand Down Expand Up @@ -150,10 +141,10 @@ def main_function(cls, recording, peaks, params):
dtype=recording.get_dtype(),
sparsity_mask=None,
copy=True,
**params["job_kwargs"],
**job_kwargs,
)

noise_levels = get_noise_levels(recording, return_scaled=False)
noise_levels = get_noise_levels(recording, return_scaled=False, **job_kwargs)
labels, peak_labels = remove_duplicates(
wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"]
)
Expand Down Expand Up @@ -181,7 +172,7 @@ def main_function(cls, recording, peaks, params):
nbefore,
nafter,
return_scaled=False,
**params["job_kwargs"],
**job_kwargs,
)
templates = Templates(
templates_array=templates_array,
Expand All @@ -193,7 +184,7 @@ def main_function(cls, recording, peaks, params):
)

labels, peak_labels = remove_duplicates_via_matching(
templates, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"]
templates, peak_labels, job_kwargs=job_kwargs, **params["cleaning_kwargs"]
)
shutil.rmtree(tmp_folder)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class PositionAndPCAClustering:
"ms_after": 2.5,
"n_components_by_channel": 3,
"n_components": 5,
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True},
"hdbscan_global_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1},
"hdbscan_local_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1},
"waveform_mode": "shared_memory",
Expand Down Expand Up @@ -73,7 +72,7 @@ def _check_params(cls, recording, peaks, params):
return params2

@classmethod
def main_function(cls, recording, peaks, params):
def main_function(cls, recording, peaks, params, job_kwargs=dict()):
# res = PositionClustering(recording, peaks, params)

assert HAVE_HDBSCAN, "position_and_pca clustering need hdbscan to be installed"
Expand All @@ -85,9 +84,7 @@ def main_function(cls, recording, peaks, params):
if params["peak_locations"] is None:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

peak_locations = localize_peaks(
recording, peaks, **params["peak_localization_kwargs"], **params["job_kwargs"]
)
peak_locations = localize_peaks(recording, peaks, **params["peak_localization_kwargs"], **job_kwargs)
else:
peak_locations = params["peak_locations"]

Expand Down Expand Up @@ -155,7 +152,7 @@ def main_function(cls, recording, peaks, params):
dtype=recording.get_dtype(),
sparsity_mask=sparsity_mask,
copy=(params["waveform_mode"] == "shared_memory"),
**params["job_kwargs"],
**job_kwargs,
)

noise = get_random_data_chunks(
Expand Down Expand Up @@ -222,7 +219,7 @@ def main_function(cls, recording, peaks, params):
dtype=recording.get_dtype(),
sparsity_mask=sparsity_mask3,
copy=(params["waveform_mode"] == "shared_memory"),
**params["job_kwargs"],
**job_kwargs,
)

clean_peak_labels, peak_sample_shifts = auto_clean_clustering(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class PositionPTPScaledClustering:
"ptps": None,
"scales": (1, 1, 10),
"peak_localization_kwargs": {"method": "center_of_mass"},
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True},
"hdbscan_kwargs": {
"min_cluster_size": 20,
"min_samples": 20,
Expand All @@ -38,7 +37,7 @@ class PositionPTPScaledClustering:
}

@classmethod
def main_function(cls, recording, peaks, params):
def main_function(cls, recording, peaks, params, job_kwargs=dict()):
assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed"
d = params

Expand All @@ -60,7 +59,7 @@ def main_function(cls, recording, peaks, params):

if d["ptps"] is None:
(ptps,) = compute_features_from_peaks(
recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **d["job_kwargs"]
recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **job_kwargs
)
else:
ptps = d["ptps"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.template import Templates
Expand Down Expand Up @@ -55,16 +54,13 @@ class RandomProjectionClustering:
"noise_levels": None,
"smoothing_kwargs": {"window_length_ms": 0.25},
"tmp_folder": None,
"job_kwargs": {},
"verbose": True,
}

@classmethod
def main_function(cls, recording, peaks, params):
def main_function(cls, recording, peaks, params, job_kwargs=dict()):
assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed"

job_kwargs = fix_job_kwargs(params["job_kwargs"])

d = params
verbose = d["verbose"]

Expand Down Expand Up @@ -148,26 +144,20 @@ def main_function(cls, recording, peaks, params):
is_scaled=False,
)
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False)
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)
sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
templates = remove_empty_templates(templates)

if verbose:
print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids)))

cleaning_matching_params = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in cleaning_matching_params:
cleaning_matching_params[value] = None
cleaning_matching_params["chunk_duration"] = "100ms"
cleaning_matching_params["n_jobs"] = 1
cleaning_matching_params["progress_bar"] = False

cleaning_job_kwargs = job_kwargs.copy()
cleaning_job_kwargs["progress_bar"] = False
cleaning_params = params["cleaning_kwargs"].copy()

labels, peak_labels = remove_duplicates_via_matching(
templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params
templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params
)

if verbose:
Expand Down
Loading

0 comments on commit 1c43ef1

Please sign in to comment.