Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype of quality metrics before and after merging #3497

Merged
merged 23 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/spikeinterface/qualitymetrics/quality_metric_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
compute_pc_metrics,
_misc_metric_name_to_func,
_possible_pc_metric_names,
qm_compute_name_to_column_names,
compute_name_to_column_names,
zm711 marked this conversation as resolved.
Show resolved Hide resolved
column_name_to_column_dtype,
)
from .misc_metrics import _default_params as misc_metrics_params
from .pca_metrics import _default_params as pca_metrics_params
Expand Down Expand Up @@ -140,13 +141,20 @@ def _merge_extension_data(
all_unit_ids = new_sorting_analyzer.unit_ids
not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)]

# this creates a new metrics dictionary, but the dtype for everything will be
# object. So we will need to fix this later after computing metrics
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(
new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs
)

# we need to fix the dtypes after we compute everything because we have nans
# we can iterate through the columns and convert them back to the dtype
# of the original quality dataframe.
for column in old_metrics.columns:
metrics[column] = metrics[column].astype(old_metrics[column].dtype)

new_data = dict(metrics=metrics)
return new_data

Expand Down Expand Up @@ -229,10 +237,25 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
# add NaN for empty units
if len(empty_unit_ids) > 0:
metrics.loc[empty_unit_ids] = np.nan
# num_spikes is an int and should be 0
if "num_spikes" in metrics.columns:
metrics.loc[empty_unit_ids, ["num_spikes"]] = 0

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
metrics = metrics.convert_dtypes()

# we do this because the convert_dtypes infers the wrong types sometimes.
# the actual types for columns can be found in column_name_to_column_dtype dictionary.
for column in metrics.columns:
# we have one issue where the name of the columns for synchrony are named based on
# what the user has input as arguments so we need a way to handle this separately
# everything else should be handled with the column name.
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
if "sync" in column:
metrics[column] = metrics[column].astype(column_name_to_column_dtype["sync"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue we keep this for backward compatibility no? I could add a comment saying we can simplify this in a couple versions.

else:
metrics[column] = metrics[column].astype(column_name_to_column_dtype[column])
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

return metrics

def _run(self, verbose=False, **job_kwargs):
Expand Down
39 changes: 38 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
"amplitude_cutoff": ["amplitude_cutoff"],
"amplitude_median": ["amplitude_median"],
"amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"],
"synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"],
"synchrony": [
"sync_spike_2",
"sync_spike_4",
"sync_spike_8",
], # we probably shouldn't hard code this. This is determined by the arguments in the function...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I think we agreed that we can just hard-code this at the QM level, so it should be ok.

Let's keep the comment until this is actually hard-coded!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this was done already: #3559

alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
"firing_range": ["firing_range"],
"drift": ["drift_ptp", "drift_std", "drift_mad"],
"sd_ratio": ["sd_ratio"],
Expand All @@ -79,3 +83,36 @@
"silhouette": ["silhouette"],
"silhouette_full": ["silhouette_full"],
}

# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them
column_name_to_column_dtype = {
"num_spikes": int,
"firing_rate": float,
"presence_ratio": float,
"snr": float,
"isi_violations_ratio": float,
"isi_violations_count": float,
"rp_violations": float,
"rp_contamination": float,
"sliding_rp_violation": float,
"amplitude_cutoff": float,
"amplitude_median": float,
"amplitude_cv_median": float,
"amplitude_cv_range": float,
"sync": float,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this could become sync_spike_2, sync_spike_4, sync_spike_8 as well

alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
"firing_range": float,
"drift_ptp": float,
"drift_std": float,
"drift_mad": float,
"sd_ratio": float,
"isolation_distance": float,
"l_ratio": float,
"d_prime": float,
"nn_hit_rate": float,
"nn_miss_rate": float,
"nn_isolation": float,
"nn_unit_id": float,
"nn_noise_overlap": float,
"silhouette": float,
"silhouette_full": float,
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def test_compute_quality_metrics(sorting_analyzer_simple):
assert "isolation_distance" in metrics.columns


def test_merging_quality_metrics(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple

metrics = compute_quality_metrics(
sorting_analyzer,
metric_names=None,
qm_params=dict(isi_violation=dict(isi_threshold_ms=2)),
skip_pc_metrics=False,
seed=2205,
)

# sorting_analyzer_simple has ten units
new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]])

new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data()

# we should copy over the metrics after merge
for column in metrics.columns:
assert column in new_metrics.columns
# should copy dtype too
assert metrics[column].dtype == new_metrics[column].dtype

# 10 units vs 9 units
assert len(metrics.index) > len(new_metrics.index)


def test_compute_quality_metrics_recordingless(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple
Expand Down Expand Up @@ -106,10 +133,15 @@ def test_empty_units(sorting_analyzer_simple):
seed=2205,
)

for empty_unit_id in sorting_empty.get_empty_unit_ids():
# num_spikes are ints not nans so we confirm empty units are nans for everything except
# num_spikes which should be 0
nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"]
for empty_unit_ids in sorting_empty.get_empty_unit_ids():
from pandas import isnull

assert np.all(isnull(metrics_empty.loc[empty_unit_id].values))
assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values))
if "num_spikes" in metrics_empty.columns:
assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0


# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()
Expand Down
Loading