diff --git a/doc/api.rst b/doc/api.rst index 6bb9b39091..eb9a61eb9c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -346,6 +346,9 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + .. autofunction:: load_model + .. autofunction:: auto_label_units + .. autofunction:: train_model Deprecated ~~~~~~~~~~ diff --git a/doc/conf.py b/doc/conf.py index e3d58ca8f2..41659d2e84 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -125,6 +125,7 @@ 'subsection_order': ExplicitOrder([ '../examples/tutorials/core', '../examples/tutorials/extractors', + '../examples/tutorials/curation', '../examples/tutorials/qualitymetrics', '../examples/tutorials/comparison', '../examples/tutorials/widgets', diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst new file mode 100644 index 0000000000..9b1612ec12 --- /dev/null +++ b/doc/how_to/auto_curation_prediction.rst @@ -0,0 +1,43 @@ +How to use a trained model to predict the curation labels +========================================================= + +For a more detailed guide to using trained models, `read our tutorial here +`_). + +There is a Collection of models for automated curation available on the +`SpikeInterface HuggingFace page `_. + +We'll apply the model ``toy_tetrode_model`` from ``SpikeInterface`` on a SortingAnalyzer +called ``sorting_analyzer``. We assume that the quality and template metrics have +already been computed. + +We need to pass the ``sorting_analyzer``, the ``repo_id`` (which is just the part of the +repo's URL after huggingface.co/) and that we trust the model. + +.. code:: + + from spikeinterface.curation import auto_label_units + + labels_and_probabilities = auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trust_model = True + ) + +If you have a local directory containing the model in a ``skops`` file you can use this to +create the labels: + +.. code:: + + labels_and_probabilities = si.auto_label_units( + sorting_analyzer = sorting_analyzer, + model_folder = "my_folder_with_a_model_in_it", + ) + +The returned labels are a dictionary of model's predictions and it's confidence. These +are also saved as a property of your ``sorting_analyzer`` and can be accessed like so: + +.. code:: + + labels = sorting_analyzer.sorting.get_property("classifier_label") + probabilities = sorting_analyzer.sorting.get_property("classifier_probability") diff --git a/doc/how_to/auto_curation_training.rst b/doc/how_to/auto_curation_training.rst new file mode 100644 index 0000000000..20ab57d284 --- /dev/null +++ b/doc/how_to/auto_curation_training.rst @@ -0,0 +1,58 @@ +How to train a model to predict curation labels +=============================================== + +A full tutorial for model-based curation can be found `here `_. + +Here, we assume that you have: + +* Two SortingAnalyzers called ``analyzer_1`` and + ``analyzer_2``, and have calculated some template and quality metrics for both +* Manually curated labels for the units in each analyzer, in lists called + ``analyzer_1_labels`` and ``analyzer_2_labels``. If you have used phy, the lists can + be accessed using ``curated_labels = analyzer.sorting.get_property("quality")``. + +With these objects calculated, you can train a model as follows + +.. code:: + + from spikeinterface.curation import train_model + + analyzer_list = [analyzer_1, analyzer_2] + labels_list = [analyzer_1_labels, analyzer_2_labels] + output_folder = "/path/to/output_folder" + + trainer = train_model( + mode="analyzers", + labels=labels_list, + analyzers=analyzer_list, + output_folder=output_folder, + metric_names=None, # Set if you want to use a subset of metrics, defaults to all calculated quality and template metrics + imputation_strategies=None, # Default is all available imputation strategies + scaling_techniques=None, # Default is all available scaling techniques + classifiers=None, # Defaults to Random Forest classifier only - we usually find this gives the best results, but a range of classifiers is available + seed=None, # Set a seed for reproducibility + ) + + +The trainer tries several models and chooses the most accurate one. This model and +some metadata are stored in the ``output_folder``, which can later be loaded using the +``load_model`` function (`more details `_). +We can also access the model, which is an sklearn ``Pipeline``, from the trainer object + +.. code:: + + best_model = trainer.best_pipeline + + +The training function can also be run in “csv” mode, if you prefer to +store metrics in as .csv files. If the target labels are stored as a column in +the file, you can point to these with the ``target_label`` parameter + +.. code:: + + trainer = train_model( + mode="csv", + metrics_paths = ["/path/to/csv_file_1", "/path/to/csv_file_2"], + target_label = "my_label", + output_folder=output_folder, + ) diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 5d7eae9003..7f79156a3b 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -15,3 +15,5 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. load_your_data_into_sorting benchmark_with_hybrid_recordings drift_with_lfp + auto_curation_training + auto_curation_prediction diff --git a/doc/images/files_screen.png b/doc/images/files_screen.png new file mode 100644 index 0000000000..ef2b5b0873 Binary files /dev/null and b/doc/images/files_screen.png differ diff --git a/doc/images/hf-logo.svg b/doc/images/hf-logo.svg new file mode 100644 index 0000000000..ab959d165f --- /dev/null +++ b/doc/images/hf-logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/doc/images/initial_model_screen.png b/doc/images/initial_model_screen.png new file mode 100644 index 0000000000..b01c4248a6 Binary files /dev/null and b/doc/images/initial_model_screen.png differ diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst index 4c7625d811..82f2c06eed 100644 --- a/doc/tutorials_custom_index.rst +++ b/doc/tutorials_custom_index.rst @@ -119,8 +119,8 @@ The :code:`spikeinterface.qualitymetrics` module allows users to compute various .. grid-item-card:: Quality Metrics :link-type: ref - :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_mertics.py - :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_mertics_thumb.png + :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_metrics.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_metrics_thumb.png :img-alt: Quality Metrics :class-card: gallery-card :text-align: center @@ -133,6 +133,39 @@ The :code:`spikeinterface.qualitymetrics` module allows users to compute various :class-card: gallery-card :text-align: center +Automated curation tutorials +---------------------------- + +Learn how to curate your units using a trained machine learning model. Or how to create +and share your own model. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Model-based curation + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_1_automated_curation.py + :img-top: /tutorials/curation/images/sphx_glr_plot_1_automated_curation_002.png + :img-alt: Model-based curation + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Train your own model + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_2_train_a_model.py + :img-top: /tutorials/curation/images/thumb/sphx_glr_plot_2_train_a_model_thumb.png + :img-alt: Train your own model + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Upload your model to HuggingFaceHub + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_3_upload_a_model.py + :img-top: /images/hf-logo.svg + :img-alt: Upload your model + :class-card: gallery-card + :text-align: center + Comparison tutorial ------------------- diff --git a/examples/tutorials/curation/README.rst b/examples/tutorials/curation/README.rst new file mode 100644 index 0000000000..0f64179e65 --- /dev/null +++ b/examples/tutorials/curation/README.rst @@ -0,0 +1,5 @@ +Curation tutorials +------------------ + +Learn how to use models to automatically curated your sorted data, or generate models +based on your own curation. diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py new file mode 100644 index 0000000000..e88b0973df --- /dev/null +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -0,0 +1,287 @@ +""" +Model-based curation tutorial +============================= + +Sorters are not perfect. They output excellent units, as well as noisy ones, and ones that +should be split or merged. Hence one should curate the generated units. Historically, this +has been done using laborious manual curation. An alternative is to use automated methods +based on metrics which quantify features of the units. In spikeinterface these are the +quality metrics and the template metrics. A simple approach is to use thresholding: +only accept units whose metrics pass a certain quality threshold. Another approach is to +take one (or more) manually labelled sortings, whose metrics have been computed, and train +a machine learning model to predict labels. + +This notebook provides a step-by-step guide on how to take a machine learning model that +someone else has trained and use it to curate your own spike sorted output. SpikeInterface +also provides the tools to train your own model, +`which you can learn about here `_. + +We'll download a toy model and use it to label our sorted data. We start by importing some packages +""" + +import warnings +warnings.filterwarnings("ignore") +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# note: you can use more cores using e.g. +# si.set_global_jobs_kwargs(n_jobs = 8) + +############################################################################## +# Download a pretrained model +# --------------------------- +# +# Let's download a pretrained model from `Hugging Face `_ (HF), +# a model sharing platform focused on AI and ML models and datasets. The +# ``load_model`` function allows us to download directly from HF, or use a model in a local +# folder. The function downloads the model and saves it in a temporary folder and returns a +# model and some metadata about the model. + +model, model_info = sc.load_model( + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + + +############################################################################## +# This model was trained on artifically generated tetrode data. There are also models trained +# on real data, like the one discussed `below <#A-model-trained-on-real-Neuropixels-data>`_. +# Each model object has a nice html representation, which will appear if you're using a Jupyter notebook. + +model + +############################################################################## +# This tells us more information about the model. The one we've just downloaded was trained used +# a ``RandomForestClassifier```. You can also discover this information by running +# ``model.get_params()``. The model object (an `sklearn Pipeline `_) also contains information +# about which metrics were used to compute the model. We can access it from the model (or from the model_info) + +print(model.feature_names_in_) + +############################################################################## +# Hence, to use this model we need to create a ``sorting_analyzer`` with all these metrics computed. +# We'll do this by generating a recording and sorting, creating a sorting analyzer and computing a +# bunch of extensions. Follow these links for more info on `recordings `_, `sortings `_, `sorting analyzers `_ +# and `extensions `_. + +recording, sorting = si.generate_ground_truth_recording(num_channels=4, seed=4, num_units=10) +sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording) +sorting_analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) +sorting_analyzer.compute('template_metrics', include_multi_channel_metrics=True) + +############################################################################## +# This sorting_analyzer now contains the required quality metrics and template metrics. +# We can check that this is true by accessing the extension data. + +all_metric_names = list(sorting_analyzer.get_extension('quality_metrics').get_data().keys()) + list(sorting_analyzer.get_extension('template_metrics').get_data().keys()) +print(set(all_metric_names) == set(model.feature_names_in_)) + +############################################################################## +# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly +# to the ``auto_label_units`` function. This returns a dictionary containing a label and +# a confidence for each unit contained in the ``sorting_analyzer``. + +labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + +print(labels) + + +############################################################################## +# The model has labelled one unit as bad. Let's look at that one, and also the 'good' unit +# with the highest confidence of being 'good'. + +sw.plot_unit_templates(sorting_analyzer, unit_ids=['7','9']) + +############################################################################## +# Nice! Unit 9 looks more like an expected action potential waveform while unit 7 doesn't, +# and it seems reasonable that unit 7 is labelled as `bad`. However, for certain experiments +# or brain areas, unit 7 might be a great small-amplitude unit. This example highlights that +# you should be careful applying models trained on one dataset to your own dataset. You can +# explore the currently available models on the `spikeinterface hugging face hub `_ +# page, or `train your own one `_. +# +# Assess the model performance +# ---------------------------- +# +# To assess the performance of the model relative to labels assigned by a human creator, we can load or generate some +# "human labels", and plot a confusion matrix of predicted vs human labels for all clusters. Here +# we'll be a conservative human, who has labelled several units with small amplitudes as 'bad'. + +human_labels = ['bad', 'good', 'good', 'bad', 'good', 'bad', 'good', 'bad', 'good', 'good'] + +# Note: if you labelled using phy, you can load the labels using: +# human_labels = sorting_analyzer.sorting.get_property('quality') +# We need to load in the `label_conversion` dictionary, which converts integers such +# as '0' and '1' to readable labels such as 'good' and 'bad'. This is stored as +# in `model_info`, which we loaded earlier. + +from sklearn.metrics import confusion_matrix, balanced_accuracy_score + +label_conversion = model_info['label_conversion'] +predictions = labels['prediction'] + +conf_matrix = confusion_matrix(human_labels, predictions) + +# Calculate balanced accuracy for the confusion matrix +balanced_accuracy = balanced_accuracy_score(human_labels, predictions) + +plt.imshow(conf_matrix) +for (index, value) in np.ndenumerate(conf_matrix): + plt.annotate( str(value), xy=index, color="white", fontsize="15") +plt.xlabel('Predicted Label') +plt.ylabel('Human Label') +plt.xticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.yticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.title('Predicted vs Human Label') +plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}") +plt.show() + + +############################################################################## +# Here, there are several false positives (if we consider the human labels to be "the truth"). +# +# Next, we can also see how the model's confidence relates to the probability that the model +# label matches the human label. +# +# This could be used to help decide which units should be auto-curated and which need further +# manual creation. For example, we might accept any unit as 'good' that the model predicts +# as 'good' with confidence over a threshold, say 80%. If the confidence is lower we might decide to take a +# look at this unit manually. Below, we will create a plot that shows how the agreement +# between human and model labels changes as we increase the confidence threshold. We see that +# the agreement increases as the confidence does. So the model gets more accurate with a +# higher confidence threshold, as expceted. + + +def calculate_moving_avg(label_df, confidence_label, window_size): + + label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop') + # Group by decile and calculate the proportion of correct labels (agreement) + p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean() + # Convert decile to range 0-1 + p_label_grouped.index = p_label_grouped.index / 10 + # Sort the DataFrame by confidence scores + label_df_sorted = label_df.sort_values(by=confidence_label) + + p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean() + + return label_df_sorted[confidence_label], p_label_moving_avg + +confidences = labels['probability'] + +# Make dataframe of human label, model label, and confidence +label_df = pd.DataFrame(data = { + 'human_label': human_labels, + 'decoder_label': predictions, + 'confidence': confidences}, + index = sorting_analyzer.sorting.get_unit_ids()) + +# Calculate the proportion of agreed labels by confidence decile +label_df['model_x_human_agreement'] = label_df['human_label'] == label_df['decoder_label'] + +p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 3) + +# Plot the moving average of agreement +plt.figure(figsize=(6, 6)) +plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average') +plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance') +plt.xlabel('Confidence'); #plt.xlim(0.5, 1) +plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1) +plt.title('Agreement vs Confidence (Moving Average)') +plt.legend(); plt.grid(True); plt.show() + +############################################################################## +# In this case, you might decide to only trust labels which had confidence over above 0.88, +# and manually labels the ones the model isn't so confident about. +# +# A model trained on real Neuropixels data +# ---------------------------------------- +# +# Above, we used a toy model trained on generated data. There are also models on HuggingFace +# trained on real data. +# +# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in +# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/noise_neural_classifier/ and +# https://huggingface.co/AnoushkaJain3/sua_mua_classifier/ . One will classify units into +# `noise` or `not-noise` and the other will classify the `not-noise` units into single +# unit activity (sua) units and multi-unit activity (mua) units. +# +# There is more information about the model on the model's HuggingFace page. Take a look! +# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one. +# We can do so as follows: +# + +# Apply the noise/not-noise model +noise_neuron_labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "AnoushkaJain3/noise_neural_classifier", + trust_model=True, +) + +noise_units = noise_neuron_labels[noise_neuron_labels['prediction']=='noise'] +analyzer_neural = sorting_analyzer.remove_units(noise_units.index) + +# Apply the sua/mua model +sua_mua_labels = sc.auto_label_units( + sorting_analyzer = analyzer_neural, + repo_id = "AnoushkaJain3/sua_mua_classifier", + trust_model=True, +) + +all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() +print(all_labels) + +############################################################################## +# If you run this without the ``trust_model=True`` parameter, you will receive an error: +# +# .. code-block:: +# +# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold'] +# +# This is a security warning, which can be overcome by passing the trusted types list +# ``trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']`` +# or by passing the ``trust_model=True``` keyword. +# +# .. dropdown:: More about security +# +# Sharing models, with are Python objects, is complicated. +# We have chosen to use the `skops format `_, instead +# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues +# `here `_). While unpacking the ``.skops`` file, each function +# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and +# allow the object to be loaded if it only contains these (and no unkown malicious code). But +# when ``skops`` it's not sure, it raises an error. Here, it doesn't recognise +# ``['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', +# 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', +# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions +# from `sklearn`, and we can happily add them to the ``trusted`` functions to load. +# +# In general, you should be cautious when downloading ``.skops`` files and ``.pkl`` files from repos, +# especially from unknown sources. +# +# Directly applying a sklearn Pipeline +# ------------------------------------ +# +# Instead of using ``HuggingFace`` and ``skops``, someone might have given you a model +# in differet way: perhaps by e-mail or a download. If you have the model in a +# folder, you can apply it in a very similar way: +# +# .. code-block:: +# +# labels = sc.auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/model/folder", +# ) + +############################################################################## +# Using this, you lose the advantages of the model metadata: the quality metric parameters +# are not checked and the labels are not converted their original human readable names (like +# 'good' and 'bad'). Hence we advise using the methods discussed above, when possible. diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py new file mode 100644 index 0000000000..1a38836527 --- /dev/null +++ b/examples/tutorials/curation/plot_2_train_a_model.py @@ -0,0 +1,168 @@ +""" +Training a model for automated curation +============================= + +If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. +""" + + +############################################################################## +# Step 1: Generate and label data +# ------------------------------- +# +# First we will import our dependencies +import warnings +warnings.filterwarnings("ignore") +from pathlib import Path +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# Note, you can set the number of cores you use using e.g. +# si.set_global_job_kwargs(n_jobs = 8) + +############################################################################## +# For this tutorial, we will use simulated data to create ``recording`` and ``sorting`` objects. We'll +# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so the spike times of the sorter will +# perfectly match the spikes in the recording. Hence this will contain good units. However, we've +# uncoupled :code:`sorting_2` to the recording and the spike times will not be matched with the spikes in the recording. +# Hence these units will mostly be random noise. We'll combine the "good" and "noise" sortings into one sorting +# object using :code:`si.aggregate_units`. +# +# (When making your own model, you should +# `load your own recording `_ +# and `do a sorting `_ on your data.) + +recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5) +_, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5) + +both_sortings = si.aggregate_units([sorting_1, sorting_2]) + +############################################################################## +# To do some visualisation and postprocessing, we need to create a sorting analyzer, and +# compute some extensions: + +analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording) +analyzer.compute(['noise_levels','random_spikes','waveforms','templates']) + +############################################################################## +# Now we can plot the templates for the first and fifth units. The first (unit id 0) belongs to +# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belongs to :code:`sorting_2` +# so should look like noise. + +sw.plot_unit_templates(analyzer, unit_ids=["0", "5"]) + +############################################################################## +# This is as expected: great! (Find out more about plotting using widgets `here `_.) +# We've set up our system so that the first five units are 'good' and the next five are 'bad'. +# So we can make a list of labels which contain this information. For real data, you could +# use a manual curation tool to make your own list. + +labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad'] + +############################################################################## +# Step 2: Train our model +# ----------------------- +# +# We'll now train a model, based on our labelled data. The model will be trained using properties +# of the units, and then be applied to units from other sortings. The properties we use are the +# `quality metrics `_ +# and `template metrics `_. +# Hence we need to compute these, using some ``sorting_analyzer``` extensions. + +analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) + +############################################################################## +# Now that we have metrics and labels, we're ready to train the model using the +# ``train_model``` function. The trainer will try several classifiers, imputation strategies and +# scaling techniques then save the most accurate. To save time in this tutorial, +# we'll only try one classifier (Random Forest), imputation strategy (median) and scaling +# technique (standard scaler). +# +# We will use a list of one analyzer here, so the model is trained on a single +# session. In reality, we would usually train a model using multiple analyzers from an +# experiment, which should make the model more robust. To do this, you can simply pass +# a list of analyzers and a list of manually curated labels for each +# of these analyzers. Then the model would use all of these data as input. + +trainer = sc.train_model( + mode = "analyzers", # You can supply a labelled csv file instead of an analyzer + labels = [labels], + analyzers = [analyzer], + folder = "my_folder", # Where to save the model and model_info.json file + metric_names = None, # Specify which metrics to use for training: by default uses those already calculted + imputation_strategies = ["median"], # Defaults to all + scaling_techniques = ["standard_scaler"], # Defaults to all + classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"] + overwrite = True, # Whether or not to overwrite `folder` if it already exists. Default is False. + search_kwargs = {'cv': 3} # Parameters used during the model hyperparameter search +) + +best_model = trainer.best_pipeline + +############################################################################## +# +# You can pass many sklearn `classifiers `_ +# `imputation strategies `_ and +# `scalers `_, although the +# documentation is quite overwhelming. You can find the classifiers we've tried out +# using the ``sc.get_default_classifier_search_spaces`` function. +# +# The above code saves the model in ``model.skops``, some metadata in +# ``model_info.json`` and the model accuracies in ``model_accuracies.csv`` +# in the specified ``folder`` (in this case ``'my_folder'``). +# +# (``skops`` is a file format: you can think of it as a more-secure pkl file. `Read more `_.) +# +# The ``model_accuracies.csv`` file contains the accuracy, precision and recall of the +# tested models. Let's take a look: + +accuracies = pd.read_csv(Path("my_folder") / "model_accuracies.csv", index_col = 0) +accuracies.head() + +############################################################################## +# Our model is perfect!! This is because the task was *very* easy. We had 10 units; where +# half were pure noise and half were not. +# +# The model also contains some more information, such as which features are "important", +# as defined by sklearn (learn about feature importance of a Random Forest Classifier +# `here `_.) +# We can plot these: + +# Plot feature importances +importances = best_model.named_steps['classifier'].feature_importances_ +indices = np.argsort(importances)[::-1] + +# The sklearn importances are not computed for inputs whose values are all `nan`. +# Hence, we need to pick out the non-`nan` columns of our metrics +features = best_model.feature_names_in_ +n_features = best_model.n_features_in_ + +metrics = pd.concat([analyzer.get_extension('quality_metrics').get_data(), analyzer.get_extension('template_metrics').get_data()], axis=1) +non_null_metrics = ~(metrics.isnull().all()).values + +features = features[non_null_metrics] +n_features = len(features) + +plt.figure(figsize=(12, 7)) +plt.title("Feature Importances") +plt.bar(range(n_features), importances[indices], align="center") +plt.xticks(range(n_features), features[indices], rotation=90) +plt.xlim([-1, n_features]) +plt.subplots_adjust(bottom=0.3) +plt.show() + +############################################################################## +# Roughly, this means the model is using metrics such as "nn_hit_rate" and "l_ratio" +# but is not using "sync_spike_4" and "rp_contanimation". This is a toy model, so don't +# take these results seriously. But using this information, you could retrain another, +# simpler model using a subset of the metrics, by passing, e.g., +# ``metric_names = ['nn_hit_rate', 'l_ratio',...]`` to the ``train_model`` function. +# +# Now that you have a model, you can `apply it to another sorting +# `_ +# or `upload it to HuggingFaceHub `_. diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py new file mode 100644 index 0000000000..0a9ea402db --- /dev/null +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -0,0 +1,139 @@ +""" +Upload a pipeline to Hugging Face Hub +===================================== +""" +############################################################################## +# In this tutorial we will upload a pipeline, trained in SpikeInterface, to the +# `Hugging Face Hub `_ (HFH). +# +# To do this, you first need to train a model. `Learn how here! `_ +# +# Hugging Face Hub? +# ----------------- +# Hugging Face Hub (HFH) is a model sharing platform focused on AI and ML models and datasets. +# To upload your own model to HFH, you need to make an account with them. +# If you do not want to make an account, you can simply share the model folder with colleagues. +# There are also several ways to interaction with HFH: the way we propose here doesn't use +# many of the tools ``skops`` and hugging face have developed such as the ``Card`` and +# ``hub_utils``. Feel free to check those out `here `_. +# +# Prepare your model +# ------------------ +# +# The plan is to make a folder with the following file structure +# +# .. code-block:: +# +# my_model_folder/ +# my_model_name.skops +# model_info.json +# training_data.csv +# labels.csv +# metadata.json +# +# SpikeInterface and HFH don't require you to keep this folder structure, we just advise it as +# best practice. +# +# If you've used SpikeInterface to train your model, the ``train_model`` function auto-generates +# most of this data. The only thing missing is the the ``metadata.json`` file. The purpose of this +# file is to detail how the model was trained, which can help prospective users decide if it +# is relevant for them. For example, taking +# a model trained on mouse data and applying it to a primate is likely a bad idea (or a +# great research paper!). And a model trained using tetrode data might have limited application +# on a silcone high-density probes. Hence we suggest saving at least the species, brain areas +# and probe information, as is done in the dictionary below. Note that we format the metadata +# so that the information +# in common with the NWB data format is consistent with it. Since the models can be trained +# on several curations, all the metadata fields are lists: +# +# .. code-block:: +# +# import json +# +# model_metadata = { +# "subject_species": ["Mus musculus"], +# "brain_areas": ["CA1"], +# "probes": +# [{ +# "manufacturer": "IMEc", +# "name": "Neuropixels 2.0" +# }] +# } +# with open("my_model_folder/metadata.json", "w") as file: +# json.dump(model_metadata, file) +# +# Upload to HuggingFaceHub +# ------------------------ +# +# We'll now upload this folder to HFH using the web interface. +# +# First, go to https://huggingface.co/ and make an account. Once you've logged in, press +# ``+`` then ``New model`` or find ``+ New Model`` in the user menu. You will be asked +# to enter a model name, to choose a license for the model and whether the model should +# be public or private. After you have made these choices, press ``Create Model``. +# +# You should be on your model's landing page, whose header looks something like +# +# .. image:: ../../images/initial_model_screen.png +# :width: 550 +# :align: center +# :alt: The page shown on HuggingFaceHub when a user first initialises a model +# +# Click Files, then ``+ Add file`` then ``Upload file(s)``. You can then add your files to the repository. Upload these by pressing ``Commit changes to main``. +# +# You are returned to the Files page, which should look similar to +# +# .. image:: ../../images/files_screen.png +# :width: 700 +# :align: center +# :alt: The file list for a model HuggingFaceHub. +# +# Let's add some information about the model for users to see when they go on your model's +# page. Click on ``Model card`` then ``Edit model card``. Here is a sample model card for +# For a model based on synthetically generated tetrode data, +# +# .. code-block:: +# +# --- +# license: mit +# --- +# +# ## Model description +# +# A toy model, trained on toy data generated from spikeinterface. +# +# # Intended use +# +# Used to try out automated curation in SpikeInterface. +# +# # How to Get Started with the Model +# +# This can be used to automatically label a sorting in spikeinterface. Provided you have a `sorting_analyzer`, it is used as follows +# +# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading) +# +# from spikeinterface.curation import auto_label_units +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# repo_id = "SpikeInterface/toy_tetrode_model", +# trust_model=True +# ) +# ` ` ` +# +# or you can download the entire repositry to `a_folder_for_a_model`, and use +# +# ` ` ` python +# from spikeinterface.curation import auto_label_units +# +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/a_folder_for_a_model", +# trusted = ['numpy.dtype'] +# ) +# ` ` ` +# +# # Authors +# +# Chris Halcrow +# +# You can see the repo with this Model card `here `_. diff --git a/examples/tutorials/qualitymetrics/plot_3_quality_mertics.py b/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py similarity index 100% rename from examples/tutorials/qualitymetrics/plot_3_quality_mertics.py rename to examples/tutorials/qualitymetrics/plot_3_quality_metrics.py diff --git a/pyproject.toml b/pyproject.toml index 22fbdc7f22..0b2f06049f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,8 @@ full = [ "matplotlib>=3.6", # matplotlib.colormaps "cuda-python; platform_system != 'Darwin'", "numba", + "skops", + "huggingface_hub" ] widgets = [ @@ -171,6 +173,10 @@ test = [ "torch", "pynndescent", + # curation + "skops", + "huggingface_hub", + # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", @@ -192,6 +198,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions "networkx", + "skops", # For auotmated curation + "scikit-learn", # For auotmated curation # Download data "pooch>=1.8.2", "datalad>=1.0.2", diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 0302ffe5b7..975f2fe22f 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -15,3 +15,7 @@ from .curation_format import validate_curation_dict, curation_label_to_dataframe, apply_curation from .sortingview_curation import apply_sortingview_curation + +# automated curation +from .model_based_curation import auto_label_units, load_model +from .train_manual_curation import train_model, get_default_classifier_search_spaces diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py new file mode 100644 index 0000000000..93ad03734c --- /dev/null +++ b/src/spikeinterface/curation/model_based_curation.py @@ -0,0 +1,435 @@ +import numpy as np +from pathlib import Path +import json +import warnings +import re + +from spikeinterface.core import SortingAnalyzer +from spikeinterface.curation.train_manual_curation import ( + try_to_get_metrics_from_analyzer, + _get_computed_metrics, + _format_metric_dataframe, +) +from copy import deepcopy + + +class ModelBasedClassification: + """ + Class for performing model-based classification on spike sorting data. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + + Attributes + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + required_metrics : Sequence[str] + The list of required metrics for classification, extracted from the pipeline. + + Methods + ------- + predict_labels() + Predicts the labels for the spike sorting data using the trained model. + """ + + def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline): + from sklearn.pipeline import Pipeline + + if not isinstance(pipeline, Pipeline): + raise ValueError("The `pipeline` must be an instance of sklearn.pipeline.Pipeline") + + self.sorting_analyzer = sorting_analyzer + self.pipeline = pipeline + self.required_metrics = pipeline.feature_names_in_ + + def predict_labels( + self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_metric_params=False + ): + """ + Predicts the labels for the spike sorting data using the trained model. + Populates the sorting object with the predicted labels and probabilities as unit properties + + Parameters + ---------- + model_info : dict or None, default: None + Model info, generated with model, used to check metric parameters used to train it. + label_conversion : dict or None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to find in `model_info` file. The dictionary should have the format {old_label: new_label}. + input_data : pandas.DataFrame or None, default: None + The input data for classification. If not provided, the method will extract metrics stored in the sorting analyzer. + export_to_phy : bool, default: False. + Whether to export the classified units to Phy format. Default is False. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + Returns + ------- + pd.DataFrame + A dataframe containing the classified units and their corresponding predictions and probabilities, + indexed by their `unit_ids`. + """ + import pandas as pd + + # Get metrics DataFrame for classification + if input_data is None: + input_data = _get_computed_metrics(self.sorting_analyzer) + else: + if not isinstance(input_data, pd.DataFrame): + raise ValueError("Input data must be a pandas DataFrame") + + input_data = self._check_required_metrics_are_present(input_data) + + if model_info is not None: + self._check_params_for_classification(enforce_metric_params, model_info=model_info) + + if model_info is not None and label_conversion is None: + try: + string_label_conversion = model_info["label_conversion"] + # json keys are strings; we convert these to ints + label_conversion = {} + for key, value in string_label_conversion.items(): + label_conversion[int(key)] = value + except: + warnings.warn("Could not find `label_conversion` key in `model_info.json` file") + + input_data = _format_metric_dataframe(input_data) + + # Apply classifier + predictions = self.pipeline.predict(input_data) + probabilities = self.pipeline.predict_proba(input_data) + probabilities = np.max(probabilities, axis=1) + + if isinstance(label_conversion, dict): + + if set(predictions).issubset(set(label_conversion.keys())) is False: + raise ValueError("Labels in predictions do not match those in label_conversion") + predictions = [label_conversion[label] for label in predictions] + + classified_units = pd.DataFrame( + zip(predictions, probabilities), columns=["prediction", "probability"], index=self.sorting_analyzer.unit_ids + ) + + # Set predictions and probability as sorting properties + self.sorting_analyzer.sorting.set_property("classifier_label", predictions) + self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities) + + if export_to_phy: + self._export_to_phy(classified_units) + + return classified_units + + def _check_required_metrics_are_present(self, calculated_metrics): + + # Check all the required metrics have been calculated + required_metrics = set(self.required_metrics) + if required_metrics.issubset(set(calculated_metrics)): + input_data = calculated_metrics[self.required_metrics] + else: + raise ValueError( + "Input data does not contain all required metrics for classification", + f"Missing metrics: {required_metrics.difference(calculated_metrics)}", + ) + + return input_data + + def _check_params_for_classification(self, enforce_metric_params=False, model_info=None): + """ + Check that quality and template metrics parameters match those used to train the model + + Parameters + ---------- + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + model_info : dict, default: None + Dictionary of model info containing provenance of the model. + """ + + extension_names = ["quality_metrics", "template_metrics"] + + metric_extensions = [self.sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + model_extension_params = model_info["metric_params"].get(extension_name + "_params") + + if metric_extension is not None and model_extension_params is not None: + + metric_params = metric_extension.params["metric_params"] + + inconsistent_metrics = [] + for metric in model_extension_params["metric_names"]: + model_metric_params = model_extension_params.get("metric_params") + if model_metric_params is None or metric not in model_metric_params: + inconsistent_metrics.append(metric) + else: + if metric_params[metric] != model_metric_params[metric]: + warning_message = f"{extension_name} params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + if len(inconsistent_metrics) > 0: + warning_message = f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + def _export_to_phy(self, classified_units): + """Export the classified units to Phy as cluster_prediction.tsv file""" + + import pandas as pd + + # Create a new DataFrame with unit_id, prediction, and probability columns from dict {unit_id: (prediction, probability)} + classified_df = pd.DataFrame.from_dict(classified_units, orient="index", columns=["prediction", "probability"]) + + # Export to Phy format + try: + sorting_path = self.sorting_analyzer.sorting.get_annotation("phy_folder") + assert sorting_path is not None + assert Path(sorting_path).is_dir() + except AssertionError: + raise ValueError("Phy folder not found in sorting annotations, or is not a directory") + + classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id") + + +def auto_label_units( + sorting_analyzer: SortingAnalyzer, + model_folder=None, + model_name=None, + repo_id=None, + label_conversion=None, + trust_model=False, + trusted=None, + export_to_phy=False, + enforce_metric_params=False, +): + """ + Automatically labels units based on a model-based classification, either from a model + hosted on HuggingFaceHub or one available in a local folder. + + This function returns the predicted labels and the prediction probabilities, and populates + the sorting object with the predicted labels and probabilities in the 'classifier_label' and + 'classifier_probability' properties. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting results. + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + label_conversion : dic | None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to extract from `model_info.json` file. The dictionary should have the format {old_label: new_label}. + export_to_phy : bool, default: False + Whether to export the results to Phy format. Default is False. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + + Returns + ------- + classified_units : pd.DataFrame + A dataframe containing the classified units, indexed by the `unit_ids`, containing the predicted label + and confidence probability of each labelled unit. + + Raises + ------ + ValueError + If the pipeline is not an instance of sklearn.pipeline.Pipeline. + + """ + from sklearn.pipeline import Pipeline + + model, model_info = load_model( + model_folder=model_folder, repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + if not isinstance(model, Pipeline): + raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline") + + model_based_classification = ModelBasedClassification(sorting_analyzer, model) + + classified_units = model_based_classification.predict_labels( + label_conversion=label_conversion, + export_to_phy=export_to_phy, + model_info=model_info, + enforce_metric_params=enforce_metric_params, + ) + + return classified_units + + +def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a HuggingFaceHub repo or a local folder. + + Parameters + ---------- + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + if model_folder is None and repo_id is None: + raise ValueError("Please provide a 'model_folder' or a 'repo_id'.") + elif model_folder is not None and repo_id is not None: + raise ValueError("Please only provide one of 'model_folder' or 'repo_id'.") + elif model_folder is not None: + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + else: + model, model_info = _load_model_from_huggingface( + repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_huggingface(repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model from a huggingface repo + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + from huggingface_hub import list_repo_files + from huggingface_hub import hf_hub_download + + # get repo filenames + repo_filenames = list_repo_files(repo_id=repo_id) + + # download all skops and json files to temp directory + for filename in repo_filenames: + if Path(filename).suffix in [".skops", ".json"]: + full_path = hf_hub_download(repo_id=repo_id, filename=filename) + model_folder = Path(full_path).parent + + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_folder(model_folder=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a folder + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + import skops.io as skio + from skops.io.exceptions import UntrustedTypesFoundException + + folder = Path(model_folder) + assert folder.is_dir(), f"The folder {folder}, does not exist." + + # look for any .skops files + skops_files = list(folder.glob("*.skops")) + assert len(skops_files) > 0, f"There are no '.skops' files in the folder {folder}" + + if len(skops_files) > 1: + if model_name is None: + model_names = [f.name for f in skops_files] + raise ValueError( + f"There are more than 1 '.skops' file in folder {folder}. You have to specify " + f"the file using the 'model_name' argument. Available files:\n{model_names}" + ) + else: + skops_file = folder / Path(model_name) + assert skops_file.is_file(), f"Model file {skops_file} not found." + elif len(skops_files) == 1: + skops_file = skops_files[0] + + if trust_model and trusted is None: + try: + model = skio.load(skops_file) + except UntrustedTypesFoundException as e: + exception_msg = str(e) + # the exception message contains the list of untrusted objects. The following + # search assumes it is the only list in the message. + string_list = re.search(r"\[(.*?)\]", exception_msg).group() + trusted = [list_item for list_item in string_list.split("'") if len(list_item) > 2] + + model = skio.load(skops_file, trusted=trusted) + + model_info_path = folder / "model_info.json" + if not model_info_path.is_file(): + warnings.warn("No 'model_info.json' file found in folder. No metadata can be checked.") + model_info = None + else: + model_info = json.load(open(model_info_path)) + + model_info = handle_backwards_compatibility_metric_params(model_info) + + return model, model_info + + +def handle_backwards_compatibility_metric_params(model_info): + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("quality_metric_params") is not None + ): + if (qm_params := model_info["metric_params"]["quality_metric_params"].get("qm_params")) is not None: + model_info["metric_params"]["quality_metric_params"]["metric_params"] = qm_params + del model_info["metric_params"]["quality_metric_params"]["qm_params"] + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("template_metric_params") is not None + ): + if (tm_params := model_info["metric_params"]["template_metric_params"].get("metrics_kwargs")) is not None: + metric_params = {} + for metric_name in model_info["metric_params"]["template_metric_params"].get("metric_names"): + metric_params[metric_name] = deepcopy(tm_params) + model_info["metric_params"]["template_metric_params"]["metric_params"] = metric_params + del model_info["metric_params"]["template_metric_params"]["metrics_kwargs"] + + return model_info diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py new file mode 100644 index 0000000000..3683b417df --- /dev/null +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -0,0 +1,167 @@ +import pytest +from pathlib import Path +from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.model_based_curation import ModelBasedClassification +from spikeinterface.curation import auto_label_units, load_model +from spikeinterface.curation.train_manual_curation import _get_computed_metrics + +import numpy as np + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "curation" +else: + cache_folder = Path("cache_folder") / "curation" + + +@pytest.fixture +def model(): + """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. + It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with + the following labels: [1,0,1,0,1].""" + + model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) + return model + + +@pytest.fixture +def required_metrics(): + """These are the metrics which `model` are trained on.""" + return ["num_spikes", "snr", "half_width"] + + +def test_model_based_classification_init(sorting_analyzer_for_curation, model): + """Test that the ModelBasedClassification attributes are correctly initialised""" + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation + assert model_based_classification.pipeline == model[0] + assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) + + +def test_metric_ordering_independence(sorting_analyzer_for_curation, model): + """The function `auto_label_units` needs the correct metrics to have been computed. However, + it should be independent of the order of computation. We test this here.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + prediction_prob_dataframe_1 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) + + prediction_prob_dataframe_2 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + assert prediction_prob_dataframe_1.equals(prediction_prob_dataframe_2) + + +def test_model_based_classification_get_metrics_for_classification( + sorting_analyzer_for_curation, model, required_metrics +): + """If the user has not computed the required metrics, an error should be returned. + This test checks that an error occurs when the required metrics have not been computed, + and that no error is returned when the required metrics have been computed. + """ + + sorting_analyzer_for_curation.delete_extension("quality_metrics") + sorting_analyzer_for_curation.delete_extension("template_metrics") + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + + # Check that ValueError is returned when no metrics are present in sorting_analyzer + with pytest.raises(ValueError): + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + + # Compute some (but not all) of the required metrics in sorting_analyzer, should still error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]]) + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + with pytest.raises(ValueError): + model_based_classification._check_required_metrics_are_present(computed_metrics) + + # Compute all of the required metrics in sorting_analyzer, no more error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) + + metrics_data = _get_computed_metrics(sorting_analyzer_for_curation) + assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) + assert set(metrics_data.columns.to_list()) == set(required_metrics) + + +def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model): + # Test the _export_to_phy() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = {0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)} + # Function should fail here + with pytest.raises(ValueError): + model_based_classification._export_to_phy(classified_units) + # Make temp output folder and set as phy_folder + phy_folder = cache_folder / "phy_folder" + phy_folder.mkdir(parents=True, exist_ok=True) + + model_based_classification.sorting_analyzer.sorting.annotate(phy_folder=phy_folder) + model_based_classification._export_to_phy(classified_units) + assert (phy_folder / "cluster_prediction.tsv").exists() + + +def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, model): + """The model `model` has been trained on the `sorting_analyzer` used in this test with + the labels `[1, 0, 1, 0, 1]`. Hence if we apply the model to this `sorting_analyzer` + we expect these labels to be outputted. The test checks this, and also checks + that label conversion works as expected.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + # Test the predict_labels() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = model_based_classification.predict_labels() + predictions = classified_units["prediction"].values + + assert np.all(predictions == np.array([1, 0, 1, 0, 1])) + + conversion = {0: "noise", 1: "good"} + classified_units_labelled = model_based_classification.predict_labels(label_conversion=conversion) + predictions_labelled = classified_units_labelled["prediction"] + assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) + + +def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation): + """We track whether the metric parameters used to compute the metrics used to train + a model are the same as the parameters used to compute the metrics in the sorting + analyzer which is being curated. If they are different, an error or warning will + be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here.""" + + sorting_analyzer_for_curation.compute( + "quality_metrics", metric_names=["num_spikes", "snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}} + ) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + + # an error should be raised if `enforce_metric_params` is True + with pytest.raises(Exception): + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) + + # but only a warning if `enforce_metric_params` is False + with pytest.warns(UserWarning): + model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info) + + # Now test the positive case. Recompute using the default parameters + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], metric_params={}) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f455fbdb9c --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,285 @@ +import pytest +import numpy as np +import tempfile, csv +from pathlib import Path + +from spikeinterface.curation.tests.common import make_sorting_analyzer +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model + + +@pytest.fixture +def trainer(): + """A simple CurationModelTrainer object is created, which can later by used to + train models using data from `sorting_analyzer`s.""" + + folder = tempfile.mkdtemp() # Create a temporary output folder + imputation_strategies = ["median"] + scaling_techniques = ["standard_scaler"] + classifiers = ["LogisticRegression"] + metric_names = ["metric1", "metric2", "metric3"] + search_kwargs = {"cv": 3} + return CurationModelTrainer( + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + search_kwargs=search_kwargs, + ) + + +def make_temp_training_csv(): + """Create a temporary CSV file with artificially generated quality metrics. + The data is designed to be easy to dicern between units. Even units metric + values are all `0`, while odd units metric values are all `1`. + """ + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + writer = csv.writer(temp_file) + writer.writerow(["unit_id", "metric1", "metric2", "metric3"]) + for i in range(5): + writer.writerow([i * 2, 0, 0, 0]) + writer.writerow([i * 2 + 1, 1, 1, 1]) + return temp_file.name + + +def test_load_and_preprocess_full(trainer): + """Check that we load and preprocess the csv file from `make_temp_training_csv` + correctly.""" + temp_file_path = make_temp_training_csv() + + # Load and preprocess the data from the temporary CSV file + trainer.load_and_preprocess_csv([temp_file_path]) + + # Assert that the data is loaded and preprocessed correctly + for a, row in trainer.X.iterrows(): + assert np.all(row.values == [float(a % 2)] * 3) + for a, label in enumerate(trainer.y.values): + assert label == a % 2 + for a, row in trainer.testing_metrics.iterrows(): + assert np.all(row.values == [a % 2] * 3) + assert row.name == a + + +def test_apply_scaling_imputation(trainer): + """Take a simple training and test set and check that they are corrected scaled, + using a standard scaler which rescales the training distribution to have mean 0 + and variance 1. Length between each row is 3, so if x0 is the first value in the + column, all other values are scaled as x -> 2/3(x - x0) - 1. The y (labled) values + do not get scaled.""" + + from sklearn.impute._knn import KNNImputer + from sklearn.preprocessing._data import StandardScaler + + imputation_strategy = "knn" + scaling_technique = "standard_scaler" + X_train = np.array([[1, 2, 3], [4, 5, 6]]) + X_test = np.array([[7, 8, 9], [10, 11, 12]]) + y_train = np.array([0, 1]) + y_test = np.array([2, 3]) + + X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, imputer, scaler = trainer.apply_scaling_imputation( + imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test + ) + + first_row_elements = X_train[0] + for a, row in enumerate(X_train): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_train_scaled[a]) + for a, row in enumerate(X_test): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_test_scaled[a]) + + assert np.all(y_train == y_train_scaled) + assert np.all(y_test == y_test_scaled) + + assert isinstance(imputer, KNNImputer) + assert isinstance(scaler, StandardScaler) + + +def test_get_classifier_search_space(trainer): + """For each classifier, there is a hyperparameter space we search over to find its + most accurate incarnation. Here, we check that we do indeed load the approprirate + dict of hyperparameter possibilities""" + + from sklearn.linear_model._logistic import LogisticRegression + + classifier = "LogisticRegression" + model, param_space = trainer.get_classifier_search_space(classifier) + + assert isinstance(model, LogisticRegression) + assert len(param_space) > 0 + assert isinstance(param_space, dict) + + +def test_get_custom_classifier_search_space(): + """Check that if a user passes a custom hyperparameter search space, that this is + passed correctly to the trainer.""" + + classifier = { + "LogisticRegression": { + "C": [0.1, 8.0], + "solver": ["lbfgs"], + "max_iter": [100, 400], + } + } + trainer = CurationModelTrainer(classifiers=classifier, labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]]) + + model, param_space = trainer.get_classifier_search_space(list(classifier.keys())[0]) + assert param_space == classifier["LogisticRegression"] + + +def test_saved_files(trainer): + """During the trainer's creation, the following files should be created: + - best_model.skops + - labels.csv + - model_accuracies.csv + - model_info.json + - training_data.csv + This test checks that these exist, and checks some properties of the files.""" + + import pandas as pd + import json + + trainer.X = np.random.rand(10, 3) + trainer.y = np.append(np.ones(5), np.zeros(5)) + + trainer.evaluate_model_config() + trainer_folder = Path(trainer.folder) + + assert trainer_folder.is_dir() + + best_model_path = trainer_folder / "best_model.skops" + model_accuracies_path = trainer_folder / "model_accuracies.csv" + training_data_path = trainer_folder / "training_data.csv" + labels_path = trainer_folder / "labels.csv" + model_info_path = trainer_folder / "model_info.json" + + assert (best_model_path).is_file() + + model_accuracies = pd.read_csv(model_accuracies_path) + model_accuracies["classifier name"].values[0] == "LogisticRegression" + assert len(model_accuracies) == 1 + + training_data = pd.read_csv(training_data_path) + assert np.all(np.isclose(training_data.values[:, 1:4], trainer.X, rtol=1e-10)) + + labels = pd.read_csv(labels_path) + assert np.all(labels.values[:, 1] == trainer.y.astype("float")) + + model_info = pd.read_json(model_info_path) + + with open(model_info_path) as f: + model_info = json.load(f) + + assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"]) + + +def test_train_model(): + """A simple function test to check that `train_model` doesn't fail with one csv inputs""" + + metrics_path = make_temp_training_csv() + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_model_using_two_csvs(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from csv files.""" + + metrics_path_1 = make_temp_training_csv() + metrics_path_2 = make_temp_training_csv() + + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path_1, metrics_path_2], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_using_two_sorting_analyzers(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from sorting analzyers. It also checks that + an error is raised if the sorting_analyzers have different sets of metrics computed.""" + + sorting_analyzer_1 = make_sorting_analyzer() + sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + labels_1 = [0, 1, 1, 1, 1] + labels_2 = [1, 1, 0, 1, 1] + + folder = tempfile.mkdtemp() + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + assert isinstance(trainer, CurationModelTrainer) + + # Check that there is an error raised if the metric names are different + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}}) + + with pytest.raises(Exception): + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + # Now check that there is an error raised if we demand the same metric params, but don't have them + + sorting_analyzer_2.compute( + { + "quality_metrics": { + "metric_names": ["num_spikes", "snr"], + "metric_params": {"snr": {"peak_mode": "at_index"}}, + } + } + ) + + with pytest.raises(Exception): + train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + overwrite=True, + enforce_metric_params=True, + ) diff --git a/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops new file mode 100644 index 0000000000..362405f917 Binary files /dev/null and b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops differ diff --git a/src/spikeinterface/curation/tests/trained_pipeline/labels.csv b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv new file mode 100644 index 0000000000..46680a9e89 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv @@ -0,0 +1,21 @@ +unit_index,0 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv new file mode 100644 index 0000000000..7f015c380b --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv @@ -0,0 +1,2 @@ +,classifier name,imputation_strategy,scaling_strategy,accuracy,precision,recall,model_id,best_params +0,LogisticRegression,median,StandardScaler(),1.0000,1.0000,1.0000,0,"OrderedDict([('C', 4.811707275233983), ('max_iter', 384), ('solver', 'saga')])" diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_info.json b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json new file mode 100644 index 0000000000..75ced28486 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json @@ -0,0 +1,60 @@ +{ + "metric_params": { + "quality_metric_params": { + "metric_names": [ + "snr", + "num_spikes" + ], + "peak_sign": null, + "seed": null, + "metric_params": { + "num_spikes": {}, + "snr": { + "peak_sign": "neg", + "peak_mode": "extremum" + } + }, + "skip_pc_metrics": false, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "snr", + "num_spikes" + ] + }, + "template_metric_params": { + "metric_names": [ + "half_width" + ], + "sparsity": null, + "peak_sign": "neg", + "upsampling_factor": 10, + "metric_params": { + "half_width": { + "recovery_window_ms": 0.7, + "peak_relative_threshold": 0.2, + "peak_width_ms": 0.1, + "depth_direction": "y", + "min_channels_for_velocity": 5, + "min_r2_velocity": 0.5, + "exp_peak_function": "ptp", + "min_r2_exp_decay": 0.5, + "spread_threshold": 0.2, + "spread_smooth_um": 20, + "column_range": null + } + }, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "half_width" + ] + } + }, + "requirements": { + "spikeinterface": "0.101.1", + "scikit-learn": "1.3.2" + }, + "label_conversion": { + "1": 1, + "0": 0 + } +} diff --git a/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv new file mode 100644 index 0000000000..c9efca17ad --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv @@ -0,0 +1,21 @@ +unit_id,snr,num_spikes,half_width +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py new file mode 100644 index 0000000000..7b315b0fba --- /dev/null +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -0,0 +1,843 @@ +import os +import warnings +import numpy as np +import json +import spikeinterface +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.qualitymetrics import ( + get_quality_metric_list, + get_quality_pca_metric_list, + qm_compute_name_to_column_names, +) +from spikeinterface.postprocessing import get_template_metric_names +from spikeinterface.postprocessing.template_metrics import tm_compute_name_to_column_names +from pathlib import Path +from copy import deepcopy + + +def get_default_classifier_search_spaces(): + + from scipy.stats import uniform, randint + + default_classifier_search_spaces = { + "RandomForestClassifier": { + "n_estimators": [100, 150], + "criterion": ["gini", "entropy"], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + "class_weight": ["balanced", "balanced_subsample"], + }, + "AdaBoostClassifier": { + "learning_rate": [1, 2], + "n_estimators": [50, 100], + "algorithm": ["SAMME", "SAMME.R"], + }, + "GradientBoostingClassifier": { + "learning_rate": uniform(0.05, 0.1), + "n_estimators": randint(100, 150), + "max_depth": [2, 4], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + }, + "SVC": { + "C": uniform(0.001, 10.0), + "kernel": ["sigmoid", "rbf"], + "gamma": uniform(0.001, 10.0), + "probability": [True], + }, + "LogisticRegression": { + "C": uniform(0.001, 10.0), + "solver": ["newton-cg", "lbfgs", "liblinear", "sag", "saga"], + "max_iter": [100], + }, + "XGBClassifier": { + "max_depth": [2, 4], + "eta": uniform(0.2, 0.5), + "sampling_method": ["uniform"], + "grow_policy": ["depthwise", "lossguide"], + }, + "CatBoostClassifier": {"depth": [2, 4], "learning_rate": uniform(0.05, 0.15), "n_estimators": [100, 150]}, + "LGBMClassifier": {"learning_rate": uniform(0.05, 0.15), "n_estimators": randint(100, 150)}, + "MLPClassifier": { + "activation": ["tanh", "relu"], + "solver": ["adam"], + "alpha": uniform(1e-7, 1e-1), + "learning_rate": ["constant", "adaptive"], + "n_iter_no_change": [32], + }, + } + + return default_classifier_search_spaces + + +class CurationModelTrainer: + """ + Used to train and evaluate machine learning models for spike sorting curation. + + Parameters + ---------- + labels : list of lists, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + folder : str, default: None + The folder where outputs such as models and evaluation metrics will be saved, if specified. Requires the skops library. If None, output will not be saved on file system. + metric_names : list of str, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str or dict, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + seed : int, default: None + Random seed for reproducibility. If None, a random seed will be generated. + smote : bool, default: False + Whether to apply SMOTE for class imbalance. Default is False. Requires imbalanced-learn package. + verbose : bool, default: True + If True, useful information is printed during training. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + + Attributes + ---------- + folder : str + The folder where outputs such as models and evaluation metrics will be saved. Requires the skops library. + labels : list of lists, default: None + List of curated labels for each `sorting_analyzer` and each unit; must be in the same order as the metrics data. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str + The list of classifiers to evaluate. + classifier_search_space : dict or None + Dictionary of classifiers and their hyperparameter search spaces, if provided. If None, default search spaces are used. + seed : int + Random seed for reproducibility. + metrics_list : list of str + The list of metrics to use for training. + X : pandas.DataFrame or None + The feature matrix after preprocessing. + y : pandas.Series or None + The target vector after preprocessing. + testing_metrics : dict or None + Dictionary to hold testing metrics data. + label_conversion : dict or None + Dictionary to map string labels to integer codes if target column contains string labels. + + Methods + ------- + get_default_metrics_list() + Returns the default list of metrics. + load_and_preprocess_full(path) + Loads and preprocesses the data from the given path. + load_data_file(path) + Loads the data file from the given path. + process_test_data_for_classification() + Processes the test data for classification. + apply_scaling_imputation(imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val) + Applies the specified imputation and scaling techniques to the data. + get_classifier_instance(classifier_name) + Returns an instance of the specified classifier. + get_classifier_search_space(classifier_name) + Returns the search space for hyperparameter tuning for the specified classifier. + get_classifier_search_space() + Returns the default search spaces for hyperparameter tuning for the classifiers. + evaluate_model_config(imputation_strategies, scaling_techniques, classifiers) + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + """ + + def __init__( + self, + labels=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + seed=None, + smote=False, + verbose=True, + search_kwargs=None, + **job_kwargs, + ): + + import pandas as pd + + if imputation_strategies is None: + imputation_strategies = ["median", "most_frequent", "knn", "iterative"] + + if scaling_techniques is None: + scaling_techniques = [ + "standard_scaler", + "min_max_scaler", + "robust_scaler", + ] + + if classifiers is None: + self.classifiers = ["RandomForestClassifier"] + self.classifier_search_space = None + elif isinstance(classifiers, dict): + self.classifiers = list(classifiers.keys()) + self.classifier_search_space = classifiers + elif isinstance(classifiers, list): + self.classifiers = classifiers + self.classifier_search_space = None + else: + raise ValueError("classifiers must be a list or dictionary") + + # check if labels is a list of lists + if not all(isinstance(label, list) or isinstance(label, np.ndarray) for label in labels): + raise ValueError("labels must be a list of lists") + + self.folder = Path(folder) if folder is not None else None + self.imputation_strategies = imputation_strategies + self.scaling_techniques = scaling_techniques + self.test_size = test_size + self.seed = seed if seed is not None else np.random.default_rng(seed=None).integers(0, 2**31) + self.metrics_params = {} + self.smote = smote + self.label_conversion = None + self.verbose = verbose + self.search_kwargs = search_kwargs + + self.X = None + self.testing_metrics = None + + self.requirements = {"spikeinterface": spikeinterface.__version__} + + self.y = pd.concat([pd.DataFrame(one_labels)[0] for one_labels in labels]) + + self.metric_names = metric_names + + if self.folder is not None and not self.folder.is_dir(): + self.folder.mkdir(parents=True, exist_ok=True) + + # update job_kwargs with global ones + job_kwargs = fix_job_kwargs(job_kwargs) + self.n_jobs = job_kwargs["n_jobs"] + + def get_default_metrics_list(self): + """Returns the default list of metrics.""" + return get_quality_metric_list() + get_quality_pca_metric_list() + get_template_metric_names() + + def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): + """ + Loads and preprocesses the quality metrics and labels from the given list of SortingAnalyzer objects. + """ + import pandas as pd + + metrics_for_each_analyzer = [_get_computed_metrics(an) for an in analyzers] + check_metric_names_are_the_same(metrics_for_each_analyzer) + + self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0) + + # Set metric names to those calculated if not provided + if self.metric_names is None: + warnings.warn("No metric_names provided, using all metrics calculated by the analyzers") + self.metric_names = self.testing_metrics.columns.tolist() + + conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params) + + self.metrics_params = {} + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [analyzers[0].get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + if metric_extension is not None: + self.metrics_params[extension_name + "_params"] = metric_extension.params + + # Only save metric params which are 1) consistent and 2) exist in metric_names + metric_names = metric_extension.params["metric_names"] + consistent_metrics = list(set(metric_names).difference(set(conflicting_metrics))) + consistent_metric_params = { + metric: metric_extension.params["metric_params"][metric] for metric in consistent_metrics + } + self.metrics_params[extension_name + "_params"]["metric_params"] = consistent_metric_params + + self.process_test_data_for_classification() + + def _check_metrics_parameters(self, analyzers, enforce_metric_params): + """Checks that the metrics of each analyzer have been calcualted using the same parameters""" + + extension_names = ["quality_metrics", "template_metrics"] + + conflicting_metrics = [] + for analyzer_index_1, analyzer_1 in enumerate(analyzers): + for analyzer_index_2, analyzer_2 in enumerate(analyzers): + + if analyzer_index_1 <= analyzer_index_2: + continue + else: + + metric_params_1 = {} + metric_params_2 = {} + + for extension_name in extension_names: + if (extension_1 := analyzer_1.get_extension(extension_name)) is not None: + metric_params_1.update(extension_1.params["metric_params"]) + if (extension_2 := analyzer_2.get_extension(extension_name)) is not None: + metric_params_2.update(extension_2.params["metric_params"]) + + conflicting_metrics_between_1_2 = [] + # check quality metrics params + for metric, params_1 in metric_params_1.items(): + if params_1 != metric_params_2.get(metric): + conflicting_metrics_between_1_2.append(metric) + + conflicting_metrics += conflicting_metrics_between_1_2 + + if len(conflicting_metrics_between_1_2) > 0: + warning_message = f"Parameters used to calculate {conflicting_metrics_between_1_2} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}" + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + unique_conflicting_metrics = set(conflicting_metrics) + return unique_conflicting_metrics + + def load_and_preprocess_csv(self, paths): + self._load_data_files(paths) + self.process_test_data_for_classification() + self.get_metric_params_csv() + + def get_metric_params_csv(self): + + from itertools import chain + + qm_metric_names = list(chain.from_iterable(qm_compute_name_to_column_names.values())) + tm_metric_names = list(chain.from_iterable(tm_compute_name_to_column_names.values())) + + quality_metric_names = [] + template_metric_names = [] + + for metric_name in self.metric_names: + if metric_name in qm_metric_names: + quality_metric_names.append(metric_name) + if metric_name in tm_metric_names: + template_metric_names.append(metric_name) + + self.metrics_params = {} + if quality_metric_names != {}: + self.metrics_params["quality_metric_params"] = {"metric_names": quality_metric_names} + if template_metric_names != {}: + self.metrics_params["template_metric_params"] = {"metric_names": template_metric_names} + + return + + def process_test_data_for_classification(self): + """ + Cleans the input data so that it can be used by sklearn. + + Extracts the target variable and features from the loaded dataset. + It handles string labels by converting them to integer codes and reindexes the + feature matrix to match the specified metrics list. Infinite values in the features + are replaced with NaN, and any remaining NaN values are filled with zeros. + + Raises + ------ + ValueError + If the target column specified is not found in the loaded dataset. + + Notes + ----- + If the target column contains string labels, a warning is issued and the labels + are converted to integer codes. The mapping from string labels to integer codes + is stored in the `label_conversion` attribute. + """ + + # Convert string labels to integer codes to allow classification + new_y = self.y.astype("category").cat.codes + self.label_conversion = dict(zip(new_y, self.y)) + self.y = new_y + + # Extract features + try: + if (set(self.metric_names) - set(self.testing_metrics.columns) != set()) and self.verbose is True: + print( + f"Dropped metrics (calculated but not included in metric_names): {set(self.testing_metrics.columns) - set(self.metric_names)}" + ) + self.X = self.testing_metrics[self.metric_names] + except KeyError as e: + raise KeyError(f"{str(e)}, metrics_list contains invalid metric names") + + self.X = self.testing_metrics.reindex(columns=self.metric_names) + self.X = _format_metric_dataframe(self.X) + + def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test): + """Impute and scale the data using the specified techniques.""" + from sklearn.experimental import enable_iterative_imputer + from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer + from sklearn.ensemble import HistGradientBoostingRegressor + from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler + + if imputation_strategy == "knn": + imputer = KNNImputer(n_neighbors=5) + elif imputation_strategy == "iterative": + imputer = IterativeImputer( + estimator=HistGradientBoostingRegressor(random_state=self.seed), random_state=self.seed + ) + else: + imputer = SimpleImputer(strategy=imputation_strategy) + + if scaling_technique == "standard_scaler": + scaler = StandardScaler() + elif scaling_technique == "min_max_scaler": + scaler = MinMaxScaler() + elif scaling_technique == "robust_scaler": + scaler = RobustScaler() + else: + raise ValueError( + f"Unknown scaling technique: {scaling_technique}. Supported scaling techniques are 'standard_scaler', 'min_max_scaler' and 'robust_scaler." + ) + + y_train_processed = y_train.astype(int) + y_test = y_test.astype(int) + + X_train_imputed = imputer.fit_transform(X_train) + X_test_imputed = imputer.transform(X_test) + X_train_processed = scaler.fit_transform(X_train_imputed) + X_test_processed = scaler.transform(X_test_imputed) + + # Apply SMOTE for class imbalance + if self.smote: + try: + from imblearn.over_sampling import SMOTE + except ModuleNotFoundError: + raise ModuleNotFoundError("Please install imbalanced-learn package to use SMOTE") + smote = SMOTE(random_state=self.seed) + X_train_processed, y_train_processed = smote.fit_resample(X_train_processed, y_train_processed) + + return X_train_processed, X_test_processed, y_train_processed, y_test, imputer, scaler + + def get_classifier_instance(self, classifier_name): + from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier + from sklearn.svm import SVC + from sklearn.linear_model import LogisticRegression + from sklearn.neural_network import MLPClassifier + + classifier_mapping = { + "RandomForestClassifier": RandomForestClassifier(random_state=self.seed), + "AdaBoostClassifier": AdaBoostClassifier(random_state=self.seed), + "GradientBoostingClassifier": GradientBoostingClassifier(random_state=self.seed), + "SVC": SVC(random_state=self.seed), + "LogisticRegression": LogisticRegression(random_state=self.seed), + "MLPClassifier": MLPClassifier(random_state=self.seed), + } + + # Check lightgbm package install + if classifier_name == "LGBMClassifier": + try: + import lightgbm + + self.requirements["lightgbm"] = lightgbm.__version__ + classifier_mapping["LGBMClassifier"] = lightgbm.LGBMClassifier(random_state=self.seed, verbose=-1) + except ImportError: + raise ImportError("Please install lightgbm package to use LGBMClassifier") + elif classifier_name == "CatBoostClassifier": + try: + import catboost + + self.requirements["catboost"] = catboost.__version__ + classifier_mapping["CatBoostClassifier"] = catboost.CatBoostClassifier( + silent=True, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install catboost package to use CatBoostClassifier") + elif classifier_name == "XGBClassifier": + try: + import xgboost + + self.requirements["xgboost"] = xgboost.__version__ + classifier_mapping["XGBClassifier"] = xgboost.XGBClassifier( + use_label_encoder=False, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install xgboost package to use XGBClassifier") + + if classifier_name not in classifier_mapping: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + return classifier_mapping[classifier_name] + + def get_classifier_search_space(self, classifier_name): + + default_classifier_search_spaces = get_default_classifier_search_spaces() + + if classifier_name not in default_classifier_search_spaces: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + model = self.get_classifier_instance(classifier_name) + if self.classifier_search_space is not None: + param_space = self.classifier_search_space[classifier_name] + else: + param_space = default_classifier_search_spaces[classifier_name] + return model, param_space + + def evaluate_model_config(self): + """ + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + + This method splits the preprocessed data into training and testing sets, then evaluates the specified + combinations of imputation strategies, scaling techniques, and classifiers. The evaluation results are + saved to the output folder. + + Raises + ------ + ValueError + If any of the specified classifier names are not recognized. + + Notes + ----- + The method converts the classifier names to actual classifier instances before evaluating them. + The evaluation results, including the best model and its parameters, are saved to the output folder. + """ + from sklearn.model_selection import train_test_split + + X_train, X_test, y_train, y_test = train_test_split( + self.X, self.y, test_size=self.test_size, random_state=self.seed, stratify=self.y + ) + classifier_instances = [self.get_classifier_instance(clf) for clf in self.classifiers] + self._evaluate( + self.imputation_strategies, + self.scaling_techniques, + classifier_instances, + X_train, + X_test, + y_train, + y_test, + self.search_kwargs, + ) + + def _load_data_files(self, paths): + import pandas as pd + + self.testing_metrics = pd.concat([pd.read_csv(path, index_col=0) for path in paths], axis=0) + + def _evaluate( + self, imputation_strategies, scaling_techniques, classifiers, X_train, X_test, y_train, y_test, search_kwargs + ): + from joblib import Parallel, delayed + from sklearn.pipeline import Pipeline + import pandas as pd + + results = Parallel(n_jobs=self.n_jobs)( + delayed(self._train_and_evaluate)( + imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, idx, search_kwargs + ) + for idx, (imputation_strategy, scaler, classifier) in enumerate( + (imputation_strategy, scaler, classifier) + for imputation_strategy in imputation_strategies + for scaler in scaling_techniques + for classifier in classifiers + ) + ) + + test_accuracies, models = zip(*results) + + if self.search_kwargs is None or self.search_kwargs.get("scoring"): + scoring_method = "balanced_accuracy" + else: + scoring_method = self.search_kwargs.get("scoring") + + self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False) + + best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"]) + best_model, best_imputer, best_scaler = models[best_model_id] + + best_pipeline = Pipeline( + [("imputer", best_imputer), ("scaler", best_scaler), ("classifier", best_model.best_estimator_)] + ) + + self.best_pipeline = best_pipeline + + if self.folder is not None: + self._save() + + def _save(self): + from skops.io import dump + import sklearn + import pandas as pd + + # export training data and labels + pd.DataFrame(self.X).to_csv(self.folder / f"training_data.csv", index_label="unit_id") + pd.DataFrame(self.y).to_csv(self.folder / f"labels.csv", index_label="unit_index") + + self.requirements["scikit-learn"] = sklearn.__version__ + + # Dump to skops if folder is provided + dump(self.best_pipeline, self.folder / f"best_model.skops") + self.test_accuracies_df.to_csv(self.folder / f"model_accuracies.csv", float_format="%.4f") + + model_info = {} + model_info["metric_params"] = self.metrics_params + + model_info["requirements"] = self.requirements + + model_info["label_conversion"] = self.label_conversion + + param_file = self.folder / "model_info.json" + Path(param_file).write_text(json.dumps(model_info, indent=4), encoding="utf8") + + def _train_and_evaluate( + self, imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, model_id, search_kwargs + ): + from sklearn.metrics import balanced_accuracy_score, precision_score, recall_score + + search_kwargs = set_default_search_kwargs(search_kwargs) + + X_train_scaled, X_test_scaled, y_train, y_test, imputer, scaler = self.apply_scaling_imputation( + imputation_strategy, scaler, X_train, X_test, y_train, y_test + ) + if self.verbose is True: + print(f"Running {classifier.__class__.__name__} with imputation {imputation_strategy} and scaling {scaler}") + model, param_space = self.get_classifier_search_space(classifier.__class__.__name__) + + try: + from skopt import BayesSearchCV + + model = BayesSearchCV( + model, + param_space, + random_state=self.seed, + **search_kwargs, + ) + except: + if self.verbose is True: + print("BayesSearchCV from scikit-optimize not available, using RandomizedSearchCV") + from sklearn.model_selection import RandomizedSearchCV + + model = RandomizedSearchCV(model, param_space, **search_kwargs) + + model.fit(X_train_scaled, y_train) + y_pred = model.predict(X_test_scaled) + balanced_acc = balanced_accuracy_score(y_test, y_pred) + precision = precision_score(y_test, y_pred, average="macro") + recall = recall_score(y_test, y_pred, average="macro") + return { + "classifier name": classifier.__class__.__name__, + "imputation_strategy": imputation_strategy, + "scaling_strategy": scaler, + "balanced_accuracy": balanced_acc, + "precision": precision, + "recall": recall, + "model_id": model_id, + "best_params": model.best_params_, + }, (model, imputer, scaler) + + +def train_model( + mode="analyzers", + labels=None, + analyzers=None, + metrics_paths=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + overwrite=False, + seed=None, + search_kwargs=None, + verbose=True, + enforce_metric_params=False, + **job_kwargs, +): + """ + Trains and evaluates machine learning models for spike sorting curation. + + This function initializes a `CurationModelTrainer` object, loads and preprocesses the data, + and evaluates the specified combinations of imputation strategies, scaling techniques, and classifiers. + The evaluation results, including the best model and its parameters, are saved to the output folder. + + Parameters + ---------- + mode : "analyzers" | "csv", default: "analyzers" + Mode to use for training. + analyzers : list of SortingAnalyzer | None, default: None + List of SortingAnalyzer objects containing the quality metrics and labels to use for training, if using 'analyzers' mode. + labels : list of list | None, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + metrics_paths : list of str or None, default: None + List of paths to the CSV files containing the metrics data if using 'csv' mode. + folder : str | None, default: None + The folder where outputs such as models and evaluation metrics will be saved. + metric_names : list of str | None, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str | dict | None, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + overwrite : bool, default: False + Overwrites the `folder` if it already exists + seed : int | None, default: None + Random seed for reproducibility. If None, a random seed will be generated. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + verbose : bool, default: True + If True, useful information is printed during training. + enforce_metric_params : bool, default: False + If True and metric parameters used to calculate metrics for different `sorting_analyzer`s are + different, an error will be raised. + + + Returns + ------- + CurationModelTrainer + The `CurationModelTrainer` object used for training and evaluation. + + Notes + ----- + This function handles the entire workflow of initializing the trainer, loading and preprocessing the data, + and evaluating the models. The evaluation results are saved to the specified output folder. + """ + + if folder is None: + raise Exception("You must supply a folder for the model to be saved in using `folder='path/to/folder/'`") + + if overwrite is False: + assert not Path(folder).is_dir(), f"folder {folder} already exists, choose another name or use overwrite=True" + + if labels is None: + raise Exception("You must supply a list of lists of curated labels using `labels = [[...],[...],...]`") + + if mode not in ["analyzers", "csv"]: + raise Exception("`mode` must be equal to 'analyzers' or 'csv'.") + + if (test_size > 1.0) or (0.0 > test_size): + raise Exception("`test_size` must be between 0.0 and 1.0") + + trainer = CurationModelTrainer( + labels=labels, + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + test_size=test_size, + seed=seed, + verbose=verbose, + search_kwargs=search_kwargs, + **job_kwargs, + ) + + if mode == "analyzers": + assert analyzers is not None, "Analyzers must be provided as a list for mode 'analyzers'" + trainer.load_and_preprocess_analyzers(analyzers, enforce_metric_params) + + elif mode == "csv": + for metrics_path in metrics_paths: + assert Path(metrics_path).is_file(), f"{metrics_path} is not a file." + trainer.load_and_preprocess_csv(metrics_paths) + + trainer.evaluate_model_config() + return trainer + + +def _get_computed_metrics(sorting_analyzer): + """Loads and organises the computed metrics from a sorting_analyzer into a single dataframe""" + + import pandas as pd + + quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(sorting_analyzer) + calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1) + + # Remove any metrics for non-existent units, raise error if no units are present + calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(sorting_analyzer.sorting.get_unit_ids())] + if calculated_metrics.shape[0] == 0: + raise ValueError("No units present in sorting data") + + return calculated_metrics + + +def try_to_get_metrics_from_analyzer(sorting_analyzer): + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + if any(metric_extensions) is False: + raise ValueError( + "At least one of quality metrics or template metrics must be computed before classification.", + "Compute both using `sorting_analyzer.compute('quality_metrics', 'template_metrics')", + ) + + metric_extensions_data = [] + for metric_extension in metric_extensions: + try: + metric_extensions_data.append(metric_extension.get_data()) + except: + metric_extensions_data.append(None) + + return metric_extensions_data + + +def set_default_search_kwargs(search_kwargs): + + if search_kwargs is None: + search_kwargs = {} + + if search_kwargs.get("cv") is None: + search_kwargs["cv"] = 5 + if search_kwargs.get("scoring") is None: + search_kwargs["scoring"] = "balanced_accuracy" + if search_kwargs.get("n_iter") is None: + search_kwargs["n_iter"] = 25 + + return search_kwargs + + +def check_metric_names_are_the_same(metrics_for_each_analyzer): + """ + Given a list of dataframes, checks that the keys are all equal. + """ + + for i, metrics_for_analyzer_1 in enumerate(metrics_for_each_analyzer): + for j, metrics_for_analyzer_2 in enumerate(metrics_for_each_analyzer): + if i > j: + metric_names_1 = set(metrics_for_analyzer_1.keys()) + metric_names_2 = set(metrics_for_analyzer_2.keys()) + if metric_names_1 != metric_names_2: + metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2) + metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1) + + error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n" + if metrics_in_1_but_not_2: + error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does." + if metrics_in_2_but_not_1: + error_message += f"#{i} does not contain {metrics_in_2_but_not_1}, which #{j} does." + raise Exception(error_message) + + +def _format_metric_dataframe(input_data): + + input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x) + input_data = input_data.astype("float32") + + return input_data diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index ca21f1e45f..c789d1af82 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -42,6 +42,9 @@ max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" ), silhouette=dict(method=("simplified",)), + isolation_distance=dict(), + l_ratio=dict(), + d_prime=dict(), )