[FlexCheckPoint]adapt fc to sharding stage3#76538
[FlexCheckPoint]adapt fc to sharding stage3#76538From00 merged 7 commits intoPaddlePaddle:developfrom
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
|
对应paddleformers适配:PaddlePaddle/PaddleFormers#2987 |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #76538 +/- ##
==========================================
Coverage ? 95.74%
==========================================
Files ? 2
Lines ? 94
Branches ? 0
==========================================
Hits ? 90
Misses ? 4
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
… adapt_stage_v3_with_fc
|
|
||
| self._optim.clear_grad = MethodType(_opt_clear, self._optim) | ||
|
|
||
| def init_slice_param(self): |
There was a problem hiding this comment.
- model被ShardingStage3包裹后,会立即把需要slice的param的数据全部存储在param.fw_storage,并将param数据清空,处于未初始化状态,而我们接续训练时,需要load ckpt里面的参数数据,注意该数据为参数的完整数据(即未切分的状态,因为ShardingStage3再最后save的时候会调用allgather通信,将param恢复成原来的完整状态)。而我们在接续训练时,需要slice的param已经处于未初始化状态,需要做一个初始化再load数据。并且load完后,要将数据转移到param.fw_storage上,并再次将param数据清除。
There was a problem hiding this comment.
param.fw_storage中的数据是完整的还是被slice的。如果是完整的是在前向执行完成之后再释放一部分吗?
There was a problem hiding this comment.
是slice的,只有需要用到完整参数做forward和backward的时候,才会调用allgather,获得完整参数,计算完之后,又会把param清空,然后把只有本rank需要更新的那部分参数保存到param.fw_storage上
|
/re-run all-failed |
| force_gc = [] | ||
|
|
||
| for param_name, tgt_shard in load_dict.items(): | ||
| for param_name, tgt_shard in sorted(load_dict.items()): |
There was a problem hiding this comment.
修复model和opt参数同时load的场景异常问题,当同时load时,可能会出现load_dict中,opt参数在model参数前,这样会导致src.dtype被更新为float32,而如果此时是multi_precision时(如fp16),model参数应该的src.dtype应该是fp16,为了防止同时load的场景,model参数的dtye被覆盖,应该先处理model参数,因此做一个sorted即可。
| for param in params: | ||
| param_shape = param.shape | ||
| origin_state = param.stop_gradient | ||
| param.stop_gradient = True |
There was a problem hiding this comment.
这里是遵循_param_storage中的方法,我认为应该为了断开计算图的连接,防止flatten_()这个种in-place操作对反向传播计算梯度的影响,flatten_()后,又会恢复原来的stop_gradient
* adapt fc to sharding stage3 * add test and fix bug * fix bug * fix bug * add test
* 【FlexCheckpoint】Aoa config reverse (#76437) * aoa_config_reverse * fix the bug * add test * fix dtype style and add test * adapt full param update * [FlexCheckPoint]adapt fc to sharding stage3 (#76538) * adapt fc to sharding stage3 * add test and fix bug * fix bug * fix bug * add test * fc comm using grouped send/recv (#76779) fix fix fix fix --------- Co-authored-by: Tianyu Zheng <129518799+zty-king@users.noreply.github.com>
PR Category
Operator Mechanism
PR Types
Improvements
Description