Skip to content

Commit c1fd5af

Browse files
committed
refactor CUDA kernel defs
1 parent 7d662a0 commit c1fd5af

File tree

3 files changed

+108
-154
lines changed

3 files changed

+108
-154
lines changed

src/cuda/backend.jl

Lines changed: 105 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,115 @@
11
export GPUBackend
22

3-
type MochaKernels
4-
mod :: CUDA.CuModule
5-
6-
# implemented kernels
7-
logistic_loss_forward_float :: CUDA.CuFunction
8-
logistic_loss_forward_double :: CUDA.CuFunction
9-
softmax_loss_backward_float :: CUDA.CuFunction
10-
softmax_loss_backward_double :: CUDA.CuFunction
11-
relu_forward_float :: CUDA.CuFunction
12-
relu_forward_double :: CUDA.CuFunction
13-
relu_backward_float :: CUDA.CuFunction
14-
relu_backward_double :: CUDA.CuFunction
15-
sigmoid_forward_float :: CUDA.CuFunction
16-
sigmoid_forward_double :: CUDA.CuFunction
17-
sigmoid_backward_float :: CUDA.CuFunction
18-
sigmoid_backward_double :: CUDA.CuFunction
19-
accuracy_forward_float :: CUDA.CuFunction
20-
accuracy_forward_double :: CUDA.CuFunction
21-
argmax_forward_float :: CUDA.CuFunction
22-
argmax_forward_double :: CUDA.CuFunction
23-
24-
add_scal_float :: CUDA.CuFunction
25-
add_scal_double :: CUDA.CuFunction
26-
mul_scal_float :: CUDA.CuFunction
27-
mul_scal_double :: CUDA.CuFunction
28-
elem_add_float :: CUDA.CuFunction
29-
elem_add_double :: CUDA.CuFunction
30-
elem_mul_float :: CUDA.CuFunction
31-
elem_mul_double :: CUDA.CuFunction
32-
elem_sub_float :: CUDA.CuFunction
33-
elem_sub_double :: CUDA.CuFunction
34-
elem_div_float :: CUDA.CuFunction
35-
elem_div_double :: CUDA.CuFunction
36-
elem_div2_float :: CUDA.CuFunction
37-
elem_div2_double :: CUDA.CuFunction
38-
elem_pow_fi :: CUDA.CuFunction
39-
elem_pow_di :: CUDA.CuFunction
40-
elem_pow_ff :: CUDA.CuFunction
41-
elem_pow_dd :: CUDA.CuFunction
42-
43-
max_channel_pooling_forward_float :: CUDA.CuFunction
44-
max_channel_pooling_forward_double :: CUDA.CuFunction
45-
max_channel_pooling_backward_float :: CUDA.CuFunction
46-
max_channel_pooling_backward_double :: CUDA.CuFunction
47-
48-
dense_to_padded_float :: CUDA.CuFunction
49-
dense_to_padded_double :: CUDA.CuFunction
50-
padded_to_dense_float :: CUDA.CuFunction
51-
padded_to_dense_double :: CUDA.CuFunction
52-
53-
copy_to_shifted_float :: CUDA.CuFunction
54-
copy_to_shifted_double :: CUDA.CuFunction
55-
copy_from_shifted_float :: CUDA.CuFunction
56-
copy_from_shifted_double :: CUDA.CuFunction
57-
58-
dropout_init :: CUDA.CuFunction
59-
dropout_alloc_size :: CUDA.CuFunction
60-
dropout_forward_float :: CUDA.CuFunction
61-
dropout_forward_double :: CUDA.CuFunction
62-
dropout_backward_float :: CUDA.CuFunction
63-
dropout_backward_double :: CUDA.CuFunction
64-
65-
l1_forward_float :: CUDA.CuFunction
66-
l1_forward_double :: CUDA.CuFunction
67-
l1_backward_float :: CUDA.CuFunction
68-
l1_backward_double :: CUDA.CuFunction
69-
70-
MochaKernels() = begin
71-
mod_dir = joinpath(dirname(@__FILE__), "kernels")
72-
mod_path = joinpath(mod_dir, "kernels.ptx")
73-
74-
# check that our module is up-to-date
75-
if !isfile(mod_path)
76-
error("Mocha CUDA kernels not found, see the documents of BACKEND on how to compile the kernels")
77-
else
78-
mod_mtime = stat(mod_path).mtime
79-
impl_files = glob(mod_dir, r".*.impl$")
80-
for i = 1:length(impl_files)
81-
if stat(joinpath(mod_dir, impl_files[i])).mtime > mod_mtime
82-
error("Mocha CUDA kernels not up-to-date. Please re-compile (see documents of BACKEND)")
3+
macro defkernels(kernels...)
4+
field_defs = map(kernels) do ker
5+
:($ker :: CUDA.CuFunction)
6+
end
7+
type_body = Expr(:block, field_defs...)
8+
9+
field_inits = map(kernels) do ker
10+
:(kernels.$ker = CUDA.CuFunction(mod, $(string(ker))))
11+
end
12+
field_init_block = Expr(:block, field_inits...)
13+
14+
quote
15+
type $(esc(:MochaKernels))
16+
mod :: CUDA.CuModule
17+
18+
$type_body
19+
20+
$(esc(:MochaKernels))() = begin
21+
mod_dir = joinpath(dirname(@__FILE__), "kernels")
22+
mod_path = joinpath(mod_dir, "kernels.ptx")
23+
24+
# check that our module is up-to-date
25+
if !isfile(mod_path)
26+
error("Mocha CUDA kernels not found, see the documents of BACKEND on how to compile the kernels")
27+
else
28+
mod_mtime = stat(mod_path).mtime
29+
impl_files = glob(mod_dir, r".*.impl$")
30+
for i = 1:length(impl_files)
31+
if stat(joinpath(mod_dir, impl_files[i])).mtime > mod_mtime
32+
error("Mocha CUDA kernels not up-to-date. Please re-compile (see documents of BACKEND)")
33+
end
34+
end
8335
end
36+
37+
mod = CUDA.CuModule(mod_path)
38+
kernels = new(mod)
39+
40+
$field_init_block
41+
42+
return kernels
8443
end
8544
end
86-
87-
mod = CUDA.CuModule(mod_path)
88-
kernels = new(mod)
89-
90-
kernels.logistic_loss_forward_float = CUDA.CuFunction(mod, "logistic_loss_forward_float")
91-
kernels.logistic_loss_forward_double = CUDA.CuFunction(mod, "logistic_loss_forward_double")
92-
kernels.softmax_loss_backward_float = CUDA.CuFunction(mod, "softmax_loss_backward_float")
93-
kernels.softmax_loss_backward_double = CUDA.CuFunction(mod, "softmax_loss_backward_double")
94-
kernels.relu_forward_float = CUDA.CuFunction(mod, "relu_forward_float")
95-
kernels.relu_forward_double = CUDA.CuFunction(mod, "relu_forward_double")
96-
kernels.relu_backward_float = CUDA.CuFunction(mod, "relu_backward_float")
97-
kernels.relu_backward_double = CUDA.CuFunction(mod, "relu_backward_double")
98-
kernels.sigmoid_forward_float = CUDA.CuFunction(mod, "sigmoid_forward_float")
99-
kernels.sigmoid_forward_double = CUDA.CuFunction(mod, "sigmoid_forward_double")
100-
kernels.sigmoid_backward_float = CUDA.CuFunction(mod, "sigmoid_backward_float")
101-
kernels.sigmoid_backward_double = CUDA.CuFunction(mod, "sigmoid_backward_double")
102-
kernels.accuracy_forward_float = CUDA.CuFunction(mod, "accuracy_forward_float")
103-
kernels.accuracy_forward_double = CUDA.CuFunction(mod, "accuracy_forward_double")
104-
kernels.argmax_forward_float = CUDA.CuFunction(mod, "argmax_forward_float")
105-
kernels.argmax_forward_double = CUDA.CuFunction(mod, "argmax_forward_double")
106-
107-
kernels.add_scal_float = CUDA.CuFunction(mod, "add_scal_float")
108-
kernels.add_scal_double = CUDA.CuFunction(mod, "add_scal_double")
109-
kernels.mul_scal_float = CUDA.CuFunction(mod, "mul_scal_float")
110-
kernels.mul_scal_double = CUDA.CuFunction(mod, "mul_scal_double")
111-
kernels.elem_add_float = CUDA.CuFunction(mod, "elem_add_float")
112-
kernels.elem_add_double = CUDA.CuFunction(mod, "elem_add_double")
113-
kernels.elem_mul_float = CUDA.CuFunction(mod, "elem_mul_float")
114-
kernels.elem_mul_double = CUDA.CuFunction(mod, "elem_mul_double")
115-
kernels.elem_sub_float = CUDA.CuFunction(mod, "elem_sub_float")
116-
kernels.elem_sub_double = CUDA.CuFunction(mod, "elem_sub_double")
117-
kernels.elem_div_float = CUDA.CuFunction(mod, "elem_div_float")
118-
kernels.elem_div_double = CUDA.CuFunction(mod, "elem_div_double")
119-
kernels.elem_div2_float = CUDA.CuFunction(mod, "elem_div2_float")
120-
kernels.elem_div2_double = CUDA.CuFunction(mod, "elem_div2_double")
121-
kernels.elem_pow_fi = CUDA.CuFunction(mod, "elem_pow_fi")
122-
kernels.elem_pow_di = CUDA.CuFunction(mod, "elem_pow_di")
123-
kernels.elem_pow_ff = CUDA.CuFunction(mod, "elem_pow_ff")
124-
kernels.elem_pow_dd = CUDA.CuFunction(mod, "elem_pow_dd")
125-
126-
kernels.max_channel_pooling_forward_float = CUDA.CuFunction(mod, "max_channel_pooling_forward_float")
127-
kernels.max_channel_pooling_forward_double = CUDA.CuFunction(mod, "max_channel_pooling_forward_double")
128-
kernels.max_channel_pooling_backward_float = CUDA.CuFunction(mod, "max_channel_pooling_backward_float")
129-
kernels.max_channel_pooling_backward_double = CUDA.CuFunction(mod, "max_channel_pooling_backward_double")
130-
131-
kernels.dense_to_padded_float = CUDA.CuFunction(mod, "dense_to_padded_float")
132-
kernels.dense_to_padded_double = CUDA.CuFunction(mod, "dense_to_padded_double")
133-
kernels.padded_to_dense_float = CUDA.CuFunction(mod, "padded_to_dense_float")
134-
kernels.padded_to_dense_double = CUDA.CuFunction(mod, "padded_to_dense_double")
135-
136-
kernels.copy_to_shifted_float = CUDA.CuFunction(mod, "copy_to_shifted_float")
137-
kernels.copy_to_shifted_double = CUDA.CuFunction(mod, "copy_to_shifted_double")
138-
kernels.copy_from_shifted_float = CUDA.CuFunction(mod, "copy_from_shifted_float")
139-
kernels.copy_from_shifted_double = CUDA.CuFunction(mod, "copy_from_shifted_double")
140-
141-
kernels.dropout_init = CUDA.CuFunction(mod, "dropout_init")
142-
kernels.dropout_alloc_size = CUDA.CuFunction(mod, "dropout_alloc_size")
143-
kernels.dropout_forward_float = CUDA.CuFunction(mod, "dropout_forward_float")
144-
kernels.dropout_forward_double = CUDA.CuFunction(mod, "dropout_forward_double")
145-
kernels.dropout_backward_float = CUDA.CuFunction(mod, "dropout_backward_float")
146-
kernels.dropout_backward_double = CUDA.CuFunction(mod, "dropout_backward_double")
147-
148-
kernels.l1_forward_float = CUDA.CuFunction(mod, "l1_forward_float")
149-
kernels.l1_forward_double = CUDA.CuFunction(mod, "l1_forward_double")
150-
kernels.l1_backward_float = CUDA.CuFunction(mod, "l1_backward_float")
151-
kernels.l1_backward_double = CUDA.CuFunction(mod, "l1_backward_double")
152-
153-
return kernels
15445
end
15546
end
47+
48+
@defkernels(
49+
logistic_loss_forward_float,
50+
logistic_loss_forward_double,
51+
softmax_loss_backward_float,
52+
softmax_loss_backward_double,
53+
relu_forward_float,
54+
relu_forward_double,
55+
relu_backward_float,
56+
relu_backward_double,
57+
sigmoid_forward_float,
58+
sigmoid_forward_double,
59+
sigmoid_backward_float,
60+
sigmoid_backward_double,
61+
accuracy_forward_float,
62+
accuracy_forward_double,
63+
argmax_forward_float,
64+
argmax_forward_double,
65+
66+
add_scal_float,
67+
add_scal_double,
68+
mul_scal_float,
69+
mul_scal_double,
70+
elem_add_float,
71+
elem_add_double,
72+
elem_mul_float,
73+
elem_mul_double,
74+
elem_sub_float,
75+
elem_sub_double,
76+
elem_div_float,
77+
elem_div_double,
78+
elem_div2_float,
79+
elem_div2_double,
80+
elem_pow_fi,
81+
elem_pow_di,
82+
elem_pow_ff,
83+
elem_pow_dd,
84+
85+
max_channel_pooling_forward_float,
86+
max_channel_pooling_forward_double,
87+
max_channel_pooling_backward_float,
88+
max_channel_pooling_backward_double,
89+
90+
dense_to_padded_float,
91+
dense_to_padded_double,
92+
padded_to_dense_float,
93+
padded_to_dense_double,
94+
95+
copy_to_shifted_float,
96+
copy_to_shifted_double,
97+
copy_from_shifted_float,
98+
copy_from_shifted_double,
99+
100+
dropout_init,
101+
dropout_alloc_size,
102+
dropout_forward_float,
103+
dropout_forward_double,
104+
dropout_backward_float,
105+
dropout_backward_double,
106+
107+
l1_forward_float,
108+
l1_forward_double,
109+
l1_backward_float,
110+
l1_backward_double,
111+
)
112+
156113
function shutdown(mocha :: MochaKernels)
157114
CUDA.unload(mocha.mod)
158115
end

