From 0aeb12a3557ed4bf255f76777fd03108a59c113c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Aug 2024 10:53:46 +0200 Subject: [PATCH 01/11] Add levels to GTStudy widgets --- .../comparison/groundtruthstudy.py | 261 ++++++++++++++++-- src/spikeinterface/widgets/gtstudy.py | 117 +++++--- 2 files changed, 316 insertions(+), 62 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index ba7268b4f0..bd608b62ce 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -57,6 +57,29 @@ def __init__(self, study_folder): @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): + """ + Create a GroundTruthStudy object from a dictionary of datasets and cases. + + Parameters + ---------- + study_folder : str + The folder where the GroundTruthStudy will be saved. + datasets : dict + A dictionary with the dataset keys and values as (recording, gt_sorting) tuples. + cases : dict + A dictionary with the case keys and values as dictionaries with the following keys: + * dataset: the key of the dataset + * label: a label for the case + * run_sorter_params: the parameters to run the sorter + If the keys are tuples of strings, they can represent several levels of parameters. + levels : str or list or None, default: None + The levels of the cases keys. If None, the levels are named "level0", "level1", etc. + + Returns + ------- + study : GroundTruthStudy + The GroundTruthStudy object. + """ # check that cases keys are homogeneous key0 = list(cases.keys())[0] if isinstance(key0, str): @@ -112,6 +135,9 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): return cls(study_folder) def scan_folder(self): + """ + Scan the folder to load or reload the datasets, cases, sortings, and comparisons. + """ if not (self.folder / "datasets").exists(): raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") @@ -154,6 +180,9 @@ def __repr__(self): return t def key_to_str(self, key): + """ + Convert a case key to a string. + """ if isinstance(key, str): return key elif isinstance(key, tuple): @@ -162,6 +191,14 @@ def key_to_str(self, key): raise ValueError("Keys for cases must str or tuple") def remove_sorting(self, key): + """ + Remove the sorting for a given case key. + + Parameters + ---------- + key : str or tuple + The case key to remove the sorting for. + """ sorting_folder = self.folder / "sortings" / self.key_to_str(key) log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" comparison_file = self.folder / "comparisons" / self.key_to_str(key) @@ -174,6 +211,17 @@ def remove_sorting(self, key): f.unlink() def set_colors(self, colors=None, map_name="tab20"): + """ + Set the colors for the cases. The self.colors will be a dictionary with + the case keys as keys and the colors as values. + + Parameters + ---------- + colors : list or None + The colors to use. If None, the colors are automatically generated. + map_name : str + The name of the color map to use. + """ from spikeinterface.widgets import get_some_colors if colors is None: @@ -184,15 +232,37 @@ def set_colors(self, colors=None, map_name="tab20"): else: self.colors = colors - def get_colors(self): + def get_colors(self, map_name="tab20"): + """ + Return the colors for the cases. If the colors are not set, they are automatically generated. + """ if self.colors is None: - self.set_colors() + self.set_colors(map_name=map_name) return self.colors - def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs=None, keep=True, verbose=False): + """ + Runs the sorters for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to run the sorters for. If None, all cases are run. + engine : "loop" | "slurm", default: "loop" + The engine to use. Can be "loop" or "slurm". + engine_kwargs : dict or None, default: None + The kwargs to pass to the engine. + keep : bool, default: True + If True, the sorting is kept if it already exists. + verbose : bool, default: False + If True, print more information. + """ if case_keys is None: case_keys = self.cases.keys() + if engine_kwargs is None: + engine_kwargs = {} + job_list = [] for key in case_keys: sorting_folder = self.folder / "sortings" / self.key_to_str(key) @@ -239,6 +309,17 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True self.copy_sortings(case_keys) def copy_sortings(self, case_keys=None, force=True): + """ + Copy the sortings from the sorter-specific folder to the sortings folder + usinf the numpy_folder format. + + Parameters + ---------- + case_keys : list or None + The case keys to copy the sortings for. If None, all cases are copied. + force : bool, default: True + If True, the sorting is copied even if it already exists. + """ if case_keys is None: case_keys = self.cases.keys() @@ -268,6 +349,18 @@ def copy_sortings(self, case_keys=None, force=True): shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): + """ + Run the comparisons with the ground-truth sorting for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to run the comparisons for. If None, all cases are run. + comparison_class : class, default: GroundTruthComparison + The class to use for the comparison. + kwargs : dict + The kwargs to pass to the comparison class. + """ if case_keys is None: case_keys = self.cases.keys() @@ -286,6 +379,19 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison pickle.dump(comp, f) def get_run_times(self, case_keys=None): + """ + Return the run times for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to get the run times for. If None, all cases are returned. + + Returns + ------- + run_times : dict + A dictionary with the case keys as keys and the run times as values\ + """ import pandas as pd if case_keys is None: @@ -303,7 +409,21 @@ def get_run_times(self, case_keys=None): return pd.Series(run_times, name="run_time") - def create_sorting_analyzer_gt(self, case_keys=None, random_params={}, waveforms_params={}, **job_kwargs): + def create_sorting_analyzer_gt(self, case_keys=None, random_params=None, template_params=None, **job_kwargs): + """ + Create the sorting analyzer for the ground-truth sorting for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to create the sorting analyzer for. If None, all cases are created. + random_params : dict or None + The parameters to pass to the `random_spikes` computation. + template_params : dict or None + The parameters to pass to the `templates` computation. + job_kwargs : keyword arguments + The kwargs to pass to the sorting analyzer. + """ if case_keys is None: case_keys = self.cases.keys() @@ -312,16 +432,33 @@ def create_sorting_analyzer_gt(self, case_keys=None, random_params={}, waveforms dataset_keys = [self.cases[key]["dataset"] for key in case_keys] dataset_keys = set(dataset_keys) + random_params = random_params if random_params is not None else {} + template_params = template_params if template_params is not None else {} for dataset_key in dataset_keys: # the waveforms depend on the dataset key folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) sorting_analyzer.compute("random_spikes", **random_params) - sorting_analyzer.compute("templates", **job_kwargs) + sorting_analyzer.compute("templates", **template_params, **job_kwargs) sorting_analyzer.compute("noise_levels") def get_sorting_analyzer(self, case_key=None, dataset_key=None): + """ + Get the ground-truth sorting analyzer for the given case key. + + Parameters + ---------- + case_key : str or tuple + The case key to get the sorting analyzer for. + dataset_key : str + The dataset key to get the sorting analyzer for. + + Returns + ------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer. + """ if case_key is not None: dataset_key = self.cases[case_key]["dataset"] @@ -329,12 +466,21 @@ def get_sorting_analyzer(self, case_key=None, dataset_key=None): sorting_analyzer = load_sorting_analyzer(folder) return sorting_analyzer - # def get_templates(self, key, mode="average"): - # analyzer = self.get_sorting_analyzer(case_key=key) - # templates = sorting_analyzer.get_all_templates(mode=mode) - # return templates - - def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): + def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False, **kwargs): + """ + Compute quality metrics for ground-truth units for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to compute the metrics for. If None, all cases are computed. + metric_names : list, default: ["snr", "firing_rate"] + The quality metrics to compute. + force : bool, default: False + If True, the metrics are recomputed even if they already exist. + kwargs : keyword arguments + The kwargs to pass to the quality metrics computation + """ if case_keys is None: case_keys = self.cases.keys() @@ -352,28 +498,76 @@ def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], f else: continue analyzer = self.get_sorting_analyzer(key) - metrics = compute_quality_metrics(analyzer, metric_names=metric_names) + metrics = compute_quality_metrics(analyzer, metric_names=metric_names, **kwargs) metrics.to_csv(filename, sep="\t", index=True) - def get_metrics(self, key): + def get_metrics(self, case_keys=None): + """ + Return the metrics for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to get the metrics for. If None, all cases are returned. + + Returns + ------- + metrics : pandas.DataFrame + The metrics for each case. The dataframe has a MultiIndex with + the case keys and a "gt_unit_id" column with the ground-truth unit ids. + """ import pandas as pd - dataset_key = self.cases[key]["dataset"] + if case_keys is None: + case_keys = self.cases.keys() + + metrics = None + for key in case_keys: + dataset_key = self.cases[key]["dataset"] - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if not filename.exists(): - return - metrics = pd.read_csv(filename, sep="\t", index_col=0) - dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] - metrics.index = gt_sorting.unit_ids + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if not filename.exists(): + continue + new_metrics = pd.read_csv(filename, sep="\t", index_col=0) + _, gt_sorting = self.datasets[dataset_key] + new_metrics.loc[:, "gt_unit_id"] = gt_sorting.unit_ids + new_metrics.index = pd.MultiIndex.from_tuples([key] * len(new_metrics), names=self.levels) + if metrics is None: + metrics = new_metrics + else: + metrics = pd.concat([metrics, new_metrics]) return metrics - def get_units_snr(self, key): - """ """ - return self.get_metrics(key)["snr"] + def get_units_snr(self, case_keys=None): + """ + Return the snr for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to get the snr for. If None, all cases are returned. + + Returns + ------- + snr : pandas.Series + The snr for each case. + """ + return self.get_metrics(case_keys)["snr"] def get_performance_by_unit(self, case_keys=None): + """ + Return a dataframe with the performance of the sorter for each unit for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to get the performance for. If None, all cases are returned. + + Returns + ------- + perf_by_unit : pandas.DataFrame + The performance for each unit and each case. + """ import pandas as pd if case_keys is None: @@ -401,6 +595,25 @@ def get_performance_by_unit(self, case_keys=None): return perf_by_unit def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + """ + Return a dataframe with the count of units for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to get the count for. If None, all cases are returned. + well_detected_score : float | None, default: None + The score to consider a unit as well detected. + redundant_score : float | None, default: None + The score to consider a unit as redundant. + overmerged_score : float | None, default: None + The score to consider a unit as overmerged. + + Returns + ------- + count_units : pandas.DataFrame + The count of units for each case + """ import pandas as pd if case_keys is None: diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 85043d0d12..e286a32c02 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -14,15 +14,18 @@ class StudyRunTimesWidget(BaseWidget): ---------- study : GroundTruthStudy A study object. - case_keys : list or None - A selection of cases to plot, if None, then all. - + case_keys : list or None, default: None + A selection of cases to plot, if None, then all cases are plotted. + levels : str or list-like or None, default: None + A selection of levels to group cases by, if None, then all + cases are treated as separate. """ def __init__( self, study, case_keys=None, + levels=None, backend=None, **backend_kwargs, ): @@ -30,7 +33,11 @@ def __init__( case_keys = list(study.cases.keys()) plot_data = dict( - study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors() + study=study, + run_times=study.get_run_times(case_keys), + case_keys=case_keys, + levels=levels, + colors=study.get_colors(), ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -51,7 +58,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.legend() -# TODO : plot optionally average on some levels using group by class StudyUnitCountsWidget(BaseWidget): """ Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" @@ -61,52 +67,66 @@ class StudyUnitCountsWidget(BaseWidget): ---------- study : GroundTruthStudy A study object. - case_keys : list or None - A selection of cases to plot, if None, then all. - + case_keys : list or None, default: None + A selection of cases to plot, if None, then all cases are plotted. + levels : str or list-like or None, default: None + A selection of levels to group cases by, if None, then all + cases are treated as separate. """ def __init__( self, study, case_keys=None, + levels=None, backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - plot_data = dict( - study=study, - count_units=study.get_count_units(case_keys=case_keys), - case_keys=case_keys, + study=study, count_units=study.get_count_units(case_keys=case_keys), case_keys=case_keys, levels=levels ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure from .utils import get_some_colors + import pandas as pd + dp = to_attr(data_plot) + count_units = dp.count_units + study = dp.study self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - columns = dp.count_units.columns.tolist() + if dp.case_keys is None: + case_keys = list(study.cases.keys()) + labels = {key: study.cases[key]["label"] for key in case_keys} + + columns = count_units.columns.tolist() columns.remove("num_gt") columns.remove("num_sorter") + if dp.levels is not None: + drop_levels = [l for l in study.levels if l not in dp.levels] + count_units = count_units.droplevel(drop_levels).sort_index() + case_keys = list(np.unique(count_units.index)) + if isinstance(count_units.index, pd.MultiIndex): + labels = {key: "-".join(key) for key in case_keys} + else: + labels = {key: key for key in case_keys} + ncol = len(columns) colors = get_some_colors(columns, color_engine="auto", map_name="hot") colors["num_well_detected"] = "green" xticklabels = [] - for i, key in enumerate(dp.case_keys): + for i, key in enumerate(case_keys): for c, col in enumerate(columns): x = i + 1 + c / (ncol + 1) - y = dp.count_units.loc[key, col] + y = count_units.loc[key, col] if not "well_detected" in col: y = -y @@ -116,8 +136,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): label = None self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) - - xticklabels.append(dp.study.cases[key]["label"]) + xticklabels.append(labels[key]) self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) self.ax.set_xticklabels(xticklabels) @@ -141,8 +160,11 @@ class StudyPerformances(BaseWidget): * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) performance_names : list or tuple, default: ("accuracy", "precision", "recall") Which performances to plot ("accuracy", "precision", "recall") - case_keys : list or None - A selection of cases to plot, if None, then all. + case_keys : list or None, default: None + A selection of cases to plot, if None, then all cases are plotted. + levels : str or list-like or None, default: None + A selection of levels to group cases by, if None, then all + cases are treated as separate. """ def __init__( @@ -151,22 +173,26 @@ def __init__( mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None, + levels=None, + cmap="tab20", backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) + perfs_by_unit = study.get_performance_by_unit(case_keys=case_keys) + if mode == "snr": + metrics = study.get_metrics(case_keys=case_keys) + perfs_by_unit = perfs_by_unit.merge(metrics, on=list(study.levels) + ["gt_unit_id"]) plot_data = dict( study=study, - perfs=study.get_performance_by_unit(case_keys=case_keys), + perfs=perfs_by_unit, mode=mode, performance_names=performance_names, + levels=levels, case_keys=case_keys, + cmap=cmap, ) - self.colors = study.get_colors() - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): @@ -181,6 +207,21 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): perfs = dp.perfs study = dp.study + if dp.case_keys is None: + case_keys = list(study.cases.keys()) + labels = {key: study.cases[key]["label"] for key in case_keys} + + if dp.levels is not None: + drop_levels = [l for l in study.levels if l not in dp.levels] + perfs = perfs.droplevel(drop_levels).sort_index() + case_keys = list(np.unique(perfs.index)) + if isinstance(perfs.index, pd.MultiIndex): + labels = {key: "-".join(key) for key in case_keys} + else: + labels = {key: key for key in case_keys} + + colors = get_some_colors(case_keys, map_name=dp.cmap, color_engine="matplotlib", shuffle=False, margin=0) + if dp.mode in ("ordered", "snr"): backend_kwargs["num_axes"] = len(dp.performance_names) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) @@ -188,11 +229,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.mode == "ordered": for count, performance_name in enumerate(dp.performance_names): ax = self.axes.flatten()[count] - for key in dp.case_keys: - label = study.cases[key]["label"] + for key in case_keys: + label = labels[key] val = perfs.xs(key).loc[:, performance_name].values val = np.sort(val)[::-1] - ax.plot(val, label=label, c=self.colors[key]) + ax.plot(val, label=label, c=colors[key]) ax.set_title(performance_name) if count == len(dp.performance_names) - 1: ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) @@ -203,11 +244,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.axes.flatten()[count] max_metric = 0 - for key in dp.case_keys: - x = study.get_metrics(key).loc[:, metric_name].values + for key in case_keys: + x = perfs.xs(key).loc[:, metric_name].values y = perfs.xs(key).loc[:, performance_name].values - label = study.cases[key]["label"] - ax.scatter(x, y, s=10, label=label, color=self.colors[key]) + label = labels[key] + ax.scatter(x, y, s=10, label=label, color=colors[key], alpha=0.5) max_metric = max(max_metric, np.max(x)) ax.set_title(performance_name) ax.set_xlim(0, max_metric * 1.05) @@ -216,7 +257,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.legend(loc="lower right") elif dp.mode == "swarm": - levels = perfs.index.names + levels = perfs.index.names if dp.levels is None else dp.levels df = pd.melt( perfs.reset_index(), id_vars=levels, @@ -236,8 +277,8 @@ class StudyAgreementMatrix(BaseWidget): ---------- study : GroundTruthStudy A study object. - case_keys : list or None - A selection of cases to plot, if None, then all. + case_keys : list or None, default: None + A selection of cases to plot, if None, then all cases are plotted. ordered : bool Order units with best agreement scores. This enable to see agreement on a diagonal. @@ -310,7 +351,7 @@ class StudySummary(BaseWidget): study : GroundTruthStudy A study object. case_keys : list or None, default: None - A selection of cases to plot, if None, then all. + A selection of cases to plot, if None, then all cases are plotted. """ def __init__( From 529e1b3052bc3c7f8b9471f6e17f04b08a244f9f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 19 Aug 2024 18:07:45 +0200 Subject: [PATCH 02/11] Improve pickling of comparison ibjects in GT studies --- .../comparison/groundtruthstudy.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index bd608b62ce..b9c0836580 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -158,6 +158,7 @@ def scan_folder(self): self.sortings = {k: None for k in self.cases} self.comparisons = {k: None for k in self.cases} for key in self.cases: + gt_sorting = self.datasets[self.cases[key]["dataset"]][1] sorting_folder = self.folder / "sortings" / self.key_to_str(key) if sorting_folder.exists(): self.sortings[key] = load_extractor(sorting_folder) @@ -167,6 +168,9 @@ def scan_folder(self): with open(comparison_file, mode="rb") as f: try: self.comparisons[key] = pickle.load(f) + # since we avoided pickling the absolute sorting paths, we need to set them here + self.comparisons[key].sorting1 = gt_sorting + self.comparisons[key].sorting2 = self.sortings[key] except Exception: pass @@ -375,8 +379,19 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison self.comparisons[key] = comp comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - with open(comparison_file, mode="wb") as f: - pickle.dump(comp, f) + # Since dumping to pickle hard-codes the sorting paths, here we temporarily set the sorting paths to None + # so that the comparison object can be pickled + # Upon reloading, we will set the sorting paths back to the correct values + comp.sorting1 = None + comp.sorting2 = None + # we also need a try-except block in case the folder is read-only + try: + with open(comparison_file, mode="wb") as f: + pickle.dump(comp, f) + except: + pass + comp.sorting1 = gt_sorting + comp.sorting2 = sorting def get_run_times(self, case_keys=None): """ From 95dba5e04d22eec70e9bbdb755a261c6fadb7ba7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Aug 2024 11:52:28 +0200 Subject: [PATCH 03/11] fix tests --- src/spikeinterface/comparison/groundtruthstudy.py | 12 ++++++------ .../comparison/tests/test_groundtruthstudy.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index b9c0836580..16975cda30 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -169,8 +169,8 @@ def scan_folder(self): try: self.comparisons[key] = pickle.load(f) # since we avoided pickling the absolute sorting paths, we need to set them here - self.comparisons[key].sorting1 = gt_sorting - self.comparisons[key].sorting2 = self.sortings[key] + self.comparisons[key]._sorting1 = gt_sorting + self.comparisons[key]._sorting2 = self.sortings[key] except Exception: pass @@ -382,16 +382,16 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison # Since dumping to pickle hard-codes the sorting paths, here we temporarily set the sorting paths to None # so that the comparison object can be pickled # Upon reloading, we will set the sorting paths back to the correct values - comp.sorting1 = None - comp.sorting2 = None + comp._sorting1 = None + comp._sorting2 = None # we also need a try-except block in case the folder is read-only try: with open(comparison_file, mode="wb") as f: pickle.dump(comp, f) except: pass - comp.sorting1 = gt_sorting - comp.sorting2 = sorting + comp._sorting1 = gt_sorting + comp._sorting2 = sorting def get_run_times(self, case_keys=None): """ diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index a92d6e9f77..f85ba50812 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -85,8 +85,11 @@ def test_GroundTruthStudy(setup_module): study.compute_metrics() - for key in study.cases: - metrics = study.get_metrics(key) + all_metrics = study.get_metrics() + print(all_metrics) + + for key in study.cases.keys(): + metrics = study.get_metrics(case_keys=[key]) print(metrics) study.get_performance_by_unit() @@ -94,5 +97,8 @@ def test_GroundTruthStudy(setup_module): if __name__ == "__main__": - setup_module() - test_GroundTruthStudy() + study_folder = Path("test") + if study_folder.is_dir(): + shutil.rmtree(study_folder) + create_a_study(study_folder) + test_GroundTruthStudy(study_folder) From d83273a1ce64d2431a6005a78b45e6c5e35fb967 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Aug 2024 16:15:18 +0200 Subject: [PATCH 04/11] wip: handle levels in plots --- .../comparison/groundtruthstudy.py | 97 ++++++++++--------- src/spikeinterface/widgets/gtstudy.py | 60 ++++++------ 2 files changed, 83 insertions(+), 74 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 16975cda30..90dbb0556d 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -169,8 +169,8 @@ def scan_folder(self): try: self.comparisons[key] = pickle.load(f) # since we avoided pickling the absolute sorting paths, we need to set them here - self.comparisons[key]._sorting1 = gt_sorting - self.comparisons[key]._sorting2 = self.sortings[key] + self.comparisons[key].sorting1 = gt_sorting + self.comparisons[key].sorting2 = self.sortings[key] except Exception: pass @@ -382,47 +382,16 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison # Since dumping to pickle hard-codes the sorting paths, here we temporarily set the sorting paths to None # so that the comparison object can be pickled # Upon reloading, we will set the sorting paths back to the correct values - comp._sorting1 = None - comp._sorting2 = None + comp.sorting1 = None + comp.sorting2 = None # we also need a try-except block in case the folder is read-only try: with open(comparison_file, mode="wb") as f: pickle.dump(comp, f) except: pass - comp._sorting1 = gt_sorting - comp._sorting2 = sorting - - def get_run_times(self, case_keys=None): - """ - Return the run times for the given case keys. - - Parameters - ---------- - case_keys : list or None - The case keys to get the run times for. If None, all cases are returned. - - Returns - ------- - run_times : dict - A dictionary with the case keys as keys and the run times as values\ - """ - import pandas as pd - - if case_keys is None: - case_keys = self.cases.keys() - - log_folder = self.folder / "sortings" / "run_logs" - - run_times = {} - for key in case_keys: - log_file = log_folder / f"{self.key_to_str(key)}.json" - with open(log_file, mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times[key] = run_time - - return pd.Series(run_times, name="run_time") + comp.sorting1 = gt_sorting + comp.sorting2 = sorting def create_sorting_analyzer_gt(self, case_keys=None, random_params=None, template_params=None, **job_kwargs): """ @@ -546,7 +515,11 @@ def get_metrics(self, case_keys=None): new_metrics = pd.read_csv(filename, sep="\t", index_col=0) _, gt_sorting = self.datasets[dataset_key] new_metrics.loc[:, "gt_unit_id"] = gt_sorting.unit_ids - new_metrics.index = pd.MultiIndex.from_tuples([key] * len(new_metrics), names=self.levels) + if isinstance(key, str): + index = [key] * len(new_metrics) + elif isinstance(key, tuple): + index = pd.MultiIndex.from_tuples([key] * len(new_metrics), names=self.levels) + new_metrics.index = index if metrics is None: metrics = new_metrics else: @@ -569,6 +542,45 @@ def get_units_snr(self, case_keys=None): """ return self.get_metrics(case_keys)["snr"] + def get_run_times(self, case_keys=None): + """ + Return the run times for the given case keys. + + Parameters + ---------- + case_keys : list or None + The case keys to get the run times for. If None, all cases are returned. + + Returns + ------- + run_times : dict + A dictionary with the case keys as keys and the run times as values\ + """ + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + log_folder = self.folder / "sortings" / "run_logs" + + run_times = [] + + if isinstance(case_keys[0], str): + index = case_keys + elif isinstance(case_keys[0], tuple): + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + for key in case_keys: + log_file = log_folder / f"{self.key_to_str(key)}.json" + with open(log_file, mode="r") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + run_times.append(run_time) + + run_times_df = pd.DataFrame(data={"run_time": run_times}, index=index) + + return run_times_df + def get_performance_by_unit(self, case_keys=None): """ Return a dataframe with the performance of the sorter for each unit for the given case keys. @@ -594,18 +606,15 @@ def get_performance_by_unit(self, case_keys=None): assert comp is not None, "You need to do study.run_comparisons() first" perf = comp.get_performance(method="by_unit", output="pandas") - + perf.loc[:, "gt_unit_id"] = perf.index if isinstance(key, str): - perf[self.levels] = key + perf.index = [key] * len(perf) elif isinstance(key, tuple): - for col, k in zip(self.levels, key): - perf[col] = k + perf.index = pd.MultiIndex.from_tuples([key] * len(perf), names=self.levels) - perf = perf.reset_index() perf_by_unit.append(perf) perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(self.levels) perf_by_unit = perf_by_unit.sort_index() return perf_by_unit diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index e286a32c02..e5db4e2ee8 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -5,6 +5,23 @@ from .base import BaseWidget, to_attr +def handle_levels(df, study, case_keys, levels): + if case_keys is None: + case_keys = list(study.cases.keys()) + labels = {key: study.cases[key]["label"] for key in case_keys} + + if levels is not None: + drop_levels = [l for l in study.levels if l not in levels] + df = df.droplevel(drop_levels).sort_index() + case_keys = list(np.unique(df.index)) + if isinstance(df.index, df.MultiIndex): + labels = {key: "-".join(key) for key in case_keys} + else: + labels = {key: key for key in case_keys} + + return df, case_keys, labels + + class StudyRunTimesWidget(BaseWidget): """ Plot sorter run times for a GroundTruthStudy @@ -32,6 +49,9 @@ def __init__( if case_keys is None: case_keys = list(study.cases.keys()) + if levels is not None: + assert all([l in study.levels for l in levels]), f"levels must be in {study.levels}" + plot_data = dict( study=study, run_times=study.get_run_times(case_keys), @@ -50,9 +70,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - for i, key in enumerate(dp.case_keys): + run_times, case_keys, labels = handle_levels(dp.run_times, dp.study, dp.case_keys, dp.levels) + + for i, key in enumerate(case_keys): label = dp.study.cases[key]["label"] - rt = dp.run_times.loc[key] + rt = run_times.loc[key] self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) self.ax.set_ylabel("run time (s)") self.ax.legend() @@ -82,6 +104,8 @@ def __init__( backend=None, **backend_kwargs, ): + if levels is not None: + assert all([l in study.levels for l in levels]), f"levels must be in {study.levels}" plot_data = dict( study=study, count_units=study.get_count_units(case_keys=case_keys), case_keys=case_keys, levels=levels ) @@ -96,27 +120,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) count_units = dp.count_units - study = dp.study self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - if dp.case_keys is None: - case_keys = list(study.cases.keys()) - labels = {key: study.cases[key]["label"] for key in case_keys} + count_units, case_keys, labels = handle_levels(dp.count_units, dp.study, dp.case_keys, dp.levels) columns = count_units.columns.tolist() columns.remove("num_gt") columns.remove("num_sorter") - if dp.levels is not None: - drop_levels = [l for l in study.levels if l not in dp.levels] - count_units = count_units.droplevel(drop_levels).sort_index() - case_keys = list(np.unique(count_units.index)) - if isinstance(count_units.index, pd.MultiIndex): - labels = {key: "-".join(key) for key in case_keys} - else: - labels = {key: key for key in case_keys} - ncol = len(columns) colors = get_some_colors(columns, color_engine="auto", map_name="hot") @@ -178,6 +190,8 @@ def __init__( backend=None, **backend_kwargs, ): + if levels is not None: + assert all([l in study.levels for l in levels]), f"levels must be in {study.levels}" perfs_by_unit = study.get_performance_by_unit(case_keys=case_keys) if mode == "snr": metrics = study.get_metrics(case_keys=case_keys) @@ -204,21 +218,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): import seaborn as sns dp = to_attr(data_plot) - perfs = dp.perfs - study = dp.study - - if dp.case_keys is None: - case_keys = list(study.cases.keys()) - labels = {key: study.cases[key]["label"] for key in case_keys} - - if dp.levels is not None: - drop_levels = [l for l in study.levels if l not in dp.levels] - perfs = perfs.droplevel(drop_levels).sort_index() - case_keys = list(np.unique(perfs.index)) - if isinstance(perfs.index, pd.MultiIndex): - labels = {key: "-".join(key) for key in case_keys} - else: - labels = {key: key for key in case_keys} + perfs, case_keys, labels = handle_levels(dp.perfs, dp.study, dp.case_keys, dp.levels) colors = get_some_colors(case_keys, map_name=dp.cmap, color_engine="matplotlib", shuffle=False, margin=0) From 15a59b0b32966e9ed3cbfec8f39a32b83d78c76a Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Tue, 20 Aug 2024 17:04:10 +0000 Subject: [PATCH 05/11] wip: handle levels in counts plot --- .../comparison/groundtruthstudy.py | 12 +-- src/spikeinterface/widgets/gtstudy.py | 78 +++++++++++++++---- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 90dbb0556d..d4fecfd8c0 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -169,8 +169,8 @@ def scan_folder(self): try: self.comparisons[key] = pickle.load(f) # since we avoided pickling the absolute sorting paths, we need to set them here - self.comparisons[key].sorting1 = gt_sorting - self.comparisons[key].sorting2 = self.sortings[key] + self.comparisons[key]._sorting1 = gt_sorting + self.comparisons[key]._sorting2 = self.sortings[key] except Exception: pass @@ -382,16 +382,16 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison # Since dumping to pickle hard-codes the sorting paths, here we temporarily set the sorting paths to None # so that the comparison object can be pickled # Upon reloading, we will set the sorting paths back to the correct values - comp.sorting1 = None - comp.sorting2 = None + comp._sorting1 = None + comp._sorting2 = None # we also need a try-except block in case the folder is read-only try: with open(comparison_file, mode="wb") as f: pickle.dump(comp, f) except: pass - comp.sorting1 = gt_sorting - comp.sorting2 = sorting + comp._sorting1 = gt_sorting + comp._sorting2 = sorting def create_sorting_analyzer_gt(self, case_keys=None, random_params=None, template_params=None, **job_kwargs): """ diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index e5db4e2ee8..600f77ea07 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -6,15 +6,19 @@ def handle_levels(df, study, case_keys, levels): + import pandas as pd + if case_keys is None: case_keys = list(study.cases.keys()) - labels = {key: study.cases[key]["label"] for key in case_keys} + labels = {key: study.cases[key]["label"] for key in case_keys} if levels is not None: drop_levels = [l for l in study.levels if l not in levels] df = df.droplevel(drop_levels).sort_index() + if len(levels) > 1: + df = df.reorder_levels(levels) case_keys = list(np.unique(df.index)) - if isinstance(df.index, df.MultiIndex): + if isinstance(df.index, pd.MultiIndex): labels = {key: "-".join(key) for key in case_keys} else: labels = {key: key for key in case_keys} @@ -35,7 +39,11 @@ class StudyRunTimesWidget(BaseWidget): A selection of cases to plot, if None, then all cases are plotted. levels : str or list-like or None, default: None A selection of levels to group cases by, if None, then all - cases are treated as separate. + cases are treated as separate in a bar plot. + When specified, ff levels is a string or a 1-element tuple/list, + then it will be treated as the "x" variable of a boxplot. In case it's a + 2-element object, the first element is "x", the second is "hue". + More than 2 elements are not supported """ def __init__( @@ -43,6 +51,7 @@ def __init__( study, case_keys=None, levels=None, + cmap="tab20", backend=None, **backend_kwargs, ): @@ -50,6 +59,9 @@ def __init__( case_keys = list(study.cases.keys()) if levels is not None: + if isinstance(levels, str): + levels = [levels] + assert len(levels) < 3, "You can pass at most 2 levels to plot against!" assert all([l in study.levels for l in levels]), f"levels must be in {study.levels}" plot_data = dict( @@ -58,12 +70,16 @@ def __init__( case_keys=case_keys, levels=levels, colors=study.get_colors(), + cmap=cmap ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt + import seaborn as sns + + from .utils import get_some_colors from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) @@ -72,12 +88,29 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): run_times, case_keys, labels = handle_levels(dp.run_times, dp.study, dp.case_keys, dp.levels) - for i, key in enumerate(case_keys): - label = dp.study.cases[key]["label"] - rt = run_times.loc[key] - self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) + if dp.levels is None: + x = None + hue = case_keys + colors = get_some_colors(case_keys, map_name=dp.cmap, color_engine="matplotlib", shuffle=False, margin=0) + plt_fun = sns.barplot + elif len(dp.levels) == 1: + x = None + colors = get_some_colors(case_keys, map_name=dp.cmap, color_engine="matplotlib", shuffle=False, margin=0) + hue = dp.levels[0] + plt_fun = sns.boxplot + elif len(dp.levels) == 2: + x, hue = dp.levels + hues = np.unique([c[1] for c in case_keys]) + colors = get_some_colors(hues, map_name=dp.cmap, color_engine="matplotlib", shuffle=False, margin=0) + plt_fun = sns.boxplot + + plt_fun(data=run_times, y="run_time", x=x, hue=hue, ax=self.ax, palette=colors) + self.ax.set_ylabel("run time (s)") - self.ax.legend() + sns.despine(ax=self.ax) + if dp.levels is None: + h, l = self.ax.get_legend_handles_labels() + self.ax.legend(h, list(labels.values())) class StudyUnitCountsWidget(BaseWidget): @@ -105,6 +138,8 @@ def __init__( **backend_kwargs, ): if levels is not None: + if isinstance(levels, str): + levels = [levels] assert all([l in study.levels for l in levels]), f"levels must be in {study.levels}" plot_data = dict( study=study, count_units=study.get_count_units(case_keys=case_keys), case_keys=case_keys, levels=levels @@ -113,10 +148,12 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): + import seaborn as sns + import pandas as pd + from .utils_matplotlib import make_mpl_figure from .utils import get_some_colors - import pandas as pd dp = to_attr(data_plot) count_units = dp.count_units @@ -128,17 +165,23 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): columns = count_units.columns.tolist() columns.remove("num_gt") columns.remove("num_sorter") - ncol = len(columns) + if dp.levels is not None: + if len(columns) > 1: + assert len(levels) == 1, f"Only one level at a time is allowed to display {ncol} counts: {columns}" + colors = get_some_colors(columns, color_engine="auto", map_name="hot") colors["num_well_detected"] = "green" xticklabels = [] + + # use melt for i, key in enumerate(case_keys): for c, col in enumerate(columns): x = i + 1 + c / (ncol + 1) y = count_units.loc[key, col] + print(y) if not "well_detected" in col: y = -y @@ -150,9 +193,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) xticklabels.append(labels[key]) - self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) + self.ax.set_xticks(np.arange(len(case_keys)) + 1) self.ax.set_xticklabels(xticklabels) self.ax.legend() + sns.despine(ax=self.ax) class StudyPerformances(BaseWidget): @@ -211,15 +255,14 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors - import pandas as pd import seaborn as sns + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + dp = to_attr(data_plot) perfs, case_keys, labels = handle_levels(dp.perfs, dp.study, dp.case_keys, dp.levels) - colors = get_some_colors(case_keys, map_name=dp.cmap, color_engine="matplotlib", shuffle=False, margin=0) if dp.mode in ("ordered", "snr"): @@ -237,6 +280,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_title(performance_name) if count == len(dp.performance_names) - 1: ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) + sns.despine(ax=ax) elif dp.mode == "snr": metric_name = dp.mode @@ -255,6 +299,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_ylim(0, 1.05) if count == 0: ax.legend(loc="lower right") + sns.despine(ax=ax) elif dp.mode == "swarm": levels = perfs.index.names if dp.levels is None else dp.levels @@ -266,7 +311,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): value_vars=dp.performance_names, ) df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) - sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True, ax=self.ax) + sns.despine(ax=self.ax) class StudyAgreementMatrix(BaseWidget): From b4d4fe02ec13cbe64098b2fd061203f028700e08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:06:06 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/gtstudy.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 600f77ea07..fd1f6ccc4a 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -39,9 +39,9 @@ class StudyRunTimesWidget(BaseWidget): A selection of cases to plot, if None, then all cases are plotted. levels : str or list-like or None, default: None A selection of levels to group cases by, if None, then all - cases are treated as separate in a bar plot. + cases are treated as separate in a bar plot. When specified, ff levels is a string or a 1-element tuple/list, - then it will be treated as the "x" variable of a boxplot. In case it's a + then it will be treated as the "x" variable of a boxplot. In case it's a 2-element object, the first element is "x", the second is "hue". More than 2 elements are not supported """ @@ -70,7 +70,7 @@ def __init__( case_keys=case_keys, levels=levels, colors=study.get_colors(), - cmap=cmap + cmap=cmap, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -110,7 +110,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sns.despine(ax=self.ax) if dp.levels is None: h, l = self.ax.get_legend_handles_labels() - self.ax.legend(h, list(labels.values())) + self.ax.legend(h, list(labels.values())) class StudyUnitCountsWidget(BaseWidget): @@ -154,7 +154,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure from .utils import get_some_colors - dp = to_attr(data_plot) count_units = dp.count_units From 58c7611bf62f5a010b120ec888545c63e9767bc2 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Wed, 21 Aug 2024 14:47:43 +0000 Subject: [PATCH 07/11] fix unit counts --- src/spikeinterface/widgets/gtstudy.py | 104 ++++++++++++++++++-------- 1 file changed, 71 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index fd1f6ccc4a..643563465a 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -40,7 +40,7 @@ class StudyRunTimesWidget(BaseWidget): levels : str or list-like or None, default: None A selection of levels to group cases by, if None, then all cases are treated as separate in a bar plot. - When specified, ff levels is a string or a 1-element tuple/list, + When specified, if levels is a string or a 1-element tuple/list, then it will be treated as the "x" variable of a boxplot. In case it's a 2-element object, the first element is "x", the second is "hue". More than 2 elements are not supported @@ -127,6 +127,17 @@ class StudyUnitCountsWidget(BaseWidget): levels : str or list-like or None, default: None A selection of levels to group cases by, if None, then all cases are treated as separate. + When specified, if levels is a string or a 1-element tuple/list, + then it will be treated as the "x" variable of a boxplot. In case it's a + 2-element object, the first element is "x", the second is "hue". + More than 2 elements are not supported. + If the number of counts to plot is more than one (e.g., in case of exhaustive + ground truth), then only one level at a time is supported. + labels : dict or None, default: None + The labels to use for each case key in case levels is None. + rotation : int or None, default: 45 + The rotation for the x tick labels + """ def __init__( @@ -134,7 +145,10 @@ def __init__( study, case_keys=None, levels=None, + labels=None, + rotation=45, backend=None, + cmap="tab20", **backend_kwargs, ): if levels is not None: @@ -142,7 +156,13 @@ def __init__( levels = [levels] assert all([l in study.levels for l in levels]), f"levels must be in {study.levels}" plot_data = dict( - study=study, count_units=study.get_count_units(case_keys=case_keys), case_keys=case_keys, levels=levels + study=study, + count_units=study.get_count_units(case_keys=case_keys), + case_keys=case_keys, + levels=levels, + labels=labels, + cmap=cmap, + rotation=rotation ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -155,46 +175,64 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils import get_some_colors dp = to_attr(data_plot) + study = dp.study count_units = dp.count_units + levels = dp.levels self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - count_units, case_keys, labels = handle_levels(dp.count_units, dp.study, dp.case_keys, dp.levels) + count_units, case_keys, labels = handle_levels(dp.count_units, dp.study, dp.case_keys, levels) + count_units = count_units.drop(columns=["num_gt", "num_sorter"]) - columns = count_units.columns.tolist() - columns.remove("num_gt") - columns.remove("num_sorter") - ncol = len(columns) - - if dp.levels is not None: - if len(columns) > 1: - assert len(levels) == 1, f"Only one level at a time is allowed to display {ncol} counts: {columns}" - - colors = get_some_colors(columns, color_engine="auto", map_name="hot") - colors["num_well_detected"] = "green" + if dp.labels is not None: + labels = dp.labels - xticklabels = [] + for col in count_units.columns: + vals = count_units[col].values + if not "well_detected" in col: + vals = -vals + col_name = col.replace("num_", "").replace("_", " ").title() + count_units.loc[:, col_name] = vals + del count_units[col] - # use melt - for i, key in enumerate(case_keys): - for c, col in enumerate(columns): - x = i + 1 + c / (ncol + 1) - y = count_units.loc[key, col] - print(y) - if not "well_detected" in col: - y = -y - - if i == 0: - label = col.replace("num_", "").replace("_", " ").title() - else: - label = None + columns = count_units.columns.tolist() + ncol = len(columns) - self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) - xticklabels.append(labels[key]) + count_units = count_units.reset_index() + if levels is not None: + if len(levels) == 1: + var_name = "Metric" + x = levels[0] + y = "Num Units" + hue = "Metric" + color_list = columns + else: + assert len(columns) == 1, ( + f"Multi-levels is not supported when multiple metrics counts are available ({columns})" + ) + var_name = None + x, hue = levels + y = columns[0] + color_list = list(np.unique(count_units[hue])) + else: + count_units.loc[:, "Label"] = labels.values() + levels = study.levels + ["Label"] + var_name = "Metric" + x = "Label" + y = "Num Units" + hue = "Metric" + color_list = columns + + colors = get_some_colors(color_list, color_engine="auto", map_name=dp.cmap) + # Well Detected is always present + colors["Well Detected"] = "green" + if var_name is not None: + df = count_units.melt(id_vars=levels, var_name=var_name, value_name="Num Units") + else: + df = count_units - self.ax.set_xticks(np.arange(len(case_keys)) + 1) - self.ax.set_xticklabels(xticklabels) - self.ax.legend() + sns.barplot(df, x=x, y=y, hue=hue, ax=self.ax, palette=colors,) + _ = self.ax.set_xticklabels(self.ax.get_xticklabels(), rotation=dp.rotation) sns.despine(ax=self.ax) From 64dc7e9e720235c381cbddcddd5d0af41d9a6fc0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:49:04 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/gtstudy.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 643563465a..90ccf84285 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -162,7 +162,7 @@ def __init__( levels=levels, labels=labels, cmap=cmap, - rotation=rotation + rotation=rotation, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -207,9 +207,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): hue = "Metric" color_list = columns else: - assert len(columns) == 1, ( - f"Multi-levels is not supported when multiple metrics counts are available ({columns})" - ) + assert ( + len(columns) == 1 + ), f"Multi-levels is not supported when multiple metrics counts are available ({columns})" var_name = None x, hue = levels y = columns[0] @@ -231,7 +231,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: df = count_units - sns.barplot(df, x=x, y=y, hue=hue, ax=self.ax, palette=colors,) + sns.barplot( + df, + x=x, + y=y, + hue=hue, + ax=self.ax, + palette=colors, + ) _ = self.ax.set_xticklabels(self.ax.get_xticklabels(), rotation=dp.rotation) sns.despine(ax=self.ax) From 954a0efed9abf4dd45ca3a7b2a29b5844833527c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 28 Aug 2024 09:54:56 +0200 Subject: [PATCH 09/11] Reduce memory footprint of gt-study --- .../comparison/groundtruthstudy.py | 67 +++++++++++++------ 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index d4fecfd8c0..c060fe33de 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -134,10 +134,29 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): return cls(study_folder) - def scan_folder(self): + def load_recording(self, dataset_key): + """ + Load the recording for a given dataset key. + + Parameters + ---------- + dataset_key : str + The dataset key. + + Returns + ------- + recording : Recording + The recording object. + """ + rec_file = self.folder / "datasets" / "recordings" / f"{dataset_key}.pickle" + recording = load_extractor(rec_file) + return recording + + def scan_folder(self, load_recordings=False, load_comparisons=False): """ Scan the folder to load or reload the datasets, cases, sortings, and comparisons. """ + print("Scanning folder") if not (self.folder / "datasets").exists(): raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") @@ -146,33 +165,37 @@ def scan_folder(self): self.levels = self.info["levels"] - for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): - key = rec_file.stem - rec = load_extractor(rec_file) - gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) - self.datasets[key] = (rec, gt_sorting) - with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) self.sortings = {k: None for k in self.cases} self.comparisons = {k: None for k in self.cases} + self.datasets = {} for key in self.cases: - gt_sorting = self.datasets[self.cases[key]["dataset"]][1] + dataset_key = self.cases[key]["dataset"] + if dataset_key not in self.datasets: + if load_recordings: + recording = load_extractor(dataset_key) + else: + recording = None + gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / dataset_key) + self.datasets[dataset_key] = (recording, gt_sorting) sorting_folder = self.folder / "sortings" / self.key_to_str(key) if sorting_folder.exists(): self.sortings[key] = load_extractor(sorting_folder) - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - if comparison_file.exists(): - with open(comparison_file, mode="rb") as f: - try: - self.comparisons[key] = pickle.load(f) - # since we avoided pickling the absolute sorting paths, we need to set them here - self.comparisons[key]._sorting1 = gt_sorting - self.comparisons[key]._sorting2 = self.sortings[key] - except Exception: - pass + if load_comparisons: + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") + if comparison_file.exists(): + with open(comparison_file, mode="rb") as f: + try: + gt_sorting = self.datasets[self.cases[key]["dataset"]][1] + self.comparisons[key] = pickle.load(f) + # since we avoided pickling the absolute sorting paths, we need to set them here + self.comparisons[key]._sorting1 = gt_sorting + self.comparisons[key]._sorting2 = self.sortings[key] + except Exception: + pass def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" @@ -293,7 +316,11 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs=None, keep=Tr params = self.cases[key]["run_sorter_params"].copy() # this ensure that sorter_name is given - recording, _ = self.datasets[self.cases[key]["dataset"]] + dataset_key = self.cases[key]["dataset"] + recording, _ = self.datasets[dataset_key] + if recording is None: + recording = self.load_recording(dataset_key) + self.datasets[dataset_key] = (recording, self.datasets[dataset_key][1]) sorter_name = params.pop("sorter_name") job = dict( sorter_name=sorter_name, @@ -422,6 +449,8 @@ def create_sorting_analyzer_gt(self, case_keys=None, random_params=None, templat # the waveforms depend on the dataset key folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] + if recording is None: + recording = self.load_recording(dataset_key) sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) sorting_analyzer.compute("random_spikes", **random_params) sorting_analyzer.compute("templates", **template_params, **job_kwargs) From b8052edc1c881d2b2f6e4a4b98c8c7134a82b145 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 28 Aug 2024 10:57:47 +0200 Subject: [PATCH 10/11] Reduce memory footprint of gt-study 2 --- .../comparison/groundtruthstudy.py | 67 +++++++++++-------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index c060fe33de..bb1fb59fc8 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -44,13 +44,13 @@ class GroundTruthStudy: Note that the underlying folder structure is not backward compatible! """ - def __init__(self, study_folder): + def __init__(self, study_folder, datasets=None, cases=None, sortings=None, comparisons=None): self.folder = Path(study_folder) - self.datasets = {} - self.cases = {} - self.sortings = {} - self.comparisons = {} + self.datasets = datasets + self.cases = cases + self.sortings = sortings + self.comparisons = comparisons self.colors = None self.scan_folder() @@ -132,7 +132,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): # cases is dumped to a pickle file, json is not possible because of the tuple key (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) - return cls(study_folder) + return cls(study_folder, datasets=datasets, cases=cases) def load_recording(self, dataset_key): """ @@ -165,37 +165,46 @@ def scan_folder(self, load_recordings=False, load_comparisons=False): self.levels = self.info["levels"] - with open(self.folder / "cases.pickle", "rb") as f: - self.cases = pickle.load(f) + if self.cases is None: + with open(self.folder / "cases.pickle", "rb") as f: + self.cases = pickle.load(f) - self.sortings = {k: None for k in self.cases} - self.comparisons = {k: None for k in self.cases} - self.datasets = {} - for key in self.cases: - dataset_key = self.cases[key]["dataset"] - if dataset_key not in self.datasets: + # load datasets + if self.datasets is None: + self.datasets = {} + for key in self.cases: + dataset_key = self.cases[key]["dataset"] if load_recordings: - recording = load_extractor(dataset_key) + recording = self.load_recording(dataset_key) else: recording = None gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / dataset_key) self.datasets[dataset_key] = (recording, gt_sorting) - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - if sorting_folder.exists(): - self.sortings[key] = load_extractor(sorting_folder) + # load sortings + if self.sortings is None: + self.sortings = {} + for key in self.cases: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + if sorting_folder.exists(): + self.sortings[key] = load_extractor(sorting_folder) + + # load comparisons + if self.comparisons is None: + self.comparisons = {} if load_comparisons: - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - if comparison_file.exists(): - with open(comparison_file, mode="rb") as f: - try: - gt_sorting = self.datasets[self.cases[key]["dataset"]][1] - self.comparisons[key] = pickle.load(f) - # since we avoided pickling the absolute sorting paths, we need to set them here - self.comparisons[key]._sorting1 = gt_sorting - self.comparisons[key]._sorting2 = self.sortings[key] - except Exception: - pass + for key in self.cases: + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") + if comparison_file.exists(): + with open(comparison_file, mode="rb") as f: + try: + gt_sorting = self.datasets[self.cases[key]["dataset"]][1] + self.comparisons[key] = pickle.load(f) + # since we avoided pickling the absolute sorting paths, we need to set them here + self.comparisons[key]._sorting1 = gt_sorting + self.comparisons[key]._sorting2 = self.sortings[key] + except Exception: + pass def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" From 1214390d24454e02db9c0085ecbd1d4df2b693b3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 28 Aug 2024 11:27:19 +0200 Subject: [PATCH 11/11] Remove optional loading of recordings and comparisons --- .../comparison/groundtruthstudy.py | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index bb1fb59fc8..ceda1b3feb 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -152,7 +152,7 @@ def load_recording(self, dataset_key): recording = load_extractor(rec_file) return recording - def scan_folder(self, load_recordings=False, load_comparisons=False): + def scan_folder(self): """ Scan the folder to load or reload the datasets, cases, sortings, and comparisons. """ @@ -174,10 +174,7 @@ def scan_folder(self, load_recordings=False, load_comparisons=False): self.datasets = {} for key in self.cases: dataset_key = self.cases[key]["dataset"] - if load_recordings: - recording = self.load_recording(dataset_key) - else: - recording = None + recording = self.load_recording(dataset_key) gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / dataset_key) self.datasets[dataset_key] = (recording, gt_sorting) @@ -192,19 +189,18 @@ def scan_folder(self, load_recordings=False, load_comparisons=False): # load comparisons if self.comparisons is None: self.comparisons = {} - if load_comparisons: - for key in self.cases: - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - if comparison_file.exists(): - with open(comparison_file, mode="rb") as f: - try: - gt_sorting = self.datasets[self.cases[key]["dataset"]][1] - self.comparisons[key] = pickle.load(f) - # since we avoided pickling the absolute sorting paths, we need to set them here - self.comparisons[key]._sorting1 = gt_sorting - self.comparisons[key]._sorting2 = self.sortings[key] - except Exception: - pass + for key in self.cases: + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") + if comparison_file.exists(): + with open(comparison_file, mode="rb") as f: + try: + gt_sorting = self.datasets[self.cases[key]["dataset"]][1] + self.comparisons[key] = pickle.load(f) + # since we avoided pickling the absolute sorting paths, we need to set them here + self.comparisons[key]._sorting1 = gt_sorting + self.comparisons[key]._sorting2 = self.sortings[key] + except Exception: + pass def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" @@ -327,9 +323,6 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs=None, keep=Tr # this ensure that sorter_name is given dataset_key = self.cases[key]["dataset"] recording, _ = self.datasets[dataset_key] - if recording is None: - recording = self.load_recording(dataset_key) - self.datasets[dataset_key] = (recording, self.datasets[dataset_key][1]) sorter_name = params.pop("sorter_name") job = dict( sorter_name=sorter_name,