Skip to content

Add multi-segment capability to BaseRasterWidget and children #3805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .rasters import BaseRasterWidget
from .base import BaseWidget, to_attr
from .utils import get_some_colors
from .utils import get_some_colors, validate_segment_indices, get_segment_durations

from spikeinterface.core.sortinganalyzer import SortingAnalyzer

Expand All @@ -25,8 +25,9 @@ class AmplitudesWidget(BaseRasterWidget):
unit_colors : dict | None, default: None
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
segment_index : int or None, default: None
The segment index (or None if mono-segment)
segment_indices : list of int or None, default: None
Segment index or indices to plot. If None and there are multiple segments, defaults to 0.
If list, spike trains and amplitudes are concatenated across the specified segments.
max_spikes_per_unit : int or None, default: None
Number of max spikes per unit to display. Use None for all spikes
y_lim : tuple or None, default: None
Expand All @@ -51,7 +52,7 @@ def __init__(
sorting_analyzer: SortingAnalyzer,
unit_ids=None,
unit_colors=None,
segment_index=None,
segment_indices=None,
max_spikes_per_unit=None,
y_lim=None,
scatter_decimate=1,
Expand All @@ -64,57 +65,75 @@ def __init__(
):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, realised we need to deprecate segment_index in favour of segment_indices. We don't need to do this for BaseRasterWidet, just for the user facing classes.

To do this, we allow the user to pass segment_index for the next few releases. If they do we convert this to segment_indices = [segment_index] and warn the user that the argument will be deprecated in a future version, in this case 1.104.

A clear example of this can be see in pca_metrics.py:

if qm_params is not None and metric_params is None:

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

sorting = sorting_analyzer.sorting
self.check_extensions(sorting_analyzer, "spike_amplitudes")

# Get amplitudes by segment
amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit")

if unit_ids is None:
unit_ids = sorting.unit_ids

if sorting.get_num_segments() > 1:
if segment_index is None:
warn("More than one segment available! Using `segment_index = 0`.")
segment_index = 0
else:
segment_index = 0
# Handle segment_index input
segment_indices = validate_segment_indices(segment_indices, sorting)

# Check for SortingView backend
is_sortingview = backend == "sortingview"

# For SortingView, ensure we're only using a single segment
if is_sortingview and len(segment_indices) > 1:
warn("SortingView backend currently supports only single segment. Using first segment.")
segment_indices = [segment_indices[0]]

# Create multi-segment data structure (dict of dicts)
spiketrains_by_segment = {}
amplitudes_by_segment = {}

for idx in segment_indices:
amplitudes_segment = amplitudes[idx]

amplitudes_segment = amplitudes[segment_index]
total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency
# Initialize for this segment
spiketrains_by_segment[idx] = {}
amplitudes_by_segment[idx] = {}

all_spiketrains = {
unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
for unit_id in sorting.unit_ids
}
for unit_id in unit_ids:
# Get spike times for this unit in this segment
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True)
amps = amplitudes_segment[unit_id]

all_amplitudes = amplitudes_segment
# Store data in dict of dicts format
spiketrains_by_segment[idx][unit_id] = spike_times
amplitudes_by_segment[idx][unit_id] = amps

