Skip to content

Commit 6ef2da2

Browse files
committed
finetune conv2d navie func
1 parent 92c3944 commit 6ef2da2

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

python/paddle/v2/framework/tests/test_conv2d_op.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@ def conv2d_forward_naive(input, filter, group, conv_param):
1111
sub_out_c = out_c / group
1212

1313
stride, pad = conv_param['stride'], conv_param['pad']
14-
out_h = 1 + (in_h + 2 * pad - f_h) / stride
15-
out_w = 1 + (in_w + 2 * pad - f_w) / stride
14+
out_h = 1 + (in_h + 2 * pad[0] - f_h) / stride[0]
15+
out_w = 1 + (in_w + 2 * pad[1] - f_w) / stride[1]
1616
out = np.zeros((in_n, out_c, out_h, out_w))
1717

18-
input_pad = np.pad(input, ((0, ), (0, ), (pad, ), (pad, )),
18+
input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )),
1919
mode='constant',
2020
constant_values=0)
2121
for i in range(out_h):
2222
for j in range(out_w):
2323
for g in range(group):
2424
input_pad_masked = input_pad[:, g * f_c:(
25-
g + 1) * f_c, i * stride:i * stride + f_h, j * stride:j *
26-
stride + f_w]
25+
g + 1) * f_c, i * stride[0]:i * stride[0] + f_h, j * stride[
26+
1]:j * stride[1] + f_w]
2727
f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :]
2828
for k in range(sub_out_c):
2929
out[:, g * sub_out_c + k, i, j] = np.sum(input_pad_masked *
@@ -37,23 +37,22 @@ class TestConv2dOp(OpTest):
3737
def setUp(self):
3838
self.init_groups()
3939
self.op_type = "conv2d"
40+
pad = [0, 0]
41+
stride = [1, 1]
4042
input_size = [2, 3, 5, 5] # NCHW
4143
assert np.mod(input_size[1], self.groups) == 0
4244
f_c = input_size[1] / self.groups
4345
filter_size = [6, f_c, 3, 3]
44-
conv2d_param = {'stride': 1, 'pad': 0}
46+
47+
conv2d_param = {'stride': stride, 'pad': pad}
4548

4649
input = np.random.random(input_size).astype("float32")
4750
filter = np.random.random(filter_size).astype("float32")
4851

4952
output = conv2d_forward_naive(input, filter, self.groups, conv2d_param)
5053

5154
self.inputs = {'Input': input, 'Filter': filter}
52-
self.attrs = {
53-
'strides': [1, 1],
54-
'paddings': [0, 0],
55-
'groups': self.groups
56-
}
55+
self.attrs = {'strides': stride, 'paddings': pad, 'groups': self.groups}
5756
self.outputs = {'Output': output}
5857

5958
def test_check_output(self):

0 commit comments

Comments
 (0)