Skip to content

Commit 30486a3

Browse files
author
Tim Moon
committed
Reduce number of data partitions in n.
1 parent 9de52b4 commit 30486a3

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

driver/level3/level3_thread.c

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
525525
BLASLONG range_M_buffer[MAX_CPU_NUMBER + 2];
526526
BLASLONG range_N_buffer[MAX_CPU_NUMBER + 2];
527527
BLASLONG *range_M, *range_N;
528-
BLASLONG num_cpu_m, num_cpu_n;
528+
BLASLONG num_parts;
529529

530530
BLASLONG nthreads = args -> nthreads;
531531

@@ -596,16 +596,16 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
596596
}
597597

598598
/* Partition m into nthreads_m regions */
599-
num_cpu_m = 0;
599+
num_parts = 0;
600600
while (m > 0){
601-
width = blas_quickdivide(m + nthreads_m - num_cpu_m - 1, nthreads_m - num_cpu_m);
601+
width = blas_quickdivide(m + nthreads_m - num_parts - 1, nthreads_m - num_parts);
602602
m -= width;
603603
if (m < 0) width = width + m;
604-
range_M[num_cpu_m + 1] = range_M[num_cpu_m] + width;
605-
num_cpu_m ++;
604+
range_M[num_parts + 1] = range_M[num_parts] + width;
605+
num_parts ++;
606606
}
607-
for (i = num_cpu_m; i < MAX_CPU_NUMBER; i++) {
608-
range_M[i + 1] = range_M[num_cpu_m];
607+
for (i = num_parts; i < MAX_CPU_NUMBER; i++) {
608+
range_M[i + 1] = range_M[num_parts];
609609
}
610610

611611
/* Initialize parameters for parallel execution */
@@ -637,16 +637,19 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
637637

638638
/* Partition (a step of) n into nthreads regions */
639639
range_N[0] = js;
640-
num_cpu_n = 0;
640+
num_parts = 0;
641641
while (n > 0){
642-
width = blas_quickdivide(n + nthreads - num_cpu_n - 1, nthreads - num_cpu_n);
642+
width = blas_quickdivide(n + nthreads - num_parts - 1, nthreads - num_parts);
643+
if (width < SWITCH_RATIO) {
644+
width = SWITCH_RATIO;
645+
}
643646
n -= width;
644647
if (n < 0) width = width + n;
645-
range_N[num_cpu_n + 1] = range_N[num_cpu_n] + width;
646-
num_cpu_n ++;
648+
range_N[num_parts + 1] = range_N[num_parts] + width;
649+
num_parts ++;
647650
}
648-
for (j = num_cpu_n; j < MAX_CPU_NUMBER; j++) {
649-
range_N[j + 1] = range_N[num_cpu_n];
651+
for (j = num_parts; j < MAX_CPU_NUMBER; j++) {
652+
range_N[j + 1] = range_N[num_parts];
650653
}
651654

652655
/* Clear synchronization flags */
@@ -683,7 +686,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
683686
n = range_n[1] - range_n[0];
684687
}
685688

686-
/* CPU partitions in m should have at least SWITCH_RATIO rows */
689+
/* Partitions in m should have at least SWITCH_RATIO rows */
687690
if (m < 2 * SWITCH_RATIO) {
688691
nthreads_m = 1;
689692
} else {
@@ -693,11 +696,11 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
693696
}
694697
}
695698

696-
/* At most one CPU partition in n should have less than nthreads_m columns */
697-
if (n < nthreads_m) {
699+
/* Partitions in n should have at most SWITCH_RATIO * nthreads_m columns */
700+
if (n < SWITCH_RATIO * nthreads_m) {
698701
nthreads_n = 1;
699702
} else {
700-
nthreads_n = blas_quickdivide(n + nthreads_m - 1, nthreads_m);
703+
nthreads_n = (n + SWITCH_RATIO * nthreads_m - 1) / (SWITCH_RATIO * nthreads_m);
701704
if (nthreads_m * nthreads_n > args -> nthreads) {
702705
nthreads_n = blas_quickdivide(args -> nthreads, nthreads_m);
703706
}

0 commit comments

Comments
 (0)