From f9ff667188d2b70026dcf2267c367a0e92a9ce4d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:03:06 +0100 Subject: [PATCH 01/90] Add for 'set_files'. --- src/spikeinterface/sorters/external/kilosort4.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..92bfabbe73 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -205,7 +205,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder - filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) + + filename, data_dir, results_dir, probe = set_files( + settings=settings, + filename=filename, + probe=probe, + probe_name=probe_name, + data_dir=data_dir, + results_dir=results_dir, + ) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( From aaa389f78243ac5c40c89f78cf282e69b591aebe Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:06:30 +0100 Subject: [PATCH 02/90] Add for 'initialize_ops'. --- .../sorters/external/kilosort4.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 92bfabbe73..b723e7a2bb 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -216,12 +216,27 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + save_preprocessed_copy=False, + ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) else: - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( get_run_parameters(ops) ) From 3ea9b8da6c9c0ce3672fc2b60d551cbfa96f8552 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:07:08 +0100 Subject: [PATCH 03/90] Add for 'compute_preprocessing'. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index b723e7a2bb..d8b1f1a60a 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -243,7 +243,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: - ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) + ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object) else: print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( From 28656425eac96ef7be1573256c316c23b057f1c5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:07:43 +0100 Subject: [PATCH 04/90] Add for 'compute_drift_correction'. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index d8b1f1a60a..d187b445ef 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -278,7 +278,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object ) # Sort spikes and save results From 9e0207aed2f92424e5d8d8088ce6c95de286eb38 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:11:42 +0100 Subject: [PATCH 05/90] Add for detect_spikes, cluster_spikes, save_sorting. --- .../sorters/external/kilosort4.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index d187b445ef..032f980ee2 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -282,14 +282,28 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) - clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar) + + clu, Wall = cluster_spikes( + st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar + ) + if params["skip_kilosort_preprocessing"]: ops["preprocessing"] = dict( hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + _ = save_sorting( + ops=ops, + results_dir=results_dir, + st=st, + clu=clu, + tF=tF, + Wall=Wall, + imin=bfile.imin, + tic0=tic0, + save_extra_vars=save_extra_vars, + ) @classmethod def _get_result_from_folder(cls, sorter_output_folder): From b07359ff360a210ae864fbf43c23d805b9507300 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:13:46 +0100 Subject: [PATCH 06/90] Add for 'load_probe', 'RecordingExtractorAsArray'. --- src/spikeinterface/sorters/external/kilosort4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 032f980ee2..ba1b10b793 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -176,12 +176,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # load probe recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - probe = load_probe(probe_filename) + probe = load_probe(probe_path=probe_filename) probe_name = "" filename = "" # this internally concatenates the recording - file_object = RecordingExtractorAsArray(recording) + file_object = RecordingExtractorAsArray(recording_extractor=recording) do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] From ac844e9d550624f007f851c1cc061e5c36abb002 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:47:22 +0100 Subject: [PATCH 07/90] Add for BinaryFiltered + some generate notes. --- .../sorters/external/kilosort4.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index ba1b10b793..47ef328b28 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -17,6 +17,7 @@ class Kilosort4Sorter(BaseSorter): requires_locations = True gpu_capability = "nvidia-optional" + # Q: Should we take these directly from the KS defaults? https://github.com/MouseLand/Kilosort/blob/59c03b060cc8e8ac75a7f1a972a8b5c5af3f41a6/kilosort/parameters.py#L164 _default_params = { "batch_size": 60000, "nblocks": 1, @@ -25,8 +26,8 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": True, "invert_sign": False, "nt": 61, - "shift": None, - "scale": None, + "shift": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 + "scale": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 "artifact_threshold": None, "nskip": 25, "whitening_range": 32, @@ -247,16 +248,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( - ops["filename"], - n_chan_bin, - fs, - NT, - nt, - twav_min, - chan_map, + filename=ops["filename"], + n_chan_bin=n_chan_bin, + fs=fs, + nT=NT, + nt=nt, + nt0min=twav_min, + chan_map=chan_map, hp_filter=None, device=device, - do_CAR=do_CAR, + do_CAR=do_CAR, # TODO: should this always be False if we are in skipping KS preprocessing land? invert_sign=invert, dtype=dtype, tmin=tmin, From 44835bb397a36ebfd914a6c2a8038bf3727b95e3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:50:12 +0100 Subject: [PATCH 08/90] Update note on DEFAULT_SETTINGS. --- src/spikeinterface/sorters/external/kilosort4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 47ef328b28..bcd8ddc617 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -18,6 +18,8 @@ class Kilosort4Sorter(BaseSorter): gpu_capability = "nvidia-optional" # Q: Should we take these directly from the KS defaults? https://github.com/MouseLand/Kilosort/blob/59c03b060cc8e8ac75a7f1a972a8b5c5af3f41a6/kilosort/parameters.py#L164 + # I see these overwrite the `DEFAULT_SETTINGS`. Do we want to do this? There is benefit to fixing on the SI side, but users switching KS version would expect + # the defaults to represent the KS version. This could lead to divergence in result between users running KS directly vs. the SI wrapper. _default_params = { "batch_size": 60000, "nblocks": 1, From 5bdc31e1ac6f2b3ecde2f2d428f4bae306dacfb3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:59:16 +0100 Subject: [PATCH 09/90] Remove some TODO and notes. --- src/spikeinterface/sorters/external/kilosort4.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index bcd8ddc617..cba7e65517 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -17,9 +17,6 @@ class Kilosort4Sorter(BaseSorter): requires_locations = True gpu_capability = "nvidia-optional" - # Q: Should we take these directly from the KS defaults? https://github.com/MouseLand/Kilosort/blob/59c03b060cc8e8ac75a7f1a972a8b5c5af3f41a6/kilosort/parameters.py#L164 - # I see these overwrite the `DEFAULT_SETTINGS`. Do we want to do this? There is benefit to fixing on the SI side, but users switching KS version would expect - # the defaults to represent the KS version. This could lead to divergence in result between users running KS directly vs. the SI wrapper. _default_params = { "batch_size": 60000, "nblocks": 1, @@ -28,8 +25,8 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": True, "invert_sign": False, "nt": 61, - "shift": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 - "scale": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 + "shift": None, + "scale": None, "artifact_threshold": None, "nskip": 25, "whitening_range": 32, From dc848eb2f8691206826d9545927d9cf28fbcd558 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 02:04:22 +0100 Subject: [PATCH 10/90] Use version to handle all KS versions some which are missing .__version__ attribute. --- .../sorters/external/kilosort4.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index cba7e65517..ed41baeff9 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,6 +6,7 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase +from importlib.metadata import version PathType = Union[str, Path] @@ -129,9 +130,8 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - import kilosort as ks - - return ks.__version__ + """kilosort version <0.0.10 is always '4' z""" + return version("kilosort") @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @@ -216,6 +216,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + # TODO: save_preprocessed_copy added ops = initialize_ops( settings=settings, probe=probe, @@ -225,9 +226,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): device=device, save_preprocessed_copy=False, ) - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( - get_run_parameters(ops) - ) else: ops = initialize_ops( settings=settings, @@ -237,6 +235,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): invert_sign=invert_sign, device=device, ) + + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): + # TODO: shift, scaled added + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) + else: n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( get_run_parameters(ops) ) @@ -259,10 +264,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR=do_CAR, # TODO: should this always be False if we are in skipping KS preprocessing land? invert_sign=invert, dtype=dtype, - tmin=tmin, + tmin=tmin, # TODO: exposing tmin, max? tmax=tmax, artifact_threshold=artifact, - file_object=file_object, + file_object=file_object, # TODO: exposing shift, scale when skipping preprocessing? ) ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) From 69e72bf0ddfafe42577959e35eac96f184acb727 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 02:06:20 +0100 Subject: [PATCH 11/90] Remove unused vars that were left over I think from prev KS versions. --- src/spikeinterface/sorters/external/kilosort4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index ed41baeff9..9320022a20 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -94,11 +94,9 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", - "keep_good_only": "If True only 'good' units are returned", "do_correction": "If True, drift correction is performed", "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", - "scaleproc": "int16 scaling of whitened data, if None set to 200.", "torch_device": "Select the torch device auto/cuda/cpu", } From c3b2bdda3d2f2f1db009302529c6c9b50a3781b9 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 12:19:16 +0100 Subject: [PATCH 12/90] Use importlib version instead of .__version__ --- src/spikeinterface/sorters/external/kilosort4.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 9320022a20..65f1483348 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,7 +6,6 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase -from importlib.metadata import version PathType = Union[str, Path] @@ -129,7 +128,10 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): """kilosort version <0.0.10 is always '4' z""" - return version("kilosort") + # Note this import clashes with version! + from importlib.metadata import version as importlib_version + + return importlib_version("kilosort") @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): From 52457224b0c724e5c0ee4f5d1e659ae7c3159b91 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 12:28:49 +0100 Subject: [PATCH 13/90] Add kilosort test script and CI workflow. --- .github/workflows/test_kilosort4.yml | 61 +++ .../temp_test_file_dir/test_kilosort4_new.py | 472 ++++++++++++++++++ 2 files changed, 533 insertions(+) create mode 100644 .github/workflows/test_kilosort4.yml create mode 100644 src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml new file mode 100644 index 0000000000..8e57f79786 --- /dev/null +++ b/.github/workflows/test_kilosort4.yml @@ -0,0 +1,61 @@ +name: Testing Kilosort4 + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + +# env: +# KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} +# KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + +# concurrency: # Cancel previous workflows on the same pull request +# group: ${{ github.workflow }}-${{ github.ref }} +# cancel-in-progress: true + +jobs: + run: + name: ${{ matrix.os }} Python ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] # TODO: "3.9", # Lower and higher versions we support + os: [ubuntu-latest] # TODO: macos-13, windows-latest, + ks_version: ["4.0.12"] # TODO: add / build from pypi based on Christians PR + steps: + - uses: actions/checkout@v4 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install packages + # TODO: maybe dont need full? + run: | + pip install -e .[test] + # git config --global user.email "CI@example.com" + # git config --global user.name "CI Almighty" + # pip install tabulate + shell: bash + + - name: Install Kilosort + run: | + pip install kilosort==${{ matrix.ks_version }} + shell: bash + + - name: Run new kilosort4 tests + # run: chmod +x .github/test_kilosort4.sh + # TODO: figure out the paths to be able to run this by calling the file directly + run: | + pytest -k test_kilosort4_new --durations=0 + shell: bash + +# TODO: pip install -e .[full,dev] is failing # +#The conflict is caused by: +# spikeinterface[docs] 0.101.0rc0 depends on datalad==0.16.2; extra == "docs" +# spikeinterface[test] 0.101.0rc0 depends on datalad>=1.0.2; extra == "test" diff --git a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py new file mode 100644 index 0000000000..0fb9841728 --- /dev/null +++ b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py @@ -0,0 +1,472 @@ +import copy +from typing import Any +import spikeinterface.full as si +import numpy as np +import torch +import kilosort +from kilosort.io import load_probe +import pandas as pd + +import pytest +from probeinterface.io import write_prb +from kilosort.parameters import DEFAULT_SETTINGS +from packaging.version import parse +from importlib.metadata import version + +# TODO: duplicate_spike_bins to duplicate_spike_ms +# TODO: write an issue on KS about bin! vs bin_ms! +# TODO: expose tmin and tmax +# TODO: expose save_preprocessed_copy +# TODO: make here a log of all API changes (or on kilosort4.py) +# TODO: try out longer recordings and do some benchmarking tests.. +# TODO: expose tmin and tmax +# There is no way to skip HP spatial filter +# might as well expose tmin and tmax +# might as well expose preprocessing save (across the two functions that use it) +# BinaryFilter added scale and shift as new arguments recently +# test with docker +# test all params once +# try and read func / class object to see kwargs +# Shift and scale are also taken as a function on BinaryFilter. Do we want to apply these even when +# do kilosort preprocessing is false? probably +# TODO: find a test case for the other annoying ones (larger recording, variable amplitude) +# TODO: test docker +# TODO: test multi-segment recording +# TODO: test do correction, skip preprocessing +# TODO: can we rename 'save_extra_kwargs' to 'save_extra_vars'. Currently untested. +# nt : # TODO: can't kilosort figure this out from sampling rate? +# TODO: also test runtimes +# TODO: test skip preprocessing separately +# TODO: the pure default case is not tested +# TODO: shift and scale - this is also added to BinaryFilter + +RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # TODO: ignore some of these +# "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. + + +PARAMS_TO_TEST = [ + # Not tested + # ("torch_device", "auto") + # Stable across KS version 4.0.01 - 4.0.12 + ("change_nothing", None), + ("nblocks", 0), + ("do_CAR", False), + ("batch_size", 42743), # Q: how much do these results change with batch size? + ("Th_universal", 12), + ("Th_learned", 14), + ("invert_sign", True), + ("nt", 93), + ("nskip", 1), + ("whitening_range", 16), + ("sig_interp", 5), + ("nt0min", 25), + ("dmin", 15), + ("dminx", 16), + ("min_template_size", 15), + ("template_sizes", 10), + ("nearest_chans", 8), + ("nearest_templates", 35), + ("max_channel_distance", 5), + ("templates_from_data", False), + ("n_templates", 10), + ("n_pcs", 3), + ("Th_single_ch", 4), + ("acg_threshold", 0.001), + ("x_centers", 5), + ("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS + ("binning_depth", 1), + ("artifact_threshold", 200), + ("ccg_threshold", 1e9), + ("cluster_downsampling", 1e9), + ("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13! +] + +# Update PARAMS_TO_TEST with version-dependent kwargs +if parse(version("kilosort")) >= parse("4.0.12"): + pass # TODO: expose? +# PARAMS_TO_TEST.extend( +# [ +# ("save_preprocessed_copy", False), +# ] +# ) +if parse(version("kilosort")) >= parse("4.0.11"): + PARAMS_TO_TEST.extend( + [ + ("shift", 1e9), + ("scale", -1e9), + ] + ) +if parse(version("kilosort")) == parse("4.0.9"): + # bug in 4.0.9 for "nblocks=0" + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] != "nblocks"] + +if parse(version("kilosort")) >= parse("4.0.8"): + PARAMS_TO_TEST.extend( + [ + ("drift_smoothing", [250, 250, 250]), + ] + ) +if parse(version("kilosort")) <= parse("4.0.6"): + # AFAIK this parameter was always unused in KS (that's why it was removed) + PARAMS_TO_TEST.extend( + [ + ("cluster_pcs", 1e9), + ] + ) +if parse(version("kilosort")) <= parse("4.0.3"): + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] not in ["x_centers", "max_channel_distance"]] + + +class TestKilosort4Long: + + # Fixtures ###### + @pytest.fixture(scope="session") + def recording_and_paths(self, tmp_path_factory): + """ """ + tmp_path = tmp_path_factory.mktemp("kilosort4_tests") + + np.random.seed(0) # TODO: check below... + + recording = self._get_ground_truth_recording() + + paths = self._save_ground_truth_recording(recording, tmp_path) + + return (recording, paths) + + @pytest.fixture(scope="session") + def default_results(self, recording_and_paths): + """ """ + recording, paths = recording_and_paths + + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths) + + defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=defaults_ks_output_dir, + ) + + default_results = self._get_sorting_output(defaults_ks_output_dir) + + return default_results + + # Tests ###### + def test_params_to_test(self): + """ + Test that all parameters in PARAMS_TO_TEST are + different than the default value used in Kilosort, otherwise + there is no point to the test. + + TODO: need to use _default_params vs. DEFAULT_SETTINGS + depending on decision + + TODO: write issue on this, we hope it will be on DEFAULT_SETTINGS + TODO: duplicate_spike_ms in POSTPROCESSING but seems unused? + """ + for parameter in PARAMS_TO_TEST: + + param_key, param_value = parameter + + if param_key == "change_nothing": + continue + + if param_key not in RUN_KILOSORT_ARGS: + assert DEFAULT_SETTINGS[param_key] != param_value, f"{param_key} values should be different in test." + + def test_default_settings_all_represented(self): + """ + Test that every entry in DEFAULT_SETTINGS is tested in + PARAMS_TO_TEST, otherwise we are missing settings added + on the KS side. + """ + tested_keys = [entry[0] for entry in PARAMS_TO_TEST] + + for param_key in DEFAULT_SETTINGS: + + if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) + def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): + """ """ + recording, paths = recording_and_paths + param_key, param_value = parameter + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + extra_ks_settings = {} + if param_key == "binning_depth": + extra_ks_settings.update({"nblocks": 5}) + + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + extra_ks_settings.update({param_key: param_value}) + run_kilosort_kwargs = {} + + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_ks_settings) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + **run_kilosort_kwargs, + ) + + extra_si_settings = {} + if param_key != "change_nothing": + extra_si_settings.update({param_key: param_value}) + + if param_key == "binning_depth": + extra_si_settings.update({"nblocks": 5}) + + spikeinterface_settings = self._get_spikeinterface_settings(extra_settings=extra_si_settings) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + **spikeinterface_settings, + ) + + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" + + assert all( + results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0] + ), f"{param_key} cluster assignment different" + assert all( + results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1] + ), f"{param_key} cluster quality different" # TODO: check pandas probably better way + + # This is saved on the SI side so not an extremely + # robust addition, but it can't hurt. + if param_key != "change_nothing": + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value + + # Finally, check out test parameters actually changes stuff! + if parse(version("kilosort")) > parse("4.0.4"): + self._check_test_parameters_are_actually_changing_the_output(results, default_results, param_key) + + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): + """ """ + recording, paths = recording_and_paths + + kilosort_output_dir = tmp_path / "kilosort_output_dir" # TODO: a lost of copying here + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=True, + ) + + spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 6}) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_correction=False, + **spikeinterface_settings, + ) + + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + + assert all(results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0]) + assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]) + + def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch): + """ """ + recording = self._get_ground_truth_recording() + + # We need to filter and whiten the recording here to KS takes forever. + # Do this in a way differnt to KS. + recording = si.highpass_filter(recording, 300) + recording = si.whiten(recording, mode="local", apply_mean=False) + + paths = self._save_ground_truth_recording(recording, tmp_path) + + kilosort_default_output_dir = tmp_path / "kilosort_default_output_dir" + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + ks_settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + + kilosort.run_kilosort( + settings=ks_settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_default_output_dir, + do_CAR=False, + ) + + # Now the tricky bit, we need to turn off preprocessing in kilosort. + # This is not exposed by run_kilosort() arguments (at 4.0.12 at least) + # and so we need to monkeypatch the internal functions. The easiest + # thing to do would be to set `get_highpass_filter()` and + # `get_whitening_matrix()` to return `None` so these steps are skipped + # in BinaryFilter. Unfortunately the ops saving machinery requires + # these to be torch arrays and will error otherwise, so instead + # we must set the filter (in frequency space) and whitening matrix + # to unity operations so the filter and whitening do nothing. It is + # also required to turn off motion correection to avoid some additional + # magic KS is doing at the whitening step when motion correction is on. + fake_filter = np.ones(60122, dtype="float32") # TODO: hard coded + fake_filter = torch.from_numpy(fake_filter).to("cpu") + + fake_white_matrix = np.eye(recording.get_num_channels(), dtype="float32") + fake_white_matrix = torch.from_numpy(fake_white_matrix).to("cpu") + + def fake_fft_highpass(*args, **kwargs): + return fake_filter + + def fake_get_whitening_matrix(*args, **kwargs): + return fake_white_matrix + + def fake_fftshift(X, dim): + return X + + monkeypatch.setattr("kilosort.io.fft_highpass", fake_fft_highpass) + monkeypatch.setattr("kilosort.preprocessing.get_whitening_matrix", fake_get_whitening_matrix) + monkeypatch.setattr("kilosort.io.fftshift", fake_fftshift) + + kilosort.run_kilosort( + settings=ks_settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=False, + ) + + monkeypatch.undo() + + # Now, run kilosort through spikeinterface with the same options. + spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 0}) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_CAR=False, + skip_kilosort_preprocessing=True, + **spikeinterface_settings, + ) + + default_results = self._get_sorting_output(kilosort_default_output_dir) + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + # Check that out intervention actually make some difference to KS output + # (or this test would do nothing). Then check SI and KS outputs with + # preprocessing skipped are identical. + assert not np.array_equal(default_results["ks"]["st"], results["ks"]["st"]) + assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + + # Helpers ###### + def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): + """ """ + if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]: + num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size + num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size + + if param_key == "change_nothing": + # TODO: lol + assert ( + (results["si"]["st"].size == default_results["ks"]["st"].size) + and num_clus == num_clus_default + and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) + ), f"{param_key} changed somehow!." + else: + assert ( + (results["si"]["st"].size != default_results["ks"]["st"].size) + or num_clus != num_clus_default + or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) + ), f"{param_key} results did not change with parameter change." + + def _run_kilosort_with_kilosort(self, recording, paths, extra_settings=None): + """ """ + # dont actually run KS here because we will overwrite the defaults! + settings = { + "data_dir": paths["recording_path"], + "n_chan_bin": recording.get_num_channels(), + "fs": recording.get_sampling_frequency(), + } + + if extra_settings is not None: + settings.update(extra_settings) + + ks_format_probe = load_probe(paths["probe_path"]) + + return settings, ks_format_probe + + def _get_spikeinterface_settings(self, extra_settings=None): + """ """ + # dont actually run here. + settings = copy.deepcopy(DEFAULT_SETTINGS) + + if extra_settings is not None: + settings.update(extra_settings) + + for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # TODO: check tmin and tmax + settings.pop(name) + + return settings + + def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: + """ """ + results = { + "si": {}, + "ks": {}, + } + if kilosort_output_dir: + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = pd.read_table(kilosort_output_dir / "cluster_group.tsv") + + if spikeinterface_output_dir: + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = pd.read_table(spikeinterface_output_dir / "sorter_output" / "cluster_group.tsv") + + return results + + def _get_ground_truth_recording(self): + """ """ + # Chosen so all parameter changes to indeed change the output + num_channels = 32 + recording, _ = si.generate_ground_truth_recording( + durations=[5], + seed=0, + num_channels=num_channels, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), + ) + return recording + + def _save_ground_truth_recording(self, recording, tmp_path): + """ """ + paths = { + "session_scope_tmp_path": tmp_path, + "recording_path": tmp_path / "my_test_recording", + "probe_path": tmp_path / "my_test_probe.prb", + } + + recording.save(folder=paths["recording_path"], overwrite=True) + + probegroup = recording.get_probegroup() + write_prb(paths["probe_path"].as_posix(), probegroup) + + return paths From ede9dd482163728901dd118973c86d946ffd5f16 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 14:20:43 +0100 Subject: [PATCH 14/90] Fix save_preprocesed copy, argument mispelled. --- src/spikeinterface/sorters/external/kilosort4.py | 4 ++-- src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 65f1483348..449ddfbff1 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -216,7 +216,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - # TODO: save_preprocessed_copy added + # TODO: save_preprocesed_copy added ops = initialize_ops( settings=settings, probe=probe, @@ -224,7 +224,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR=do_CAR, invert_sign=invert_sign, device=device, - save_preprocessed_copy=False, + save_preprocesed_copy=False, ) else: ops = initialize_ops( diff --git a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py index 0fb9841728..e4d48a1344 100644 --- a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py +++ b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py @@ -13,6 +13,7 @@ from packaging.version import parse from importlib.metadata import version +# TODO: save_preprocesed_copy is misspelled in KS4. # TODO: duplicate_spike_bins to duplicate_spike_ms # TODO: write an issue on KS about bin! vs bin_ms! # TODO: expose tmin and tmax From 9570c9273b4f86bd800120c6d05096ebfc82e85d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 14:36:55 +0100 Subject: [PATCH 15/90] Fix NT format for BinaryFiltered, double-check all again --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 449ddfbff1..28a3c3ffa3 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -255,7 +255,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): filename=ops["filename"], n_chan_bin=n_chan_bin, fs=fs, - nT=NT, + NT=NT, nt=nt, nt0min=twav_min, chan_map=chan_map, From a8489a50a0d4ccaa1c6e75307b73fcae7a8c4bc2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 15:25:08 +0100 Subject: [PATCH 16/90] Add CI to test all kilosort4 versions. --- .github/scripts/README.MD | 2 + .github/scripts/check_kilosort4_releases.py | 20 ++++ .../scripts/test_kilosort4_ci.py | 106 +++++++++++++++++- .github/workflows/test_kilosort4.yml | 63 ++++++----- conftest.py | 7 +- 5 files changed, 170 insertions(+), 28 deletions(-) create mode 100644 .github/scripts/README.MD create mode 100644 .github/scripts/check_kilosort4_releases.py rename src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py => .github/scripts/test_kilosort4_ci.py (83%) diff --git a/.github/scripts/README.MD b/.github/scripts/README.MD new file mode 100644 index 0000000000..1d3a622aae --- /dev/null +++ b/.github/scripts/README.MD @@ -0,0 +1,2 @@ +This folder contains test scripts for running in the CI, that are not run as part of the usual +CI because they are too long / heavy. These are run on cron-jobs once per week. diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py new file mode 100644 index 0000000000..3d04d6948a --- /dev/null +++ b/.github/scripts/check_kilosort4_releases.py @@ -0,0 +1,20 @@ +import os +import re +from pathlib import Path +import requests +import json + + +def get_pypi_versions(package_name): + url = f"https://pypi.org/pypi/{package_name}/json" + response = requests.get(url) + response.raise_for_status() + data = response.json() + return list(sorted(data["releases"].keys())) + + +if __name__ == "__main__": + package_name = "kilosort" + versions = get_pypi_versions(package_name) + with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + json.dump(versions, f) diff --git a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py b/.github/scripts/test_kilosort4_ci.py similarity index 83% rename from src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py rename to .github/scripts/test_kilosort4_ci.py index e4d48a1344..4684038bd0 100644 --- a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -12,6 +12,14 @@ from kilosort.parameters import DEFAULT_SETTINGS from packaging.version import parse from importlib.metadata import version +from inspect import signature +from kilosort.run_kilosort import (set_files, initialize_ops, + compute_preprocessing, + compute_drift_correction, detect_spikes, + cluster_spikes, save_sorting, + get_run_parameters, ) +from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered +from kilosort.parameters import DEFAULT_SETTINGS # TODO: save_preprocesed_copy is misspelled in KS4. # TODO: duplicate_spike_bins to duplicate_spike_ms @@ -190,6 +198,102 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + def test_set_files_arguments(self): + self._check_arguments( + set_files, + ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"] + ) + + def test_initialize_ops_arguments(self): + + expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] + + if parse(version("kilosort")) >= parse("4.0.12"): + expected_arguments.append("save_preprocesed_copy") + + self._check_arguments( + initialize_ops, + expected_arguments, + ) + + def test_compute_preprocessing_arguments(self): + self._check_arguments( + compute_preprocessing, + ["ops", "device", "tic0", "file_object"] + ) + + def test_compute_drift_location_arguments(self): + self._check_arguments( + compute_drift_correction, + ["ops", "device", "tic0", "progress_bar", "file_object"] + ) + + def test_detect_spikes_arguments(self): + self._check_arguments( + detect_spikes, + ["ops", "device", "bfile", "tic0", "progress_bar"] + ) + + + def test_cluster_spikes_arguments(self): + self._check_arguments( + cluster_spikes, + ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"] + ) + + def test_save_sorting_arguments(self): + + expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] + + if parse(version("kilosort")) > parse("4.0.11"): + expected_arguments.append("save_preprocessed_copy") + + self._check_arguments( + save_sorting, + expected_arguments + ) + + def test_get_run_parameters(self): + self._check_arguments( + get_run_parameters, + ["ops"] + ) + + def test_load_probe_parameters(self): + self._check_arguments( + load_probe, + ["probe_path"] + ) + + def test_recording_extractor_as_array_arguments(self): + self._check_arguments( + RecordingExtractorAsArray, + ["recording_extractor"] + ) + + def test_binary_filtered_arguments(self): + + expected_arguments = [ + "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", + "chan_map", "hp_filter", "whiten_mat", "dshift", + "device", "do_CAR", "artifact_threshold", "invert_sign", + "dtype", "tmin", "tmax", "file_object" + ] + + if parse(version("kilosort")) >= parse("4.0.11"): + expected_arguments.pop(-1) + expected_arguments.extend(["shift", "scale", "file_object"]) + + self._check_arguments( + BinaryFiltered, + expected_arguments + ) + + def _check_arguments(self, object_, expected_arguments): + sig = signature(object_) + obj_arguments = list(sig.parameters.keys()) + assert expected_arguments == obj_arguments + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): """ """ @@ -381,7 +485,7 @@ def fake_fftshift(X, dim): # Helpers ###### def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): """ """ - if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]: + if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling", "cluster_pcs"]: num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 8e57f79786..c216be20d0 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -9,38 +9,56 @@ on: branches: - main -# env: -# KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} -# KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} +jobs: + versions: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: Checkout repository + uses: actions/checkout@v2 -# concurrency: # Cancel previous workflows on the same pull request -# group: ${{ github.workflow }}-${{ github.ref }} -# cancel-in-progress: true + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.12 -jobs: - run: - name: ${{ matrix.os }} Python ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install requests + + - name: Fetch package versions from PyPI + run: | + python .github/scripts/check_kilosort4_releases.py + shell: bash + + - name: Set matrix data + id: set-matrix + run: | + echo "matrix=$(jq -c . < .github/scripts/kilosort4-latest-version.json)" >> $GITHUB_OUTPUT + + test: + needs: versions + name: ${{ matrix.ks_version }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - python-version: ["3.12"] # TODO: "3.9", # Lower and higher versions we support - os: [ubuntu-latest] # TODO: macos-13, windows-latest, - ks_version: ["4.0.12"] # TODO: add / build from pypi based on Christians PR + python-version: ["3.12"] + os: [ubuntu-latest] + ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: - - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install packages - # TODO: maybe dont need full? + - name: Install SpikeInterface run: | pip install -e .[test] - # git config --global user.email "CI@example.com" - # git config --global user.name "CI Almighty" - # pip install tabulate shell: bash - name: Install Kilosort @@ -49,13 +67,6 @@ jobs: shell: bash - name: Run new kilosort4 tests - # run: chmod +x .github/test_kilosort4.sh - # TODO: figure out the paths to be able to run this by calling the file directly run: | - pytest -k test_kilosort4_new --durations=0 + pytest .github/scripts/test_kilosort4_ci.py shell: bash - -# TODO: pip install -e .[full,dev] is failing # -#The conflict is caused by: -# spikeinterface[docs] 0.101.0rc0 depends on datalad==0.16.2; extra == "docs" -# spikeinterface[test] 0.101.0rc0 depends on datalad>=1.0.2; extra == "test" diff --git a/conftest.py b/conftest.py index c4bac6628a..8c06830d25 100644 --- a/conftest.py +++ b/conftest.py @@ -19,6 +19,7 @@ def create_cache_folder(tmp_path_factory): cache_folder = tmp_path_factory.mktemp("cache_folder") return cache_folder + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location @@ -28,7 +29,11 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - rel_path = Path(item.fspath).relative_to(modules_location) + try: # TODO: make a note on this, check with Herberto its okay. + rel_path = Path(item.fspath).relative_to(modules_location) + except: + continue + module = rel_path.parts[0] if module == "sorters": if "internal" in rel_path.parts: From 159e2b0a92b87ebaddedbf12cc68062bd0e5e5eb Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 01:33:48 +0100 Subject: [PATCH 17/90] Tidying up tests and removing comments from kilosort4.py. --- .github/scripts/test_kilosort4_ci.py | 442 ++++++++++-------- conftest.py | 2 +- .../sorters/external/kilosort4.py | 14 +- 3 files changed, 247 insertions(+), 211 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 4684038bd0..8a455a41fe 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -1,3 +1,23 @@ +""" +This file tests the SpikeInterface wrapper of the Kilosort4. The general logic +of the tests are: +- Change every exposed parameter one at a time (PARAMS_TO_TEST). Check that + the result of the SpikeInterface wrapper and Kilosort run natively are + the same. The SpikeInterface wrapper is non-trivial and decomposes the + kilosort pipeline to allow additions such as skipping preprocessing. Therefore, + the idea is that is it safer to rely on the output directly rather than + try monkeypatching. One thing can could be better tested is parameter + changes when skipping KS4 preprocessing is true, because this takes a slightly + different path through the kilosort4.py wrapper logic. + This also checks that changing the parameter changes the test output from default + on our test case (otherwise, the test could not detect a failure). This is possible + for nearly all parameters, see `_check_test_parameters_are_changing_the_output()`. + +- Test that kilosort functions called from `kilosort4.py` wrapper have the expected + input signatures + +- Do some tests to check all KS4 parameters are tested against. +""" import copy from typing import Any import spikeinterface.full as si @@ -20,47 +40,21 @@ get_run_parameters, ) from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered from kilosort.parameters import DEFAULT_SETTINGS +from kilosort import preprocessing as ks_preprocessing -# TODO: save_preprocesed_copy is misspelled in KS4. -# TODO: duplicate_spike_bins to duplicate_spike_ms -# TODO: write an issue on KS about bin! vs bin_ms! -# TODO: expose tmin and tmax -# TODO: expose save_preprocessed_copy -# TODO: make here a log of all API changes (or on kilosort4.py) -# TODO: try out longer recordings and do some benchmarking tests.. -# TODO: expose tmin and tmax -# There is no way to skip HP spatial filter -# might as well expose tmin and tmax -# might as well expose preprocessing save (across the two functions that use it) -# BinaryFilter added scale and shift as new arguments recently -# test with docker -# test all params once -# try and read func / class object to see kwargs -# Shift and scale are also taken as a function on BinaryFilter. Do we want to apply these even when -# do kilosort preprocessing is false? probably -# TODO: find a test case for the other annoying ones (larger recording, variable amplitude) -# TODO: test docker -# TODO: test multi-segment recording -# TODO: test do correction, skip preprocessing -# TODO: can we rename 'save_extra_kwargs' to 'save_extra_vars'. Currently untested. -# nt : # TODO: can't kilosort figure this out from sampling rate? -# TODO: also test runtimes -# TODO: test skip preprocessing separately -# TODO: the pure default case is not tested -# TODO: shift and scale - this is also added to BinaryFilter - -RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # TODO: ignore some of these +RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. - +# Setup Params to test #### PARAMS_TO_TEST = [ # Not tested # ("torch_device", "auto") + # Stable across KS version 4.0.01 - 4.0.12 ("change_nothing", None), ("nblocks", 0), ("do_CAR", False), - ("batch_size", 42743), # Q: how much do these results change with batch size? + ("batch_size", 42743), ("Th_universal", 12), ("Th_learned", 14), ("invert_sign", True), @@ -80,14 +74,15 @@ ("n_templates", 10), ("n_pcs", 3), ("Th_single_ch", 4), - ("acg_threshold", 0.001), ("x_centers", 5), - ("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS ("binning_depth", 1), + # Note: These don't change the results from + # default when applied to the test case. ("artifact_threshold", 200), - ("ccg_threshold", 1e9), - ("cluster_downsampling", 1e9), - ("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13! + ("ccg_threshold", 1e12), + ("acg_threshold", 1e12), + ("cluster_downsampling", 2), + ("duplicate_spike_bins", 5), ] # Update PARAMS_TO_TEST with version-dependent kwargs @@ -131,11 +126,13 @@ class TestKilosort4Long: # Fixtures ###### @pytest.fixture(scope="session") def recording_and_paths(self, tmp_path_factory): - """ """ + """ + Create a ground-truth recording, and save it to binary + so KS4 can run on it. Fixture is set up once and shared between + all tests. + """ tmp_path = tmp_path_factory.mktemp("kilosort4_tests") - np.random.seed(0) # TODO: check below... - recording = self._get_ground_truth_recording() paths = self._save_ground_truth_recording(recording, tmp_path) @@ -144,10 +141,17 @@ def recording_and_paths(self, tmp_path_factory): @pytest.fixture(scope="session") def default_results(self, recording_and_paths): - """ """ + """ + Because we check each parameter at a time and check the + KS4 and SpikeInterface versions match, if changing the parameter + had no effect as compared to default then the test would not test + anything. Therefore, the default results are run once and stored, + to check changing params indeed changes the results during testing. + This is possibly for nearly all parameters. + """ recording, paths = recording_and_paths - settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths) + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "change_nothing", None) defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" @@ -162,18 +166,46 @@ def default_results(self, recording_and_paths): return default_results - # Tests ###### - def test_params_to_test(self): + def _get_ground_truth_recording(self): """ - Test that all parameters in PARAMS_TO_TEST are - different than the default value used in Kilosort, otherwise - there is no point to the test. + A ground truth recording chosen to be as small as possible (for speed). + But contain enough information so that changing most parameters + changes the results. + """ + num_channels = 32 + recording, _ = si.generate_ground_truth_recording( + durations=[5], + seed=0, + num_channels=num_channels, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), + ) + return recording - TODO: need to use _default_params vs. DEFAULT_SETTINGS - depending on decision + def _save_ground_truth_recording(self, recording, tmp_path): + """ + Save the recording and its probe to file, so it can be + loaded by KS4. + """ + paths = { + "session_scope_tmp_path": tmp_path, + "recording_path": tmp_path / "my_test_recording", + "probe_path": tmp_path / "my_test_probe.prb", + } - TODO: write issue on this, we hope it will be on DEFAULT_SETTINGS - TODO: duplicate_spike_ms in POSTPROCESSING but seems unused? + recording.save(folder=paths["recording_path"], overwrite=True) + + probegroup = recording.get_probegroup() + write_prb(paths["probe_path"].as_posix(), probegroup) + + return paths + + # Tests ###### + def test_params_to_test(self): + """ + Test that all values in PARAMS_TO_TEST are + different to the default values used in Kilosort, + otherwise there is no point to the test. """ for parameter in PARAMS_TO_TEST: @@ -198,6 +230,7 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + # Testing Arguments ### def test_set_files_arguments(self): self._check_arguments( set_files, @@ -205,7 +238,6 @@ def test_set_files_arguments(self): ) def test_initialize_ops_arguments(self): - expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] if parse(version("kilosort")) >= parse("4.0.12"): @@ -234,7 +266,6 @@ def test_detect_spikes_arguments(self): ["ops", "device", "bfile", "tic0", "progress_bar"] ) - def test_cluster_spikes_arguments(self): self._check_arguments( cluster_spikes, @@ -242,7 +273,6 @@ def test_cluster_spikes_arguments(self): ) def test_save_sorting_arguments(self): - expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] if parse(version("kilosort")) > parse("4.0.11"): @@ -272,7 +302,6 @@ def test_recording_extractor_as_array_arguments(self): ) def test_binary_filtered_arguments(self): - expected_arguments = [ "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", "chan_map", "hp_filter", "whiten_mat", "dshift", @@ -294,27 +323,23 @@ def _check_arguments(self, object_, expected_arguments): obj_arguments = list(sig.parameters.keys()) assert expected_arguments == obj_arguments + # Full Test #### @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): - """ """ + """ + Given a recording, paths to raw data, and a parameter to change, + run KS4 natively and within the SpikeInterface wrapper with the + new parameter value (all other values default) and + check the outputs are the same. + """ recording, paths = recording_and_paths param_key, param_value = parameter + # Setup parameters for KS4 and run it natively kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - extra_ks_settings = {} - if param_key == "binning_depth": - extra_ks_settings.update({"nblocks": 5}) - - if param_key in RUN_KILOSORT_ARGS: - run_kilosort_kwargs = {param_key: param_value} - else: - if param_key != "change_nothing": - extra_ks_settings.update({param_key: param_value}) - run_kilosort_kwargs = {} - - settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_ks_settings) + settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) kilosort.run_kilosort( settings=settings, @@ -324,14 +349,9 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet **run_kilosort_kwargs, ) - extra_si_settings = {} - if param_key != "change_nothing": - extra_si_settings.update({param_key: param_value}) + # Setup Parameters for SI and KS4 through SI + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) - if param_key == "binning_depth": - extra_si_settings.update({"nblocks": 5}) - - spikeinterface_settings = self._get_spikeinterface_settings(extra_settings=extra_si_settings) si.run_sorter( "kilosort4", recording, @@ -340,36 +360,41 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet **spikeinterface_settings, ) + # Get the results and check they match results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]), f"{param_key} cluster assignment different" - assert all( - results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0] - ), f"{param_key} cluster assignment different" - assert all( - results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1] - ), f"{param_key} cluster quality different" # TODO: check pandas probably better way - - # This is saved on the SI side so not an extremely - # robust addition, but it can't hurt. + # Check the ops file in KS4 output is as expected. This is saved on the + # SI side so not an extremely robust addition, but it can't hurt. if param_key != "change_nothing": ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) ops = ops.tolist() # strangely this makes a dict assert ops[param_key] == param_value - # Finally, check out test parameters actually changes stuff! + # Finally, check out test parameters actually change the output of + # KS4, ensuring our tests are actually doing something. This is not + # done prior to 4.0.4 because a number of parameters seem to stop + # having an effect. This is probably due to small changes in their + # behaviour, and the test file chosen here. if parse(version("kilosort")) > parse("4.0.4"): - self._check_test_parameters_are_actually_changing_the_output(results, default_results, param_key) + self._check_test_parameters_are_changing_the_output(results, default_results, param_key) def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): - """ """ + """ + Test the SpikeInterface wrappers `do_correction` argument. We set + `nblocks=0` for KS4 native, turning off motion correction. Then + we run KS$ through SpikeInterface with `do_correction=False` but + `nblocks=1` (KS4 default) - checking that `do_correction` overrides + this and the result matches KS4 when run without motion correction. + """ recording, paths = recording_and_paths - kilosort_output_dir = tmp_path / "kilosort_output_dir" # TODO: a lost of copying here + kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "nblocks", 0) kilosort.run_kilosort( settings=settings, @@ -379,7 +404,7 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): do_CAR=True, ) - spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 6}) + spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) si.run_sorter( "kilosort4", recording, @@ -392,186 +417,199 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + + @pytest.mark.parametrize("param_to_test", [ + ("change_nothing", None), + ("do_CAR", False), + ("batch_size", 42743), + ("Th_learned", 14), + ("dmin", 15), + ("max_channel_distance", 5), + ("n_pcs", 3), + ]) + def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, param_to_test): + """ + Test that skipping KS4 preprocessing works as expected. Run + KS4 natively, monkeypatching the relevant preprocessing functions + such that preprocessing is not performed. Then run in SpikeInterface + with `skip_kilosort_preprocessing=True` and check the outputs match. + + Run with a few randomly chosen parameters to check these are propagated + under this condition. - assert all(results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0]) - assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]) + TODO + ---- + It would be nice to check a few additional parameters here. Screw it! + """ + param_key, param_value = param_to_test - def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch): - """ """ recording = self._get_ground_truth_recording() # We need to filter and whiten the recording here to KS takes forever. - # Do this in a way differnt to KS. + # Do this in a way different to KS. recording = si.highpass_filter(recording, 300) recording = si.whiten(recording, mode="local", apply_mean=False) paths = self._save_ground_truth_recording(recording, tmp_path) - kilosort_default_output_dir = tmp_path / "kilosort_default_output_dir" kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - ks_settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + def monkeypatch_filter_function(self, X, ops=None, ibatch=None): + """ + This is a direct copy of the kilosort io.BinaryFiltered.filter + function, with hp_filter and whitening matrix code sections, and + comments removed. This is the easiest way to monkeypatch (tried a few approaches) + """ + if self.chan_map is not None: + X = X[self.chan_map] - kilosort.run_kilosort( - settings=ks_settings, - probe=ks_format_probe, - data_dtype="float32", - results_dir=kilosort_default_output_dir, - do_CAR=False, - ) + if self.invert_sign: + X = X * -1 + + X = X - X.mean(1).unsqueeze(1) + if self.do_CAR: + X = X - torch.median(X, 0)[0] + + if self.hp_filter is not None: + pass - # Now the tricky bit, we need to turn off preprocessing in kilosort. - # This is not exposed by run_kilosort() arguments (at 4.0.12 at least) - # and so we need to monkeypatch the internal functions. The easiest - # thing to do would be to set `get_highpass_filter()` and - # `get_whitening_matrix()` to return `None` so these steps are skipped - # in BinaryFilter. Unfortunately the ops saving machinery requires - # these to be torch arrays and will error otherwise, so instead - # we must set the filter (in frequency space) and whitening matrix - # to unity operations so the filter and whitening do nothing. It is - # also required to turn off motion correection to avoid some additional - # magic KS is doing at the whitening step when motion correction is on. - fake_filter = np.ones(60122, dtype="float32") # TODO: hard coded - fake_filter = torch.from_numpy(fake_filter).to("cpu") - - fake_white_matrix = np.eye(recording.get_num_channels(), dtype="float32") - fake_white_matrix = torch.from_numpy(fake_white_matrix).to("cpu") - - def fake_fft_highpass(*args, **kwargs): - return fake_filter - - def fake_get_whitening_matrix(*args, **kwargs): - return fake_white_matrix - - def fake_fftshift(X, dim): + if self.artifact_threshold < np.inf: + if torch.any(torch.abs(X) >= self.artifact_threshold): + return torch.zeros_like(X) + + if self.whiten_mat is not None: + pass return X - monkeypatch.setattr("kilosort.io.fft_highpass", fake_fft_highpass) - monkeypatch.setattr("kilosort.preprocessing.get_whitening_matrix", fake_get_whitening_matrix) - monkeypatch.setattr("kilosort.io.fftshift", fake_fftshift) + monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", + monkeypatch_filter_function) + + ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + ks_settings["nblocks"] = 0 + + # Be explicit here and don't rely on defaults. + do_CAR = param_value if param_key == "do_CAR" else False kilosort.run_kilosort( settings=ks_settings, probe=ks_format_probe, data_dtype="float32", results_dir=kilosort_output_dir, - do_CAR=False, + do_CAR=do_CAR, ) monkeypatch.undo() # Now, run kilosort through spikeinterface with the same options. - spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 0}) + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + spikeinterface_settings["nblocks"] = 0 + + do_CAR = False if param_key != "do_CAR" else spikeinterface_settings.pop("do_CAR") + si.run_sorter( "kilosort4", recording, remove_existing_folder=True, folder=spikeinterface_output_dir, - do_CAR=False, + do_CAR=do_CAR, skip_kilosort_preprocessing=True, **spikeinterface_settings, ) - default_results = self._get_sorting_output(kilosort_default_output_dir) + # There is a very slight difference caused by the batching between load vs. + # memory file. Because in this test recordings are preprocessed, there are + # some filter edge effects that depend on the chunking in `get_traces()`. + # These are all extremely close (usually just 1 spike, 1 idx different). results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - - # Check that out intervention actually make some difference to KS output - # (or this test would do nothing). Then check SI and KS outputs with - # preprocessing skipped are identical. - assert not np.array_equal(default_results["ks"]["st"], results["ks"]["st"]) - assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) # Helpers ###### - def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): - """ """ - if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling", "cluster_pcs"]: - num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size - num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size + def _check_test_parameters_are_changing_the_output(self, results, default_results, param_key): + """ + If nothing is changed, default vs. results outputs are identical. + Otherwise, check they are not the same. Can't figure out how to get + the skipped three parameters below to change the results on this + small test file. + """ + if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling"]: + return + + if param_key == "change_nothing": + assert all( + default_results["ks"]["st"] == results["ks"]["st"] + ) and all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} changed somehow!." + else: + assert not ( + default_results["ks"]["st"].size == results["ks"]["st"].size + ) or not all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} results did not change with parameter change." - if param_key == "change_nothing": - # TODO: lol - assert ( - (results["si"]["st"].size == default_results["ks"]["st"].size) - and num_clus == num_clus_default - and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) - ), f"{param_key} changed somehow!." - else: - assert ( - (results["si"]["st"].size != default_results["ks"]["st"].size) - or num_clus != num_clus_default - or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) - ), f"{param_key} results did not change with parameter change." - - def _run_kilosort_with_kilosort(self, recording, paths, extra_settings=None): - """ """ - # dont actually run KS here because we will overwrite the defaults! + def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): + """ + Function to generate the settings and function inputs to run kilosort. + Note when `binning_depth` is used we need to set `nblocks` high to + get the results to change from default. + + Some settings in KS4 are passed by `settings` dict while others + are through the function, these are split here. + """ settings = { "data_dir": paths["recording_path"], "n_chan_bin": recording.get_num_channels(), "fs": recording.get_sampling_frequency(), } - if extra_settings is not None: - settings.update(extra_settings) + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + settings.update({param_key: param_value}) + run_kilosort_kwargs = {} ks_format_probe = load_probe(paths["probe_path"]) - return settings, ks_format_probe + return settings, run_kilosort_kwargs, ks_format_probe - def _get_spikeinterface_settings(self, extra_settings=None): - """ """ - # dont actually run here. + def _get_spikeinterface_settings(self, param_key, param_value): + """ + Generate settings kwargs for running KS4 in SpikeInterface. + See `_get_kilosort_native_settings()` for some details. + """ settings = copy.deepcopy(DEFAULT_SETTINGS) - if extra_settings is not None: - settings.update(extra_settings) + if param_key != "change_nothing": + settings.update({param_key: param_value}) + + if param_key == "binning_depth": + settings.update({"nblocks": 5}) - for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # TODO: check tmin and tmax + for name in ["n_chan_bin", "fs", "tmin", "tmax"]: settings.pop(name) return settings def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: - """ """ + """ + Load the results of sorting into a dict for easy comparison. + """ results = { "si": {}, "ks": {}, } if kilosort_output_dir: results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") - results["ks"]["clus"] = pd.read_table(kilosort_output_dir / "cluster_group.tsv") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") if spikeinterface_output_dir: results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") - results["si"]["clus"] = pd.read_table(spikeinterface_output_dir / "sorter_output" / "cluster_group.tsv") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") return results - - def _get_ground_truth_recording(self): - """ """ - # Chosen so all parameter changes to indeed change the output - num_channels = 32 - recording, _ = si.generate_ground_truth_recording( - durations=[5], - seed=0, - num_channels=num_channels, - num_units=5, - generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), - ) - return recording - - def _save_ground_truth_recording(self, recording, tmp_path): - """ """ - paths = { - "session_scope_tmp_path": tmp_path, - "recording_path": tmp_path / "my_test_recording", - "probe_path": tmp_path / "my_test_probe.prb", - } - - recording.save(folder=paths["recording_path"], overwrite=True) - - probegroup = recording.get_probegroup() - write_prb(paths["probe_path"].as_posix(), probegroup) - - return paths diff --git a/conftest.py b/conftest.py index 8c06830d25..544c2fb6cb 100644 --- a/conftest.py +++ b/conftest.py @@ -29,7 +29,7 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - try: # TODO: make a note on this, check with Herberto its okay. + try: rel_path = Path(item.fspath).relative_to(modules_location) except: continue diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 28a3c3ffa3..8721ce1b89 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -127,8 +127,7 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - """kilosort version <0.0.10 is always '4' z""" - # Note this import clashes with version! + """kilosort version <0.0.10 is always '4'""" from importlib.metadata import version as importlib_version return importlib_version("kilosort") @@ -216,7 +215,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - # TODO: save_preprocesed_copy added ops = initialize_ops( settings=settings, probe=probe, @@ -237,7 +235,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): - # TODO: shift, scaled added n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) @@ -261,22 +258,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): chan_map=chan_map, hp_filter=None, device=device, - do_CAR=do_CAR, # TODO: should this always be False if we are in skipping KS preprocessing land? + do_CAR=do_CAR, invert_sign=invert, dtype=dtype, - tmin=tmin, # TODO: exposing tmin, max? + tmin=tmin, tmax=tmax, artifact_threshold=artifact, - file_object=file_object, # TODO: exposing shift, scale when skipping preprocessing? + file_object=file_object, ) ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) ops["Nbatches"] = bfile.n_batches + # bfile.close() # TODO: KS do this after preprocessing? np.random.seed(1) torch.cuda.manual_seed_all(1) torch.random.manual_seed(1) - # if not params["skip_kilosort_preprocessing"]: + if not params["do_correction"]: print("Skipping drift correction.") ops["nblocks"] = 0 From 0817a5b3f10c986db04632fb979e2c30cf501dbc Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 02:30:19 +0100 Subject: [PATCH 18/90] Add tests to check _default_params against KS params. --- .github/scripts/test_kilosort4_ci.py | 25 +++++++++++++++---- .../sorters/external/kilosort4.py | 7 +++--- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 8a455a41fe..ecc931781c 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -26,7 +26,7 @@ import kilosort from kilosort.io import load_probe import pandas as pd - +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter import pytest from probeinterface.io import write_prb from kilosort.parameters import DEFAULT_SETTINGS @@ -230,6 +230,21 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + def test_spikeinterface_defaults_against_kilsort(self): + """ + Here check that all _ + Don't check that every default in KS is exposed in params, + because they change across versions. Instead, this check + is performed here against PARAMS_TO_TEST. + """ + params = copy.deepcopy(Kilosort4Sorter._default_params) + + for key in params.keys(): + # "artifact threshold" is set to `np.inf` if `None` in + # the body of the `Kilosort4Sorter` class. + if key in DEFAULT_SETTINGS and key not in ["artifact_threshold"]: + assert params[key] == DEFAULT_SETTINGS[key], f"{key} is not the same across versions." + # Testing Arguments ### def test_set_files_arguments(self): self._check_arguments( @@ -533,7 +548,7 @@ def _check_test_parameters_are_changing_the_output(self, results, default_result the skipped three parameters below to change the results on this small test file. """ - if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling"]: + if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling", "cluster_pcs"]: return if param_key == "change_nothing": @@ -583,7 +598,7 @@ def _get_spikeinterface_settings(self, param_key, param_value): Generate settings kwargs for running KS4 in SpikeInterface. See `_get_kilosort_native_settings()` for some details. """ - settings = copy.deepcopy(DEFAULT_SETTINGS) + settings = {} # copy.deepcopy(DEFAULT_SETTINGS) if param_key != "change_nothing": settings.update({param_key: param_value}) @@ -591,8 +606,8 @@ def _get_spikeinterface_settings(self, param_key, param_value): if param_key == "binning_depth": settings.update({"nblocks": 5}) - for name in ["n_chan_bin", "fs", "tmin", "tmax"]: - settings.pop(name) + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: + # settings.pop(name) return settings diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8721ce1b89..82c033f61d 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,6 +6,7 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase +from importlib.metadata import version as importlib_version PathType = Union[str, Path] @@ -35,7 +36,7 @@ class Kilosort4Sorter(BaseSorter): "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, - "dminx": 32, + "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.0.1") else None, "min_template_size": 10, "template_sizes": 5, "nearest_chans": 10, @@ -50,7 +51,7 @@ class Kilosort4Sorter(BaseSorter): "cluster_downsampling": 20, "cluster_pcs": 64, "x_centers": None, - "duplicate_spike_bins": 7, + "duplicate_spike_bins": 7 if version.parse(importlib_version("kilosort")) >= version.parse("4.0.4") else 15, "do_correction": True, "keep_good_only": False, "save_extra_kwargs": False, @@ -128,8 +129,6 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): """kilosort version <0.0.10 is always '4'""" - from importlib.metadata import version as importlib_version - return importlib_version("kilosort") @classmethod From c8779fc87dfaa6aa1d2bdb72d6fa58ed36c7da7c Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 03:01:37 +0100 Subject: [PATCH 19/90] Skip tests where relevant, try on slightly earlier python version to avoid weird xlabel bug. --- .github/scripts/test_kilosort4_ci.py | 3 +++ .github/workflows/test_kilosort4.yml | 2 +- src/spikeinterface/sorters/external/kilosort4.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index ecc931781c..3e74fa708e 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -228,6 +228,8 @@ def test_default_settings_all_represented(self): for param_key in DEFAULT_SETTINGS: if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + if parse(version("kilosort")) == parse("4.0.9") and param_key == "nblocks": + continue assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." def test_spikeinterface_defaults_against_kilsort(self): @@ -434,6 +436,7 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): assert np.array_equal(results["ks"]["st"], results["si"]["st"]) assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") @pytest.mark.parametrize("param_to_test", [ ("change_nothing", None), ("do_CAR", False), diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index c216be20d0..3ad61c0d2e 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] + python-version: ["3.10"] # TODO: just checking python version is not cause of failing test. os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 82c033f61d..811a6e8452 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -36,7 +36,7 @@ class Kilosort4Sorter(BaseSorter): "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, - "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.0.1") else None, + "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.2") else None, "min_template_size": 10, "template_sizes": 5, "nearest_chans": 10, @@ -128,7 +128,7 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - """kilosort version <0.0.10 is always '4'""" + """kilosort version <4.0.10 is always '4'""" return importlib_version("kilosort") @classmethod From 867729102ee5a76f412f1a8e7c025ceefadb7bff Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 09:37:21 +0100 Subject: [PATCH 20/90] Don't support 4.0.4 --- .github/scripts/check_kilosort4_releases.py | 7 +++++++ .github/scripts/test_kilosort4_ci.py | 3 ++- .github/workflows/test_kilosort4.yml | 2 +- src/spikeinterface/sorters/external/kilosort4.py | 5 +++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 3d04d6948a..9572f88330 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -6,14 +6,21 @@ def get_pypi_versions(package_name): + """ + Make an API call to pypi to retrieve all + available versions of the kilosort package. + """ url = f"https://pypi.org/pypi/{package_name}/json" response = requests.get(url) response.raise_for_status() data = response.json() + versions = list(sorted(data["releases"].keys())) + versions.pop(versions.index("4.0.4")) return list(sorted(data["releases"].keys())) if __name__ == "__main__": + # Get all KS4 versions from pipi and write to file. package_name = "kilosort" versions = get_pypi_versions(package_name) with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 3e74fa708e..c894ed71ff 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -342,7 +342,7 @@ def _check_arguments(self, object_, expected_arguments): # Full Test #### @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) - def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): + def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, parameter): """ Given a recording, paths to raw data, and a parameter to change, run KS4 natively and within the SpikeInterface wrapper with the @@ -398,6 +398,7 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet if parse(version("kilosort")) > parse("4.0.4"): self._check_test_parameters_are_changing_the_output(results, default_results, param_key) + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ Test the SpikeInterface wrappers `do_correction` argument. We set diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 3ad61c0d2e..03db2b6170 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] # TODO: just checking python version is not cause of failing test. + python-version: ["3.12"] # TODO: just checking python version is not cause of failing test. os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 811a6e8452..55e694a02f 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -163,6 +163,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) + if cls.get_sorter_version() == version.parse("4.0.4"): + raise RuntimeError( + "Kilosort version 4.0.4 is not supported" "in SpikeInterface. Please change Kilosort version." + ) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" From 21caaf99bd93e7189725acc3de3079264e79d710 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 10:09:58 +0100 Subject: [PATCH 21/90] Remove support for versions earlier that 4.0.5. --- .github/scripts/check_kilosort4_releases.py | 5 +++-- src/spikeinterface/sorters/external/kilosort4.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 9572f88330..05d8c0c614 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -3,7 +3,7 @@ from pathlib import Path import requests import json - +from packaging.version import parse def get_pypi_versions(package_name): """ @@ -15,8 +15,9 @@ def get_pypi_versions(package_name): response.raise_for_status() data = response.json() versions = list(sorted(data["releases"].keys())) + versions = [ver for ver in versions if parse(ver) >= parse("4.0.5")] versions.pop(versions.index("4.0.4")) - return list(sorted(data["releases"].keys())) + return versions if __name__ == "__main__": diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 55e694a02f..dba28f7244 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -36,7 +36,7 @@ class Kilosort4Sorter(BaseSorter): "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, - "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.2") else None, + "dminx": 32, "min_template_size": 10, "template_sizes": 5, "nearest_chans": 10, @@ -51,7 +51,7 @@ class Kilosort4Sorter(BaseSorter): "cluster_downsampling": 20, "cluster_pcs": 64, "x_centers": None, - "duplicate_spike_bins": 7 if version.parse(importlib_version("kilosort")) >= version.parse("4.0.4") else 15, + "duplicate_spike_bins": 7, "do_correction": True, "keep_good_only": False, "save_extra_kwargs": False, @@ -163,9 +163,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) - if cls.get_sorter_version() == version.parse("4.0.4"): + if cls.get_sorter_version() < version.parse("4.0.5"): raise RuntimeError( - "Kilosort version 4.0.4 is not supported" "in SpikeInterface. Please change Kilosort version." + "Kilosort versions before 4.0.5 are not supported" + "in SpikeInterface. " + "Please upgrade Kilosort version." ) sorter_output_folder = sorter_output_folder.absolute() From 9bc18978fbb56917b0f4fe46df7c3bc531f850a4 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 10:40:50 +0100 Subject: [PATCH 22/90] Add packaging to CI dependency. On branch add_kilosort4_wrapper_tests --- .github/scripts/check_kilosort4_releases.py | 1 - .github/workflows/test_kilosort4.yml | 2 +- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 05d8c0c614..de11dc974b 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -16,7 +16,6 @@ def get_pypi_versions(package_name): data = response.json() versions = list(sorted(data["releases"].keys())) versions = [ver for ver in versions if parse(ver) >= parse("4.0.5")] - versions.pop(versions.index("4.0.4")) return versions diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 03db2b6170..088dd1a6a4 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | - pip install requests + pip install requests packaging - name: Fetch package versions from PyPI run: | diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index dba28f7244..eb1df7c455 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -163,7 +163,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) - if cls.get_sorter_version() < version.parse("4.0.5"): + if version.parse(cls.get_sorter_version()) < version.parse("4.0.5"): raise RuntimeError( "Kilosort versions before 4.0.5 are not supported" "in SpikeInterface. " From 23d2c77533a2bc65791bd6d07eda9b8723133c33 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 12:30:05 +0100 Subject: [PATCH 23/90] Add some more documentation to .yml --- .github/workflows/test_kilosort4.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 088dd1a6a4..13d70acf88 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -11,6 +11,8 @@ on: jobs: versions: + # Poll Pypi for all released KS4 versions >4.0.4, save to JSON + # and store them in a matrix for the next job. runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} From 1bad6d6e3fb26b4f3e4bae9876bf25b056077280 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 16:01:22 +0100 Subject: [PATCH 24/90] Remove unused rng. --- src/spikeinterface/generation/drifting_generator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index b439c57c52..7f8682035c 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -348,9 +348,6 @@ def generate_drifting_recording( This can be helpfull for motion benchmark. """ - - rng = np.random.default_rng(seed=seed) - # probe if generate_probe_kwargs is None: generate_probe_kwargs = _toy_probes[probe_name] From 0f9c32cbdb82ce48120830fe55c88d0376a350b8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 17:25:00 +0100 Subject: [PATCH 25/90] Add 'int' type to 'num_samples' on 'InjectTemplatesRecording'. --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 62aa7f37c3..e53f8cc539 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1714,7 +1714,7 @@ def __init__( amplitude_factor: Union[List[List[float]], List[float], float, None] = None, parent_recording: Union[BaseRecording, None] = None, num_samples: Optional[List[int]] = None, - upsample_vector: Union[List[int], None] = None, + upsample_vector: Union[List[int], int, None] = None, check_borders: bool = False, ) -> None: templates = np.asarray(templates) From 73c146f29828d073af08e91a7fd45a38430cff71 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:01:22 +0100 Subject: [PATCH 26/90] Remove some errneous Optional type hints and convert to | on 'generate_recording'. --- src/spikeinterface/core/generate.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index e53f8cc539..e9255d55cc 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -27,12 +27,12 @@ def _ensure_seed(seed): def generate_recording( - num_channels: Optional[int] = 2, - sampling_frequency: Optional[float] = 30000.0, - durations: Optional[List[float]] = [5.0, 2.5], - set_probe: Optional[bool] = True, - ndim: Optional[int] = 2, - seed: Optional[int] = None, + num_channels: int = 2, + sampling_frequency: float = 30000.0, + durations: List[float] = [5.0, 2.5], + set_probe: bool | None = True, + ndim: int | None = 2, + seed: int | None = None, ) -> BaseRecording: """ Generate a lazy recording object. @@ -1090,7 +1090,7 @@ def __init__( num_channels: int, sampling_frequency: float, durations: List[float], - noise_levels: float = 1.0, + noise_levels: float | np.array = 1.0, cov_matrix: Optional[np.array] = None, dtype: Optional[Union[np.dtype, str]] = "float32", seed: Optional[int] = None, From e5701f6202106c14c4307e2e541ad8167319dfdd Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:01:31 +0100 Subject: [PATCH 27/90] Remove some errneous Optional type hints and convert to | on 'generate_recording'. --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index e9255d55cc..c4665a7bd5 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1090,7 +1090,7 @@ def __init__( num_channels: int, sampling_frequency: float, durations: List[float], - noise_levels: float | np.array = 1.0, + noise_levels: float = 1.0, cov_matrix: Optional[np.array] = None, dtype: Optional[Union[np.dtype, str]] = "float32", seed: Optional[int] = None, From 098f8071b1aa0e7170b04d0966fc35526839b1e9 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:12:56 +0100 Subject: [PATCH 28/90] Convert NoiseGeneratorRecording. --- src/spikeinterface/core/generate.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index c4665a7bd5..9037109549 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1064,11 +1064,11 @@ class NoiseGeneratorRecording(BaseRecording): The durations of each segment in seconds. Note that the length of this list is the number of segments. noise_levels: float or array, default: 1 Std of the white noise (if an array, defined by per channels) - cov_matrix: np.array, default None + cov_matrix: np.array | None, default None The covariance matrix of the noise - dtype : Optional[Union[np.dtype, str]], default: "float32" + dtype : np.dtype | str |None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. - seed : Optional[int], default: None + seed : int | None, default: None The seed for np.random.default_rng. strategy : "tile_pregenerated" or "on_the_fly" The strategy of generating noise chunk: @@ -1090,10 +1090,10 @@ def __init__( num_channels: int, sampling_frequency: float, durations: List[float], - noise_levels: float = 1.0, - cov_matrix: Optional[np.array] = None, - dtype: Optional[Union[np.dtype, str]] = "float32", - seed: Optional[int] = None, + noise_levels: float | np.array = 1.0, + cov_matrix: np.array | None = None, + dtype: np.dtype | str | None = "float32", + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", noise_block_size: int = 30000, ): From db0c30d37d443fcffad7caead7412d672e553d81 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:13:27 +0100 Subject: [PATCH 29/90] Remove duplicate noise level keys in NoiseGeneratorRecording. --- src/spikeinterface/core/generate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9037109549..a3e77b57f0 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1150,7 +1150,6 @@ def __init__( "sampling_frequency": sampling_frequency, "noise_levels": noise_levels, "cov_matrix": cov_matrix, - "noise_levels": noise_levels, "dtype": dtype, "seed": seed, "strategy": strategy, From 013c834aba8ad0f0ac6a956db30bdc4a5e8b3598 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:16:44 +0100 Subject: [PATCH 30/90] substitute get_traces(). --- src/spikeinterface/core/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index a3e77b57f0..0d95668f2e 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1204,9 +1204,9 @@ def get_num_samples(self) -> int: def get_traces( self, - start_frame: Union[int, None] = None, - end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: List | None = None, ) -> np.ndarray: start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size From 0cede16b03ca2b404783806353b30591e0116d03 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:17:17 +0100 Subject: [PATCH 31/90] Remove unused argument to generate_recording_by_size. --- src/spikeinterface/core/generate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0d95668f2e..b48d8b40df 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1260,7 +1260,6 @@ def get_traces( def generate_recording_by_size( full_traces_size_GiB: float, - num_channels: int = 384, seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: From bbc55b48233e1cca578c69bb6af0223fc0e0c1d0 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:17:38 +0100 Subject: [PATCH 32/90] Convert 'generate_recording_by_size'. --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index b48d8b40df..41b44792be 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1260,7 +1260,7 @@ def get_traces( def generate_recording_by_size( full_traces_size_GiB: float, - seed: Optional[int] = None, + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: """ From c5af7f36b75d1f37d0e7e2d54ff81383a7928cc2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:40:02 +0100 Subject: [PATCH 33/90] Fix type hints on InjectTemplatesRecording and convert. --- src/spikeinterface/core/generate.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 41b44792be..b9ab8f6d25 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1694,7 +1694,7 @@ class InjectTemplatesRecording(BaseRecording): num_samples: list[int] | int | None The number of samples in the recording per segment. You can use int for mono-segment objects. - upsample_vector: np.array or None, default: None. + upsample_vector: np.array | None, default: None. When templates is 4d we can simulate a jitter. Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. @@ -1708,11 +1708,11 @@ def __init__( self, sorting: BaseSorting, templates: np.ndarray, - nbefore: Union[List[int], int, None] = None, - amplitude_factor: Union[List[List[float]], List[float], float, None] = None, - parent_recording: Union[BaseRecording, None] = None, - num_samples: Optional[List[int]] = None, - upsample_vector: Union[List[int], int, None] = None, + nbefore: List[int] | int | None = None, + amplitude_factor: List[List[float]] | List[float] | float | None = None, + parent_recording: BaseRecording | None = None, + num_samples: List[int] | int | None = None, + upsample_vector: np.array | None = None, check_borders: bool = False, ) -> None: templates = np.asarray(templates) From 16bf359de2c364f9681f193601aeb46304053762 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:43:55 +0100 Subject: [PATCH 34/90] Remove bad type hint on 'InjectTemplatesRecording'. --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index b9ab8f6d25..21edae8447 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1709,7 +1709,7 @@ def __init__( sorting: BaseSorting, templates: np.ndarray, nbefore: List[int] | int | None = None, - amplitude_factor: List[List[float]] | List[float] | float | None = None, + amplitude_factor: List[float] | float | None = None, parent_recording: BaseRecording | None = None, num_samples: List[int] | int | None = None, upsample_vector: np.array | None = None, From 1760d0f5a213356390069b387d4674fafbd314bb Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:54:30 +0100 Subject: [PATCH 35/90] Fix all other cases --- src/spikeinterface/core/generate.py | 60 ++++++++++++++--------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 21edae8447..6fc21a34dc 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Union, Optional, List, Literal +from typing import List, Literal from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -47,10 +47,10 @@ def generate_recording( durations: List[float], default: [5.0, 2.5] The duration in seconds of each segment in the recording, default: [5.0, 2.5]. Note that the number of segments is determined by the length of this list. - set_probe: bool, default: True - ndim : int, default: 2 + set_probe: bool | None, default: True + ndim : int | None, default: 2 The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probe. - seed : Optional[int] + seed : int | None, default: None A seed for the np.ramdom.default_rng function Returns @@ -253,13 +253,13 @@ def generate_sorting_to_inject( num_samples: list of size num_segments. The number of samples in all the segments of the sorting, to generate spike times covering entire the entire duration of the segments. - max_injected_per_unit: int, default 1000 + max_injected_per_unit: int, default: 1000 The maximal number of spikes injected per units. - injected_rate: float, default 0.05 + injected_rate: float, default: 0.05 The rate at which spikes are injected. - refractory_period_ms: float, default 1.5 + refractory_period_ms: float, default: 1.5 The refractory period that should not be violated while injecting new spikes. - seed: int, default None + seed: int, default: None The random seed. Returns @@ -313,13 +313,13 @@ class TransformSorting(BaseSorting): ---------- sorting : BaseSorting The sorting object. - added_spikes_existing_units : np.array (spike_vector) + added_spikes_existing_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for existing units. - added_spikes_new_units: np.array (spike_vector) + added_spikes_new_units: np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for new units. - new_units_ids: list + new_units_ids: list[str, int] | None, default: None The unit_ids that should be added if spikes for new units are added. - refractory_period_ms : float, default None + refractory_period_ms : float | None, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -333,10 +333,10 @@ class TransformSorting(BaseSorting): def __init__( self, sorting: BaseSorting, - added_spikes_existing_units=None, - added_spikes_new_units=None, - new_unit_ids: Optional[List[Union[str, int]]] = None, - refractory_period_ms: Optional[float] = None, + added_spikes_existing_units: np.array | None = None, + added_spikes_new_units: np.array | None = None, + new_unit_ids: List[str | int] | None = None, + refractory_period_ms: float | None = None, ): sampling_frequency = sorting.get_sampling_frequency() unit_ids = list(sorting.get_unit_ids()) @@ -432,7 +432,7 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe The first sorting. sorting2: BaseSorting The second sorting. - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -498,7 +498,7 @@ def add_from_unit_dict( The first sorting dict_list: list of dict A list of dict with unit_ids as keys and spike times as values. - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -528,7 +528,7 @@ def from_times_labels( unit_ids: list or None, default: None The explicit list of unit_ids that should be extracted from labels_list If None, then it will be np.unique(labels_list). - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -1064,7 +1064,7 @@ class NoiseGeneratorRecording(BaseRecording): The durations of each segment in seconds. Note that the length of this list is the number of segments. noise_levels: float or array, default: 1 Std of the white noise (if an array, defined by per channels) - cov_matrix: np.array | None, default None + cov_matrix: np.array | None, default: None The covariance matrix of the noise dtype : np.dtype | str |None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. @@ -1279,7 +1279,7 @@ def generate_recording_by_size( The size in gigabytes (GiB) of the recording. num_channels: int Number of channels. - seed : int, default: None + seed : int | None, default: None The seed for np.random.default_rng. Returns @@ -1688,10 +1688,10 @@ class InjectTemplatesRecording(BaseRecording): Can be None (no scaling). Can be scalar all spikes have the same factor (certainly useless). Can be a vector with same shape of spike_vector of the sorting. - parent_recording: BaseRecording | None + parent_recording: BaseRecording | None, default: None The recording over which to add the templates. If None, will default to traces containing all 0. - num_samples: list[int] | int | None + num_samples: list[int] | int | None, default: None The number of samples in the recording per segment. You can use int for mono-segment objects. upsample_vector: np.array | None, default: None. @@ -1844,10 +1844,10 @@ def __init__( spike_vector: np.ndarray, templates: np.ndarray, nbefore: int, - amplitude_vector: Union[List[float], None], - upsample_vector: Union[List[float], None], - parent_recording_segment: Union[BaseRecordingSegment, None] = None, - num_samples: Union[int, None] = None, + amplitude_vector: List[float] | None, + upsample_vector: List[float] | None, + parent_recording_segment: BaseRecordingSegment | None = None, + num_samples: int | None = None, ) -> None: BaseRecordingSegment.__init__( self, @@ -1867,9 +1867,9 @@ def __init__( def get_traces( self, - start_frame: Union[int, None] = None, - end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: List | None = None, ) -> np.ndarray: if channel_indices is None: n_channels = self.templates.shape[2] From f694e30e06ffaf3895ff3cbafd060551aabedd08 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 18 Jul 2024 18:55:40 +0100 Subject: [PATCH 36/90] List -> list. --- src/spikeinterface/core/generate.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6fc21a34dc..ea58ab6ef8 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import List, Literal +from typing import Literal from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -29,7 +29,7 @@ def _ensure_seed(seed): def generate_recording( num_channels: int = 2, sampling_frequency: float = 30000.0, - durations: List[float] = [5.0, 2.5], + durations: list[float] = [5.0, 2.5], set_probe: bool | None = True, ndim: int | None = 2, seed: int | None = None, @@ -44,7 +44,7 @@ def generate_recording( The number of channels in the recording. sampling_frequency : float, default: 30000. (in Hz) The sampling frequency of the recording, default: 30000. - durations: List[float], default: [5.0, 2.5] + durations: list[float], default: [5.0, 2.5] The duration in seconds of each segment in the recording, default: [5.0, 2.5]. Note that the number of segments is determined by the length of this list. set_probe: bool | None, default: True @@ -236,7 +236,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): def generate_sorting_to_inject( sorting: BaseSorting, - num_samples: List[int], + num_samples: list[int], max_injected_per_unit: int = 1000, injected_rate: float = 0.05, refractory_period_ms: float = 1.5, @@ -335,7 +335,7 @@ def __init__( sorting: BaseSorting, added_spikes_existing_units: np.array | None = None, added_spikes_new_units: np.array | None = None, - new_unit_ids: List[str | int] | None = None, + new_unit_ids: list[str | int] | None = None, refractory_period_ms: float | None = None, ): sampling_frequency = sorting.get_sampling_frequency() @@ -1060,7 +1060,7 @@ class NoiseGeneratorRecording(BaseRecording): The number of channels. sampling_frequency : float The sampling frequency of the recorder. - durations : List[float] + durations : list[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. noise_levels: float or array, default: 1 Std of the white noise (if an array, defined by per channels) @@ -1089,7 +1089,7 @@ def __init__( self, num_channels: int, sampling_frequency: float, - durations: List[float], + durations: list[float], noise_levels: float | np.array = 1.0, cov_matrix: np.array | None = None, dtype: np.dtype | str | None = "float32", @@ -1206,7 +1206,7 @@ def get_traces( self, start_frame: int | None = None, end_frame: int | None = None, - channel_indices: List | None = None, + channel_indices: list | None = None, ) -> np.ndarray: start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size @@ -1708,10 +1708,10 @@ def __init__( self, sorting: BaseSorting, templates: np.ndarray, - nbefore: List[int] | int | None = None, - amplitude_factor: List[float] | float | None = None, + nbefore: list[int] | int | None = None, + amplitude_factor: list[float] | float | None = None, parent_recording: BaseRecording | None = None, - num_samples: List[int] | int | None = None, + num_samples: list[int] | int | None = None, upsample_vector: np.array | None = None, check_borders: bool = False, ) -> None: @@ -1844,8 +1844,8 @@ def __init__( spike_vector: np.ndarray, templates: np.ndarray, nbefore: int, - amplitude_vector: List[float] | None, - upsample_vector: List[float] | None, + amplitude_vector: list[float] | None, + upsample_vector: list[float] | None, parent_recording_segment: BaseRecordingSegment | None = None, num_samples: int | None = None, ) -> None: @@ -1869,7 +1869,7 @@ def get_traces( self, start_frame: int | None = None, end_frame: int | None = None, - channel_indices: List | None = None, + channel_indices: list | None = None, ) -> np.ndarray: if channel_indices is None: n_channels = self.templates.shape[2] From bc2cc8a965ec8871c7a8f91edde8bf104792fde5 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:22:18 +0100 Subject: [PATCH 37/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ea58ab6ef8..04d2135670 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -44,7 +44,7 @@ def generate_recording( The number of channels in the recording. sampling_frequency : float, default: 30000. (in Hz) The sampling frequency of the recording, default: 30000. - durations: list[float], default: [5.0, 2.5] + durations : list[float], default: [5.0, 2.5] The duration in seconds of each segment in the recording, default: [5.0, 2.5]. Note that the number of segments is determined by the length of this list. set_probe: bool | None, default: True From ff66a3815663e9f4909a505231008d5bd1779fb7 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:22:28 +0100 Subject: [PATCH 38/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 04d2135670..325008a4f2 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -259,7 +259,7 @@ def generate_sorting_to_inject( The rate at which spikes are injected. refractory_period_ms: float, default: 1.5 The refractory period that should not be violated while injecting new spikes. - seed: int, default: None + seed : int, default: None The random seed. Returns From 90b366dfe4b5fd7c4a73b858b5aff844d4703f1e Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:22:58 +0100 Subject: [PATCH 39/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 325008a4f2..10139918c2 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1688,7 +1688,7 @@ class InjectTemplatesRecording(BaseRecording): Can be None (no scaling). Can be scalar all spikes have the same factor (certainly useless). Can be a vector with same shape of spike_vector of the sorting. - parent_recording: BaseRecording | None, default: None + parent_recording : BaseRecording | None, default: None The recording over which to add the templates. If None, will default to traces containing all 0. num_samples: list[int] | int | None, default: None From adc40e6de25a3e001a4a2d0e1385fb18dec35b4a Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:23:17 +0100 Subject: [PATCH 40/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 10139918c2..d1f9ff97f3 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1694,7 +1694,7 @@ class InjectTemplatesRecording(BaseRecording): num_samples: list[int] | int | None, default: None The number of samples in the recording per segment. You can use int for mono-segment objects. - upsample_vector: np.array | None, default: None. + upsample_vector : np.array | None, default: None. When templates is 4d we can simulate a jitter. Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. From 4c5d198b20806521b013c3ea9d37b067032edd03 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:23:32 +0100 Subject: [PATCH 41/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d1f9ff97f3..3be5e166ab 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1708,11 +1708,11 @@ def __init__( self, sorting: BaseSorting, templates: np.ndarray, - nbefore: list[int] | int | None = None, - amplitude_factor: list[float] | float | None = None, - parent_recording: BaseRecording | None = None, - num_samples: list[int] | int | None = None, - upsample_vector: np.array | None = None, + nbefore : list[int] | int | None = None, + amplitude_factor : list[float] | float | None = None, + parent_recording : BaseRecording | None = None, + num_samples : list[int] | int | None = None, + upsample_vector : np.array | None = None, check_borders: bool = False, ) -> None: templates = np.asarray(templates) From a9400f999a73a7997c1160a90375d91210c1a5c0 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:23:47 +0100 Subject: [PATCH 42/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 3be5e166ab..6fc231f6f4 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1844,10 +1844,10 @@ def __init__( spike_vector: np.ndarray, templates: np.ndarray, nbefore: int, - amplitude_vector: list[float] | None, - upsample_vector: list[float] | None, - parent_recording_segment: BaseRecordingSegment | None = None, - num_samples: int | None = None, + amplitude_vector : list[float] | None, + upsample_vector : list[float] | None, + parent_recording_segment : BaseRecordingSegment | None = None, + num_samples : int | None = None, ) -> None: BaseRecordingSegment.__init__( self, From ef4d9e39440bb2f1924b2bba023a2f8d4ebc6c5a Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:24:01 +0100 Subject: [PATCH 43/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6fc231f6f4..d51fe8101c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1867,9 +1867,9 @@ def __init__( def get_traces( self, - start_frame: int | None = None, - end_frame: int | None = None, - channel_indices: list | None = None, + start_frame : int | None = None, + end_frame : int | None = None, + channel_indices : list | None = None, ) -> np.ndarray: if channel_indices is None: n_channels = self.templates.shape[2] From 6519ffa1e001f5abb37f07d88364af15665fb9c4 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:24:12 +0100 Subject: [PATCH 44/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d51fe8101c..1098df8275 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -253,7 +253,7 @@ def generate_sorting_to_inject( num_samples: list of size num_segments. The number of samples in all the segments of the sorting, to generate spike times covering entire the entire duration of the segments. - max_injected_per_unit: int, default: 1000 + max_injected_per_unit : int, default: 1000 The maximal number of spikes injected per units. injected_rate: float, default: 0.05 The rate at which spikes are injected. From ab80c707be525130f5f50c33dddf1fa3868b08e0 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:24:27 +0100 Subject: [PATCH 45/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1098df8275..68aa558543 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1089,11 +1089,11 @@ def __init__( self, num_channels: int, sampling_frequency: float, - durations: list[float], - noise_levels: float | np.array = 1.0, - cov_matrix: np.array | None = None, - dtype: np.dtype | str | None = "float32", - seed: int | None = None, + durations : list[float], + noise_levels : float | np.array = 1.0, + cov_matrix : np.array | None = None, + dtype : np.dtype | str | None = "float32", + seed : int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", noise_block_size: int = 30000, ): From 061a5fa47004c2aca2b5c0c28d6437a6e1af6d6a Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:24:42 +0100 Subject: [PATCH 46/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 68aa558543..c73954dcf5 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1064,7 +1064,7 @@ class NoiseGeneratorRecording(BaseRecording): The durations of each segment in seconds. Note that the length of this list is the number of segments. noise_levels: float or array, default: 1 Std of the white noise (if an array, defined by per channels) - cov_matrix: np.array | None, default: None + cov_matrix : np.array | None, default: None The covariance matrix of the noise dtype : np.dtype | str |None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. From 34d09a9d9dd2b5db4213e0f2434ffd033d102ba0 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:24:53 +0100 Subject: [PATCH 47/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index c73954dcf5..b57767161f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -255,7 +255,7 @@ def generate_sorting_to_inject( covering entire the entire duration of the segments. max_injected_per_unit : int, default: 1000 The maximal number of spikes injected per units. - injected_rate: float, default: 0.05 + injected_rate : float, default: 0.05 The rate at which spikes are injected. refractory_period_ms: float, default: 1.5 The refractory period that should not be violated while injecting new spikes. From 8257cd9a0783c48e93db69011d5960cc80ae1059 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:25:33 +0100 Subject: [PATCH 48/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index b57767161f..4abd407681 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -333,10 +333,10 @@ class TransformSorting(BaseSorting): def __init__( self, sorting: BaseSorting, - added_spikes_existing_units: np.array | None = None, - added_spikes_new_units: np.array | None = None, - new_unit_ids: list[str | int] | None = None, - refractory_period_ms: float | None = None, + added_spikes_existing_units : np.array | None = None, + added_spikes_new_units : np.array | None = None, + new_unit_ids : list[str | int] | None = None, + refractory_period_ms : float | None = None, ): sampling_frequency = sorting.get_sampling_frequency() unit_ids = list(sorting.get_unit_ids()) From c175b6b72563661a400752bece72e873e56e1abc Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:25:48 +0100 Subject: [PATCH 49/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4abd407681..28cf7ec404 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -315,7 +315,7 @@ class TransformSorting(BaseSorting): The sorting object. added_spikes_existing_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for existing units. - added_spikes_new_units: np.array (spike_vector) | None, default: None + added_spikes_new_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for new units. new_units_ids: list[str, int] | None, default: None The unit_ids that should be added if spikes for new units are added. From 39d46a12aa14db08f23d18fe68f1b9a408be7bf3 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:26:04 +0100 Subject: [PATCH 50/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 28cf7ec404..09db185776 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -257,7 +257,7 @@ def generate_sorting_to_inject( The maximal number of spikes injected per units. injected_rate : float, default: 0.05 The rate at which spikes are injected. - refractory_period_ms: float, default: 1.5 + refractory_period_ms : float, default: 1.5 The refractory period that should not be violated while injecting new spikes. seed : int, default: None The random seed. From 16023bba5f325017f20a91fb4f6740edd9445ed5 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:26:19 +0100 Subject: [PATCH 51/90] space before colon Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 09db185776..57f79c87ae 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -317,7 +317,7 @@ class TransformSorting(BaseSorting): The spikes that should be added to the sorting object, for existing units. added_spikes_new_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for new units. - new_units_ids: list[str, int] | None, default: None + new_units_ids : list[str, int] | None, default: None The unit_ids that should be added if spikes for new units are added. refractory_period_ms : float | None, default: None The refractory period violation to prevent duplicates and/or unphysiological addition From 259562e2d554464a29bc68c944afc4ff3a9bbb65 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 15:29:32 +0000 Subject: [PATCH 52/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 42 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 57f79c87ae..5749f31b10 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -333,10 +333,10 @@ class TransformSorting(BaseSorting): def __init__( self, sorting: BaseSorting, - added_spikes_existing_units : np.array | None = None, - added_spikes_new_units : np.array | None = None, - new_unit_ids : list[str | int] | None = None, - refractory_period_ms : float | None = None, + added_spikes_existing_units: np.array | None = None, + added_spikes_new_units: np.array | None = None, + new_unit_ids: list[str | int] | None = None, + refractory_period_ms: float | None = None, ): sampling_frequency = sorting.get_sampling_frequency() unit_ids = list(sorting.get_unit_ids()) @@ -1089,11 +1089,11 @@ def __init__( self, num_channels: int, sampling_frequency: float, - durations : list[float], - noise_levels : float | np.array = 1.0, - cov_matrix : np.array | None = None, - dtype : np.dtype | str | None = "float32", - seed : int | None = None, + durations: list[float], + noise_levels: float | np.array = 1.0, + cov_matrix: np.array | None = None, + dtype: np.dtype | str | None = "float32", + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", noise_block_size: int = 30000, ): @@ -1708,11 +1708,11 @@ def __init__( self, sorting: BaseSorting, templates: np.ndarray, - nbefore : list[int] | int | None = None, - amplitude_factor : list[float] | float | None = None, - parent_recording : BaseRecording | None = None, - num_samples : list[int] | int | None = None, - upsample_vector : np.array | None = None, + nbefore: list[int] | int | None = None, + amplitude_factor: list[float] | float | None = None, + parent_recording: BaseRecording | None = None, + num_samples: list[int] | int | None = None, + upsample_vector: np.array | None = None, check_borders: bool = False, ) -> None: templates = np.asarray(templates) @@ -1844,10 +1844,10 @@ def __init__( spike_vector: np.ndarray, templates: np.ndarray, nbefore: int, - amplitude_vector : list[float] | None, - upsample_vector : list[float] | None, - parent_recording_segment : BaseRecordingSegment | None = None, - num_samples : int | None = None, + amplitude_vector: list[float] | None, + upsample_vector: list[float] | None, + parent_recording_segment: BaseRecordingSegment | None = None, + num_samples: int | None = None, ) -> None: BaseRecordingSegment.__init__( self, @@ -1867,9 +1867,9 @@ def __init__( def get_traces( self, - start_frame : int | None = None, - end_frame : int | None = None, - channel_indices : list | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, ) -> np.ndarray: if channel_indices is None: n_channels = self.templates.shape[2] From 67174c2893f4afaf767433bed223ba5906f7956f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 22 Jul 2024 17:05:20 +0100 Subject: [PATCH 53/90] Add a few more fixes to docstrings. --- src/spikeinterface/core/generate.py | 185 +++++++++++++++------------- 1 file changed, 98 insertions(+), 87 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 57f79c87ae..a195b73aab 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -47,7 +47,7 @@ def generate_recording( durations : list[float], default: [5.0, 2.5] The duration in seconds of each segment in the recording, default: [5.0, 2.5]. Note that the number of segments is determined by the length of this list. - set_probe: bool | None, default: True + set_probe : bool | None, default: True ndim : int | None, default: 2 The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probe. seed : int | None, default: None @@ -188,7 +188,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): ---------- sorting : BaseSorting The sorting object. - sync_event_ratio : float + sync_event_ratio : float, default: 0.3 The ratio of added synchronous spikes with respect to the total number of spikes. E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra spikes are synchronous (same sample_index), but on different units (not duplicates). @@ -250,7 +250,7 @@ def generate_sorting_to_inject( ---------- sorting : BaseSorting The sorting object. - num_samples: list of size num_segments. + num_samples : list[int] of size num_segments. The number of samples in all the segments of the sorting, to generate spike times covering entire the entire duration of the segments. max_injected_per_unit : int, default: 1000 @@ -333,10 +333,10 @@ class TransformSorting(BaseSorting): def __init__( self, sorting: BaseSorting, - added_spikes_existing_units : np.array | None = None, - added_spikes_new_units : np.array | None = None, - new_unit_ids : list[str | int] | None = None, - refractory_period_ms : float | None = None, + added_spikes_existing_units: np.array | None = None, + added_spikes_new_units: np.array | None = None, + new_unit_ids: list[str | int] | None = None, + refractory_period_ms: float | None = None, ): sampling_frequency = sorting.get_sampling_frequency() unit_ids = list(sorting.get_unit_ids()) @@ -428,9 +428,9 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting. - sorting2: BaseSorting + sorting2 : BaseSorting The second sorting. refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition @@ -484,7 +484,7 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe @staticmethod def add_from_unit_dict( - sorting1: BaseSorting, units_dict_list: dict, refractory_period_ms=None + sorting1: BaseSorting, units_dict_list: list[dict] | dict, refractory_period_ms=None ) -> "TransformSorting": """ Construct TransformSorting by adding one sorting with a @@ -494,9 +494,9 @@ def add_from_unit_dict( Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting - dict_list: list of dict + dict_list : list[dict] | dict A list of dict with unit_ids as keys and spike times as values. refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition @@ -519,13 +519,15 @@ def from_times_labels( Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting - times_list: list of array (or array) + times_list : list[np.array] | np.array An array of spike times (in frames). - labels_list: list of array (or array) + labels_list : list[np.array] | np.array An array of spike labels corresponding to the given times. - unit_ids: list or None, default: None + sampling_frequency : float, default: 30000. (in Hz) + The sampling frequency of the recording, default: 30000. + unit_ids : list | None, default: None The explicit list of unit_ids that should be extracted from labels_list If None, then it will be np.unique(labels_list). refractory_period_ms : float, default: None @@ -592,7 +594,7 @@ def generate_snippets( nafter=44, num_channels=2, wf_folder=None, - sampling_frequency=30000.0, # in Hz + sampling_frequency=30000.0, durations=[10.325, 3.5], #  in s for 2 segments set_probe=True, ndim=2, @@ -613,7 +615,7 @@ def generate_snippets( Number of channels. wf_folder : str | Path | None, default: None Optional folder to save the waveform snippets. If None, snippets are in memory. - sampling_frequency : float, default: 30000.0 + sampling_frequency : float, default: 30000.0 (in Hz) The sampling frequency of the snippets. ndim : int, default: 2 The number of dimensions of the probe. @@ -690,7 +692,7 @@ def synthesize_poisson_spike_vector( ---------- num_units : int, default: 20 Number of neuronal units to simulate. - sampling_frequency : float, default: 30000.0 + sampling_frequency : float, default: 30000.0 (in Hz) Sampling frequency in Hz. duration : float, default: 60.0 Duration of the simulation in seconds. @@ -793,20 +795,20 @@ def synthesize_random_firings( Parameters ---------- - num_units : int + num_units : int, default: 20 Number of units. - sampling_frequency : float + sampling_frequency : float, default: 30000.0 (in Hz) Sampling rate. - duration : float + duration : float, default: 60 Duration of the segment in seconds. - refractory_period_ms: float + refractory_period_ms : float, default: 4.0 Refractory period in ms. - firing_rates: float or list[float] + firing_rates : float or list[float], default: 3.0 The firing rate of each unit (in Hz). If float, all units will have the same firing rate. - add_shift_shuffle: bool, default: False + add_shift_shuffle : bool, default: False Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat. - seed: int, default: None + seed : int, default: None Seed for the generator. Returns @@ -899,12 +901,14 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No ---------- sorting : Original sorting. - num : int + num : int, default: 4 Number of injected units. - max_shift : int + max_shift : int, default: 5 range of the shift in sample. - ratio: float + ratio : float | None, default: None Proportion of original spike in the injected units. + seed : int, default: None + Seed for the generator. Returns ------- @@ -1062,21 +1066,21 @@ class NoiseGeneratorRecording(BaseRecording): The sampling frequency of the recorder. durations : list[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_levels: float or array, default: 1 + noise_levels : float | np.array, default: 1.0 Std of the white noise (if an array, defined by per channels) cov_matrix : np.array | None, default: None The covariance matrix of the noise - dtype : np.dtype | str |None, default: "float32" + dtype : np.dtype | str | None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. seed : int | None, default: None The seed for np.random.default_rng. - strategy : "tile_pregenerated" or "on_the_fly" + strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated" The strategy of generating noise chunk: * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it very fast and cusume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) - noise_block_size: int + noise_block_size : int, default: 30000 Size in sample of noise block. Note @@ -1089,11 +1093,11 @@ def __init__( self, num_channels: int, sampling_frequency: float, - durations : list[float], - noise_levels : float | np.array = 1.0, - cov_matrix : np.array | None = None, - dtype : np.dtype | str | None = "float32", - seed : int | None = None, + durations: list[float], + noise_levels: float | np.array = 1.0, + cov_matrix: np.array | None = None, + dtype: np.dtype | str | None = "float32", + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", noise_block_size: int = 30000, ): @@ -1277,11 +1281,14 @@ def generate_recording_by_size( ---------- full_traces_size_GiB : float The size in gigabytes (GiB) of the recording. - num_channels: int - Number of channels. seed : int | None, default: None The seed for np.random.default_rng. - + strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated" + The strategy of generating noise chunk: + * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it + very fast and cusume only one noise block. + * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index + no memory preallocation but a bit more computaion (random) Returns ------- GeneratorRecording @@ -1517,25 +1524,25 @@ def generate_templates( Parameters ---------- - channel_locations: np.ndarray + channel_locations : np.ndarray Channel locations. - units_locations: np.ndarray + units_locations : np.ndarray Must be 3D. - sampling_frequency: float + sampling_frequency : float Sampling frequency. - ms_before: float + ms_before : float Cut out in ms before spike peak. - ms_after: float + ms_after : float Cut out in ms after spike peak. - seed: int or None + seed : int | None A seed for random. - dtype: numpy.dtype, default: "float32" + dtype : numpy.dtype, default: "float32" Templates dtype - upsample_factor: None or int + upsample_factor : int | None, default: None If not None then template are generated upsampled by this factor. Then a new dimention (axis=3) is added to the template with intermediate inter sample representation. This allow easy random jitter by choising a template this new dim - unit_params: dict of arrays or dict of scalar of dict of tuple + unit_params : dict[np.array] | dict[float] | dict[tuple] | None, default: None An optional dict containing parameters per units. Keys are parameter names: @@ -1552,6 +1559,10 @@ def generate_templates( * array of the same length of units * scalar, then an array is created * tuple, then this difine a range for random values. + mode : "ellipsoid" | "sphere", default: "ellipsoid" + Method used to calculate the distance between unit and channel location. + Ellipoid injects some anisotropy dependent on unit shape, sphere is equivalent + to Euclidean distance. Returns ------- @@ -1672,18 +1683,18 @@ class InjectTemplatesRecording(BaseRecording): Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting Sorting object containing all the units and their spike train. - templates: np.ndarray[n_units, n_samples, n_channels] or np.ndarray[n_units, n_samples, n_oversampling] + templates : np.ndarray[n_units, n_samples, n_channels] | np.ndarray[n_units, n_samples, n_oversampling] Array containing the templates to inject for all the units. Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. - nbefore: list[int] | int | None, default: None + nbefore : list[int] | int | None, default: None The number of samples before the peak of the template to align the spike. If None, will default to the highest peak. - amplitude_factor: list[float] | float | None, default: None + amplitude_factor : list[float] | float | None, default: None The amplitude of each spike for each unit. Can be None (no scaling). Can be scalar all spikes have the same factor (certainly useless). @@ -1691,7 +1702,7 @@ class InjectTemplatesRecording(BaseRecording): parent_recording : BaseRecording | None, default: None The recording over which to add the templates. If None, will default to traces containing all 0. - num_samples: list[int] | int | None, default: None + num_samples : list[int] | int | None, default: None The number of samples in the recording per segment. You can use int for mono-segment objects. upsample_vector : np.array | None, default: None. @@ -1708,11 +1719,11 @@ def __init__( self, sorting: BaseSorting, templates: np.ndarray, - nbefore : list[int] | int | None = None, - amplitude_factor : list[float] | float | None = None, - parent_recording : BaseRecording | None = None, - num_samples : list[int] | int | None = None, - upsample_vector : np.array | None = None, + nbefore: list[int] | int | None = None, + amplitude_factor: list[float] | float | None = None, + parent_recording: BaseRecording | None = None, + num_samples: list[int] | int | None = None, + upsample_vector: np.array | None = None, check_borders: bool = False, ) -> None: templates = np.asarray(templates) @@ -1844,10 +1855,10 @@ def __init__( spike_vector: np.ndarray, templates: np.ndarray, nbefore: int, - amplitude_vector : list[float] | None, - upsample_vector : list[float] | None, - parent_recording_segment : BaseRecordingSegment | None = None, - num_samples : int | None = None, + amplitude_vector: list[float] | None, + upsample_vector: list[float] | None, + parent_recording_segment: BaseRecordingSegment | None = None, + num_samples: int | None = None, ) -> None: BaseRecordingSegment.__init__( self, @@ -1867,9 +1878,9 @@ def __init__( def get_traces( self, - start_frame : int | None = None, - end_frame : int | None = None, - channel_indices : list | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, ) -> np.ndarray: if channel_indices is None: n_channels = self.templates.shape[2] @@ -2040,55 +2051,55 @@ def generate_ground_truth_recording( Parameters ---------- - durations: list of float, default: [10.] + durations : list[float], default: [10.] Durations in seconds for all segments. - sampling_frequency: float, default: 25000 + sampling_frequency : float, default: 25000.0 Sampling frequency. - num_channels: int, default: 4 + num_channels : int, default: 4 Number of channels, not used when probe is given. - num_units: int, default: 10 + num_units : int, default: 10 Number of units, not used when sorting is given. - sorting: Sorting or None + sorting : Sorting | None An external sorting object. If not provide, one is genrated. - probe: Probe or None + probe : Probe | None An external Probe object. If not provided a probe is generated using generate_probe_kwargs. - generate_probe_kwargs: dict + generate_probe_kwargs : dict A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. - templates: np.array or None + templates : np.array | None The templates of units. If None they are generated. Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. - ms_before: float, default: 1.5 + ms_before : float, default: 1.5 Cut out in ms before spike peak. - ms_after: float, default: 3 + ms_after : float, default: 3 Cut out in ms after spike peak. - upsample_factor: None or int, default: None + upsample_factor : None | int, default: None A upsampling factor used only when templates are not provided. - upsample_vector: np.array or None + upsample_vector : np.array | None Optional the upsample_vector can given. This has the same shape as spike_vector - generate_sorting_kwargs: dict + generate_sorting_kwargs : dict When sorting is not provide, this dict is used to generated a Sorting. - noise_kwargs: dict + noise_kwargs : dict Dict used to generated the noise with NoiseGeneratorRecording. - generate_unit_locations_kwargs: dict + generate_unit_locations_kwargs : dict Dict used to generated template when template not provided. - generate_templates_kwargs: dict + generate_templates_kwargs : dict Dict used to generated template when template not provided. - dtype: np.dtype, default: "float32" + dtype : np.dtype, default: "float32" The dtype of the recording. - seed: int or None + seed : int | None Seed for random initialization. If None a diffrent Recording is generated at every call. Note: even with None a generated recording keep internaly a seed to regenerate the same signal after dump/load. Returns ------- - recording: Recording + recording : Recording The generated recording extractor. - sorting: Sorting + sorting : Sorting The generated sorting extractor. """ generate_templates_kwargs = generate_templates_kwargs or dict() From 535fe17e83872d983ef5fe96abde34c944d74a1d Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:04:46 +0100 Subject: [PATCH 54/90] typo fix Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index a195b73aab..d2a5f98fe8 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1561,7 +1561,7 @@ def generate_templates( * tuple, then this difine a range for random values. mode : "ellipsoid" | "sphere", default: "ellipsoid" Method used to calculate the distance between unit and channel location. - Ellipoid injects some anisotropy dependent on unit shape, sphere is equivalent + Ellipsoid injects some anisotropy dependent on unit shape, sphere is equivalent to Euclidean distance. Returns From e7ce974fe0f9e5fc72b8ebe708e58ac5371e120f Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:05:03 +0100 Subject: [PATCH 55/90] typo fix Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d2a5f98fe8..09965b4550 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1286,7 +1286,7 @@ def generate_recording_by_size( strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated" The strategy of generating noise chunk: * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it - very fast and cusume only one noise block. + very fast and consume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) Returns From c7b5aa6810ab9019b54e10aef973d62a31f015e1 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:05:15 +0100 Subject: [PATCH 56/90] typo fix Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 09965b4550..b2ffdcd88a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1288,7 +1288,7 @@ def generate_recording_by_size( * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it very fast and consume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index - no memory preallocation but a bit more computaion (random) + no memory preallocation but a bit more computation (random) Returns ------- GeneratorRecording From 92ea6093e2312e60212798089dc552e6b2f15d21 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 24 Jul 2024 15:10:39 +0100 Subject: [PATCH 57/90] Move 'in Hz' to description for sampling frequency docstring. --- src/spikeinterface/core/generate.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index b2ffdcd88a..73a159380c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -42,8 +42,8 @@ def generate_recording( ---------- num_channels : int, default: 2 The number of channels in the recording. - sampling_frequency : float, default: 30000. (in Hz) - The sampling frequency of the recording, default: 30000. + sampling_frequency : float, default: 30000.0 + The sampling frequency of the recording in Hz durations : list[float], default: [5.0, 2.5] The duration in seconds of each segment in the recording, default: [5.0, 2.5]. Note that the number of segments is determined by the length of this list. @@ -105,7 +105,7 @@ def generate_sorting( num_units : int, default: 5 Number of units. sampling_frequency : float, default: 30000.0 - The sampling frequency. + The sampling frequency of the recording in Hz. durations : list, default: [10.325, 3.5] Duration of each segment in s. firing_rates : float, default: 3.0 @@ -525,8 +525,8 @@ def from_times_labels( An array of spike times (in frames). labels_list : list[np.array] | np.array An array of spike labels corresponding to the given times. - sampling_frequency : float, default: 30000. (in Hz) - The sampling frequency of the recording, default: 30000. + sampling_frequency : float, default: 30000.0 + The sampling frequency of the recording in Hz. unit_ids : list | None, default: None The explicit list of unit_ids that should be extracted from labels_list If None, then it will be np.unique(labels_list). @@ -615,8 +615,8 @@ def generate_snippets( Number of channels. wf_folder : str | Path | None, default: None Optional folder to save the waveform snippets. If None, snippets are in memory. - sampling_frequency : float, default: 30000.0 (in Hz) - The sampling frequency of the snippets. + sampling_frequency : float, default: 30000.0 + The sampling frequency of the snippets in Hz. ndim : int, default: 2 The number of dimensions of the probe. num_units : int, default: 5 @@ -692,7 +692,7 @@ def synthesize_poisson_spike_vector( ---------- num_units : int, default: 20 Number of neuronal units to simulate. - sampling_frequency : float, default: 30000.0 (in Hz) + sampling_frequency : float, default: 30000.0 Sampling frequency in Hz. duration : float, default: 60.0 Duration of the simulation in seconds. @@ -797,8 +797,8 @@ def synthesize_random_firings( ---------- num_units : int, default: 20 Number of units. - sampling_frequency : float, default: 30000.0 (in Hz) - Sampling rate. + sampling_frequency : float, default: 30000.0 + Sampling rate in Hz. duration : float, default: 60 Duration of the segment in seconds. refractory_period_ms : float, default: 4.0 From 05be8efe5334aa970ba7885fedba6c253fc8a85d Mon Sep 17 00:00:00 2001 From: jonahpearl Date: Thu, 25 Jul 2024 13:31:58 -0400 Subject: [PATCH 58/90] always split job kwargs --- src/spikeinterface/core/sortinganalyzer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ac142405ab..cda0e10ff7 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1209,11 +1209,7 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar print(f"Deleting {child}") self.delete_extension(child) - if extension_class.need_job_kwargs: - params, job_kwargs = split_job_kwargs(kwargs) - else: - params = kwargs - job_kwargs = {} + params, job_kwargs = split_job_kwargs(kwargs) # check dependencies if extension_class.need_recording: From 3c3cb933f7f51615e02e04335123899046708335 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 29 Jul 2024 13:37:24 -0300 Subject: [PATCH 59/90] drop python 3.8 in pyproject --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 71919c072b..eb2c0f2fe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.8,<4.0" +requires-python = ">=3.9,<4.0" classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From 4336f0d6d3ef5595a45242cab6985fa2e638fe3f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 29 Jul 2024 19:24:28 +0100 Subject: [PATCH 60/90] Expose 'save_preprocessed_copy' in KS4 wrapper. --- src/spikeinterface/sorters/external/kilosort4.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..a904866629 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -56,6 +56,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": False, "skip_kilosort_preprocessing": False, "scaleproc": None, + "save_preprocessed_copy": False, "torch_device": "auto", } @@ -98,6 +99,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", + "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory", "torch_device": "Select the torch device auto/cuda/cpu", } @@ -186,6 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] save_extra_vars = params["save_extra_kwargs"] + save_preprocessed_copy = (params["save_preprocessed_copy"],) progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} settings_ks["n_chan_bin"] = recording.get_num_channels() @@ -207,7 +210,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): results_dir = sorter_output_folder filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + ops = initialize_ops( + settings, + probe, + recording.get_dtype(), + do_CAR, + invert_sign, + device, + save_preprocesed_copy=save_preprocessed_copy, + ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) From 20b4d2fcf95171ea5304339c689170e52429d4fd Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 31 Jul 2024 15:04:57 +0100 Subject: [PATCH 61/90] Edit kilosort4.py to match the ks4 'run_sorter' function body. --- .../sorters/external/kilosort4.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a904866629..16918128a2 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -155,7 +155,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): save_sorting, get_run_parameters, ) - from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered + from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered, save_preprocessing from kilosort.parameters import DEFAULT_SETTINGS import time @@ -188,7 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] save_extra_vars = params["save_extra_kwargs"] - save_preprocessed_copy = (params["save_preprocessed_copy"],) + save_preprocessed_copy = params["save_preprocessed_copy"] progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} settings_ks["n_chan_bin"] = recording.get_num_channels() @@ -268,6 +268,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object ) + if save_preprocessed_copy: + save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) + # Sort spikes and save results st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) @@ -276,7 +279,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + _ = save_sorting( + ops, + results_dir, + st, + clu, + tF, + Wall, + bfile.imin, + tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy, + ) + else: + _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) @classmethod def _get_result_from_folder(cls, sorter_output_folder): From c320d6c09761ea673a5c24e06ea55622997f4d9f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 31 Jul 2024 15:18:09 +0100 Subject: [PATCH 62/90] Add clarification on typo. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 16918128a2..250c2865f9 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -217,7 +217,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR, invert_sign, device, - save_preprocesed_copy=save_preprocessed_copy, + save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) From e51088ab0e5f56596a78fc4cfd4e9a6d50f71414 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 31 Jul 2024 15:20:32 +0100 Subject: [PATCH 63/90] Extend param description. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 250c2865f9..6d83249653 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -99,7 +99,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", - "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory", + "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", } From 4668ab7ca434e004f16ebfd454923c7adf3d6943 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:29:56 -0400 Subject: [PATCH 64/90] begin to add examples to docstrings --- src/spikeinterface/extractors/neoextractors/alphaomega.py | 8 +++++++- src/spikeinterface/extractors/neoextractors/axona.py | 5 +++++ src/spikeinterface/extractors/neoextractors/ced.py | 5 +++++ src/spikeinterface/extractors/neoextractors/intan.py | 8 +++++++- src/spikeinterface/extractors/neoextractors/plexon.py | 5 +++++ src/spikeinterface/extractors/neoextractors/plexon2.py | 5 +++++ .../extractors/neoextractors/spikegadgets.py | 5 +++++ src/spikeinterface/extractors/neoextractors/spikeglx.py | 7 +++++++ 8 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index b3f671ebf3..cf47b9819c 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -18,7 +18,7 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): folder_path : str or Path-like The folder path to the AlphaOmega recordings. lsx_files : list of strings or None, default: None - A list of listings files that refers to mpx files to load. + A list of files that refers to mpx files to load. stream_id : {"RAW", "LFP", "SPK", "ACC", "AI", "UD"}, default: "RAW" If there are several streams, specify the stream id you want to load. stream_name : str, default: None @@ -28,6 +28,12 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_alphaomega + >>> recording = read_alphaomega(folder_path="alphaomega_folder") + """ NeoRawIOClass = "AlphaOmegaRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index adfdccddd9..9de39bef2e 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -22,6 +22,11 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_axona + >>> recording = read_axona(file_path=r'my_data.set') """ NeoRawIOClass = "AxonaRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index a42a2d75a5..992d1a8941 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -28,6 +28,11 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_ced + >>> recording = read_ced(file_path=r'my_data.smr') """ NeoRawIOClass = "CedRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index f0a1894f25..261472ede9 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -34,7 +34,13 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): In Intan the ids provided by NeoRawIO are the hardware channel ids while the names are custom names given by the user - + Examples + -------- + >>> from spikeinterface.extractors import read_intan + # intan amplifier data is stored in stream_id = '0' + >>> recording = read_intan(file_path=r'my_data.rhd', stream_id='0') + # intan has multi-file formats as well, but in this case our path should point to the header file 'info.rhd' + >>> recording = read_intan(file_path=r'info.rhd', stream_id='0') """ NeoRawIOClass = "IntanRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 0adddc2439..a10c231e13 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -30,6 +30,11 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): Example for wideband signals: names: ["WB01", "WB02", "WB03", "WB04"] ids: ["0" , "1", "2", "3"] + + Examples + -------- + >>> from spikeinterface.extractors import read_plexon + >>> recording = read_plexon(file_path=r'my_data.plx') """ NeoRawIOClass = "PlexonRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 4434d02cc1..2f360ed864 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -28,6 +28,11 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. + + Examples + -------- + >>> from spikeinterface.extractors import read_plexon2 + >>> recording = read_plexon2(file_path=r'my_data.pl2') """ NeoRawIOClass = "Plexon2RawIO" diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index 89c457a573..e91a81398b 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -29,6 +29,11 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_spikegadgets + >>> recording = read_spikegadgets(file_path=r'my_data.rec') """ NeoRawIOClass = "SpikeGadgetsRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index cfe20bbfa6..874a65c045 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -41,6 +41,13 @@ class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_spikeglx + >>> recording = read_spikeglx(folder_path=r'path_to_folder_with_data', load_sync_channel=False) + # we can load the sync channel, but then the probe is not loaded + >>> recording = read_spikeglx(folder_path=r'pat_to_folder_with_data', load_sync_channel=True) """ NeoRawIOClass = "SpikeGLXRawIO" From cd66e7ba8875161700d5523200930f9067ad2cde Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 8 Aug 2024 16:30:29 -0300 Subject: [PATCH 65/90] fix sampling repr --- src/spikeinterface/core/baserecording.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e65afabaca..edcc23f339 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -96,14 +96,18 @@ def list_to_string(lst, max_size=6): def _repr_header(self): num_segments = self.get_num_segments() num_channels = self.get_num_channels() - sf_hz = self.get_sampling_frequency() - sf_khz = sf_hz / 1000 dtype = self.get_dtype() total_samples = self.get_total_samples() total_duration = self.get_total_duration() total_memory_size = self.get_total_memory_size() - sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" + + sf_hz = self.get_sampling_frequency() + if not sf_hz.is_integer(): + sampling_frequency_repr = f"{sf_hz:f} Hz" + else: + # Khz for high sampling rate and Hz for LFP + sampling_frequency_repr = f"{(sf_hz/1000.0):0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" txt = ( f"{self.name}: " From 9583aae767322f1b4ab8a0b6b47057ce3bfc1d71 Mon Sep 17 00:00:00 2001 From: Christian Tabedzki <35670232+tabedzki@users.noreply.github.com> Date: Thu, 15 Aug 2024 09:02:31 -0400 Subject: [PATCH 66/90] Added sphinxcontrib-jquery --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 71919c072b..f759e839e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,7 @@ docs = [ "sphinx-design", "numpydoc", "ipython", + "sphinxcontrib-jquery", # for notebooks in the gallery "MEArec", # Use as an example From 3b77acbada0ebef08cee3203f504dc96d741d238 Mon Sep 17 00:00:00 2001 From: Robin Kim Date: Thu, 15 Aug 2024 15:09:00 -0500 Subject: [PATCH 67/90] Add no merge test --- src/spikeinterface/curation/curation_format.py | 3 ++- .../tests/sv-sorting-curation-no-merge.json | 1 + .../curation/tests/test_sortingview_curation.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index babe7aac40..5a57692597 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -92,7 +92,8 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo """ assert destination_format == "1" - + if "mergeGroups" not in sortingview_dict.keys(): + sortingview_dict["mergeGroups"] = [] merge_groups = sortingview_dict["mergeGroups"] merged_units = sum(merge_groups, []) if len(merged_units) > 0: diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json b/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json new file mode 100644 index 0000000000..2a350340f3 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json @@ -0,0 +1 @@ +{"labelsByUnit":{"2":["mua"],"3":["mua"],"4":["mua"],"5":["accept"],"6":["accept"],"7":["accept"],"8":["artifact"],"9":["artifact"]}} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index bb152e7f71..24bd44a4c8 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -243,6 +243,18 @@ def test_label_inheritance_str(): assert np.all(sorting_include_accept.get_property("accept")) +def test_json_no_merge_curation(): + """ + Test curation with no merges using a JSON file. + """ + sorting = generate_sorting(num_units=10) + + # from curation.json + json_file = parent_folder / "sv-sorting-curation-no-merge.json" + # print(f"Sorting: {sorting.get_unit_ids()}") + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file) + + if __name__ == "__main__": # generate_sortingview_curation_dataset() # test_sha1_curation() @@ -251,3 +263,4 @@ def test_label_inheritance_str(): test_false_positive_curation() test_label_inheritance_int() test_label_inheritance_str() + test_json_no_merge_curation() From 1356d3a362174cdfdb3382e777e0f3c3af126e4e Mon Sep 17 00:00:00 2001 From: Robin Kim Date: Thu, 15 Aug 2024 15:11:39 -0500 Subject: [PATCH 68/90] Add comment describing test fail --- src/spikeinterface/curation/tests/test_sortingview_curation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 24bd44a4c8..6c6dc482c3 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -249,10 +249,9 @@ def test_json_no_merge_curation(): """ sorting = generate_sorting(num_units=10) - # from curation.json json_file = parent_folder / "sv-sorting-curation-no-merge.json" - # print(f"Sorting: {sorting.get_unit_ids()}") sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file) + # ValueError: Curation format: some labeled units are not in the unit list if __name__ == "__main__": From 0b2c237076aca81ab9920db953f8ac5fb2fdb4b4 Mon Sep 17 00:00:00 2001 From: Christian Tabedzki <35670232+tabedzki@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:21:25 -0400 Subject: [PATCH 69/90] Added sphinx-rtd-theme minimum version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f759e839e8..ed5f6e6fa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,7 +181,7 @@ test = [ docs = [ "Sphinx", - "sphinx_rtd_theme", + "sphinx_rtd_theme>=1.2", "sphinx-gallery", "sphinx-design", "numpydoc", From d029f7d974020145ad6f309e05b3b1456693318b Mon Sep 17 00:00:00 2001 From: Robin Kim Date: Fri, 16 Aug 2024 11:56:39 -0500 Subject: [PATCH 70/90] Fix value error by checking first dict key type --- src/spikeinterface/curation/curation_format.py | 7 +++++-- .../curation/tests/test_sortingview_curation.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 5a57692597..511abb7801 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -96,10 +96,13 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo sortingview_dict["mergeGroups"] = [] merge_groups = sortingview_dict["mergeGroups"] merged_units = sum(merge_groups, []) - if len(merged_units) > 0: - unit_id_type = int if isinstance(merged_units[0], int) else str + + first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) + if str.isdigit(first_unit_id): + unit_id_type = int else: unit_id_type = str + all_units = [] all_labels = [] manual_labels = [] diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 6c6dc482c3..945aca7937 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -250,13 +250,13 @@ def test_json_no_merge_curation(): sorting = generate_sorting(num_units=10) json_file = parent_folder / "sv-sorting-curation-no-merge.json" - sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file) - # ValueError: Curation format: some labeled units are not in the unit list + sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file) if __name__ == "__main__": # generate_sortingview_curation_dataset() # test_sha1_curation() + test_gh_curation() test_json_curation() test_false_positive_curation() From 2af85b3f27a5463f0ddbf306229e5f8df1298106 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 19 Aug 2024 17:51:07 +0200 Subject: [PATCH 71/90] Enable cloud-loading for analyzer Zarr --- src/spikeinterface/core/core_tools.py | 17 +++++++++ src/spikeinterface/core/sortinganalyzer.py | 41 ++++++++++++---------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index aad7613d01..b38222391c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -684,3 +684,20 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float: memory = mem_info.total - mem_info.available return memory + + +def is_path_remote(path: str | Path) -> bool: + """ + Returns True if the path is a remote path (e.g., s3:// or gcs://). + + Parameters + ---------- + path : str or Path + The path to check. + + Returns + ------- + bool + Whether the path is a remote path. + """ + return "s3://" in str(path) or "gcs://" in str(path) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ac142405ab..eb6233bf86 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match -from .core_tools import check_json, retrieve_importing_provenance +from .core_tools import check_json, retrieve_importing_provenance, is_path_remote from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -195,6 +195,7 @@ def __init__( format=None, sparsity=None, return_scaled=True, + storage_options=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -204,6 +205,7 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled + self.storage_options = storage_options # this is used to store temporary recording self._temporary_recording = None @@ -276,17 +278,15 @@ def create( return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto"): + def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None): """ Load folder or zarr. The recording can be given if the recording location has changed. Otherwise the recording is loaded when possible. """ - folder = Path(folder) - assert folder.is_dir(), "Waveform folder does not exists" if format == "auto": # make better assumption and check for auto guess format - if folder.suffix == ".zarr": + if Path(folder).suffix == ".zarr": format = "zarr" else: format = "binary_folder" @@ -294,12 +294,18 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): if format == "binary_folder": sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) elif format == "zarr": - sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_zarr( + folder, recording=recording, storage_options=storage_options + ) - sorting_analyzer.folder = folder + if is_path_remote(str(folder)): + sorting_analyzer.folder = folder + # in this case we only load extensions when needed + else: + sorting_analyzer.folder = Path(folder) - if load_extensions: - sorting_analyzer.load_all_saved_extension() + if load_extensions: + sorting_analyzer.load_all_saved_extension() return sorting_analyzer @@ -470,7 +476,9 @@ def load_from_binary_folder(cls, folder, recording=None): def _get_zarr_root(self, mode="r+"): import zarr - zarr_root = zarr.open(self.folder, mode=mode) + if is_path_remote(str(self.folder)): + mode = "r" + zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -552,25 +560,22 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") @classmethod - def load_from_zarr(cls, folder, recording=None): + def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - folder = Path(folder) - assert folder.is_dir(), f"This folder does not exist {folder}" - - zarr_root = zarr.open(folder, mode="r") + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory - # TODO propagate storage_options sorting = NumpySorting.from_sorting( - ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True, copy_spike_vector=True + ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), + with_metadata=True, + copy_spike_vector=True, ) # load recording if possible if recording is None: rec_dict = zarr_root["recording"][0] try: - recording = load_extractor(rec_dict, base_folder=folder) except: recording = None From 33e27b1c621aeca2a99a32d3ddf44f4a5fadf022 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Aug 2024 09:15:51 +0200 Subject: [PATCH 72/90] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index eb6233bf86..45f1f881b4 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -563,7 +563,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) + zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory sorting = NumpySorting.from_sorting( From c1239397b9bec50acf06c8b3cbc11ee93861786f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Aug 2024 15:14:12 +0200 Subject: [PATCH 73/90] Lazy loading of zarr timestamps --- src/spikeinterface/core/baserecording.py | 7 +++---- src/spikeinterface/core/zarrextractors.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e65afabaca..efa6d03f56 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -811,10 +811,9 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): def get_times(self): if self.time_vector is not None: - if isinstance(self.time_vector, np.ndarray): - return self.time_vector - else: - return np.array(self.time_vector) + if not isinstance(self.time_vector, np.ndarray): + self.time_vector = np.array(self.time_vector) + return self.time_vector else: time_vector = np.arange(self.get_num_samples(), dtype="float64") time_vector /= self.sampling_frequency diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1b9637e097..17f1ac08b3 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -66,7 +66,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) time_kwargs = {} time_vector = self._root.get(f"times_seg{segment_index}", None) if time_vector is not None: - time_kwargs["time_vector"] = time_vector[:] + time_kwargs["time_vector"] = time_vector else: if t_starts is None: t_start = None From 32568ca1a9637a7dc167dbf1a56e214dbe13cfb5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 20 Aug 2024 15:46:49 +0100 Subject: [PATCH 74/90] Remove run CI on main, only run on cron job. --- .github/workflows/test_kilosort4.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 13d70acf88..24b2e29440 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -4,10 +4,6 @@ on: workflow_dispatch: schedule: - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC - pull_request: - types: [synchronize, opened, reopened] - branches: - - main jobs: versions: From 8580c975e0d26db4006883da7ff2c36a58a5832a Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:47:42 +0100 Subject: [PATCH 75/90] Update .github/scripts/test_kilosort4_ci.py --- .github/scripts/test_kilosort4_ci.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index c894ed71ff..10855f2120 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -336,6 +336,10 @@ def test_binary_filtered_arguments(self): ) def _check_arguments(self, object_, expected_arguments): + """ + Check that the argument signature of `object_` is as expected + (i..e has not changed across kilosort versions). + """ sig = signature(object_) obj_arguments = list(sig.parameters.keys()) assert expected_arguments == obj_arguments From b3c6680f859d165bc6f4e11ea8d91cfd6c95eaf1 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:47:52 +0100 Subject: [PATCH 76/90] Update src/spikeinterface/sorters/external/kilosort4.py --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index eb1df7c455..3f7a0f7abe 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -128,7 +128,7 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - """kilosort version <4.0.10 is always '4'""" + """kilosort.__version__ <4.0.10 is always '4'""" return importlib_version("kilosort") @classmethod From 23c39831a9cadba7ab50c88c53536723e93fba2f Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:48:10 +0100 Subject: [PATCH 77/90] Update .github/workflows/test_kilosort4.yml --- .github/workflows/test_kilosort4.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 24b2e29440..95fc30b0b2 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -42,7 +42,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] # TODO: just checking python version is not cause of failing test. + python-version: ["3.12"] os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: From ed9ef3251504a8d2388a5c461e5c8531113ccb09 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 20 Aug 2024 16:04:16 +0100 Subject: [PATCH 78/90] Fix linting. --- .github/scripts/test_kilosort4_ci.py | 2 +- .github/workflows/test_kilosort4.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 10855f2120..3ac8c7dd2b 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -337,7 +337,7 @@ def test_binary_filtered_arguments(self): def _check_arguments(self, object_, expected_arguments): """ - Check that the argument signature of `object_` is as expected + Check that the argument signature of `object_` is as expected (i..e has not changed across kilosort versions). """ sig = signature(object_) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 95fc30b0b2..390bec98be 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -42,7 +42,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] + python-version: ["3.12"] os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: From 522260cdda14ffab5a49675bf63b9bc8c44cbec5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Aug 2024 18:51:20 +0200 Subject: [PATCH 79/90] asarray and annotations --- src/spikeinterface/core/baserecording.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index efa6d03f56..dbec8a3730 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -422,7 +422,7 @@ def get_time_info(self, segment_index=None) -> dict: return time_kwargs - def get_times(self, segment_index=None): + def get_times(self, segment_index=None) -> np.ndarray: """Get time vector for a recording segment. If the segment has a time_vector, then it is returned. Otherwise @@ -809,10 +809,9 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): BaseSegment.__init__(self) - def get_times(self): + def get_times(self) -> np.ndarray: if self.time_vector is not None: - if not isinstance(self.time_vector, np.ndarray): - self.time_vector = np.array(self.time_vector) + self.time_vector = np.asarray(self.time_vector) return self.time_vector else: time_vector = np.arange(self.get_num_samples(), dtype="float64") From be0fd8afacea9508d38e65a9f89655a5e25bba57 Mon Sep 17 00:00:00 2001 From: JuanPimiento <148992347+JuanPimientoCaicedo@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:52:10 -0400 Subject: [PATCH 80/90] Add causal filtering to filter.py (#3172) --- doc/api.rst | 1 + src/spikeinterface/preprocessing/filter.py | 127 ++++++++++++++-- .../preprocessing/preprocessinglist.py | 1 + .../preprocessing/tests/test_filter.py | 137 +++++++++++++++++- 4 files changed, 254 insertions(+), 12 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 1966b48a37..42f9fec299 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -171,6 +171,7 @@ spikeinterface.preprocessing .. autofunction:: interpolate_bad_channels .. autofunction:: normalize_by_quantile .. autofunction:: notch_filter + .. autofunction:: causal_filter .. autofunction:: phase_shift .. autofunction:: rectify .. autofunction:: remove_artifacts diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 54c5ab2b2d..a67d163d3d 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -24,10 +24,12 @@ class FilterRecording(BasePreprocessor): """ - Generic filter class based on: - - * scipy.signal.iirfilter - * scipy.signal.filtfilt or scipy.signal.sosfilt + A generic filter class based on: + For filter coefficient generation: + * scipy.signal.iirfilter + For filter application: + * scipy.signal.filtfilt or scipy.signal.sosfiltfilt when direction = "forward-backward" + * scipy.signal.lfilter or scipy.signal.sosfilt when direction = "forward" or "backward" BandpassFilterRecording is built on top of it. @@ -56,6 +58,11 @@ class FilterRecording(BasePreprocessor): - numerator/denominator : ("ba") ftype : str, default: "butter" Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + direction : "forward" | "backward" | "forward-backward", default: "forward-backward" + Direction of filtering: + - "forward" - filter is applied to the timeseries in one direction, creating phase shifts + - "backward" - the timeseries is reversed, the filter is applied and filtered timeseries reversed again. Creates phase shifts in the opposite direction to "forward" + - "forward-backward" - Applies the filter in the forward and backward direction, resulting in zero-phase filtering. Note this doubles the effective filter order. Returns ------- @@ -75,6 +82,7 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, + direction="forward-backward", ): import scipy.signal @@ -106,7 +114,13 @@ def __init__( for parent_segment in recording._recording_segments: self.add_recording_segment( FilterRecordingSegment( - parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding + parent_segment, + filter_coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=add_reflect_padding, + direction=direction, ) ) @@ -121,14 +135,25 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, + direction=direction, ) class FilterRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False): + def __init__( + self, + parent_recording_segment, + coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=False, + direction="forward-backward", + ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff self.filter_mode = filter_mode + self.direction = direction self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype @@ -150,11 +175,24 @@ def get_traces(self, start_frame, end_frame, channel_indices): import scipy.signal - if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + if self.direction == "forward-backward": + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + else: + if self.direction == "backward": + traces_chunk = np.flip(traces_chunk, axis=0) + + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0) + + if self.direction == "backward": + filtered_traces = np.flip(filtered_traces, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -289,6 +327,73 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") + +def causal_filter( + recording, + direction="forward", + band=[300.0, 6000.0], + btype="bandpass", + filter_order=5, + ftype="butter", + filter_mode="sos", + margin_ms=5.0, + add_reflect_padding=False, + coeff=None, + dtype=None, +): + """ + Generic causal filter built on top of the filter function. + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + direction : "forward" | "backward", default: "forward" + Direction of causal filter. The "backward" option flips the traces in time before applying the filter + and then flips them back. + band : float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + btype : "bandpass" | "highpass", default: "bandpass" + Type of the filter + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + coeff : array | None, default: None + Filter coefficients in the filter_mode form. + dtype : dtype or None, default: None + The dtype of the returned traces. If None, the dtype of the parent recording is used + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + + Returns + ------- + filter_recording : FilterRecording + The causal-filtered recording extractor object + """ + assert direction in ["forward", "backward"], "Direction must be either 'forward' or 'backward'" + return filter( + recording=recording, + direction=direction, + band=band, + btype=btype, + filter_order=filter_order, + ftype=ftype, + filter_mode=filter_mode, + margin_ms=margin_ms, + add_reflect_padding=add_reflect_padding, + coeff=coeff, + dtype=dtype, + ) + + bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 149c6eb458..bdf5f2219c 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -12,6 +12,7 @@ notch_filter, HighpassFilterRecording, highpass_filter, + causal_filter, ) from .filter_gaussian import GaussianFilterRecording, gaussian_filter from .normalize_scale import ( diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 68790b3273..9df60af3db 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -4,7 +4,140 @@ from spikeinterface.core import generate_recording from spikeinterface import NumpyRecording, set_global_tmp_folder -from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter +from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter, causal_filter + + +class TestCausalFilter: + """ + The only thing that is not tested (JZ, as of 23/07/2024) is the + propagation of margin kwargs, these are general filter params + and can be tested in an upcoming PR. + """ + + @pytest.fixture(scope="session") + def recording_and_data(self): + recording = generate_recording(durations=[1]) + raw_data = recording.get_traces() + + return (recording, raw_data) + + def test_causal_filter_main_kwargs(self, recording_and_data): + """ + Perform a test that expected output is returned under change + of all key filter-related kwargs. First run the filter in + the forward direction with options and compare it + to the expected output from scipy. + + Next, change every filter-related kwarg and set in the backwards + direction. Again check it matches expected scipy output. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + # First, check in the forward direction with + # the default set of kwargs + options = self._get_filter_options() + + sos = self._run_iirfilter(options, recording) + + test_data = sosfilt(sos, raw_data, axis=0) + test_data.astype(recording.dtype) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + # Then, change all kwargs to ensure they are propagated + # and check the backwards version. + options["band"] = [671] + options["btype"] = "highpass" + options["filter_order"] = 8 + options["ftype"] = "bessel" + options["filter_mode"] = "ba" + options["dtype"] = np.float16 + + b, a = self._run_iirfilter(options, recording) + + flip_raw = np.flip(raw_data, axis=0) + test_data = lfilter(b, a, flip_raw, axis=0) + test_data = np.flip(test_data, axis=0) + test_data = test_data.astype(options["dtype"]) + + filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + def test_causal_filter_custom_coeff(self, recording_and_data): + """ + A different path is taken when custom coeff is selected. + Therefore, explicitly test the expected outputs are obtained + when passing custom coeff, under the "ba" and "sos" conditions. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + options = self._get_filter_options() + options["filter_mode"] = "ba" + options["coeff"] = (np.array([0.1, 0.2, 0.3]), np.array([0.4, 0.5, 0.6])) + + # Check the custom coeff are propagated in both modes. + # First, in "ba" mode + test_data = lfilter(options["coeff"][0], options["coeff"][1], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + # Next, in "sos" mode + options["filter_mode"] = "sos" + options["coeff"] = np.ones((2, 6)) + + test_data = sosfilt(options["coeff"], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + def test_causal_kwarg_error_raised(self, recording_and_data): + """ + Test that passing the "forward-backward" direction results in + an error. It is is critical this error is raised, + otherwise the filter will no longer be causal. + """ + recording, raw_data = recording_and_data + + with pytest.raises(BaseException) as e: + filt_data = causal_filter(recording, direction="forward-backward") + + def _run_iirfilter(self, options, recording): + """ + Convenience function to convert Si kwarg + names to Scipy. + """ + from scipy.signal import iirfilter + + return iirfilter( + N=options["filter_order"], + Wn=options["band"], + btype=options["btype"], + ftype=options["ftype"], + output=options["filter_mode"], + fs=recording.get_sampling_frequency(), + ) + + def _get_filter_options(self): + return { + "band": [300.0, 6000.0], + "btype": "bandpass", + "filter_order": 5, + "ftype": "butter", + "filter_mode": "sos", + "coeff": None, + } def test_filter(): @@ -28,6 +161,8 @@ def test_filter(): # other filtering types rec3 = filter(rec, band=500.0, btype="highpass", filter_mode="ba", filter_order=2) rec4 = notch_filter(rec, freq=3000, q=30, margin_ms=5.0) + rec5 = causal_filter(rec, direction="forward") + rec6 = causal_filter(rec, direction="backward") # filter from coefficients from scipy.signal import iirfilter From ae44b4a908855b8495d1d9807fddc73d8452b86a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 20 Aug 2024 21:20:27 +0100 Subject: [PATCH 81/90] Remove 'save_preprocessed' test. --- .github/scripts/test_kilosort4_ci.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 3ac8c7dd2b..e0d1f2a504 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -85,14 +85,6 @@ ("duplicate_spike_bins", 5), ] -# Update PARAMS_TO_TEST with version-dependent kwargs -if parse(version("kilosort")) >= parse("4.0.12"): - pass # TODO: expose? -# PARAMS_TO_TEST.extend( -# [ -# ("save_preprocessed_copy", False), -# ] -# ) if parse(version("kilosort")) >= parse("4.0.11"): PARAMS_TO_TEST.extend( [ From 642eea9b2c1242000dd847701eb89dc533def6be Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 21 Aug 2024 12:55:59 +0100 Subject: [PATCH 82/90] Update KS4 versions to test on, add a warning for the next version. --- .github/scripts/check_kilosort4_releases.py | 10 +++++++++- .github/scripts/kilosort4-latest-version.json | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 .github/scripts/kilosort4-latest-version.json diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index de11dc974b..92e7bf277f 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -4,6 +4,7 @@ import requests import json from packaging.version import parse +import spikeinterface def get_pypi_versions(package_name): """ @@ -15,7 +16,13 @@ def get_pypi_versions(package_name): response.raise_for_status() data = response.json() versions = list(sorted(data["releases"].keys())) - versions = [ver for ver in versions if parse(ver) >= parse("4.0.5")] + + assert parse(spikeinterface.__version__) < parse("0.101.1"), ( + "Kilosort 4.0.5-12 are supported in SpikeInterface < 0.101.1." + "At version 0.101.1, this should be updated to support newer" + "kilosort verrsions." + ) + versions = [ver for ver in versions if parse("4.0.12") >= parse(ver) >= parse("4.0.5")] return versions @@ -24,4 +31,5 @@ def get_pypi_versions(package_name): package_name = "kilosort" versions = get_pypi_versions(package_name) with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + print(versions) json.dump(versions, f) diff --git a/.github/scripts/kilosort4-latest-version.json b/.github/scripts/kilosort4-latest-version.json new file mode 100644 index 0000000000..03629ff842 --- /dev/null +++ b/.github/scripts/kilosort4-latest-version.json @@ -0,0 +1 @@ +["4.0.10", "4.0.11", "4.0.12", "4.0.5", "4.0.6", "4.0.7", "4.0.8", "4.0.9"] From 3ffe6dfda36b00be3e67ba181b60db7a209363d8 Mon Sep 17 00:00:00 2001 From: Matthias H Hennig Date: Sun, 25 Aug 2024 01:41:14 +0100 Subject: [PATCH 83/90] fix: download apptainer images without docker client --- src/spikeinterface/sorters/container_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 6406919455..6b194c0702 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -99,7 +99,7 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): singularity_image = sif_file else: - docker_image = self._get_docker_image(container_image) + docker_image = Client.load('docker://'+container_image) if docker_image and len(docker_image.tags) > 0: tag = docker_image.tags[0] print(f"Building singularity image from local docker image: {tag}") From 4644dd14e4ff3dbd414a720dc9656c8b0d1faade Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 00:43:05 +0000 Subject: [PATCH 84/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/container_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 6b194c0702..f9611586c9 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -99,7 +99,7 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): singularity_image = sif_file else: - docker_image = Client.load('docker://'+container_image) + docker_image = Client.load("docker://" + container_image) if docker_image and len(docker_image.tags) > 0: tag = docker_image.tags[0] print(f"Building singularity image from local docker image: {tag}") From 50643c195750f505038b535c37ba99e7b9d7031a Mon Sep 17 00:00:00 2001 From: Matthias H Hennig Date: Mon, 26 Aug 2024 09:57:50 +0100 Subject: [PATCH 85/90] Fix for ninor changes in latest Kilosort4 API --- src/spikeinterface/sorters/external/kilosort4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8499cef11f..8a81274d24 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -222,6 +222,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_name=probe_name, data_dir=data_dir, results_dir=results_dir, + bad_channels=None ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): @@ -232,7 +233,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR=do_CAR, invert_sign=invert_sign, device=device, - save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) + save_preprocessed_copy=save_preprocessed_copy, # this kwarg is correct (typo) ) else: ops = initialize_ops( From d1f546b7a493d6caa40fccfddb5fa0608d7a797d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:58:56 +0000 Subject: [PATCH 86/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8a81274d24..2ec6055d9b 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -222,7 +222,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_name=probe_name, data_dir=data_dir, results_dir=results_dir, - bad_channels=None + bad_channels=None, ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): From d4ea5c58a662497af47ac64977dd7fdbaf20edeb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 00:43:05 +0000 Subject: [PATCH 87/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/container_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 6b194c0702..f9611586c9 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -99,7 +99,7 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): singularity_image = sif_file else: - docker_image = Client.load('docker://'+container_image) + docker_image = Client.load("docker://" + container_image) if docker_image and len(docker_image.tags) > 0: tag = docker_image.tags[0] print(f"Building singularity image from local docker image: {tag}") From 3c39b3c1035272b0e3a9a81e46292e4915eb670f Mon Sep 17 00:00:00 2001 From: Matthias H Hennig Date: Mon, 26 Aug 2024 10:11:05 +0100 Subject: [PATCH 88/90] Revert "Fix for ninor changes in latest Kilosort4 API" This reverts commit 50643c195750f505038b535c37ba99e7b9d7031a. --- src/spikeinterface/sorters/external/kilosort4.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8a81274d24..8499cef11f 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -222,7 +222,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_name=probe_name, data_dir=data_dir, results_dir=results_dir, - bad_channels=None ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): @@ -233,7 +232,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR=do_CAR, invert_sign=invert_sign, device=device, - save_preprocessed_copy=save_preprocessed_copy, # this kwarg is correct (typo) + save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) ) else: ops = initialize_ops( From 8c80ecbd1d1b3b0168b9873bbab45bc71617dbd6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 27 Aug 2024 17:14:37 +0200 Subject: [PATCH 89/90] Make InterpolateMotionRecording not JSON-serializable --- .../sortingcomponents/motion/motion_interpolation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 11ce11e1aa..2108fdf9ca 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -386,6 +386,10 @@ def __init__( ) self.add_recording_segment(rec_segment) + # this object is currently not JSON-serializable because the Motion obejct cannot be reloaded properly + # see issue #3313 + self._serializability["json"] = False + self._kwargs = dict( recording=recording, motion=motion, From d6f3ced15f938943f83741b6867de7e7916a3de8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 28 Aug 2024 13:09:18 +0200 Subject: [PATCH 90/90] Update src/spikeinterface/core/recording_tools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/recording_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index cd2f563fba..5833f81ff8 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -951,7 +951,7 @@ def do_recording_attributes_match( bool True if the recordings have the same attributes str - A string with the an exception message with attributes that do not match + A string with the exception message with the attributes that do not match """ recording1_attributes = get_rec_attributes(recording1) recording2_attributes = deepcopy(recording2_attributes)