Skip to content

Commit 8d8ed2c

Browse files
authored
Merge pull request #189 from neurolib-dev/fix/multimodel_exploration
Quick fixes for BoxSearch xr() when exploring MultiModel
2 parents 19dbe93 + 27a2d37 commit 8d8ed2c

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

neurolib/optimize/exploration/exploration.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,10 @@ def xr(self, bold=False):
488488
:param bold: if True, will load and return only BOLD output
489489
:type bold: bool
490490
"""
491+
492+
def _sanitize_nc_key(k):
493+
return k.replace("*", "_").replace(".", "_").replace("|", "_")
494+
491495
assert self.results is not None, "Run `loadResults()` first to populate the results"
492496
assert len(self.results) == len(self.dfResults)
493497
# create intrisinsic dims for one run
@@ -510,6 +514,8 @@ def xr(self, bold=False):
510514
expand_coords = {}
511515
# iterate exploration coordinates
512516
for k, v in expl_coords.items():
517+
# sanitize keys in the case of stars etc
518+
k = _sanitize_nc_key(k)
513519
# if single values, just assing
514520
if isinstance(v, (str, float, int)):
515521
expand_coords[k] = [v]
@@ -525,10 +531,14 @@ def xr(self, bold=False):
525531
dataarrays.append(data_temp.expand_dims(expand_coords))
526532

527533
# finally, combine all arrays into one
528-
combined = xr.combine_by_coords(dataarrays)["exploration"]
534+
# sometimes combining xr.DataArrays does not work, see https://github.com/pydata/xarray/issues/3248#issuecomment-531511177
535+
# resolved by casting them explicitely to xr.Dataset
536+
combined = xr.combine_by_coords([da.to_dataset() for da in dataarrays])["exploration"]
529537
if self.parameterSpace.star:
530-
combined.attrs = {k: list(self.model.params[k].keys()) for k in orig_search_coords.keys()}
531-
538+
# if we explored over star params, unwrap them into attributes
539+
combined.attrs = {
540+
_sanitize_nc_key(k): list(self.model.params[k].keys()) for k in orig_search_coords.keys() if "*" in k
541+
}
532542
return combined
533543

534544
def getRun(self, runId, filename=None, trajectoryName=None, pypetShortNames=True):

tests/test_exploration.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ def test_multimodel_explore(self):
188188
dataarray = search.xr()
189189
self.assertTrue(isinstance(dataarray, xr.DataArray))
190190
self.assertTrue(isinstance(dataarray.attrs, dict))
191-
self.assertListEqual(list(dataarray.attrs.keys()), list(parameters.dict().keys()))
191+
self.assertListEqual(
192+
list(dataarray.attrs.keys()),
193+
[k.replace("*", "_").replace(".", "_").replace("|", "_") for k in parameters.dict().keys()],
194+
)
192195

193196
end = time.time()
194197
logging.info("\t > Done in {:.2f} s".format(end - start))

0 commit comments

Comments
 (0)