Skip to content

Expose gather mode to tridesclous2 and spykingcircus2 #3719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,14 @@ def run_node_pipeline(
The classical job_kwargs
job_name : str
The name of the pipeline used for the progress_bar
gather_mode : "memory" | "npz"

gather_mode : "memory" | "npy"
How to gather the output of the nodes.
gather_kwargs : dict
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
squeeze_output : bool, default True
If only one output node then squeeze the tuple
folder : str | Path | None
Used for gather_mode="npz"
Used for gather_mode="npy"
names : list of str
Names of outputs.
verbose : bool, default False
Expand Down
12 changes: 11 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates.to_zarr(folder_path=clustering_folder / "templates")

## We launch a OMP matching pursuit by full convolution of the templates and the raw traces

matching_method = params["matching"].get("method", "circus-omp_svd")
gather_mode = params["matching"].pop("gather_mode", "memory")
matching_params = params["matching"].get("method_kwargs", dict())
matching_params["templates"] = templates

if matching_method is not None:
gather_kwargs = {}
if gather_mode == "npy":
gather_kwargs["folder"] = sorter_output_folder / "matching"
spikes = find_spikes_from_templates(
recording_w, matching_method, method_kwargs=matching_params, **job_kwargs
recording_w,
matching_method,
method_kwargs=matching_params,
gather_mode=gather_mode,
gather_kwargs=gather_kwargs,
**job_kwargs,
)

if debug:
Expand Down
25 changes: 23 additions & 2 deletions src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
import unittest

from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

from spikeinterface.sorters import Spykingcircus2Sorter
from spikeinterface.sorters import Spykingcircus2Sorter, run_sorter

from pathlib import Path


class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = Spykingcircus2Sorter

@unittest.skip("performance reason")
def test_with_numpy_gather(self):
recording = self.recording
sorter_name = self.SorterClass.sorter_name
output_folder = self.cache_folder / sorter_name
sorter_params = self.SorterClass.default_params()

sorter_params["matching"]["gather_mode"] = "npy"

sorting = run_sorter(
sorter_name,
recording,
folder=output_folder,
remove_existing_folder=True,
delete_output_folder=False,
verbose=False,
raise_error=True,
**sorter_params,
)
assert (output_folder / "sorter_output" / "matching").is_dir()
assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file()


if __name__ == "__main__":
from spikeinterface import set_global_job_kwargs
Expand Down
24 changes: 23 additions & 1 deletion src/spikeinterface/sorters/internal/tests/test_tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,36 @@

from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

from spikeinterface.sorters import Tridesclous2Sorter
from spikeinterface.sorters import Tridesclous2Sorter, run_sorter

from pathlib import Path


class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = Tridesclous2Sorter

@unittest.skip("performance reason")
def test_with_numpy_gather(self):
recording = self.recording
sorter_name = self.SorterClass.sorter_name
output_folder = self.cache_folder / sorter_name
sorter_params = self.SorterClass.default_params()

sorter_params["matching"]["gather_mode"] = "npy"

sorting = run_sorter(
sorter_name,
recording,
folder=output_folder,
remove_existing_folder=True,
delete_output_folder=False,
verbose=False,
raise_error=True,
**sorter_params,
)
assert (output_folder / "sorter_output" / "matching").is_dir()
assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file()


if __name__ == "__main__":
test = Tridesclous2SorterCommonTestSuite()
Expand Down
19 changes: 14 additions & 5 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
},
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
# "matching": {"method": "circus-omp-svd", "method_kwargs": {}},
"matching": {"method": "wobble", "method_kwargs": {}},
"matching": {"method": "wobble", "method_kwargs": {}, "gather_mode": "memory"},
"job_kwargs": {"n_jobs": -1},
"save_array": True,
}
Expand Down Expand Up @@ -232,13 +232,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates = remove_empty_templates(templates)

## peeler
matching_method = params["matching"]["method"]
matching_params = params["matching"]["method_kwargs"].copy()
matching_method = params["matching"].pop("method")
gather_mode = params["matching"].pop("gather_mode", "memory")
matching_params = params["matching"].get("matching_kwargs", {}).copy()
matching_params["templates"] = templates
if params["matching"]["method"] in ("tdc-peeler",):
if matching_method in ("tdc-peeler",):
matching_params["noise_levels"] = noise_levels
gather_kwargs = {}
if gather_mode == "npy":
gather_kwargs["folder"] = sorter_output_folder / "matching"
spikes = find_spikes_from_templates(
recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs
recording_for_peeler,
method=matching_method,
method_kwargs=matching_params,
gather_mode=gather_mode,
gather_kwargs=gather_kwargs,
**job_kwargs,
)

if params["save_array"]:
Expand Down
24 changes: 20 additions & 4 deletions src/spikeinterface/sortingcomponents/matching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@


def find_spikes_from_templates(
recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs
recording,
method="naive",
method_kwargs={},
extra_outputs=False,
gather_mode="memory",
gather_kwargs=None,
verbose=False,
**job_kwargs,
) -> np.ndarray | tuple[np.ndarray, dict]:
"""Find spike from a recording from given templates.

Expand All @@ -25,10 +32,14 @@ def find_spikes_from_templates(
Keyword arguments for the chosen method
extra_outputs : bool
If True then a dict is also returned is also returned
**job_kwargs : dict
Parameters for ChunkRecordingExecutor
gather_mode : "memory" | "npy", default: "memory"
If "memory" then the output is gathered in memory, if "npy" then the output is gathered on disk
gather_kwargs : dict, optional
The kwargs for the gather method
verbose : Bool, default: False
If True, output is verbose
**job_kwargs : keyword arguments
Parameters for ChunkRecordingExecutor

Returns
-------
Expand All @@ -47,13 +58,18 @@ def find_spikes_from_templates(
node0 = method_class(recording, **method_kwargs)
nodes = [node0]

gather_kwargs = gather_kwargs or {}
names = ["spikes"]

spikes = run_node_pipeline(
recording,
nodes,
job_kwargs,
job_name=f"find spikes ({method})",
gather_mode="memory",
gather_mode=gather_mode,
squeeze_output=True,
names=names,
**gather_kwargs,
)
if extra_outputs:
outputs = node0.get_extra_outputs()
Expand Down