Skip to content

fix: unify dense tensor padding convention (dim-0 == batch_size)#362

Merged
shijieliu merged 2 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/fix-dense-tensor-padding
Apr 17, 2026
Merged

fix: unify dense tensor padding convention (dim-0 == batch_size)#362
shijieliu merged 2 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/fix-dense-tensor-padding

Conversation

@JacoCheung
Copy link
Copy Markdown
Collaborator

@JacoCheung JacoCheung commented Apr 14, 2026

Summary

Fix NCCL AllGather deadlock caused by inconsistent dense tensor padding in pad_and_allgather_batch. All dense tensor fields in BaseBatch now have dim-0 == batch_size (zero-padded by the dataloader), eliminating the ambiguity that caused _elems_per_sample to compute wrong padding sizes for pre-padded tensors.

Closes #361

Changes

  • batch_allgather.py: Remove _elems_per_sample and dense padding logic from allgather_field — dense tensors are already at batch_size, just AllGather directly
  • batch_all2all.py: Remove dense padding logic from all2all_field and world_size==1 fast path
  • batch.py: Update BaseBatch docstring (dense dim-0 == batch_size), fix index_select_dense_tensor to use batch_size instead of actual_batch_size
  • batch_shuffler.py: Update _strip_dense_padding docstring
  • test_distributed.py: Update generate_batch to pad dense labels to [batch_size]

Net: -108 lines of padding/reshaping logic removed.

Root Cause

pad_and_allgather_batch assumed dense tensors had actual_batch_size elements and computed eps = numel // actual_batch_size to determine padding. But num_candidates was already padded to [batch_size] by the dataloader, so 128 // 11 = 11 (wrong, should be 1), producing a [1408] tensor while other ranks sent [128] → NCCL size mismatch → deadlock.

Design Document

Batch Padding & Shuffle Design

This document describes how padding, shuffling, and load balancing work together in the distributed training pipeline.

1. Batch Structure

BaseBatch holds two types of data:

Field Type Dimension Padding
KJT (KeyedJaggedTensor) batch_size * num_keys lengths Padding samples have length=0
Dense Tensor batch_size in dim-0 (flat: batch_size * eps elements) Padding rows are zeros

Key invariant: Both KJT and dense tensors always have batch_size in their batch dimension. actual_batch_size only records how many leading samples are real (non-padding).

batch_size = 128          # full size including padding
actual_batch_size = 100   # first 100 samples are real

KJT lengths:   [L0, L1, ..., L99, 0, 0, ..., 0]   (128 entries, last 28 are zero)
Dense tensor:  [V0, V1, ..., V99, 0, 0, ..., 0]   (128 entries, last 28 are zero)

2. Why KJT Padding Is Required

KJT lengths must be padded to batch_size with zeros because:

  1. AllGather symmetry: keyed_jagged_tensor_list_allgather requires all ranks to contribute the same number of length entries.
  2. KK partitioning: The Karmarkar-Karp algorithm requires len(workloads) % world_size == 0. Workloads are derived from KJT lengths, so the global batch must be batch_size * world_size.
  3. FBGEMM ops: keyed_jagged_index_select_dim1 requires the batch_size argument to match the lengths dimension.
  4. Global indexing: After AllGather, index_select uses global indices (0 to global_batch_size - 1). Without padding, indices would be invalid.

Padding samples have length=0, producing zero values in the KJT. The HSTU attention kernel naturally skips these (zero-length sequences produce no attention output).

3. Why Dense Padding Is Required

Dense tensors must have dim-0 == batch_size because:

  1. AllGather symmetry: gather_along_first_dim / all_gather_into_tensor requires all ranks to send the same tensor shape. If one rank has fewer rows, NCCL hangs.
  2. index_select correctness: BaseBatch.index_select reshapes dense tensors as [batch_size, -1] before indexing. If dim-0 != batch_size, the reshape fails.
  3. Consistency with KJT: After shuffle, _strip_dense_padding uses reshape(batch_size, -1)[:actual_bs] to extract real rows. This requires dim-0 == batch_size.

The dataloader is responsible for padding dense tensors to batch_size (e.g., pad_tensor() in hstu_sequence_dataset.py).

4. Shuffle Pipeline

4.1 Two-Phase Async Shuffle

The shuffle has two phases, designed to overlap with forward/backward:

Phase 1 (on _memcpy_stream, during forward):
  1. Compute per-sample workloads from KJT lengths
  2. AllGather workloads across all ranks
  3. Submit Karmarkar-Karp algorithm to background thread

