diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f038fe8b93..93c76ab4e1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -11,7 +11,6 @@ from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, - get_prototype_and_waveforms_from_recording, get_shuffled_recording_slices, ) from spikeinterface.core.basesorting import minimum_spike_dtype @@ -42,6 +41,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, "seed": 42, + "deterministic": False, "debug": False, } @@ -69,6 +69,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", "seed": "An int to control how chunks are shuffled while detecting peaks", + "deterministic": "A boolean to specify if the sorting should be deterministic or not. If True, then the seed will be used to shuffle the chunks", "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } @@ -110,6 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): radius_um = params["general"].get("radius_um", 75) peak_sign = params["detection"].get("peak_sign", "neg") templates_from_svd = params["templates_from_svd"] + deterministic = params["deterministic"] debug = params["debug"] seed = params["seed"] apply_preprocessing = params["apply_preprocessing"] @@ -141,6 +143,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" params["motion_correction"].update({"folder": motion_folder}) + noise_levels = get_noise_levels( + recording_f, return_scaled=False, random_slices_kwargs={"seed": seed}, **job_kwargs + ) + params["detect_kwargs"] = {"noise_levels": noise_levels} recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs) else: motion_folder = None @@ -149,6 +155,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # TODO add , regularize=True chen ready whitening_kwargs = params["whitening"].copy() whitening_kwargs["dtype"] = "float32" + whitening_kwargs["seed"] = params["seed"] whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False) if num_channels == 1: whitening_kwargs["regularize"] = False @@ -158,7 +165,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) + noise_levels = get_noise_levels( + recording_w, return_scaled=False, random_slices_kwargs={"seed": seed}, **job_kwargs + ) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -179,6 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): n_peaks_per_channel = selection_params.get("n_peaks_per_channel", 5000) min_n_peaks = selection_params.get("min_n_peaks", 100000) skip_peaks = not params["multi_units_only"] and selection_method == "uniform" + skip_peaks = skip_peaks and not params["deterministic"] max_n_peaks = n_peaks_per_channel * num_channels n_peaks = max(min_n_peaks, max_n_peaks) selection_params["n_peaks"] = n_peaks @@ -189,36 +199,55 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_folder.mkdir(parents=True, exist_ok=True) np.save(clustering_folder / "noise_levels.npy", noise_levels) + detection_params["random_chunk_kwargs"] = {"num_chunks_per_segment": 5, "seed": params["seed"]} + if detection_method == "matched_filtering": - prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( - recording_w, - n_peaks=10000, - ms_before=ms_before, - ms_after=ms_after, - seed=seed, - **detection_params, - **job_kwargs, - ) + if not deterministic: + from spikeinterface.sortingcomponents.tools import ( + get_prototype_and_waveforms_from_recording, + ) + + prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( + recording_w, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=seed, + **detection_params, + **job_kwargs, + ) + else: + from spikeinterface.sortingcomponents.tools import ( + get_prototype_and_waveforms_from_peaks, + ) + + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) + prototype, waveforms, _ = get_prototype_and_waveforms_from_peaks( + recording_w, + peaks, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=seed, + **detection_params, + **job_kwargs, + ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before if debug: np.save(clustering_folder / "waveforms.npy", waveforms) np.save(clustering_folder / "prototype.npy", prototype) - if skip_peaks: - detection_params["skip_after_n_peaks"] = n_peaks - detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=seed, **job_kwargs - ) - detection_method = "matched_filtering" + else: waveforms = None - if skip_peaks: - detection_params["skip_after_n_peaks"] = n_peaks - detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=seed, **job_kwargs - ) detection_method = "locally_exclusive" + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) peaks = peaks[order] @@ -262,6 +291,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["ms_before"] = ms_before clustering_params["ms_after"] = ms_after clustering_params["verbose"] = verbose + clustering_params["seed"] = seed clustering_params["templates_from_svd"] = templates_from_svd clustering_params["tmp_folder"] = sorter_output_folder / "clustering" clustering_params["debug"] = debug diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 0fd58d3011..9401799992 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -56,6 +56,7 @@ class CircusClustering: "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, + "seed": None, "noise_threshold": 4, "rank": 5, "templates_from_svd": False, @@ -90,7 +91,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # SVD for time compression if params["few_waveforms"] is None: few_peaks = select_peaks( - peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter) + peaks, + recording=recording, + method="uniform", + seed=params["seed"], + n_peaks=10000, + margin=(nbefore, nafter), ) few_wfs = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs @@ -109,7 +115,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.decomposition import TruncatedSVD - svd_model = TruncatedSVD(params["n_svd"]) + svd_model = TruncatedSVD(params["n_svd"], random_state=params["seed"]) svd_model.fit(wfs) features_folder = tmp_folder / "tsvd_features" features_folder.mkdir(exist_ok=True) @@ -122,6 +128,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): svd_model=svd_model, radius_um=radius_um, folder=features_folder, + seed=params["seed"], **job_kwargs, ) @@ -137,6 +144,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): split_kwargs = params["split_kwargs"].copy() split_kwargs["neighbours_mask"] = neighbours_mask split_kwargs["waveforms_sparse_mask"] = sparse_mask + split_kwargs["seed"] = params["seed"] split_kwargs["min_size_split"] = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 50) split_kwargs["clusterer_kwargs"] = params["hdbscan_kwargs"] diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py index a0034e7741..39c6ad9c3f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py @@ -80,6 +80,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): radius_um=radius_um, motion_aware=motion_aware, motion=None, + seed=params["seed"], **params["extract_peaks_svd_kwargs"], # **job_kwargs ) diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py index 409181bcf3..7cfebd3526 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py @@ -20,6 +20,7 @@ def create_graph_from_peak_features( sparse_mode="knn", apply_local_svd=False, n_components=10, + seed=None, normed_distances=False, n_neighbors=20, ensure_symetric=False, @@ -141,12 +142,12 @@ def create_graph_from_peak_features( if apply_local_svd: if isinstance(n_components, int): n_components = min(n_components, flatten_feat.shape[1]) - tsvd = TruncatedSVD(n_components) + tsvd = TruncatedSVD(n_components, random_state=seed) flatten_feat = tsvd.fit_transform(flatten_feat) elif isinstance(n_components, float): assert 0 < n_components < 1, "n_components should be in ]0, 1[" - tsvd = TruncatedSVD(flatten_feat.shape[1]) + tsvd = TruncatedSVD(flatten_feat.shape[1], random_state=seed) flatten_feat = tsvd.fit_transform(flatten_feat) n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_components) + 1 flatten_feat = flatten_feat[:, :n_explain] diff --git a/src/spikeinterface/sortingcomponents/clustering/peak_svd.py b/src/spikeinterface/sortingcomponents/clustering/peak_svd.py index e58d2621c7..f02855dd1f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/peak_svd.py +++ b/src/spikeinterface/sortingcomponents/clustering/peak_svd.py @@ -26,6 +26,7 @@ def extract_peaks_svd( motion_aware=False, motion=None, folder=None, + seed=None, ensure_peak_same_sign=True, **job_kwargs, ): @@ -67,7 +68,7 @@ def extract_peaks_svd( if ensure_peak_same_sign: wfs *= -np.sign(wfs[:, nbefore])[:, np.newaxis] - svd_model = TruncatedSVD(n_components=n_components) + svd_model = TruncatedSVD(n_components=n_components, random_state=seed) svd_model.fit(wfs) need_save_model = True else: diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index f8daa51e5e..ca7812a4a2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -211,6 +211,7 @@ def split( waveforms_sparse_mask=None, min_size_split=25, n_pca_features=2, + seed=None, projection_mode="tsvd", minimum_overlap_ratio=0.25, ): @@ -258,7 +259,7 @@ def split( elif projection_mode == "tsvd": from sklearn.decomposition import TruncatedSVD - tsvd = TruncatedSVD(nb_dimensions) + tsvd = TruncatedSVD(nb_dimensions, random_state=seed) final_features = tsvd.fit_transform(flatten_features) n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_pca_features) + 1 final_features = final_features[:, :n_explain] @@ -272,7 +273,7 @@ def split( elif projection_mode == "tsvd": from sklearn.decomposition import TruncatedSVD - tsvd = TruncatedSVD(n_pca_features) + tsvd = TruncatedSVD(n_pca_features, random_state=seed) final_features = tsvd.fit_transform(flatten_features) else: diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index dcc543c1c3..8e10f45624 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -571,7 +571,6 @@ def check_params( # ) assert peak_sign in ("both", "neg", "pos") - if noise_levels is None: noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) abs_thresholds = noise_levels * detect_threshold diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index e4fd3c2539..3c15e56726 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -171,7 +171,6 @@ def get_prototype_and_waveforms_from_recording( pipeline_nodes = [node] recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) - res = detect_peaks( recording, pipeline_nodes=pipeline_nodes, @@ -180,7 +179,6 @@ def get_prototype_and_waveforms_from_recording( **detection_kwargs, **job_kwargs, ) - rng = np.random.RandomState(seed) indices = rng.permutation(np.arange(len(res[0])))