Skip to content

basesorting.py get_spike_trains method #3946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

pas-calc
Copy link
Contributor

convienient function to get spike trains of all units as dict

an alternative method could be

spike_vector = sorting.to_spike_vector()
for unit_id in sorting.unit_ids:
    spikes = spike_vector[spike_vector["unit_index"]==sorting.id_to_index(unit_id)]
    # np.unique(spike_vector["unit_index"],return_counts=True) # number of spikes per unit
    ...

pas-calc and others added 2 commits May 22, 2025 16:07
convienient function to get spike trains of all units as dict
@chrishalcrow
Copy link
Member

Hey @pas-calc , I think this is a great idea! I suspect there will be some debate about how to implement this. Just to note: there's another way to do this as follows:

spike_trains_samples = si.spike_vector_to_spike_trains(sorting.to_spike_vector(concatenated=False), unit_ids=sort.unit_ids)

# go from sample index to seconds
spike_trains = {segment_id : {unit_id: spike_train/sorting.sampling_frequency for unit_id, spike_train in  spike_trains_s\
egment.items()} for segment_id, spike_trains_segment in spike_trains_samples.items()}

spike_vector_to_spike_trains uses a numba implementation so should be a lot faster. It might worth benchmarking on some real data. And you also need to worry about segments, I'm afraid.

I would also vote to call it to_spike_trains to match the to_spike_vector method.

@h-mayorquin
Copy link
Collaborator

The get_unit_spike_train is basically that dict with some extra functionality:

def get_unit_spike_train(
self,
unit_id: str | int,
segment_index: Union[int, None] = None,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
return_times: bool = False,
use_cache: bool = True,
):
segment_index = self._check_segment_index(segment_index)
if use_cache:
if segment_index not in self._cached_spike_trains:
self._cached_spike_trains[segment_index] = {}
if unit_id not in self._cached_spike_trains[segment_index]:
segment = self._sorting_segments[segment_index]
spike_frames = segment.get_unit_spike_train(unit_id=unit_id, start_frame=None, end_frame=None).astype(
"int64", copy=False
)
self._cached_spike_trains[segment_index][unit_id] = spike_frames
else:
spike_frames = self._cached_spike_trains[segment_index][unit_id]
if start_frame is not None:
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
else:
segment = self._sorting_segments[segment_index]
spike_frames = segment.get_unit_spike_train(
unit_id=unit_id, start_frame=start_frame, end_frame=end_frame
).astype("int64")
if return_times:
if self.has_recording():
times = self.get_times(segment_index=segment_index)
return times[spike_frames]
else:
segment = self._sorting_segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
spike_times = spike_frames / self.get_sampling_frequency()
return t_start + spike_times
else:
return spike_frames

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants