Skip to content

Commit 032f725

Browse files
committed
Implement lazy masks in the ToyHF.
1 parent 3791ba1 commit 032f725

File tree

4 files changed

+116
-23
lines changed

4 files changed

+116
-23
lines changed

swiftgalaxy/demo_data.py

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
import numpy as np
88
import unyt as u
99
import subprocess
10-
from typing import Optional, Callable, Any, Union, Sequence
10+
from types import EllipsisType
11+
from typing import Optional, Callable, Any, Union, List, Sequence
12+
from numpy.typing import NDArray
1113
from astropy.cosmology import LambdaCDM
1214
from astropy import units as U
1315
from swiftsimio.objects import cosmo_array
1416
import swiftsimio
1517
from swiftsimio import Writer, SWIFTMask
16-
from swiftgalaxy import MaskCollection, SWIFTGalaxy
18+
from swiftgalaxy import SWIFTGalaxy
19+
from swiftgalaxy.masks import MaskCollection, LazyMask
1720
from swiftgalaxy.halo_catalogues import _HaloCatalogue
1821
from swiftsimio.units import cosmo_units
1922

@@ -314,7 +317,7 @@ class ToyHF(_HaloCatalogue):
314317
def __init__(
315318
self,
316319
snapfile: Union[str, Path] = _toysnap_filename,
317-
index: Union[int, Sequence[int]] = 0,
320+
index: Union[int, List[int]] = 0,
318321
) -> None:
319322
self.snapfile = snapfile
320323
if isinstance(index, Sequence):
@@ -394,23 +397,87 @@ def _generate_bound_only_mask(self, sg: SWIFTGalaxy) -> MaskCollection:
394397
out : :class:`~swiftgalaxy.masks.MaskCollection`
395398
The extra mask.
396399
"""
397-
# the two objects are in different cells, remember we're masking cell particles
398-
if self.index == 0:
399-
extra_mask = MaskCollection(
400-
gas=np.s_[-_n_g_1:],
401-
dark_matter=np.s_[-_n_dm_1:],
402-
stars=np.s_[...],
403-
black_holes=np.s_[...],
404-
)
405-
else: # self.index == 1
406-
extra_mask = MaskCollection(
407-
gas=np.s_[-_n_g_2:],
408-
dark_matter=np.s_[-_n_dm_2:],
409-
stars=np.s_[...],
410-
black_holes=np.s_[...],
411-
)
412400

413-
return extra_mask
401+
def generate_lazy_mask(group_name: str) -> LazyMask:
402+
"""
403+
Generate a function that evaluates a mask for bound particles of a specified
404+
particle type. The generated function should have one parameter, accepting a
405+
boolean, that toggles masking the data loaded during the construction of the
406+
mask on and off.
407+
408+
Parameters
409+
----------
410+
group_name : :obj:`str`
411+
The particle type to evaluate a mask for.
412+
413+
Returns
414+
-------
415+
out : Callable
416+
The generated function that evaluates a mask.
417+
"""
418+
419+
def lazy_mask(
420+
mask_loaded_data: bool = True,
421+
) -> Union[NDArray, slice, EllipsisType]:
422+
"""
423+
"Evaluate" a mask that selects bound particles. In reality we know what
424+
the mask is a priori. We pretend that we need to load the particle ids
425+
so that we can test the behaviour of a dataset loaded while constructing
426+
the mask.
427+
428+
This function must optionally mask the data (``particle_ids``) that it
429+
has loaded.
430+
431+
Parameters
432+
----------
433+
mask_loaded_data : :obj:`bool`, default ``True``
434+
If ``True``, data loaded while constructing the mask is masked
435+
during this function call. Set ``False`` when called from
436+
a :class:`~swiftgalaxy.iterator.SWIFTGalaxies` "server".
437+
438+
Returns
439+
-------
440+
out : :class:`~numpy.ndarray`, :obj:`slice` or :obj:`Ellipsis`
441+
The mask that selects bound particles.
442+
"""
443+
getattr(
444+
getattr(sg, group_name)._particle_dataset,
445+
sg.id_particle_dataset_name,
446+
) # load the ids
447+
assert isinstance(self._mask_index, int) # placate mypy
448+
mask = {
449+
"gas": (np.s_[-_n_g_1:], np.s_[-_n_g_2:])[self._mask_index],
450+
"dark_matter": (np.s_[-_n_dm_1:], np.s_[-_n_dm_2:])[
451+
self._mask_index
452+
],
453+
"stars": np.s_[...],
454+
"black_holes": np.s_[...],
455+
}[group_name]
456+
if mask_loaded_data:
457+
# mask the particle_ids
458+
setattr(
459+
getattr(sg, group_name)._particle_dataset,
460+
f"_{sg.id_particle_dataset_name}",
461+
getattr(
462+
getattr(sg, group_name)._particle_dataset,
463+
f"_{sg.id_particle_dataset_name}",
464+
)[mask],
465+
)
466+
assert (
467+
isinstance(mask, np.ndarray)
468+
or isinstance(mask, slice)
469+
or (mask is Ellipsis)
470+
) # placate mypy
471+
return mask
472+
473+
return LazyMask(mask_function=lazy_mask)
474+
475+
return MaskCollection(
476+
**{
477+
group_name: generate_lazy_mask(group_name)
478+
for group_name in sg.metadata.present_group_names
479+
}
480+
)
414481

415482
@property
416483
def centre(self) -> cosmo_array:

swiftgalaxy/halo_catalogues.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ def lazy_mask(mask_loaded_data: bool = True) -> NDArray:
806806
Evaluate a mask that selects bound particles by comparing the particle
807807
group membership dataset ``group_nr_bound`` to the halo catalogue index.
808808
809-
This function must optionally mask the data (``group_nr_bound``) that is
809+
This function must optionally mask the data (``group_nr_bound``) that it
810810
has loaded.
811811
812812
Parameters
@@ -1178,7 +1178,7 @@ def lazy_mask(mask_loaded_data: bool = True) -> NDArray:
11781178
Evaluate a mask that selects bound particles by comparing the
11791179
``particle_ids`` to the list of bound particle IDs.
11801180
1181-
This function must optionally mask the data (``particle_ids``) that is has
1181+
This function must optionally mask the data (``particle_ids``) that it has
11821182
loaded.
11831183
11841184
Parameters
@@ -1799,7 +1799,7 @@ def lazy_mask(mask_loaded_data: bool = True) -> Union[NDArray, slice]:
17991799
Evaluate a mask that selects bound particles by comparing the lists of
18001800
bound particle indices to the ranges read in the spatial mask.
18011801
1802-
This function must optionally mask the data that is has loaded, but it
1802+
This function must optionally mask the data that it has loaded, but it
18031803
loads nothing.
18041804
18051805
Parameters

swiftgalaxy/reader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,12 @@ def _mask_dataset(self, mask: LazyMask) -> None:
844844
particle_metadata = getattr(
845845
self._particle_dataset.metadata, f"{particle_name}_properties"
846846
)
847+
# force old mask evaluation to ensure any data loaded during evaluation
848+
# are in memory and have the old mask applied, if any:
849+
old_mask = getattr(self._swiftgalaxy._extra_mask, particle_name)
850+
if old_mask is not None:
851+
old_mask._evaluate()
852+
# apply the new mask to any data already in memory:
847853
for field_name in particle_metadata.field_names:
848854
if self._is_namedcolumns(field_name):
849855
for named_column in getattr(self, field_name).named_columns:
@@ -861,6 +867,7 @@ def _mask_dataset(self, mask: LazyMask) -> None:
861867
)
862868
elif getattr(self._particle_dataset, f"_{field_name}") is not None:
863869
setattr(self, field_name, getattr(self, field_name)[mask.mask])
870+
# also the derived coordinates, if any:
864871
self._mask_derived_coordinates(mask)
865872
if getattr(self._swiftgalaxy._extra_mask, particle_name) is None:
866873
setattr(self._swiftgalaxy._extra_mask, particle_name, mask)
@@ -875,7 +882,6 @@ def _mask_dataset(self, mask: LazyMask) -> None:
875882
particle_name
876883
]
877884
)
878-
old_mask = getattr(self._swiftgalaxy._extra_mask, particle_name)
879885
# need to convert to an integer mask to combine
880886
# (boolean is insufficient in case of re-ordering masks)
881887
setattr(

tests/test_masking.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def test_reordering_slice_mask(self, sg, particle_name, before_load):
3535
ids_before = getattr(sg, particle_name).particle_ids
3636
if before_load:
3737
getattr(sg, particle_name)._particle_dataset._particle_ids = None
38+
del getattr(sg._extra_mask, particle_name)._mask
39+
getattr(sg._extra_mask, particle_name)._evaluated = False
3840
sg.mask_particles(MaskCollection(**{particle_name: mask}))
3941
ids = getattr(sg, particle_name).particle_ids
4042
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
@@ -54,6 +56,8 @@ def test_reordering_int_mask(self, sg, particle_name, before_load):
5456
mask = mask[: mask.size // 2]
5557
if before_load:
5658
getattr(sg, particle_name)._particle_dataset._particle_ids = None
59+
del getattr(sg._extra_mask, particle_name)._mask
60+
getattr(sg._extra_mask, particle_name)._evaluated = False
5761
sg.mask_particles(MaskCollection(**{particle_name: mask}))
5862
ids = getattr(sg, particle_name).particle_ids
5963
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
@@ -69,6 +73,8 @@ def test_bool_mask(self, sg, particle_name, before_load):
6973
mask = np.random.rand(ids_before.size) > 0.5
7074
if before_load:
7175
getattr(sg, particle_name)._particle_dataset._particle_ids = None
76+
del getattr(sg._extra_mask, particle_name)._mask
77+
getattr(sg._extra_mask, particle_name)._evaluated = False
7278
sg.mask_particles(MaskCollection(**{particle_name: mask}))
7379
ids = getattr(sg, particle_name).particle_ids
7480
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
@@ -82,6 +88,8 @@ def test_namedcolumn_masked(self, sg, before_load):
8288
mask = np.random.rand(neutral_before.size) > 0.5
8389
if before_load:
8490
sg.gas.hydrogen_ionization_fractions._named_column_dataset._neutral = None
91+
del sg._extra_mask.gas._mask
92+
sg._extra_mask.gas._evaluated = False
8593
sg.mask_particles(MaskCollection(**{"gas": mask}))
8694
neutral = sg.gas.hydrogen_ionization_fractions.neutral
8795
assert_allclose_units(
@@ -135,6 +143,8 @@ def test_reordering_slice_mask(self, sg, particle_name, before_load):
135143
ids_before = getattr(sg, particle_name).particle_ids
136144
if before_load:
137145
getattr(sg, particle_name)._particle_dataset._particle_ids = None
146+
del getattr(sg._extra_mask, particle_name)._mask
147+
getattr(sg._extra_mask, particle_name)._evaluated = False
138148
masked_dataset = getattr(sg, particle_name)[mask]
139149
ids = masked_dataset.particle_ids
140150
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
@@ -154,6 +164,8 @@ def test_reordering_int_mask(self, sg, particle_name, before_load):
154164
mask = mask[: mask.size // 2]
155165
if before_load:
156166
getattr(sg, particle_name)._particle_dataset._particle_ids = None
167+
del getattr(sg._extra_mask, particle_name)._mask
168+
getattr(sg._extra_mask, particle_name)._evaluated = False
157169
masked_dataset = getattr(sg, particle_name)[mask]
158170
ids = masked_dataset.particle_ids
159171
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
@@ -169,6 +181,8 @@ def test_bool_mask(self, sg, particle_name, before_load):
169181
mask = np.random.rand(ids_before.size) > 0.5
170182
if before_load:
171183
getattr(sg, particle_name)._particle_dataset._particle_ids = None
184+
del getattr(sg._extra_mask, particle_name)._mask
185+
getattr(sg._extra_mask, particle_name)._evaluated = False
172186
masked_dataset = getattr(sg, particle_name)[mask]
173187
ids = masked_dataset.particle_ids
174188
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
@@ -184,6 +198,8 @@ def test_reordering_slice_mask(self, sg, before_load):
184198
fractions_before = sg.gas.hydrogen_ionization_fractions.neutral
185199
if before_load:
186200
sg.gas.hydrogen_ionization_fractions._neutral = None
201+
del sg._extra_mask.gas._mask
202+
sg._extra_mask.gas._evaluated = False
187203
masked_namedcolumnsdataset = sg.gas.hydrogen_ionization_fractions[mask]
188204
fractions = masked_namedcolumnsdataset.neutral
189205
assert_allclose_units(
@@ -204,6 +220,8 @@ def test_reordering_int_mask(self, sg, before_load):
204220
mask = mask[: mask.size // 2]
205221
if before_load:
206222
sg.gas.hydrogen_ionization_fractions._neutral = None
223+
del sg._extra_mask.gas._mask
224+
sg._extra_mask.gas._mask = False
207225
masked_namedcolumnsdataset = sg.gas.hydrogen_ionization_fractions[mask]
208226
fractions = masked_namedcolumnsdataset.neutral
209227
assert_allclose_units(
@@ -220,6 +238,8 @@ def test_bool_mask(self, sg, before_load):
220238
mask = np.random.rand(fractions_before.size) > 0.5
221239
if before_load:
222240
sg.gas.hydrogen_ionization_fractions._neutral = None
241+
del sg._extra_mask.gas._mask
242+
sg._extra_mask.gas._evaluated = False
223243
masked_namedcolumnsdataset = sg.gas.hydrogen_ionization_fractions[mask]
224244
fractions = masked_namedcolumnsdataset.neutral
225245
assert_allclose_units(

0 commit comments

Comments
 (0)