diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index cd2f563fba..7cbc236eda 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations from copy import deepcopy -from typing import Literal, Tuple +from typing import Literal import warnings from pathlib import Path import os @@ -930,8 +930,8 @@ def get_rec_attributes(recording): def do_recording_attributes_match( - recording1: "BaseRecording", recording2_attributes: bool, check_is_filtered: bool = True, check_dtype: bool = True -) -> Tuple[bool, str]: + recording1: "BaseRecording", recording2_attributes: bool, check_dtype: bool = True +) -> tuple[bool, str]: """ Check if two recordings have the same attributes @@ -941,8 +941,6 @@ def do_recording_attributes_match( The first recording object recording2_attributes : dict The recording attributes to test against - check_is_filtered : bool, default: True - If True, check if the recordings are filtered check_dtype : bool, default: True If True, check if the recordings have the same dtype @@ -962,31 +960,24 @@ def do_recording_attributes_match( non_matching_attrs = [] if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]): - attributes_match = False non_matching_attrs.append("channel_ids") if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]: - attributes_match = False non_matching_attrs.append("sampling_frequency") if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]: - attributes_match = False non_matching_attrs.append("num_channels") if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]: - attributes_match = False non_matching_attrs.append("num_samples") - if check_is_filtered: - if not recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]: - attributes_match = False - non_matching_attrs.append("is_filtered") # dtype is optional if "dtype" in recording1_attributes and "dtype" in recording2_attributes: if check_dtype: if not recording1_attributes["dtype"] == recording2_attributes["dtype"]: - attributes_match = False non_matching_attrs.append("dtype") if len(non_matching_attrs) > 0: + attributes_match = False exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}" else: + attributes_match = True exception_str = "" return attributes_match, exception_str diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d034dcb46a..7687017db6 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -608,9 +608,7 @@ def load_from_zarr(cls, folder, recording=None): return sorting_analyzer - def set_temporary_recording( - self, recording: BaseRecording, check_is_filtered: bool = True, check_dtype: bool = True - ): + def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True): """ Sets a temporary recording object. This function can be useful to temporarily set a "cached" recording object that is not saved in the SortingAnalyzer object to speed up @@ -622,14 +620,12 @@ def set_temporary_recording( ---------- recording : BaseRecording The recording object to set as temporary recording. - check_is_filtered : bool, default: True - If True, check that the temporary recording is filtered in the same way as the original recording. check_dtype : bool, default: True If True, check that the dtype of the temporary recording is the same as the original recording. """ # check that recording is compatible attributes_match, exception_str = do_recording_attributes_match( - recording, self.rec_attributes, check_is_filtered=check_is_filtered, check_dtype=check_dtype + recording, self.rec_attributes, check_dtype=check_dtype ) if not attributes_match: raise ValueError(exception_str) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 8a8fc3a358..23a1574f2a 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -316,16 +316,6 @@ def test_do_recording_attributes_match(): assert not do_match assert "sampling_frequency" in exc - # check is_filtered options - rec_attributes = get_rec_attributes(recording) - rec_attributes["is_filtered"] = not rec_attributes["is_filtered"] - - do_match, exc = do_recording_attributes_match(recording, rec_attributes) - assert not do_match - assert "is_filtered" in exc - do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_is_filtered=False) - assert do_match - # check dtype options rec_attributes = get_rec_attributes(recording) rec_attributes["dtype"] = "int16" diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 9de725239d..689073d6bf 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -144,15 +144,6 @@ def test_SortingAnalyzer_tmp_recording(dataset): with pytest.raises(ValueError): sorting_analyzer.set_temporary_recording(recording_sliced) - # test with different is_filtered - recording_filt = recording.clone() - recording_filt.annotate(is_filtered=False) - with pytest.raises(ValueError): - sorting_analyzer.set_temporary_recording(recording_filt) - - # thest with additional check_is_filtered - sorting_analyzer.set_temporary_recording(recording_filt, check_is_filtered=False) - def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):