Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,19 @@ def requires_grad(self: Tensor, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
)
self.stop_gradient = not value

@property
def itemsize(self: Tensor) -> int:
"""
Expand Down Expand Up @@ -625,6 +638,7 @@ def _reduce_ex_(self: Tensor, proto):
('new_ones', _new_ones_),
('new_zeros', _new_zeros_),
("requires_grad", requires_grad),
("requires_grad_", requires_grad_),
# for logical compare
('__array_ufunc__', None),
('itemsize', itemsize),
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,19 @@ def requires_grad(self, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
)
self.stop_gradient = not value

def _scalar_add_(var, value):
return _scalar_op_(var, 1.0, value)

Expand Down Expand Up @@ -849,6 +862,7 @@ def to_dense(var):
('ndimension', ndimension),
('ndim', _ndim),
("requires_grad", requires_grad),
("requires_grad_", requires_grad_),
(
'__add__',
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_),
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,19 @@ def requires_grad(self, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infoflow 2025-11-17 15-51-18

统一加到这三个文件里去吧,和def requires_grad(放一块

这个是一个function,不是一个属性,所以不需要@requires_grad.setter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里少了默认参数True,是漏掉了么

"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
)
self.stop_gradient = not value

@property
def itemsize(self) -> int:
"""
Expand Down Expand Up @@ -1496,6 +1509,7 @@ def get_device(self) -> None:
('new_ones', _new_ones_),
('new_zeros', _new_zeros_),
("requires_grad", requires_grad),
("requires_grad_", requires_grad_),
('clone', clone),
('clear_gradient', clear_gradient),
('append', append),
Expand Down
253 changes: 253 additions & 0 deletions test/legacy_test/test_tensor_requires_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,258 @@ def test_requires_grad_edge_cases(self):
self.assertTrue(x.requires_grad)


class TestTensorRequiresGrad_(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
paddle.disable_static()
np.random.seed(1919)

def tearDown(self):
"""Clean up after each test method."""
paddle.disable_static()

def test_basic_requires_grad_property(self):
"""Test basic requires_grad property functionality"""
# Test default behavior - new tensors have stop_gradient=True by default
x = paddle.randn([2, 3])
self.assertFalse(x.requires_grad)
self.assertTrue(x.stop_gradient)

# Test setting requires_grad to True
x.requires_grad_(True)
self.assertTrue(x.requires_grad)
self.assertFalse(x.stop_gradient)

# Test setting requires_grad to False
x.requires_grad_(False)
self.assertFalse(x.requires_grad)
self.assertTrue(x.stop_gradient)

def test_requires_grad_consistency_with_stop_gradient(self):
"""Test that requires_grad is always the opposite of stop_gradient"""
x = paddle.randn([3, 4])

# Test multiple state changes
states = [True, False, True, False]
for requires_grad_state in states:
x.requires_grad_(requires_grad_state)
self.assertEqual(x.requires_grad, requires_grad_state)
self.assertEqual(x.stop_gradient, not requires_grad_state)

# Also test setting stop_gradient directly
x.stop_gradient = requires_grad_state
self.assertEqual(x.requires_grad, not requires_grad_state)
self.assertEqual(x.stop_gradient, requires_grad_state)

def test_requires_grad_type_checking(self):
"""Test type checking for requires_grad setter"""
x = paddle.randn([2, 2])

# Valid boolean values should work
x.requires_grad_(True)
x.requires_grad_(False)

# Invalid types should raise TypeError
invalid_values = ["true", 1, 0, None, [], {}]
for invalid_value in invalid_values:
with self.assertRaises(TypeError) as cm:
x.requires_grad_(invalid_value)
self.assertIn("requires_grad must be bool", str(cm.exception))

def test_requires_grad_with_parameter(self):
"""Test requires_grad behavior with Parameter tensors"""
# Create a parameter - Parameters have stop_gradient=False by default (trainable)
param = paddle.create_parameter([3, 4], dtype='float32')
self.assertTrue(
param.requires_grad
) # Parameters require grad by default
self.assertFalse(
param.stop_gradient
) # Parameters are trainable by default

# Test changing requires_grad on parameter
param.requires_grad_(False)
self.assertFalse(param.requires_grad)
self.assertTrue(param.stop_gradient)

def test_requires_grad_in_gradient_computation(self):
"""Test requires_grad behavior in actual gradient computation"""
x = paddle.randn([2, 3])
y = paddle.randn([2, 3])

# Set both tensors to require grad
x.requires_grad_(True)
y.requires_grad_(True)

z = x * y + x.sum()
z.backward()

self.assertIsNotNone(x.grad)
self.assertIsNotNone(y.grad)

# Clear gradients and test with requires_grad=False
x.grad._clear_data()
y.grad._clear_data()

x.requires_grad_(False)
y.requires_grad_(True)

z = x * y + x.sum()
z.backward()

self.assertIsNone(x.grad) # x doesn't require grad
self.assertIsNotNone(y.grad) # y requires grad

def test_requires_grad_with_different_tensor_types(self):
"""Test requires_grad with different tensor creation methods"""
# Test with different tensor creation functions
tensor_creators = [
lambda: paddle.randn([2, 3]),
lambda: paddle.zeros([2, 3]),
lambda: paddle.ones([2, 3]),
lambda: paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32'),
lambda: paddle.arange(6, dtype='float32').reshape([2, 3]),
]

for creator in tensor_creators:
x = creator()
# All newly created tensors should have requires_grad=False by default
self.assertFalse(x.requires_grad)
self.assertTrue(x.stop_gradient)

# Test modification
x.requires_grad_(True)
self.assertTrue(x.requires_grad)
self.assertFalse(x.stop_gradient)

def test_requires_grad_with_tensor_operations(self):
"""Test requires_grad preservation through tensor operations"""
x = paddle.randn([3, 3])
y = paddle.randn([3, 3])

x.requires_grad_(True)
y.requires_grad_(False)

# Operations should preserve requires_grad appropriately
z1 = x + y # Should require grad (x requires grad)
z2 = x * 2.0 # Should require grad (x requires grad)
z3 = y.sin() # Should not require grad (y doesn't require grad)

self.assertTrue(z1.requires_grad)
self.assertTrue(z2.requires_grad)
self.assertFalse(z3.requires_grad)

def test_requires_grad_with_detach(self):
"""Test requires_grad behavior with detach operation"""
x = paddle.randn([2, 3])
x.requires_grad_(True)

y = x.detach()

# Detached tensor should not require grad
self.assertTrue(x.requires_grad)
self.assertFalse(y.requires_grad)
self.assertTrue(y.stop_gradient)

def test_requires_grad_old_static_mode(self):
"""Test requires_grad behavior in static mode"""
paddle.enable_static()
with paddle.pir_utils.OldIrGuard():
x = paddle.static.data(name='x', shape=[2, 3], dtype='float32')

# In static mode, variables also have stop_gradient=True by default
self.assertFalse(x.requires_grad)
self.assertTrue(x.stop_gradient)

# Test setting requires_grad in static mode
x.requires_grad_(True)
self.assertTrue(x.requires_grad)
self.assertFalse(x.stop_gradient)

def test_requires_grad_static_mode(self):
"""Test requires_grad behavior in static mode"""
paddle.enable_static()

with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='x', shape=[2, 3], dtype='float32')

