@@ -62,6 +62,22 @@ def _calculate_gain(nonlinearity, param):
6262 return recommended_gain [nonlinearity ]
6363
6464
65+ def _calculate_fan_in_and_fan_out (var : paddle .Tensor ) -> tuple [int , int ]:
66+ shape = var .shape
67+ if not shape or len (shape ) == 0 :
68+ fan_in = fan_out = 1
69+ elif len (shape ) == 1 :
70+ fan_in = fan_out = shape [0 ]
71+ elif len (shape ) == 2 :
72+ fan_in = shape [0 ]
73+ fan_out = shape [1 ]
74+ else :
75+ receptive_field_size = np .prod (shape [2 :])
76+ fan_in = shape [1 ] * receptive_field_size
77+ fan_out = shape [0 ] * receptive_field_size
78+ return (fan_in , fan_out )
79+
80+
6581class Test_calculate_gain (unittest .TestCase ):
6682 def test (self ):
6783 for nonlinearity in [
@@ -87,6 +103,27 @@ def test(self):
87103 )
88104
89105
106+ class TestCAlFanINOUT (unittest .TestCase ):
107+ def test_cal_fan_in_and_out (self ):
108+ x = paddle .tensor .randn ([10 ])
109+ self .assertEqual (
110+ _calculate_fan_in_and_fan_out (x ),
111+ paddle .nn .init ._calculate_fan_in_and_fan_out (x ),
112+ )
113+
114+ y = paddle .tensor .randn ([10 , 10 ])
115+ self .assertEqual (
116+ _calculate_fan_in_and_fan_out (y ),
117+ paddle .nn .init ._calculate_fan_in_and_fan_out (y ),
118+ )
119+
120+ z = paddle .randn ([10 , 10 , 10 ])
121+ self .assertEqual (
122+ _calculate_fan_in_and_fan_out (z ),
123+ paddle .nn .init ._calculate_fan_in_and_fan_out (z ),
124+ )
125+
126+
90127class Test_kaiming_uniform_ (unittest .TestCase ):
91128 def check_kaiming_uniform (
92129 self , tensor , a = 0 , mode = 'fan_in' , nonlinearity = 'leaky_relu'
0 commit comments