Skip to content

Commit

Permalink
Improve do_recording_attributes_match impelmentation, errors, and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Aug 20, 2024
1 parent df23822 commit bc1c704
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 20 deletions.
52 changes: 42 additions & 10 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Literal
from typing import Literal, Tuple
import warnings
from pathlib import Path
import os
Expand Down Expand Up @@ -929,7 +929,9 @@ def get_rec_attributes(recording):
return rec_attributes


def do_recording_attributes_match(recording1, recording2_attributes) -> bool:
def do_recording_attributes_match(
recording1: "BaseRecording", recording2_attributes: bool, check_is_filtered: bool = True, check_dtype: bool = True
) -> Tuple[bool, str]:
"""
Check if two recordings have the same attributes
Expand All @@ -939,22 +941,52 @@ def do_recording_attributes_match(recording1, recording2_attributes) -> bool:
The first recording object
recording2_attributes : dict
The recording attributes to test against
check_is_filtered : bool, default: True
If True, check if the recordings are filtered
check_dtype : bool, default: True
If True, check if the recordings have the same dtype
Returns
-------
bool
True if the recordings have the same attributes
str
A string with the an exception message with attributes that do not match
"""
recording1_attributes = get_rec_attributes(recording1)
recording2_attributes = deepcopy(recording2_attributes)
recording1_attributes.pop("properties")
recording2_attributes.pop("properties")

return (
np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"])
and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]
and recording1_attributes["num_channels"] == recording2_attributes["num_channels"]
and recording1_attributes["num_samples"] == recording2_attributes["num_samples"]
and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]
and recording1_attributes["dtype"] == recording2_attributes["dtype"]
)
attributes_match = True
non_matching_attrs = []

if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]):
attributes_match = False
non_matching_attrs.append("channel_ids")
if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]:
attributes_match = False
non_matching_attrs.append("sampling_frequency")
if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]:
attributes_match = False
non_matching_attrs.append("num_channels")
if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]:
attributes_match = False
non_matching_attrs.append("num_samples")
if check_is_filtered:
if not recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]:
attributes_match = False
non_matching_attrs.append("is_filtered")
# dtype is optional
if "dtype" in recording1_attributes and "dtype" in recording2_attributes:
if check_dtype:
if not recording1_attributes["dtype"] == recording2_attributes["dtype"]:
attributes_match = False
non_matching_attrs.append("dtype")

if len(non_matching_attrs) > 0:
exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}"
else:
exception_str = ""

return attributes_match, exception_str
19 changes: 14 additions & 5 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,9 @@ def load_from_zarr(cls, folder, recording=None):

return sorting_analyzer

def set_temporary_recording(self, recording: BaseRecording):
def set_temporary_recording(
self, recording: BaseRecording, check_is_filtered: bool = True, check_dtype: bool = True
):
"""
Sets a temporary recording object. This function can be useful to temporarily set
a "cached" recording object that is not saved in the SortingAnalyzer object to speed up
Expand All @@ -620,12 +622,19 @@ def set_temporary_recording(self, recording: BaseRecording):
----------
recording : BaseRecording
The recording object to set as temporary recording.
check_is_filtered : bool, default: True
If True, check that the temporary recording is filtered in the same way as the original recording.
check_dtype : bool, default: True
If True, check that the dtype of the temporary recording is the same as the original recording.
"""
# check that recording is compatible
assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match."
assert np.array_equal(
recording.get_channel_locations(), self.get_channel_locations()
), "Recording channel locations do not match."
attributes_match, exception_str = do_recording_attributes_match(
recording, self.rec_attributes, check_is_filtered=check_is_filtered, check_dtype=check_dtype
)
if not attributes_match:
raise ValueError(exception_str)
if not np.array_equal(recording.get_channel_locations(), self.get_channel_locations()):
raise ValueError("Recording channel locations do not match.")
if self._recording is not None:
warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.")
self._temporary_recording = recording
Expand Down
41 changes: 41 additions & 0 deletions src/spikeinterface/core/tests/test_recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
get_channel_distances,
get_noise_levels,
order_channels_by_depth,
do_recording_attributes_match,
get_rec_attributes,
)


Expand Down Expand Up @@ -300,6 +302,45 @@ def test_order_channels_by_depth():
assert np.array_equal(order_2d[::-1], order_2d_fliped)


def test_do_recording_attributes_match():
recording = NoiseGeneratorRecording(
num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated"
)
rec_attributes = get_rec_attributes(recording)
do_match, _ = do_recording_attributes_match(recording, rec_attributes)
assert do_match

rec_attributes = get_rec_attributes(recording)
rec_attributes["sampling_frequency"] = 1.0
do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert not do_match
assert "sampling_frequency" in exc

# check is_filtered options
rec_attributes = get_rec_attributes(recording)
rec_attributes["is_filtered"] = not rec_attributes["is_filtered"]

do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert not do_match
assert "is_filtered" in exc
do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_is_filtered=False)
assert do_match

# check dtype options
rec_attributes = get_rec_attributes(recording)
rec_attributes["dtype"] = "int16"
do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert not do_match
assert "dtype" in exc
do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_dtype=False)
assert do_match

# check missing dtype
rec_attributes.pop("dtype")
do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert do_match


if __name__ == "__main__":
# Create a temporary folder using the standard library
import tempfile
Expand Down
11 changes: 10 additions & 1 deletion src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,18 @@ def test_SortingAnalyzer_tmp_recording(dataset):
recording_sliced = recording.channel_slice(recording.channel_ids[:-1])

# wrong channels
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
sorting_analyzer.set_temporary_recording(recording_sliced)

# test with different is_filtered
recording_filt = recording.clone()
recording_filt.annotate(is_filtered=False)
with pytest.raises(ValueError):
sorting_analyzer.set_temporary_recording(recording_filt)

# thest with additional check_is_filtered
sorting_analyzer.set_temporary_recording(recording_filt, check_is_filtered=False)


def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,6 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
else:
rec_attributes["probegroup"] = None

if "dtype" not in rec_attributes:
warnings.warn("dtype not found in rec_attributes. Setting to float32")
rec_attributes["dtype"] = "float32"

# recording
recording = None
if (folder / "recording.json").exists():
Expand Down

0 comments on commit bc1c704

Please sign in to comment.