# In static mode, variables also have stop_gradient=True by default
self.assertFalse(x.requires_grad)
self.assertTrue(x.stop_gradient)

# Test setting requires_grad in static mode
x.requires_grad_(True)
self.assertTrue(x.requires_grad)
self.assertFalse(x.stop_gradient)

def test_requires_grad_edge_cases(self):
"""Test edge cases for requires_grad"""
# Test with scalar tensor
scalar = paddle.to_tensor(3.14)
self.assertFalse(scalar.requires_grad) # False
scalar.requires_grad_(True)
self.assertTrue(scalar.requires_grad)

# Test with empty tensor
empty = paddle.empty([0, 3])
self.assertFalse(empty.requires_grad) # False
empty.requires_grad_(True)
self.assertTrue(empty.requires_grad)

# Test with different dtypes
dtypes = [paddle.float32, paddle.float64, paddle.int32, paddle.int64]
for dtype in dtypes:
x = paddle.ones([2, 2], dtype=dtype)
# All tensors should have requires_grad=False by default
self.assertFalse(x.requires_grad)

# Float tensors should support requires_grad
if dtype in [paddle.float32, paddle.float64]:
x.requires_grad_(True)
self.assertTrue(x.requires_grad)


class TestAPI(unittest.TestCase):
def setUp(self):
paddle.enable_static()

def assert_api(self, api_func, require_grad):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = api_func()
self.assertEqual(x.stop_gradient, require_grad)
# test for setter
x.requires_grad_(require_grad)
self.assertEqual(x.stop_gradient, not require_grad)

def test_full(self):
api = lambda: paddle.full(shape=[2, 3], fill_value=1.0)
self.assert_api(api, True)

def test_data(self):
api = lambda: paddle.static.data('x', [4, 4], dtype='float32')
self.assert_api(api, True)

# TODO(Aurelius84): Add more test cases after API is migrated.


class TestParameters(unittest.TestCase):
def setUp(self):
paddle.enable_static()

def test_create_param(self):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
w = paddle.create_parameter(shape=[784, 200], dtype='float32')
self.assertEqual(w.stop_gradient, False)
self.assertEqual(w.persistable, True)

# test for setter
w.requires_grad_(False)
w.persistable = False
self.assertEqual(w.stop_gradient, True)
self.assertEqual(w.persistable, False)


if __name__ == '__main__':
unittest.main()
Loading