Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 6 additions & 38 deletions examples/commons/distributed/batch_all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.distributed as dist
import torch.nn.functional as F
from commons.sequence_batch.batch import BaseBatch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

Expand Down Expand Up @@ -430,29 +429,9 @@ def pad_and_all2all_batch(
local_batch_size = batch.batch_size

if world_size == 1:
if batch.actual_batch_size >= local_batch_size:
return batch

# Still need to pad dense tensors for incomplete batches so that
# downstream reshape(batch_size, -1) is valid — same as allgather path.
def _pad_dense(
tensor_or_kjt: Union[torch.Tensor, KeyedJaggedTensor],
) -> Union[torch.Tensor, KeyedJaggedTensor]:
if isinstance(tensor_or_kjt, KeyedJaggedTensor):
return tensor_or_kjt
elif isinstance(tensor_or_kjt, torch.Tensor):
t = tensor_or_kjt
pad_size = (
local_batch_size * (t.numel() // batch.actual_batch_size)
- t.numel()
)
return F.pad(t, (0, pad_size)) if pad_size > 0 else t
else:
raise ValueError(f"Unsupported type: {type(tensor_or_kjt)}")

new_batch = batch._apply_to_tensors_or_kjt(_pad_dense, inplace=False)
new_batch.actual_batch_size = recv_ids.numel()
return new_batch
# Under the unified convention, dense tensors already have
# dim-0 == batch_size, so no padding is needed.
return batch

# ---- Phase 1: KJT fields — fused all-to-all via _all2all_kjt_list ----
kjt_field_names: List[str] = []
Expand All @@ -472,27 +451,16 @@ def _pad_dense(
)

# ---- Phase 2: Dense tensor fields — all-to-all via _all2all_dense_tensor ----
# When actual_batch_size < batch_size, dense tensors have fewer rows
# than local_batch_size. Pad them to local_batch_size before sending
# so that _all2all_dense_tensor's reshape is valid and padding samples
# (assigned by KK) carry zero values.
pad_dense = batch.actual_batch_size < local_batch_size

# Under the unified convention, dense tensors already have dim-0 ==
# batch_size (== local_batch_size), so no padding is needed.
def all2all_field(
tensor_or_kjt: Union[torch.Tensor, KeyedJaggedTensor],
) -> Union[torch.Tensor, KeyedJaggedTensor]:
if isinstance(tensor_or_kjt, KeyedJaggedTensor):
return tensor_or_kjt # already handled in Phase 1
elif isinstance(tensor_or_kjt, torch.Tensor):
t = tensor_or_kjt
if pad_dense:
pad_size = (
local_batch_size * (t.numel() // batch.actual_batch_size)
- t.numel()
)
t = F.pad(t, (0, pad_size))
return _all2all_dense_tensor(
t,
tensor_or_kjt,
dst_rank,
recv_counts,
local_batch_size,
Expand Down
66 changes: 10 additions & 56 deletions examples/commons/distributed/batch_allgather.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import math
from dataclasses import fields
from typing import Dict, List, Tuple, Union

import torch
import torch.nn.functional as F
from commons.ops.collective_ops import (
gather_along_first_dim,
keyed_jagged_tensor_list_allgather,
Expand All @@ -12,17 +10,6 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


def _elems_per_sample(t: torch.Tensor, actual_batch_size: int) -> int:
"""Return the number of flat elements per sample in *t*.

When *actual_batch_size* > 0 we can simply divide; when the batch is
empty we fall back to the product of all dimensions after dim-0.
"""
if actual_batch_size > 0:
return t.numel() // actual_batch_size
return math.prod(t.shape[1:]) if t.dim() > 1 else 1


def pad_and_allgather_batch(
batch: BaseBatch,
pg_group: torch.distributed.ProcessGroup = torch.distributed.group.WORLD,
Expand All @@ -35,17 +22,9 @@ def pad_and_allgather_batch(
(1 for lengths, 1 for values) via :func:`keyed_jagged_tensor_list_allgather`.
Dense tensor fields are gathered separately.

If ``actual_batch_size < batch_size`` on any rank, dense tensors
are zero-padded to ``batch_size`` before gathering so that all
ranks contribute the same dim-0 and global sample indices remain
valid for the subsequent ``index_select``.

**world_size == 1 fast-path**: When the process group contains only
one rank, no collective communication is performed. Dense tensors
are still zero-padded to ``batch_size`` when ``actual_batch_size <
batch_size`` (incomplete batch) and ``actual_batch_size`` is set to
``batch_size`` for consistency with the multi-rank code path. KJT
fields are returned as-is.
Under the unified dense-padding convention, all dense tensor fields
have dim-0 == ``batch_size`` (zero-padded by the dataloader), so no
additional padding is needed before communication.

Args:
return_padding_flag: When True, an extra AllGather is performed
Expand All @@ -62,35 +41,14 @@ def pad_and_allgather_batch(
device = batch.features.values().device
global_batch_size = batch.batch_size * world_size

# ---- Fast path: world_size == 1 — only pad dense tensors, no collectives ----
# ---- Fast path: world_size == 1 — no collectives needed ----
if world_size == 1:
if batch.actual_batch_size < batch.batch_size:
orig_actual_bs = batch.actual_batch_size

def _pad_dense(
tensor_or_kjt: Union[torch.Tensor, KeyedJaggedTensor],
) -> Union[torch.Tensor, KeyedJaggedTensor]:
if isinstance(tensor_or_kjt, KeyedJaggedTensor):
return tensor_or_kjt
elif isinstance(tensor_or_kjt, torch.Tensor):
t = tensor_or_kjt
eps = _elems_per_sample(t, orig_actual_bs)
pad_size = batch.batch_size * eps - t.numel()
return F.pad(t, (0, pad_size)) if pad_size > 0 else t
else:
raise ValueError(f"Unsupported type: {type(tensor_or_kjt)}")

new_batch = batch._apply_to_tensors_or_kjt(_pad_dense, inplace=False)
new_batch.actual_batch_size = global_batch_size
else:
new_batch = batch

if not return_padding_flag:
return new_batch
return batch
is_padding = (
torch.arange(batch.batch_size, device=device) >= batch.actual_batch_size
)
return new_batch, is_padding
return batch, is_padding

# ---- Phase 1: collect KJT fields and fused AllGather them ----
kjt_field_names: List[str] = []
Expand All @@ -106,18 +64,14 @@ def _pad_dense(
zip(kjt_field_names, kjt_outputs)
)

# ---- Phase 2: gather dense tensors (pad if needed) ----
pad_dense = batch.actual_batch_size < batch.batch_size

# ---- Phase 2: gather dense tensors ----
# Under the unified convention, dense tensors already have dim-0 ==
# batch_size (zero-padded by the dataloader), so no additional padding
# is needed — just AllGather directly.
def allgather_field(tensor_or_kjt: Union[torch.Tensor, KeyedJaggedTensor]):
if isinstance(tensor_or_kjt, KeyedJaggedTensor):
return tensor_or_kjt
elif isinstance(tensor_or_kjt, torch.Tensor):
if pad_dense:
eps = _elems_per_sample(tensor_or_kjt, batch.actual_batch_size)
pad_size = batch.batch_size * eps - tensor_or_kjt.numel()
padded = F.pad(tensor_or_kjt, (0, pad_size))
return gather_along_first_dim(padded, pg_group)
return gather_along_first_dim(tensor_or_kjt, pg_group)
else:
raise ValueError(f"Unsupported type: {type(tensor_or_kjt)}")
Expand Down
7 changes: 6 additions & 1 deletion examples/commons/distributed/batch_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def _sort_partitions_padding_last(
def _strip_dense_padding(batch: BaseBatch, actual_bs: int) -> BaseBatch:
"""Remove trailing padding rows from dense tensors, keep KJTs intact.

Under the unified dense-padding convention, every dense tensor has
``batch.batch_size`` in dim-0 (stored flat as ``batch_size * eps``
elements). This function reshapes each dense tensor to
``[batch_size, eps]``, slices ``[:actual_bs]``, and flattens back.

Relies on ``_sort_partitions_padding_last`` having placed real samples
before padding samples so that a simple ``[:actual_bs]`` slice suffices.

Expand All @@ -91,7 +96,7 @@ def _strip_dense_padding(batch: BaseBatch, actual_bs: int) -> BaseBatch:

Returns:
A new ``BaseBatch`` where each dense tensor has ``actual_bs`` rows
and KJT fields are unchanged.
(flat: ``actual_bs * eps`` elements) and KJT fields are unchanged.
"""
full_bs = batch.batch_size

Expand Down
15 changes: 8 additions & 7 deletions examples/commons/sequence_batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@ class BaseBatch(Pipelineable):

Invariants (especially for incomplete / padded batches):

* ``batch_size`` — full size **including** padding samples. This is
the KJT dimension: ``len(kjt.lengths()) == batch_size * num_keys``.
* ``batch_size`` — full size **including** padding samples. Both KJTs
and dense tensors use this dimension:
- KJT: ``len(kjt.lengths()) == batch_size * num_keys``
- Dense: dim-0 == ``batch_size`` (trailing rows are zero-padding)
* ``actual_batch_size`` — number of **real** (non-padding) samples.
This is the dense-tensor dimension: each dense tensor has
``actual_batch_size`` rows (dim-0).
Samples at indices ``[actual_batch_size, batch_size)`` are padding.
* For complete batches the two values are equal.
"""

features: KeyedJaggedTensor
batch_size: int # KJT dimension (includes padding)
batch_size: int # both KJT and dense dimension (includes padding)
feature_to_max_seqlen: Dict[str, int]

contextual_feature_names: List[str] = field(default_factory=list)
labels: Optional[KeyedJaggedTensor] = None
actual_batch_size: Optional[int] = None # dense dimension (real samples only)
actual_batch_size: Optional[int] = None # number of real (non-padding) samples

def __post_init__(self):
if len(set(self.features.keys())) != len(list(self.features.keys())):
Expand Down Expand Up @@ -130,7 +131,7 @@ def index_select_dense_tensor(
tensor: torch.Tensor, indices: torch.Tensor
) -> torch.Tensor:
return (
tensor.reshape(self.actual_batch_size, -1)
tensor.reshape(self.batch_size, -1)
.index_select(dim=0, index=indices)
.reshape(-1)
)
Expand Down
10 changes: 7 additions & 3 deletions examples/hstu/test/tensor_parallel/test_tp_ranking_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,15 @@ def test_tp_gr_ranking_forward_backward_update(
tp_ranking_gr, debug_ranking_gr, debug_ranking_gr_fp32
)
for i, batch in enumerate(history_batches):
_, (losses, logits, _, _) = debug_pipeline.progress(debug_pipeline_batches)
_, (losses_fp32, logits_fp32, _, _) = debug_pipeline_fp32.progress(
_, _, (losses, logits, _, _) = debug_pipeline.progress(
debug_pipeline_batches
)
_, _, (losses_fp32, logits_fp32, _, _) = debug_pipeline_fp32.progress(
debug_pipeline_batches_fp32
)
_, (tp_losses, tp_logits, _, _) = tp_pipeline.progress(iter_history_batches)
_, _, (tp_losses, tp_logits, _, _) = tp_pipeline.progress(
iter_history_batches
)
torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
compare_tpN_to_debug_weights(
tp_ranking_gr, debug_ranking_gr, debug_ranking_gr_fp32
Expand Down
15 changes: 9 additions & 6 deletions examples/tests/commons/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ def generate_batch(

feature_values = torch.randint(0, 100000, (feature_lengths.sum().item(),)).cuda()
if dense_label:
labels = (
torch.arange(
actual_batch_size * num_features, device=torch.device("cuda")
).view(-1)
# Dense labels have dim-0 == batch_size (unified convention).
# Real samples get meaningful values; padding rows are zeros.
labels = torch.zeros(
batch_size * num_features, device=torch.device("cuda"), dtype=torch.long
)
labels[: actual_batch_size * num_features] = (
torch.arange(actual_batch_size * num_features, device=torch.device("cuda"))
// num_features
)
else:
Expand Down Expand Up @@ -161,8 +164,8 @@ def test_batch_allgather(
f"got {stripped.labels.numel()}"
)
assert torch.equal(
stripped.labels, batch.labels
), "Stripped dense labels should match original unpadded labels"
stripped.labels, batch.labels[: actual_bs * num_features]
), "Stripped dense labels should match original real labels"


@pytest.mark.parametrize("batch_size", [128])
Expand Down