Skip to content

Commit daa67ab

Browse files
[API compatibility] add tensor.requires_grad_ method (#76439)
* add tensor.require_grad_ method * add tensor.requires_grad_ test * add tensor.requires_grad_ test * fix ffn_hidden_size is None when init gpt_mlp * remove eager_method
1 parent ff34d8b commit daa67ab

File tree

4 files changed

+295
-0
lines changed

4 files changed

+295
-0
lines changed

python/paddle/base/dygraph/math_op_patch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,19 @@ def requires_grad(self: Tensor, value: bool) -> None:
574574
)
575575
self.stop_gradient = not value
576576

577+
def requires_grad_(self, value: bool) -> None:
578+
"""
579+
Set whether this Tensor requires gradient computation.
580+
581+
Args:
582+
value (bool): True to enable gradient computation, False to disable.
583+
"""
584+
if not isinstance(value, bool):
585+
raise TypeError(
586+
f"requires_grad must be bool, but got {type(value)}"
587+
)
588+
self.stop_gradient = not value
589+
577590
@property
578591
def itemsize(self: Tensor) -> int:
579592
"""
@@ -625,6 +638,7 @@ def _reduce_ex_(self: Tensor, proto):
625638
('new_ones', _new_ones_),
626639
('new_zeros', _new_zeros_),
627640
("requires_grad", requires_grad),
641+
("requires_grad_", requires_grad_),
628642
# for logical compare
629643
('__array_ufunc__', None),
630644
('itemsize', itemsize),

python/paddle/base/layers/math_op_patch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,19 @@ def requires_grad(self, value: bool) -> None:
598598
)
599599
self.stop_gradient = not value
600600

601+
def requires_grad_(self, value: bool) -> None:
602+
"""
603+
Set whether this Tensor requires gradient computation.
604+
605+
Args:
606+
value (bool): True to enable gradient computation, False to disable.
607+
"""
608+
if not isinstance(value, bool):
609+
raise TypeError(
610+
f"requires_grad must be bool, but got {type(value)}"
611+
)
612+
self.stop_gradient = not value
613+
601614
def _scalar_add_(var, value):
602615
return _scalar_op_(var, 1.0, value)
603616

@@ -849,6 +862,7 @@ def to_dense(var):
849862
('ndimension', ndimension),
850863
('ndim', _ndim),
851864
("requires_grad", requires_grad),
865+
("requires_grad_", requires_grad_),
852866
(
853867
'__add__',
854868
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_),

python/paddle/pir/math_op_patch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,19 @@ def requires_grad(self, value: bool) -> None:
14571457
)
14581458
self.stop_gradient = not value
14591459

1460+
def requires_grad_(self, value: bool) -> None:
1461+
"""
1462+
Set whether this Tensor requires gradient computation.
1463+
1464+
Args:
1465+
value (bool): True to enable gradient computation, False to disable.
1466+
"""
1467+
if not isinstance(value, bool):
1468+
raise TypeError(
1469+
f"requires_grad must be bool, but got {type(value)}"
1470+
)
1471+
self.stop_gradient = not value
1472+
14601473
@property
14611474
def itemsize(self) -> int:
14621475
"""
@@ -1508,6 +1521,7 @@ def get_device(self) -> None:
15081521
('new_ones', _new_ones_),
15091522
('new_zeros', _new_zeros_),
15101523
("requires_grad", requires_grad),
1524+
("requires_grad_", requires_grad_),
15111525
('clone', clone),
15121526
('clear_gradient', clear_gradient),
15131527
('append', append),

test/legacy_test/test_tensor_requires_grad.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,5 +219,258 @@ def test_requires_grad_edge_cases(self):
219219
self.assertTrue(x.requires_grad)
220220

221221

