Skip to content

Add multi-segment capability to BaseRasterWidget and children #3805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4e65593
Add multi-segment support for amplitudes widget
jakeswann1 Mar 25, 2025
2539a89
Update base raster widget and children to handle multi-segment
jakeswann1 Mar 25, 2025
11a1845
Retain sortingview compatibility
jakeswann1 Mar 25, 2025
16e6272
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2025
c4fece4
Merge branch 'main' into main
jakeswann1 Mar 25, 2025
468e564
Merge branch 'main' into main
jakeswann1 Mar 29, 2025
cbc790c
Improve segment validation to list only. Add unitls function for vali…
jakeswann1 Apr 29, 2025
540db00
minor fixes
jakeswann1 Apr 30, 2025
afdea78
Update durations to use a list
jakeswann1 Apr 30, 2025
03676e0
Merge branch 'main' into main
jakeswann1 Apr 30, 2025
65a5280
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2025
c43fa7c
add test for validate_segment_indices
jakeswann1 May 12, 2025
55e773e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2025
b098557
simplify segment duration computation
jakeswann1 May 12, 2025
9dfd1a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2025
0815572
Merge branch 'main' into main
jakeswann1 May 12, 2025
5d42e1d
Merge branch 'main' into main
chrishalcrow May 16, 2025
5924f2e
Merge branch 'main' into main
alejoe91 Jun 12, 2025
112b240
Merge branch 'main' into main
jakeswann1 Jul 3, 2025
9684e0a
Address Chris' comments
jakeswann1 Jul 3, 2025
6c20215
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 85 additions & 36 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .rasters import BaseRasterWidget
from .base import BaseWidget, to_attr
from .utils import get_some_colors
from .utils import get_some_colors, validate_segment_indices, get_segment_durations

from spikeinterface.core.sortinganalyzer import SortingAnalyzer

Expand All @@ -25,8 +25,9 @@ class AmplitudesWidget(BaseRasterWidget):
unit_colors : dict | None, default: None
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
segment_index : int or None, default: None
The segment index (or None if mono-segment)
segment_indices : list of int or None, default: None
Segment index or indices to plot. If None and there are multiple segments, defaults to 0.
If list, spike trains and amplitudes are concatenated across the specified segments.
max_spikes_per_unit : int or None, default: None
Number of max spikes per unit to display. Use None for all spikes
y_lim : tuple or None, default: None
Expand All @@ -51,70 +52,104 @@ def __init__(
sorting_analyzer: SortingAnalyzer,
unit_ids=None,
unit_colors=None,
segment_index=None,
segment_indices=None,
max_spikes_per_unit=None,
y_lim=None,
scatter_decimate=1,
hide_unit_selector=False,
plot_histograms=False,
bins=None,
plot_legend=True,
segment_index=None,
backend=None,
**backend_kwargs,
):
import warnings

# Handle deprecation of segment_index parameter
if segment_index is not None:
warnings.warn(
"The 'segment_index' parameter is deprecated and will be removed in a future version. "
"Use 'segment_indices' instead.",
DeprecationWarning,
stacklevel=2,
)
if segment_indices is None:
if isinstance(segment_index, int):
segment_indices = [segment_index]
else:
segment_indices = segment_index

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

sorting = sorting_analyzer.sorting
self.check_extensions(sorting_analyzer, "spike_amplitudes")

# Get amplitudes by segment
amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit")

if unit_ids is None:
unit_ids = sorting.unit_ids

if sorting.get_num_segments() > 1:
if segment_index is None:
warn("More than one segment available! Using `segment_index = 0`.")
segment_index = 0
else:
segment_index = 0
# Handle segment_index input
segment_indices = validate_segment_indices(segment_indices, sorting)

# Check for SortingView backend
is_sortingview = backend == "sortingview"

# For SortingView, ensure we're only using a single segment
if is_sortingview and len(segment_indices) > 1:
warn("SortingView backend currently supports only single segment. Using first segment.")
segment_indices = [segment_indices[0]]

amplitudes_segment = amplitudes[segment_index]
total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency
# Create multi-segment data structure (dict of dicts)
spiketrains_by_segment = {}
amplitudes_by_segment = {}

all_spiketrains = {
unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
for unit_id in sorting.unit_ids
}
for idx in segment_indices:
amplitudes_segment = amplitudes[idx]

all_amplitudes = amplitudes_segment
# Initialize for this segment
spiketrains_by_segment[idx] = {}
amplitudes_by_segment[idx] = {}

for unit_id in unit_ids:
# Get spike times for this unit in this segment
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True)
amps = amplitudes_segment[unit_id]

# Store data in dict of dicts format
spiketrains_by_segment[idx][unit_id] = spike_times
amplitudes_by_segment[idx][unit_id] = amps

# Apply max_spikes_per_unit limit if specified
if max_spikes_per_unit is not None:
spiketrains_to_plot = dict()
amplitudes_to_plot = dict()
for unit, st in all_spiketrains.items():
amps = all_amplitudes[unit]
if len(st) > max_spikes_per_unit:
random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False)
spiketrains_to_plot[unit] = st[random_idxs]
amplitudes_to_plot[unit] = amps[random_idxs]
else:
spiketrains_to_plot[unit] = st
amplitudes_to_plot[unit] = amps
else:
spiketrains_to_plot = all_spiketrains
amplitudes_to_plot = all_amplitudes
for idx in segment_indices:
for unit_id in unit_ids:
st = spiketrains_by_segment[idx][unit_id]
amps = amplitudes_by_segment[idx][unit_id]
if len(st) > max_spikes_per_unit:
# Scale down the number of spikes proportionally per segment
# to ensure we have max_spikes_per_unit total after concatenation
segment_count = len(segment_indices)
segment_max = max(1, max_spikes_per_unit // segment_count)

if len(st) > segment_max:
random_idxs = np.random.choice(len(st), size=segment_max, replace=False)
spiketrains_by_segment[idx][unit_id] = st[random_idxs]
amplitudes_by_segment[idx][unit_id] = amps[random_idxs]

if plot_histograms and bins is None:
bins = 100

# Calculate durations for all segments for x-axis limits
durations = get_segment_durations(sorting)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once function is updated, pass segment_indices here


# Build the plot data with the full dict of dicts structure
plot_data = dict(
spike_train_data=spiketrains_to_plot,
y_axis_data=amplitudes_to_plot,
unit_colors=unit_colors,
plot_histograms=plot_histograms,
bins=bins,
total_duration=total_duration,
durations=durations,
unit_ids=unit_ids,
hide_unit_selector=hide_unit_selector,
plot_legend=plot_legend,
Expand All @@ -123,6 +158,17 @@ def __init__(
scatter_decimate=scatter_decimate,
)

# If using SortingView, extract just the first segment's data as flat dicts
if is_sortingview:
first_segment = segment_indices[0]
plot_data["spike_train_data"] = spiketrains_by_segment[first_segment]
plot_data["y_axis_data"] = amplitudes_by_segment[first_segment]
else:
# Otherwise use the full dict of dicts structure with all segments
plot_data["spike_train_data"] = spiketrains_by_segment
plot_data["y_axis_data"] = amplitudes_by_segment
plot_data["segment_indices"] = segment_indices

BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)

def plot_sortingview(self, data_plot, **backend_kwargs):
Expand All @@ -143,7 +189,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
]

self.view = vv.SpikeAmplitudes(
start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector
start_time_sec=0,
end_time_sec=np.sum(dp.durations),
plots=sa_items,
hide_unit_selector=dp.hide_unit_selector,
)

self.url = handle_display_and_url(self, self.view, **backend_kwargs)
Loading