Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc2 fixes #3250

Merged
merged 47 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
71821c2
Fixes
yger Jul 19, 2024
cc89e6e
Patches
yger Jul 19, 2024
62be466
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Jul 19, 2024
4e4b869
Fixes for SC2 and for split clustering
yger Jul 23, 2024
f351640
debugging clustering
yger Jul 24, 2024
b38c6f2
WIP
yger Jul 24, 2024
95ec4fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
5772b2d
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Jul 24, 2024
8cd975e
WIP
yger Jul 25, 2024
dbb21f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
0c91349
WIP
yger Jul 30, 2024
856bd4d
WIP
yger Jul 31, 2024
3f6f078
Default params
yger Jul 31, 2024
5ddbb91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2024
15bf498
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Aug 26, 2024
485a269
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Aug 27, 2024
37323ee
Merge branch 'sc2_fixes' of github.com:yger/spikeinterface into sc2_f…
yger Aug 28, 2024
39b0514
WIP
yger Aug 28, 2024
e71b7d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
f0a9417
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Sep 10, 2024
32c0b77
Merge branch 'main' into sc2_fixes
yger Sep 25, 2024
dcaa3fb
Adding gather_func to find_spikes
yger Sep 25, 2024
2067c37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
8552358
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Sep 26, 2024
ca5a1f3
Merge remote-tracking branch 'origin/gather_find_spikes' into sc2_fixes
yger Sep 26, 2024
355e5f5
Gathering mode more explicit for matching
yger Sep 26, 2024
10c7a35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
728c9f1
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Oct 7, 2024
6698843
WIP
yger Oct 7, 2024
e826830
WIP
yger Oct 7, 2024
7975c95
Merge branch 'node_pipeline_skip_no_peaks' of github.com:samuelgarcia…
yger Oct 7, 2024
affff49
Fixes for SC2
yger Oct 7, 2024
6f4268d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2024
4d39563
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Oct 8, 2024
d83aca6
WIP
yger Oct 8, 2024
2ba923d
Simplifications
yger Oct 8, 2024
fd441a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
cf04170
Naming for Sam
yger Oct 8, 2024
258617e
Merge branch 'sc2_fixes' of github.com:yger/spikeinterface into sc2_f…
yger Oct 8, 2024
7a91086
Optimize circus matching engine
yger Oct 8, 2024
b428050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
1cc1578
Optimizations
yger Oct 8, 2024
cfaa98b
Remove the limit to chunk sizes in circus-omp-svd
yger Oct 8, 2024
40fd191
Naming
yger Oct 9, 2024
20fc26e
Merge branch 'main' into sc2_fixes
yger Oct 9, 2024
2e82342
Patch imports
yger Oct 9, 2024
077e58e
Merge branch 'sc2_fixes' of github.com:yger/spikeinterface into sc2_f…
yger Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 18 additions & 23 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"whitening": {"mode": "local", "regularize": True},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {
"method": "uniform",
Expand All @@ -36,7 +37,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"seed": 42,
},
"apply_motion_correction": True,
"motion_correction": {"preset": "nonrigid_fast_and_accurate"},
"motion_correction": {"preset": "dredge_fast"},
"merging": {
"similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2},
"correlograms_kwargs": {},
Expand All @@ -46,7 +47,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
},
},
"clustering": {"legacy": True},
"matching": {"method": "wobble"},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
"matched_filtering": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
Expand All @@ -62,6 +63,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
and also the radius_um used to be considered during clustering",
"sparsity": "A dictionary to be passed to all the calls to sparsify the templates",
"filtering": "A dictionary for the high_pass filter to be used during preprocessing",
"whitening": "A dictionary for the whitening option to be used during preprocessing",
"detection": "A dictionary for the peak detection node (locally_exclusive)",
"selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\
and 5000 peaks per electrode on average.",
Expand Down Expand Up @@ -109,8 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction
from spikeinterface.sortingcomponents.tools import get_prototype_spike

job_kwargs = params["job_kwargs"]
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs = fix_job_kwargs(params["job_kwargs"])
job_kwargs.update({"progress_bar": verbose})

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
Expand All @@ -119,7 +120,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
num_channels = recording.get_num_channels()
ms_before = params["general"].get("ms_before", 2)
ms_after = params["general"].get("ms_after", 2)
radius_um = params["general"].get("radius_um", 100)
radius_um = params["general"].get("radius_um", 75)
exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2)

## First, we are filtering the data
Expand All @@ -143,14 +144,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
print("Motion correction activated (probe geometry compatible)")
motion_folder = sorter_output_folder / "motion"
params["motion_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["motion_correction"])
recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs)
else:
motion_folder = None

## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True)
whitening_kwargs = params["whitening"].copy()
whitening_kwargs["dtype"] = "float32"
whitening_kwargs["radius_um"] = radius_um
if num_channels == 1:
whitening_kwargs["regularize"] = False

recording_w = whiten(recording_f, **whitening_kwargs)
noise_levels = get_noise_levels(recording_w, return_scaled=False)

if recording_w.check_serializability("json"):
Expand All @@ -172,20 +178,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)

peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params)

if params["matched_filtering"]:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000)
prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before

for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in detection_params:
detection_params.pop(value)

detection_params["chunk_duration"] = "100ms"

peaks = detect_peaks(recording_w, "matched_filtering", **detection_params)
else:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params)

if verbose:
print("We found %d peaks in total" % len(peaks))
Expand All @@ -196,7 +196,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We subselect a subset of all the peaks, by making the distributions os SNRs over all
## channels as flat as possible
selection_params = params["selection"]
selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels
selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels)
selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"])

