Skip to content

Sc2 deterministic mode #3854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
d838240
WIP
yger Apr 4, 2025
776e881
WIP
yger Apr 4, 2025
28f1623
WIP
yger Apr 4, 2025
86c1cf7
WIP
yger Apr 7, 2025
01f97c9
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Apr 7, 2025
8e33d08
WIP
yger Apr 8, 2025
e499302
Example of how to use SVD to estimate templates in SC2
yger Apr 8, 2025
5be94da
Patching to get a working example
yger Apr 8, 2025
2be228b
WIP
yger Apr 8, 2025
bd7c7be
WIP
yger Apr 8, 2025
1cd89f3
WIP
yger Apr 8, 2025
c28a7b6
WIP
yger Apr 8, 2025
8e455f9
WIP
yger Apr 8, 2025
6de8310
WIP
yger Apr 8, 2025
3fb5fa6
Cosmetic
yger Apr 8, 2025
5eec5e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2025
74f52bd
Patch
yger Apr 8, 2025
1388666
Merge branch 'returned_svd' of github.com:yger/spikeinterface into re…
yger Apr 8, 2025
ff10442
WIP
yger Apr 9, 2025
d0333dd
WIP
yger Apr 9, 2025
220fe03
Fix
yger Apr 9, 2025
97e8c67
WIP
yger Apr 9, 2025
83dceec
Merge branch 'main' into returned_svd
yger Apr 9, 2025
2cb39eb
WIP
yger Apr 9, 2025
c502f4d
WIP
yger Apr 9, 2025
cc4fa8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
de0847e
WIP
yger Apr 10, 2025
3ab2852
WIP
yger Apr 10, 2025
2302142
WIP
yger Apr 10, 2025
a4807b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
571a562
WIP
yger Apr 10, 2025
136bb58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
c733e88
Determinism can only be achieved by controlling tsvd
yger Apr 10, 2025
93af5b6
Merge branch 'sc2_deterministic_mode' of github.com:yger/spikeinterfa…
yger Apr 10, 2025
2f1674b
Make delete_mixtures optional
yger Apr 10, 2025
8728273
Better logs
yger Apr 10, 2025
3c854bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
d7a4b1c
Merge branch 'returned_svd' of github.com:yger/spikeinterface into re…
yger Apr 10, 2025
2ef7365
Merge branch 'returned_svd' into sc2_deterministic_mode
yger Apr 10, 2025
d3dc4e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
e08a150
Sync with main
yger Apr 16, 2025
abc1d00
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 22, 2025
8f94e5c
Merge branch 'main' into sc2_deterministic_mode
yger May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion
from spikeinterface.sortingcomponents.tools import (
cache_preprocessing,
get_prototype_and_waveforms_from_recording,
get_shuffled_recording_slices,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
Expand Down Expand Up @@ -42,6 +41,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.75},
"seed": 42,
"deterministic": False,
"debug": False,
}

Expand Down Expand Up @@ -69,6 +69,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)",
"job_kwargs": "A dictionary to specify how many jobs and which parameters they should used",
"seed": "An int to control how chunks are shuffled while detecting peaks",
"deterministic": "A boolean to specify if the sorting should be deterministic or not. If True, then the seed will be used to shuffle the chunks",
"debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging",
}

Expand Down Expand Up @@ -110,6 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
radius_um = params["general"].get("radius_um", 75)
peak_sign = params["detection"].get("peak_sign", "neg")
templates_from_svd = params["templates_from_svd"]
deterministic = params["deterministic"]
debug = params["debug"]
seed = params["seed"]
apply_preprocessing = params["apply_preprocessing"]
Expand Down Expand Up @@ -141,6 +143,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
print("Motion correction activated (probe geometry compatible)")
motion_folder = sorter_output_folder / "motion"
params["motion_correction"].update({"folder": motion_folder})
noise_levels = get_noise_levels(
recording_f, return_scaled=False, random_slices_kwargs={"seed": seed}, **job_kwargs
)
params["detect_kwargs"] = {"noise_levels": noise_levels}
recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs)
else:
motion_folder = None
Expand All @@ -149,6 +155,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# TODO add , regularize=True chen ready
whitening_kwargs = params["whitening"].copy()
whitening_kwargs["dtype"] = "float32"
whitening_kwargs["seed"] = params["seed"]
whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False)
if num_channels == 1:
whitening_kwargs["regularize"] = False
Expand All @@ -158,7 +165,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

