Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc2 fixes #3250

Merged
merged 47 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
71821c2
Fixes
yger Jul 19, 2024
cc89e6e
Patches
yger Jul 19, 2024
62be466
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Jul 19, 2024
4e4b869
Fixes for SC2 and for split clustering
yger Jul 23, 2024
f351640
debugging clustering
yger Jul 24, 2024
b38c6f2
WIP
yger Jul 24, 2024
95ec4fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
5772b2d
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Jul 24, 2024
8cd975e
WIP
yger Jul 25, 2024
dbb21f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
0c91349
WIP
yger Jul 30, 2024
856bd4d
WIP
yger Jul 31, 2024
3f6f078
Default params
yger Jul 31, 2024
5ddbb91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2024
15bf498
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Aug 26, 2024
485a269
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Aug 27, 2024
37323ee
Merge branch 'sc2_fixes' of github.com:yger/spikeinterface into sc2_f…
yger Aug 28, 2024
39b0514
WIP
yger Aug 28, 2024
e71b7d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
f0a9417
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Sep 10, 2024
32c0b77
Merge branch 'main' into sc2_fixes
yger Sep 25, 2024
dcaa3fb
Adding gather_func to find_spikes
yger Sep 25, 2024
2067c37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
8552358
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Sep 26, 2024
ca5a1f3
Merge remote-tracking branch 'origin/gather_find_spikes' into sc2_fixes
yger Sep 26, 2024
355e5f5
Gathering mode more explicit for matching
yger Sep 26, 2024
10c7a35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
728c9f1
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Oct 7, 2024
6698843
WIP
yger Oct 7, 2024
e826830
WIP
yger Oct 7, 2024
7975c95
Merge branch 'node_pipeline_skip_no_peaks' of github.com:samuelgarcia…
yger Oct 7, 2024
affff49
Fixes for SC2
yger Oct 7, 2024
6f4268d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2024
4d39563
Merge branch 'SpikeInterface:main' into sc2_fixes
yger Oct 8, 2024
d83aca6
WIP
yger Oct 8, 2024
2ba923d
Simplifications
yger Oct 8, 2024
fd441a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
cf04170
Naming for Sam
yger Oct 8, 2024
258617e
Merge branch 'sc2_fixes' of github.com:yger/spikeinterface into sc2_f…
yger Oct 8, 2024
7a91086
Optimize circus matching engine
yger Oct 8, 2024
b428050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
1cc1578
Optimizations
yger Oct 8, 2024
cfaa98b
Remove the limit to chunk sizes in circus-omp-svd
yger Oct 8, 2024
40fd191
Naming
yger Oct 9, 2024
20fc26e
Merge branch 'main' into sc2_fixes
yger Oct 9, 2024
2e82342
Patch imports
yger Oct 9, 2024
077e58e
Merge branch 'sc2_fixes' of github.com:yger/spikeinterface into sc2_f…
yger Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"seed": 42,
},
"apply_motion_correction": True,
"motion_correction": {"preset": "nonrigid_fast_and_accurate"},
"motion_correction": {"preset": "dredge_fast"},
"merging": {
"similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2},
"correlograms_kwargs": {},
Expand Down Expand Up @@ -149,8 +149,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True)

if num_channels > 1:
regularize = True
else:
regularize = False
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=regularize)
noise_levels = get_noise_levels(recording_w, return_scaled=False)

if recording_w.check_serializability("json"):
Expand Down Expand Up @@ -196,7 +199,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We subselect a subset of all the peaks, by making the distributions os SNRs over all
## channels as flat as possible
selection_params = params["selection"]
selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels
selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels)
selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"])

selection_params.update({"noise_levels": noise_levels})
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def main_function(cls, recording, peaks, params):
pipeline_nodes = [node0, node1, node2]

if len(params["recursive_kwargs"]) == 0:
from sklearn.decomposition import PCA

