-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto-parallel] 增强stage对层间传递stop_grad=True的参数的支持 #73459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
# We assume we always send to stage + 1 | ||
if not self.is_last: | ||
self.act_send_info[idx] = [self.stage_index + 1] | ||
if not outputs_meta[idx].stop_gradient: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorMeta 没有这个属性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Sorry to inform you that ea80cdb's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #73459 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 2
Lines ? 12
Branches ? 0
===========================================
Hits ? 12
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@@ -710,9 +709,19 @@ def forward_one_chunk( | |||
flat_args = _flatten_args(input_args) | |||
flat_kwargs = _flatten_args(composite_kwargs) | |||
flatten_input_tensors = flat_args + flat_kwargs | |||
grad_required_output_tuple = tuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名似有问题,grad开头,给人误解为是grad数据。命名为 requires_grad_output_tuple更好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for out in output_tuple | ||
if isinstance(out, paddle.Tensor) and not out.stop_gradient | ||
) | ||
grad_required_flatten_input_tensors = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
/re-run all-failed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
Improvements
Description
use_flash_attention设置为false时,此时跑llama2_13b_hybrid_pp会报错,主要原因在于forward输出到下一个stage的output数,和backward时从下一个stage接收到的gard数不相等,导致backward无法正确计算。
由于在EmbeddingLayer中计算出来的参数需要在forward过程中不断传递,并在每个DecoderLayer中做相关计算,但是注意,除了hidden_states,其它参数在EmbeddingLayer计算之后,只在层间传递,辅助计算。
在每层获取输出时,对其过滤,对于stop_grad为True的参数,进行过滤,并适配相关代码,只在层间传递,不做backward