@@ -219,12 +219,14 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
219
219
FLOAT * buffer [DIVIDE_RATE ];
220
220
221
221
BLASLONG k , lda , ldb , ldc ;
222
- BLASLONG m_from , m_to , n_from , n_to , N_from , N_to ;
222
+ BLASLONG m_from , m_to , n_from , n_to ;
223
223
224
224
FLOAT * alpha , * beta ;
225
225
FLOAT * a , * b , * c ;
226
226
job_t * job = (job_t * )args -> common ;
227
+ BLASLONG nthreads_m ;
227
228
BLASLONG xxx , bufferside ;
229
+ BLASLONG mypos_m , mypos_n ;
228
230
229
231
BLASLONG ls , min_l , jjs , min_jj ;
230
232
BLASLONG is , min_i , div_n ;
@@ -259,26 +261,28 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
259
261
alpha = (FLOAT * )args -> alpha ;
260
262
beta = (FLOAT * )args -> beta ;
261
263
264
+ nthreads_m = args -> nthreads ;
265
+ if (range_m ) {
266
+ nthreads_m = range_m [-1 ];
267
+ }
268
+
269
+ mypos_m = mypos % nthreads_m ;
270
+ mypos_n = mypos / nthreads_m ;
271
+
262
272
m_from = 0 ;
263
273
m_to = M ;
264
274
265
275
if (range_m ) {
266
- m_from = range_m [0 ];
267
- m_to = range_m [1 ];
276
+ m_from = range_m [mypos_m + 0 ];
277
+ m_to = range_m [mypos_m + 1 ];
268
278
}
269
279
270
280
n_from = 0 ;
271
281
n_to = N ;
272
282
273
- N_from = 0 ;
274
- N_to = N ;
275
-
276
283
if (range_n ) {
277
284
n_from = range_n [mypos + 0 ];
278
285
n_to = range_n [mypos + 1 ];
279
-
280
- N_from = range_n [0 ];
281
- N_to = range_n [args -> nthreads ];
282
286
}
283
287
284
288
if (beta ) {
@@ -287,7 +291,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
287
291
#else
288
292
if ((beta [0 ] != ONE ) || (beta [1 ] != ZERO ))
289
293
#endif
290
- BETA_OPERATION (m_from , m_to , N_from , N_to , beta , c , ldc );
294
+ BETA_OPERATION (m_from , m_to , range_n [ mypos_n * nthreads_m ], range_n [( mypos_n + 1 ) * nthreads_m ] , beta , c , ldc );
291
295
}
292
296
293
297
if ((k == 0 ) || (alpha == NULL )) return 0 ;
@@ -299,8 +303,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
299
303
) return 0 ;
300
304
301
305
#if 0
302
- fprintf (stderr , "Thread[%ld] m_from : %ld m_to : %ld n_from : %ld n_to : %ld N_from : %ld N_to : %ld \n" ,
303
- mypos , m_from , m_to , n_from , n_to , N_from , N_to );
306
+ fprintf (stderr , "Thread[%ld] m_from : %ld m_to : %ld n_from : %ld n_to : %ld\n" ,
307
+ mypos , m_from , m_to , n_from , n_to );
304
308
305
309
fprintf (stderr , "GEMM: P = %4ld Q = %4ld R = %4ld\n" , (BLASLONG )GEMM_P , (BLASLONG )GEMM_Q , (BLASLONG )GEMM_R );
306
310
@@ -394,21 +398,22 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
394
398
}
395
399
#endif
396
400
397
- for (i = 0 ; i < args -> nthreads ; i ++ ) job [mypos ].working [i ][CACHE_LINE_SIZE * bufferside ] = (BLASLONG )buffer [bufferside ];
401
+ for (i = mypos_n * nthreads_m ; i < (mypos_n + 1 ) * nthreads_m ; i ++ )
402
+ job [mypos ].working [i ][CACHE_LINE_SIZE * bufferside ] = (BLASLONG )buffer [bufferside ];
398
403
WMB ;
399
404
}
400
405
401
406
current = mypos ;
402
407
403
408
do {
404
409
current ++ ;
405
- if (current >= args -> nthreads ) current = 0 ;
410
+ if (current >= ( mypos_n + 1 ) * nthreads_m ) current = mypos_n * nthreads_m ;
406
411
407
412
div_n = (range_n [current + 1 ] - range_n [current ] + DIVIDE_RATE - 1 ) / DIVIDE_RATE ;
408
413
409
414
for (xxx = range_n [current ], bufferside = 0 ; xxx < range_n [current + 1 ]; xxx += div_n , bufferside ++ ) {
410
415
411
- if (current != mypos ) {
416
+ if (current != mypos ) {
412
417
413
418
START_RPCC ();
414
419
@@ -479,7 +484,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
479
484
}
480
485
481
486
current ++ ;
482
- if (current >= args -> nthreads ) current = 0 ;
487
+ if (current >= ( mypos_n + 1 ) * nthreads_m ) current = mypos_n * nthreads_m ;
483
488
484
489
} while (current != mypos );
485
490
@@ -525,7 +530,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
525
530
}
526
531
527
532
static int gemm_driver (blas_arg_t * args , BLASLONG * range_m , BLASLONG
528
- * range_n , FLOAT * sa , FLOAT * sb , BLASLONG mypos ){
533
+ * range_n , FLOAT * sa , FLOAT * sb ,
534
+ BLASLONG nthreads_m , BLASLONG nthreads_n ) {
529
535
530
536
blas_arg_t newarg ;
531
537
@@ -537,8 +543,10 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
537
543
538
544
blas_queue_t queue [MAX_CPU_NUMBER ];
539
545
540
- BLASLONG range_M [MAX_CPU_NUMBER + 1 ];
541
- BLASLONG range_N [MAX_CPU_NUMBER + 1 ];
546
+ BLASLONG range_M_buffer [MAX_CPU_NUMBER + 2 ];
547
+ BLASLONG range_N_buffer [MAX_CPU_NUMBER + 2 ];
548
+ BLASLONG * range_M = range_M_buffer + 1 ;
549
+ BLASLONG * range_N = range_N_buffer + 1 ;
542
550
543
551
BLASLONG num_cpu_m , num_cpu_n ;
544
552
@@ -595,6 +603,9 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
595
603
newarg .gemm_r = args -> gemm_r ;
596
604
#endif
597
605
606
+ range_M [-1 ] = nthreads_m ;
607
+ range_N [-1 ] = nthreads_n ;
608
+
598
609
if (!range_m ) {
599
610
range_M [0 ] = 0 ;
600
611
m = args -> m ;
@@ -607,7 +618,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
607
618
608
619
while (m > 0 ){
609
620
610
- width = blas_quickdivide (m + nthreads - num_cpu_m - 1 , nthreads - num_cpu_m );
621
+ width = blas_quickdivide (m + nthreads_m - num_cpu_m - 1 , nthreads_m - num_cpu_m );
611
622
612
623
m -= width ;
613
624
if (m < 0 ) width = width + m ;
@@ -617,12 +628,16 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
617
628
num_cpu_m ++ ;
618
629
}
619
630
620
- for (i = 0 ; i < num_cpu_m ; i ++ ) {
631
+ for (i = num_cpu_m ; i < MAX_CPU_NUMBER ; i ++ ) {
632
+ range_M [i + 1 ] = range_M [num_cpu_m ];
633
+ }
634
+
635
+ for (i = 0 ; i < nthreads ; i ++ ) {
621
636
queue [i ].mode = mode ;
622
637
queue [i ].routine = inner_thread ;
623
638
queue [i ].args = & newarg ;
624
- queue [i ].range_m = & range_M [ i ] ;
625
- queue [i ].range_n = & range_N [ 0 ] ;
639
+ queue [i ].range_m = range_M ;
640
+ queue [i ].range_n = range_N ;
626
641
queue [i ].sa = NULL ;
627
642
queue [i ].sb = NULL ;
628
643
queue [i ].next = & queue [i + 1 ];
@@ -659,17 +674,21 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
659
674
num_cpu_n ++ ;
660
675
}
661
676
662
- for (j = 0 ; j < num_cpu_m ; j ++ ) {
663
- for (i = 0 ; i < num_cpu_m ; i ++ ) {
677
+ for (j = num_cpu_n ; j < MAX_CPU_NUMBER ; j ++ ) {
678
+ range_N [j + 1 ] = range_N [num_cpu_n ];
679
+ }
680
+
681
+ for (j = 0 ; j < MAX_CPU_NUMBER ; j ++ ) {
682
+ for (i = 0 ; i < MAX_CPU_NUMBER ; i ++ ) {
664
683
for (k = 0 ; k < DIVIDE_RATE ; k ++ ) {
665
684
job [j ].working [i ][CACHE_LINE_SIZE * k ] = 0 ;
666
685
}
667
686
}
668
687
}
669
688
670
- queue [num_cpu_m - 1 ].next = NULL ;
689
+ queue [nthreads - 1 ].next = NULL ;
671
690
672
- exec_blas (num_cpu_m , queue );
691
+ exec_blas (nthreads , queue );
673
692
}
674
693
675
694
#ifdef USE_ALLOC_HEAP
@@ -684,6 +703,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
684
703
BLASLONG m = args -> m ;
685
704
BLASLONG n = args -> n ;
686
705
BLASLONG nthreads = args -> nthreads ;
706
+ BLASLONG nthreads_m , nthreads_n ;
687
707
688
708
if (nthreads == 1 ) {
689
709
GEMM_LOCAL (args , range_m , range_n , sa , sb , 0 );
@@ -704,21 +724,31 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
704
724
n = n_to - n_from ;
705
725
}
706
726
707
- if ((m < 2 * SWITCH_RATIO ) || (n < 2 * SWITCH_RATIO )) {
727
+ nthreads_m = nthreads ;
728
+ while (m < nthreads_m * SWITCH_RATIO ) {
729
+ nthreads_m = nthreads_m / 2 ;
730
+ }
731
+
732
+ if (nthreads_m < 1 ) {
708
733
GEMM_LOCAL (args , range_m , range_n , sa , sb , 0 );
709
734
return 0 ;
710
735
}
711
736
712
- if (m < nthreads * SWITCH_RATIO ) {
713
- nthreads = blas_quickdivide (m , SWITCH_RATIO );
737
+ nthreads_n = nthreads / nthreads_m ;
738
+ if (n < nthreads_m * (nthreads_n - 1 )) {
739
+ nthreads_n = (n + nthreads_m - 1 ) / nthreads_m ;
714
740
}
715
- if (n < nthreads * SWITCH_RATIO ) {
716
- nthreads = blas_quickdivide (n , SWITCH_RATIO );
741
+
742
+ nthreads = nthreads_m * nthreads_n ;
743
+
744
+ if (nthreads <= 1 ) {
745
+ GEMM_LOCAL (args , range_m , range_n , sa , sb , 0 );
746
+ return 0 ;
717
747
}
718
748
719
749
args -> nthreads = nthreads ;
720
750
721
- gemm_driver (args , range_m , range_n , sa , sb , 0 );
751
+ gemm_driver (args , range_m , range_n , sa , sb , nthreads_m , nthreads_n );
722
752
723
753
return 0 ;
724
754
}
0 commit comments