@@ -413,12 +413,12 @@ def __init__(
413413 var_query : Optional [soma .AxisQuery ] = None ,
414414 obs_column_names : Sequence [str ] = (),
415415 batch_size : int = 1 ,
416- shuffle : bool = False ,
416+ shuffle : bool = True ,
417417 seed : Optional [int ] = None ,
418418 return_sparse_X : bool = False ,
419- soma_chunk_size : Optional [int ] = None ,
419+ soma_chunk_size : Optional [int ] = 64 ,
420420 use_eager_fetch : bool = True ,
421- shuffle_chunk_count : Optional [int ] = None ,
421+ shuffle_chunk_count : Optional [int ] = 2000 ,
422422 ) -> None :
423423 r"""Construct a new ``ExperimentDataPipe``.
424424
@@ -443,18 +443,15 @@ def __init__(
443443 ``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will
444444 result in :class:`torch.Tensor`\ s of rank 2 (multiple rows).
445445 shuffle:
446- Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``False`` (no shuffling).
447- For performance reasons, shuffling is performed in two steps: 1) a global shuffling, where contiguous
448- rows are grouped into chunks and the order of the chunks is randomized, and then 2) a local
449- shuffling, where the rows within each chunk are shuffled. Since this class must retrieve data
450- in chunks (to keep memory requirements to a fixed size), global shuffling ensures that a given row in
451- the shuffled result can originate from any position in the non-shuffled result ordering. If shuffling
452- only occurred within each chunk (i.e. "local" shuffling), the first chunk's rows would always be
453- returned first, the second chunk's rows would always be returned second, and so on. The chunk size is
454- determined by the ``soma_chunk_size`` parameter. Note that rows within a chunk will maintain
455- proximity, even after shuffling, so some experimentation may be required to ensure the shuffling is
456- sufficient for the model training process. To this end, the ``soma_chunk_size`` can be treated as a
457- hyperparameter that can be tuned.
446+ Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``.
447+ For performance reasons, shuffling is not performed globally across all rows, but rather in chunks.
448+ More specifically, we select ``shuffle_chunk_count`` non-contiguous chunks across all the observations
449+ in the query, concatenate the chunks and shuffle the associated observations.
450+ The randomness of the shuffling is therefore determined by the
451+ (``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have been determined
452+ to yield a good trade-off between randomness and performance. Further tuning may be required for
453+ different type of models. Note that memory usage is correlated to the product
454+ ``soma_chunk_size * shuffle_chunk_count``.
458455 seed:
459456 The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using
460457 :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker
@@ -468,10 +465,8 @@ def __init__(
468465 The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of
469466 this class's behavior: 1) The maximum memory utilization, with larger values providing
470467 better read performance, but also requiring more memory; 2) The granularity of the global shuffling
471- step (see ``shuffle`` parameter for details). If not specified, the value is set to utilize ~1 GiB of
472- RAM per SOMA chunk read, based upon the number of ``var`` columns (cells/features) being requested
473- and assuming X data sparsity of 95%; the number of rows per chunk will depend on the number of
474- ``var`` columns being read.
468+ step (see ``shuffle`` parameter for details). The default value of 64 works well in conjunction
469+ with the default ``shuffle_chunk_count`` value.
475470 use_eager_fetch:
476471 Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
477472 available for processing via the iterator. This allows network (or filesystem) requests to be made in
@@ -480,6 +475,7 @@ def __init__(
480475 shuffle_chunk_count:
481476 The number of contiguous blocks (chunks) of rows sampled to then concatenate and shuffle.
482477 Larger numbers correspond to more randomness per training batch.
478+ If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``.
483479
484480 Lifecycle:
485481 experimental
@@ -499,7 +495,7 @@ def __init__(
499495 self ._encoders = None
500496 self ._obs_joinids = None
501497 self ._var_joinids = None
502- self ._shuffle_chunk_count = ( shuffle_chunk_count or 1 ) if shuffle else None
498+ self ._shuffle_chunk_count = shuffle_chunk_count if shuffle else None
503499 self ._shuffle_rng = np .random .default_rng (seed ) if shuffle else None
504500 self ._initialized = False
505501
0 commit comments