Skip to content

Add multi-segment capability to BaseRasterWidget and children #3805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

jakeswann1
Copy link
Contributor

@jakeswann1 jakeswann1 commented Mar 25, 2025

Adds the option to pass a list of segment indices to the AmplitudesWidget, DriftRasterMapWidget, and RasterWidget to plot across multiple segments, by updating how the base widget handles plotting data. Maintains current default behaviour and SortingView capability. resolves #3801

@jakeswann1
Copy link
Contributor Author

Multi-segment plots would look like this:

image
image
image

@zm711 zm711 added the widgets Related to widgets module label Mar 26, 2025
@jakeswann1
Copy link
Contributor Author

Hey @chrishalcrow - sorry for the delay working on this, I've been away for a while. Think I've addressed your comments, let me know if there's anything else you think could be improved here!

@alejoe91 alejoe91 modified the milestone: 0.102.3 May 2, 2025
st = sorting.get_unit_spike_train(unit_id, segment_index=seg_idx, return_times=True)
if len(st) > 0:
max_time = max(max_time, np.max(st))
duration = max_time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit simpler to do this using the full spike train all at once. You get this using

spikes = sorting.to_spike_vector()

then spikes has a segment_index which can be used to figure out the segment lengths:

segment_slices = [np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) for segment_index in segment_indices]
segment_lengths = [segment_slice[1] - segment_slice[0] for segment_slice in segment_slices

then you gotta convert to time using sorting.sampling_frequency.

Although it's a bit convoluted, I think using this would simplify the code a bit, and I'd be tempted to use it instead of the recording, because then we remove any mention of recordings in this function.

(This is a bit messy because we don't keep track of the segment length in the sorting object )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made the change, but do you not think that removing the option to get segment duration from the recording object introduces inaccuracy because final spike time != recording end time? Obviously in most dense recordings this is a pretty minimal difference, but I'm not sure what the policy is more generally that you guys have been using to handle this

Copy link
Member

@chrishalcrow chrishalcrow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Jake! Looking amazing, and feels much simpler than before - thanks for all the updates.

I think all my feeback is about computing durations of segments. Have a read of the comments in rasters.py first, then the other stuff. Once that's sorted, this should be good to go!

@jakeswann1 jakeswann1 requested a review from chrishalcrow May 12, 2025 16:57
@@ -64,57 +65,75 @@ def __init__(
):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, realised we need to deprecate segment_index in favour of segment_indices. We don't need to do this for BaseRasterWidet, just for the user facing classes.

To do this, we allow the user to pass segment_index for the next few releases. If they do we convert this to segment_indices = [segment_index] and warn the user that the argument will be deprecated in a future version, in this case 1.104.

A clear example of this can be see in pca_metrics.py:

if qm_params is not None and metric_params is None:

The segment index.
segment_indices : list of int or None, default: None
The segment index or indices to use. If None and there are multiple segments, defaults to 0.
If a list of indices is provided, spike trains are concatenated across the specified segments.
unit_ids : list
List of unit ids
time_range : list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to deprecate segment_index somewhere near here

@@ -160,14 +161,19 @@ def __init__(
backend: str | None = None,
**backend_kwargs,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to deprecate segment_index somewhere near here

]

# Calculate durations from max sample in each segment
durations = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spike trains are always sorted, so we don't need to find the max sample index, instead can use something like:

 durations = [(filtered_peaks["sample_index"][end-1]+1) / sampling_frequency for (_, end) in segment_boundaries ]

segment_indices = np.unique(spikes["segment_index"])

durations = []
for seg_idx in segment_indices:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can replace with the searchsorted thing, mostly so that we use the same method to estimate segment length from sortings in the whole codebase. It's also twice as fast. So this for loop would replaced with

segment_boundaries = [
    np.searchsorted(spikes["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices
]

durations = [(spikes["sample_index"][end-1] + 1) / sorting.sampling_frequency for (_, end) in segment_boundaries ]

unit_st = all_spiketrains[unit_id]
all_spiketrains[unit_id] = unit_st[(time_range[0] < unit_st) & (unit_st < time_range[1])]
# Get spikes for this segment and unit
mask = (spikes["segment_index"] == seg_idx) & (spikes["unit_index"] == unit_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, found a gross bug where this doesn't work for sortings with unit_ids which are strings. The internal functions deal with this, so it's probably safest to replace this and the next line with something like

spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=seg_idx) / sorting.sampling_frequency

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
widgets Related to widgets module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-segment support for AmplitudesWidget
4 participants