Skip to content

Commit b9cc7e8

Browse files
authored
[Trainer][Megatron] Sequence packing + Context Parallel for Megatron (#274)
# Overview Adds sequence packing + context parallel support for megatron backend. Note that context parallel without sequence packing is not supported. ## Correctness Check ### CP + TP + PP <img width="368" height="275" alt="image" src="https://github.com/user-attachments/assets/53fdd009-3af9-4352-8e63-7604b2dfdeee" /> ### Just Sequence Packing <img width="366" height="278" alt="image" src="https://github.com/user-attachments/assets/9a40dfdf-af8c-44e8-bc54-78e13d187daa" /> ### Just CP + Sequence Packing <img width="364" height="281" alt="image" src="https://github.com/user-attachments/assets/c69522e8-52b1-4581-8a66-a579b29bbb0d" /> ### Timing Adding CP is slower as expected, adding just sequence packing is also slightly slower for tp=2,pp=2. <img width="362" height="286" alt="image" src="https://github.com/user-attachments/assets/9109ce98-0740-46ce-8a92-de5cd8cf2ec2" /> This seems to be because of overhead in computing rotary positional embeddings - without sequence packing, it's a batched call for a well formed tensor, while without sequence packing, it iterates over sequences one by one: #274 (comment)
1 parent c2bfe61 commit b9cc7e8

File tree

5 files changed

+282
-58
lines changed

5 files changed

+282
-58
lines changed

skyrl-train/examples/training_backends/megatron/run_megatron.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron
1515

1616
MEGATRON_TP=2
1717
MEGATRON_PP=2
18+
MEGATRON_CP=1
1819

1920
uv run --isolated --extra $INFERENCE_BACKEND --extra mcore -m skyrl_train.entrypoints.main_base \
2021
data.train_data="['$DATA_DIR/train.parquet']" \
@@ -29,9 +30,11 @@ uv run --isolated --extra $INFERENCE_BACKEND --extra mcore -m skyrl_train.entryp
2930
generator.inference_engine_tensor_parallel_size=1 \
3031
megatron_config.policy.tensor_model_parallel_size=$MEGATRON_TP \
3132
megatron_config.policy.pipeline_model_parallel_size=$MEGATRON_PP \
33+
megatron_config.policy.context_parallel_size=$MEGATRON_CP \
3234
megatron_config.ref.tensor_model_parallel_size=$MEGATRON_TP \
35+
megatron_config.ref.context_parallel_size=$MEGATRON_CP \
3336
megatron_config.ref.pipeline_model_parallel_size=$MEGATRON_PP \
34-
trainer.use_sample_packing=false \
37+
trainer.use_sample_packing=true \
3538
trainer.epochs=20 \
3639
trainer.eval_batch_size=1024 \
3740
trainer.eval_before_train=false \
@@ -56,7 +59,7 @@ uv run --isolated --extra $INFERENCE_BACKEND --extra mcore -m skyrl_train.entryp
5659
generator.gpu_memory_utilization=0.6 \
5760
trainer.logger="$LOGGER" \
5861
trainer.project_name="gsm8k_megatron" \
59-
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_${MODEL_NAME}" \
62+
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" \
6063
trainer.resume_mode=null \
6164
trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \
6265
$@

skyrl-train/skyrl_train/distributed/megatron/megatron_utils.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Utils ported from Verl
22
# https://github.com/volcengine/verl/blob/e1603dc97f3c20c58feed1f5be34acd5c72a830c/verl/utils/megatron_utils.py#L4
3+
# https://github.com/volcengine/verl/blob/dfa3933ac44b545fca1f6a8519fd07394a2cde1c/verl/models/mcore/util.py
34
# The original copyright is reproduced below:
45

56
# Copyright 2024 Bytedance Ltd. and/or its affiliates
@@ -27,6 +28,7 @@
2728
from megatron.core.optimizer import ChainedOptimizer
2829
from megatron.core import parallel_state as mpu
2930
from megatron.core.utils import get_attr_wrapped_model
31+
from megatron.core.packed_seq_params import PackedSeqParams
3032

3133
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
3234

@@ -291,6 +293,148 @@ def _iter_opts(opt):
291293
torch.cuda.empty_cache()
292294

293295

296+
def preprocess_packed_seqs(
297+
input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True
298+
) -> tuple[torch.Tensor, PackedSeqParams]:
299+
"""
300+
Preprocess packed sequences
301+
CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1
302+
gets second and second last chunks, and so on), this is for load balancing with causal masking.
303+
See https://github.com/NVIDIA/TransformerEngine/issues/1368
304+
"""
305+
batch_size = input_ids.shape[0]
306+
307+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
308+
tp_size = mpu.get_tensor_model_parallel_world_size()
309+
cp_size = mpu.get_context_parallel_world_size()
310+
cp_rank = mpu.get_context_parallel_rank()
311+
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
312+
313+
pad_size = (align_size - seqlens_in_batch % align_size) % align_size
314+
seqlens_in_batch_padded = seqlens_in_batch + pad_size
315+
316+
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
317+
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
318+
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
319+
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
320+
321+
# ----------------------------------------------------------------------------
322+
# Move the index information needed in the subsequent loop to the CPU at once,
323+
# to avoid frequent .item() calls in the loop that cause D2H synchronization
324+
# ----------------------------------------------------------------------------
325+
seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths
326+
seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding
327+
cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding)
328+
329+
# Pure Python int calculation to avoid further synchronization
330+
max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)
331+
332+
shape = list(input_ids.shape[1:])
333+
shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size
334+
if pre_process:
335+
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
336+
for i in range(batch_size):
337+
# Use Python int, so no GPU→CPU sync in the loop
338+
if cp_size <= 1:
339+
seqlen = seqlens_in_batch_cpu[i]
340+
start_idx = cu_seqlens_padded_cpu[i]
341+
input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
342+
continue
343+
344+
seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
345+
seqlen = seqlen_padded_i // cp_size
346+
half_seqlen = seqlen // 2
347+
start_idx = cu_seqlens_padded_cpu[i] // cp_size
348+
# split to 2 chunks
349+
d = input_ids[i, attention_mask[i]]
350+
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
351+
half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
352+
]
353+
354+
remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)
355+
remain_end = seqlen_padded_i - half_seqlen * cp_rank
356+
remain_end = min(remain_end, d.shape[0])
357+
remain_len = remain_end - remain_start
358+
if remain_len > 0:
359+
input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[
360+
remain_start:remain_end
361+
]
362+
363+
packed_seq_params = PackedSeqParams(
364+
qkv_format="thd",
365+
cu_seqlens_q=cu_seqlens_padded,
366+
max_seqlen_q=max_seqlen_in_batch,
367+
cu_seqlens_kv=cu_seqlens_padded,
368+
max_seqlen_kv=max_seqlen_in_batch,
369+
cu_seqlens_q_padded=cu_seqlens_padded,
370+
cu_seqlens_kv_padded=cu_seqlens_padded,
371+
)
372+
if pre_process:
373+
return input_ids_rmpad.unsqueeze(0), packed_seq_params
374+
else:
375+
return input_ids, packed_seq_params
376+
377+
378+
def postprocess_packed_seqs(
379+
output: torch.Tensor,
380+
packed_seq_params: PackedSeqParams,
381+
attention_mask: torch.Tensor,
382+
batch_size: int,
383+
seq_len: int,
384+
post_process: bool = True,
385+
) -> torch.Tensor:
386+
"""
387+
Postprocess packed sequences
388+
"""
389+
if not post_process:
390+
return output
391+
392+
# -------------------------------------------------------------------------
393+
# Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance,
394+
# to avoid a large number of .item() calls in the loop
395+
# -------------------------------------------------------------------------
396+
cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()
397+
seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist()
398+
399+
shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim
400+
output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)
401+
402+
cp_size = mpu.get_context_parallel_world_size()
403+
# all gather output across context parallel group
404+
if cp_size > 1:
405+
# output shape: [1, packed_len, hidden_dim]
406+
# need to gather across cp group and concatenate in sequence dimension
407+
output_list = [torch.empty_like(output) for _ in range(cp_size)]
408+
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
409+
output_list[mpu.get_context_parallel_rank()] = output
410+
else:
411+
output_list = [output]
412+
for i in range(batch_size):
413+
if cp_size <= 1:
414+
s = seq_lens_cpu[i]
415+
start_idx = cu_padded_cpu[i]
416+
output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s]
417+
continue
418+
s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size
419+
half_seqlen = s_len_padded_chunk // 2
420+
s_len = seq_lens_cpu[i]
421+
s_len_padded = s_len_padded_chunk * cp_size
422+
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
423+
for j in range(cp_size):
424+
o = output_list[j][0]
425+
# split to 2 chunks
426+
packed_start_idx = cu_padded_cpu[i] // cp_size
427+
o0, o1 = (
428+
o[packed_start_idx : packed_start_idx + half_seqlen],
429+
o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
430+
)
431+
tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0
432+
tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1
433+
output_new[i, attention_mask[i]] = tmp[:s_len]
434+
435+
return output_new
436+
437+
294438
def remove_left_padding(
295439
input_ids: torch.Tensor,
296440
attention_mask: torch.Tensor,

skyrl-train/skyrl_train/utils/utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,19 @@ def validate_megatron_cfg(cfg: DictConfig):
120120
assert cfg.generator.backend == "vllm", "only vllm is supported for with megatron"
121121
assert cfg.trainer.placement.colocate_all, "only colocate_all=True is supported for megatron training"
122122
assert cfg.trainer.critic.model.path is None, "only GRPO training is currently supported for megatron"
123-
assert not cfg.trainer.use_sample_packing, "sample packing is not yet supported for megatron"
123+
124+
if cfg.trainer.flash_attn:
125+
import flash_attn
126+
127+
version = flash_attn.__version__
128+
if version > "2.7.4.post1":
129+
raise ValueError("flash_attn <= 2.7.4.post1 is required for using the megatron backend with flash_attn")
124130

125131
worker_configs = [(cfg.trainer.policy, "policy"), (cfg.trainer.ref, "ref")]
126132
for config, worker_type in worker_configs:
127-
# context, expert, and export tensor parallel are not yet supported for megatron
128-
assert (
129-
config.megatron_config.context_parallel_size == 1
130-
), f"found {worker_type}.context_parallel_size > 1, context parallel is not yet supported for megatron"
133+
# context, expert, and expert tensor parallel are not yet supported for megatron
134+
if config.megatron_config.context_parallel_size > 1:
135+
assert cfg.trainer.use_sample_packing, "context parallel is only supported with sample packing"
131136
assert (
132137
config.megatron_config.expert_model_parallel_size == 1
133138
), f"found {worker_type}.expert_model_parallel_size > 1, expert model parallel is not yet supported for megatron"

skyrl-train/skyrl_train/workers/megatron/megatron_policy.py

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from skyrl_train.distributed.megatron.megatron_utils import (
1515
make_batch_generator,
16+
preprocess_packed_seqs,
17+
postprocess_packed_seqs,
1618
remove_left_padding,
1719
recover_left_padding,
1820
)
@@ -34,6 +36,7 @@ def __init__(
3436
self.actor_module = actor_module
3537
self.actor_optimizer = actor_optimizer
3638
self.policy_loss_fn = policy_loss_fn
39+
self.use_sample_packing = self.cfg.trainer.use_sample_packing
3740

3841
config = get_model_config(self.actor_module[0])
3942
# This is set to None by default: https://github.com/NVIDIA/Megatron-LM/blob/07b22a05136a3cb08ece05f7de38cf6aeeb165fb/megatron/core/model_parallel_config.py#L95
@@ -86,6 +89,7 @@ def collection_func(logits, data):
8689
vocab_end_index=(tp_rank + 1) * logits.shape[-1],
8790
tp_group=tp_grp,
8891
inference_only=True,
92+
cp_group=None, # we handle cp gathering in `postprocess_packed_seqs`
8993
chunk_size=None,
9094
)
9195
return 0.0, {"log_probs": token_logprobs}
@@ -96,27 +100,48 @@ def forward_step(batch_iter, model):
96100
attention_mask = batch["attention_mask"].to(bool)
97101
position_ids = batch["position_ids"]
98102

99-
new_sequences, new_attention_mask, new_position_ids = remove_left_padding(
100-
sequences,
101-
attention_mask,
102-
position_ids,
103-
self.tf_config.sequence_parallel,
104-
pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True),
105-
)
103+
if self.use_sample_packing:
104+
new_sequences, packed_seq_params = preprocess_packed_seqs(
105+
sequences,
106+
attention_mask,
107+
pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True),
108+
)
109+
new_attention_mask = None
110+
new_position_ids = None
111+
else:
112+
new_sequences, new_attention_mask, new_position_ids = remove_left_padding(
113+
sequences,
114+
attention_mask,
115+
position_ids,
116+
self.tf_config.sequence_parallel,
117+
pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True),
118+
)
119+
packed_seq_params = None
106120

