Skip to content

Commit

Permalink
Merge pull request #3215 from samuelgarcia/fix_sa
Browse files Browse the repository at this point in the history
Implement a simple system to have backward compatibility for Analyzer extension
  • Loading branch information
alejoe91 authored Jul 18, 2024
2 parents 7d426ec + ace9c19 commit 63b295c
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 36 deletions.
30 changes: 10 additions & 20 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,16 @@ class ComputeTemplates(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = True
need_backward_compatibility_on_load = True

def _handle_backward_compatibility_on_load(self):
if "ms_before" not in self.params:
# compatibility february 2024 > july 2024
self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency

if "ms_after" not in self.params:
# compatibility february 2024 > july 2024
self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency

def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=None):
operators = operators or ["average", "std"]
Expand Down Expand Up @@ -487,31 +497,11 @@ def _compute_and_append_from_waveforms(self, operators):

@property
def nbefore(self):
if "ms_before" not in self.params:
# compatibility february 2024 > july 2024
self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency
warnings.warn(
"The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning,
stacklevel=2,
)

nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return nbefore

@property
def nafter(self):
if "ms_after" not in self.params:
# compatibility february 2024 > july 2024
warnings.warn(
"The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning,
stacklevel=2,
)
self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency

nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return nafter

Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids):
"""
old_unit_ids = np.asarray(old_unit_ids)
dtype = old_unit_ids.dtype
if dtype.kind == "U":
# the new dtype can be longer
dtype = "U"

assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups"

Expand All @@ -361,7 +365,7 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids):
all_unit_ids.remove(unit_id)
if new_unit_id not in all_unit_ids:
all_unit_ids.append(new_unit_id)
return np.array(all_unit_ids)
return np.array(all_unit_ids, dtype=dtype)


def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy="append"):
Expand Down
20 changes: 16 additions & 4 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,10 @@ def _save_or_select_or_merge(
# make a copy of extensions
# note that the copy of extension handle itself the slicing of units when necessary and also the saveing
sorted_extensions = _sort_extensions_by_dependency(self.extensions)
# hack: quality metrics are computed at last
qm_extension_params = sorted_extensions.pop("quality_metrics", None)
if qm_extension_params is not None:
sorted_extensions["quality_metrics"] = qm_extension_params
recompute_dict = {}

for extension_name, extension in sorted_extensions.items():
Expand Down Expand Up @@ -1207,7 +1211,9 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar

# check dependencies
if extension_class.need_recording:
assert self.has_recording(), f"Extension {extension_name} requires the recording"
assert (
self.has_recording() or self.has_temporary_recording()
), f"Extension {extension_name} requires the recording"
for dependency_name in extension_class.depend_on:
if "|" in dependency_name:
ok = any(self.get_extension(name) is not None for name in dependency_name.split("|"))
Expand Down Expand Up @@ -1401,9 +1407,7 @@ def load_extension(self, extension_name: str):

extension_class = get_extension_class(extension_name)

extension_instance = extension_class(self)
extension_instance.load_params()
extension_instance.load_data()
extension_instance = extension_class.load(self)

self.extensions[extension_name] = extension_instance

Expand Down Expand Up @@ -1702,6 +1706,7 @@ class AnalyzerExtension:
use_nodepipeline = False
nodepipeline_variables = None
need_job_kwargs = False
need_backward_compatibility_on_load = False

def __init__(self, sorting_analyzer):
self._sorting_analyzer = weakref.ref(sorting_analyzer)
Expand Down Expand Up @@ -1740,6 +1745,10 @@ def _get_data(self):
# must be implemented in subclass
raise NotImplementedError

def _handle_backward_compatibility_on_load(self):
# must be implemented in subclass only if need_backward_compatibility_on_load=True
raise NotImplementedError

@classmethod
def function_factory(cls):
# make equivalent
Expand Down Expand Up @@ -1817,6 +1826,9 @@ def load(cls, sorting_analyzer):
ext = cls(sorting_analyzer)
ext.load_params()
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()

return ext

def load_params(self):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ def test_generate_unit_ids_for_merge_group():

test_apply_merges_to_sorting()
test_get_ids_after_merging()
test_generate_unit_ids_for_merge_group()
# test_generate_unit_ids_for_merge_group()
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
templates[mode] = np.load(template_file)
if len(templates) > 0:
ext = ComputeTemplates(sorting_analyzer)
ext.params = dict(nbefore=nbefore, nafter=nafter, operators=list(templates.keys()))
ext.params = dict(ms_before=params["ms_before"], ms_after=params["ms_after"], operators=list(templates.keys()))
for mode, arr in templates.items():
ext.data[mode] = arr
sorting_analyzer.extensions["templates"] = ext
Expand All @@ -544,10 +544,6 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
ext = new_class(sorting_analyzer)
with open(ext_folder / "params.json", "r") as f:
params = json.load(f)
# update params
new_params = ext._set_params()
updated_params = make_ext_params_up_to_date(ext, params, new_params)
ext.set_params(**updated_params)

if new_name == "spike_amplitudes":
amplitudes = []
Expand Down Expand Up @@ -604,6 +600,13 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
pc_all[mask, ...] = pc_one
ext.data["pca_projection"] = pc_all

# update params
new_params = ext._set_params()
updated_params = make_ext_params_up_to_date(ext, params, new_params)
ext.set_params(**updated_params, save=False)
if ext.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()

sorting_analyzer.extensions[new_name] = ext

return sorting_analyzer
Expand All @@ -614,13 +617,12 @@ def make_ext_params_up_to_date(ext, old_params, new_params):
old_name = ext.extension_name
updated_params = old_params.copy()
for p, values in old_params.items():
if isinstance(values, dict):
if p not in new_params:
warnings.warn(f"Removing legacy parameter {p} from {old_name} extension")
updated_params.pop(p)
elif isinstance(values, dict):
new_values = new_params.get(p, {})
updated_params[p] = make_ext_params_up_to_date(ext, values, new_values)
else:
if p not in new_params:
warnings.warn(f"Removing legacy param {p} from {old_name} extension")
updated_params.pop(p)
return updated_params


Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ class ComputeTemplateSimilarity(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _handle_backward_compatibility_on_load(self):
if "max_lag_ms" not in self.params:
# make compatible analyzer created between february 24 and july 24
self.params["max_lag_ms"] = 0.0
self.params["support"] = "union"

def _set_params(self, method="cosine", max_lag_ms=0, support="union"):
if method == "cosine_similarity":
warnings.warn(
Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/postprocessing/unit_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,17 @@ class ComputeUnitLocations(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _handle_backward_compatibility_on_load(self):
if "method_kwargs" in self.params:
# make compatible analyzer created between february 24 and july 24
method_kwargs = self.params.pop("method_kwargs")
self.params.update(**method_kwargs)

def _set_params(self, method="monopolar_triangulation", **method_kwargs):
params = dict(method=method)
params.update(method_kwargs)
Expand Down

0 comments on commit 63b295c

Please sign in to comment.