Commit 57d9d61
committed
[Refactor] Unify TrainEngine by moving model-specific logic to model layer
Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models
This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types
- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
- `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
- `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system
- **BaseModel**:
- Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
- Implement default `pre_micro_batch_forward()` to compute token statistics
- Implement default `post_micro_batch_forward()` to aggregate losses and extra info
- Add overload type hints for `__call__` to improve type inference
- **MoE Model**:
- Override `post_micro_batch_forward()` to handle MoE-specific logic:
- Compute maxvio for router load balancing
- Update router bias based on expert load
- Add `need_update_bias` property for cleaner code
- Properly scale balancing_loss and z_loss by batch_size
- **ComposeModel**:
- Override `pre_micro_batch_forward()` to compute image token statistics
- Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field
- **TrainEngine**:
- Simplify `train_step()` to delegate statistics to model hooks
- Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
- Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
- Remove all model-specific branching logic
- **EngineConfig**:
- Remove conditional logic for VisionComposeTrainEngine
- Use single TrainEngine.build() path
- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic
- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types
Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)
`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.
The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer
1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain
- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed
ghstack-source-id: ba87e63
Pull-Request: #15181 parent 3d9944c commit 57d9d61
17 files changed
Lines changed: 316 additions & 536 deletions
File tree
- tests
- engine
- model
- train
- xtuner/v1
- engine
- loss
- model
- compose
- moe
- rl/base
- train
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
93 | | - | |
| 93 | + | |
94 | 94 | | |
95 | 95 | | |
96 | 96 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
99 | 99 | | |
100 | 100 | | |
101 | 101 | | |
102 | | - | |
| 102 | + | |
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
| |||
190 | 190 | | |
191 | 191 | | |
192 | 192 | | |
193 | | - | |
| 193 | + | |
194 | 194 | | |
195 | 195 | | |
196 | 196 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
93 | 93 | | |
94 | 94 | | |
95 | 95 | | |
96 | | - | |
| 96 | + | |
97 | 97 | | |
98 | 98 | | |
99 | 99 | | |
| |||
171 | 171 | | |
172 | 172 | | |
173 | 173 | | |
174 | | - | |
| 174 | + | |
175 | 175 | | |
176 | 176 | | |
177 | 177 | | |
| |||
270 | 270 | | |
271 | 271 | | |
272 | 272 | | |
273 | | - | |
| 273 | + | |
274 | 274 | | |
275 | 275 | | |
276 | 276 | | |
277 | | - | |
| 277 | + | |
278 | 278 | | |
279 | 279 | | |
280 | 280 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
23 | 22 | | |
24 | 23 | | |
25 | 24 | | |
| |||
85 | 84 | | |
86 | 85 | | |
87 | 86 | | |
88 | | - | |
| 87 | + | |
89 | 88 | | |
90 | 89 | | |
91 | 90 | | |
| |||
116 | 115 | | |
117 | 116 | | |
118 | 117 | | |
119 | | - | |
| 118 | + | |
120 | 119 | | |
121 | 120 | | |
122 | 121 | | |
| |||
160 | 159 | | |
161 | 160 | | |
162 | 161 | | |
163 | | - | |
| 162 | + | |
164 | 163 | | |
165 | 164 | | |
166 | 165 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
30 | 29 | | |
31 | 30 | | |
32 | 31 | | |
33 | 32 | | |
34 | 33 | | |
35 | 34 | | |
36 | 35 | | |
| 36 | + | |
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
| |||
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
| 58 | + | |
| 59 | + | |
62 | 60 | | |
63 | 61 | | |
64 | 62 | | |
| |||
647 | 645 | | |
648 | 646 | | |
649 | 647 | | |
650 | | - | |
651 | | - | |
652 | 648 | | |
653 | 649 | | |
654 | 650 | | |
655 | 651 | | |
656 | 652 | | |
657 | | - | |
| 653 | + | |
658 | 654 | | |
659 | 655 | | |
660 | 656 | | |
| |||
673 | 669 | | |
674 | 670 | | |
675 | 671 | | |
676 | | - | |
677 | | - | |
678 | | - | |
679 | | - | |
| 672 | + | |
680 | 673 | | |
681 | 674 | | |
682 | 675 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
| 3 | + | |
5 | 4 | | |
6 | 5 | | |
7 | 6 | | |
8 | 7 | | |
9 | 8 | | |
10 | | - | |
11 | | - | |
12 | | - | |
13 | 9 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
8 | 7 | | |
9 | | - | |
10 | 8 | | |
11 | 9 | | |
12 | 10 | | |
| |||
27 | 25 | | |
28 | 26 | | |
29 | 27 | | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
| 28 | + | |
0 commit comments