Skip to content

Commit b9319cc

Browse files
committed
simplify CuVec implementations and interface for purpose of ND-tensor
1 parent 8f9e223 commit b9319cc

File tree

5 files changed

+74
-95
lines changed

5 files changed

+74
-95
lines changed

src/cuda/constraints.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# apply L2 constraint
33
############################################################
44

5-
function apply_l2_cons!{T <: FloatingPoint}(backend::GPUBackend, blob::CuTensorBlob{T},
5+
function apply_l2_cons!{T <: FloatingPoint}(backend::GPUBackend, blob::CuTensorBlob{T},
66
coef::FloatingPoint, ninputs::Integer, nunits::Integer)
77
# we allocate a bit of temporary memory here
88
# we could instead also store this in the cons type
99
# but that would double the memory footprint of a network
1010
# which is prohibitive for large models!
11-
# --
12-
# NOTE stokasto:
11+
# --
12+
# NOTE stokasto:
1313
# an even better alternative would be to write
1414
# a dedicated kernel for normalization
1515
# but since the weight matrices are usually small
@@ -18,24 +18,23 @@ function apply_l2_cons!{T <: FloatingPoint}(backend::GPUBackend, blob::CuTensorB
1818
# I also tested using cublas cublasSnorm2 but that was way slower
1919
# than computing all norms using gemm
2020
@assert(ninputs*nunits == length(blob))
21-
width, height, channels, num = size(blob)
2221
# allocate
2322
tmpA = make_blob(backend, T, size(blob)...)
2423
onesv = make_blob(backend, ones(T, ninputs, 1, 1, 1))
2524
tmp_norm = make_blob(backend, T, (nunits, 1, 1, 1))
2625
tmp_norm_host = zeros(T, nunits)
27-
# copy blob so that it stays intact
26+
# copy blob so that it stays intact
2827
copy!(tmpA, blob)
2928

3029
# we compute the squared norm of all colums of matrix A as:
3130
# ||A||^2 = transpose(A .* A) * ones(size(A))
3231
# square blob inplace
33-
CuVec.mul!(backend, T, tmpA.ptr.p, tmpA.ptr.p, width*height, channels, num)
32+
CuVec.mul!(backend, T, tmpA.ptr.p, tmpA.ptr.p, length(blob))
3433
# and reduce via gemv to get the sum
35-
CuBLAS.gemm(backend.cublas_ctx, CuBLAS.OP_T, CuBLAS.OP_N, nunits, 1, ninputs,
34+
CuBLAS.gemm(backend.cublas_ctx, CuBLAS.OP_T, CuBLAS.OP_N, nunits, 1, ninputs,
3635
convert(T, 1), tmpA.ptr, ninputs, onesv.ptr, ninputs, convert(T, 0), tmp_norm.ptr, nunits)
37-
# copy back for doing the norm size check on the cpu
38-
copy!(tmp_norm_host, tmp_norm)
36+
# copy back for doing the norm size check on the cpu
37+
copy!(tmp_norm_host, tmp_norm)
3938

4039
for i = 1:nunits
4140
# calculate offset in blob vector

src/cuda/kernels/elementwise.impl

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
#define ELEMWISE_BOUNDS_AND_INDEX \
2-
int s = threadIdx.x + blockIdx.x * blockDim.x; \
3-
int k = threadIdx.y + blockIdx.y * blockDim.y; \
4-
int n = threadIdx.z + blockIdx.z * blockDim.z; \
5-
if (s >= spatial_dim || k >= channels || n >= num) \
6-
return; \
7-
int idx = s + spatial_dim * (k + channels * n)
2+
int idx = threadIdx.x + blockIdx.x * blockDim.x; \
3+
if (idx >= len) \
4+
return
85

96
template <typename T>
10-
__device__ void add_scal(T *array, T scal, int spatial_dim, int channels, int num) {
7+
__device__ void add_scal(T *array, T scal, int len) {
118
ELEMWISE_BOUNDS_AND_INDEX;
129
array[idx] += scal;
1310
}
1411
template <typename T>
15-
__device__ void mul_scal(T *array, T scal, int spatial_dim, int channels, int num) {
12+
__device__ void mul_scal(T *array, T scal, int len) {
1613
ELEMWISE_BOUNDS_AND_INDEX;
1714
array[idx] *= scal;
1815
}
1916

2017
#define DEF_ELEMWISE_OP(NAME, OP) \
2118
template <typename T> \
22-
__device__ void elem_ ## NAME(T *X, T *Y, int spatial_dim, int channels, int num) { \
19+
__device__ void elem_ ## NAME(T *X, T *Y, int len) { \
2320
ELEMWISE_BOUNDS_AND_INDEX; \
2421
X[idx] = X[idx] OP Y[idx]; \
2522
}
@@ -30,23 +27,23 @@ DEF_ELEMWISE_OP(sub, -)
3027
DEF_ELEMWISE_OP(div, /)
3128

3229
template <typename T>
33-
__device__ void elem_div2(T *X, T *Y, int spatial_dim, int channels, int num) {
30+
__device__ void elem_div2(T *X, T *Y, int len) {
3431
ELEMWISE_BOUNDS_AND_INDEX;
3532
Y[idx] = X[idx] / Y[idx];
3633
}
3734

3835
template <typename T1, typename T2>
39-
__device__ void elem_pow(T1 *X, T2 p, int spatial_dim, int channels, int num) {
36+
__device__ void elem_pow(T1 *X, T2 p, int len) {
4037
ELEMWISE_BOUNDS_AND_INDEX;
4138
X[idx] = pow(X[idx], p);
4239
}
4340

4441
#define DEF_ELEMWISE_API(NAME) \
45-
__global__ void elem_ ## NAME ## _float(float *X, float *Y, int spatial_dim, int channels, int num) { \
46-
elem_##NAME(X, Y, spatial_dim, channels, num); \
42+
__global__ void elem_ ## NAME ## _float(float *X, float *Y, int len) { \
43+
elem_##NAME(X, Y, len); \
4744
} \
48-
__global__ void elem_ ## NAME ## _double(double *X, double *Y, int spatial_dim, int channels, int num) { \
49-
elem_##NAME(X, Y, spatial_dim, channels, num); \
45+
__global__ void elem_ ## NAME ## _double(double *X, double *Y, int len) { \
46+
elem_##NAME(X, Y, len); \
5047
}
5148

5249
extern "C" {
@@ -58,31 +55,31 @@ DEF_ELEMWISE_API(div)
5855
DEF_ELEMWISE_API(div2)
5956

6057

61-
__global__ void add_scal_float(float *X, float Y, int spatial_dim, int channels, int num) {
62-
add_scal(X, Y, spatial_dim, channels, num);
58+
__global__ void add_scal_float(float *X, float Y, int len) {
59+
add_scal(X, Y, len);
6360
}
64-
__global__ void add_scal_double(double *X, double Y, int spatial_dim, int channels, int num) {
65-
add_scal(X, Y, spatial_dim, channels, num);
61+
__global__ void add_scal_double(double *X, double Y, int len) {
62+
add_scal(X, Y, len);
6663
}
6764

68-
__global__ void mul_scal_float(float *X, float Y, int spatial_dim, int channels, int num) {
69-
mul_scal(X, Y, spatial_dim, channels, num);
65+
__global__ void mul_scal_float(float *X, float Y, int len) {
66+
mul_scal(X, Y, len);
7067
}
71-
__global__ void mul_scal_double(double *X, double Y, int spatial_dim, int channels, int num) {
72-
mul_scal(X, Y, spatial_dim, channels, num);
68+
__global__ void mul_scal_double(double *X, double Y, int len) {
69+
mul_scal(X, Y, len);
7370
}
7471

75-
__global__ void elem_pow_fi(float *X, int p, int spatial_dim, int channels, int num) {
76-
elem_pow(X, p, spatial_dim, channels, num);
72+
__global__ void elem_pow_fi(float *X, int p, int len) {
73+
elem_pow(X, p, len);
7774
}
78-
__global__ void elem_pow_di(double *X, int p, int spatial_dim, int channels, int num) {
79-
elem_pow(X, p, spatial_dim, channels, num);
75+
__global__ void elem_pow_di(double *X, int p, int len) {
76+
elem_pow(X, p, len);
8077
}
81-
__global__ void elem_pow_ff(float *X, float p, int spatial_dim, int channels, int num) {
82-
elem_pow(X, p, spatial_dim, channels, num);
78+
__global__ void elem_pow_ff(float *X, float p, int len) {
79+
elem_pow(X, p, len);
8380
}
84-
__global__ void elem_pow_dd(double *X, double p, int spatial_dim, int channels, int num) {
85-
elem_pow(X, p, spatial_dim, channels, num);
81+
__global__ void elem_pow_dd(double *X, double p, int len) {
82+
elem_pow(X, p, len);
8683
}
8784

8885
} // extern "C"

src/cuda/layers/power.jl

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ function forward(backend::GPUBackend, state::PowerLayerState, inputs::Vector{Blo
77
input = inputs[i]
88
output = state.blobs[i]
99

10-
width, height, channels, num = size(input)
11-
spatial_dim = width*height
10+
len = length(input)
1211
data_type = eltype(input)
1312

1413
# output = input
@@ -22,18 +21,17 @@ function forward(backend::GPUBackend, state::PowerLayerState, inputs::Vector{Blo
2221

2322
if state.layer.shift != 0
2423
# output += shift
25-
CuVec.add_scal!(backend, data_type, output.ptr.p, convert(data_type, state.layer.shift),
26-
spatial_dim, channels, num)
24+
CuVec.add_scal!(backend, data_type, output.ptr.p, convert(data_type, state.layer.shift), len)
2725
end
2826

2927
# output = output ^ power
3028
if state.layer.power != 1
3129
if state.layer.power == 2
32-
CuVec.mul!(backend, data_type, output.ptr.p, output.ptr.p, spatial_dim, channels, num)
30+
CuVec.mul!(backend, data_type, output.ptr.p, output.ptr.p, len)
3331
else
3432
CuVec.pow!(backend, data_type, output.ptr.p,
3533
isinteger(state.layer.power) ? int(state.layer.power) : convert(data_type, state.layer.power),
36-
spatial_dim, channels, num)
34+
len)
3735
end
3836
end
3937
end
@@ -45,8 +43,7 @@ function backward(backend::GPUBackend, state::PowerLayerState,
4543
data_type = eltype(inputs[1])
4644
pow_scale = convert(data_type,state.layer.power * state.layer.scale)
4745
for i = 1:length(inputs)
48-
width, height, channels, num = size(inputs[i])
49-
spatial_dim = width*height
46+
len = length(inputs[i])
5047

5148
diff = diffs[i]
5249
if state.layer.power == 1 || state.layer.scale == 0
@@ -64,32 +61,28 @@ function backward(backend::GPUBackend, state::PowerLayerState,
6461
CuBLAS.axpy(backend.cublas_ctx, length(input), convert(data_type, pow_scale*state.layer.scale),
6562
input.ptr, 1, diff.ptr, 1)
6663
if state.layer.shift != 0
67-
CuVec.add_scal!(backend, data_type, diff.ptr.p, pow_scale * state.layer.shift,
68-
spatial_dim, channels, num)
64+
CuVec.add_scal!(backend, data_type, diff.ptr.p, pow_scale * state.layer.shift, len)
6965
end
7066
elseif state.layer.shift == 0
7167
# dO/dI = power * scale * (scale * I) ^ (power - 1)
7268
# = power * O / I
7369
CuBLAS.axpy(backend.cublas_ctx, length(input), convert(data_type,state.layer.power),
7470
output.ptr, 1, diff.ptr, 1)
75-
CuVec.div!(backend, data_type, diff.ptr.p, input.ptr.p, spatial_dim, channels, num)
71+
CuVec.div!(backend, data_type, diff.ptr.p, input.ptr.p, len)
7672
else
7773
# general case
7874
# dO/dI = power * scale * (scale * I + shift) ^ (power - 1)
7975
# = power * scale * O / (scale * I + shift)
8076
copy!(diff, input)
8177
if state.layer.scale != 1
82-
CuBLAS.scal(backend.cublas_ctx, length(diff),
78+
CuBLAS.scal(backend.cublas_ctx, length(diff),
8379
convert(data_type,state.layer.scale), diff.ptr, 1)
8480
end
85-
CuVec.add_scal!(backend, data_type, diff.ptr.p, state.layer.shift,
86-
spatial_dim, channels, num)
87-
CuVec.div2!(backend, data_type, output.ptr.p, diff.ptr.p,
88-
spatial_dim, channels, num)
81+
CuVec.add_scal!(backend, data_type, diff.ptr.p, state.layer.shift, len)
82+
CuVec.div2!(backend, data_type, output.ptr.p, diff.ptr.p, len)
8983
CuBLAS.scal(backend.cublas_ctx, length(diff), pow_scale, diff.ptr, 1)
9084
end
9185
end
92-
CuVec.mul!(backend, data_type, diff.ptr.p, state.blobs_diff[i].ptr.p,
93-
spatial_dim, channels, num)
86+
CuVec.mul!(backend, data_type, diff.ptr.p, state.blobs_diff[i].ptr.p, len)
9487
end
9588
end

src/cuda/utils/math.jl

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,45 @@ export CuVec
22
module CuVec
33
using ..Mocha
44

5-
function cuda_geometry(sp_dim::Int, chann::Int, num::Int)
6-
x_block = int(ceil(float64(sp_dim)/CUDA.THREADS_PER_BLOCK_X))
7-
y_block = int(ceil(float64(chann)/CUDA.THREADS_PER_BLOCK_Y))
8-
z_block = int(ceil(float64(num)/CUDA.THREADS_PER_BLOCK_Z))
9-
return ((x_block,y_block,z_block),
10-
(CUDA.THREADS_PER_BLOCK_X,CUDA.THREADS_PER_BLOCK_Y,CUDA.THREADS_PER_BLOCK_Z))
5+
const THREADS_PER_BLOCK = 128
6+
function cuda_geometry(len::Int)
7+
x_block = int(ceil(float64(len)/THREADS_PER_BLOCK))
8+
return (x_block, THREADS_PER_BLOCK)
119
end
1210

1311
for (ctype, dtype) in [(:float, Float32), (:double, Float64)]
1412
# define add!, sub!, mul!, div!, div2!
1513
for name in [:add, :sub, :mul, :div, :div2]
1614
@eval begin
17-
function $(symbol("$(name)!"))(backend::GPUBackend, ::Type{$dtype}, X, Y,
18-
spatial_dim::Int, channels::Int, num::Int)
15+
function $(symbol("$(name)!"))(backend::GPUBackend, ::Type{$dtype}, X, Y, len::Int)
1916
X = convert(Ptr{Void},X)
2017
Y = convert(Ptr{Void},Y)
21-
cuda_dim = cuda_geometry(spatial_dim, channels, num)
18+
cuda_dim = cuda_geometry(len)
2219
kernel = backend.mocha.$(symbol("elem_$(name)_$ctype"))
23-
CUDA.launch(kernel, cuda_dim..., (X, Y, spatial_dim, channels, num))
20+
CUDA.launch(kernel, cuda_dim..., (X, Y, len))
2421
end
2522
end
2623
end
2724

2825
# define add_scal!
2926
@eval begin
30-
function add_scal!(backend::GPUBackend, ::Type{$dtype}, X, Y,
31-
spatial_dim::Int, channels::Int, num::Int)
27+
function add_scal!(backend::GPUBackend, ::Type{$dtype}, X, Y, len::Int)
3228
X = convert(Ptr{Void}, X)
3329
Y = convert($dtype, Y)
34-
cuda_dim = cuda_geometry(spatial_dim, channels, num)
30+
cuda_dim = cuda_geometry(len)
3531
kernel = backend.mocha.$(symbol("add_scal_$ctype"))
36-
CUDA.launch(kernel, cuda_dim..., (X,Y,spatial_dim,channels,num))
32+
CUDA.launch(kernel, cuda_dim..., (X,Y,len))
3733
end
3834
end
3935

4036
# define mul_scal!
4137
@eval begin
42-
function mul_scal!(backend::GPUBackend, ::Type{$dtype}, X, Y,
43-
spatial_dim::Int, channels::Int, num::Int)
38+
function mul_scal!(backend::GPUBackend, ::Type{$dtype}, X, Y, len::Int)
4439
X = convert(Ptr{Void}, X)
4540
Y = convert($dtype, Y)
46-
cuda_dim = cuda_geometry(spatial_dim, channels, num)
41+
cuda_dim = cuda_geometry(len)
4742
kernel = backend.mocha.$(symbol("mul_scal_$ctype"))
48-
CUDA.launch(kernel, cuda_dim..., (X,Y,spatial_dim,channels,num))
43+
CUDA.launch(kernel, cuda_dim..., (X,Y,len))
4944
end
5045
end
5146
end
@@ -54,34 +49,30 @@ end
5449
for name in [:add, :sub, :mul, :div, :div2]
5550
@eval begin
5651
function $(symbol("$(name)!")){T}(backend::GPUBackend, X::CuTensorBlob{T}, Y::CuTensorBlob{T})
57-
width, height, channels, num = get_whcn(X)
58-
sp_dim = width*height
59-
$(symbol("$(name)!"))(backend, T, X.ptr.p, Y.ptr.p, sp_dim, channels, num)
52+
len = length(X)
53+
$(symbol("$(name)!"))(backend, T, X.ptr.p, Y.ptr.p, len)
6054
end
6155
end
6256
end
6357
function add_scal!{T}(backend::GPUBackend, X::CuTensorBlob{T}, Y)
6458
Y = convert(T, Y)
65-
width, height, channels, num = get_whcn(X)
66-
sp_dim = width*height
67-
add_scal!(backend, T, X.ptr.p, Y, sp_dim, channels, num)
59+
len = length(X)
60+
add_scal!(backend, T, X.ptr.p, Y, len)
6861
end
6962
function mul_scal!{T}(backend::GPUBackend, X::CuTensorBlob{T}, Y)
7063
Y = convert(T, Y)
71-
width, height, channels, num = get_whcn(X)
72-
sp_dim = width*height
73-
mul_scal!(backend, T, X.ptr.p, Y, sp_dim, channels, num)
64+
len = length(X)
65+
mul_scal!(backend, T, X.ptr.p, Y, len)
7466
end
7567

7668
for (postfix, dt1, dt2) in [(:fi, Float32, Int), (:di, Float64, Int),
7769
(:ff, Float32, Float32), (:dd, Float64, Float64)]
7870
@eval begin
79-
function pow!(backend::GPUBackend, ::Type{$dt1}, X, Y::$dt2,
80-
spatial_dim::Int, channels::Int, num::Int)
71+
function pow!(backend::GPUBackend, ::Type{$dt1}, X, Y::$dt2, len::Int)
8172
X = convert(Ptr{Void}, X)
82-
cuda_dim = cuda_geometry(spatial_dim, channels, num)
73+
cuda_dim = cuda_geometry(len)
8374
kernel = backend.mocha.$(symbol("elem_pow_$postfix"))
84-
CUDA.launch(kernel, cuda_dim..., (X,Y,spatial_dim,channels,num))
75+
CUDA.launch(kernel, cuda_dim..., (X,Y,len))
8576
end
8677
end
8778
end

test/cuda/cuvec.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
function test_cuvec(backend::Backend, T)
22
println("-- Testing CuVec Utilities{$T}")
3-
width, height, channels, num = (5,6,7,8)
4-
spatial_dim = width*height
5-
dims = (width, height, channels, num)
3+
dims = (5,6,7,8)
4+
len = prod(dims)
65
eps = 1e-5
76

87
X = rand(T, dims)
@@ -12,19 +11,19 @@ function test_cuvec(backend::Backend, T)
1211

1312
println(" > mul!")
1413
Vec.mul!(X, Y)
15-
CuVec.mul!(backend, T, X_blob.ptr.p, Y_blob.ptr.p, spatial_dim, channels, num)
14+
CuVec.mul!(backend, T, X_blob.ptr.p, Y_blob.ptr.p, len)
1615
X2 = similar(X)
1716
copy!(X2, X_blob)
1817
@test all(abs(X-X2) .< eps)
1918

2019
println(" > pow!")
2120
Vec.pow!(X, 2)
22-
CuVec.pow!(backend, T, X_blob.ptr.p, 2, spatial_dim, channels, num)
21+
CuVec.pow!(backend, T, X_blob.ptr.p, 2, len)
2322
copy!(X2, X_blob)
2423
@test all(abs(X-X2) .< eps)
2524

2625
Vec.pow!(X, convert(T, 0.75))
27-
CuVec.pow!(backend, T, X_blob.ptr.p, convert(T, 0.75), spatial_dim, channels, num)
26+
CuVec.pow!(backend, T, X_blob.ptr.p, convert(T, 0.75), len)
2827
copy!(X2, X_blob)
2928
@test all(abs(X-X2) .< eps)
3029
end

0 commit comments

Comments
 (0)