diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 27214f32e6..0e70b1f494 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -150,7 +150,6 @@ def _get_data(self): def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): - import sklearn.metrics.pairwise if method == "cosine_similarity": @@ -223,15 +222,17 @@ def compute_similarity_with_templates_array( if method == "l1": norm_i = np.sum(np.abs(src)) norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1") + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() distances[count, i, j] /= norm_i + norm_j elif method == "l2": norm_i = np.linalg.norm(src, ord=2) norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2") + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() distances[count, i, j] /= norm_i + norm_j else: - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine") + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( + src, tgt, metric="cosine" + ).item() if same_array: distances[count, j, i] = distances[count, i, j]