[API compatibility] add tensor.requires_grad_ method#76439
[API compatibility] add tensor.requires_grad_ method#76439zhwesky2010 merged 5 commits intoPaddlePaddle:developfrom
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
zhwesky2010
left a comment
There was a problem hiding this comment.
加单测跑一下吧,这个需要实际运行下。同事覆盖率CI也需要实际跑到了才行
| self.stop_gradient = not value | ||
|
|
||
| @requires_grad.setter | ||
| def requires_grad_(self, value: bool) -> None: |
zhwesky2010
left a comment
There was a problem hiding this comment.
eager_method.cc、framework.py这两个文件不需要改动
python/paddle/base/framework.py
Outdated
| self.desc.set_stop_gradient(s) | ||
|
|
||
| @property | ||
| def requires_grad(self) -> bool: |
python/paddle/base/framework.py
Outdated
| return not self.desc.stop_gradient() | ||
|
|
||
| @requires_grad.setter | ||
| def requires_grad_(self, value) -> None: |
| @@ -0,0 +1,63 @@ | |||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | |||
There was a problem hiding this comment.
单测里还需要测 dygraph、老IR分支下的。
统一放到test/legacy_test/test_tensor_requires_grad.py里来测吧。
对于老IR:需要加with paddle.pir_utils.OldIrGuard(),才能测试到layers/math_op_patch.py
对于新IR:需要加with paddle.pir_utils.IrGuard(), 才能测到pir/math_op_patch.py
对于动态图:直接测试
| """Test requires_grad behavior in static mode""" | ||
| paddle.enable_static() | ||
|
|
||
| try: |
There was a problem hiding this comment.
这个不是检测异常抛出,不需要try finally
paddle/fluid/pybind/eager_method.cc
Outdated
| EAGER_CATCH_AND_THROW_RETURN_NULL | ||
| } | ||
|
|
||
| PyDoc_STRVAR(tensor_requires_grad___doc__, // NOLINT |
There was a problem hiding this comment.
eager_mothod.cc不需要改动了,这个的逻辑会被dygraph/math_op_patch.py覆盖掉
| self.assertFalse(y.requires_grad) | ||
| self.assertTrue(y.stop_gradient) | ||
|
|
||
| def test_requires_grad_static_mode(self): |
There was a problem hiding this comment.
静态图pir跑到了,老IR应该没跑到吧,这样layers/match_op_patch.py里面的代码是没有覆盖到的
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (83.33%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #76439 +/- ##
==========================================
Coverage ? 83.33%
==========================================
Files ? 3
Lines ? 12
Branches ? 0
==========================================
Hits ? 10
Misses ? 2
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
/re-run all-failed |


PR Category
User Experience
PR Types
New features
Description
add tensor.requires_grad_ method , API corresponding to PyTorch:[https://docs.pytorch.org/docs/stable/generated/torch.Tensor.requires_grad_.html#torch.Tensor.requires_grad_]