-
Notifications
You must be signed in to change notification settings - Fork 216
Add multi-segment capability to BaseRasterWidget and children #3805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hey @chrishalcrow - sorry for the delay working on this, I've been away for a while. Think I've addressed your comments, let me know if there's anything else you think could be improved here! |
st = sorting.get_unit_spike_train(unit_id, segment_index=seg_idx, return_times=True) | ||
if len(st) > 0: | ||
max_time = max(max_time, np.max(st)) | ||
duration = max_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit simpler to do this using the full spike train all at once. You get this using
spikes = sorting.to_spike_vector()
then spikes has a segment_index
which can be used to figure out the segment lengths:
segment_slices = [np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) for segment_index in segment_indices]
segment_lengths = [segment_slice[1] - segment_slice[0] for segment_slice in segment_slices
then you gotta convert to time using sorting.sampling_frequency
.
Although it's a bit convoluted, I think using this would simplify the code a bit, and I'd be tempted to use it instead of the recording, because then we remove any mention of recording
s in this function.
(This is a bit messy because we don't keep track of the segment length in the sorting object )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made the change, but do you not think that removing the option to get segment duration from the recording object introduces inaccuracy because final spike time != recording end time? Obviously in most dense recordings this is a pretty minimal difference, but I'm not sure what the policy is more generally that you guys have been using to handle this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Jake! Looking amazing, and feels much simpler than before - thanks for all the updates.
I think all my feeback is about computing durations of segments. Have a read of the comments in rasters.py
first, then the other stuff. Once that's sorted, this should be good to go!
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
@@ -64,57 +65,75 @@ def __init__( | |||
): | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, realised we need to deprecate segment_index
in favour of segment_indices
. We don't need to do this for BaseRasterWidet
, just for the user facing classes.
To do this, we allow the user to pass segment_index
for the next few releases. If they do we convert this to segment_indices = [segment_index]
and warn the user that the argument will be deprecated in a future version, in this case 1.104.
A clear example of this can be see in pca_metrics.py
:
if qm_params is not None and metric_params is None: |
The segment index. | ||
segment_indices : list of int or None, default: None | ||
The segment index or indices to use. If None and there are multiple segments, defaults to 0. | ||
If a list of indices is provided, spike trains are concatenated across the specified segments. | ||
unit_ids : list | ||
List of unit ids | ||
time_range : list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need to deprecate segment_index somewhere near here
@@ -160,14 +161,19 @@ def __init__( | |||
backend: str | None = None, | |||
**backend_kwargs, | |||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need to deprecate segment_index somewhere near here
] | ||
|
||
# Calculate durations from max sample in each segment | ||
durations = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spike trains are always sorted, so we don't need to find the max sample index, instead can use something like:
durations = [(filtered_peaks["sample_index"][end-1]+1) / sampling_frequency for (_, end) in segment_boundaries ]
segment_indices = np.unique(spikes["segment_index"]) | ||
|
||
durations = [] | ||
for seg_idx in segment_indices: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can replace with the searchsorted
thing, mostly so that we use the same method to estimate segment length from sortings in the whole codebase. It's also twice as fast. So this for loop would replaced with
segment_boundaries = [
np.searchsorted(spikes["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices
]
durations = [(spikes["sample_index"][end-1] + 1) / sorting.sampling_frequency for (_, end) in segment_boundaries ]
unit_st = all_spiketrains[unit_id] | ||
all_spiketrains[unit_id] = unit_st[(time_range[0] < unit_st) & (unit_st < time_range[1])] | ||
# Get spikes for this segment and unit | ||
mask = (spikes["segment_index"] == seg_idx) & (spikes["unit_index"] == unit_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, found a gross bug where this doesn't work for sortings with unit_id
s which are strings. The internal functions deal with this, so it's probably safest to replace this and the next line with something like
spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=seg_idx) / sorting.sampling_frequency
Adds the option to pass a list of segment indices to the
AmplitudesWidget
,DriftRasterMapWidget
, andRasterWidget
to plot across multiple segments, by updating how the base widget handles plotting data. Maintains current default behaviour and SortingView capability. resolves #3801