recording_w = whiten(recording_f, **whitening_kwargs)

noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs)
noise_levels = get_noise_levels(
recording_w, return_scaled=False, random_slices_kwargs={"seed": seed}, **job_kwargs
)

if recording_w.check_serializability("json"):
recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None)
Expand All @@ -179,6 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
n_peaks_per_channel = selection_params.get("n_peaks_per_channel", 5000)
min_n_peaks = selection_params.get("min_n_peaks", 100000)
skip_peaks = not params["multi_units_only"] and selection_method == "uniform"
skip_peaks = skip_peaks and not params["deterministic"]
max_n_peaks = n_peaks_per_channel * num_channels
n_peaks = max(min_n_peaks, max_n_peaks)
selection_params["n_peaks"] = n_peaks
Expand All @@ -189,36 +199,55 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_folder.mkdir(parents=True, exist_ok=True)
np.save(clustering_folder / "noise_levels.npy", noise_levels)

detection_params["random_chunk_kwargs"] = {"num_chunks_per_segment": 5, "seed": params["seed"]}

if detection_method == "matched_filtering":
prototype, waveforms, _ = get_prototype_and_waveforms_from_recording(
recording_w,
n_peaks=10000,
ms_before=ms_before,
ms_after=ms_after,
seed=seed,
**detection_params,
**job_kwargs,
)
if not deterministic:
from spikeinterface.sortingcomponents.tools import (
get_prototype_and_waveforms_from_recording,
)

prototype, waveforms, _ = get_prototype_and_waveforms_from_recording(
recording_w,
n_peaks=10000,
ms_before=ms_before,
ms_after=ms_after,
seed=seed,
**detection_params,
**job_kwargs,
)
else:
from spikeinterface.sortingcomponents.tools import (
get_prototype_and_waveforms_from_peaks,
)

peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs)
prototype, waveforms, _ = get_prototype_and_waveforms_from_peaks(
recording_w,
peaks,
n_peaks=10000,
ms_before=ms_before,
ms_after=ms_after,
seed=seed,
**detection_params,
**job_kwargs,
)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before
if debug:
np.save(clustering_folder / "waveforms.npy", waveforms)
np.save(clustering_folder / "prototype.npy", prototype)
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=seed, **job_kwargs
)
detection_method = "matched_filtering"

else:
waveforms = None
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=seed, **job_kwargs
)
detection_method = "locally_exclusive"

if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks

detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=params["seed"], **job_kwargs
)
peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs)
order = np.lexsort((peaks["sample_index"], peaks["segment_index"]))
peaks = peaks[order]
Expand Down Expand Up @@ -262,6 +291,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["ms_before"] = ms_before
clustering_params["ms_after"] = ms_after
clustering_params["verbose"] = verbose
clustering_params["seed"] = seed
clustering_params["templates_from_svd"] = templates_from_svd
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
clustering_params["debug"] = debug
Expand Down
12 changes: 10 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CircusClustering:
"few_waveforms": None,
"ms_before": 0.5,
"ms_after": 0.5,
"seed": None,
"noise_threshold": 4,
"rank": 5,
"templates_from_svd": False,
Expand Down Expand Up @@ -90,7 +91,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
# SVD for time compression
if params["few_waveforms"] is None:
few_peaks = select_peaks(
peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter)
peaks,
recording=recording,
method="uniform",
seed=params["seed"],
n_peaks=10000,
margin=(nbefore, nafter),
)
few_wfs = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
Expand All @@ -109,7 +115,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):

from sklearn.decomposition import TruncatedSVD

