@@ -592,3 +592,204 @@ void hl_matrix_rotate(
592592 mat, matRot, dimM, dimN, clockWise);
593593 CHECK_SYNC (" hl_matrix_rotate failed" );
594594}
595+
596+ __global__ void keMatrixVol2Col (int num_kernels,
597+ const real* dataSrc,
598+ real* dataDst,
599+ int depth,
600+ int height,
601+ int width,
602+ int filterD,
603+ int filterH,
604+ int filterW,
605+ int strideD,
606+ int strideH,
607+ int strideW,
608+ int paddingD,
609+ int paddingH,
610+ int paddingW,
611+ int depth_col,
612+ int height_col,
613+ int width_col) {
614+ for (int index = blockIdx .x * blockDim .x + threadIdx .x ; index < num_kernels;
615+ index += blockDim .x * gridDim .x ) {
616+ int w_out = index % width_col;
617+ int h_out = (index / width_col) % height_col;
618+ int d_out = (index / width_col / height_col) % depth_col;
619+ int channel_in = index / width_col / height_col / depth_col;
620+ int channel_out = channel_in * filterD * filterH * filterW;
621+ int w_in = w_out * strideW - paddingW;
622+ int h_in = h_out * strideH - paddingH;
623+ int d_in = d_out * strideD - paddingD;
624+
625+ dataDst +=
626+ ((channel_out * depth_col + d_out) * height_col + h_out) * width_col +
627+ w_out;
628+ dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
629+ for (int k = 0 ; k < filterD; ++k) {
630+ for (int i = 0 ; i < filterH; ++i) {
631+ for (int j = 0 ; j < filterW; ++j) {
632+ int d = d_in + k;
633+ int h = h_in + i;
634+ int w = w_in + j;
635+ *dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
636+ w < width)
637+ ? dataSrc[(k * height + i) * width + j]
638+ : 0 ;
639+ dataDst += depth_col * height_col * width_col;
640+ }
641+ }
642+ }
643+ }
644+ }
645+
646+ void hl_matrix_vol2Col (const real* dataSrc,
647+ int channels,
648+ int depth,
649+ int height,
650+ int width,
651+ int filterD,
652+ int filterH,
653+ int filterW,
654+ int strideD,
655+ int strideH,
656+ int strideW,
657+ int paddingD,
658+ int paddingH,
659+ int paddingW,
660+ real* dataDst) {
661+ int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1 ;
662+ int height_col = (height + 2 * paddingH - filterH) / strideH + 1 ;
663+ int width_col = (width + 2 * paddingW - filterW) / strideW + 1 ;
664+ int num_kernels = channels * depth_col * height_col * width_col;
665+
666+ const int threads = 512 ;
667+ const int blocks = DIVUP (num_kernels, threads);
668+
669+ keMatrixVol2Col<<<blocks, threads, 0 , STREAM_DEFAULT>>> (num_kernels,
670+ dataSrc,
671+ dataDst,
672+ depth,
673+ height,
674+ width,
675+ filterD,
676+ filterH,
677+ filterW,
678+ strideD,
679+ strideH,
680+ strideW,
681+ paddingD,
682+ paddingH,
683+ paddingW,
684+ depth_col,
685+ height_col,
686+ width_col);
687+ CHECK_SYNC (" hl_matrix_vol2Col failed" );
688+ }
689+
690+ __global__ void keMatrixCol2Vol (int num_kernels,
691+ real* dataDst,
692+ const real* dataSrc,
693+ int depth,
694+ int height,
695+ int width,
696+ int filterD,
697+ int filterH,
698+ int filterW,
699+ int strideD,
700+ int strideH,
701+ int strideW,
702+ int paddingD,
703+ int paddingH,
704+ int paddingW,
705+ int depth_col,
706+ int height_col,
707+ int width_col,
708+ real alpha,
709+ real beta) {
710+ for (int index = blockIdx .x * blockDim .x + threadIdx .x ; index < num_kernels;
711+ index += blockDim .x * gridDim .x ) {
712+ real srcVal = 0 ;
713+ real dstVal = dataDst[index];
714+ int w = index % width + paddingW;
715+ int h = (index / width) % height + paddingH;
716+ int d = (index / width / height) % depth + paddingD;
717+ int c = index / width / height / depth;
718+ // compute the start and end of the output
719+ int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1 ;
720+ int w_col_end = min (w / strideW + 1 , width_col);
721+ int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1 ;
722+ int h_col_end = min (h / strideH + 1 , height_col);
723+ int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1 ;
724+ int d_col_end = min (d / strideD + 1 , depth_col);
725+
726+ int offset = (c * filterD * filterW * filterH + d * filterW * filterH +
727+ h * filterW + w) *
728+ depth_col * height_col * width_col;
729+
730+ int coeff_d_col =
731+ (1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
732+ int coeff_h_col =
733+ (1 - strideH * filterW * depth_col * height_col) * width_col;
734+ int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);
735+
736+ for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
737+ for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
738+ for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
739+ srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col +
740+ w_col * coeff_w_col];
741+ }
742+ }
743+ }
744+ dataDst[index] = alpha * srcVal + beta * dstVal;
745+ }
746+ }
747+
748+ void hl_matrix_col2Vol (real* dataDst,
749+ int channels,
750+ int depth,
751+ int height,
752+ int width,
753+ int filterD,
754+ int filterH,
755+ int filterW,
756+ int strideD,
757+ int strideH,
758+ int strideW,
759+ int paddingD,
760+ int paddingH,
761+ int paddingW,
762+ const real* dataSrc,
763+ real alpha,
764+ real beta) {
765+ int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1 ;
766+ int height_col = (height + 2 * paddingH - filterH) / strideH + 1 ;
767+ int width_col = (width + 2 * paddingW - filterW) / strideW + 1 ;
768+ int num_kernels = channels * depth * height * width;
769+
770+ const int threads = 512 ;
771+ const int blocks = DIVUP (num_kernels, threads);
772+
773+ keMatrixCol2Vol<<<blocks, threads, 0 , STREAM_DEFAULT>>> (num_kernels,
774+ dataDst,
775+ dataSrc,
776+ depth,
777+ height,
778+ width,
779+ filterD,
780+ filterH,
781+ filterW,
782+ strideD,
783+ strideH,
784+ strideW,
785+ paddingD,
786+ paddingH,
787+ paddingW,
788+ depth_col,
789+ height_col,
790+ width_col,
791+ alpha,
792+ beta);
793+
794+ CHECK_SYNC (" hl_matrix_col2Vol failed" );
795+ }
0 commit comments