Skip to content

Commit 57d9d61

Browse files
committed
[Refactor] Unify TrainEngine by moving model-specific logic to model layer
Previously, we had two separate train engines: - `TrainEngine` for regular models - `VisionComposeTrainEngine` for vision-language models This duplication led to: - Code maintenance overhead (242 lines of duplicated logic) - Tight coupling between engine and model-specific details - Difficulty in extending to new model types - **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted) - **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel` - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics - **Unify** `TrainEngine` to handle all model types through the new hook system - **BaseModel**: - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types - Implement default `pre_micro_batch_forward()` to compute token statistics - Implement default `post_micro_batch_forward()` to aggregate losses and extra info - Add overload type hints for `__call__` to improve type inference - **MoE Model**: - Override `post_micro_batch_forward()` to handle MoE-specific logic: - Compute maxvio for router load balancing - Update router bias based on expert load - Add `need_update_bias` property for cleaner code - Properly scale balancing_loss and z_loss by batch_size - **ComposeModel**: - Override `pre_micro_batch_forward()` to compute image token statistics - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field - **TrainEngine**: - Simplify `train_step()` to delegate statistics to model hooks - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo` - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor) - Remove all model-specific branching logic - **EngineConfig**: - Remove conditional logic for VisionComposeTrainEngine - Use single TrainEngine.build() path - Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog` - Simplify hook signatures (from 2 params to 1) - Remove conditional engine instantiation logic - Replace `VisionComposeTrainEngine` imports with `TrainEngine` - Update test assertions to use new `TrainStepInfo` structure - Remove TypeAdapter validation for deprecated types Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating through ModelOutputs fields. This is pragmatic but not ideal: - **Pros**: Avoids large-scale changes to model forward() logic - **Cons**: Engine knows about loss field names (coupling) - **Future**: Model should return total_loss directly (see TODO comment) `loss_ctx.batch_size` represents the full gradient accumulation batch size, not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()` and used for scaling balancing_loss and z_loss. The pre/post hooks provide clean extension points: - Subclasses can override to add model-specific logic - Default implementations in BaseModel handle common cases - No conditional logic needed in engine layer 1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed) 2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics 3. **Extensibility**: New model types can override hooks without changing engine 4. **Type Safety**: Unified TrainStepInfo with clear field definitions 5. **Maintainability**: Single engine implementation to maintain - Loss reduce logic still needs clarification (minor issue, doesn't affect training) - TODO added for future refactor: move total_loss aggregation to model layer - All format issues (extra blank lines, class formatting) fixed ghstack-source-id: ba87e63 Pull-Request: #1518
1 parent 3d9944c commit 57d9d61

17 files changed

Lines changed: 316 additions & 536 deletions

tests/engine/test_dense_train_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def warmup_fn(x):
9090
seq_ctx = seq_ctx_list[0]
9191
loss_ctx = loss_ctx_list[0]
9292
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
93-
loss_log, _ = engine.train_step(engine_input)
93+
loss_log = engine.train_step(engine_input)["logs_info"]
9494
grad_norm = engine.clip_grad_norm()
9595
engine.step_optimizer(grad_norm)
9696
lr_scheduler.step()

tests/engine/test_moe_train_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def warmup_fn(x):
9999
loss_ctx = loss_ctx_list[0]
100100
seq_ctx = seq_ctx_list[0]
101101
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
102-
loss_log, _ = engine.train_step(engine_input)
102+
loss_log = engine.train_step(engine_input)["logs_info"]
103103
grad_norm = engine.clip_grad_norm()
104104
engine.step_optimizer(grad_norm)
105105
lr_scheduler.step()
@@ -190,7 +190,7 @@ def warmup_fn(x):
190190
loss_ctx = loss_ctx_list[0]
191191
seq_ctx = seq_ctx_list[0]
192192
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
193-
loss_log, _ = engine.train_step(engine_input)
193+
loss_log = engine.train_step(engine_input)["logs_info"]
194194
grad_norm = engine.clip_grad_norm()
195195
engine.step_optimizer(grad_norm)
196196
lr_scheduler.step()

tests/engine/test_moe_train_engine_float8.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def warmup_fn(x):
9393
loss_ctx = loss_ctx_list[0]
9494
seq_ctx = seq_ctx_list[0]
9595
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
96-
loss_log, _ = engine.train_step(engine_input)
96+
loss_log = engine.train_step(engine_input)["logs_info"]
9797
grad_norm = engine.clip_grad_norm()
9898
engine.step_optimizer(grad_norm)
9999
lr_scheduler.step()
@@ -171,7 +171,7 @@ def warmup_fn(x):
171171
loss_ctx = loss_ctx_list[0]
172172
seq_ctx = seq_ctx_list[0]
173173
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
174-
loss_log, _ = engine.train_step(engine_input)
174+
loss_log = engine.train_step(engine_input)["logs_info"]
175175
grad_norm = engine.clip_grad_norm()
176176
engine.step_optimizer(grad_norm)
177177
lr_scheduler.step()
@@ -270,11 +270,11 @@ def warmup_fn(x):
270270
loss_ctx = loss_ctx_list[0]
271271
seq_ctx = seq_ctx_list[0]
272272
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
273-
loss_log, _ = engine.train_step(engine_input)
273+
logs_info = engine.train_step(engine_input)["logs_info"]
274274
grad_norm = engine.clip_grad_norm()
275275
engine.step_optimizer(grad_norm)
276276
lr_scheduler.step()
277-
losses.append(loss_log["reduced_llm_loss"])
277+
losses.append(logs_info["reduced_llm_loss"])
278278
losses_ref = torch.tensor([2.41, 2.41, 2.47, 2.42, 2.44, 2.44, 2.42, 2.38, 2.31, 2.30])
279279
losses = torch.tensor(losses)
280280
self._check_loss_curve(losses, losses_ref)

tests/model/test_qwen3_tile_embedding.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from xtuner.v1.loss.ce_loss import CELossConfig
2020
from xtuner.v1.config import FSDPConfig, LRConfig, AdamWConfig
2121
from xtuner.v1.engine.train_engine import TrainEngine
22-
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
2322
from torch.optim.lr_scheduler import LambdaLR
2423
from xtuner.v1.utils import pad_to_max_length
2524
from xtuner.v1.utils.device import get_device
@@ -85,7 +84,7 @@ def warmup_fn(x):
8584
loss_ctx = loss_ctx_list[0]
8685
seq_ctx = seq_ctx_list[0]
8786
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
88-
loss_log, _ = engine.train_step(engine_input)
87+
engine.train_step(engine_input)
8988
grad_norm = engine.clip_grad_norm()
9089
engine.step_optimizer(grad_norm)
9190
lr_scheduler.step()
@@ -116,7 +115,7 @@ def test_qwen3vl_tie_embedding(self, device, tp_size):
116115
cpu_offload=False,
117116
tp_size=tp_size
118117
)
119-
engine = VisionComposeTrainEngine(
118+
engine = TrainEngine(
120119
model_cfg=dense_cfg, optim_cfg=optim_cfg, fsdp_cfg=fsdp_cfg
121120
)
122121
engine.from_hf(hf_path=QWEN3_VL_DENSE_PATH)
@@ -160,7 +159,7 @@ def warmup_fn(x):
160159
loss_ctx = loss_ctx_list[0]
161160
seq_ctx = seq_ctx_list[0]
162161
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
163-
loss_log, _ = engine.train_step(engine_input)
162+
engine.train_step(engine_input)
164163
grad_norm = engine.clip_grad_norm()
165164
engine.step_optimizer(grad_norm)
166165
lr_scheduler.step()

tests/train/test_trainer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
from xtuner.v1.datasets import FTDPTokenizeFnConfig
2727
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
2828
from xtuner.v1.train.trainer import TrainerConfig
29-
from xtuner.v1.engine.train_engine import LossLog, OtherLog
3029
from xtuner.v1.loss import CELossConfig
3130
from xtuner._testing import DeterministicDDPTestCase
3231
from unittest import TestCase
3332
from xtuner.v1.train.trainer import XTunerMeta, ExpInfo, ExpHistory, GitInfo
3433
from xtuner.v1.utils.device import get_device
3534
from xtuner.v1.datasets.dataloader import Dataloader
3635
from torch.optim.lr_scheduler import SequentialLR
36+
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
3737

3838

3939
DEVICE = get_device()
@@ -55,10 +55,8 @@ def grad_accumulation_steps(self, *args, **kwargs):
5555

5656
def train_step(self, *args, **kwargs):
5757
self.train_step_calls += 1
58-
return (
59-
{"local_loss": 1.0, "reduced_llm_loss": 0.8},
60-
{"step_consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5}
61-
)
58+
return {"total_loss": 1.8, "step_consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5, "logs_info": {"local_loss": 1.0, "reduced_llm_loss": 0.8}, "extra_info": ModelForwardExtraLogInfo()}
59+
6260

6361
def save_hf(self, hf_path):
6462
self.save_hf_calls.append(hf_path)
@@ -647,14 +645,12 @@ def test_hooks_config(self):
647645
self.create_pg(DEVICE)
648646
checkpoint_function_call_times = 0
649647
train_step_function_call_times = 0
650-
losslog_adapater = TypeAdapter(LossLog)
651-
otherlog_adapter = TypeAdapter(OtherLog)
652648

653649
def checkpoint_hook(checkpoint, step, epoch, total_step, total_epoch):
654650
nonlocal checkpoint_function_call_times
655651
checkpoint_function_call_times += 1
656652

657-
def train_step_hook(loss_log, other_log, step, epoch, total_step, total_epoch):
653+
def train_step_hook(train_step_info, step, epoch, total_step, total_epoch):
658654
nonlocal train_step_function_call_times
659655
train_step_function_call_times += 1
660656

@@ -673,10 +669,7 @@ def connect_trainer(self, trainer: Trainer):
673669
def __init__(self) -> None:
674670
self.count = 0
675671

676-
def __call__(self, loss_log, other_log, step, epoch, total_step, total_epoch):
677-
losslog_adapater.validate_python(loss_log)
678-
otherlog_adapter.validate_python(other_log)
679-
672+
def __call__(self, train_step_info, step, epoch, total_step, total_epoch):
680673
assert self.trainer().cur_step == step
681674
assert self.trainer().cur_epoch == epoch
682675
assert self.trainer().total_step == total_step

xtuner/v1/engine/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
from xtuner.v1.engine.config import EngineConfig
22

3-
from .train_engine import LossLog, OtherLog, TrainEngine
4-
from .vision_compose_train_engine import VisionComposeTrainEngine
3+
from .train_engine import TrainEngine
54

65

76
__all__ = [
87
"TrainEngine",
98
"EngineConfig",
10-
"VisionComposeTrainEngine",
11-
"LossLog",
12-
"OtherLog",
139
]

xtuner/v1/engine/config.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from xtuner.v1.config import FSDPConfig, OptimConfig
66
from xtuner.v1.engine.train_engine import TrainEngine
7-
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
87
from xtuner.v1.model.base import BaseModel, ConfigDict
9-
from xtuner.v1.model.compose.base import BaseComposeConfig
108

119

1210
@runtime_checkable
@@ -27,7 +25,4 @@ class EngineConfig(PydanticBaseModel):
2725
model_cfg: ModelConfigProto
2826

2927
def build(self):
30-
if isinstance(self.model_cfg, BaseComposeConfig):
31-
return VisionComposeTrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)
32-
else:
33-
return TrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)
28+
return TrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)

0 commit comments

Comments
 (0)