selection_params.update({"noise_levels": noise_levels})
Expand Down Expand Up @@ -281,11 +281,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_job_params = job_kwargs.copy()

if matching_method is not None:
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params[value] = None
matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_w, matching_method, method_kwargs=matching_params, **matching_job_params
)
Expand Down
18 changes: 8 additions & 10 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def main_function(cls, recording, peaks, params):
pipeline_nodes = [node0, node1, node2]

if len(params["recursive_kwargs"]) == 0:
from sklearn.decomposition import PCA

all_pc_data = run_node_pipeline(
recording,
Expand All @@ -152,9 +153,9 @@ def main_function(cls, recording, peaks, params):
sub_data = sub_data.reshape(len(sub_data), -1)

if all_pc_data.shape[1] > params["n_svd"][1]:
tsvd = TruncatedSVD(params["n_svd"][1])
tsvd = PCA(params["n_svd"][1], whiten=True)
yger marked this conversation as resolved.
Show resolved Hide resolved
else:
tsvd = TruncatedSVD(all_pc_data.shape[1])
tsvd = PCA(all_pc_data.shape[1], whiten=True)

hdbscan_data = tsvd.fit_transform(sub_data)
try:
Expand Down Expand Up @@ -184,14 +185,16 @@ def main_function(cls, recording, peaks, params):
)

sparse_mask = node1.neighbours_mask
neighbours_mask = get_channel_distances(recording) < radius_um
neighbours_mask = get_channel_distances(recording) <= radius_um

# np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters

min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50)

peak_labels, _ = split_clusters(
original_labels,
recording,
Expand All @@ -202,7 +205,7 @@ def main_function(cls, recording, peaks, params):
feature_name="sparse_tsvd",
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=50,
min_size_split=min_size,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=params["n_svd"][1],
scale_n_pca_by_depth=True,
Expand Down Expand Up @@ -233,7 +236,7 @@ def main_function(cls, recording, peaks, params):
if d["rank"] is not None:
from spikeinterface.sortingcomponents.matching.circus import compress_templates

_, _, _, templates_array = compress_templates(templates_array, 5)
_, _, _, templates_array = compress_templates(templates_array, d["rank"])

templates = Templates(
templates_array=templates_array,
Expand All @@ -258,13 +261,8 @@ def main_function(cls, recording, peaks, params):
print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids)))

cleaning_matching_params = params["job_kwargs"].copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in cleaning_matching_params:
cleaning_matching_params.pop(value)
cleaning_matching_params["chunk_duration"] = "100ms"
cleaning_matching_params["n_jobs"] = 1
cleaning_matching_params["progress_bar"] = False

cleaning_params = params["cleaning_kwargs"].copy()

labels, peak_labels = remove_duplicates_via_matching(
Expand Down
22 changes: 13 additions & 9 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def split_clusters(
peak_labels,
recording,
features_dict_or_folder,
method="hdbscan_on_local_pca",
method="local_feature_clustering",
method_kwargs={},
recursive=False,
recursive_depth=None,
Expand Down Expand Up @@ -81,7 +81,6 @@ def split_clusters(
) as pool:
labels_set = np.setdiff1d(peak_labels, [-1])
current_max_label = np.max(labels_set) + 1

jobs = []
for label in labels_set:
peak_indices = np.flatnonzero(peak_labels == label)
Expand All @@ -95,15 +94,14 @@ def split_clusters(

for res in iterator:
is_split, local_labels, peak_indices = res.result()
# print(is_split, local_labels, peak_indices)
if not is_split:
continue

mask = local_labels >= 0
peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label
peak_labels[peak_indices[~mask]] = local_labels[~mask]

split_count[peak_indices] += 1

current_max_label += np.max(local_labels[mask]) + 1

if recursive:
Expand All @@ -120,6 +118,7 @@ def split_clusters(
for label in new_labels_set:
peak_indices = np.flatnonzero(peak_labels == label)
if peak_indices.size > 0:
# print('Relaunched', label, len(peak_indices), recursion_level)
jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level))
if progress_bar:
iterator.total += 1
Expand Down Expand Up @@ -187,7 +186,7 @@ def split(
min_size_split=25,
n_pca_features=2,
scale_n_pca_by_depth=False,
minimum_common_channels=2,
minimum_overlap_ratio=0.25,
):
local_labels = np.zeros(peak_indices.size, dtype=np.int64)

Expand All @@ -199,19 +198,22 @@ def split(
# target channel subset is done intersect local channels + neighbours
local_chans = np.unique(peaks["channel_index"][peak_indices])

target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0))
target_intersection_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0))
target_union_channels = np.flatnonzero(np.any(neighbours_mask[local_chans, :], axis=0))
num_intersection = len(target_intersection_channels)
num_union = len(target_union_channels)

# TODO fix this a better way, this when cluster have too few overlapping channels
if target_channels.size < minimum_common_channels:
if (num_intersection / num_union) < minimum_overlap_ratio:
return False, None

aligned_wfs, dont_have_channels = aggregate_sparse_features(
peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels
peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_intersection_channels
)

local_labels[dont_have_channels] = -2
kept = np.flatnonzero(~dont_have_channels)

# print(recursion_level, kept.size, min_size_split)
if kept.size < min_size_split:
return False, None

Expand All @@ -222,6 +224,8 @@ def split(
if flatten_features.shape[1] > n_pca_features:
from sklearn.decomposition import PCA

# from sklearn.decomposition import TruncatedSVD

if scale_n_pca_by_depth:
# tsvd = TruncatedSVD(n_pca_features * recursion_level)
tsvd = PCA(n_pca_features * recursion_level, whiten=True)
Expand Down
Loading