Skip to content

Commit 5fe6893

Browse files
committed
fix code struce
1 parent 8ad67da commit 5fe6893

File tree

1 file changed

+43
-30
lines changed

1 file changed

+43
-30
lines changed

python/paddle/v2/framework/tests/test_conv2d_op.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,41 +21,37 @@ def conv2d_forward_naive(input, filter, group, conv_param):
2121
for i in range(out_h):
2222
for j in range(out_w):
2323
for g in range(group):
24-
input_pad_masked = input_pad[:, g * f_c:(
25-
g + 1) * f_c, i * stride[0]:i * stride[0] + f_h, j * stride[
26-
1]:j * stride[1] + f_w]
24+
input_pad_masked = \
25+
input_pad[:, g * f_c:(g + 1) * f_c,
26+
i * stride[0]:i * stride[0] + f_h,
27+
j * stride[1]:j * stride[1] + f_w]
28+
2729
f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :]
2830
for k in range(sub_out_c):
29-
out[:, g * sub_out_c + k, i, j] = np.sum(input_pad_masked *
30-
f_sub[k, :, :, :],
31-
axis=(1, 2, 3))
31+
out[:, g * sub_out_c + k, i, j] = \
32+
np.sum(input_pad_masked * f_sub[k, :, :, :],
33+
axis=(1, 2, 3))
3234

3335
return out
3436

3537

3638
class TestConv2dOp(OpTest):
3739
def setUp(self):
38-
self.init_groups()
39-
self.init_optype()
40-
pad = [0, 0]
41-
stride = [1, 1]
42-
input_size = [2, 3, 5, 5] # NCHW
43-
assert np.mod(input_size[1], self.groups) == 0
44-
f_c = input_size[1] / self.groups
45-
filter_size = [6, f_c, 3, 3]
46-
47-
conv2d_param = {'stride': stride, 'pad': pad}
48-
input = np.random.random(input_size).astype("float32")
49-
filter = np.random.random(filter_size).astype("float32")
40+
self.init_op_type()
41+
self.init_group()
42+
self.init_test_case()
5043

44+
conv2d_param = {'stride': self.stride, 'pad': self.pad}
45+
input = np.random.random(self.input_size).astype("float32")
46+
filter = np.random.random(self.filter_size).astype("float32")
5147
output = conv2d_forward_naive(input, filter, self.groups, conv2d_param)
5248

5349
self.inputs = {'Input': input, 'Filter': filter}
5450
self.attrs = {
55-
'strides': stride,
56-
'paddings': pad,
51+
'strides': self.stride,
52+
'paddings': self.pad,
5753
'groups': self.groups,
58-
'dilations': [1, 1]
54+
'dilations': self.dilations
5955
}
6056
self.outputs = {'Output': output}
6157

@@ -80,30 +76,47 @@ def test_check_grad_no_input(self):
8076
max_relative_error=0.05,
8177
no_grad_set=set(['Input']))
8278

83-
def init_groups(self):
79+
def init_test_case(self):
80+
self.groups = 1
81+
self.op_type = "conv2d"
82+
self.pad = [0, 0]
83+
self.stride = [1, 1]
84+
self.dilations = [1, 1]
85+
self.input_size = [2, 3, 5, 5] # NCHW
86+
assert np.mod(self.input_size[1], self.groups) == 0
87+
f_c = self.input_size[1] / self.groups
88+
self.filter_size = [6, f_c, 3, 3]
89+
90+
def init_group(self):
8491
self.groups = 1
8592

86-
def init_optype(self):
93+
def init_op_type(self):
8794
self.op_type = "conv2d"
8895

8996

9097
class TestWithGroup(TestConv2dOp):
91-
def init_groups(self):
98+
def init_group(self):
9299
self.groups = 3
93100

101+
def init_op_type(self):
102+
self.op_type = "conv2d"
94103

95-
class TestCudnn2d(TestConv2dOp):
96-
def init_optype(self):
97-
self.op_type = "conv_cudnn"
98104

105+
class TestCudnn(TestConv2dOp):
106+
def init_group(self):
107+
self.groups = 1
99108

100-
class TestCudnn2dWithGroup(TestConv2dOp):
101-
def init_optype(self):
109+
def init_op_type(self):
102110
self.op_type = "conv_cudnn"
103111

104-
def init_groups(self):
112+
113+
class TestCudnnWithGroup(TestConv2dOp):
114+
def init_group(self):
105115
self.groups = 3
106116

117+
def init_op_type(self):
118+
self.op_type = "conv_cudnn"
119+
107120

108121
if __name__ == '__main__':
109122
unittest.main()

0 commit comments

Comments
 (0)