Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4e65593
Add multi-segment support for amplitudes widget
jakeswann1 Mar 25, 2025
2539a89
Update base raster widget and children to handle multi-segment
jakeswann1 Mar 25, 2025
11a1845
Retain sortingview compatibility
jakeswann1 Mar 25, 2025
16e6272
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2025
c4fece4
Merge branch 'main' into main
jakeswann1 Mar 25, 2025
468e564
Merge branch 'main' into main
jakeswann1 Mar 29, 2025
cbc790c
Improve segment validation to list only. Add unitls function for vali…
jakeswann1 Apr 29, 2025
540db00
minor fixes
jakeswann1 Apr 30, 2025
afdea78
Update durations to use a list
jakeswann1 Apr 30, 2025
03676e0
Merge branch 'main' into main
jakeswann1 Apr 30, 2025
65a5280
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2025
c43fa7c
add test for validate_segment_indices
jakeswann1 May 12, 2025
55e773e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2025
b098557
simplify segment duration computation
jakeswann1 May 12, 2025
9dfd1a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2025
0815572
Merge branch 'main' into main
jakeswann1 May 12, 2025
5d42e1d
Merge branch 'main' into main
chrishalcrow May 16, 2025
5924f2e
Merge branch 'main' into main
alejoe91 Jun 12, 2025
112b240
Merge branch 'main' into main
jakeswann1 Jul 3, 2025
9684e0a
Address Chris' comments
jakeswann1 Jul 3, 2025
6c20215
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2025
032bdb5
address Chris' comments
jakeswann1 Jul 6, 2025
6b15da8
Merge branch 'main' into main
jakeswann1 Jul 6, 2025
d889d41
oops
jakeswann1 Jul 6, 2025
260a951
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 85 additions & 36 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .rasters import BaseRasterWidget
from .base import BaseWidget, to_attr
from .utils import get_some_colors
from .utils import get_some_colors, validate_segment_indices, get_segment_durations

from spikeinterface.core.sortinganalyzer import SortingAnalyzer

Expand All @@ -25,8 +25,9 @@ class AmplitudesWidget(BaseRasterWidget):
unit_colors : dict | None, default: None
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
segment_index : int or None, default: None
The segment index (or None if mono-segment)
segment_indices : list of int or None, default: None
Segment index or indices to plot. If None and there are multiple segments, defaults to 0.
If list, spike trains and amplitudes are concatenated across the specified segments.
max_spikes_per_unit : int or None, default: None
Number of max spikes per unit to display. Use None for all spikes
y_lim : tuple or None, default: None
Expand All @@ -51,70 +52,104 @@ def __init__(
sorting_analyzer: SortingAnalyzer,
unit_ids=None,
unit_colors=None,
segment_index=None,
segment_indices=None,
max_spikes_per_unit=None,
y_lim=None,
scatter_decimate=1,
hide_unit_selector=False,
plot_histograms=False,
bins=None,
plot_legend=True,
segment_index=None,
backend=None,
**backend_kwargs,
):
import warnings

# Handle deprecation of segment_index parameter
if segment_index is not None:
warnings.warn(
"The 'segment_index' parameter is deprecated and will be removed in a future version. "
"Use 'segment_indices' instead.",
DeprecationWarning,
stacklevel=2,
)
if segment_indices is None:
if isinstance(segment_index, int):
segment_indices = [segment_index]
else:
segment_indices = segment_index

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

sorting = sorting_analyzer.sorting
self.check_extensions(sorting_analyzer, "spike_amplitudes")

# Get amplitudes by segment
amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit")

if unit_ids is None:
unit_ids = sorting.unit_ids

if sorting.get_num_segments() > 1:
if segment_index is None:
warn("More than one segment available! Using `segment_index = 0`.")
segment_index = 0
else:
segment_index = 0
# Handle segment_index input
segment_indices = validate_segment_indices(segment_indices, sorting)

# Check for SortingView backend
is_sortingview = backend == "sortingview"

# For SortingView, ensure we're only using a single segment
if is_sortingview and len(segment_indices) > 1:
warn("SortingView backend currently supports only single segment. Using first segment.")
segment_indices = [segment_indices[0]]

amplitudes_segment = amplitudes[segment_index]
total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency
# Create multi-segment data structure (dict of dicts)
spiketrains_by_segment = {}
amplitudes_by_segment = {}

all_spiketrains = {
unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
for unit_id in sorting.unit_ids
}
for idx in segment_indices:
amplitudes_segment = amplitudes[idx]

all_amplitudes = amplitudes_segment
# Initialize for this segment
spiketrains_by_segment[idx] = {}
amplitudes_by_segment[idx] = {}

for unit_id in unit_ids:
# Get spike times for this unit in this segment
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True)
amps = amplitudes_segment[unit_id]