Phase 2 (on _memcpy_stream, after forward):
  4. Wait for KK result (partition indices)
  5. AllGather batch data (KJT + dense)
  6. index_select to pick this rank's assigned samples
  7. Strip dense padding

4.2 Karmarkar-Karp (KK) Algorithm

KK partitions the global batch (batch_size * world_size samples, including padding) into world_size equal-size groups, minimizing the max workload across groups.

  • Input: [W0, W1, ..., W_{global_bs-1}] where padding samples have W=0
  • Output: partitions[rank] = [idx_0, idx_1, ..., idx_{batch_size-1}]
  • Each rank gets exactly batch_size indices (some may point to padding samples)

4.3 _sort_partitions_padding_last

After KK, each rank's indices may interleave real and padding samples:

Before sort: [3, 17, 0, 12, 8, 15, 2, 20]
                      ^padding     ^padding

After sort:  [3, 17, 12, 8, 2, 20, 0, 15]
              |--- real ---|  |-padding-|

This ensures [:actual_bs] slicing extracts exactly the real samples, enabling _strip_dense_padding to work with a simple slice.

4.4 _strip_dense_padding

After shuffle + index_select, the batch has batch_size samples (some padding). _strip_dense_padding removes trailing padding rows from dense tensors:

t.reshape(batch_size, -1)[:actual_bs].reshape(-1)

KJT fields are not stripped — they keep batch_size lengths (with zeros for padding). This is safe because downstream ops (attention, embedding) handle zero-length sequences.

5. Communication Pattern

AllGather path:
  Local batch [bs] → AllGather → Global batch [bs * W] → index_select [bs] → strip [actual_bs]

All2All path:
  Local batch [bs] → All2All (send to target ranks) → Local batch [bs] → strip [actual_bs]

Both paths produce a batch where:

  • KJT has batch_size samples (padded)
  • Dense tensors have actual_bs rows (stripped)
  • actual_batch_size is set to actual_bs

6. Loss Computation

After shuffle, the model processes actual_bs real samples. The loss is computed only on these samples:

# Pipeline handles this via num_loss_tokens():
global_tokens = batch.num_loss_tokens()  # counts only real samples
torch.distributed.all_reduce(global_tokens)
loss = local_loss_sum * dp_size / global_tokens  # normalize by real token count

Padding samples either:

  • Produce zero loss (because their features/labels are zeros), or
  • Are excluded by _strip_dense_padding before entering the model

Test Plan

  • test_distributed.py — 52 cases (4 GPU), all passed
  • test_batch_shuffler_factory — passed (4 GPU)
  • pretrain_gr_ranking 4-GPU movielen integration test — 1000 iters + eval, AUC=0.80, no hang
  • CI pipeline

CI

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 14, 2026

Greptile Summary

This PR fixes a real NCCL AllGather deadlock by unifying the dense tensor padding convention across BaseBatch: dense fields now always have dim-0 == batch_size (pre-padded by the dataloader), removing the ambiguous _elems_per_sample logic that computed wrong pad sizes for already-padded tensors. The net removal of ~108 lines of fragile padding/reshaping logic in batch_allgather.py and batch_all2all.py is a clear improvement in correctness and simplicity.

Confidence Score: 5/5

Safe to merge — all findings are P2 (documentation and an unexplained incidental test change); no logic or correctness issues.

The core fix (removing _elems_per_sample and inline padding) is correct. index_select_dense_tensor is properly updated to use batch_size. Tests cover allgather and all2all paths including incomplete batches. Only remaining issues are a stale docstring paragraph and an unexplained test arity change, neither of which affects runtime behavior.

examples/commons/distributed/batch_all2all.py — world_size==1 docstring section is stale; examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py — unexplained progress() arity change

Important Files Changed

Filename Overview
examples/commons/distributed/batch_allgather.py Removes _elems_per_sample and all dense-padding logic; allgather_field now calls gather_along_first_dim directly. world_size==1 fast path simplified to a direct return. Logic is correct under the new convention.
examples/commons/distributed/batch_all2all.py Removes dense padding from both the world_size==1 fast path and Phase 2. Implementation is correct, but the docstring for the world_size==1 fast path (lines 403–410) still describes old behavior (explicit padding + actual_batch_size update).
examples/commons/sequence_batch/batch.py Docstring updated to reflect unified convention; index_select_dense_tensor correctly changed from self.actual_batch_size to self.batch_size for the reshape, fixing the root cause of potential reshape failures.
examples/commons/distributed/batch_shuffler.py Docstring of _strip_dense_padding updated to describe the unified convention. No logic changes; existing reshape(full_bs, -1)[:actual_bs].reshape(-1) is already correct.
examples/tests/commons/test_distributed.py generate_batch now pads dense labels to [batch_size * num_features] with zeros; test assertion updated to compare against batch.labels[:actual_bs * num_features]. Both changes correctly reflect the new convention.
examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py Unpacking of pipeline.progress() changed from 2 return values to 3 (_, _, (...) instead of _, (...)). This change is unrelated to the dense-padding fix and is not mentioned in the PR description, making it hard to verify independently.

