Skip to content

Commit

Permalink
Add bad channels and do version check
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Aug 26, 2024
1 parent 3921806 commit 0ee0c43
Showing 1 changed file with 36 additions and 28 deletions.
64 changes: 36 additions & 28 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Kilosort4Sorter(BaseSorter):
"scaleproc": None,
"save_preprocessed_copy": False,
"torch_device": "auto",
"bad_channels": None,
}

_params_description = {
Expand Down Expand Up @@ -101,6 +102,7 @@ class Kilosort4Sorter(BaseSorter):
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
"save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data",
"torch_device": "Select the torch device auto/cuda/cpu",
"bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.",
}

sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching.
Expand All @@ -110,7 +112,7 @@ class Kilosort4Sorter(BaseSorter):
For more information see https://github.com/MouseLand/Kilosort"""

installation_mesg = """\nTo use Kilosort4 run:\n
>>> pip install kilosort==4.0
>>> pip install kilosort --upgrade
More information on Kilosort4 at:
https://github.com/MouseLand/Kilosort
Expand All @@ -134,6 +136,25 @@ def get_sorter_version(cls):
"""kilosort.__version__ <4.0.10 is always '4'"""
return importlib_version("kilosort")

@classmethod
def initialize_folder(cls, recording, output_folder, verbose, remove_existing_folder):
if not cls.is_installed():
raise Exception(
f"The sorter {cls.sorter_name} is not installed. Please install it with:\n{cls.installation_mesg}"
)
cls.check_sorter_version()
return super(Kilosort4Sorter, cls).initialize_folder(recording, output_folder, verbose, remove_existing_folder)

@classmethod
def check_sorter_version(cls):
kilosort_version = version.parse(cls.get_sorter_version())
if kilosort_version < version.parse("4.0.16"):
raise Exception(
f"""SpikeInterface only supports kilosort versions 4.0.16 and above. You are running version {kilosort_version}. To install the latest version, run:
>>> pip install kilosort --upgrade
"""
)

@classmethod
def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
from probeinterface import write_prb
Expand Down Expand Up @@ -214,6 +235,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# NOTE: Also modifies settings in-place
data_dir = ""
results_dir = sorter_output_folder
bad_channels = params["bad_channels"]

filename, data_dir, results_dir, probe = set_files(
settings=settings,
Expand All @@ -222,36 +244,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
probe_name=probe_name,
data_dir=data_dir,
results_dir=results_dir,
bad_channels=bad_channels,
)

if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"):
ops = initialize_ops(
settings=settings,
probe=probe,
data_dtype=recording.get_dtype(),
do_CAR=do_CAR,
invert_sign=invert_sign,
device=device,
save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo)
)
else:
ops = initialize_ops(
settings=settings,
probe=probe,
data_dtype=recording.get_dtype(),
do_CAR=do_CAR,
invert_sign=invert_sign,
device=device,
)
ops = initialize_ops(
settings=settings,
probe=probe,
data_dtype=recording.get_dtype(),
do_CAR=do_CAR,
invert_sign=invert_sign,
device=device,
save_preprocessed_copy=save_preprocessed_copy, # this kwarg is correct (typo)
)

if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"):
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
)
else:
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = (
get_run_parameters(ops)
)
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
)

# Set preprocessing and drift correction parameters
if not params["skip_kilosort_preprocessing"]:
Expand Down

0 comments on commit 0ee0c43

Please sign in to comment.