# Store data in dict of dicts format
spiketrains_by_segment[idx][unit_id] = spike_times
amplitudes_by_segment[idx][unit_id] = amps

# Apply max_spikes_per_unit limit if specified
if max_spikes_per_unit is not None:
spiketrains_to_plot = dict()
amplitudes_to_plot = dict()
for unit, st in all_spiketrains.items():
amps = all_amplitudes[unit]
if len(st) > max_spikes_per_unit:
random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False)
spiketrains_to_plot[unit] = st[random_idxs]
amplitudes_to_plot[unit] = amps[random_idxs]
else:
spiketrains_to_plot[unit] = st
amplitudes_to_plot[unit] = amps
else:
spiketrains_to_plot = all_spiketrains
amplitudes_to_plot = all_amplitudes
for idx in segment_indices:
for unit_id in unit_ids:
st = spiketrains_by_segment[idx][unit_id]
amps = amplitudes_by_segment[idx][unit_id]
if len(st) > max_spikes_per_unit:
# Scale down the number of spikes proportionally per segment
# to ensure we have max_spikes_per_unit total after concatenation
segment_count = len(segment_indices)
segment_max = max(1, max_spikes_per_unit // segment_count)

if len(st) > segment_max:
random_idxs = np.random.choice(len(st), size=segment_max, replace=False)
spiketrains_by_segment[idx][unit_id] = st[random_idxs]
amplitudes_by_segment[idx][unit_id] = amps[random_idxs]

if plot_histograms and bins is None:
bins = 100

# Calculate durations for all segments for x-axis limits
durations = get_segment_durations(sorting, segment_indices)

# Build the plot data with the full dict of dicts structure
plot_data = dict(
spike_train_data=spiketrains_to_plot,
y_axis_data=amplitudes_to_plot,
unit_colors=unit_colors,
plot_histograms=plot_histograms,
bins=bins,
total_duration=total_duration,
durations=durations,
unit_ids=unit_ids,
hide_unit_selector=hide_unit_selector,
plot_legend=plot_legend,
Expand All @@ -123,6 +158,17 @@ def __init__(
scatter_decimate=scatter_decimate,
)

# If using SortingView, extract just the first segment's data as flat dicts
if is_sortingview:
first_segment = segment_indices[0]
plot_data["spike_train_data"] = spiketrains_by_segment[first_segment]
plot_data["y_axis_data"] = amplitudes_by_segment[first_segment]
else:
# Otherwise use the full dict of dicts structure with all segments
plot_data["spike_train_data"] = spiketrains_by_segment
plot_data["y_axis_data"] = amplitudes_by_segment
plot_data["segment_indices"] = segment_indices

BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)

def plot_sortingview(self, data_plot, **backend_kwargs):
Expand All @@ -143,7 +189,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
]

self.view = vv.SpikeAmplitudes(
start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector
start_time_sec=0,
end_time_sec=np.sum(dp.durations),
plots=sa_items,
hide_unit_selector=dp.hide_unit_selector,
)

self.url = handle_display_and_url(self, self.view, **backend_kwargs)
118 changes: 89 additions & 29 deletions src/spikeinterface/widgets/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ class DriftRasterMapWidget(BaseRasterWidget):
"spike_locations" extension computed.
direction : "x" or "y", default: "y"
The direction to display. "y" is the depth direction.
segment_index : int, default: None
The segment index to display.
recording : RecordingExtractor | None, default: None
The recording extractor object (only used to get "real" times).
segment_index : int, default: 0
The segment index to display.
sampling_frequency : float, default: None
The sampling frequency (needed if recording is None).
segment_indices : list of int or None, default: None
The segment index or indices to display. If None and there's only one segment, it's used.
If None and there are multiple segments, you must specify which to use.
If a list of indices is provided, peaks and locations are concatenated across the segments.
depth_lim : tuple or None, default: None
The min and max depth to display, if None (min and max of the recording).
scatter_decimate : int, default: None
Expand All @@ -149,25 +149,46 @@ def __init__(
direction: str = "y",
recording: BaseRecording | None = None,
sampling_frequency: float | None = None,
segment_index: int | None = None,
segment_indices: list[int] | None = None,
depth_lim: tuple[float, float] | None = None,
color_amplitude: bool = True,
scatter_decimate: int | None = None,
cmap: str = "inferno",
color: str = "Gray",
clim: tuple[float, float] | None = None,
alpha: float = 1,
segment_index: int | list[int] | None = None,
backend: str | None = None,
**backend_kwargs,
):
import warnings
from matplotlib.pyplot import colormaps
from matplotlib.colors import Normalize

# Handle deprecation of segment_index parameter
if segment_index is not None:
warnings.warn(
"The 'segment_index' parameter is deprecated and will be removed in a future version. "
"Use 'segment_indices' instead.",
DeprecationWarning,
stacklevel=2,
)
if segment_indices is None:
if isinstance(segment_index, int):
segment_indices = [segment_index]
else:
segment_indices = segment_index

assert peaks is not None or sorting_analyzer is not None

if peaks is not None:
assert peak_locations is not None
if recording is None:
assert sampling_frequency is not None, "If recording is None, you must provide the sampling frequency"
else:
sampling_frequency = recording.sampling_frequency
peak_amplitudes = peaks["amplitude"]

if sorting_analyzer is not None:
if sorting_analyzer.has_recording():
recording = sorting_analyzer.recording
Expand All @@ -190,29 +211,56 @@ def __init__(
else:
peak_amplitudes = None

if segment_index is None:
assert (
len(np.unique(peaks["segment_index"])) == 1
), "segment_index must be specified if there are multiple segments"
segment_index = 0
else:
peak_mask = peaks["segment_index"] == segment_index
peaks = peaks[peak_mask]
peak_locations = peak_locations[peak_mask]
if peak_amplitudes is not None:
peak_amplitudes = peak_amplitudes[peak_mask]

from matplotlib.pyplot import colormaps
unique_segments = np.unique(peaks["segment_index"])

if color_amplitude:
amps = peak_amplitudes
if segment_indices is None:
if len(unique_segments) == 1:
segment_indices = [int(unique_segments[0])]
else:
raise ValueError("segment_indices must be specified if there are multiple segments")

if not isinstance(segment_indices, list):
raise ValueError("segment_indices must be a list of ints")

# Validate all segment indices exist in the data
for idx in segment_indices:
if idx not in unique_segments:
raise ValueError(f"segment_index {idx} not found in peaks data")

# Filter data for the selected segments
# Note: For simplicity, we'll filter all data first, then construct dict of dicts
segment_mask = np.isin(peaks["segment_index"], segment_indices)
filtered_peaks = peaks[segment_mask]
filtered_locations = peak_locations[segment_mask]
if peak_amplitudes is not None:
filtered_amplitudes = peak_amplitudes[segment_mask]

# Create dict of dicts structure for the base class
spike_train_data = {}
y_axis_data = {}

# Process each segment separately
for seg_idx in segment_indices:
segment_mask = filtered_peaks["segment_index"] == seg_idx
segment_peaks = filtered_peaks[segment_mask]
segment_locations = filtered_locations[segment_mask]

# Convert peak times to seconds
spike_times = segment_peaks["sample_index"] / sampling_frequency

# Store in dict of dicts format (using 0 as the "unit" id)
spike_train_data[seg_idx] = {0: spike_times}
y_axis_data[seg_idx] = {0: segment_locations[direction]}

if color_amplitude and peak_amplitudes is not None:
amps = filtered_amplitudes
amps_abs = np.abs(amps)
q_95 = np.quantile(amps_abs, 0.95)
cmap = colormaps[cmap]
cmap_obj = colormaps[cmap]
if clim is None:
amps = amps_abs
amps /= q_95
c = cmap(amps)
c = cmap_obj(amps)
else:
from matplotlib.colors import Normalize

Expand All @@ -226,18 +274,30 @@ def __init__(
else:
color_kwargs = dict(color=color, c=None, alpha=alpha)

# convert data into format that `BaseRasterWidget` can take it in
spike_train_data = {0: peaks["sample_index"] / sampling_frequency}
y_axis_data = {0: peak_locations[direction]}
# Calculate segment durations for x-axis limits
if recording is not None:
durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices]
else:
# Find boundaries between segments using searchsorted
segment_boundaries = [
np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices
]

# Calculate durations from max sample in each segment
durations = [
(filtered_peaks["sample_index"][end - 1] + 1) / sampling_frequency for (_, end) in segment_boundaries
]

plot_data = dict(
spike_train_data=spike_train_data,
y_axis_data=y_axis_data,
segment_indices=segment_indices,
y_lim=depth_lim,
color_kwargs=color_kwargs,
scatter_decimate=scatter_decimate,
title="Peak depth",
y_label="Depth [um]",
durations=durations,
)

BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)
Expand Down Expand Up @@ -370,10 +430,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
dp.recording,
)

commpon_drift_map_kwargs = dict(
common_drift_map_kwargs = dict(
direction=dp.motion.direction,
recording=dp.recording,
segment_index=dp.segment_index,
segment_indices=[dp.segment_index],
depth_lim=dp.depth_lim,
scatter_decimate=dp.scatter_decimate,
color_amplitude=dp.color_amplitude,
Expand All @@ -390,15 +450,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
dp.peak_locations,
ax=ax0,
immediate_plot=True,
**commpon_drift_map_kwargs,
**common_drift_map_kwargs,
)

_ = DriftRasterMapWidget(
dp.peaks,
corrected_location,
ax=ax1,
immediate_plot=True,
**commpon_drift_map_kwargs,
**common_drift_map_kwargs,
)

ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black")
Expand Down
Loading
Loading