@@ -219,5 +219,258 @@ def test_requires_grad_edge_cases(self):
219219 self .assertTrue (x .requires_grad )
220220
221221
222+ class TestTensorRequiresGrad_ (unittest .TestCase ):
223+ def setUp (self ):
224+ """Set up test fixtures before each test method."""
225+ paddle .disable_static ()
226+ np .random .seed (1919 )
227+
228+ def tearDown (self ):
229+ """Clean up after each test method."""
230+ paddle .disable_static ()
231+
232+ def test_basic_requires_grad_property (self ):
233+ """Test basic requires_grad property functionality"""
234+ # Test default behavior - new tensors have stop_gradient=True by default
235+ x = paddle .randn ([2 , 3 ])
236+ self .assertFalse (x .requires_grad )
237+ self .assertTrue (x .stop_gradient )
238+
239+ # Test setting requires_grad to True
240+ x .requires_grad_ (True )
241+ self .assertTrue (x .requires_grad )
242+ self .assertFalse (x .stop_gradient )
243+
244+ # Test setting requires_grad to False
245+ x .requires_grad_ (False )
246+ self .assertFalse (x .requires_grad )
247+ self .assertTrue (x .stop_gradient )
248+
249+ def test_requires_grad_consistency_with_stop_gradient (self ):
250+ """Test that requires_grad is always the opposite of stop_gradient"""
251+ x = paddle .randn ([3 , 4 ])
252+
253+ # Test multiple state changes
254+ states = [True , False , True , False ]
255+ for requires_grad_state in states :
256+ x .requires_grad_ (requires_grad_state )
257+ self .assertEqual (x .requires_grad , requires_grad_state )
258+ self .assertEqual (x .stop_gradient , not requires_grad_state )
259+
260+ # Also test setting stop_gradient directly
261+ x .stop_gradient = requires_grad_state
262+ self .assertEqual (x .requires_grad , not requires_grad_state )
263+ self .assertEqual (x .stop_gradient , requires_grad_state )
264+
265+ def test_requires_grad_type_checking (self ):
266+ """Test type checking for requires_grad setter"""
267+ x = paddle .randn ([2 , 2 ])
268+
269+ # Valid boolean values should work
270+ x .requires_grad_ (True )
271+ x .requires_grad_ (False )
272+
273+ # Invalid types should raise TypeError
274+ invalid_values = ["true" , 1 , 0 , None , [], {}]
275+ for invalid_value in invalid_values :
276+ with self .assertRaises (TypeError ) as cm :
277+ x .requires_grad_ (invalid_value )
278+ self .assertIn ("requires_grad must be bool" , str (cm .exception ))
279+
280+ def test_requires_grad_with_parameter (self ):
281+ """Test requires_grad behavior with Parameter tensors"""
282+ # Create a parameter - Parameters have stop_gradient=False by default (trainable)
283+ param = paddle .create_parameter ([3 , 4 ], dtype = 'float32' )
284+ self .assertTrue (
285+ param .requires_grad
286+ ) # Parameters require grad by default
287+ self .assertFalse (
288+ param .stop_gradient
289+ ) # Parameters are trainable by default
290+
291+ # Test changing requires_grad on parameter
292+ param .requires_grad_ (False )
293+ self .assertFalse (param .requires_grad )
294+ self .assertTrue (param .stop_gradient )
295+
296+ def test_requires_grad_in_gradient_computation (self ):
297+ """Test requires_grad behavior in actual gradient computation"""
298+ x = paddle .randn ([2 , 3 ])
299+ y = paddle .randn ([2 , 3 ])
300+
301+ # Set both tensors to require grad
302+ x .requires_grad_ (True )
303+ y .requires_grad_ (True )
304+
305+ z = x * y + x .sum ()
306+ z .backward ()
307+
308+ self .assertIsNotNone (x .grad )
309+ self .assertIsNotNone (y .grad )
310+
311+ # Clear gradients and test with requires_grad=False
312+ x .grad ._clear_data ()
313+ y .grad ._clear_data ()
314+
315+ x .requires_grad_ (False )
316+ y .requires_grad_ (True )
317+
318+ z = x * y + x .sum ()
319+ z .backward ()
320+
321+ self .assertIsNone (x .grad ) # x doesn't require grad
322+ self .assertIsNotNone (y .grad ) # y requires grad
323+
324+ def test_requires_grad_with_different_tensor_types (self ):
325+ """Test requires_grad with different tensor creation methods"""
326+ # Test with different tensor creation functions
327+ tensor_creators = [
328+ lambda : paddle .randn ([2 , 3 ]),
329+ lambda : paddle .zeros ([2 , 3 ]),
330+ lambda : paddle .ones ([2 , 3 ]),
331+ lambda : paddle .to_tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = 'float32' ),
332+ lambda : paddle .arange (6 , dtype = 'float32' ).reshape ([2 , 3 ]),
333+ ]
334+
335+ for creator in tensor_creators :
336+ x = creator ()
337+ # All newly created tensors should have requires_grad=False by default
338+ self .assertFalse (x .requires_grad )
339+ self .assertTrue (x .stop_gradient )
340+
341+ # Test modification
342+ x .requires_grad_ (True )
343+ self .assertTrue (x .requires_grad )
344+ self .assertFalse (x .stop_gradient )
345+
346+ def test_requires_grad_with_tensor_operations (self ):
347+ """Test requires_grad preservation through tensor operations"""
348+ x = paddle .randn ([3 , 3 ])
349+ y = paddle .randn ([3 , 3 ])
350+
351+ x .requires_grad_ (True )
352+ y .requires_grad_ (False )
353+
354+ # Operations should preserve requires_grad appropriately
355+ z1 = x + y # Should require grad (x requires grad)
356+ z2 = x * 2.0 # Should require grad (x requires grad)
357+ z3 = y .sin () # Should not require grad (y doesn't require grad)
358+
359+ self .assertTrue (z1 .requires_grad )
360+ self .assertTrue (z2 .requires_grad )
361+ self .assertFalse (z3 .requires_grad )
362+
363+ def test_requires_grad_with_detach (self ):
364+ """Test requires_grad behavior with detach operation"""
365+ x = paddle .randn ([2 , 3 ])
366+ x .requires_grad_ (True )
367+
368+ y = x .detach ()
369+
370+ # Detached tensor should not require grad
371+ self .assertTrue (x .requires_grad )
372+ self .assertFalse (y .requires_grad )
373+ self .assertTrue (y .stop_gradient )
374+
375+ def test_requires_grad_old_static_mode (self ):
376+ """Test requires_grad behavior in static mode"""
377+ paddle .enable_static ()
378+ with paddle .pir_utils .OldIrGuard ():
379+ x = paddle .static .data (name = 'x' , shape = [2 , 3 ], dtype = 'float32' )
380+
381+ # In static mode, variables also have stop_gradient=True by default
382+ self .assertFalse (x .requires_grad )
383+ self .assertTrue (x .stop_gradient )
384+
385+ # Test setting requires_grad in static mode
386+ x .requires_grad_ (True )
387+ self .assertTrue (x .requires_grad )
388+ self .assertFalse (x .stop_gradient )
389+
390+ def test_requires_grad_static_mode (self ):
391+ """Test requires_grad behavior in static mode"""
392+ paddle .enable_static ()
393+
394+ with paddle .static .program_guard (paddle .static .Program ()):
395+ x = paddle .static .data (name = 'x' , shape = [2 , 3 ], dtype = 'float32' )
396+
397+ # In static mode, variables also have stop_gradient=True by default
398+ self .assertFalse (x .requires_grad )
399+ self .assertTrue (x .stop_gradient )
400+
401+ # Test setting requires_grad in static mode
402+ x .requires_grad_ (True )
403+ self .assertTrue (x .requires_grad )
404+ self .assertFalse (x .stop_gradient )
405+
406+ def test_requires_grad_edge_cases (self ):
407+ """Test edge cases for requires_grad"""
408+ # Test with scalar tensor
409+ scalar = paddle .to_tensor (3.14 )
410+ self .assertFalse (scalar .requires_grad ) # False
411+ scalar .requires_grad_ (True )
412+ self .assertTrue (scalar .requires_grad )
413+
414+ # Test with empty tensor
415+ empty = paddle .empty ([0 , 3 ])
416+ self .assertFalse (empty .requires_grad ) # False
417+ empty .requires_grad_ (True )
418+ self .assertTrue (empty .requires_grad )
419+
420+ # Test with different dtypes
421+ dtypes = [paddle .float32 , paddle .float64 , paddle .int32 , paddle .int64 ]
422+ for dtype in dtypes :
423+ x = paddle .ones ([2 , 2 ], dtype = dtype )
424+ # All tensors should have requires_grad=False by default
425+ self .assertFalse (x .requires_grad )
426+
427+ # Float tensors should support requires_grad
428+ if dtype in [paddle .float32 , paddle .float64 ]:
429+ x .requires_grad_ (True )
430+ self .assertTrue (x .requires_grad )
431+
432+
433+ class TestAPI (unittest .TestCase ):
434+ def setUp (self ):
435+ paddle .enable_static ()
436+
437+ def assert_api (self , api_func , require_grad ):
438+ main_program = paddle .static .Program ()
439+ with paddle .static .program_guard (main_program ):
440+ x = api_func ()
441+ self .assertEqual (x .stop_gradient , require_grad )
442+ # test for setter
443+ x .requires_grad_ (require_grad )
444+ self .assertEqual (x .stop_gradient , not require_grad )
445+
446+ def test_full (self ):
447+ api = lambda : paddle .full (shape = [2 , 3 ], fill_value = 1.0 )
448+ self .assert_api (api , True )
449+
450+ def test_data (self ):
451+ api = lambda : paddle .static .data ('x' , [4 , 4 ], dtype = 'float32' )
452+ self .assert_api (api , True )
453+
454+ # TODO(Aurelius84): Add more test cases after API is migrated.
455+
456+
457+ class TestParameters (unittest .TestCase ):
458+ def setUp (self ):
459+ paddle .enable_static ()
460+
461+ def test_create_param (self ):
462+ main_program = paddle .static .Program ()
463+ with paddle .static .program_guard (main_program ):
464+ w = paddle .create_parameter (shape = [784 , 200 ], dtype = 'float32' )
465+ self .assertEqual (w .stop_gradient , False )
466+ self .assertEqual (w .persistable , True )
467+
468+ # test for setter
469+ w .requires_grad_ (False )
470+ w .persistable = False
471+ self .assertEqual (w .stop_gradient , True )
472+ self .assertEqual (w .persistable , False )
473+
474+
222475if __name__ == '__main__' :
223476 unittest .main ()
0 commit comments