From d83824079423874c5c60a72dd3217412fe941322 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 4 Apr 2025 14:16:23 +0200 Subject: [PATCH 01/33] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2edd38e77a..d03ee29dd0 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -47,6 +47,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "multi_units_only": False, "job_kwargs": {"n_jobs": 0.5}, "seed": 42, + "deterministic": False, "debug": False, } @@ -76,6 +77,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", "seed": "An int to control how chunks are shuffled while detecting peaks", + "deterministic": "a boolean to specify if the sorting should be deterministic or not. If True, then the seed will be used to shuffle the chunks", "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } @@ -187,6 +189,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nafter = int(ms_after * fs / 1000.0) skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform" + skip_peaks = skip_peaks and not params["deterministic"] max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) @@ -195,6 +198,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_folder.mkdir(parents=True, exist_ok=True) np.save(clustering_folder / "noise_levels.npy", noise_levels) + if params["matched_filtering"]: prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( recording_w, @@ -224,7 +228,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w, seed=params["seed"], **job_kwargs ) peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) - + print(prototype.mean()) if not skip_peaks and verbose: print("Found %d peaks in total" % len(peaks)) @@ -234,6 +238,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible selection_params = params["selection"] + selection_params["seed"] = params["seed"] selection_params["n_peaks"] = n_peaks selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks(peaks, **selection_params) From 776e8818b341689d508b65cddc0629b349ca4647 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 4 Apr 2025 16:47:52 +0200 Subject: [PATCH 02/33] WIP --- src/spikeinterface/sortingcomponents/peak_detection.py | 1 - src/spikeinterface/sortingcomponents/tools.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index dcc543c1c3..8e10f45624 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -571,7 +571,6 @@ def check_params( # ) assert peak_sign in ("both", "neg", "pos") - if noise_levels is None: noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index e4fd3c2539..83ebbe9b3e 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -171,7 +171,7 @@ def get_prototype_and_waveforms_from_recording( pipeline_nodes = [node] recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) - + print(recording_slices, detection_kwargs, n_peaks) res = detect_peaks( recording, pipeline_nodes=pipeline_nodes, @@ -180,9 +180,10 @@ def get_prototype_and_waveforms_from_recording( **detection_kwargs, **job_kwargs, ) - + print(seed, len(res[0])) rng = np.random.RandomState(seed) indices = rng.permutation(np.arange(len(res[0]))) + print("indices", indices.sum()) few_peaks = res[0][indices[:n_peaks]] waveforms = res[1][indices[:n_peaks]] From 28f1623fafba63d1031fbd9c54469d1343f6ccf2 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 4 Apr 2025 16:48:06 +0200 Subject: [PATCH 03/33] WIP --- .../sorters/internal/spyking_circus2.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index d03ee29dd0..3f33e686c1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -161,6 +161,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # TODO add , regularize=True chen ready whitening_kwargs = params["whitening"].copy() whitening_kwargs["dtype"] = "float32" + whitening_kwargs["seed"] = params["seed"] whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False) if num_channels == 1: whitening_kwargs["regularize"] = False @@ -168,7 +169,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) + noise_levels = get_noise_levels(recording_w, + random_slices_kwargs={"seed": params["seed"]}, + return_scaled=False, **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -198,7 +201,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_folder.mkdir(parents=True, exist_ok=True) np.save(clustering_folder / "noise_levels.npy", noise_levels) - if params["matched_filtering"]: prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( recording_w, @@ -216,19 +218,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "prototype.npy", prototype) if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks - detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=params["seed"], **job_kwargs - ) + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + detection_params['random_chunk_kwargs'] = {"num_chunks_per_segment": 5, + "seed" : params["seed"]} + print(prototype.mean()) + import sys + sys.exit() peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: waveforms = None if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks - detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=params["seed"], **job_kwargs - ) + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) - print(prototype.mean()) + if not skip_peaks and verbose: print("Found %d peaks in total" % len(peaks)) From 86c1cf7fd79f6ffa6768aafb6fffc40a4875bb9a Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 7 Apr 2025 09:41:16 +0200 Subject: [PATCH 04/33] WIP --- .../sorters/internal/spyking_circus2.py | 46 +++++++++++++------ .../sortingcomponents/clustering/circus.py | 14 +++++- src/spikeinterface/sortingcomponents/tools.py | 3 -- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 3f33e686c1..77a12c6542 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,14 +6,13 @@ import numpy as np from spikeinterface.core import NumpySorting -from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, - get_prototype_and_waveforms_from_recording, get_shuffled_recording_slices, ) from spikeinterface.core.basesorting import minimum_spike_dtype @@ -47,7 +46,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "multi_units_only": False, "job_kwargs": {"n_jobs": 0.5}, "seed": 42, - "deterministic": False, + "deterministic": True, "debug": False, } @@ -202,15 +201,35 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "noise_levels.npy", noise_levels) if params["matched_filtering"]: - prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( - recording_w, - n_peaks=10000, - ms_before=ms_before, - ms_after=ms_after, - seed=params["seed"], - **detection_params, - **job_kwargs, - ) + if not params["deterministic"]: + from spikeinterface.sortingcomponents.tools import ( + get_prototype_and_waveforms_from_recording, + ) + prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( + recording_w, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=params["seed"], + **detection_params, + **job_kwargs, + ) + else: + from spikeinterface.sortingcomponents.tools import ( + get_prototype_and_waveforms_from_peaks, + ) + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) + prototype, waveforms, _ = get_prototype_and_waveforms_from_peaks( + recording_w, + peaks, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=params["seed"], + **detection_params, + **job_kwargs, + ) + detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before if params["debug"]: @@ -223,9 +242,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) detection_params['random_chunk_kwargs'] = {"num_chunks_per_segment": 5, "seed" : params["seed"]} - print(prototype.mean()) - import sys - sys.exit() peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: waveforms = None diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7bce0800d3..803cadeeab 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -60,6 +60,7 @@ class CircusClustering: "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, + "seed" : 42, "noise_threshold": 4, "rank": 5, "noise_levels": None, @@ -91,10 +92,19 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # SVD for time compression if params["few_waveforms"] is None: few_peaks = select_peaks( - peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter) + peaks, + recording=recording, + method="uniform", + seed=params["seed"], + n_peaks=10000, + margin=(nbefore, nafter) ) few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + recording, + few_peaks, + ms_before=ms_before, + ms_after=ms_after, + **job_kwargs ) wfs = few_wfs[:, :, 0] else: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 83ebbe9b3e..3c15e56726 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -171,7 +171,6 @@ def get_prototype_and_waveforms_from_recording( pipeline_nodes = [node] recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) - print(recording_slices, detection_kwargs, n_peaks) res = detect_peaks( recording, pipeline_nodes=pipeline_nodes, @@ -180,10 +179,8 @@ def get_prototype_and_waveforms_from_recording( **detection_kwargs, **job_kwargs, ) - print(seed, len(res[0])) rng = np.random.RandomState(seed) indices = rng.permutation(np.arange(len(res[0]))) - print("indices", indices.sum()) few_peaks = res[0][indices[:n_peaks]] waveforms = res[1][indices[:n_peaks]] From 8e33d08ee483549cf10538befc8c3e824293ffe2 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 09:55:30 +0200 Subject: [PATCH 05/33] WIP --- .../sortingcomponents/clustering/circus.py | 182 +++++------------- .../clustering/graph_clustering.py | 4 +- 2 files changed, 49 insertions(+), 137 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7bce0800d3..0f1a636ee9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -19,16 +19,12 @@ from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection -from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates -import pickle, json +from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd from spikeinterface.core.node_pipeline import ( run_node_pipeline, - ExtractSparseWaveforms, - PeakRetriever, ) @@ -55,6 +51,8 @@ class CircusClustering: "recursive_depth": 3, "returns_split_count": True, }, + "split_kwargs": {"projection_mode": "tsvd", + "n_pca_features": 0.9}, "radius_um": 100, "n_svd": 5, "few_waveforms": None, @@ -78,6 +76,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): fs = recording.get_sampling_frequency() ms_before = params["ms_before"] ms_after = params["ms_after"] + radius_um = params["radius_um"] nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) if params["tmp_folder"] is None: @@ -108,136 +107,50 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): valid = np.argmax(np.abs(wfs), axis=1) == nbefore wfs = wfs[valid] - # Perform Hanning filtering - hanning_before = np.hanning(2 * nbefore) - hanning_after = np.hanning(2 * nafter) - hanning = np.concatenate((hanning_before[:nbefore], hanning_after[nafter:])) - wfs *= hanning - from sklearn.decomposition import TruncatedSVD + svd_model = TruncatedSVD(params["n_svd"]) + svd_model.fit(wfs) + features_folder = tmp_folder / "tsvd_features" + features_folder.mkdir(exist_ok=True) + + peaks_svd, sparse_mask, svd_model = extract_peaks_svd(recording, + peaks, + ms_before=ms_before, + ms_after=ms_after, + svd_model=svd_model, + radius_um=radius_um, + folder=features_folder, + **job_kwargs) + + neighbours_mask = get_channel_distances(recording) <= radius_um + + if params["debug"]: + np.save(features_folder / "sparse_mask.npy", sparse_mask) + np.save(features_folder / "peaks.npy", peaks) - tsvd = TruncatedSVD(params["n_svd"]) - tsvd.fit(wfs) - - model_folder = tmp_folder / "tsvd_model" - - model_folder.mkdir(exist_ok=True) - with open(model_folder / "pca_model.pkl", "wb") as f: - pickle.dump(tsvd, f) - - model_params = { - "ms_before": ms_before, - "ms_after": ms_after, - "sampling_frequency": float(fs), - } - - with open(model_folder / "params.json", "w") as f: - json.dump(model_params, f) - - # features - node0 = PeakRetriever(recording, peaks) - - radius_um = params["radius_um"] - node1 = ExtractSparseWaveforms( - recording, - parents=[node0], - return_output=False, - ms_before=ms_before, - ms_after=ms_after, - radius_um=radius_um, - ) - - node2 = HanningFilter(recording, parents=[node0, node1], return_output=False) - - node3 = TemporalPCAProjection( - recording, parents=[node0, node2], return_output=True, model_folder_path=model_folder - ) - - pipeline_nodes = [node0, node1, node2, node3] - - if len(params["recursive_kwargs"]) == 0: - from sklearn.decomposition import PCA + original_labels = peaks["channel_index"] + from spikeinterface.sortingcomponents.clustering.split import split_clusters - all_pc_data = run_node_pipeline( - recording, - pipeline_nodes, - job_kwargs, - job_name="extracting features", - ) + split_kwargs = params["split_kwargs"].copy() + split_kwargs["neighbours_mask"] = neighbours_mask + split_kwargs["waveforms_sparse_mask"] = sparse_mask + split_kwargs["min_size_split"] = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 50) + split_kwargs["clusterer_kwargs"] = params["hdbscan_kwargs"] - peak_labels = -1 * np.ones(len(peaks), dtype=int) - nb_clusters = 0 - for c in np.unique(peaks["channel_index"]): - mask = peaks["channel_index"] == c - sub_data = all_pc_data[mask] - sub_data = sub_data.reshape(len(sub_data), -1) - - if all_pc_data.shape[1] > params["n_svd"]: - tsvd = PCA(params["n_svd"], whiten=True) - else: - tsvd = PCA(all_pc_data.shape[1], whiten=True) - - hdbscan_data = tsvd.fit_transform(sub_data) - try: - clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) - local_labels = clustering[0] - except Exception: - local_labels = np.zeros(len(hdbscan_data)) - valid_clusters = local_labels > -1 - if np.sum(valid_clusters) > 0: - local_labels[valid_clusters] += nb_clusters - peak_labels[mask] = local_labels - nb_clusters += len(np.unique(local_labels[valid_clusters])) + if params["debug"]: + debug_folder = tmp_folder / "split" else: + debug_folder = None - features_folder = tmp_folder / "tsvd_features" - features_folder.mkdir(exist_ok=True) - - _ = run_node_pipeline( - recording, - pipeline_nodes, - job_kwargs, - job_name="extracting features", - gather_mode="npy", - gather_kwargs=dict(exist_ok=True), - folder=features_folder, - names=["sparse_tsvd"], - ) - - sparse_mask = node1.neighbours_mask - neighbours_mask = get_channel_distances(recording) <= radius_um - - # np.save(features_folder / "sparse_mask.npy", sparse_mask) - np.save(features_folder / "peaks.npy", peaks) - - original_labels = peaks["channel_index"] - from spikeinterface.sortingcomponents.clustering.split import split_clusters - - min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 20) - - if params["debug"]: - debug_folder = tmp_folder / "split" - else: - debug_folder = None - - peak_labels, _ = split_clusters( - original_labels, - recording, - features_folder, - method="local_feature_clustering", - method_kwargs=dict( - clusterer="hdbscan", - feature_name="sparse_tsvd", - neighbours_mask=neighbours_mask, - waveforms_sparse_mask=sparse_mask, - min_size_split=min_size, - clusterer_kwargs=d["hdbscan_kwargs"], - n_pca_features=5, - ), - debug_folder=debug_folder, - **params["recursive_kwargs"], - **job_kwargs, - ) + peak_labels, _ = split_clusters( + original_labels, + recording, + {"peaks" : peaks, "sparse_tsvd" : peaks_svd}, + method="local_feature_clustering", + method_kwargs=split_kwargs, + debug_folder=debug_folder, + **params["recursive_kwargs"], + **job_kwargs) non_noise = peak_labels > -1 labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) @@ -273,11 +186,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] valid_templates = best_snrs_ratio > params["noise_threshold"] - if d["rank"] is not None: - from spikeinterface.sortingcomponents.matching.circus import compress_templates - - _, _, _, templates_array = compress_templates(templates_array, d["rank"]) - templates = Templates( templates_array=templates_array[valid_templates], sampling_frequency=fs, @@ -289,6 +197,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): is_scaled=False, ) + if params["debug"]: + templates_folder = tmp_folder / "dense_templates" + templates.to_zarr(folder_path=templates_folder) + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 @@ -314,4 +226,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("Kept %d non-duplicated clusters" % len(labels)) - return labels, peak_labels + return labels, peak_labels, svd_model, peaks_svd, sparse_mask diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py index 28409c2221..af25267b86 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py @@ -70,7 +70,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): elif graph_kwargs["bin_mode"] == "vertical_bins": assert radius_um >= graph_kwargs["bin_um"] * 3 - peaks_svd, sparse_mask, _ = extract_peaks_svd( + peaks_svd, sparse_mask, svd_model = extract_peaks_svd( recording, peaks, radius_um=radius_um, @@ -191,7 +191,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels_set = np.unique(peak_labels) labels_set = labels_set[labels_set >= 0] - return labels_set, peak_labels + return labels_set, peak_labels, svd_model, peaks_svd, sparse_mask def _remove_small_cluster(peak_labels, min_size=1): From e4993023000a9c451df772c8d4f908cd8c1e632c Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 10:25:09 +0200 Subject: [PATCH 06/33] Example of how to use SVD to estimate templates in SC2 --- .../sorters/internal/spyking_circus2.py | 258 +++++++++--------- .../sortingcomponents/clustering/circus.py | 3 - .../clustering/graph_clustering.py | 4 + .../sortingcomponents/clustering/tools.py | 11 +- 4 files changed, 146 insertions(+), 130 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2edd38e77a..b04de700d6 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,10 +6,9 @@ import numpy as np from spikeinterface.core import NumpySorting -from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates -from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, @@ -24,28 +23,32 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 1}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, - "detection": {"peak_sign": "neg", "detect_threshold": 5}, - "selection": { - "method": "uniform", - "n_peaks_per_channel": 5000, - "min_n_peaks": 100000, - "select_per_channel": False, - "seed": 42, - }, + "detection": {"method" : "matched_filtering", + "method_kwargs" : dict( + peak_sign="neg", + detect_threshold=5 + )}, + "selection": {"method": "uniform", + "method_kwargs" : dict( + n_peaks_per_channel=5000, + min_n_peaks=100000, + select_per_channel=False) + }, "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"legacy": True}, - "matching": {"method": "circus-omp-svd"}, + "clustering": {"method": "graph_clustering", + "method_kwargs" : dict()}, + "matching": {"method": "circus-omp-svd", + "method_kwargs" : dict()}, "apply_preprocessing": True, - "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.5}, + "job_kwargs": {"n_jobs": 0.75}, "seed": 42, "debug": False, } @@ -56,15 +59,12 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "general": "A dictionary to describe how templates should be computed. User can define ms_before and ms_after (in ms) \ and also the radius_um used to be considered during clustering", "sparsity": "A dictionary to be passed to all the calls to sparsify the templates", - "filtering": "A dictionary for the high_pass filter to be used during preprocessing", - "whitening": "A dictionary for the whitening option to be used during preprocessing", - "detection": "A dictionary for the peak detection node (locally_exclusive)", - "selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\ - and 5000 peaks per electrode on average.", - "clustering": "A dictionary to be provided to the clustering method. By default, random_projections is used, but if legacy is set to\ - True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1", - "matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\ - can be used", + "filtering": "A dictionary for the high_pass filter used during preprocessing", + "whitening": "A dictionary for the whitening used during preprocessing", + "detection": "A dictionary for the peak detection component. Default is matched filtering", + "selection": "A dictionary for the peak selection component. Default is to use uniform", + "clustering": "A dictionary for the clustering component. Default, graph_clustering is used", + "matching": "A dictionary for the matching component. Default circus-omp-svd. Use None to avoid matching", "merging": "A dictionary to specify the final merging param to group cells after template matching (auto_merge_units)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ @@ -90,14 +90,6 @@ def get_sorter_version(cls): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - try: - import hdbscan - - HAVE_HDBSCAN = True - except: - HAVE_HDBSCAN = False - - assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" try: import torch @@ -124,11 +116,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_after = params["general"].get("ms_after", 2) radius_um = params["general"].get("radius_um", 75) peak_sign = params["detection"].get("peak_sign", "neg") + debug = params["debug"] + seed = params["seed"] + apply_preprocessing = params["apply_preprocessing"] + apply_motion_correction = params["apply_motion_correction"] exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after)) ## First, we are filtering the data filtering_params = params["filtering"].copy() - if params["apply_preprocessing"]: + if apply_preprocessing: if verbose: print("Preprocessing the recording (bandpass filtering + CMR + whitening)") recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") @@ -141,7 +137,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f.annotate(is_filtered=True) valid_geometry = check_probe_for_drift_correction(recording_f) - if params["apply_motion_correction"]: + if apply_motion_correction: if not valid_geometry: if verbose: print("Geometry of the probe does not allow 1D drift correction") @@ -176,146 +172,157 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w = cache_preprocessing(recording_w, **job_kwargs, **params["cache_preprocessing"]) ## Then, we are detecting peaks with a locally_exclusive method - detection_params = params["detection"].copy() - selection_params = params["selection"].copy() + detection_method = params["detection"].get("method", "matched_filtering") + detection_params = params["detection"].get("method_kwargs", dict()) detection_params["radius_um"] = radius_um detection_params["exclude_sweep_ms"] = exclude_sweep_ms detection_params["noise_levels"] = noise_levels - fs = recording_w.get_sampling_frequency() - nbefore = int(ms_before * fs / 1000.0) - nafter = int(ms_after * fs / 1000.0) - - skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform" - max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels - n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) - - if params["debug"]: + selection_method = params["selection"].get("method", "uniform") + selection_params = params["selection"].get("method_kwargs", dict()) + n_peaks_per_channel = selection_params.get("n_peaks_per_channel", 5000) + min_n_peaks = selection_params.get("min_n_peaks", 100000) + skip_peaks = not params["multi_units_only"] and selection_method == "uniform" + max_n_peaks = n_peaks_per_channel * num_channels + n_peaks = max(min_n_peaks, max_n_peaks) + selection_params["n_peaks"] = n_peaks + selection_params["noise_levels"] = noise_levels + + if debug: clustering_folder = sorter_output_folder / "clustering" clustering_folder.mkdir(parents=True, exist_ok=True) np.save(clustering_folder / "noise_levels.npy", noise_levels) - if params["matched_filtering"]: + if detection_method == "matched_filtering": prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( recording_w, n_peaks=10000, ms_before=ms_before, ms_after=ms_after, - seed=params["seed"], + seed=seed, **detection_params, **job_kwargs, ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before - if params["debug"]: + if debug: np.save(clustering_folder / "waveforms.npy", waveforms) np.save(clustering_folder / "prototype.npy", prototype) if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=params["seed"], **job_kwargs + recording_w, seed=seed, **job_kwargs ) - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) + detection_method = "matched_filtering" else: waveforms = None if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=params["seed"], **job_kwargs + recording_w, seed=seed, **job_kwargs ) - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) + detection_method = "locally_exclusive" + + peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) + + if debug: + np.save(clustering_folder / "peaks.npy", peaks) if not skip_peaks and verbose: print("Found %d peaks in total" % len(peaks)) + sparsity_kwargs = params["sparsity"].copy() + if "peak_sign" not in sparsity_kwargs: + sparsity_kwargs["peak_sign"] = peak_sign + if params["multi_units_only"]: sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) else: ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible - selection_params = params["selection"] - selection_params["n_peaks"] = n_peaks - selection_params.update({"noise_levels": noise_levels}) - selected_peaks = select_peaks(peaks, **selection_params) + selected_peaks = select_peaks(peaks, seed=seed, method=selection_method, **selection_params) if verbose: print("Kept %d peaks for clustering" % len(selected_peaks)) - ## We launch a clustering (using hdbscan) relying on positions and features extracted on - ## the fly from the snippets - clustering_params = params["clustering"].copy() - clustering_params["waveforms"] = {} - sparsity_kwargs = params["sparsity"].copy() - if "peak_sign" not in sparsity_kwargs: - sparsity_kwargs["peak_sign"] = peak_sign - - clustering_params["sparsity"] = sparsity_kwargs - clustering_params["radius_um"] = radius_um - clustering_params["waveforms"]["ms_before"] = ms_before - clustering_params["waveforms"]["ms_after"] = ms_after - clustering_params["few_waveforms"] = waveforms - clustering_params["noise_levels"] = noise_levels - clustering_params["ms_before"] = ms_before - clustering_params["ms_after"] = ms_after - clustering_params["verbose"] = verbose - clustering_params["tmp_folder"] = sorter_output_folder / "clustering" - clustering_params["debug"] = params["debug"] - clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) - - legacy = clustering_params.get("legacy", True) - - if legacy: - clustering_method = "circus" - else: - clustering_method = "random_projections" - - labels, peak_labels = find_cluster_from_peaks( - recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs - ) - - ## We get the labels for our peaks - mask = peak_labels > -1 - - labeled_peaks = np.zeros(np.sum(mask), dtype=minimum_spike_dtype) - labeled_peaks["sample_index"] = selected_peaks[mask]["sample_index"] - labeled_peaks["segment_index"] = selected_peaks[mask]["segment_index"] - for count, l in enumerate(labels): - sub_mask = peak_labels[mask] == l - labeled_peaks["unit_index"][sub_mask] = count - unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"]))) - sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids) - - if params["debug"]: - np.save(clustering_folder / "peak_labels", peak_labels) - np.save(clustering_folder / "labels", labels) - np.save(clustering_folder / "peaks", selected_peaks) - - templates_array = estimate_templates( - recording_w, labeled_peaks, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + clustering_method = params["clustering"].get("method", "graph_clustering") + clustering_params = params["clustering"].get("method_kwargs", dict()) + + if clustering_method == "circus": + clustering_params["waveforms"] = {} + clustering_params["sparsity"] = sparsity_kwargs + clustering_params["radius_um"] = radius_um + clustering_params["waveforms"]["ms_before"] = ms_before + clustering_params["waveforms"]["ms_after"] = ms_after + clustering_params["few_waveforms"] = waveforms + clustering_params["noise_levels"] = noise_levels + clustering_params["ms_before"] = ms_before + clustering_params["ms_after"] = ms_after + clustering_params["verbose"] = verbose + clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["debug"] = debug + clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) + elif clustering_method == "graph_clustering": + clustering_params = {"ms_before" : ms_before, + "ms_after" : ms_after, + "clustering_method": "hdbscan", + "radius_um" : radius_um, + "clustering_kwargs" : dict(min_samples=1, + n_jobs=-1, + min_cluster_size=50, + cluster_selection_method='leaf', + allow_single_cluster=True, + cluster_selection_epsilon=0.1) + } + + outputs = find_cluster_from_peaks( + recording_w, + selected_peaks, + method=clustering_method, + method_kwargs=clustering_params, + extra_outputs=True, + **job_kwargs ) + if len(outputs) == 2: + _, peak_labels = outputs + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording + templates = get_templates_from_peaks_and_recording( + recording, + peaks, + peak_labels, + ms_before, + ms_after, + **job_kwargs, + ) - templates = Templates( - templates_array=templates_array, - sampling_frequency=sampling_frequency, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording_w.channel_ids, - unit_ids=unit_ids, - probe=recording_w.get_probe(), - is_scaled=False, - ) + elif len(outputs) == 3: + _, peak_labels, templates = outputs + elif len(outputs) == 5: + _, peak_labels, svd_model, svd_features, sparsity_mask = outputs + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd + templates = get_templates_from_peaks_and_svd( + recording_w, + peaks, + peak_labels, + ms_before, + ms_after, + svd_model, + svd_features, + sparsity_mask, + operator="median" + ) sparsity = compute_sparsity(templates, noise_levels, **sparsity_kwargs) templates = templates.to_sparse(sparsity) templates = remove_empty_templates(templates) - if params["debug"]: + if debug: templates.to_zarr(folder_path=clustering_folder / "templates") sorting = sorting.save(folder=clustering_folder / "sorting") ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces - matching_method = params["matching"].pop("method") - matching_params = params["matching"].copy() + matching_method = params["matching"].get("method", "circus-omp_svd") + matching_params = params["matching"].get("method_kwargs", dict()) matching_params["templates"] = templates if matching_method is not None: @@ -323,7 +330,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w, matching_method, method_kwargs=matching_params, **job_kwargs ) - if params["debug"]: + if debug: fitting_folder = sorter_output_folder / "fitting" fitting_folder.mkdir(parents=True, exist_ok=True) np.save(fitting_folder / "spikes", spikes) @@ -336,15 +343,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting["sample_index"] = spikes["sample_index"] sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] - sorting = NumpySorting(sorting, sampling_frequency, unit_ids) + sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) sorting_folder = sorter_output_folder / "sorting" if sorting_folder.exists(): shutil.rmtree(sorting_folder) merging_params = params["merging"].copy() - if params["debug"]: - merging_params["debug_folder"] = sorter_output_folder / "merging" + merging_params["debug_folder"] = sorter_output_folder / "merging" if len(merging_params) > 0: if params["motion_correction"] and motion_folder is not None: @@ -358,7 +364,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): max_distance_um = merging_params.get("max_distance_um", 50) merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) - if params["debug"]: + if debug: curation_folder = sorter_output_folder / "curation" if curation_folder.exists(): shutil.rmtree(curation_folder) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 0f1a636ee9..fea4c57a18 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -23,9 +23,6 @@ from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd -from spikeinterface.core.node_pipeline import ( - run_node_pipeline, -) from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py index af25267b86..95016318f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py @@ -59,6 +59,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): radius_um = params["radius_um"] motion = params["motion"] seed = params["seed"] + ms_before = params["ms_before"] + ms_after = params["ms_after"] clustering_method = params["clustering_method"] clustering_kwargs = params["clustering_kwargs"] graph_kwargs = params["graph_kwargs"] @@ -73,6 +75,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peaks_svd, sparse_mask, svd_model = extract_peaks_svd( recording, peaks, + ms_before=ms_before, + ms_after=ms_after, radius_um=radius_um, motion_aware=motion_aware, motion=None, diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 693f67305f..73e9211202 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -200,6 +200,7 @@ def get_templates_from_peaks_and_recording( peak_labels, ms_before, ms_after, + operator="mean", **job_kwargs, ): """ @@ -238,7 +239,15 @@ def get_templates_from_peaks_and_recording( from spikeinterface.core.waveform_tools import estimate_templates templates_array = estimate_templates( - recording, peaks, labels, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + peaks, + labels, + nbefore, + nafter, + operator=operator, + return_scaled=False, + job_name=None, + **job_kwargs ) templates = Templates( From 5be94daca81ac64a9c0d07f1361381a6d33800d1 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 11:25:47 +0200 Subject: [PATCH 07/33] Patching to get a working example --- .../sorters/internal/spyking_circus2.py | 12 +++---- .../sortingcomponents/clustering/circus.py | 4 ++- .../sortingcomponents/clustering/tools.py | 33 +++++++++++++------ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b04de700d6..9e8ff240e9 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -41,9 +41,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"method": "graph_clustering", + "clustering": {"method": "circus", "method_kwargs" : dict()}, - "matching": {"method": "circus-omp-svd", + "matching": {"method": "wobble", "method_kwargs" : dict()}, "apply_preprocessing": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -162,6 +162,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} recording_w = whiten(recording_f, **whitening_kwargs) + noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) if recording_w.check_serializability("json"): @@ -251,6 +252,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if clustering_method == "circus": clustering_params["waveforms"] = {} clustering_params["sparsity"] = sparsity_kwargs + clustering_params["neighbors_radius_um"] = 50 clustering_params["radius_um"] = radius_um clustering_params["waveforms"]["ms_before"] = ms_before clustering_params["waveforms"]["ms_after"] = ms_after @@ -280,21 +282,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selected_peaks, method=clustering_method, method_kwargs=clustering_params, - extra_outputs=True, + extra_outputs=False, **job_kwargs ) if len(outputs) == 2: _, peak_labels = outputs from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording templates = get_templates_from_peaks_and_recording( - recording, + recording_w, peaks, peak_labels, ms_before, ms_after, **job_kwargs, ) - elif len(outputs) == 3: _, peak_labels, templates = outputs elif len(outputs) == 5: @@ -318,7 +319,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if debug: templates.to_zarr(folder_path=clustering_folder / "templates") - sorting = sorting.save(folder=clustering_folder / "sorting") ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_method = params["matching"].get("method", "circus-omp_svd") diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index fea4c57a18..3ac8a9b259 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -51,6 +51,7 @@ class CircusClustering: "split_kwargs": {"projection_mode": "tsvd", "n_pca_features": 0.9}, "radius_um": 100, + "neighbors_radius_um": 50, "n_svd": 5, "few_waveforms": None, "ms_before": 0.5, @@ -74,6 +75,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_before = params["ms_before"] ms_after = params["ms_after"] radius_um = params["radius_um"] + neighbors_radius_um = params["neighbors_radius_um"] nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) if params["tmp_folder"] is None: @@ -119,7 +121,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): folder=features_folder, **job_kwargs) - neighbours_mask = get_channel_distances(recording) <= radius_um + neighbours_mask = get_channel_distances(recording) <= neighbors_radius_um if params["debug"]: np.save(features_folder / "sparse_mask.npy", sparse_mask) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 73e9211202..b5d53cae58 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -200,7 +200,7 @@ def get_templates_from_peaks_and_recording( peak_labels, ms_before, ms_after, - operator="mean", + operator="average", **job_kwargs, ): """ @@ -220,6 +220,8 @@ def get_templates_from_peaks_and_recording( The time window before the peak in milliseconds. ms_after : float The time window after the peak in milliseconds. + operator : str + The operator to use for template estimation. Can be 'average' or 'median'. job_kwargs : dict Additional keyword arguments for the estimate_templates function. @@ -229,19 +231,25 @@ def get_templates_from_peaks_and_recording( The estimated templates object. """ from spikeinterface.core.template import Templates + from spikeinterface.core.numpyextractors import NumpySorting mask = peak_labels > -1 - labels = np.unique(peak_labels[mask]) + valid_peaks = peaks[mask] + valid_labels = peak_labels[mask] + labels = np.unique(valid_labels) + fs = recording.get_sampling_frequency() nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) + spikes = NumpySorting.from_peaks(valid_peaks, fs, labels) + spikes = spikes.to_spike_vector() from spikeinterface.core.waveform_tools import estimate_templates templates_array = estimate_templates( recording, - peaks, - labels, + spikes, + np.arange(len(labels)), nbefore, nafter, operator=operator, @@ -273,7 +281,7 @@ def get_templates_from_peaks_and_svd( svd_model, svd_features, sparsity_mask, - operator="mean", + operator="average", ): """ Get templates from recording using the SVD components @@ -296,6 +304,8 @@ def get_templates_from_peaks_and_svd( The SVD features array. sparsity_mask : numpy.ndarray The sparsity mask array. + operator : str + The operator to use for template estimation. Can be 'average' or 'median'. Returns ------- @@ -306,7 +316,10 @@ def get_templates_from_peaks_and_svd( assert operator in ["mean", "median"], "operator should be either 'mean' or 'median'" mask = peak_labels > -1 - labels = np.unique(peak_labels[mask]) + valid_peaks = peaks[mask] + valid_labels = peak_labels[mask] + valid_svd_features = svd_features[mask] + labels = np.unique(valid_labels) fs = recording.get_sampling_frequency() nbefore = int(ms_before * fs / 1000.0) @@ -315,14 +328,14 @@ def get_templates_from_peaks_and_svd( templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) for unit_ind, label in enumerate(labels): - mask = peak_labels == label - local_peaks = peaks[mask] - local_svd = svd_features[mask] + mask = valid_labels == label + local_peaks = valid_peaks[mask] + local_svd = valid_svd_features[mask] peak_channels, b = np.unique(local_peaks["channel_index"], return_counts=True) best_channel = peak_channels[np.argmax(b)] sub_mask = local_peaks["channel_index"] == best_channel for count, i in enumerate(np.flatnonzero(sparsity_mask[best_channel])): - if operator == "mean": + if operator == "average": data = np.mean(local_svd[sub_mask, :, count], 0) elif operator == "median": data = np.median(local_svd[sub_mask, :, count], 0) From 2be228be00bd900bb30d88199c7113c3e2b2ebf0 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 12:59:21 +0200 Subject: [PATCH 08/33] WIP --- .../sorters/internal/spyking_circus2.py | 6 ++---- .../sortingcomponents/clustering/tools.py | 12 ++++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9e8ff240e9..0e12a9ac03 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -24,7 +24,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, "detection": {"method" : "matched_filtering", @@ -282,7 +282,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selected_peaks, method=clustering_method, method_kwargs=clustering_params, - extra_outputs=False, + extra_outputs=True, **job_kwargs ) if len(outputs) == 2: @@ -296,8 +296,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_after, **job_kwargs, ) - elif len(outputs) == 3: - _, peak_labels, templates = outputs elif len(outputs) == 5: _, peak_labels, svd_model, svd_features, sparsity_mask = outputs from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index b5d53cae58..9a5fbd89d5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -334,12 +334,12 @@ def get_templates_from_peaks_and_svd( peak_channels, b = np.unique(local_peaks["channel_index"], return_counts=True) best_channel = peak_channels[np.argmax(b)] sub_mask = local_peaks["channel_index"] == best_channel - for count, i in enumerate(np.flatnonzero(sparsity_mask[best_channel])): - if operator == "average": - data = np.mean(local_svd[sub_mask, :, count], 0) - elif operator == "median": - data = np.median(local_svd[sub_mask, :, count], 0) - templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) + channel_indices = np.flatnonzero(sparsity_mask[best_channel]) + if operator == "average": + data = np.mean(local_svd[sub_mask], 0) + elif operator == "median": + data = np.median(local_svd[sub_mask], 0) + templates_array[unit_ind, :, channel_indices] = svd_model.inverse_transform(data.T) templates = Templates( templates_array=templates_array, From bd7c7bea372457fcbdf4c024aff7453c0de73490 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 13:09:47 +0200 Subject: [PATCH 09/33] WIP --- .../sorters/internal/spyking_circus2.py | 4 +++- .../sortingcomponents/clustering/tools.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 0e12a9ac03..9656e15fd3 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -237,7 +237,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sparsity_kwargs["peak_sign"] = peak_sign if params["multi_units_only"]: - sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) + sorting = NumpySorting.from_peaks(peaks, + sampling_frequency, + unit_ids=recording_w.channel_ids) else: ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 9a5fbd89d5..421727f396 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -264,7 +264,7 @@ def get_templates_from_peaks_and_recording( nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=labels, + unit_ids=np.arange(len(labels)), probe=recording.get_probe(), is_scaled=False, ) @@ -334,12 +334,12 @@ def get_templates_from_peaks_and_svd( peak_channels, b = np.unique(local_peaks["channel_index"], return_counts=True) best_channel = peak_channels[np.argmax(b)] sub_mask = local_peaks["channel_index"] == best_channel - channel_indices = np.flatnonzero(sparsity_mask[best_channel]) - if operator == "average": - data = np.mean(local_svd[sub_mask], 0) - elif operator == "median": - data = np.median(local_svd[sub_mask], 0) - templates_array[unit_ind, :, channel_indices] = svd_model.inverse_transform(data.T) + for count, i in enumerate(np.flatnonzero(sparsity_mask[best_channel])): + if operator == "average": + data = np.mean(local_svd[sub_mask, :, count], 0) + elif operator == "median": + data = np.median(local_svd[sub_mask, :, count], 0) + templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) templates = Templates( templates_array=templates_array, @@ -347,7 +347,7 @@ def get_templates_from_peaks_and_svd( nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=labels, + unit_ids=np.arange(len(labels)), probe=recording.get_probe(), is_scaled=False, ) From 1cd89f3e1176f92aea6905ea2c80724199bb6426 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 13:50:00 +0200 Subject: [PATCH 10/33] WIP --- .../sorters/internal/spyking_circus2.py | 74 +++++++++++-------- .../sortingcomponents/clustering/tools.py | 21 +++--- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9656e15fd3..379f9dbe1d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -8,7 +8,6 @@ from spikeinterface.core import NumpySorting from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.template import Templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, @@ -23,9 +22,18 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, + "general": {"ms_before": 2, + "ms_after": 2, + "radius_um": 100}, + "sparsity": {"method": "snr", + "amplitude_mode": + "peak_to_peak", + "threshold": 0.25}, + "filtering": {"freq_min": 150, + "freq_max": 7000, + "ftype": "bessel", + "filter_order": 2, + "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, "detection": {"method" : "matched_filtering", "method_kwargs" : dict( @@ -46,7 +54,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matching": {"method": "wobble", "method_kwargs" : dict()}, "apply_preprocessing": True, - "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "cache_preprocessing": {"mode": "memory", + "memory_limit": 0.5, + "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, "seed": 42, @@ -236,6 +246,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if "peak_sign" not in sparsity_kwargs: sparsity_kwargs["peak_sign"] = peak_sign + sorting_folder = sorter_output_folder / "sorting" + if sorting_folder.exists(): + shutil.rmtree(sorting_folder) + if params["multi_units_only"]: sorting = NumpySorting.from_peaks(peaks, sampling_frequency, @@ -284,7 +298,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selected_peaks, method=clustering_method, method_kwargs=clustering_params, - extra_outputs=True, + extra_outputs=False, **job_kwargs ) if len(outputs) == 2: @@ -343,38 +357,36 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting["sample_index"] = spikes["sample_index"] sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] + order = np.lexsort((spikes["sample_index"], spikes["segment_index"])) + spikes = spikes[order] sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) - sorting_folder = sorter_output_folder / "sorting" - if sorting_folder.exists(): - shutil.rmtree(sorting_folder) + merging_params = params["merging"].copy() + merging_params["debug_folder"] = sorter_output_folder / "merging" - merging_params = params["merging"].copy() - merging_params["debug_folder"] = sorter_output_folder / "merging" + if len(merging_params) > 0: + if params["motion_correction"] and motion_folder is not None: + from spikeinterface.preprocessing.motion import load_motion_info - if len(merging_params) > 0: - if params["motion_correction"] and motion_folder is not None: - from spikeinterface.preprocessing.motion import load_motion_info + motion_info = load_motion_info(motion_folder) + motion = motion_info["motion"] + max_motion = max( + np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) + ) + max_distance_um = merging_params.get("max_distance_um", 50) + merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) - motion_info = load_motion_info(motion_folder) - motion = motion_info["motion"] - max_motion = max( - np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) - ) - max_distance_um = merging_params.get("max_distance_um", 50) - merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) - - if debug: - curation_folder = sorter_output_folder / "curation" - if curation_folder.exists(): - shutil.rmtree(curation_folder) - sorting.save(folder=curation_folder) - # np.save(fitting_folder / "amplitudes", guessed_amplitudes) + if debug: + curation_folder = sorter_output_folder / "curation" + if curation_folder.exists(): + shutil.rmtree(curation_folder) + sorting.save(folder=curation_folder) + # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) + sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) - if verbose: - print(f"Kept {len(sorting.unit_ids)} units after final merging") + if verbose: + print(f"Kept {len(sorting.unit_ids)} units after final merging") folder_to_delete = None cache_mode = params["cache_preprocessing"].get("mode", "memory") diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 421727f396..ab7b0dcb2b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -234,22 +234,25 @@ def get_templates_from_peaks_and_recording( from spikeinterface.core.numpyextractors import NumpySorting mask = peak_labels > -1 - valid_peaks = peaks[mask] - valid_labels = peak_labels[mask] - labels = np.unique(valid_labels) + labels = np.unique(peak_labels[mask]) fs = recording.get_sampling_frequency() nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - spikes = NumpySorting.from_peaks(valid_peaks, fs, labels) - spikes = spikes.to_spike_vector() + + sorting = NumpySorting.from_samples_and_labels( + peaks["sample_index"][mask], + peak_labels[mask], + fs, + unit_ids=labels, + ) from spikeinterface.core.waveform_tools import estimate_templates templates_array = estimate_templates( recording, - spikes, - np.arange(len(labels)), + sorting.to_spike_vector(), + sorting.unit_ids, nbefore, nafter, operator=operator, @@ -264,7 +267,7 @@ def get_templates_from_peaks_and_recording( nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=np.arange(len(labels)), + unit_ids=labels, probe=recording.get_probe(), is_scaled=False, ) @@ -347,7 +350,7 @@ def get_templates_from_peaks_and_svd( nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=np.arange(len(labels)), + unit_ids=labels, probe=recording.get_probe(), is_scaled=False, ) From c28a7b67411ed63ea05d25fad2e447608709d9ab Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 14:29:12 +0200 Subject: [PATCH 11/33] WIP --- .../sorters/internal/spyking_circus2.py | 5 +- .../sortingcomponents/clustering/circus.py | 148 +++++++++--------- .../sortingcomponents/clustering/tools.py | 20 +-- 3 files changed, 87 insertions(+), 86 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 379f9dbe1d..ec08cdca7e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -235,6 +235,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): detection_method = "locally_exclusive" peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) + order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) + peaks = peaks[order] if debug: np.save(clustering_folder / "peaks.npy", peaks) @@ -301,6 +303,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): extra_outputs=False, **job_kwargs ) + if len(outputs) == 2: _, peak_labels = outputs from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording @@ -357,8 +360,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting["sample_index"] = spikes["sample_index"] sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] - order = np.lexsort((spikes["sample_index"], spikes["segment_index"])) - spikes = spikes[order] sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) merging_params = params["merging"].copy() diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 3ac8a9b259..74c3286561 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -151,78 +151,78 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **params["recursive_kwargs"], **job_kwargs) - non_noise = peak_labels > -1 - labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) - peak_labels[non_noise] = inverse - labels = np.unique(inverse) - - spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype) - spikes["sample_index"] = peaks[non_noise]["sample_index"] - spikes["segment_index"] = peaks[non_noise]["segment_index"] - spikes["unit_index"] = peak_labels[non_noise] - - unit_ids = labels - - nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) - nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) - - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - - templates_array = estimate_templates( - recording, - spikes, - unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name=None, - **job_kwargs, - ) - - best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) - peak_snrs = np.abs(templates_array[:, nbefore, :]) - best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] - valid_templates = best_snrs_ratio > params["noise_threshold"] - - templates = Templates( - templates_array=templates_array[valid_templates], - sampling_frequency=fs, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=unit_ids[valid_templates], - probe=recording.get_probe(), - is_scaled=False, - ) - - if params["debug"]: - templates_folder = tmp_folder / "dense_templates" - templates.to_zarr(folder_path=templates_folder) - - sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) - templates = templates.to_sparse(sparsity) - empty_templates = templates.sparsity_mask.sum(axis=1) == 0 - templates = remove_empty_templates(templates) - - mask = np.isin(peak_labels, np.where(empty_templates)[0]) - peak_labels[mask] = -1 - - mask = np.isin(peak_labels, np.where(~valid_templates)[0]) - peak_labels[mask] = -1 - - if verbose: - print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - - cleaning_job_kwargs = job_kwargs.copy() - cleaning_job_kwargs["progress_bar"] = False - cleaning_params = params["cleaning_kwargs"].copy() - - labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params - ) - - if verbose: - print("Kept %d non-duplicated clusters" % len(labels)) - + # non_noise = peak_labels > -1 + # labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) + # peak_labels[non_noise] = inverse + # labels = np.unique(inverse) + + # spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype) + # spikes["sample_index"] = peaks[non_noise]["sample_index"] + # spikes["segment_index"] = peaks[non_noise]["segment_index"] + # spikes["unit_index"] = peak_labels[non_noise] + + # unit_ids = labels + + # nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) + # nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + + # if params["noise_levels"] is None: + # params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + + # templates_array = estimate_templates( + # recording, + # spikes, + # unit_ids, + # nbefore, + # nafter, + # return_scaled=False, + # job_name=None, + # **job_kwargs, + # ) + + # best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + # peak_snrs = np.abs(templates_array[:, nbefore, :]) + # best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + # valid_templates = best_snrs_ratio > params["noise_threshold"] + + # templates = Templates( + # templates_array=templates_array[valid_templates], + # sampling_frequency=fs, + # nbefore=nbefore, + # sparsity_mask=None, + # channel_ids=recording.channel_ids, + # unit_ids=unit_ids[valid_templates], + # probe=recording.get_probe(), + # is_scaled=False, + # ) + + # if params["debug"]: + # templates_folder = tmp_folder / "dense_templates" + # templates.to_zarr(folder_path=templates_folder) + + # sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) + # templates = templates.to_sparse(sparsity) + # empty_templates = templates.sparsity_mask.sum(axis=1) == 0 + # templates = remove_empty_templates(templates) + + # mask = np.isin(peak_labels, np.where(empty_templates)[0]) + # peak_labels[mask] = -1 + + # mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + # peak_labels[mask] = -1 + + # if verbose: + # print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + + # cleaning_job_kwargs = job_kwargs.copy() + # cleaning_job_kwargs["progress_bar"] = False + # cleaning_params = params["cleaning_kwargs"].copy() + + # labels, peak_labels = remove_duplicates_via_matching( + # templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params + # ) + + # if verbose: + # print("Kept %d non-duplicated clusters" % len(labels)) + labels = np.unique(peak_labels) return labels, peak_labels, svd_model, peaks_svd, sparse_mask diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index ab7b0dcb2b..28f8c023cb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -231,28 +231,28 @@ def get_templates_from_peaks_and_recording( The estimated templates object. """ from spikeinterface.core.template import Templates - from spikeinterface.core.numpyextractors import NumpySorting + from spikeinterface.core.basesorting import minimum_spike_dtype mask = peak_labels > -1 - labels = np.unique(peak_labels[mask]) + valid_peaks = peaks[mask] + valid_labels = peak_labels[mask] + labels, indices = np.unique(valid_labels, return_inverse=True) fs = recording.get_sampling_frequency() nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - sorting = NumpySorting.from_samples_and_labels( - peaks["sample_index"][mask], - peak_labels[mask], - fs, - unit_ids=labels, - ) + sorting = np.zeros(valid_peaks.size, dtype=minimum_spike_dtype) + sorting["sample_index"] = valid_peaks["sample_index"] + sorting["unit_index"] = indices + sorting["segment_index"] = valid_peaks["segment_index"] from spikeinterface.core.waveform_tools import estimate_templates templates_array = estimate_templates( recording, - sorting.to_spike_vector(), - sorting.unit_ids, + sorting, + np.arange(len(labels)), nbefore, nafter, operator=operator, From 8e455f98ecaf453d4dba779a1301282e8a04e4fc Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 14:42:17 +0200 Subject: [PATCH 12/33] WIP --- .../sorters/internal/spyking_circus2.py | 7 +- .../sortingcomponents/clustering/circus.py | 148 +++++++++--------- 2 files changed, 78 insertions(+), 77 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ec08cdca7e..13666b50c7 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -34,7 +34,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, - "whitening": {"mode": "local", "regularize": False}, + "whitening": {"mode": "local", + "regularize": False}, "detection": {"method" : "matched_filtering", "method_kwargs" : dict( peak_sign="neg", @@ -49,9 +50,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"method": "circus", + "clustering": {"method": "graph_clustering", "method_kwargs" : dict()}, - "matching": {"method": "wobble", + "matching": {"method": "circus-omp-svd", "method_kwargs" : dict()}, "apply_preprocessing": True, "cache_preprocessing": {"mode": "memory", diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 74c3286561..3ac8a9b259 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -151,78 +151,78 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **params["recursive_kwargs"], **job_kwargs) - # non_noise = peak_labels > -1 - # labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) - # peak_labels[non_noise] = inverse - # labels = np.unique(inverse) - - # spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype) - # spikes["sample_index"] = peaks[non_noise]["sample_index"] - # spikes["segment_index"] = peaks[non_noise]["segment_index"] - # spikes["unit_index"] = peak_labels[non_noise] - - # unit_ids = labels - - # nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) - # nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) - - # if params["noise_levels"] is None: - # params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - - # templates_array = estimate_templates( - # recording, - # spikes, - # unit_ids, - # nbefore, - # nafter, - # return_scaled=False, - # job_name=None, - # **job_kwargs, - # ) - - # best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) - # peak_snrs = np.abs(templates_array[:, nbefore, :]) - # best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] - # valid_templates = best_snrs_ratio > params["noise_threshold"] - - # templates = Templates( - # templates_array=templates_array[valid_templates], - # sampling_frequency=fs, - # nbefore=nbefore, - # sparsity_mask=None, - # channel_ids=recording.channel_ids, - # unit_ids=unit_ids[valid_templates], - # probe=recording.get_probe(), - # is_scaled=False, - # ) - - # if params["debug"]: - # templates_folder = tmp_folder / "dense_templates" - # templates.to_zarr(folder_path=templates_folder) - - # sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) - # templates = templates.to_sparse(sparsity) - # empty_templates = templates.sparsity_mask.sum(axis=1) == 0 - # templates = remove_empty_templates(templates) - - # mask = np.isin(peak_labels, np.where(empty_templates)[0]) - # peak_labels[mask] = -1 - - # mask = np.isin(peak_labels, np.where(~valid_templates)[0]) - # peak_labels[mask] = -1 - - # if verbose: - # print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - - # cleaning_job_kwargs = job_kwargs.copy() - # cleaning_job_kwargs["progress_bar"] = False - # cleaning_params = params["cleaning_kwargs"].copy() - - # labels, peak_labels = remove_duplicates_via_matching( - # templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params - # ) - - # if verbose: - # print("Kept %d non-duplicated clusters" % len(labels)) - labels = np.unique(peak_labels) + non_noise = peak_labels > -1 + labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) + peak_labels[non_noise] = inverse + labels = np.unique(inverse) + + spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype) + spikes["sample_index"] = peaks[non_noise]["sample_index"] + spikes["segment_index"] = peaks[non_noise]["segment_index"] + spikes["unit_index"] = peak_labels[non_noise] + + unit_ids = labels + + nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) + nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + + templates_array = estimate_templates( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, + ) + + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + + templates = Templates( + templates_array=templates_array[valid_templates], + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=unit_ids[valid_templates], + probe=recording.get_probe(), + is_scaled=False, + ) + + if params["debug"]: + templates_folder = tmp_folder / "dense_templates" + templates.to_zarr(folder_path=templates_folder) + + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) + templates = templates.to_sparse(sparsity) + empty_templates = templates.sparsity_mask.sum(axis=1) == 0 + templates = remove_empty_templates(templates) + + mask = np.isin(peak_labels, np.where(empty_templates)[0]) + peak_labels[mask] = -1 + + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + + if verbose: + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False + cleaning_params = params["cleaning_kwargs"].copy() + + labels, peak_labels = remove_duplicates_via_matching( + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params + ) + + if verbose: + print("Kept %d non-duplicated clusters" % len(labels)) + return labels, peak_labels, svd_model, peaks_svd, sparse_mask From 6de831051d171671eabd38f6de754a42a5519190 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 15:29:53 +0200 Subject: [PATCH 13/33] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 8 ++++---- .../sortingcomponents/clustering/graph_clustering.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 13666b50c7..4447e8ce62 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -28,7 +28,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", - "threshold": 0.25}, + "threshold": 0}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", @@ -52,7 +52,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "merging": {"max_distance_um": 50}, "clustering": {"method": "graph_clustering", "method_kwargs" : dict()}, - "matching": {"method": "circus-omp-svd", + "matching": {"method": "wobble", "method_kwargs" : dict()}, "apply_preprocessing": True, "cache_preprocessing": {"mode": "memory", @@ -289,8 +289,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): "clustering_method": "hdbscan", "radius_um" : radius_um, "clustering_kwargs" : dict(min_samples=1, - n_jobs=-1, min_cluster_size=50, + core_dist_n_jobs=-1, cluster_selection_method='leaf', allow_single_cluster=True, cluster_selection_epsilon=0.1) @@ -301,7 +301,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selected_peaks, method=clustering_method, method_kwargs=clustering_params, - extra_outputs=False, + extra_outputs=True, **job_kwargs ) diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py index 95016318f6..3de3cb5a43 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py @@ -102,7 +102,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # print(distances.shape) # print("sparsity: ", distances.indices.size / (distances.shape[0]**2)) - print("clustering_method", clustering_method) + #print("clustering_method", clustering_method) if clustering_method == "networkx-louvain": # using networkx : very slow (possible backend with cude backend="cugraph",) From 3fb5fa6fa8e03561cf0f259dd1919ab934d15f41 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 15:39:36 +0200 Subject: [PATCH 14/33] Cosmetic --- src/spikeinterface/sortingcomponents/clustering/graph_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py index 43f9abe141..409181bcf3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py @@ -118,7 +118,7 @@ def create_graph_from_peak_features( raise ValueError("create_graph_from_peak_features : wrong bin_mode") if progress_bar: - loop = tqdm(loop, desc=f"Construct distance graph looping over {bin_mode}") + loop = tqdm(loop, desc=f"Build distance graph over {bin_mode}") local_graphs = [] row_indices = [] From 5eec5e3fbc4f87957120eadb562fbd6886068e2a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 13:46:50 +0000 Subject: [PATCH 15/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 87 ++++++++----------- .../sortingcomponents/clustering/circus.py | 27 +++--- .../clustering/graph_clustering.py | 2 +- .../sortingcomponents/clustering/tools.py | 20 ++--- 4 files changed, 61 insertions(+), 75 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4447e8ce62..4940ec20c2 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -22,42 +22,22 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, - "ms_after": 2, - "radius_um": 100}, - "sparsity": {"method": "snr", - "amplitude_mode": - "peak_to_peak", - "threshold": 0}, - "filtering": {"freq_min": 150, - "freq_max": 7000, - "ftype": "bessel", - "filter_order": 2, - "margin_ms": 10}, - "whitening": {"mode": "local", - "regularize": False}, - "detection": {"method" : "matched_filtering", - "method_kwargs" : dict( - peak_sign="neg", - detect_threshold=5 - )}, - "selection": {"method": "uniform", - "method_kwargs" : dict( - n_peaks_per_channel=5000, - min_n_peaks=100000, - select_per_channel=False) - }, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, + "whitening": {"mode": "local", "regularize": False}, + "detection": {"method": "matched_filtering", "method_kwargs": dict(peak_sign="neg", detect_threshold=5)}, + "selection": { + "method": "uniform", + "method_kwargs": dict(n_peaks_per_channel=5000, min_n_peaks=100000, select_per_channel=False), + }, "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"method": "graph_clustering", - "method_kwargs" : dict()}, - "matching": {"method": "wobble", - "method_kwargs" : dict()}, + "clustering": {"method": "graph_clustering", "method_kwargs": dict()}, + "matching": {"method": "wobble", "method_kwargs": dict()}, "apply_preprocessing": True, - "cache_preprocessing": {"mode": "memory", - "memory_limit": 0.5, - "delete_cache": True}, + "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, "seed": 42, @@ -234,7 +214,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w, seed=seed, **job_kwargs ) detection_method = "locally_exclusive" - + peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) peaks = peaks[order] @@ -254,9 +234,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): shutil.rmtree(sorting_folder) if params["multi_units_only"]: - sorting = NumpySorting.from_peaks(peaks, - sampling_frequency, - unit_ids=recording_w.channel_ids) + sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.channel_ids) else: ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible @@ -284,30 +262,34 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["debug"] = debug clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) elif clustering_method == "graph_clustering": - clustering_params = {"ms_before" : ms_before, - "ms_after" : ms_after, - "clustering_method": "hdbscan", - "radius_um" : radius_um, - "clustering_kwargs" : dict(min_samples=1, - min_cluster_size=50, - core_dist_n_jobs=-1, - cluster_selection_method='leaf', - allow_single_cluster=True, - cluster_selection_epsilon=0.1) + clustering_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "clustering_method": "hdbscan", + "radius_um": radius_um, + "clustering_kwargs": dict( + min_samples=1, + min_cluster_size=50, + core_dist_n_jobs=-1, + cluster_selection_method="leaf", + allow_single_cluster=True, + cluster_selection_epsilon=0.1, + ), } outputs = find_cluster_from_peaks( - recording_w, - selected_peaks, + recording_w, + selected_peaks, method=clustering_method, - method_kwargs=clustering_params, + method_kwargs=clustering_params, extra_outputs=True, - **job_kwargs + **job_kwargs, ) - + if len(outputs) == 2: _, peak_labels = outputs from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording + templates = get_templates_from_peaks_and_recording( recording_w, peaks, @@ -319,6 +301,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): elif len(outputs) == 5: _, peak_labels, svd_model, svd_features, sparsity_mask = outputs from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd + templates = get_templates_from_peaks_and_svd( recording_w, peaks, @@ -328,7 +311,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): svd_model, svd_features, sparsity_mask, - operator="median" + operator="median", ) sparsity = compute_sparsity(templates, noise_levels, **sparsity_kwargs) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 3ac8a9b259..3023e1b143 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -48,8 +48,7 @@ class CircusClustering: "recursive_depth": 3, "returns_split_count": True, }, - "split_kwargs": {"projection_mode": "tsvd", - "n_pca_features": 0.9}, + "split_kwargs": {"projection_mode": "tsvd", "n_pca_features": 0.9}, "radius_um": 100, "neighbors_radius_um": 50, "n_svd": 5, @@ -107,19 +106,22 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): wfs = wfs[valid] from sklearn.decomposition import TruncatedSVD + svd_model = TruncatedSVD(params["n_svd"]) svd_model.fit(wfs) features_folder = tmp_folder / "tsvd_features" features_folder.mkdir(exist_ok=True) - peaks_svd, sparse_mask, svd_model = extract_peaks_svd(recording, - peaks, - ms_before=ms_before, - ms_after=ms_after, - svd_model=svd_model, - radius_um=radius_um, - folder=features_folder, - **job_kwargs) + peaks_svd, sparse_mask, svd_model = extract_peaks_svd( + recording, + peaks, + ms_before=ms_before, + ms_after=ms_after, + svd_model=svd_model, + radius_um=radius_um, + folder=features_folder, + **job_kwargs, + ) neighbours_mask = get_channel_distances(recording) <= neighbors_radius_um @@ -144,12 +146,13 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peak_labels, _ = split_clusters( original_labels, recording, - {"peaks" : peaks, "sparse_tsvd" : peaks_svd}, + {"peaks": peaks, "sparse_tsvd": peaks_svd}, method="local_feature_clustering", method_kwargs=split_kwargs, debug_folder=debug_folder, **params["recursive_kwargs"], - **job_kwargs) + **job_kwargs, + ) non_noise = peak_labels > -1 labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py index 3de3cb5a43..a0034e7741 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py @@ -102,7 +102,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # print(distances.shape) # print("sparsity: ", distances.indices.size / (distances.shape[0]**2)) - #print("clustering_method", clustering_method) + # print("clustering_method", clustering_method) if clustering_method == "networkx-louvain": # using networkx : very slow (possible backend with cude backend="cugraph",) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 28f8c023cb..037346980e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -237,11 +237,11 @@ def get_templates_from_peaks_and_recording( valid_peaks = peaks[mask] valid_labels = peak_labels[mask] labels, indices = np.unique(valid_labels, return_inverse=True) - + fs = recording.get_sampling_frequency() nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - + sorting = np.zeros(valid_peaks.size, dtype=minimum_spike_dtype) sorting["sample_index"] = valid_peaks["sample_index"] sorting["unit_index"] = indices @@ -250,15 +250,15 @@ def get_templates_from_peaks_and_recording( from spikeinterface.core.waveform_tools import estimate_templates templates_array = estimate_templates( - recording, - sorting, - np.arange(len(labels)), - nbefore, - nafter, + recording, + sorting, + np.arange(len(labels)), + nbefore, + nafter, operator=operator, - return_scaled=False, - job_name=None, - **job_kwargs + return_scaled=False, + job_name=None, + **job_kwargs, ) templates = Templates( From 74f52bd29950e4422884e77a644e437a56c31b27 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 8 Apr 2025 16:02:04 +0200 Subject: [PATCH 16/33] Patch --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4447e8ce62..df6311f5a3 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -171,6 +171,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): whitening_kwargs["regularize"] = False if whitening_kwargs["regularize"]: whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} + whitening_kwargs["apply_mean"] = True recording_w = whiten(recording_f, **whitening_kwargs) From ff104426a161b8c79ee03a863a216d82d57a2584 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 9 Apr 2025 09:34:50 +0200 Subject: [PATCH 17/33] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6b7833a85a..7efd27cb40 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -37,6 +37,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "clustering": {"method": "graph_clustering", "method_kwargs": dict()}, "matching": {"method": "wobble", "method_kwargs": dict()}, "apply_preprocessing": True, + "templates_from_svd": False, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, @@ -61,6 +62,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ median reference + whitening", "apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not", + "templates_from_svd": "Boolean to specify whether templates should be computed from SVD or not.", "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", @@ -107,6 +109,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_after = params["general"].get("ms_after", 2) radius_um = params["general"].get("radius_um", 75) peak_sign = params["detection"].get("peak_sign", "neg") + templates_from_svd = params["templates_from_svd"] debug = params["debug"] seed = params["seed"] apply_preprocessing = params["apply_preprocessing"] @@ -283,7 +286,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selected_peaks, method=clustering_method, method_kwargs=clustering_params, - extra_outputs=True, + extra_outputs=templates_from_svd, **job_kwargs, ) From d0333ddcde375cbce7183a29072368bf398b9e3b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 9 Apr 2025 10:16:15 +0200 Subject: [PATCH 18/33] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 10 +++++----- .../sortingcomponents/matching/circus.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 7efd27cb40..aea6109da4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -23,7 +23,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, "detection": {"method": "matched_filtering", "method_kwargs": dict(peak_sign="neg", detect_threshold=5)}, @@ -34,10 +34,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"method": "graph_clustering", "method_kwargs": dict()}, - "matching": {"method": "wobble", "method_kwargs": dict()}, + "clustering": {"method": "circus", "method_kwargs": dict()}, + "matching": {"method": "circus-omp-svd", "method_kwargs": dict()}, "apply_preprocessing": True, - "templates_from_svd": False, + "templates_from_svd": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, @@ -79,7 +79,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2.0" + return "2.0rc" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 3b97f2dc6a..64a6b2333d 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -31,7 +31,7 @@ def compress_templates( - templates_array, approx_rank, remove_mean=True, return_new_templates=True + templates_array, approx_rank, remove_mean=False, return_new_templates=True ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]: """Compress templates using singular value decomposition. From 220fe0368078482d5194c343bd5adb26d81ab344 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 9 Apr 2025 10:21:32 +0200 Subject: [PATCH 19/33] Fix --- src/spikeinterface/sortingcomponents/clustering/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 037346980e..30e0ef1462 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -317,7 +317,7 @@ def get_templates_from_peaks_and_svd( """ from spikeinterface.core.template import Templates - assert operator in ["mean", "median"], "operator should be either 'mean' or 'median'" + assert operator in ["average", "median"], "operator should be either 'average' or 'median'" mask = peak_labels > -1 valid_peaks = peaks[mask] valid_labels = peak_labels[mask] From 97e8c67d93d5991c552583b314f0d2dfa5fd67aa Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 9 Apr 2025 13:59:27 +0200 Subject: [PATCH 20/33] WIP --- .../sorters/internal/spyking_circus2.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index aea6109da4..d197e70fa2 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -296,7 +296,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = get_templates_from_peaks_and_recording( recording_w, - peaks, + selected_peaks, peak_labels, ms_before, ms_after, @@ -308,7 +308,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = get_templates_from_peaks_and_svd( recording_w, - peaks, + selected_peaks, peak_labels, ms_before, ms_after, @@ -349,6 +349,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) + else: + ## we should have a case to deal with clustering all peaks without matching + ## for small density channel counts + sorting = np.zeros(selected_peaks.size, dtype=minimum_spike_dtype) + sorting["sample_index"] = selected_peaks["sample_index"] + sorting["unit_index"] = peak_labels + sorting["segment_index"] = selected_peaks["segment_index"] + sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) merging_params = params["merging"].copy() merging_params["debug_folder"] = sorter_output_folder / "merging" From 2cb39ebd0dc6e25ff67da3560765d7435603ed90 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 9 Apr 2025 15:28:28 +0200 Subject: [PATCH 21/33] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 3 ++- .../sortingcomponents/clustering/tools.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index d197e70fa2..ca29651aab 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -79,7 +79,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2.0rc" + return "2.1" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): @@ -352,6 +352,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: ## we should have a case to deal with clustering all peaks without matching ## for small density channel counts + sorting = np.zeros(selected_peaks.size, dtype=minimum_spike_dtype) sorting["sample_index"] = selected_peaks["sample_index"] sorting["unit_index"] = peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 30e0ef1462..20b8f2c8de 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -242,16 +242,16 @@ def get_templates_from_peaks_and_recording( nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - sorting = np.zeros(valid_peaks.size, dtype=minimum_spike_dtype) - sorting["sample_index"] = valid_peaks["sample_index"] - sorting["unit_index"] = indices - sorting["segment_index"] = valid_peaks["segment_index"] + spikes = np.zeros(valid_peaks.size, dtype=minimum_spike_dtype) + spikes["sample_index"] = valid_peaks["sample_index"] + spikes["unit_index"] = indices + spikes["segment_index"] = valid_peaks["segment_index"] from spikeinterface.core.waveform_tools import estimate_templates templates_array = estimate_templates( recording, - sorting, + spikes, np.arange(len(labels)), nbefore, nafter, From c502f4d32145a75382748055063a9f8e759fa6c7 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 9 Apr 2025 16:40:12 +0200 Subject: [PATCH 22/33] WIP --- .../sorters/internal/spyking_circus2.py | 1 + .../sortingcomponents/clustering/circus.py | 56 ++++++++++--------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ca29651aab..f038fe8b93 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -262,6 +262,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["ms_before"] = ms_before clustering_params["ms_after"] = ms_after clustering_params["verbose"] = verbose + clustering_params["templates_from_svd"] = templates_from_svd clustering_params["tmp_folder"] = sorter_output_folder / "clustering" clustering_params["debug"] = debug clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 3023e1b143..b9225ec434 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -57,6 +57,7 @@ class CircusClustering: "ms_after": 0.5, "noise_threshold": 4, "rank": 5, + "templates_from_svd": False, "noise_levels": None, "tmp_folder": None, "verbose": True, @@ -154,47 +155,48 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **job_kwargs, ) - non_noise = peak_labels > -1 - labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) - peak_labels[non_noise] = inverse - labels = np.unique(inverse) - - spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype) - spikes["sample_index"] = peaks[non_noise]["sample_index"] - spikes["segment_index"] = peaks[non_noise]["segment_index"] - spikes["unit_index"] = peak_labels[non_noise] - - unit_ids = labels - - nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) - nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) - if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - templates_array = estimate_templates( - recording, - spikes, - unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name=None, - **job_kwargs, - ) + if not params["templates_from_svd"]: + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording + templates = get_templates_from_peaks_and_recording( + recording, + peaks, + peak_labels, + ms_before, + ms_after, + **job_kwargs, + ) + else: + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd + templates = get_templates_from_peaks_and_svd( + recording, + peaks, + peak_labels, + ms_before, + ms_after, + svd_model, + peaks_svd, + sparse_mask, + operator="median", + ) + + templates_array = templates.templates_array best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) peak_snrs = np.abs(templates_array[:, nbefore, :]) best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] valid_templates = best_snrs_ratio > params["noise_threshold"] + from spikeinterface.core.template import Templates templates = Templates( templates_array=templates_array[valid_templates], sampling_frequency=fs, - nbefore=nbefore, + nbefore=templates.nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids[valid_templates], + unit_ids=templates.unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) From cc4fa8fbedef0f08cfdd414381ebbac53a2ebcbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 14:41:34 +0000 Subject: [PATCH 23/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index b9225ec434..92e4907837 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -171,6 +171,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ) else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd + templates = get_templates_from_peaks_and_svd( recording, peaks, @@ -182,7 +183,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): sparse_mask, operator="median", ) - + templates_array = templates.templates_array best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) peak_snrs = np.abs(templates_array[:, nbefore, :]) @@ -190,6 +191,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): valid_templates = best_snrs_ratio > params["noise_threshold"] from spikeinterface.core.template import Templates + templates = Templates( templates_array=templates_array[valid_templates], sampling_frequency=fs, From 3ab28529812927d2664902729c2609a8a9aac777 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 10 Apr 2025 08:54:49 +0200 Subject: [PATCH 24/33] WIP --- .../sorters/internal/spyking_circus2.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index caf8f08858..5baa56ed22 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -161,7 +161,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) + noise_levels = get_noise_levels(recording_w, + return_scaled=False, + seed=seed, + **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -193,9 +196,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_folder.mkdir(parents=True, exist_ok=True) np.save(clustering_folder / "noise_levels.npy", noise_levels) - detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=params["seed"], **job_kwargs - ) detection_params['random_chunk_kwargs'] = {"num_chunks_per_segment": 5, "seed" : params["seed"]} @@ -228,25 +228,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): **detection_params, **job_kwargs, ) - detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before if debug: np.save(clustering_folder / "waveforms.npy", waveforms) np.save(clustering_folder / "prototype.npy", prototype) - if skip_peaks: - detection_params["skip_after_n_peaks"] = n_peaks - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: waveforms = None - if skip_peaks: - detection_params["skip_after_n_peaks"] = n_peaks detection_method = "locally_exclusive" + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) peaks = peaks[order] + if debug: np.save(clustering_folder / "peaks.npy", peaks) From 2302142d8a5e104ab80b4fbb51cb900db84d7977 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 10 Apr 2025 08:56:00 +0200 Subject: [PATCH 25/33] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 5baa56ed22..1bcf8347b6 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -69,7 +69,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", "seed": "An int to control how chunks are shuffled while detecting peaks", - "deterministic": "a boolean to specify if the sorting should be deterministic or not. If True, then the seed will be used to shuffle the chunks", + "deterministic": "A boolean to specify if the sorting should be deterministic or not. If True, then the seed will be used to shuffle the chunks", "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } From a4807b333df2d5419bbfc50b1ae00a897d57ad20 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 07:01:01 +0000 Subject: [PATCH 26/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 15 +++++-------- .../sortingcomponents/clustering/circus.py | 22 +++++++++---------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 1bcf8347b6..7a67d5f9c4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -161,10 +161,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, - return_scaled=False, - seed=seed, - **job_kwargs) + noise_levels = get_noise_levels(recording_w, return_scaled=False, seed=seed, **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -196,14 +193,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_folder.mkdir(parents=True, exist_ok=True) np.save(clustering_folder / "noise_levels.npy", noise_levels) - detection_params['random_chunk_kwargs'] = {"num_chunks_per_segment": 5, - "seed" : params["seed"]} + detection_params["random_chunk_kwargs"] = {"num_chunks_per_segment": 5, "seed": params["seed"]} if detection_method == "matched_filtering": if not deterministic: from spikeinterface.sortingcomponents.tools import ( get_prototype_and_waveforms_from_recording, ) + prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( recording_w, n_peaks=10000, @@ -217,6 +214,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.tools import ( get_prototype_and_waveforms_from_peaks, ) + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) prototype, waveforms, _ = get_prototype_and_waveforms_from_peaks( recording_w, @@ -233,21 +231,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if debug: np.save(clustering_folder / "waveforms.npy", waveforms) np.save(clustering_folder / "prototype.npy", prototype) - + else: waveforms = None detection_method = "locally_exclusive" if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks - + detection_params["recording_slices"] = get_shuffled_recording_slices( recording_w, seed=params["seed"], **job_kwargs ) peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) peaks = peaks[order] - if debug: np.save(clustering_folder / "peaks.npy", peaks) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 5188b8bc2d..56687c133a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -55,7 +55,7 @@ class CircusClustering: "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, - "seed" : 42, + "seed": 42, "noise_threshold": 4, "rank": 5, "templates_from_svd": False, @@ -90,19 +90,15 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # SVD for time compression if params["few_waveforms"] is None: few_peaks = select_peaks( - peaks, - recording=recording, - method="uniform", + peaks, + recording=recording, + method="uniform", seed=params["seed"], - n_peaks=10000, - margin=(nbefore, nafter) + n_peaks=10000, + margin=(nbefore, nafter), ) few_wfs = extract_waveform_at_max_channel( - recording, - few_peaks, - ms_before=ms_before, - ms_after=ms_after, - **job_kwargs + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) wfs = few_wfs[:, :, 0] else: @@ -181,6 +177,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ) else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd + templates = get_templates_from_peaks_and_svd( recording, peaks, @@ -192,7 +189,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): sparse_mask, operator="median", ) - + templates_array = templates.templates_array best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) peak_snrs = np.abs(templates_array[:, nbefore, :]) @@ -200,6 +197,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): valid_templates = best_snrs_ratio > params["noise_threshold"] from spikeinterface.core.template import Templates + templates = Templates( templates_array=templates_array[valid_templates], sampling_frequency=fs, From 571a5626e66cd76fa491f87b1411bf9f3ce3811d Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 10 Apr 2025 13:32:22 +0200 Subject: [PATCH 27/33] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 7a67d5f9c4..be63955013 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -41,7 +41,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, "seed": 42, - "deterministic": True, + "deterministic": False, "debug": False, } @@ -143,6 +143,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" params["motion_correction"].update({"folder": motion_folder}) + noise_levels = get_noise_levels(recording_f, return_scaled=False, + random_slices_kwargs={"seed" : seed}, **job_kwargs) + params["detect_kwargs"] = {"noise_levels" : noise_levels} recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs) else: motion_folder = None @@ -161,7 +164,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False, seed=seed, **job_kwargs) + noise_levels = get_noise_levels(recording_w, return_scaled=False, + random_slices_kwargs={"seed" : seed}, **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) From 136bb5829d557d46ac59be55fb7e66ab93ba8462 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 11:35:09 +0000 Subject: [PATCH 28/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index be63955013..41641e1e7a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -143,9 +143,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" params["motion_correction"].update({"folder": motion_folder}) - noise_levels = get_noise_levels(recording_f, return_scaled=False, - random_slices_kwargs={"seed" : seed}, **job_kwargs) - params["detect_kwargs"] = {"noise_levels" : noise_levels} + noise_levels = get_noise_levels( + recording_f, return_scaled=False, random_slices_kwargs={"seed": seed}, **job_kwargs + ) + params["detect_kwargs"] = {"noise_levels": noise_levels} recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs) else: motion_folder = None @@ -164,8 +165,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False, - random_slices_kwargs={"seed" : seed}, **job_kwargs) + noise_levels = get_noise_levels( + recording_w, return_scaled=False, random_slices_kwargs={"seed": seed}, **job_kwargs + ) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) From c733e88389ecf164431e81762f82fc44a5e09429 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 10 Apr 2025 13:46:42 +0200 Subject: [PATCH 29/33] Determinism can only be achieved by controlling tsvd --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 + src/spikeinterface/sortingcomponents/clustering/circus.py | 6 ++++-- .../sortingcomponents/clustering/graph_clustering.py | 1 + .../sortingcomponents/clustering/graph_tools.py | 5 +++-- src/spikeinterface/sortingcomponents/clustering/peak_svd.py | 3 ++- src/spikeinterface/sortingcomponents/clustering/split.py | 5 +++-- 6 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index be63955013..6a3441fc52 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -289,6 +289,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["ms_before"] = ms_before clustering_params["ms_after"] = ms_after clustering_params["verbose"] = verbose + clustering_params["seed"] = seed clustering_params["templates_from_svd"] = templates_from_svd clustering_params["tmp_folder"] = sorter_output_folder / "clustering" clustering_params["debug"] = debug diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 56687c133a..495c88629e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -55,7 +55,7 @@ class CircusClustering: "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, - "seed": 42, + "seed": None, "noise_threshold": 4, "rank": 5, "templates_from_svd": False, @@ -114,7 +114,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.decomposition import TruncatedSVD - svd_model = TruncatedSVD(params["n_svd"]) + svd_model = TruncatedSVD(params["n_svd"], random_state=params["seed"]) svd_model.fit(wfs) features_folder = tmp_folder / "tsvd_features" features_folder.mkdir(exist_ok=True) @@ -127,6 +127,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): svd_model=svd_model, radius_um=radius_um, folder=features_folder, + seed=params["seed"], **job_kwargs, ) @@ -142,6 +143,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): split_kwargs = params["split_kwargs"].copy() split_kwargs["neighbours_mask"] = neighbours_mask split_kwargs["waveforms_sparse_mask"] = sparse_mask + split_kwargs["seed"] = params["seed"] split_kwargs["min_size_split"] = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 50) split_kwargs["clusterer_kwargs"] = params["hdbscan_kwargs"] diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py index a0034e7741..39c6ad9c3f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py @@ -80,6 +80,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): radius_um=radius_um, motion_aware=motion_aware, motion=None, + seed=params["seed"], **params["extract_peaks_svd_kwargs"], # **job_kwargs ) diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py index 409181bcf3..7cfebd3526 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py @@ -20,6 +20,7 @@ def create_graph_from_peak_features( sparse_mode="knn", apply_local_svd=False, n_components=10, + seed=None, normed_distances=False, n_neighbors=20, ensure_symetric=False, @@ -141,12 +142,12 @@ def create_graph_from_peak_features( if apply_local_svd: if isinstance(n_components, int): n_components = min(n_components, flatten_feat.shape[1]) - tsvd = TruncatedSVD(n_components) + tsvd = TruncatedSVD(n_components, random_state=seed) flatten_feat = tsvd.fit_transform(flatten_feat) elif isinstance(n_components, float): assert 0 < n_components < 1, "n_components should be in ]0, 1[" - tsvd = TruncatedSVD(flatten_feat.shape[1]) + tsvd = TruncatedSVD(flatten_feat.shape[1], random_state=seed) flatten_feat = tsvd.fit_transform(flatten_feat) n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_components) + 1 flatten_feat = flatten_feat[:, :n_explain] diff --git a/src/spikeinterface/sortingcomponents/clustering/peak_svd.py b/src/spikeinterface/sortingcomponents/clustering/peak_svd.py index e58d2621c7..f02855dd1f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/peak_svd.py +++ b/src/spikeinterface/sortingcomponents/clustering/peak_svd.py @@ -26,6 +26,7 @@ def extract_peaks_svd( motion_aware=False, motion=None, folder=None, + seed=None, ensure_peak_same_sign=True, **job_kwargs, ): @@ -67,7 +68,7 @@ def extract_peaks_svd( if ensure_peak_same_sign: wfs *= -np.sign(wfs[:, nbefore])[:, np.newaxis] - svd_model = TruncatedSVD(n_components=n_components) + svd_model = TruncatedSVD(n_components=n_components, random_state=seed) svd_model.fit(wfs) need_save_model = True else: diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index f8daa51e5e..ca7812a4a2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -211,6 +211,7 @@ def split( waveforms_sparse_mask=None, min_size_split=25, n_pca_features=2, + seed=None, projection_mode="tsvd", minimum_overlap_ratio=0.25, ): @@ -258,7 +259,7 @@ def split( elif projection_mode == "tsvd": from sklearn.decomposition import TruncatedSVD - tsvd = TruncatedSVD(nb_dimensions) + tsvd = TruncatedSVD(nb_dimensions, random_state=seed) final_features = tsvd.fit_transform(flatten_features) n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_pca_features) + 1 final_features = final_features[:, :n_explain] @@ -272,7 +273,7 @@ def split( elif projection_mode == "tsvd": from sklearn.decomposition import TruncatedSVD - tsvd = TruncatedSVD(n_pca_features) + tsvd = TruncatedSVD(n_pca_features, random_state=seed) final_features = tsvd.fit_transform(flatten_features) else: From 2f1674b97f1967e0303cbed1aae14b4d79718e42 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 10 Apr 2025 16:54:01 +0200 Subject: [PATCH 30/33] Make delete_mixtures optional --- .../sortingcomponents/clustering/circus.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 92e4907837..cde61b87f0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -41,6 +41,7 @@ class CircusClustering: "allow_single_cluster": True, }, "cleaning_kwargs": {}, + "remove_mixtures": False, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { @@ -184,12 +185,18 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): operator="median", ) + templates_array = templates.templates_array best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) peak_snrs = np.abs(templates_array[:, nbefore, :]) best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + old_unit_ids = templates.unit_ids.copy() valid_templates = best_snrs_ratio > params["noise_threshold"] + mask = np.isin(peak_labels, old_unit_ids[~valid_templates]) + peak_labels[mask] = -1 + + from spikeinterface.core.template import Templates templates = Templates( @@ -210,12 +217,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 + old_unit_ids = templates.unit_ids.copy() templates = remove_empty_templates(templates) - mask = np.isin(peak_labels, np.where(empty_templates)[0]) - peak_labels[mask] = -1 - - mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + mask = np.isin(peak_labels, old_unit_ids[empty_templates]) peak_labels[mask] = -1 if verbose: @@ -224,12 +229,15 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): cleaning_job_kwargs = job_kwargs.copy() cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() + labels = np.unique(peak_labels) + labels = labels[labels >= 0] - labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params - ) + if params["remove_mixtures"]: + labels, peak_labels = remove_duplicates_via_matching( + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params + ) - if verbose: - print("Kept %d non-duplicated clusters" % len(labels)) + if verbose: + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels, svd_model, peaks_svd, sparse_mask From 872827336cdf603261196e124018cb1b865b4daf Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 10 Apr 2025 16:55:48 +0200 Subject: [PATCH 31/33] Better logs --- .../sortingcomponents/clustering/circus.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index cde61b87f0..24fc37e737 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -223,21 +223,25 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): mask = np.isin(peak_labels, old_unit_ids[empty_templates]) peak_labels[mask] = -1 - if verbose: - print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - - cleaning_job_kwargs = job_kwargs.copy() - cleaning_job_kwargs["progress_bar"] = False - cleaning_params = params["cleaning_kwargs"].copy() labels = np.unique(peak_labels) labels = labels[labels >= 0] if params["remove_mixtures"]: + if verbose: + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False + cleaning_params = params["cleaning_kwargs"].copy() + labels, peak_labels = remove_duplicates_via_matching( templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: print("Kept %d non-duplicated clusters" % len(labels)) + else: + if verbose: + print("Kept %d raw clusters" % len(labels)) return labels, peak_labels, svd_model, peaks_svd, sparse_mask From 3c854bb80d94d794a2951dd741d25ad8d197fb32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 14:56:00 +0000 Subject: [PATCH 32/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index cde61b87f0..70756df62e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -185,7 +185,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): operator="median", ) - templates_array = templates.templates_array best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) peak_snrs = np.abs(templates_array[:, nbefore, :]) @@ -196,7 +195,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): mask = np.isin(peak_labels, old_unit_ids[~valid_templates]) peak_labels[mask] = -1 - from spikeinterface.core.template import Templates templates = Templates( From d3dc4e511a1f75db3995459cba11ea391ebdc994 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 15:07:11 +0000 Subject: [PATCH 33/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index d3ff4c236f..9401799992 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -239,7 +239,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): cleaning_job_kwargs = job_kwargs.copy() cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() - + labels, peak_labels = remove_duplicates_via_matching( templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params )