Skip to content

Commit e04e1c2

Browse files
committed
fix: port wrap_data_iterator pattern from PR #4659 to fix DCP test
The Phase 1 merge of #4716 took main's version of training.py per the skill's "Files to Override from Main" rule, which uses the HybridCPDataLoaderWrapper class wrapped once outside train_step. That broke the DCP test (gpt3_mcore_te_tp2_pp1_cp4_dcp) in two ways: 1. RuntimeError: Trying to resize storage that is not resizable - fixed by c3dbea7 (rename args.hybrid_context_parallel -> args.dynamic_context_parallel). 2. AssertionError: data iterator is not wrapped with RerunDataIterator - the outside-train_step wrap converted train_data_iterator from a RerunDataIterator to a plain iterator, but rerun_state_machine's should_run_forward_backward asserts the wrap. PR #4659 resolved this by keeping dev's wrap_data_iterator pattern instead of main's HybridCPDataLoaderWrapper, calling wrap_data_iterator INSIDE train_step (after should_run_forward_backward) and inside the eval loop. That keeps the original RerunDataIterator visible to the assertion and only swaps in the packed iterator for the forward_backward_func call. Port that pattern verbatim from PR #4659's training.py: - Replace HybridCPDataLoaderWrapper import with wrap_data_iterator - Remove the outside-train_step wrap (was at line 3000-3001) - Inside train_step: add the if config.sequence_packing_scheduler is not None block before forward_backward_func, unpacking (data_iterator, num_microbatches, seqlen_sum_this_global_batch, seqlen_squared_sum_this_global_batch); pass num_microbatches= num_microbatches to forward_backward_func - Inside eval loop: add the same wrap with try/except StopIteration, using packed_data_iterator and scheduled_eval_num_microbatches Note: this leaves HybridCPDataLoaderWrapper and its imports (Any, List, BalancedCPScheduler) as dead code in megatron/core/datasets/data_schedule.py. Cleanup of that file (and of the remaining structural diff in training.py / data_samplers.py / utils.py vs PR #4659's tree) is left to follow-up.
1 parent c3dbea7 commit e04e1c2

1 file changed

Lines changed: 44 additions & 7 deletions

File tree

megatron/training/training.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def set_startup_timestamps(program_start=None, main_entry=None):
181181
except ImportError:
182182
HAVE_FSDP2 = False
183183

184-
from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper
184+
from megatron.core.datasets.data_schedule import wrap_data_iterator
185185
from megatron.core.distributed import finalize_model_grads
186186
from megatron.core.enums import ModelType
187187
from megatron.core.inference.symmetric_memory import SymmetricMemoryManager
@@ -2030,6 +2030,27 @@ def train_step(
20302030
if isinstance(optim_instance, DistributedOptimizer):
20312031
optim_instance._copy_main_params_to_param_buffer()
20322032

2033+
if config.sequence_packing_scheduler is not None:
2034+
# This wrapper is designed to support DP-balanced THD and dynamic-CP.
2035+
# Before wrapping, the data_iterator returns either a single sequence per get_item call, or a list where each element is a sequence.
2036+
# The wrapper is responsible for:
2037+
# 1. scheduling the sequences across ranks
2038+
# 2. packing them into THD format
2039+
# 3. broadcast flops parametes and num_microbatches to TP ranks to support unfixed num_microbatches
2040+
# 4. broadcast metadata(cu_seqlens, cu_seqlens_padded, max_seqlen, etc.) to PP ranks to
2041+
# 5. returning the packed data iterator and the FLOPs parameters
2042+
(
2043+
data_iterator,
2044+
num_microbatches,
2045+
seqlen_sum_this_global_batch,
2046+
seqlen_squared_sum_this_global_batch,
2047+
) = wrap_data_iterator(data_iterator, config, get_num_microbatches())
2048+
else:
2049+
# data_iterator unchanged
2050+
num_microbatches = get_num_microbatches()
2051+
seqlen_sum_this_global_batch = args.seq_length * args.global_batch_size
2052+
seqlen_squared_sum_this_global_batch = args.seq_length**2 * args.global_batch_size
2053+
20332054
# Forward pass.
20342055
if save_activations_in_this_iteration:
20352056
enable_activation_logging(model, args.save)
@@ -2041,7 +2062,7 @@ def train_step(
20412062
forward_step_func=forward_step_func,
20422063
data_iterator=data_iterator,
20432064
model=model,
2044-
num_microbatches=get_num_microbatches(),
2065+
num_microbatches=num_microbatches,
20452066
seq_length=args.seq_length,
20462067
micro_batch_size=args.micro_batch_size,
20472068
decoder_seq_length=args.decoder_seq_length,
@@ -2997,9 +3018,6 @@ def train(
29973018
energy_monitor = get_energy_monitor()
29983019
one_logger = get_one_logger()
29993020

3000-
if args.dynamic_context_parallel:
3001-
train_data_iterator = iter(HybridCPDataLoaderWrapper(train_data_iterator, config))
3002-
30033021
if args.run_workload_inspector_server:
30043022
try:
30053023
import threading
@@ -3699,11 +3717,30 @@ def evaluate(
36993717
# Don't care about timing during evaluation
37003718
config.timers = None
37013719
ft_integration.on_eval_step_start()
3720+
if config.sequence_packing_scheduler is not None:
3721+
# This wrapper is designed to support DP-balanced THD and dynamic-CP.
3722+
# Before wrapping, the data_iterator returns either a single sequence per get_item call, or a list where each element is a sequence.
3723+
# The wrapper is responsible for:
3724+
# 1. scheduling the sequences across ranks
3725+
# 2. packing them into THD format
3726+
# 3. broadcast flops parametes and num_microbatches to TP ranks to support unfixed num_microbatches
3727+
# 4. broadcast metadata(cu_seqlens, cu_seqlens_padded, max_seqlen, etc.) to PP ranks to
3728+
# 5. returning the packed data iterator and the FLOPs parameters
3729+
try:
3730+
(packed_data_iterator, scheduled_eval_num_microbatches, _, _) = (
3731+
wrap_data_iterator(data_iterator, config, eval_num_microbatches)
3732+
)
3733+
except StopIteration:
3734+
# Validation data iterator exhausted, stop evaluation early.
3735+
break
3736+
else:
3737+
packed_data_iterator = data_iterator
3738+
scheduled_eval_num_microbatches = eval_num_microbatches
37023739
loss_dicts = forward_backward_func(
37033740
forward_step_func=forward_step_func,
3704-
data_iterator=data_iterator,
3741+
data_iterator=packed_data_iterator,
37053742
model=model,
3706-
num_microbatches=eval_num_microbatches,
3743+
num_microbatches=scheduled_eval_num_microbatches,
37073744
seq_length=args.seq_length,
37083745
micro_batch_size=eval_micro_batch_size,
37093746
decoder_seq_length=args.decoder_seq_length,

0 commit comments

Comments
 (0)