Skip to content

[FlexCheckPoint]adapt fc to sharding stage3#76538

Merged
From00 merged 7 commits intoPaddlePaddle:developfrom
zty-king:adapt_stage_v3_with_fc
Nov 27, 2025
Merged

[FlexCheckPoint]adapt fc to sharding stage3#76538
From00 merged 7 commits intoPaddlePaddle:developfrom
zty-king:adapt_stage_v3_with_fc

Conversation

@zty-king
Copy link
Contributor

@zty-king zty-king commented Nov 21, 2025

PR Category

Operator Mechanism

PR Types

Improvements

Description

  • 添加fc框架对Sharding Stage3的支持
  • 修复dtype的一个bug,如果target中,当处于multi_precision时(如fp16),优化器状态参数和模型参数同时存在,而前者在前面访问,会把src_dtype转成float32,导致后续赋值模型参数的dtype时出错。

@paddle-bot
Copy link

paddle-bot bot commented Nov 21, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Nov 21, 2025
@zty-king
Copy link
Contributor Author

对应paddleformers适配:PaddlePaddle/PaddleFormers#2987

@codecov-commenter
Copy link

codecov-commenter commented Nov 21, 2025

Codecov Report

❌ Patch coverage is 95.74468% with 4 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@f9062d5). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...eet/meta_parallel/sharding/group_sharded_stage3.py 95.69% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #76538   +/-   ##
==========================================
  Coverage           ?   95.74%           
==========================================
  Files              ?        2           
  Lines              ?       94           
  Branches           ?        0           
==========================================
  Hits               ?       90           
  Misses             ?        4           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@zty-king zty-king changed the title adapt fc to sharding stage3 [FlexCheckPoint]adapt fc to sharding stage3 Nov 24, 2025

self._optim.clear_grad = MethodType(_opt_clear, self._optim)

def init_slice_param(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要初始化

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • model被ShardingStage3包裹后,会立即把需要slice的param的数据全部存储在param.fw_storage,并将param数据清空,处于未初始化状态,而我们接续训练时,需要load ckpt里面的参数数据,注意该数据为参数的完整数据(即未切分的状态,因为ShardingStage3再最后save的时候会调用allgather通信,将param恢复成原来的完整状态)。而我们在接续训练时,需要slice的param已经处于未初始化状态,需要做一个初始化再load数据。并且load完后,要将数据转移到param.fw_storage上,并再次将param数据清除。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param.fw_storage中的数据是完整的还是被slice的。如果是完整的是在前向执行完成之后再释放一部分吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是slice的,只有需要用到完整参数做forward和backward的时候,才会调用allgather,获得完整参数,计算完之后,又会把param清空,然后把只有本rank需要更新的那部分参数保存到param.fw_storage上

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@zty-king
Copy link
Contributor Author

/re-run all-failed

force_gc = []

for param_name, tgt_shard in load_dict.items():
for param_name, tgt_shard in sorted(load_dict.items()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要sort

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复model和opt参数同时load的场景异常问题,当同时load时,可能会出现load_dict中,opt参数在model参数前,这样会导致src.dtype被更新为float32,而如果此时是multi_precision时(如fp16),model参数应该的src.dtype应该是fp16,为了防止同时load的场景,model参数的dtye被覆盖,应该先处理model参数,因此做一个sorted即可。

for param in params:
param_shape = param.shape
origin_state = param.stop_gradient
param.stop_gradient = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要单独设置stop_gradient

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是遵循_param_storage中的方法,我认为应该为了断开计算图的连接,防止flatten_()这个种in-place操作对反向传播计算梯度的影响,flatten_()后,又会恢复原来的stop_gradient

@From00 From00 merged commit 7235a00 into PaddlePaddle:develop Nov 27, 2025
98 of 106 checks passed
xingmingyyj pushed a commit to xingmingyyj/Paddle that referenced this pull request Dec 5, 2025
* adapt fc to sharding stage3

* add test and fix bug

* fix bug

* fix bug

* add test
swgu98 pushed a commit that referenced this pull request Dec 6, 2025
* 【FlexCheckpoint】Aoa config reverse (#76437)

* aoa_config_reverse

* fix the bug

* add test

* fix dtype style and add test

* adapt full param update

* [FlexCheckPoint]adapt fc to sharding stage3 (#76538)

* adapt fc to sharding stage3

* add test and fix bug

* fix bug

* fix bug

* add test

* fc comm using grouped send/recv (#76779)

fix

fix

fix

fix

---------

Co-authored-by: Tianyu Zheng <129518799+zty-king@users.noreply.github.com>
@zty-king zty-king deleted the adapt_stage_v3_with_fc branch January 8, 2026 08:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants