@@ -525,7 +525,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
525
525
BLASLONG range_M_buffer [MAX_CPU_NUMBER + 2 ];
526
526
BLASLONG range_N_buffer [MAX_CPU_NUMBER + 2 ];
527
527
BLASLONG * range_M , * range_N ;
528
- BLASLONG num_cpu_m , num_cpu_n ;
528
+ BLASLONG num_parts ;
529
529
530
530
BLASLONG nthreads = args -> nthreads ;
531
531
@@ -596,16 +596,16 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
596
596
}
597
597
598
598
/* Partition m into nthreads_m regions */
599
- num_cpu_m = 0 ;
599
+ num_parts = 0 ;
600
600
while (m > 0 ){
601
- width = blas_quickdivide (m + nthreads_m - num_cpu_m - 1 , nthreads_m - num_cpu_m );
601
+ width = blas_quickdivide (m + nthreads_m - num_parts - 1 , nthreads_m - num_parts );
602
602
m -= width ;
603
603
if (m < 0 ) width = width + m ;
604
- range_M [num_cpu_m + 1 ] = range_M [num_cpu_m ] + width ;
605
- num_cpu_m ++ ;
604
+ range_M [num_parts + 1 ] = range_M [num_parts ] + width ;
605
+ num_parts ++ ;
606
606
}
607
- for (i = num_cpu_m ; i < MAX_CPU_NUMBER ; i ++ ) {
608
- range_M [i + 1 ] = range_M [num_cpu_m ];
607
+ for (i = num_parts ; i < MAX_CPU_NUMBER ; i ++ ) {
608
+ range_M [i + 1 ] = range_M [num_parts ];
609
609
}
610
610
611
611
/* Initialize parameters for parallel execution */
@@ -637,16 +637,19 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
637
637
638
638
/* Partition (a step of) n into nthreads regions */
639
639
range_N [0 ] = js ;
640
- num_cpu_n = 0 ;
640
+ num_parts = 0 ;
641
641
while (n > 0 ){
642
- width = blas_quickdivide (n + nthreads - num_cpu_n - 1 , nthreads - num_cpu_n );
642
+ width = blas_quickdivide (n + nthreads - num_parts - 1 , nthreads - num_parts );
643
+ if (width < SWITCH_RATIO ) {
644
+ width = SWITCH_RATIO ;
645
+ }
643
646
n -= width ;
644
647
if (n < 0 ) width = width + n ;
645
- range_N [num_cpu_n + 1 ] = range_N [num_cpu_n ] + width ;
646
- num_cpu_n ++ ;
648
+ range_N [num_parts + 1 ] = range_N [num_parts ] + width ;
649
+ num_parts ++ ;
647
650
}
648
- for (j = num_cpu_n ; j < MAX_CPU_NUMBER ; j ++ ) {
649
- range_N [j + 1 ] = range_N [num_cpu_n ];
651
+ for (j = num_parts ; j < MAX_CPU_NUMBER ; j ++ ) {
652
+ range_N [j + 1 ] = range_N [num_parts ];
650
653
}
651
654
652
655
/* Clear synchronization flags */
@@ -683,7 +686,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
683
686
n = range_n [1 ] - range_n [0 ];
684
687
}
685
688
686
- /* CPU partitions in m should have at least SWITCH_RATIO rows */
689
+ /* Partitions in m should have at least SWITCH_RATIO rows */
687
690
if (m < 2 * SWITCH_RATIO ) {
688
691
nthreads_m = 1 ;
689
692
} else {
@@ -693,11 +696,11 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
693
696
}
694
697
}
695
698
696
- /* At most one CPU partition in n should have less than nthreads_m columns */
697
- if (n < nthreads_m ) {
699
+ /* Partitions in n should have at most SWITCH_RATIO * nthreads_m columns */
700
+ if (n < SWITCH_RATIO * nthreads_m ) {
698
701
nthreads_n = 1 ;
699
702
} else {
700
- nthreads_n = blas_quickdivide (n + nthreads_m - 1 , nthreads_m );
703
+ nthreads_n = (n + SWITCH_RATIO * nthreads_m - 1 ) / ( SWITCH_RATIO * nthreads_m );
701
704
if (nthreads_m * nthreads_n > args -> nthreads ) {
702
705
nthreads_n = blas_quickdivide (args -> nthreads , nthreads_m );
703
706
}
0 commit comments