Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into fix-template-similarity-depre…
Browse files Browse the repository at this point in the history
…cation-warning
  • Loading branch information
rkim48 authored Sep 2, 2024
2 parents 6f10148 + 161efc4 commit 2f87919
Show file tree
Hide file tree
Showing 25 changed files with 469 additions and 210 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
description = "Python toolkit for analysis, visualization, and comparison of spike sorting output"
readme = "README.md"
requires-python = ">=3.8,<4.0"
requires-python = ">=3.9,<4.0"
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: MIT License",
Expand Down Expand Up @@ -181,11 +181,12 @@ test = [

docs = [
"Sphinx",
"sphinx_rtd_theme",
"sphinx_rtd_theme>=1.2",
"sphinx-gallery",
"sphinx-design",
"numpydoc",
"ipython",
"sphinxcontrib-jquery",

# for notebooks in the gallery
"MEArec", # Use as an example
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/comparison/paircomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,12 @@ def print_summary(self, well_detected_score=None, redundant_score=None, overmerg
num_gt=len(self.unit1_ids),
num_tested=len(self.unit2_ids),
num_well_detected=self.count_well_detected_units(well_detected_score),
num_redundant=self.count_redundant_units(redundant_score),
num_overmerged=self.count_overmerged_units(overmerged_score),
)

if self.exhaustive_gt:
txt = txt + _template_summary_part2
d["num_redundant"] = self.count_redundant_units(redundant_score)
d["num_overmerged"] = self.count_overmerged_units(overmerged_score)
d["num_false_positive_units"] = self.count_false_positive_units()
d["num_bad"] = self.count_bad_units()

Expand Down Expand Up @@ -676,11 +676,11 @@ def count_units_categories(
GT num_units: {num_gt}
TESTED num_units: {num_tested}
num_well_detected: {num_well_detected}
num_redundant: {num_redundant}
num_overmerged: {num_overmerged}
"""

_template_summary_part2 = """num_false_positive_units {num_false_positive_units}
_template_summary_part2 = """num_redundant: {num_redundant}
num_overmerged: {num_overmerged}
num_false_positive_units {num_false_positive_units}
num_bad: {num_bad}
"""

Expand Down
48 changes: 30 additions & 18 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .globals import get_global_tmp_folder, is_set_global_tmp_folder
from .core_tools import (
check_json,
clean_zarr_folder_name,
is_dict_extractor,
SIJsonEncoder,
make_paths_relative,
Expand Down Expand Up @@ -164,7 +165,7 @@ def id_to_index(self, id) -> int:
def annotate(self, **new_annotations) -> None:
self._annotations.update(new_annotations)

def set_annotation(self, annotation_key, value: Any, overwrite=False) -> None:
def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> None:
"""This function adds an entry to the annotations dictionary.
Parameters
Expand Down Expand Up @@ -192,7 +193,7 @@ def get_preferred_mp_context(self):
"""
return self._preferred_mp_context

def get_annotation(self, key, copy: bool = True) -> Any:
def get_annotation(self, key: str, copy: bool = True) -> Any:
"""
Get a annotation.
Return a copy by default
Expand All @@ -205,7 +206,13 @@ def get_annotation(self, key, copy: bool = True) -> Any:
def get_annotation_keys(self) -> List:
return list(self._annotations.keys())

def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, missing_value: Any = None) -> None:
def set_property(
self,
key,
values: list | np.ndarray | tuple,
ids: list | np.ndarray | tuple | None = None,
missing_value: Any = None,
) -> None:
"""
Set property vector for main ids.
Expand All @@ -223,6 +230,7 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi
Array of values for the property
ids : list/np.array, default: None
List of subset of ids to set the values, default: None
if None which is the default all the ids are set or changed
missing_value : object, default: None
In case the property is set on a subset of values ("ids" not None),
it specifies the how the missing values should be filled.
Expand All @@ -240,23 +248,26 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi
dtype_kind = dtype.kind

if ids is None:
assert values.shape[0] == size
assert (
values.shape[0] == size
), f"Values must have the same size as the main ids: {size} but got array of size {values.shape[0]}"
self._properties[key] = values
else:
ids = np.array(ids)
assert np.unique(ids).size == ids.size, "'ids' are not unique!"
unique_ids = np.unique(ids)
if unique_ids.size != ids.size:
non_unique_ids = [id for id in ids if np.count_nonzero(ids == id) > 1]
raise ValueError(f"IDs are not unique: {non_unique_ids}")

# Not clear where this branch is used, perhaps on aggregation of extractors?
if ids.size < size:
if key not in self._properties:
# create the property with nan or empty
shape = (size,) + values.shape[1:]

if missing_value is None:
if dtype_kind not in self.default_missing_property_values.keys():
raise Exception(
"For values dtypes other than float, string, object or unicode, the missing value "
"cannot be automatically inferred. Please specify it with the 'missing_value' "
"argument."
raise ValueError(
f"Can't infer a natural missing value for dtype {dtype_kind}. "
"Please provide it with the `missing_value` argument"
)
else:
missing_value = self.default_missing_property_values[dtype_kind]
Expand All @@ -268,15 +279,18 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi
"as the values."
)

# create the property with nan or empty
shape = (size,) + values.shape[1:]
empty_values = np.zeros(shape, dtype=dtype)
empty_values[:] = missing_value
self._properties[key] = empty_values
if ids.size == 0:
return
else:
assert dtype_kind == self._properties[key].dtype.kind, (
"Mismatch between existing property dtype " "values dtype."
)
existing_property = self._properties[key].dtype
assert (
dtype_kind == existing_property.kind
), f"Mismatch between existing property dtype {existing_property.kind} and provided values dtype {dtype_kind}."

indices = self.ids_to_indices(ids)
self._properties[key][indices] = values
Expand All @@ -285,7 +299,7 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi
self._properties[key] = np.zeros_like(values, dtype=values.dtype)
self._properties[key][indices] = values

def get_property(self, key, ids: Optional[Iterable] = None) -> np.ndarray:
def get_property(self, key: str, ids: Optional[Iterable] = None) -> np.ndarray:
values = self._properties.get(key, None)
if ids is not None and values is not None:
inds = self.ids_to_indices(ids)
Expand Down Expand Up @@ -1048,9 +1062,7 @@ def save_to_zarr(
print(f"Use zarr_path={zarr_path}")
else:
if storage_options is None:
folder = Path(folder)
if folder.suffix != ".zarr":
folder = folder.parent / f"{folder.stem}.zarr"
folder = clean_zarr_folder_name(folder)
if folder.is_dir() and overwrite:
shutil.rmtree(folder)
zarr_path = folder
Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,18 @@ def list_to_string(lst, max_size=6):
def _repr_header(self):
num_segments = self.get_num_segments()
num_channels = self.get_num_channels()
sf_hz = self.get_sampling_frequency()
sf_khz = sf_hz / 1000
dtype = self.get_dtype()

total_samples = self.get_total_samples()
total_duration = self.get_total_duration()
total_memory_size = self.get_total_memory_size()
sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz"

sf_hz = self.get_sampling_frequency()
if not sf_hz.is_integer():
sampling_frequency_repr = f"{sf_hz:f} Hz"
else:
# Khz for high sampling rate and Hz for LFP
sampling_frequency_repr = f"{(sf_hz/1000.0):0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz"

txt = (
f"{self.name}: "
Expand Down
24 changes: 24 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def check_json(dictionary: dict) -> dict:
return json.loads(json_string)


def clean_zarr_folder_name(folder):
folder = Path(folder)
if folder.suffix != ".zarr":
folder = folder.parent / f"{folder.stem}.zarr"
return folder


def add_suffix(file_path, possible_suffix):
file_path = Path(file_path)
if isinstance(possible_suffix, str):
Expand Down Expand Up @@ -684,3 +691,20 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float:
memory = mem_info.total - mem_info.available

return memory


def is_path_remote(path: str | Path) -> bool:
"""
Returns True if the path is a remote path (e.g., s3:// or gcs://).
Parameters
----------
path : str or Path
The path to check.
Returns
-------
bool
Whether the path is a remote path.
"""
return "s3://" in str(path) or "gcs://" in str(path)
Loading

0 comments on commit 2f87919

Please sign in to comment.