all_pc_data = run_node_pipeline(
recording,
Expand All @@ -152,9 +153,9 @@ def main_function(cls, recording, peaks, params):
sub_data = sub_data.reshape(len(sub_data), -1)

if all_pc_data.shape[1] > params["n_svd"][1]:
tsvd = TruncatedSVD(params["n_svd"][1])
tsvd = PCA(params["n_svd"][1], whiten=True)
yger marked this conversation as resolved.
Show resolved Hide resolved
else:
tsvd = TruncatedSVD(all_pc_data.shape[1])
tsvd = PCA(all_pc_data.shape[1], whiten=True)

hdbscan_data = tsvd.fit_transform(sub_data)
try:
Expand Down Expand Up @@ -184,14 +185,16 @@ def main_function(cls, recording, peaks, params):
)

sparse_mask = node1.neighbours_mask
neighbours_mask = get_channel_distances(recording) < radius_um
neighbours_mask = get_channel_distances(recording) <= radius_um

# np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters

min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50)

peak_labels, _ = split_clusters(
original_labels,
recording,
Expand All @@ -202,7 +205,7 @@ def main_function(cls, recording, peaks, params):
feature_name="sparse_tsvd",
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=50,
min_size_split=min_size,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=params["n_svd"][1],
scale_n_pca_by_depth=True,
Expand Down Expand Up @@ -233,7 +236,7 @@ def main_function(cls, recording, peaks, params):
if d["rank"] is not None:
from spikeinterface.sortingcomponents.matching.circus import compress_templates

_, _, _, templates_array = compress_templates(templates_array, 5)
_, _, _, templates_array = compress_templates(templates_array, d["rank"])

templates = Templates(
templates_array=templates_array,
Expand Down
15 changes: 8 additions & 7 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def split_clusters(
peak_labels,
recording,
features_dict_or_folder,
method="hdbscan_on_local_pca",
method="local_feature_clustering",
method_kwargs={},
recursive=False,
recursive_depth=None,
Expand Down Expand Up @@ -69,7 +69,7 @@ def split_clusters(

original_labels = peak_labels
peak_labels = peak_labels.copy()
split_count = np.zeros(peak_labels.size, dtype=int)
split_count = np.ones(peak_labels.size, dtype=int)

Executor = get_poolexecutor(n_jobs)

Expand Down Expand Up @@ -101,7 +101,6 @@ def split_clusters(
mask = local_labels >= 0
peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label
peak_labels[peak_indices[~mask]] = local_labels[~mask]

split_count[peak_indices] += 1

current_max_label += np.max(local_labels[mask]) + 1
Expand All @@ -120,6 +119,7 @@ def split_clusters(
for label in new_labels_set:
peak_indices = np.flatnonzero(peak_labels == label)
if peak_indices.size > 0:
# print('Relaunched', label, len(peak_indices), recursion_level)
jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level))
if progress_bar:
iterator.total += 1
Expand Down Expand Up @@ -187,7 +187,7 @@ def split(
min_size_split=25,
n_pca_features=2,
scale_n_pca_by_depth=False,
minimum_common_channels=2,
minimum_common_channels=1,
):
local_labels = np.zeros(peak_indices.size, dtype=np.int64)

Expand All @@ -198,9 +198,8 @@ def split(

# target channel subset is done intersect local channels + neighbours
local_chans = np.unique(peaks["channel_index"][peak_indices])

target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0))

# print(recursion_level, target_channels)
# TODO fix this a better way, this when cluster have too few overlapping channels
if target_channels.size < minimum_common_channels:
return False, None
Expand All @@ -211,7 +210,7 @@ def split(

local_labels[dont_have_channels] = -2
kept = np.flatnonzero(~dont_have_channels)

# print(recursion_level, kept.size, min_size_split)
if kept.size < min_size_split:
return False, None

Expand All @@ -222,6 +221,8 @@ def split(
if flatten_features.shape[1] > n_pca_features:
from sklearn.decomposition import PCA

# from sklearn.decomposition import TruncatedSVD

if scale_n_pca_by_depth:
# tsvd = TruncatedSVD(n_pca_features * recursion_level)
tsvd = PCA(n_pca_features * recursion_level, whiten=True)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self, recording, return_output=True, parents=None, radius_um=75.0):
self.radius_um = radius_um
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um
self._kwargs["radius_um"] = radius_um

def get_dtype(self):
Expand Down