Skip to content

Commit bbe80ec

Browse files
authored
Merge pull request #15 from ruuskas/vol_src
Add support for volumetric source spaces
2 parents d8913e1 + c7f97ba commit bbe80ec

File tree

2 files changed

+77
-34
lines changed

2 files changed

+77
-34
lines changed

conpy/forward.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def select_vertices_in_sensor_range(
8484
"You need to specify an Info object with "
8585
"information about the channels."
8686
)
87+
n_src = len(src)
8788

8889
# Load the head<->MRI transform if necessary
8990
if src[0]["coord_frame"] == FIFF.FIFFV_COORD_MRI:
@@ -149,11 +150,10 @@ def select_vertices_in_sensor_range(
149150
if indices:
150151
return np.flatnonzero(src_sel)
151152
else:
152-
n_lh_verts = src[0]["nuse"]
153-
lh_sel, rh_sel = src_sel[:n_lh_verts], src_sel[n_lh_verts:]
154-
vert_lh = src[0]["vertno"][lh_sel]
155-
vert_rh = src[1]["vertno"][rh_sel]
156-
return [vert_lh, vert_rh]
153+
n_verts = np.cumsum([0] + [s["nuse"] for s in src])
154+
sel = [src_sel[n_verts[i]:n_verts[i+1]] for i in range(n_src)]
155+
verts = [src[i]["vertno"][sel[i]] for i in range(n_src)]
156+
return verts
157157

158158

159159
@verbose
@@ -201,34 +201,38 @@ def restrict_forward_to_vertices(
201201
else:
202202
fwd_out = fwd
203203

204-
lh_vertno, rh_vertno = [src["vertno"] for src in fwd["src"]]
205-
204+
n_src = len(fwd["src"])
205+
vertno = [s["vertno"] for s in fwd["src"]]
206+
n_vertno = [len(hemi_vertno) for hemi_vertno in vertno]
206207
if isinstance(vertno_or_idx[0], int):
207208
logger.info("Interpreting given vertno_or_idx as vertex indices.")
208209
vertno_or_idx = np.asarray(vertno_or_idx)
209210

210211
# Make sure the vertices are in sequential order
211212
fwd_idx = np.sort(vertno_or_idx)
212213

213-
n_vert_lh = len(lh_vertno)
214-
sel_lh_idx = vertno_or_idx[fwd_idx < n_vert_lh]
215-
sel_rh_idx = vertno_or_idx[fwd_idx >= n_vert_lh] - n_vert_lh
216-
sel_lh_vertno = lh_vertno[sel_lh_idx]
217-
sel_rh_vertno = rh_vertno[sel_rh_idx]
214+
vert_idx = np.cumsum([0] + n_vertno)
215+
sel_idx = [
216+
vertno_or_idx[(fwd_idx >= vert_idx[i])
217+
& (fwd_idx < vert_idx[i+1])] - vert_idx[i]
218+
for i in range(n_src)]
219+
sel_vertno = [hemi_vertno[sel] for hemi_vertno, sel in zip(vertno, sel_idx)]
218220
else:
219221
logger.info("Interpreting given vertno_or_idx as vertex numbers.")
220222

221223
# Make sure vertno_or_idx is sorted
222224
vertno_or_idx = [np.sort(v) for v in vertno_or_idx]
225+
sel_vertno = vertno_or_idx
223226

224-
sel_lh_vertno, sel_rh_vertno = vertno_or_idx
225-
src_lh_idx = _find_indices_1d(lh_vertno, sel_lh_vertno, check_vertno)
226-
src_rh_idx = _find_indices_1d(rh_vertno, sel_rh_vertno, check_vertno)
227-
fwd_idx = np.hstack((src_lh_idx, src_rh_idx + len(lh_vertno)))
227+
src_idx = [
228+
_find_indices_1d(hemi_vertno, sel, check_vertno) + sum(n_vertno[:i])
229+
for i, (hemi_vertno, sel) in enumerate(zip(vertno, sel_vertno))
230+
]
231+
fwd_idx = np.hstack(src_idx)
228232

229233
logger.info(
230234
"Restricting forward solution to %d out of %d vertices."
231-
% (len(fwd_idx), len(lh_vertno) + len(rh_vertno))
235+
% (len(fwd_idx), sum(n_vertno))
232236
)
233237

234238
n_orient = fwd["sol"]["ncol"] // fwd["nsource"]
@@ -260,7 +264,7 @@ def _reshape_select(X, dim3, sel):
260264
# Restrict the SourceSpaces inside the forward operator
261265
fwd_out["src"] = restrict_src_to_vertices(
262266
fwd_out["src"],
263-
[sel_lh_vertno, sel_rh_vertno],
267+
sel_vertno,
264268
check_vertno=False,
265269
verbose=False,
266270
)
@@ -307,37 +311,40 @@ def restrict_src_to_vertices(
307311
else:
308312
src_out = src
309313

314+
n_src = len(src)
310315
if vertno_or_idx:
311316
if isinstance(vertno_or_idx[0], int):
312317
logger.info("Interpreting given vertno_or_idx as vertex indices.")
313318
vertno_or_idx = np.asarray(vertno_or_idx)
314-
n_vert_lh = src[0]["nuse"]
315-
ind_lh = vertno_or_idx[vertno_or_idx < n_vert_lh]
316-
ind_rh = vertno_or_idx[vertno_or_idx >= n_vert_lh] - n_vert_lh
317-
vert_no_lh = src[0]["vertno"][ind_lh]
318-
vert_no_rh = src[1]["vertno"][ind_rh]
319+
vert_idx = np.cumsum([0] + [s["nuse"] for s in src])
320+
ind = [
321+
vertno_or_idx[
322+
(vertno_or_idx >= vert_idx[i]) & (vertno_or_idx < vert_idx[i+1])
323+
] - vert_idx[i] for i in range(n_src)
324+
]
325+
vertno = [s["vertno"][inds] for s, inds in zip(src, ind)]
319326
else:
320327
logger.info("Interpreting given vertno_or_idx as vertex numbers.")
321-
vert_no_lh, vert_no_rh = vertno_or_idx
328+
vertno = vertno_or_idx
322329
if check_vertno:
323-
if not (
324-
np.all(np.isin(vert_no_lh, src[0]["vertno"]))
325-
and np.all(np.isin(vert_no_rh, src[1]["vertno"]))
326-
):
327-
raise ValueError(
328-
"One or more vertices were not present in SourceSpaces."
329-
)
330+
for s, verts in zip(src, vertno):
331+
if not np.all(np.isin(verts, s["vertno"])):
332+
raise ValueError(
333+
"One or more vertices were not present in SourceSpaces."
334+
)
330335

331336
else:
332337
# Empty list
333-
vert_no_lh, vert_no_rh = [], []
338+
vertno = [[] for i in range(n_src)]
334339

340+
nuse = sum([s["nuse"] for s in src])
341+
n_vertno = sum([len(verts) for verts in vertno])
335342
logger.info(
336343
"Restricting source space to %d out of %d vertices."
337-
% (len(vert_no_lh) + len(vert_no_rh), src[0]["nuse"] + src[1]["nuse"])
344+
% (n_vertno, nuse)
338345
)
339346

340-
for hemi, verts in zip(src_out, (vert_no_lh, vert_no_rh)):
347+
for hemi, verts in zip(src_out, vertno):
341348
# Ensure vertices are in sequential order
342349
verts = np.sort(verts)
343350

tests/test_forward.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ def src():
3333
)
3434

3535

36+
@pytest.fixture
37+
def vol_src():
38+
"""Load a volume source space."""
39+
path = mne.datasets.sample.data_path()
40+
return mne.read_source_spaces(
41+
op.join(path, "subjects", "sample", "bem", "volume-7mm-src.fif")
42+
)
43+
44+
45+
@pytest.fixture
46+
def vol_fwd():
47+
"""Load a volume forward solution."""
48+
path = mne.datasets.sample.data_path()
49+
return mne.read_forward_solution(
50+
op.join(path, "MEG", "sample", "sample_audvis-meg-vol-7-fwd.fif")
51+
)
52+
53+
3654
def _trans():
3755
path = mne.datasets.sample.data_path()
3856
return op.join(path, "MEG", "sample", "sample_audvis_raw-trans.fif")
@@ -213,6 +231,24 @@ def test_select_vertices_in_sensor_range(fwd, src):
213231
assert_array_equal(verts2[1], np.array([2159]))
214232

215233

234+
def test_select_vertices_in_sensor_range_volume(vol_fwd):
235+
"""Test selecting vertices in sensor range with volumetric source space."""
236+
fwd_r = restrict_forward_to_vertices(vol_fwd, ([[1273, 1312]]))
237+
assert_array_equal(fwd_r["src"][0]["vertno"], np.array([1273, 1312]))
238+
239+
verts = select_vertices_in_sensor_range(fwd_r, 0.08)
240+
assert_array_equal(verts[0], np.array([1273]))
241+
242+
# Test indices
243+
verts = select_vertices_in_sensor_range(fwd_r, 0.08, indices=True)
244+
assert_array_equal(verts, np.array([0]))
245+
246+
# Test restricting
247+
fwd_rs = restrict_forward_to_sensor_range(fwd_r, 0.08)
248+
assert_array_equal(fwd_rs["src"][0]["vertno"], np.array([1273]))
249+
assert len(fwd_rs["src"]) == 1 # No second source space
250+
251+
216252
# FIXME: disabled until we can make a proper test
217253
# def test_radial_coord_system():
218254
# """Test making a radial coordinate system."""

0 commit comments

Comments
 (0)