diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 197fefbab2..6f36baf521 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -5,7 +5,7 @@ from .rasters import BaseRasterWidget from .base import BaseWidget, to_attr -from .utils import get_some_colors +from .utils import get_some_colors, validate_segment_indices, get_segment_durations from spikeinterface.core.sortinganalyzer import SortingAnalyzer @@ -25,8 +25,9 @@ class AmplitudesWidget(BaseRasterWidget): unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. - segment_index : int or None, default: None - The segment index (or None if mono-segment) + segment_indices : list of int or None, default: None + Segment index or indices to plot. If None and there are multiple segments, defaults to 0. + If list, spike trains and amplitudes are concatenated across the specified segments. max_spikes_per_unit : int or None, default: None Number of max spikes per unit to display. Use None for all spikes y_lim : tuple or None, default: None @@ -51,7 +52,7 @@ def __init__( sorting_analyzer: SortingAnalyzer, unit_ids=None, unit_colors=None, - segment_index=None, + segment_indices=None, max_spikes_per_unit=None, y_lim=None, scatter_decimate=1, @@ -64,57 +65,75 @@ def __init__( ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - sorting = sorting_analyzer.sorting self.check_extensions(sorting_analyzer, "spike_amplitudes") + # Get amplitudes by segment amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit") if unit_ids is None: unit_ids = sorting.unit_ids - if sorting.get_num_segments() > 1: - if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") - segment_index = 0 - else: - segment_index = 0 + # Handle segment_index input + segment_indices = validate_segment_indices(segment_indices, sorting) + + # Check for SortingView backend + is_sortingview = backend == "sortingview" + + # For SortingView, ensure we're only using a single segment + if is_sortingview and len(segment_indices) > 1: + warn("SortingView backend currently supports only single segment. Using first segment.") + segment_indices = [segment_indices[0]] + + # Create multi-segment data structure (dict of dicts) + spiketrains_by_segment = {} + amplitudes_by_segment = {} + + for idx in segment_indices: + amplitudes_segment = amplitudes[idx] - amplitudes_segment = amplitudes[segment_index] - total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency + # Initialize for this segment + spiketrains_by_segment[idx] = {} + amplitudes_by_segment[idx] = {} - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in sorting.unit_ids - } + for unit_id in unit_ids: + # Get spike times for this unit in this segment + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True) + amps = amplitudes_segment[unit_id] - all_amplitudes = amplitudes_segment + # Store data in dict of dicts format + spiketrains_by_segment[idx][unit_id] = spike_times + amplitudes_by_segment[idx][unit_id] = amps + + # Apply max_spikes_per_unit limit if specified if max_spikes_per_unit is not None: - spiketrains_to_plot = dict() - amplitudes_to_plot = dict() - for unit, st in all_spiketrains.items(): - amps = all_amplitudes[unit] - if len(st) > max_spikes_per_unit: - random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False) - spiketrains_to_plot[unit] = st[random_idxs] - amplitudes_to_plot[unit] = amps[random_idxs] - else: - spiketrains_to_plot[unit] = st - amplitudes_to_plot[unit] = amps - else: - spiketrains_to_plot = all_spiketrains - amplitudes_to_plot = all_amplitudes + for idx in segment_indices: + for unit_id in unit_ids: + st = spiketrains_by_segment[idx][unit_id] + amps = amplitudes_by_segment[idx][unit_id] + if len(st) > max_spikes_per_unit: + # Scale down the number of spikes proportionally per segment + # to ensure we have max_spikes_per_unit total after concatenation + segment_count = len(segment_indices) + segment_max = max(1, max_spikes_per_unit // segment_count) + + if len(st) > segment_max: + random_idxs = np.random.choice(len(st), size=segment_max, replace=False) + spiketrains_by_segment[idx][unit_id] = st[random_idxs] + amplitudes_by_segment[idx][unit_id] = amps[random_idxs] if plot_histograms and bins is None: bins = 100 + # Calculate durations for all segments for x-axis limits + durations = get_segment_durations(sorting) + + # Build the plot data with the full dict of dicts structure plot_data = dict( - spike_train_data=spiketrains_to_plot, - y_axis_data=amplitudes_to_plot, unit_colors=unit_colors, plot_histograms=plot_histograms, bins=bins, - total_duration=total_duration, + durations=durations, unit_ids=unit_ids, hide_unit_selector=hide_unit_selector, plot_legend=plot_legend, @@ -123,6 +142,17 @@ def __init__( scatter_decimate=scatter_decimate, ) + # If using SortingView, extract just the first segment's data as flat dicts + if is_sortingview: + first_segment = segment_indices[0] + plot_data["spike_train_data"] = spiketrains_by_segment[first_segment] + plot_data["y_axis_data"] = amplitudes_by_segment[first_segment] + else: + # Otherwise use the full dict of dicts structure with all segments + plot_data["spike_train_data"] = spiketrains_by_segment + plot_data["y_axis_data"] = amplitudes_by_segment + plot_data["segment_indices"] = segment_indices + BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) def plot_sortingview(self, data_plot, **backend_kwargs): @@ -143,7 +173,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ] self.view = vv.SpikeAmplitudes( - start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector + start_time_sec=0, + end_time_sec=np.sum(dp.durations), + plots=sa_items, + hide_unit_selector=dp.hide_unit_selector, ) self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index f28560fcd6..dbc271f305 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -6,6 +6,7 @@ from spikeinterface.core import BaseRecording, SortingAnalyzer from .rasters import BaseRasterWidget +from .utils import get_segment_durations from spikeinterface.core.motion import Motion @@ -117,14 +118,14 @@ class DriftRasterMapWidget(BaseRasterWidget): "spike_locations" extension computed. direction : "x" or "y", default: "y" The direction to display. "y" is the depth direction. - segment_index : int, default: None - The segment index to display. recording : RecordingExtractor | None, default: None The recording extractor object (only used to get "real" times). - segment_index : int, default: 0 - The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None). + segment_indices : list of int or None, default: None + The segment index or indices to display. If None and there's only one segment, it's used. + If None and there are multiple segments, you must specify which to use. + If a list of indices is provided, peaks and locations are concatenated across the segments. depth_lim : tuple or None, default: None The min and max depth to display, if None (min and max of the recording). scatter_decimate : int, default: None @@ -149,7 +150,7 @@ def __init__( direction: str = "y", recording: BaseRecording | None = None, sampling_frequency: float | None = None, - segment_index: int | None = None, + segment_indices: list[int] | None = None, depth_lim: tuple[float, float] | None = None, color_amplitude: bool = True, scatter_decimate: int | None = None, @@ -160,7 +161,11 @@ def __init__( backend: str | None = None, **backend_kwargs, ): + from matplotlib.pyplot import colormaps + from matplotlib.colors import Normalize + assert peaks is not None or sorting_analyzer is not None + if peaks is not None: assert peak_locations is not None if recording is None: @@ -168,6 +173,7 @@ def __init__( else: sampling_frequency = recording.sampling_frequency peak_amplitudes = peaks["amplitude"] + if sorting_analyzer is not None: if sorting_analyzer.has_recording(): recording = sorting_analyzer.recording @@ -190,29 +196,56 @@ def __init__( else: peak_amplitudes = None - if segment_index is None: - assert ( - len(np.unique(peaks["segment_index"])) == 1 - ), "segment_index must be specified if there are multiple segments" - segment_index = 0 - else: - peak_mask = peaks["segment_index"] == segment_index - peaks = peaks[peak_mask] - peak_locations = peak_locations[peak_mask] - if peak_amplitudes is not None: - peak_amplitudes = peak_amplitudes[peak_mask] - - from matplotlib.pyplot import colormaps + unique_segments = np.unique(peaks["segment_index"]) - if color_amplitude: - amps = peak_amplitudes + if segment_indices is None: + if len(unique_segments) == 1: + segment_indices = [int(unique_segments[0])] + else: + raise ValueError("segment_indices must be specified if there are multiple segments") + + if not isinstance(segment_indices, list): + raise ValueError("segment_indices must be a list of ints") + + # Validate all segment indices exist in the data + for idx in segment_indices: + if idx not in unique_segments: + raise ValueError(f"segment_index {idx} not found in peaks data") + + # Filter data for the selected segments + # Note: For simplicity, we'll filter all data first, then construct dict of dicts + segment_mask = np.isin(peaks["segment_index"], segment_indices) + filtered_peaks = peaks[segment_mask] + filtered_locations = peak_locations[segment_mask] + if peak_amplitudes is not None: + filtered_amplitudes = peak_amplitudes[segment_mask] + + # Create dict of dicts structure for the base class + spike_train_data = {} + y_axis_data = {} + + # Process each segment separately + for seg_idx in segment_indices: + segment_mask = filtered_peaks["segment_index"] == seg_idx + segment_peaks = filtered_peaks[segment_mask] + segment_locations = filtered_locations[segment_mask] + + # Convert peak times to seconds + spike_times = segment_peaks["sample_index"] / sampling_frequency + + # Store in dict of dicts format (using 0 as the "unit" id) + spike_train_data[seg_idx] = {0: spike_times} + y_axis_data[seg_idx] = {0: segment_locations[direction]} + + if color_amplitude and peak_amplitudes is not None: + amps = filtered_amplitudes amps_abs = np.abs(amps) q_95 = np.quantile(amps_abs, 0.95) - cmap = colormaps[cmap] + cmap_obj = colormaps[cmap] if clim is None: amps = amps_abs amps /= q_95 - c = cmap(amps) + c = cmap_obj(amps) else: from matplotlib.colors import Normalize @@ -226,18 +259,31 @@ def __init__( else: color_kwargs = dict(color=color, c=None, alpha=alpha) - # convert data into format that `BaseRasterWidget` can take it in - spike_train_data = {0: peaks["sample_index"] / sampling_frequency} - y_axis_data = {0: peak_locations[direction]} + # Calculate segment durations for x-axis limits + if recording is not None: + durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices] + else: + # Find boundaries between segments using searchsorted + segment_boundaries = [ + np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices + ] + + # Calculate durations from max sample in each segment + durations = [ + (np.max(filtered_peaks["sample_index"][start:end]) + 1) / sampling_frequency if start < end else 0 + for (start, end) in segment_boundaries + ] plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, + segment_indices=segment_indices, y_lim=depth_lim, color_kwargs=color_kwargs, scatter_decimate=scatter_decimate, title="Peak depth", y_label="Depth [um]", + durations=durations, ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) @@ -370,10 +416,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.recording, ) - commpon_drift_map_kwargs = dict( + common_drift_map_kwargs = dict( direction=dp.motion.direction, recording=dp.recording, - segment_index=dp.segment_index, + segment_indices=[dp.segment_index], depth_lim=dp.depth_lim, scatter_decimate=dp.scatter_decimate, color_amplitude=dp.color_amplitude, @@ -390,7 +436,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.peak_locations, ax=ax0, immediate_plot=True, - **commpon_drift_map_kwargs, + **common_drift_map_kwargs, ) _ = DriftRasterMapWidget( @@ -398,7 +444,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): corrected_location, ax=ax1, immediate_plot=True, - **commpon_drift_map_kwargs, + **common_drift_map_kwargs, ) ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 398ae4d728..4219b34c3d 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -4,7 +4,7 @@ from warnings import warn from .base import BaseWidget, to_attr, default_backend_kwargs -from .utils import get_some_colors +from .utils import get_some_colors, validate_segment_indices, get_segment_durations class BaseRasterWidget(BaseWidget): @@ -15,14 +15,19 @@ class BaseRasterWidget(BaseWidget): Parameters ---------- - spike_train_data : dict - A dict of spike trains, indexed by the unit_id - y_axis_data : dict - A dict of the y-axis data, indexed by the unit_id + spike_train_data : dict of dicts + A dict of dicts where the structure is spike_train_data[segment_index][unit_id]. + y_axis_data : dict of dicts + A dict of dicts where the structure is y_axis_data[segment_index][unit_id]. + For backwards compatibility, a flat dict indexed by unit_id will be internally + converted to a dict of dicts with segment 0. unit_ids : array-like | None, default: None List of unit_ids to plot - total_duration : int | None, default: None - Duration of spike_train_data in seconds. + segment_indices : list | None, default: None + For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments. + For single-segment data, this parameter is ignored. + durations : list | None, default: None + List of durations per segment of spike_train_data in seconds. plot_histograms : bool, default: False Plot histogram of y-axis data in another subplot bins : int | None, default: None @@ -48,6 +53,8 @@ class BaseRasterWidget(BaseWidget): Ticks on y-axis, passed to `set_yticks`. If None, default ticks are used. hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed + segment_boundary_kwargs : dict | None, default: None + Additional arguments for the segment boundary lines, passed to `matplotlib.axvline` backend : str | None, default None Which plotting backend to use e.g. 'matplotlib', 'ipywidgets'. If None, uses default from `get_default_plotter_backend`. @@ -58,7 +65,8 @@ def __init__( spike_train_data: dict, y_axis_data: dict, unit_ids: list | None = None, - total_duration: int | None = None, + segment_indices: list | None = None, + durations: list | None = None, plot_histograms: bool = False, bins: int | None = None, scatter_decimate: int = 1, @@ -71,13 +79,72 @@ def __init__( y_label: str | None = None, y_ticks: bool = False, hide_unit_selector: bool = True, + segment_boundary_kwargs: dict | None = None, backend: str | None = None, **backend_kwargs, ): + # Set default segment boundary kwargs if not provided + if segment_boundary_kwargs is None: + segment_boundary_kwargs = {"color": "gray", "linestyle": "--", "alpha": 0.7} + + # Process the data + available_segments = list(spike_train_data.keys()) + available_segments.sort() # Ensure consistent ordering + + # Determine which segments to use + if segment_indices is None: + # Use all segments by default + segments_to_use = available_segments + elif isinstance(segment_indices, list): + # Multiple segments specified + for idx in segment_indices: + if idx not in available_segments: + raise ValueError(f"segment_index {idx} not found in avialable segments {available_segments}") + segments_to_use = segment_indices + else: + raise ValueError("segment_index must be `list` or `None`") + + # Get all unit IDs present in any segment if not specified + if unit_ids is None: + all_units = set() + for seg_idx in segments_to_use: + all_units.update(spike_train_data[seg_idx].keys()) + unit_ids = list(all_units) + + # Calculate cumulative durations for segment boundaries + segment_boundaries = np.cumsum(durations) + cumulative_durations = np.concatenate([[0], segment_boundaries]) + + # Concatenate data across segments with proper time offsets + concatenated_spike_trains = {unit_id: np.array([]) for unit_id in unit_ids} + concatenated_y_axis = {unit_id: np.array([]) for unit_id in unit_ids} + + for offset, spike_train_segment, y_axis_segment in zip( + cumulative_durations, + [spike_train_data[idx] for idx in segments_to_use], + [y_axis_data[idx] for idx in segments_to_use], + ): + # Process each unit in the current segment + for unit_id, spike_times in spike_train_segment.items(): + if unit_id not in unit_ids: + continue + + # Get y-axis values for this unit + y_values = y_axis_segment[unit_id] + + # Apply offset to spike times + adjusted_times = spike_times + offset + + # Add to concatenated data + concatenated_spike_trains[unit_id] = np.concatenate( + [concatenated_spike_trains[unit_id], adjusted_times] + ) + concatenated_y_axis[unit_id] = np.concatenate([concatenated_y_axis[unit_id], y_values]) + plot_data = dict( - spike_train_data=spike_train_data, - y_axis_data=y_axis_data, + spike_train_data=concatenated_spike_trains, + y_axis_data=concatenated_y_axis, unit_ids=unit_ids, plot_histograms=plot_histograms, y_lim=y_lim, @@ -87,11 +154,13 @@ def __init__( unit_colors=unit_colors, y_label=y_label, title=title, - total_duration=total_duration, + durations=durations, plot_legend=plot_legend, bins=bins, y_ticks=y_ticks, hide_unit_selector=hide_unit_selector, + segment_boundaries=segment_boundaries, + segment_boundary_kwargs=segment_boundary_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -134,6 +203,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): y_axis_data = dp.y_axis_data for unit_id in unit_ids: + if unit_id not in spike_train_data: + continue # Skip this unit if not in data unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate] unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate] @@ -155,6 +226,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): count, bins = np.histogram(unit_y_data, bins=bins) ax_hist.plot(count, bins[:-1], color=unit_colors[unit_id], alpha=0.8) + # Add segment boundary lines if provided + if getattr(dp, "segment_boundaries", None) is not None: + for boundary in dp.segment_boundaries: + scatter_ax.axvline(boundary, **dp.segment_boundary_kwargs) + if dp.plot_histograms: ax_hist = self.axes.flatten()[1] ax_hist.set_ylim(scatter_ax.get_ylim()) @@ -171,7 +247,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): scatter_ax.set_ylim(*dp.y_lim) x_lim = dp.x_lim if x_lim is None: - x_lim = [0, dp.total_duration] + x_lim = [0, np.sum(dp.durations)] scatter_ax.set_xlim(x_lim) if dp.y_ticks: @@ -282,8 +358,9 @@ class RasterWidget(BaseRasterWidget): A sorting object sorting_analyzer : SortingAnalyzer | None, default: None A sorting analyzer object - segment_index : None or int - The segment index. + segment_indices : list of int or None, default: None + The segment index or indices to use. If None and there are multiple segments, defaults to 0. + If a list of indices is provided, spike trains are concatenated across the specified segments. unit_ids : list List of unit ids time_range : list @@ -296,7 +373,7 @@ def __init__( self, sorting=None, sorting_analyzer=None, - segment_index=None, + segment_indices=None, unit_ids=None, time_range=None, color="k", @@ -312,30 +389,42 @@ def __init__( sorting = self.ensure_sorting(sorting) - if sorting.get_num_segments() > 1: - if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") - segment_index = 0 - else: - segment_index = 0 + segment_indices = validate_segment_indices(segment_indices, sorting) if unit_ids is None: unit_ids = sorting.unit_ids - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in unit_ids - } + # Create dict of dicts structure + spike_train_data = {} + y_axis_data = {} - if time_range is not None: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + # Create a lookup dictionary for unit indices + unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)} + + # Get all spikes at once + spikes = sorting.to_spike_vector() + + # Estimate segment duration from max spike time in each segment + durations = get_segment_durations(sorting) + + # Extract spike data for all segments and units at once + spike_train_data = {seg_idx: {} for seg_idx in segment_indices} + y_axis_data = {seg_idx: {} for seg_idx in segment_indices} + + for seg_idx in segment_indices: for unit_id in unit_ids: - unit_st = all_spiketrains[unit_id] - all_spiketrains[unit_id] = unit_st[(time_range[0] < unit_st) & (unit_st < time_range[1])] + # Get spikes for this segment and unit + mask = (spikes["segment_index"] == seg_idx) & (spikes["unit_index"] == unit_id) + spike_times = spikes["sample_index"][mask] / sorting.sampling_frequency - raster_locations = { - unit_id: unit_index * np.ones(len(all_spiketrains[unit_id])) for unit_index, unit_id in enumerate(unit_ids) - } + # Store data + spike_train_data[seg_idx][unit_id] = spike_times + y_axis_data[seg_idx][unit_id] = unit_indices_map[unit_id] * np.ones(len(spike_times)) + + # Apply time range filtering if specified + if time_range is not None: + assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + # Let BaseRasterWidget handle the filtering unit_indices = list(range(len(unit_ids))) @@ -346,14 +435,16 @@ def __init__( y_ticks = {"ticks": unit_indices, "labels": unit_ids} plot_data = dict( - spike_train_data=all_spiketrains, - y_axis_data=raster_locations, + spike_train_data=spike_train_data, + y_axis_data=y_axis_data, + segment_indices=segment_indices, x_lim=time_range, y_label="Unit id", unit_ids=unit_ids, unit_colors=unit_colors, plot_histograms=None, y_ticks=y_ticks, + durations=durations, ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) diff --git a/src/spikeinterface/widgets/tests/test_widgets_utils.py b/src/spikeinterface/widgets/tests/test_widgets_utils.py index 2131969c2c..ff4bfd957c 100644 --- a/src/spikeinterface/widgets/tests/test_widgets_utils.py +++ b/src/spikeinterface/widgets/tests/test_widgets_utils.py @@ -1,4 +1,7 @@ -from spikeinterface.widgets.utils import get_some_colors +import pytest + +from spikeinterface import generate_sorting +from spikeinterface.widgets.utils import get_some_colors, validate_segment_indices, get_segment_durations def test_get_some_colors(): @@ -19,5 +22,74 @@ def test_get_some_colors(): # print(colors) +def test_validate_segment_indices(): + # Setup + sorting_single = generate_sorting(durations=[5]) # 1 segment + sorting_multiple = generate_sorting(durations=[5, 10, 15, 20, 25]) # 5 segments + + # Test None with single segment + assert validate_segment_indices(None, sorting_single) == [0] + + # Test None with multiple segments + with pytest.warns(UserWarning): + assert validate_segment_indices(None, sorting_multiple) == [0] + + # Test valid indices + assert validate_segment_indices([0], sorting_single) == [0] + assert validate_segment_indices([0, 1, 4], sorting_multiple) == [0, 1, 4] + + # Test invalid type + with pytest.raises(TypeError): + validate_segment_indices(0, sorting_multiple) + + # Test invalid index type + with pytest.raises(ValueError): + validate_segment_indices([0, "1"], sorting_multiple) + + # Test out of range + with pytest.raises(ValueError): + validate_segment_indices([5], sorting_multiple) + + +def test_get_segment_durations(): + from spikeinterface import generate_sorting + + # Test with a normal multi-segment sorting + durations = [5.0, 10.0, 15.0] + + # Create sorting with high fr to ensure spikes near the end segments + sorting = generate_sorting( + durations=durations, + firing_rates=15.0, + ) + + # Calculate durations + calculated_durations = get_segment_durations(sorting) + + # Check results + assert len(calculated_durations) == len(durations) + # Durations should be approximately correct + for calculated_duration, expected_duration in zip(calculated_durations, durations): + # Duration should be <= expected (spikes can't be after the end) + assert calculated_duration <= expected_duration + # And reasonably close + tolerance = max(0.1 * expected_duration, 0.1) + assert expected_duration - calculated_duration < tolerance + + # Test with single-segment sorting + sorting_single = generate_sorting( + durations=[7.0], + firing_rates=15.0, + ) + + single_duration = get_segment_durations(sorting_single)[0] + + # Test that the calculated duration is reasonable + assert single_duration <= 7.0 + assert 7.0 - single_duration < 0.7 # Within 10% + + if __name__ == "__main__": test_get_some_colors() + test_validate_segment_indices() + test_get_segment_durations() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index a1ac9d4af9..9c5892a937 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -3,6 +3,8 @@ from warnings import warn import numpy as np +from spikeinterface.core import BaseSorting + def get_some_colors( keys, @@ -349,3 +351,78 @@ def make_units_table_from_analyzer( ) return units_table + + +def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSorting): + """ + Validate a list of segment indices for a sorting object. + + Parameters + ---------- + segment_indices : list of int + The segment index or indices to validate. + sorting : BaseSorting + The sorting object to validate against. + + Returns + ------- + list of int + A list of valid segment indices. + + Raises + ------ + ValueError + If the segment indices are not valid. + """ + num_segments = sorting.get_num_segments() + + # Handle segment_indices input + if segment_indices is None: + if num_segments > 1: + warn("Segment indices not specified. Using first available segment only.") + return [0] + + # Convert segment_index to list for consistent processing + if not isinstance(segment_indices, list): + raise ValueError( + "segment_indices must be a list of ints - available segments are: " + list(range(num_segments)) + ) + + # Validate segment indices + for idx in segment_indices: + if not isinstance(idx, int): + raise ValueError(f"Each segment index must be an integer, got {type(idx)}") + if idx < 0 or idx >= num_segments: + raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") + + return segment_indices + + +def get_segment_durations(sorting: BaseSorting) -> list[float]: + """ + Calculate the duration of each segment in a sorting object. + + Parameters + ---------- + sorting : BaseSorting + The sorting object containing spike data + + Returns + ------- + list[float] + List of segment durations in seconds + """ + spikes = sorting.to_spike_vector() + segment_indices = np.unique(spikes["segment_index"]) + + durations = [] + for seg_idx in segment_indices: + segment_mask = spikes["segment_index"] == seg_idx + if np.any(segment_mask): + max_sample = np.max(spikes["sample_index"][segment_mask]) + duration = max_sample / sorting.sampling_frequency + else: + duration = 0 + durations.append(duration) + + return durations