@@ -21,19 +21,17 @@ type ChannelPoolingLayerState <: LayerState
2121 etc :: Any
2222end
2323
24- function setup_etc (backend:: CPUBackend , layer:: ChannelPoolingLayer , inputs, pooled_chann )
24+ function setup_etc (backend:: CPUBackend , layer:: ChannelPoolingLayer , inputs, blobs )
2525 if isa (layer. pooling, Pooling. Max)
2626 masks = Array (Array, length (inputs))
2727 for i = 1 : length (inputs)
28- masks[i] = Array (Csize_t, get_width (inputs[i]), get_height (inputs[i]),
29- pooled_chann[i], get_num (inputs[i]))
28+ masks[i] = Array (Csize_t, size (blobs[i]))
3029 end
3130 etc = masks
3231 elseif isa (layer. pooling, Pooling. Mean)
3332 integrals = Array (Array, length (inputs))
3433 for i = 1 : length (inputs)
35- integrals[i] = Array (eltype (inputs[i]), get_width (inputs[i]), get_height (inputs[i]),
36- get_chann (inputs[i]))
34+ integrals[i] = Array (eltype (inputs[i]), size (inputs[i])[1 : end - 1 ])
3735 end
3836 etc = integrals
3937 else
@@ -43,40 +41,43 @@ function setup_etc(backend::CPUBackend, layer::ChannelPoolingLayer, inputs, pool
4341end
4442
4543function setup (backend:: Backend , layer:: ChannelPoolingLayer , inputs:: Vector{Blob} , diffs:: Vector{Blob} )
46- for i = 1 : length (inputs)
47- # currently we only handle 4D-tensor
48- @assert ndims (inputs[i]) == 4
49- end
50-
5144 pooled_chann_all = Array (Int, length (inputs))
5245 blobs = Array (Blob, length (inputs))
5346 blobs_diff = Array (Blob, length (inputs))
5447 op_dims = Array (Int, length (inputs))
5548
5649 for i = 1 : length (inputs)
5750 dim_total = ndims (inputs[i])
58- dim = layer. dim < 0 ? layer. dim + dim_total+ 1 : layer. dim
59- op_dims[i] = dim
51+ op_dim = layer. dim < 0 ? layer. dim + dim_total+ 1 : layer. dim
52+ @assert 1 <= op_dim <= dim_total
53+ @assert op_dim != dim_total
6054
61- width, height, channels, num = size (inputs[i])
62- pooled_chann = int (ceil (float (channels + layer. pad[1 ]+ layer. pad[2 ] - layer. kernel) / layer. stride)) + 1
55+ op_dims[i] = op_dim
56+
57+ dims = [size (inputs[i])... ]
58+ pool_dim = dims[op_dim]
59+ pooled_dim = int (ceil (float (pool_dim + layer. pad[1 ]+ layer. pad[2 ] - layer. kernel) / layer. stride)) + 1
6360
6461 # make sure the last pooling is not purely pooling padded area
65- if ((pooled_chann - 1 )* layer. stride >= channels + layer. pad[1 ])
66- pooled_chann -= 1
62+ if ((pooled_dim - 1 )* layer. stride >= pool_dim + layer. pad[1 ])
63+ pooled_dim -= 1
6764 end
68- pooled_chann_all[i] = pooled_chann
65+ pooled_chann_all[i] = pooled_dim
66+
67+ output_dims = copy (dims)
68+ output_dims[op_dim] = pooled_dim
69+ output_dims = tuple (output_dims... )
6970
7071 data_type = eltype (inputs[i])
71- blobs[i] = make_blob (backend, data_type, (width,height,pooled_chann_all[i],num) )
72+ blobs[i] = make_blob (backend, data_type, output_dims )
7273 if isa (diffs[i], NullBlob)
7374 blobs_diff[i] = NullBlob ()
7475 else
75- blobs_diff[i] = make_blob (backend, data_type, (width,height,pooled_chann_all[i],num) )
76+ blobs_diff[i] = make_blob (backend, data_type, output_dims )
7677 end
7778 end
7879
79- etc = setup_etc (backend, layer, inputs, pooled_chann_all )
80+ etc = setup_etc (backend, layer, inputs, blobs )
8081 state = ChannelPoolingLayerState (layer, blobs, blobs_diff, op_dims, etc)
8182end
8283function shutdown_etc (backend:: CPUBackend , state:: ChannelPoolingLayerState )
0 commit comments