Skip to content

Add Literals to lots of function arguments #3904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from .filter import highpass_filter
from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording

detect_bad_channels_method_names = ("std", "mad", "coherence+psd", "neighborhood_r2")


def detect_bad_channels(
recording: BaseRecording,
method: str = "coherence+psd",
method: Literal["std", "mad", "coherence+psd", "neighborhood_r2"] = "coherence+psd",
std_mad_threshold: float = 5,
psd_hf_threshold: float = 0.02,
dead_channel_threshold: float = -0.5,
Expand Down Expand Up @@ -121,8 +123,9 @@ def detect_bad_channels(
"""
import scipy.stats

method_list = ("std", "mad", "coherence+psd", "neighborhood_r2")
assert method in method_list, f"{method} is not a valid method. Available methods are {method_list}"
assert (
method in detect_bad_channels_method_names
), f"{method} is not a valid method. Available methods are {detect_bad_channels_method_names}."

# Get random subset of data to estimate from
random_chunk_kwargs = dict(
Expand Down
12 changes: 12 additions & 0 deletions src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from spikeinterface.core import generate_recording
from spikeinterface.preprocessing import detect_bad_channels, highpass_filter
from spikeinterface.preprocessing.detect_bad_channels import detect_bad_channels_method_names

from typing import get_args, Literal

try:
# WARNING : this is not this package https://pypi.org/project/neurodsp/
Expand All @@ -18,6 +21,15 @@
HAVE_NPIX = False


def test_literal_hardcoded_values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if we think about it what is the value of this test? This is basically saying that you don't trust the type hints of python and want to confirm they exist? We wouldn't add this test to every file right?

What is it that you want to test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to protect against a developer adding a new method and forgetting to add it to the typing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see so this is just a change detector to annoy Sam :P

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after carefully reading your PR discussion + your edits I think this:

EDIT: The steps 1. 2. 3. might be overkill?? Should I just do 1.?

I think we start with just 1 and then if we decide to add a test later we do that. I'm actually fine with 2 as well because it allows us to easily add in a useful error message (ie you have entered x but only [y,z,a,b] are allowed). But 2 just adds so much work and you're right that if we do 2 then if we want to ensure it is enforced we have to do 3, but then that will cause test failures in refactor potentially. So for me 2 and 3 should really be a 1.0 type move when we are saying the API is 100% stable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came to the same conclusion while trying to get to sleep last night. Importantly, if we miss a method from a Literal, it doesn't mess any of the code up - it just gives the user bad advice via their typing tool

"""
The possible strings allowed by the `method` argument are hardcoded. Here we check they are consistent
with the methods list.
"""
detect_bad_channels_method_literals = get_args(eval(detect_bad_channels.__annotations__["method"]))
assert set(detect_bad_channels_method_literals) == set(detect_bad_channels_method_names)


def test_detect_bad_channels_std_mad():
num_channels = 4
sampling_frequency = 30000.0
Expand Down