@@ -21,41 +21,37 @@ def conv2d_forward_naive(input, filter, group, conv_param):
2121 for i in range (out_h ):
2222 for j in range (out_w ):
2323 for g in range (group ):
24- input_pad_masked = input_pad [:, g * f_c :(
25- g + 1 ) * f_c , i * stride [0 ]:i * stride [0 ] + f_h , j * stride [
26- 1 ]:j * stride [1 ] + f_w ]
24+ input_pad_masked = \
25+ input_pad [:, g * f_c :(g + 1 ) * f_c ,
26+ i * stride [0 ]:i * stride [0 ] + f_h ,
27+ j * stride [1 ]:j * stride [1 ] + f_w ]
28+
2729 f_sub = filter [g * sub_out_c :(g + 1 ) * sub_out_c , :, :, :]
2830 for k in range (sub_out_c ):
29- out [:, g * sub_out_c + k , i , j ] = np . sum ( input_pad_masked *
30- f_sub [k , :, :, :],
31- axis = (1 , 2 , 3 ))
31+ out [:, g * sub_out_c + k , i , j ] = \
32+ np . sum ( input_pad_masked * f_sub [k , :, :, :],
33+ axis = (1 , 2 , 3 ))
3234
3335 return out
3436
3537
3638class TestConv2dOp (OpTest ):
3739 def setUp (self ):
38- self .init_groups ()
39- self .init_optype ()
40- pad = [0 , 0 ]
41- stride = [1 , 1 ]
42- input_size = [2 , 3 , 5 , 5 ] # NCHW
43- assert np .mod (input_size [1 ], self .groups ) == 0
44- f_c = input_size [1 ] / self .groups
45- filter_size = [6 , f_c , 3 , 3 ]
46-
47- conv2d_param = {'stride' : stride , 'pad' : pad }
48- input = np .random .random (input_size ).astype ("float32" )
49- filter = np .random .random (filter_size ).astype ("float32" )
40+ self .init_op_type ()
41+ self .init_group ()
42+ self .init_test_case ()
5043
44+ conv2d_param = {'stride' : self .stride , 'pad' : self .pad }
45+ input = np .random .random (self .input_size ).astype ("float32" )
46+ filter = np .random .random (self .filter_size ).astype ("float32" )
5147 output = conv2d_forward_naive (input , filter , self .groups , conv2d_param )
5248
5349 self .inputs = {'Input' : input , 'Filter' : filter }
5450 self .attrs = {
55- 'strides' : stride ,
56- 'paddings' : pad ,
51+ 'strides' : self . stride ,
52+ 'paddings' : self . pad ,
5753 'groups' : self .groups ,
58- 'dilations' : [ 1 , 1 ]
54+ 'dilations' : self . dilations
5955 }
6056 self .outputs = {'Output' : output }
6157
@@ -80,30 +76,47 @@ def test_check_grad_no_input(self):
8076 max_relative_error = 0.05 ,
8177 no_grad_set = set (['Input' ]))
8278
83- def init_groups (self ):
79+ def init_test_case (self ):
80+ self .groups = 1
81+ self .op_type = "conv2d"
82+ self .pad = [0 , 0 ]
83+ self .stride = [1 , 1 ]
84+ self .dilations = [1 , 1 ]
85+ self .input_size = [2 , 3 , 5 , 5 ] # NCHW
86+ assert np .mod (self .input_size [1 ], self .groups ) == 0
87+ f_c = self .input_size [1 ] / self .groups
88+ self .filter_size = [6 , f_c , 3 , 3 ]
89+
90+ def init_group (self ):
8491 self .groups = 1
8592
86- def init_optype (self ):
93+ def init_op_type (self ):
8794 self .op_type = "conv2d"
8895
8996
9097class TestWithGroup (TestConv2dOp ):
91- def init_groups (self ):
98+ def init_group (self ):
9299 self .groups = 3
93100
101+ def init_op_type (self ):
102+ self .op_type = "conv2d"
94103
95- class TestCudnn2d (TestConv2dOp ):
96- def init_optype (self ):
97- self .op_type = "conv_cudnn"
98104
105+ class TestCudnn (TestConv2dOp ):
106+ def init_group (self ):
107+ self .groups = 1
99108
100- class TestCudnn2dWithGroup (TestConv2dOp ):
101- def init_optype (self ):
109+ def init_op_type (self ):
102110 self .op_type = "conv_cudnn"
103111
104- def init_groups (self ):
112+
113+ class TestCudnnWithGroup (TestConv2dOp ):
114+ def init_group (self ):
105115 self .groups = 3
106116
117+ def init_op_type (self ):
118+ self .op_type = "conv_cudnn"
119+
107120
108121if __name__ == '__main__' :
109122 unittest .main ()
0 commit comments