test/layers/channel-pooling.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction,
2222
payloads = Array(Any, n_input)
2323
for i = 1:n_input
2424
expected_output, payloads[i] = channel_pooling_forward(state, i, input[i])
25-
got_output = similar(expected_output)
26-
copy!(got_output, state.blobs[i])
25+
got_output = to_array(state.blobs[i])
2726
@test all(-eps .< expected_output-got_output .< eps)
2827
end
2928

@@ -36,8 +35,7 @@ function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction,
3635

3736
for i = 1:n_input
3837
expected_output = channel_pooling_backward(state, i, input[i], top_diff[i], payloads[i])
39-
got_output = similar(expected_output)
40-
copy!(got_output, diffs[i])
38+
got_output = to_array(diffs[i])
4139
@test all(-eps .< expected_output - got_output .< eps)
4240
end
4341

test/layers/concat.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ function test_concat_layer(backend::Backend, dim, T, eps)
1919
forward(backend, state, input_blobs)
2020

2121
expected_output = cat(dim, inputs...)
22-
got_output = similar(expected_output)
23-
copy!(got_output, state.blobs[1])
22+
got_output = to_array(state.blobs[1])
2423
@test all(abs(expected_output-got_output) .< eps)
2524

2625
println(" > Backward")

0 commit comments

Comments
 (0)