svd_model = TruncatedSVD(params["n_svd"])
svd_model = TruncatedSVD(params["n_svd"], random_state=params["seed"])
svd_model.fit(wfs)
features_folder = tmp_folder / "tsvd_features"
features_folder.mkdir(exist_ok=True)
Expand All @@ -122,6 +128,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
svd_model=svd_model,
radius_um=radius_um,
folder=features_folder,
seed=params["seed"],
**job_kwargs,
)

Expand All @@ -137,6 +144,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
split_kwargs = params["split_kwargs"].copy()
split_kwargs["neighbours_mask"] = neighbours_mask
split_kwargs["waveforms_sparse_mask"] = sparse_mask
split_kwargs["seed"] = params["seed"]
split_kwargs["min_size_split"] = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 50)
split_kwargs["clusterer_kwargs"] = params["hdbscan_kwargs"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
radius_um=radius_um,
motion_aware=motion_aware,
motion=None,
seed=params["seed"],
**params["extract_peaks_svd_kwargs"],
# **job_kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def create_graph_from_peak_features(
sparse_mode="knn",
apply_local_svd=False,
n_components=10,
seed=None,
normed_distances=False,
n_neighbors=20,
ensure_symetric=False,
Expand Down Expand Up @@ -141,12 +142,12 @@ def create_graph_from_peak_features(
if apply_local_svd:
if isinstance(n_components, int):
n_components = min(n_components, flatten_feat.shape[1])
tsvd = TruncatedSVD(n_components)
tsvd = TruncatedSVD(n_components, random_state=seed)
flatten_feat = tsvd.fit_transform(flatten_feat)

elif isinstance(n_components, float):
assert 0 < n_components < 1, "n_components should be in ]0, 1["
tsvd = TruncatedSVD(flatten_feat.shape[1])
tsvd = TruncatedSVD(flatten_feat.shape[1], random_state=seed)
flatten_feat = tsvd.fit_transform(flatten_feat)
n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_components) + 1
flatten_feat = flatten_feat[:, :n_explain]
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/sortingcomponents/clustering/peak_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def extract_peaks_svd(
motion_aware=False,
motion=None,
folder=None,
seed=None,
ensure_peak_same_sign=True,
**job_kwargs,
):
Expand Down Expand Up @@ -67,7 +68,7 @@ def extract_peaks_svd(
if ensure_peak_same_sign:
wfs *= -np.sign(wfs[:, nbefore])[:, np.newaxis]

svd_model = TruncatedSVD(n_components=n_components)
svd_model = TruncatedSVD(n_components=n_components, random_state=seed)
svd_model.fit(wfs)
need_save_model = True
else:
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def split(
waveforms_sparse_mask=None,
min_size_split=25,
n_pca_features=2,
seed=None,
projection_mode="tsvd",
minimum_overlap_ratio=0.25,
):
Expand Down Expand Up @@ -258,7 +259,7 @@ def split(
elif projection_mode == "tsvd":
from sklearn.decomposition import TruncatedSVD

tsvd = TruncatedSVD(nb_dimensions)
tsvd = TruncatedSVD(nb_dimensions, random_state=seed)
final_features = tsvd.fit_transform(flatten_features)
n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_pca_features) + 1
final_features = final_features[:, :n_explain]
Expand All @@ -272,7 +273,7 @@ def split(
elif projection_mode == "tsvd":
from sklearn.decomposition import TruncatedSVD

tsvd = TruncatedSVD(n_pca_features)
tsvd = TruncatedSVD(n_pca_features, random_state=seed)

final_features = tsvd.fit_transform(flatten_features)
else:
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,6 @@ def check_params(
# )

assert peak_sign in ("both", "neg", "pos")

if noise_levels is None:
noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs)
abs_thresholds = noise_levels * detect_threshold
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def get_prototype_and_waveforms_from_recording(
pipeline_nodes = [node]

recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs)

res = detect_peaks(
recording,
pipeline_nodes=pipeline_nodes,
Expand All @@ -180,7 +179,6 @@ def get_prototype_and_waveforms_from_recording(
**detection_kwargs,
**job_kwargs,
)

rng = np.random.RandomState(seed)
indices = rng.permutation(np.arange(len(res[0])))

Expand Down