222+
class TestTensorRequiresGrad_(unittest.TestCase):
223+
def setUp(self):
224+
"""Set up test fixtures before each test method."""
225+
paddle.disable_static()
226+
np.random.seed(1919)
227+
228+
def tearDown(self):
229+
"""Clean up after each test method."""
230+
paddle.disable_static()
231+
232+
def test_basic_requires_grad_property(self):
233+
"""Test basic requires_grad property functionality"""
234+
# Test default behavior - new tensors have stop_gradient=True by default
235+
x = paddle.randn([2, 3])
236+
self.assertFalse(x.requires_grad)
237+
self.assertTrue(x.stop_gradient)
238+
239+
# Test setting requires_grad to True
240+
x.requires_grad_(True)
241+
self.assertTrue(x.requires_grad)
242+
self.assertFalse(x.stop_gradient)
243+
244+
# Test setting requires_grad to False
245+
x.requires_grad_(False)
246+
self.assertFalse(x.requires_grad)
247+
self.assertTrue(x.stop_gradient)
248+
249+
def test_requires_grad_consistency_with_stop_gradient(self):
250+
"""Test that requires_grad is always the opposite of stop_gradient"""
251+
x = paddle.randn([3, 4])
252+
253+
# Test multiple state changes
254+
states = [True, False, True, False]
255+
for requires_grad_state in states:
256+
x.requires_grad_(requires_grad_state)
257+
self.assertEqual(x.requires_grad, requires_grad_state)
258+
self.assertEqual(x.stop_gradient, not requires_grad_state)
259+
260+
# Also test setting stop_gradient directly
261+
x.stop_gradient = requires_grad_state
262+
self.assertEqual(x.requires_grad, not requires_grad_state)
263+
self.assertEqual(x.stop_gradient, requires_grad_state)
264+
265+
def test_requires_grad_type_checking(self):
266+
"""Test type checking for requires_grad setter"""
267+
x = paddle.randn([2, 2])
268+
269+
# Valid boolean values should work
270+
x.requires_grad_(True)
271+
x.requires_grad_(False)
272+
273+
# Invalid types should raise TypeError
274+
invalid_values = ["true", 1, 0, None, [], {}]
275+
for invalid_value in invalid_values:
276+
with self.assertRaises(TypeError) as cm:
277+
x.requires_grad_(invalid_value)
278+
self.assertIn("requires_grad must be bool", str(cm.exception))
279+
280+
def test_requires_grad_with_parameter(self):
281+
"""Test requires_grad behavior with Parameter tensors"""
282+
# Create a parameter - Parameters have stop_gradient=False by default (trainable)
283+
param = paddle.create_parameter([3, 4], dtype='float32')
284+
self.assertTrue(
285+
param.requires_grad
286+
) # Parameters require grad by default
287+
self.assertFalse(
288+
param.stop_gradient
289+
) # Parameters are trainable by default
290+
291+
# Test changing requires_grad on parameter
292+
param.requires_grad_(False)
293+
self.assertFalse(param.requires_grad)
294+
self.assertTrue(param.stop_gradient)
295+
296+
def test_requires_grad_in_gradient_computation(self):
297+
"""Test requires_grad behavior in actual gradient computation"""
298+
x = paddle.randn([2, 3])
299+
y = paddle.randn([2, 3])
300+
301+
# Set both tensors to require grad
302+
x.requires_grad_(True)
303+
y.requires_grad_(True)
304+
305+
z = x * y + x.sum()
306+
z.backward()
307+
308+
self.assertIsNotNone(x.grad)
309+
self.assertIsNotNone(y.grad)
310+
311+
# Clear gradients and test with requires_grad=False
312+
x.grad._clear_data()
313+
y.grad._clear_data()
314+
315+
x.requires_grad_(False)
316+
y.requires_grad_(True)
317+
318+
z = x * y + x.sum()
319+
z.backward()
320+
321+
self.assertIsNone(x.grad) # x doesn't require grad
322+
self.assertIsNotNone(y.grad) # y requires grad
323+
324+
def test_requires_grad_with_different_tensor_types(self):
325+
"""Test requires_grad with different tensor creation methods"""
326+
# Test with different tensor creation functions
327+
tensor_creators = [
328+
lambda: paddle.randn([2, 3]),
329+
lambda: paddle.zeros([2, 3]),
330+
lambda: paddle.ones([2, 3]),
331+
lambda: paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32'),
332+
lambda: paddle.arange(6, dtype='float32').reshape([2, 3]),
333+
]
334+
335+
for creator in tensor_creators:
336+
x = creator()
337+
# All newly created tensors should have requires_grad=False by default
338+
self.assertFalse(x.requires_grad)
339+
self.assertTrue(x.stop_gradient)
340+
341+
# Test modification
342+
x.requires_grad_(True)
343+
self.assertTrue(x.requires_grad)
344+
self.assertFalse(x.stop_gradient)
345+
346+
def test_requires_grad_with_tensor_operations(self):
347+
"""Test requires_grad preservation through tensor operations"""
348+
x = paddle.randn([3, 3])
349+
y = paddle.randn([3, 3])
350+
351+
x.requires_grad_(True)
352+
y.requires_grad_(False)
353+
354+
# Operations should preserve requires_grad appropriately
355+
z1 = x + y # Should require grad (x requires grad)
356+
z2 = x * 2.0 # Should require grad (x requires grad)
357+
z3 = y.sin() # Should not require grad (y doesn't require grad)
358+
359+
self.assertTrue(z1.requires_grad)
360+
self.assertTrue(z2.requires_grad)
361+
self.assertFalse(z3.requires_grad)
362+
363+
def test_requires_grad_with_detach(self):
364+
"""Test requires_grad behavior with detach operation"""
365+
x = paddle.randn([2, 3])
366+
x.requires_grad_(True)
367+
368+
y = x.detach()
369+
370+
# Detached tensor should not require grad
371+
self.assertTrue(x.requires_grad)
372+
self.assertFalse(y.requires_grad)
373+
self.assertTrue(y.stop_gradient)
374+
375+
def test_requires_grad_old_static_mode(self):
376+
"""Test requires_grad behavior in static mode"""
377+
paddle.enable_static()
378+
with paddle.pir_utils.OldIrGuard():
379+
x = paddle.static.data(name='x', shape=[2, 3], dtype='float32')
380+
381+
# In static mode, variables also have stop_gradient=True by default
382+
self.assertFalse(x.requires_grad)
383+
self.assertTrue(x.stop_gradient)
384+
385+
# Test setting requires_grad in static mode
386+
x.requires_grad_(True)
387+
self.assertTrue(x.requires_grad)
388+
self.assertFalse(x.stop_gradient)
389+
390+
def test_requires_grad_static_mode(self):
391+
"""Test requires_grad behavior in static mode"""
392+
paddle.enable_static()
393+
394+
with paddle.static.program_guard(paddle.static.Program()):
395+
x = paddle.static.data(name='x', shape=[2, 3], dtype='float32')
396+
397+
# In static mode, variables also have stop_gradient=True by default
398+
self.assertFalse(x.requires_grad)
399+
self.assertTrue(x.stop_gradient)
400+
401+
# Test setting requires_grad in static mode
402+
x.requires_grad_(True)
403+
self.assertTrue(x.requires_grad)
404+
self.assertFalse(x.stop_gradient)
405+
406+
def test_requires_grad_edge_cases(self):
407+
"""Test edge cases for requires_grad"""
408+
# Test with scalar tensor
409+
scalar = paddle.to_tensor(3.14)
410+
self.assertFalse(scalar.requires_grad) # False
411+
scalar.requires_grad_(True)
412+
self.assertTrue(scalar.requires_grad)
413+
414+
# Test with empty tensor
415+
empty = paddle.empty([0, 3])
416+
self.assertFalse(empty.requires_grad) # False
417+
empty.requires_grad_(True)
418+
self.assertTrue(empty.requires_grad)
419+
420+
# Test with different dtypes
421+
dtypes = [paddle.float32, paddle.float64, paddle.int32, paddle.int64]
422+
for dtype in dtypes:
423+
x = paddle.ones([2, 2], dtype=dtype)
424+
# All tensors should have requires_grad=False by default
425+
self.assertFalse(x.requires_grad)
426+
427+
# Float tensors should support requires_grad
428+
if dtype in [paddle.float32, paddle.float64]:
429+
x.requires_grad_(True)
430+
self.assertTrue(x.requires_grad)
431+
432+
433+
class TestAPI(unittest.TestCase):
434+
def setUp(self):
435+
paddle.enable_static()
436+
437+
def assert_api(self, api_func, require_grad):
438+
main_program = paddle.static.Program()
439+
with paddle.static.program_guard(main_program):
440+
x = api_func()
441+
self.assertEqual(x.stop_gradient, require_grad)
442+
# test for setter
443+
x.requires_grad_(require_grad)
444+
self.assertEqual(x.stop_gradient, not require_grad)
445+
446+
def test_full(self):
447+
api = lambda: paddle.full(shape=[2, 3], fill_value=1.0)
448+
self.assert_api(api, True)
449+
450+
def test_data(self):
451+
api = lambda: paddle.static.data('x', [4, 4], dtype='float32')
452+
self.assert_api(api, True)
453+
454+
# TODO(Aurelius84): Add more test cases after API is migrated.
455+
456+
457+
class TestParameters(unittest.TestCase):
458+
def setUp(self):
459+
paddle.enable_static()
460+
461+
def test_create_param(self):
462+
main_program = paddle.static.Program()
463+
with paddle.static.program_guard(main_program):
464+
w = paddle.create_parameter(shape=[784, 200], dtype='float32')
465+
self.assertEqual(w.stop_gradient, False)
466+
self.assertEqual(w.persistable, True)
467+
468+
# test for setter
469+
w.requires_grad_(False)
470+
w.persistable = False
471+
self.assertEqual(w.stop_gradient, True)
472+
self.assertEqual(w.persistable, False)
473+
474+
222475
if __name__ == '__main__':
223476
unittest.main()

0 commit comments

Comments
 (0)