From fce1988ee482d2353c35492e3291582caa83ff1d Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 6 May 2025 12:38:51 +0100 Subject: [PATCH 1/2] add detect_bad_channels method Literal --- .../preprocessing/detect_bad_channels.py | 9 ++++++--- .../preprocessing/tests/test_detect_bad_channels.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 2175351f0b..eab6db7081 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -7,10 +7,12 @@ from .filter import highpass_filter from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording +detect_bad_channels_method_names = ("std", "mad", "coherence+psd", "neighborhood_r2") + def detect_bad_channels( recording: BaseRecording, - method: str = "coherence+psd", + method: Literal["std", "mad", "coherence+psd", "neighborhood_r2"] = "coherence+psd", std_mad_threshold: float = 5, psd_hf_threshold: float = 0.02, dead_channel_threshold: float = -0.5, @@ -121,8 +123,9 @@ def detect_bad_channels( """ import scipy.stats - method_list = ("std", "mad", "coherence+psd", "neighborhood_r2") - assert method in method_list, f"{method} is not a valid method. Available methods are {method_list}" + assert ( + method in detect_bad_channels_method_names + ), f"{method} is not a valid method. Available methods are {detect_bad_channels_method_names}." # Get random subset of data to estimate from random_chunk_kwargs = dict( diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 4622be1440..aa335b3ea3 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -6,6 +6,9 @@ from spikeinterface.core import generate_recording from spikeinterface.preprocessing import detect_bad_channels, highpass_filter +from spikeinterface.preprocessing.detect_bad_channels import detect_bad_channels_method_names + +from typing import get_type_hints, get_args try: # WARNING : this is not this package https://pypi.org/project/neurodsp/ @@ -18,6 +21,15 @@ HAVE_NPIX = False +def test_literal_hardcoded_values(): + """ + The possible strings allowed by the `method` argument are hardcoded. Here we check they are consistent + with the methods list. + """ + detect_bad_channels_method_literals = get_args(get_type_hints(detect_bad_channels)["method"]) + assert set(detect_bad_channels_method_literals) == set(detect_bad_channels_method_names) + + def test_detect_bad_channels_std_mad(): num_channels = 4 sampling_frequency = 30000.0 From 5fc50fb6fb2a7bf12fa7ba8c43deacaa73fbde5a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 6 May 2025 13:18:40 +0100 Subject: [PATCH 2/2] update test for 3.9 --- .../preprocessing/tests/test_detect_bad_channels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index aa335b3ea3..beaa1ec310 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -8,7 +8,7 @@ from spikeinterface.preprocessing import detect_bad_channels, highpass_filter from spikeinterface.preprocessing.detect_bad_channels import detect_bad_channels_method_names -from typing import get_type_hints, get_args +from typing import get_args, Literal try: # WARNING : this is not this package https://pypi.org/project/neurodsp/ @@ -26,7 +26,7 @@ def test_literal_hardcoded_values(): The possible strings allowed by the `method` argument are hardcoded. Here we check they are consistent with the methods list. """ - detect_bad_channels_method_literals = get_args(get_type_hints(detect_bad_channels)["method"]) + detect_bad_channels_method_literals = get_args(eval(detect_bad_channels.__annotations__["method"])) assert set(detect_bad_channels_method_literals) == set(detect_bad_channels_method_names)