diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 0d245e3783..17a55adf5f 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -112,6 +112,10 @@ PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 11}) PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_neighbors") +if parse(kilosort.__version__) >= parse("4.0.37"): + PARAMS_TO_TEST_DICT.update({"max_cluster_subset": 20}) + PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset") + PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) @@ -254,6 +258,8 @@ def test_initialize_ops_arguments(self): "device", "save_preprocessed_copy", ] + if parse(kilosort.__version__) >= parse("4.0.37"): + expected_arguments += ["gui_mode"] self._check_arguments( initialize_ops, diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 6070961ad4..eb94a34edf 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -438,7 +438,9 @@ def _setup_json_probe_map(cls, recording, sorter_output_folder): chanMap = np.arange(n_chan) xc = positions[:, 0] yc = positions[:, 1] - kcoords = groups.astype(float) + unique_groups = set(groups) + group_map = {group: idx for idx, group in enumerate(unique_groups)} + kcoords = np.array([group_map[group] for group in groups], dtype=int) probe = { "chanMap": chanMap,