-
Notifications
You must be signed in to change notification settings - Fork 216
Add multi-segment capability to BaseRasterWidget and children #3805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4e65593
2539a89
11a1845
16e6272
c4fece4
468e564
cbc790c
540db00
afdea78
03676e0
65a5280
c43fa7c
55e773e
b098557
9dfd1a4
0815572
5d42e1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
from spikeinterface.core import BaseRecording, SortingAnalyzer | ||
from .rasters import BaseRasterWidget | ||
from .utils import get_segment_durations | ||
from spikeinterface.core.motion import Motion | ||
|
||
|
||
|
@@ -117,14 +118,14 @@ class DriftRasterMapWidget(BaseRasterWidget): | |
"spike_locations" extension computed. | ||
direction : "x" or "y", default: "y" | ||
The direction to display. "y" is the depth direction. | ||
segment_index : int, default: None | ||
The segment index to display. | ||
recording : RecordingExtractor | None, default: None | ||
The recording extractor object (only used to get "real" times). | ||
segment_index : int, default: 0 | ||
The segment index to display. | ||
sampling_frequency : float, default: None | ||
The sampling frequency (needed if recording is None). | ||
segment_indices : list of int or None, default: None | ||
The segment index or indices to display. If None and there's only one segment, it's used. | ||
If None and there are multiple segments, you must specify which to use. | ||
If a list of indices is provided, peaks and locations are concatenated across the segments. | ||
depth_lim : tuple or None, default: None | ||
The min and max depth to display, if None (min and max of the recording). | ||
scatter_decimate : int, default: None | ||
|
@@ -149,7 +150,7 @@ def __init__( | |
direction: str = "y", | ||
recording: BaseRecording | None = None, | ||
sampling_frequency: float | None = None, | ||
segment_index: int | None = None, | ||
segment_indices: list[int] | None = None, | ||
depth_lim: tuple[float, float] | None = None, | ||
color_amplitude: bool = True, | ||
scatter_decimate: int | None = None, | ||
|
@@ -160,14 +161,19 @@ def __init__( | |
backend: str | None = None, | ||
**backend_kwargs, | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also need to deprecate segment_index somewhere near here |
||
from matplotlib.pyplot import colormaps | ||
from matplotlib.colors import Normalize | ||
|
||
assert peaks is not None or sorting_analyzer is not None | ||
|
||
if peaks is not None: | ||
assert peak_locations is not None | ||
if recording is None: | ||
assert sampling_frequency is not None, "If recording is None, you must provide the sampling frequency" | ||
else: | ||
sampling_frequency = recording.sampling_frequency | ||
peak_amplitudes = peaks["amplitude"] | ||
|
||
if sorting_analyzer is not None: | ||
if sorting_analyzer.has_recording(): | ||
recording = sorting_analyzer.recording | ||
|
@@ -190,29 +196,56 @@ def __init__( | |
else: | ||
peak_amplitudes = None | ||
|
||
if segment_index is None: | ||
assert ( | ||
len(np.unique(peaks["segment_index"])) == 1 | ||
), "segment_index must be specified if there are multiple segments" | ||
segment_index = 0 | ||
else: | ||
peak_mask = peaks["segment_index"] == segment_index | ||
peaks = peaks[peak_mask] | ||
peak_locations = peak_locations[peak_mask] | ||
if peak_amplitudes is not None: | ||
peak_amplitudes = peak_amplitudes[peak_mask] | ||
|
||
from matplotlib.pyplot import colormaps | ||
unique_segments = np.unique(peaks["segment_index"]) | ||
|
||
if color_amplitude: | ||
amps = peak_amplitudes | ||
if segment_indices is None: | ||
if len(unique_segments) == 1: | ||
segment_indices = [int(unique_segments[0])] | ||
else: | ||
raise ValueError("segment_indices must be specified if there are multiple segments") | ||
|
||
if not isinstance(segment_indices, list): | ||
raise ValueError("segment_indices must be a list of ints") | ||
|
||
# Validate all segment indices exist in the data | ||
for idx in segment_indices: | ||
if idx not in unique_segments: | ||
raise ValueError(f"segment_index {idx} not found in peaks data") | ||
|
||
# Filter data for the selected segments | ||
# Note: For simplicity, we'll filter all data first, then construct dict of dicts | ||
segment_mask = np.isin(peaks["segment_index"], segment_indices) | ||
filtered_peaks = peaks[segment_mask] | ||
filtered_locations = peak_locations[segment_mask] | ||
if peak_amplitudes is not None: | ||
filtered_amplitudes = peak_amplitudes[segment_mask] | ||
|
||
# Create dict of dicts structure for the base class | ||
spike_train_data = {} | ||
y_axis_data = {} | ||
|
||
# Process each segment separately | ||
for seg_idx in segment_indices: | ||
segment_mask = filtered_peaks["segment_index"] == seg_idx | ||
segment_peaks = filtered_peaks[segment_mask] | ||
segment_locations = filtered_locations[segment_mask] | ||
|
||
# Convert peak times to seconds | ||
spike_times = segment_peaks["sample_index"] / sampling_frequency | ||
|
||
# Store in dict of dicts format (using 0 as the "unit" id) | ||
spike_train_data[seg_idx] = {0: spike_times} | ||
y_axis_data[seg_idx] = {0: segment_locations[direction]} | ||
|
||
if color_amplitude and peak_amplitudes is not None: | ||
amps = filtered_amplitudes | ||
amps_abs = np.abs(amps) | ||
q_95 = np.quantile(amps_abs, 0.95) | ||
cmap = colormaps[cmap] | ||
cmap_obj = colormaps[cmap] | ||
if clim is None: | ||
amps = amps_abs | ||
amps /= q_95 | ||
c = cmap(amps) | ||
c = cmap_obj(amps) | ||
else: | ||
from matplotlib.colors import Normalize | ||
|
||
|
@@ -226,18 +259,31 @@ def __init__( | |
else: | ||
color_kwargs = dict(color=color, c=None, alpha=alpha) | ||
|
||
# convert data into format that `BaseRasterWidget` can take it in | ||
spike_train_data = {0: peaks["sample_index"] / sampling_frequency} | ||
y_axis_data = {0: peak_locations[direction]} | ||
# Calculate segment durations for x-axis limits | ||
if recording is not None: | ||
durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices] | ||
else: | ||
# Find boundaries between segments using searchsorted | ||
segment_boundaries = [ | ||
np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices | ||
] | ||
|
||
# Calculate durations from max sample in each segment | ||
durations = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spike trains are always sorted, so we don't need to find the max sample index, instead can use something like:
|
||
(np.max(filtered_peaks["sample_index"][start:end]) + 1) / sampling_frequency if start < end else 0 | ||
for (start, end) in segment_boundaries | ||
] | ||
|
||
plot_data = dict( | ||
spike_train_data=spike_train_data, | ||
y_axis_data=y_axis_data, | ||
segment_indices=segment_indices, | ||
y_lim=depth_lim, | ||
color_kwargs=color_kwargs, | ||
scatter_decimate=scatter_decimate, | ||
title="Peak depth", | ||
y_label="Depth [um]", | ||
durations=durations, | ||
) | ||
|
||
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) | ||
|
@@ -370,10 +416,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): | |
dp.recording, | ||
) | ||
|
||
commpon_drift_map_kwargs = dict( | ||
common_drift_map_kwargs = dict( | ||
direction=dp.motion.direction, | ||
recording=dp.recording, | ||
segment_index=dp.segment_index, | ||
segment_indices=[dp.segment_index], | ||
depth_lim=dp.depth_lim, | ||
scatter_decimate=dp.scatter_decimate, | ||
color_amplitude=dp.color_amplitude, | ||
|
@@ -390,15 +436,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): | |
dp.peak_locations, | ||
ax=ax0, | ||
immediate_plot=True, | ||
**commpon_drift_map_kwargs, | ||
**common_drift_map_kwargs, | ||
) | ||
|
||
_ = DriftRasterMapWidget( | ||
dp.peaks, | ||
corrected_location, | ||
ax=ax1, | ||
immediate_plot=True, | ||
**commpon_drift_map_kwargs, | ||
**common_drift_map_kwargs, | ||
) | ||
|
||
ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, realised we need to deprecate
segment_index
in favour ofsegment_indices
. We don't need to do this forBaseRasterWidet
, just for the user facing classes.To do this, we allow the user to pass
segment_index
for the next few releases. If they do we convert this tosegment_indices = [segment_index]
and warn the user that the argument will be deprecated in a future version, in this case 1.104.A clear example of this can be see in
pca_metrics.py
:spikeinterface/src/spikeinterface/qualitymetrics/pca_metrics.py
Line 95 in 7c99530