Skip to content

Commit

Permalink
Suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Aug 28, 2024
1 parent bc1c704 commit 945fc15
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 39 deletions.
19 changes: 5 additions & 14 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Literal, Tuple
from typing import Literal
import warnings
from pathlib import Path
import os
Expand Down Expand Up @@ -930,8 +930,8 @@ def get_rec_attributes(recording):


def do_recording_attributes_match(
recording1: "BaseRecording", recording2_attributes: bool, check_is_filtered: bool = True, check_dtype: bool = True
) -> Tuple[bool, str]:
recording1: "BaseRecording", recording2_attributes: bool, check_dtype: bool = True
) -> tuple[bool, str]:
"""
Check if two recordings have the same attributes
Expand All @@ -941,8 +941,6 @@ def do_recording_attributes_match(
The first recording object
recording2_attributes : dict
The recording attributes to test against
check_is_filtered : bool, default: True
If True, check if the recordings are filtered
check_dtype : bool, default: True
If True, check if the recordings have the same dtype
Expand All @@ -962,31 +960,24 @@ def do_recording_attributes_match(
non_matching_attrs = []

if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]):
attributes_match = False
non_matching_attrs.append("channel_ids")
if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]:
attributes_match = False
non_matching_attrs.append("sampling_frequency")
if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]:
attributes_match = False
non_matching_attrs.append("num_channels")
if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]:
attributes_match = False
non_matching_attrs.append("num_samples")
if check_is_filtered:
if not recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]:
attributes_match = False
non_matching_attrs.append("is_filtered")
# dtype is optional
if "dtype" in recording1_attributes and "dtype" in recording2_attributes:
if check_dtype:
if not recording1_attributes["dtype"] == recording2_attributes["dtype"]:
attributes_match = False
non_matching_attrs.append("dtype")

if len(non_matching_attrs) > 0:
attributes_match = False
exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}"
else:
attributes_match = True
exception_str = ""

return attributes_match, exception_str
8 changes: 2 additions & 6 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,7 @@ def load_from_zarr(cls, folder, recording=None):

return sorting_analyzer

def set_temporary_recording(
self, recording: BaseRecording, check_is_filtered: bool = True, check_dtype: bool = True
):
def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True):
"""
Sets a temporary recording object. This function can be useful to temporarily set
a "cached" recording object that is not saved in the SortingAnalyzer object to speed up
Expand All @@ -622,14 +620,12 @@ def set_temporary_recording(
----------
recording : BaseRecording
The recording object to set as temporary recording.
check_is_filtered : bool, default: True
If True, check that the temporary recording is filtered in the same way as the original recording.
check_dtype : bool, default: True
If True, check that the dtype of the temporary recording is the same as the original recording.
"""
# check that recording is compatible
attributes_match, exception_str = do_recording_attributes_match(
recording, self.rec_attributes, check_is_filtered=check_is_filtered, check_dtype=check_dtype
recording, self.rec_attributes, check_dtype=check_dtype
)
if not attributes_match:
raise ValueError(exception_str)
Expand Down
10 changes: 0 additions & 10 deletions src/spikeinterface/core/tests/test_recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,16 +316,6 @@ def test_do_recording_attributes_match():
assert not do_match
assert "sampling_frequency" in exc

# check is_filtered options
rec_attributes = get_rec_attributes(recording)
rec_attributes["is_filtered"] = not rec_attributes["is_filtered"]

do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert not do_match
assert "is_filtered" in exc
do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_is_filtered=False)
assert do_match

# check dtype options
rec_attributes = get_rec_attributes(recording)
rec_attributes["dtype"] = "int16"
Expand Down
9 changes: 0 additions & 9 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,6 @@ def test_SortingAnalyzer_tmp_recording(dataset):
with pytest.raises(ValueError):
sorting_analyzer.set_temporary_recording(recording_sliced)

# test with different is_filtered
recording_filt = recording.clone()
recording_filt.annotate(is_filtered=False)
with pytest.raises(ValueError):
sorting_analyzer.set_temporary_recording(recording_filt)

# thest with additional check_is_filtered
sorting_analyzer.set_temporary_recording(recording_filt, check_is_filtered=False)


def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

Expand Down

0 comments on commit 945fc15

Please sign in to comment.