Skip to content

Commit cfca649

Browse files
authored
[pytorch] Set shuffle as default in pytorch, use new algorithm (#1188)
* Set shuffle as default in pytorch, use new algorithm * Doc change * Add memory note
1 parent ef62987 commit cfca649

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -413,12 +413,12 @@ def __init__(
413413
var_query: Optional[soma.AxisQuery] = None,
414414
obs_column_names: Sequence[str] = (),
415415
batch_size: int = 1,
416-
shuffle: bool = False,
416+
shuffle: bool = True,
417417
seed: Optional[int] = None,
418418
return_sparse_X: bool = False,
419-
soma_chunk_size: Optional[int] = None,
419+
soma_chunk_size: Optional[int] = 64,
420420
use_eager_fetch: bool = True,
421-
shuffle_chunk_count: Optional[int] = None,
421+
shuffle_chunk_count: Optional[int] = 2000,
422422
) -> None:
423423
r"""Construct a new ``ExperimentDataPipe``.
424424
@@ -443,18 +443,15 @@ def __init__(
443443
``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will
444444
result in :class:`torch.Tensor`\ s of rank 2 (multiple rows).
445445
shuffle:
446-
Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``False`` (no shuffling).
447-
For performance reasons, shuffling is performed in two steps: 1) a global shuffling, where contiguous
448-
rows are grouped into chunks and the order of the chunks is randomized, and then 2) a local
449-
shuffling, where the rows within each chunk are shuffled. Since this class must retrieve data
450-
in chunks (to keep memory requirements to a fixed size), global shuffling ensures that a given row in
451-
the shuffled result can originate from any position in the non-shuffled result ordering. If shuffling
452-
only occurred within each chunk (i.e. "local" shuffling), the first chunk's rows would always be
453-
returned first, the second chunk's rows would always be returned second, and so on. The chunk size is
454-
determined by the ``soma_chunk_size`` parameter. Note that rows within a chunk will maintain
455-
proximity, even after shuffling, so some experimentation may be required to ensure the shuffling is
456-
sufficient for the model training process. To this end, the ``soma_chunk_size`` can be treated as a
457-
hyperparameter that can be tuned.
446+
Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``.
447+
For performance reasons, shuffling is not performed globally across all rows, but rather in chunks.
448+
More specifically, we select ``shuffle_chunk_count`` non-contiguous chunks across all the observations
449+
in the query, concatenate the chunks and shuffle the associated observations.
450+
The randomness of the shuffling is therefore determined by the
451+
(``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have been determined
452+
to yield a good trade-off between randomness and performance. Further tuning may be required for
453+
different type of models. Note that memory usage is correlated to the product
454+
``soma_chunk_size * shuffle_chunk_count``.
458455
seed:
459456
The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using
460457
:class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker
@@ -468,10 +465,8 @@ def __init__(
468465
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of
469466
this class's behavior: 1) The maximum memory utilization, with larger values providing
470467
better read performance, but also requiring more memory; 2) The granularity of the global shuffling
471-
step (see ``shuffle`` parameter for details). If not specified, the value is set to utilize ~1 GiB of
472-
RAM per SOMA chunk read, based upon the number of ``var`` columns (cells/features) being requested
473-
and assuming X data sparsity of 95%; the number of rows per chunk will depend on the number of
474-
``var`` columns being read.
468+
step (see ``shuffle`` parameter for details). The default value of 64 works well in conjunction
469+
with the default ``shuffle_chunk_count`` value.
475470
use_eager_fetch:
476471
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
477472
available for processing via the iterator. This allows network (or filesystem) requests to be made in
@@ -480,6 +475,7 @@ def __init__(
480475
shuffle_chunk_count:
481476
The number of contiguous blocks (chunks) of rows sampled to then concatenate and shuffle.
482477
Larger numbers correspond to more randomness per training batch.
478+
If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``.
483479
484480
Lifecycle:
485481
experimental
@@ -499,7 +495,7 @@ def __init__(
499495
self._encoders = None
500496
self._obs_joinids = None
501497
self._var_joinids = None
502-
self._shuffle_chunk_count = (shuffle_chunk_count or 1) if shuffle else None
498+
self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None
503499
self._shuffle_rng = np.random.default_rng(seed) if shuffle else None
504500
self._initialized = False
505501

api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def test_non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None
142142
measurement_name="RNA",
143143
X_name="raw",
144144
obs_column_names=["label"],
145+
shuffle=False,
145146
use_eager_fetch=use_eager_fetch,
146147
)
147148
row_iter = iter(exp_data_pipe)
@@ -164,6 +165,7 @@ def test_batching__all_batches_full_size(soma_experiment: Experiment, use_eager_
164165
X_name="raw",
165166
obs_column_names=["label"],
166167
batch_size=3,
168+
shuffle=False,
167169
use_eager_fetch=use_eager_fetch,
168170
)
169171
batch_iter = iter(exp_data_pipe)
@@ -214,6 +216,7 @@ def test_batching__partial_final_batch_size(soma_experiment: Experiment, use_eag
214216
X_name="raw",
215217
obs_column_names=["label"],
216218
batch_size=3,
219+
shuffle=False,
217220
use_eager_fetch=use_eager_fetch,
218221
)
219222
batch_iter = iter(exp_data_pipe)
@@ -239,6 +242,7 @@ def test_batching__exactly_one_batch(soma_experiment: Experiment, use_eager_fetc
239242
X_name="raw",
240243
obs_column_names=["label"],
241244
batch_size=3,
245+
shuffle=False,
242246
use_eager_fetch=use_eager_fetch,
243247
)
244248
batch_iter = iter(exp_data_pipe)
@@ -286,6 +290,7 @@ def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch
286290
X_name="raw",
287291
obs_column_names=["label"],
288292
return_sparse_X=True,
293+
shuffle=False,
289294
use_eager_fetch=use_eager_fetch,
290295
)
291296
batch_iter = iter(exp_data_pipe)
@@ -309,6 +314,7 @@ def test_sparse_output__batched(soma_experiment: Experiment, use_eager_fetch: bo
309314
obs_column_names=["label"],
310315
batch_size=3,
311316
return_sparse_X=True,
317+
shuffle=False,
312318
use_eager_fetch=use_eager_fetch,
313319
)
314320
batch_iter = iter(exp_data_pipe)
@@ -350,6 +356,7 @@ def test_encoders(soma_experiment: Experiment) -> None:
350356
measurement_name="RNA",
351357
X_name="raw",
352358
obs_column_names=["label"],
359+
shuffle=False,
353360
batch_size=3,
354361
)
355362
batch_iter = iter(exp_data_pipe)
@@ -413,6 +420,7 @@ def test_distributed__returns_data_partition_for_rank(
413420
X_name="raw",
414421
obs_column_names=["label"],
415422
soma_chunk_size=2,
423+
shuffle=False,
416424
)
417425
full_result = list(iter(dp))
418426

@@ -451,6 +459,7 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
451459
X_name="raw",
452460
obs_column_names=["label"],
453461
soma_chunk_size=2,
462+
shuffle=False,
454463
)
455464

456465
full_result = list(iter(dp))
@@ -475,6 +484,7 @@ def test_experiment_dataloader__non_batched(soma_experiment: Experiment, use_eag
475484
measurement_name="RNA",
476485
X_name="raw",
477486
obs_column_names=["label"],
487+
shuffle=False,
478488
use_eager_fetch=use_eager_fetch,
479489
)
480490
dl = experiment_dataloader(dp)
@@ -498,6 +508,7 @@ def test_experiment_dataloader__batched(soma_experiment: Experiment, use_eager_f
498508
X_name="raw",
499509
obs_column_names=["label"],
500510
batch_size=3,
511+
shuffle=False,
501512
use_eager_fetch=use_eager_fetch,
502513
)
503514
dl = experiment_dataloader(dp)

0 commit comments

Comments
 (0)