# Apply max_spikes_per_unit limit if specified
if max_spikes_per_unit is not None:
spiketrains_to_plot = dict()
amplitudes_to_plot = dict()
for unit, st in all_spiketrains.items():
amps = all_amplitudes[unit]
if len(st) > max_spikes_per_unit:
random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False)
spiketrains_to_plot[unit] = st[random_idxs]
amplitudes_to_plot[unit] = amps[random_idxs]
else:
spiketrains_to_plot[unit] = st
amplitudes_to_plot[unit] = amps
else:
spiketrains_to_plot = all_spiketrains
amplitudes_to_plot = all_amplitudes
for idx in segment_indices:
for unit_id in unit_ids:
st = spiketrains_by_segment[idx][unit_id]
amps = amplitudes_by_segment[idx][unit_id]
if len(st) > max_spikes_per_unit:
# Scale down the number of spikes proportionally per segment
# to ensure we have max_spikes_per_unit total after concatenation
segment_count = len(segment_indices)
segment_max = max(1, max_spikes_per_unit // segment_count)

if len(st) > segment_max:
random_idxs = np.random.choice(len(st), size=segment_max, replace=False)
spiketrains_by_segment[idx][unit_id] = st[random_idxs]
amplitudes_by_segment[idx][unit_id] = amps[random_idxs]

if plot_histograms and bins is None:
bins = 100

# Calculate durations for all segments for x-axis limits
durations = get_segment_durations(sorting)

# Build the plot data with the full dict of dicts structure
plot_data = dict(
spike_train_data=spiketrains_to_plot,
y_axis_data=amplitudes_to_plot,
unit_colors=unit_colors,
plot_histograms=plot_histograms,
bins=bins,
total_duration=total_duration,
durations=durations,
unit_ids=unit_ids,
hide_unit_selector=hide_unit_selector,
plot_legend=plot_legend,
Expand All @@ -123,6 +142,17 @@ def __init__(
scatter_decimate=scatter_decimate,
)

# If using SortingView, extract just the first segment's data as flat dicts
if is_sortingview:
first_segment = segment_indices[0]
plot_data["spike_train_data"] = spiketrains_by_segment[first_segment]
plot_data["y_axis_data"] = amplitudes_by_segment[first_segment]
else:
# Otherwise use the full dict of dicts structure with all segments
plot_data["spike_train_data"] = spiketrains_by_segment
plot_data["y_axis_data"] = amplitudes_by_segment
plot_data["segment_indices"] = segment_indices

BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)

def plot_sortingview(self, data_plot, **backend_kwargs):
Expand All @@ -143,7 +173,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
]

self.view = vv.SpikeAmplitudes(
start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector
start_time_sec=0,
end_time_sec=np.sum(dp.durations),
plots=sa_items,
hide_unit_selector=dp.hide_unit_selector,
)

self.url = handle_display_and_url(self, self.view, **backend_kwargs)
104 changes: 75 additions & 29 deletions src/spikeinterface/widgets/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from spikeinterface.core import BaseRecording, SortingAnalyzer
from .rasters import BaseRasterWidget
from .utils import get_segment_durations
from spikeinterface.core.motion import Motion