107121
outputs = model(
108122
new_sequences,
109123
new_position_ids,
110124
new_attention_mask,
125+
packed_seq_params=packed_seq_params,
111126
)
112127

113-
outputs = recover_left_padding(
114-
outputs,
115-
new_attention_mask,
116-
attention_mask,
117-
seq_len,
118-
post_process=mpu.is_pipeline_last_stage(ignore_virtual=True),
119-
)
128+
if self.use_sample_packing:
129+
outputs = postprocess_packed_seqs(
130+
outputs,
131+
packed_seq_params,
132+
attention_mask,
133+
micro_batch_size,
134+
seq_len,
135+
post_process=mpu.is_pipeline_last_stage(ignore_virtual=True),
136+
)
137+
else:
138+
outputs = recover_left_padding(
139+
outputs,
140+
new_attention_mask,
141+
attention_mask,
142+
seq_len,
143+
post_process=mpu.is_pipeline_last_stage(ignore_virtual=True),
144+
)
120145

121146
return outputs, partial(collection_func, data=batch)
122147

@@ -192,6 +217,7 @@ def loss_func(logits, data):
192217
vocab_end_index=(tp_rank + 1) * logits.shape[-1],
193218
tp_group=tp_grp,
194219
inference_only=False,
220+
cp_group=None, # we handle cp gathering in `postprocess_packed_seqs`
195221
chunk_size=None,
196222
)
197223

