fix: unify dense tensor padding convention (dim-0 == batch_size)#362
Conversation
Greptile SummaryThis PR fixes a real NCCL AllGather deadlock by unifying the dense tensor padding convention across Confidence Score: 5/5Safe to merge — all findings are P2 (documentation and an unexplained incidental test change); no logic or correctness issues. The core fix (removing examples/commons/distributed/batch_all2all.py — world_size==1 docstring section is stale; examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py — unexplained progress() arity change Important Files Changed
Sequence DiagramsequenceDiagram
participant DL as Dataloader
participant BS as BaseBatch
participant AG as pad_and_allgather_batch
participant A2A as pad_and_all2all_batch
participant SH as BatchShuffler
DL->>BS: dense tensor pre-padded to [batch_size * eps]<br/>(dim-0 == batch_size, padding rows = 0)
Note over BS: batch.actual_batch_size le batch.batch_size
alt AllGather path
SH->>AG: pad_and_allgather_batch(batch)
Note over AG: No padding needed (already batch_size rows)
AG->>AG: gather_along_first_dim(dense_tensor)
AG->>SH: global_batch [batch_size x W]
SH->>SH: index_select(indices_this_rank) reshape(batch_size, -1)
SH->>SH: _strip_dense_padding(new_batch, actual_bs)
else All2All path
SH->>A2A: pad_and_all2all_batch(batch, recv_ids)
Note over A2A: No padding needed (already batch_size rows)
A2A->>A2A: _all2all_dense_tensor(tensor.reshape(bs,-1))
A2A->>SH: redistributed batch
SH->>SH: _strip_dense_padding(new_batch, actual_bs)
end
SH->>SH: new_batch.actual_batch_size = actual_bs
Reviews (3): Last reviewed commit: "fix: update test_tp_ranking_gr to match ..." | Re-trigger Greptile |
|
/build |
3 similar comments
|
/build |
|
/build |
|
/build |
All dense tensor fields in BaseBatch now have dim-0 == batch_size (zero-padded by the dataloader). This eliminates the ambiguity in pad_and_allgather_batch where _elems_per_sample computed wrong eps for pre-padded tensors (e.g. num_candidates), causing NCCL AllGather size mismatch and deadlock on incomplete batches. Changes: - Remove _elems_per_sample and dense padding logic from allgather/all2all - Simplify BaseBatch.index_select to use batch_size (not actual_batch_size) - Update _strip_dense_padding docstring - Update test generate_batch to pad dense labels to batch_size Closes NVIDIA#361 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
progress() now returns (loss, global_tokens, extras) instead of (loss, extras). Update all 3 call sites in the TP test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
9363bae to
6625a3d
Compare
|
/build |
|
❌ Pipeline #48748914 -- failed
Result: 11/14 jobs passed |
Summary
Fix NCCL AllGather deadlock caused by inconsistent dense tensor padding in
pad_and_allgather_batch. All dense tensor fields inBaseBatchnow have dim-0 ==batch_size(zero-padded by the dataloader), eliminating the ambiguity that caused_elems_per_sampleto compute wrong padding sizes for pre-padded tensors.Closes #361
Changes
batch_allgather.py: Remove_elems_per_sampleand dense padding logic fromallgather_field— dense tensors are already atbatch_size, just AllGather directlybatch_all2all.py: Remove dense padding logic fromall2all_fieldand world_size==1 fast pathbatch.py: UpdateBaseBatchdocstring (dense dim-0 ==batch_size), fixindex_select_dense_tensorto usebatch_sizeinstead ofactual_batch_sizebatch_shuffler.py: Update_strip_dense_paddingdocstringtest_distributed.py: Updategenerate_batchto pad dense labels to[batch_size]Net: -108 lines of padding/reshaping logic removed.
Root Cause
pad_and_allgather_batchassumed dense tensors hadactual_batch_sizeelements and computedeps = numel // actual_batch_sizeto determine padding. Butnum_candidateswas already padded to[batch_size]by the dataloader, so128 // 11 = 11(wrong, should be 1), producing a[1408]tensor while other ranks sent[128]→ NCCL size mismatch → deadlock.Design Document
Batch Padding & Shuffle Design
This document describes how padding, shuffling, and load balancing work together in the distributed training pipeline.
1. Batch Structure
BaseBatchholds two types of data:batch_size * num_keyslengthslength=0batch_sizein dim-0 (flat:batch_size * epselements)Key invariant: Both KJT and dense tensors always have
batch_sizein their batch dimension.actual_batch_sizeonly records how many leading samples are real (non-padding).2. Why KJT Padding Is Required
KJT lengths must be padded to
batch_sizewith zeros because:keyed_jagged_tensor_list_allgatherrequires all ranks to contribute the same number of length entries.len(workloads) % world_size == 0. Workloads are derived from KJT lengths, so the global batch must bebatch_size * world_size.keyed_jagged_index_select_dim1requires thebatch_sizeargument to match the lengths dimension.index_selectuses global indices (0 toglobal_batch_size - 1). Without padding, indices would be invalid.Padding samples have
length=0, producing zero values in the KJT. The HSTU attention kernel naturally skips these (zero-length sequences produce no attention output).3. Why Dense Padding Is Required
Dense tensors must have dim-0 ==
batch_sizebecause:gather_along_first_dim/all_gather_into_tensorrequires all ranks to send the same tensor shape. If one rank has fewer rows, NCCL hangs.index_selectcorrectness:BaseBatch.index_selectreshapes dense tensors as[batch_size, -1]before indexing. If dim-0 !=batch_size, the reshape fails._strip_dense_paddingusesreshape(batch_size, -1)[:actual_bs]to extract real rows. This requires dim-0 ==batch_size.The dataloader is responsible for padding dense tensors to
batch_size(e.g.,pad_tensor()inhstu_sequence_dataset.py).4. Shuffle Pipeline
4.1 Two-Phase Async Shuffle
The shuffle has two phases, designed to overlap with forward/backward:
4.2 Karmarkar-Karp (KK) Algorithm
KK partitions the global batch (
batch_size * world_sizesamples, including padding) intoworld_sizeequal-size groups, minimizing the max workload across groups.[W0, W1, ..., W_{global_bs-1}]where padding samples haveW=0partitions[rank] = [idx_0, idx_1, ..., idx_{batch_size-1}]batch_sizeindices (some may point to padding samples)4.3
_sort_partitions_padding_lastAfter KK, each rank's indices may interleave real and padding samples:
This ensures
[:actual_bs]slicing extracts exactly the real samples, enabling_strip_dense_paddingto work with a simple slice.4.4
_strip_dense_paddingAfter shuffle + index_select, the batch has
batch_sizesamples (some padding)._strip_dense_paddingremoves trailing padding rows from dense tensors:KJT fields are not stripped — they keep
batch_sizelengths (with zeros for padding). This is safe because downstream ops (attention, embedding) handle zero-length sequences.5. Communication Pattern
Both paths produce a batch where:
batch_sizesamples (padded)actual_bsrows (stripped)actual_batch_sizeis set toactual_bs6. Loss Computation
After shuffle, the model processes
actual_bsreal samples. The loss is computed only on these samples:Padding samples either:
_strip_dense_paddingbefore entering the modelTest Plan
test_distributed.py— 52 cases (4 GPU), all passedtest_batch_shuffler_factory— passed (4 GPU)pretrain_gr_ranking4-GPU movielen integration test — 1000 iters + eval, AUC=0.80, no hangCI