Expand Down Expand Up @@ -117,14 +118,14 @@ class DriftRasterMapWidget(BaseRasterWidget):
"spike_locations" extension computed.
direction : "x" or "y", default: "y"
The direction to display. "y" is the depth direction.
segment_index : int, default: None
The segment index to display.
recording : RecordingExtractor | None, default: None
The recording extractor object (only used to get "real" times).
segment_index : int, default: 0
The segment index to display.
sampling_frequency : float, default: None
The sampling frequency (needed if recording is None).
segment_indices : list of int or None, default: None
The segment index or indices to display. If None and there's only one segment, it's used.
If None and there are multiple segments, you must specify which to use.
If a list of indices is provided, peaks and locations are concatenated across the segments.
depth_lim : tuple or None, default: None
The min and max depth to display, if None (min and max of the recording).
scatter_decimate : int, default: None
Expand All @@ -149,7 +150,7 @@ def __init__(
direction: str = "y",
recording: BaseRecording | None = None,
sampling_frequency: float | None = None,
segment_index: int | None = None,
segment_indices: list[int] | None = None,
depth_lim: tuple[float, float] | None = None,
color_amplitude: bool = True,
scatter_decimate: int | None = None,
Expand All @@ -160,14 +161,19 @@ def __init__(
backend: str | None = None,
**backend_kwargs,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to deprecate segment_index somewhere near here

from matplotlib.pyplot import colormaps
from matplotlib.colors import Normalize

assert peaks is not None or sorting_analyzer is not None

if peaks is not None:
assert peak_locations is not None
if recording is None:
assert sampling_frequency is not None, "If recording is None, you must provide the sampling frequency"
else:
sampling_frequency = recording.sampling_frequency
peak_amplitudes = peaks["amplitude"]

if sorting_analyzer is not None:
if sorting_analyzer.has_recording():
recording = sorting_analyzer.recording
Expand All @@ -190,29 +196,56 @@ def __init__(
else:
peak_amplitudes = None

if segment_index is None:
assert (
len(np.unique(peaks["segment_index"])) == 1
), "segment_index must be specified if there are multiple segments"
segment_index = 0
else:
peak_mask = peaks["segment_index"] == segment_index
peaks = peaks[peak_mask]
peak_locations = peak_locations[peak_mask]
if peak_amplitudes is not None:
peak_amplitudes = peak_amplitudes[peak_mask]

from matplotlib.pyplot import colormaps
unique_segments = np.unique(peaks["segment_index"])

if color_amplitude:
amps = peak_amplitudes
if segment_indices is None:
if len(unique_segments) == 1:
segment_indices = [int(unique_segments[0])]
else:
raise ValueError("segment_indices must be specified if there are multiple segments")

if not isinstance(segment_indices, list):
raise ValueError("segment_indices must be a list of ints")

# Validate all segment indices exist in the data
for idx in segment_indices:
if idx not in unique_segments:
raise ValueError(f"segment_index {idx} not found in peaks data")

# Filter data for the selected segments
# Note: For simplicity, we'll filter all data first, then construct dict of dicts
segment_mask = np.isin(peaks["segment_index"], segment_indices)
filtered_peaks = peaks[segment_mask]
filtered_locations = peak_locations[segment_mask]
if peak_amplitudes is not None:
filtered_amplitudes = peak_amplitudes[segment_mask]

# Create dict of dicts structure for the base class
spike_train_data = {}
y_axis_data = {}

# Process each segment separately
for seg_idx in segment_indices:
segment_mask = filtered_peaks["segment_index"] == seg_idx
segment_peaks = filtered_peaks[segment_mask]
segment_locations = filtered_locations[segment_mask]

# Convert peak times to seconds
spike_times = segment_peaks["sample_index"] / sampling_frequency

# Store in dict of dicts format (using 0 as the "unit" id)
spike_train_data[seg_idx] = {0: spike_times}
y_axis_data[seg_idx] = {0: segment_locations[direction]}

if color_amplitude and peak_amplitudes is not None:
amps = filtered_amplitudes
amps_abs = np.abs(amps)
q_95 = np.quantile(amps_abs, 0.95)
cmap = colormaps[cmap]
cmap_obj = colormaps[cmap]
if clim is None:
amps = amps_abs
amps /= q_95
c = cmap(amps)
c = cmap_obj(amps)
else:
from matplotlib.colors import Normalize

Expand All @@ -226,18 +259,31 @@ def __init__(
else:
color_kwargs = dict(color=color, c=None, alpha=alpha)

# convert data into format that `BaseRasterWidget` can take it in
spike_train_data = {0: peaks["sample_index"] / sampling_frequency}
y_axis_data = {0: peak_locations[direction]}
# Calculate segment durations for x-axis limits
if recording is not None:
durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices]
else:
# Find boundaries between segments using searchsorted
segment_boundaries = [
np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices
]

# Calculate durations from max sample in each segment
durations = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spike trains are always sorted, so we don't need to find the max sample index, instead can use something like:

 durations = [(filtered_peaks["sample_index"][end-1]+1) / sampling_frequency for (_, end) in segment_boundaries ]

(np.max(filtered_peaks["sample_index"][start:end]) + 1) / sampling_frequency if start < end else 0
for (start, end) in segment_boundaries
]

plot_data = dict(
spike_train_data=spike_train_data,
y_axis_data=y_axis_data,
segment_indices=segment_indices,
y_lim=depth_lim,
color_kwargs=color_kwargs,
scatter_decimate=scatter_decimate,
title="Peak depth",
y_label="Depth [um]",
durations=durations,
)

BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)
Expand Down Expand Up @@ -370,10 +416,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
dp.recording,
)

commpon_drift_map_kwargs = dict(
common_drift_map_kwargs = dict(
direction=dp.motion.direction,
recording=dp.recording,
segment_index=dp.segment_index,
segment_indices=[dp.segment_index],
depth_lim=dp.depth_lim,
scatter_decimate=dp.scatter_decimate,
color_amplitude=dp.color_amplitude,
Expand All @@ -390,15 +436,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
dp.peak_locations,
ax=ax0,
immediate_plot=True,
**commpon_drift_map_kwargs,
**common_drift_map_kwargs,
)

_ = DriftRasterMapWidget(
dp.peaks,
corrected_location,
ax=ax1,
immediate_plot=True,
**commpon_drift_map_kwargs,
**common_drift_map_kwargs,
)

ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black")
Expand Down
Loading