Skip to content

Commit 6374831

Browse files
committed
fix conflict
2 parents a357bd6 + 1e6c992 commit 6374831

26 files changed

+2102
-29
lines changed

paddle/cuda/include/hl_matrix.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,80 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
224224
extern void hl_matrix_rotate(
225225
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
226226

227+
/**
228+
* @brief Matrix vol2Col: Convert 3D volume into col matrix
229+
*
230+
* @param[in] matSrc input matrix.
231+
* @param[in] channel channel of matSrc.
232+
* @param[in] depth depth of matSrc.
233+
* @param[in] height height of matSrc.
234+
* @param[in] width width of matSrc.
235+
* @param[in] filterD depth of filter.
236+
* @param[in] filterH height of filter.
237+
* @param[in] filterW width of filter.
238+
* @param[in] strideD stride in the depth.
239+
* @param[in] strideH stride in the height.
240+
* @param[in] strideW stride in the width.
241+
* @param[in] paddingD padding in the depth.
242+
* @param[in] paddingH padding in the height.
243+
* @param[in] paddingW padding in the width.
244+
* @param[out] dataDst output matrix.
245+
*
246+
*/
247+
extern void hl_matrix_vol2Col(const real* dataSrc,
248+
int channels,
249+
int depth,
250+
int height,
251+
int width,
252+
int filterD,
253+
int filterH,
254+
int filterW,
255+
int strideD,
256+
int strideH,
257+
int strideW,
258+
int paddingD,
259+
int paddingH,
260+
int paddingW,
261+
real* dataDst);
262+
263+
/**
264+
* @brief Matrix col2Vol: Convert col matrix into 3D volume
265+
*
266+
* @param[out] matDst output matrix.
267+
* @param[in] channel channel of matDst.
268+
* @param[in] depth depth of matDst.
269+
* @param[in] height height of matDst.
270+
* @param[in] width width of matDst.
271+
* @param[in] filterD depth of filter.
272+
* @param[in] filterH height of filter.
273+
* @param[in] filterW width of filter.
274+
* @param[in] strideD stride in the depth.
275+
* @param[in] strideH stride in the height.
276+
* @param[in] strideW stride in the width.
277+
* @param[in] paddingD padding in the depth.
278+
* @param[in] paddingH padding in the height.
279+
* @param[in] paddingW padding in the width.
280+
* @param[in] matSrc input matrix.
281+
* @param[in] beta input
282+
* @param[in] alpha input
283+
*
284+
*/
285+
extern void hl_matrix_col2Vol(real* dataDst,
286+
int channels,
287+
int depth,
288+
int height,
289+
int width,
290+
int filterD,
291+
int filterH,
292+
int filterW,
293+
int strideD,
294+
int strideH,
295+
int strideW,
296+
int paddingD,
297+
int paddingH,
298+
int paddingW,
299+
const real* dataSrc,
300+
real alpha,
301+
real beta);
302+
227303
#endif /* HL_MATRIX_H_ */

paddle/cuda/include/stub/hl_matrix_stub.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
9999
inline void hl_matrix_rotate(
100100
real* mat, real* matRot, int dimM, int dimN, bool clockWise) {}
101101

102+
inline void hl_matrix_vol2Col(const real* dataSrc,
103+
int channels,
104+
int depth,
105+
int height,
106+
int width,
107+
int filterD,
108+
int filterH,
109+
int filterW,
110+
int strideD,
111+
int strideH,
112+
int strideW,
113+
int paddingD,
114+
int paddingH,
115+
int paddingW,
116+
real* dataDst) {}
117+
118+
inline void hl_matrix_col2Vol(real* dataDst,
119+
int channels,
120+
int depth,
121+
int height,
122+
int width,
123+
int filterD,
124+
int filterH,
125+
int filterW,
126+
int strideD,
127+
int strideH,
128+
int strideW,
129+
int paddingD,
130+
int paddingH,
131+
int paddingW,
132+
const real* dataSrc,
133+
real alpha,
134+
real beta) {}
135+
102136
#endif // HL_MATRIX_STUB_H_

paddle/cuda/src/hl_cuda_matrix.cu

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,204 @@ void hl_matrix_rotate(
592592
mat, matRot, dimM, dimN, clockWise);
593593
CHECK_SYNC("hl_matrix_rotate failed");
594594
}
595+
596+
__global__ void keMatrixVol2Col(int num_kernels,
597+
const real* dataSrc,
598+
real* dataDst,
599+
int depth,
600+
int height,
601+
int width,
602+
int filterD,
603+
int filterH,
604+
int filterW,
605+
int strideD,
606+
int strideH,
607+
int strideW,
608+
int paddingD,
609+
int paddingH,
610+
int paddingW,
611+
int depth_col,
612+
int height_col,
613+
int width_col) {
614+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
615+
index += blockDim.x * gridDim.x) {
616+
int w_out = index % width_col;
617+
int h_out = (index / width_col) % height_col;
618+
int d_out = (index / width_col / height_col) % depth_col;
619+
int channel_in = index / width_col / height_col / depth_col;
620+
int channel_out = channel_in * filterD * filterH * filterW;
621+
int w_in = w_out * strideW - paddingW;
622+
int h_in = h_out * strideH - paddingH;
623+
int d_in = d_out * strideD - paddingD;
624+
625+
dataDst +=
626+
((channel_out * depth_col + d_out) * height_col + h_out) * width_col +
627+
w_out;
628+
dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
629+
for (int k = 0; k < filterD; ++k) {
630+
for (int i = 0; i < filterH; ++i) {
631+
for (int j = 0; j < filterW; ++j) {
632+
int d = d_in + k;
633+
int h = h_in + i;
634+
int w = w_in + j;
635+
*dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
636+
w < width)
637+
? dataSrc[(k * height + i) * width + j]
638+
: 0;
639+
dataDst += depth_col * height_col * width_col;
640+
}
641+
}
642+
}
643+
}
644+
}
645+
646+
void hl_matrix_vol2Col(const real* dataSrc,
647+
int channels,
648+
int depth,
649+
int height,
650+
int width,
651+
int filterD,
652+
int filterH,
653+
int filterW,
654+
int strideD,
655+
int strideH,
656+
int strideW,
657+
int paddingD,
658+
int paddingH,
659+
int paddingW,
660+
real* dataDst) {
661+
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
662+
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
663+
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
664+
int num_kernels = channels * depth_col * height_col * width_col;
665+
666+
const int threads = 512;
667+
const int blocks = DIVUP(num_kernels, threads);
668+
669+
keMatrixVol2Col<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
670+
dataSrc,
671+
dataDst,
672+
depth,
673+
height,
674+
width,
675+
filterD,
676+
filterH,
677+
filterW,
678+
strideD,
679+
strideH,
680+
strideW,
681+
paddingD,
682+
paddingH,
683+
paddingW,
684+
depth_col,
685+
height_col,
686+
width_col);
687+
CHECK_SYNC("hl_matrix_vol2Col failed");
688+
}
689+
690+
__global__ void keMatrixCol2Vol(int num_kernels,
691+
real* dataDst,
692+
const real* dataSrc,
693+
int depth,
694+
int height,
695+
int width,
696+
int filterD,
697+
int filterH,
698+
int filterW,
699+
int strideD,
700+
int strideH,
701+
int strideW,
702+
int paddingD,
703+
int paddingH,
704+
int paddingW,
705+
int depth_col,
706+
int height_col,
707+
int width_col,
708+
real alpha,
709+
real beta) {
710+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
711+
index += blockDim.x * gridDim.x) {
712+
real srcVal = 0;
713+
real dstVal = dataDst[index];
714+
int w = index % width + paddingW;
715+
int h = (index / width) % height + paddingH;
716+
int d = (index / width / height) % depth + paddingD;
717+
int c = index / width / height / depth;
718+
// compute the start and end of the output
719+
int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1;
720+
int w_col_end = min(w / strideW + 1, width_col);
721+
int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1;
722+
int h_col_end = min(h / strideH + 1, height_col);
723+
int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1;
724+
int d_col_end = min(d / strideD + 1, depth_col);
725+
726+
int offset = (c * filterD * filterW * filterH + d * filterW * filterH +
727+
h * filterW + w) *
728+
depth_col * height_col * width_col;
729+
730+
int coeff_d_col =
731+
(1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
732+
int coeff_h_col =
733+
(1 - strideH * filterW * depth_col * height_col) * width_col;
734+
int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);
735+
736+
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
737+
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
738+
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
739+
srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col +
740+
w_col * coeff_w_col];
741+
}
742+
}
743+
}
744+
dataDst[index] = alpha * srcVal + beta * dstVal;
745+
}
746+
}
747+
748+
void hl_matrix_col2Vol(real* dataDst,
749+
int channels,
750+
int depth,
751+
int height,
752+
int width,
753+
int filterD,
754+
int filterH,
755+
int filterW,
756+
int strideD,
757+
int strideH,
758+
int strideW,
759+
int paddingD,
760+
int paddingH,
761+
int paddingW,
762+
const real* dataSrc,
763+
real alpha,
764+
real beta) {
765+
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
766+
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
767+
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
768+
int num_kernels = channels * depth * height * width;
769+
770+
const int threads = 512;
771+
const int blocks = DIVUP(num_kernels, threads);
772+
773+
keMatrixCol2Vol<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
774+
dataDst,
775+
dataSrc,
776+
depth,
777+
height,
778+
width,
779+
filterD,
780+
filterH,
781+
filterW,
782+
strideD,
783+
strideH,
784+
strideW,
785+
paddingD,
786+
paddingH,
787+
paddingW,
788+
depth_col,
789+
height_col,
790+
width_col,
791+
alpha,
792+
beta);
793+
794+
CHECK_SYNC("hl_matrix_col2Vol failed");
795+
}

0 commit comments

Comments
 (0)