@@ -240,27 +266,48 @@ def forward_step(batch_iter, model):
240266
attention_mask = batch["attention_mask"].to(bool)
241267
position_ids = batch["position_ids"]
242268

243-
new_sequences, new_attention_mask, new_position_ids = remove_left_padding(
244-
sequences,
245-
attention_mask,
246-
position_ids,
247-
self.tf_config.sequence_parallel,
248-
pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True),
249-
)
269+
if self.use_sample_packing:
270+
new_sequences, packed_seq_params = preprocess_packed_seqs(
271+
sequences,
272+
attention_mask,
273+
pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True),
274+
)
275+
new_attention_mask = None
276+
new_position_ids = None
277+
else:
278+
new_sequences, new_attention_mask, new_position_ids = remove_left_padding(
279+
sequences,
280+
attention_mask,
281+
position_ids,
282+
self.tf_config.sequence_parallel,
283+
pre_process=mpu.is_pipeline_first_stage(ignore_virtual=True),
284+
)
285+
packed_seq_params = None
250286

251287
outputs = model(
252288
new_sequences,
253289
new_position_ids,
254290
new_attention_mask,
291+
packed_seq_params=packed_seq_params,
255292
)
256293

257-
outputs = recover_left_padding(
258-
outputs,
259-
new_attention_mask,
260-
attention_mask,
261-
seq_len,
262-
post_process=mpu.is_pipeline_last_stage(ignore_virtual=True),
263-
)
294+
if self.use_sample_packing:
295+
outputs = postprocess_packed_seqs(
296+
outputs,
297+
packed_seq_params,
298+
attention_mask,
299+
micro_batch_size,
300+
seq_len,
301+
post_process=mpu.is_pipeline_last_stage(ignore_virtual=True),
302+
)
303+
else:
304+
outputs = recover_left_padding(
305+
outputs,
306+
new_attention_mask,
307+
attention_mask,
308+
seq_len,
309+
post_process=mpu.is_pipeline_last_stage(ignore_virtual=True),
310+
)
264311

265312
return outputs, partial(loss_func, data=batch)
266313

0 commit comments

Comments
 (0)