@@ -11,19 +11,19 @@ def conv2d_forward_naive(input, filter, group, conv_param):
1111 sub_out_c = out_c / group
1212
1313 stride , pad = conv_param ['stride' ], conv_param ['pad' ]
14- out_h = 1 + (in_h + 2 * pad - f_h ) / stride
15- out_w = 1 + (in_w + 2 * pad - f_w ) / stride
14+ out_h = 1 + (in_h + 2 * pad [ 0 ] - f_h ) / stride [ 0 ]
15+ out_w = 1 + (in_w + 2 * pad [ 1 ] - f_w ) / stride [ 1 ]
1616 out = np .zeros ((in_n , out_c , out_h , out_w ))
1717
18- input_pad = np .pad (input , ((0 , ), (0 , ), (pad , ), (pad , )),
18+ input_pad = np .pad (input , ((0 , ), (0 , ), (pad [ 0 ] , ), (pad [ 1 ] , )),
1919 mode = 'constant' ,
2020 constant_values = 0 )
2121 for i in range (out_h ):
2222 for j in range (out_w ):
2323 for g in range (group ):
2424 input_pad_masked = input_pad [:, g * f_c :(
25- g + 1 ) * f_c , i * stride :i * stride + f_h , j * stride : j *
26- stride + f_w ]
25+ g + 1 ) * f_c , i * stride [ 0 ] :i * stride [ 0 ] + f_h , j * stride [
26+ 1 ]: j * stride [ 1 ] + f_w ]
2727 f_sub = filter [g * sub_out_c :(g + 1 ) * sub_out_c , :, :, :]
2828 for k in range (sub_out_c ):
2929 out [:, g * sub_out_c + k , i , j ] = np .sum (input_pad_masked *
@@ -37,23 +37,22 @@ class TestConv2dOp(OpTest):
3737 def setUp (self ):
3838 self .init_groups ()
3939 self .op_type = "conv2d"
40+ pad = [0 , 0 ]
41+ stride = [1 , 1 ]
4042 input_size = [2 , 3 , 5 , 5 ] # NCHW
4143 assert np .mod (input_size [1 ], self .groups ) == 0
4244 f_c = input_size [1 ] / self .groups
4345 filter_size = [6 , f_c , 3 , 3 ]
44- conv2d_param = {'stride' : 1 , 'pad' : 0 }
46+
47+ conv2d_param = {'stride' : stride , 'pad' : pad }
4548
4649 input = np .random .random (input_size ).astype ("float32" )
4750 filter = np .random .random (filter_size ).astype ("float32" )
4851
4952 output = conv2d_forward_naive (input , filter , self .groups , conv2d_param )
5053
5154 self .inputs = {'Input' : input , 'Filter' : filter }
52- self .attrs = {
53- 'strides' : [1 , 1 ],
54- 'paddings' : [0 , 0 ],
55- 'groups' : self .groups
56- }
55+ self .attrs = {'strides' : stride , 'paddings' : pad , 'groups' : self .groups }
5756 self .outputs = {'Output' : output }
5857
5958 def test_check_output (self ):
0 commit comments