Sequence Diagram

sequenceDiagram
    participant DL as Dataloader
    participant BS as BaseBatch
    participant AG as pad_and_allgather_batch
    participant A2A as pad_and_all2all_batch
    participant SH as BatchShuffler

    DL->>BS: dense tensor pre-padded to [batch_size * eps]<br/>(dim-0 == batch_size, padding rows = 0)
    Note over BS: batch.actual_batch_size le batch.batch_size

    alt AllGather path
        SH->>AG: pad_and_allgather_batch(batch)
        Note over AG: No padding needed (already batch_size rows)
        AG->>AG: gather_along_first_dim(dense_tensor)
        AG->>SH: global_batch [batch_size x W]
        SH->>SH: index_select(indices_this_rank) reshape(batch_size, -1)
        SH->>SH: _strip_dense_padding(new_batch, actual_bs)
    else All2All path
        SH->>A2A: pad_and_all2all_batch(batch, recv_ids)
        Note over A2A: No padding needed (already batch_size rows)
        A2A->>A2A: _all2all_dense_tensor(tensor.reshape(bs,-1))
        A2A->>SH: redistributed batch
        SH->>SH: _strip_dense_padding(new_batch, actual_bs)
    end

    SH->>SH: new_batch.actual_batch_size = actual_bs
Loading

Reviews (3): Last reviewed commit: "fix: update test_tp_ranking_gr to match ..." | Re-trigger Greptile

@JacoCheung
Copy link
Copy Markdown
Collaborator Author

/build

3 similar comments
@JacoCheung
Copy link
Copy Markdown
Collaborator Author

/build

@JacoCheung
Copy link
Copy Markdown
Collaborator Author

/build

@JacoCheung
Copy link
Copy Markdown
Collaborator Author

/build

JacoCheung and others added 2 commits April 17, 2026 09:54
All dense tensor fields in BaseBatch now have dim-0 == batch_size
(zero-padded by the dataloader). This eliminates the ambiguity in
pad_and_allgather_batch where _elems_per_sample computed wrong eps
for pre-padded tensors (e.g. num_candidates), causing NCCL AllGather
size mismatch and deadlock on incomplete batches.

Changes:
- Remove _elems_per_sample and dense padding logic from allgather/all2all
- Simplify BaseBatch.index_select to use batch_size (not actual_batch_size)
- Update _strip_dense_padding docstring
- Update test generate_batch to pad dense labels to batch_size

Closes NVIDIA#361

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
progress() now returns (loss, global_tokens, extras) instead of
(loss, extras). Update all 3 call sites in the TP test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@shijieliu shijieliu force-pushed the junzhang/fix-dense-tensor-padding branch from 9363bae to 6625a3d Compare April 17, 2026 01:54
@JacoCheung
Copy link
Copy Markdown
Collaborator Author

/build

@JacoCheung
Copy link
Copy Markdown
Collaborator Author

JacoCheung commented Apr 17, 2026

Pipeline #48748914 -- failed

Job Status Log
pre_check ✅ success view
train_build ✅ success view
inference_build ✅ success view
tritonserver_build ✅ success view
build_whl ✅ success view
dynamicemb_test_fwd_bwd_8gpus ✅ success view
dynamicemb_test_load_dump_8gpus ✅ success view
unit_test_1gpu_a100 ❌ failed view
unit_test_1gpu_h100 ❌ failed view
unit_test_4gpu ✅ success view
unit_test_tp_4gpu ❌ failed view
L20_unit_test_1gpu ✅ success view
inference_unit_test_1gpu ✅ success view
inference_test_1gpu ✅ success view

Result: 11/14 jobs passed

View full pipeline

@shijieliu shijieliu merged commit 3db734f into NVIDIA:main Apr 17, 2026
0 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BaseBatch dense tensor shape convention inconsistency with padded fields (num_candidates)

2 participants