@@ -480,3 +480,98 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
480
480
// CHECK: scf.reduce
481
481
// CHECK: }
482
482
// CHECK: memref.dealloc [[SUM]]
483
+
484
+ // -----
485
+
486
+ func.func @fuse_same_indices_by_affine_apply (
487
+ %A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
488
+ %c0 = arith.constant 0 : index
489
+ %c1 = arith.constant 1 : index
490
+ %c2 = arith.constant 2 : index
491
+ %sum = memref.alloc () : memref <2 x3 xf32 >
492
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
493
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
494
+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %j )
495
+ memref.store %B_elem , %sum [%i , %1 ] : memref <2 x3 xf32 >
496
+ scf.reduce
497
+ }
498
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
499
+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %j )
500
+ %sum_elem = memref.load %sum [%i , %1 ] : memref <2 x3 xf32 >
501
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
502
+ %product = arith.mulf %sum_elem , %A_elem : f32
503
+ memref.store %product , %B [%i , %j ] : memref <2 x2 xf32 >
504
+ scf.reduce
505
+ }
506
+ memref.dealloc %sum : memref <2 x3 xf32 >
507
+ return
508
+ }
509
+ // CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
510
+ // CHECK-LABEL: fuse_same_indices_by_affine_apply
511
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>) {
512
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
513
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
514
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
515
+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
516
+ // CHECK-NEXT: scf.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
517
+ // CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
518
+ // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
519
+ // CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32>
520
+ // CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
521
+ // CHECK-NEXT: %[[S3:.*]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32>
522
+ // CHECK-NEXT: %[[S4:.*]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
523
+ // CHECK-NEXT: %[[S5:.*]] = arith.mulf %[[S3]], %[[S4]] : f32
524
+ // CHECK-NEXT: memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
525
+ // CHECK-NEXT: scf.reduce
526
+ // CHECK-NEXT: }
527
+ // CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
528
+ // CHECK-NEXT: return
529
+
530
+ // -----
531
+
532
+ func.func @do_not_fuse_affine_apply_to_non_ind_var (
533
+ %A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >, %OffsetA: index , %OffsetB: index ) {
534
+ %c0 = arith.constant 0 : index
535
+ %c1 = arith.constant 1 : index
536
+ %c2 = arith.constant 2 : index
537
+ %sum = memref.alloc () : memref <2 x3 xf32 >
538
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
539
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
540
+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %OffsetA )
541
+ memref.store %B_elem , %sum [%i , %1 ] : memref <2 x3 xf32 >
542
+ scf.reduce
543
+ }
544
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
545
+ %1 = affine.apply affine_map <(d0 , d1 ) -> (d0 + d1 )>(%i , %OffsetB )
546
+ %sum_elem = memref.load %sum [%i , %1 ] : memref <2 x3 xf32 >
547
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
548
+ %product = arith.mulf %sum_elem , %A_elem : f32
549
+ memref.store %product , %B [%i , %j ] : memref <2 x2 xf32 >
550
+ scf.reduce
551
+ }
552
+ memref.dealloc %sum : memref <2 x3 xf32 >
553
+ return
554
+ }
555
+ // CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
556
+ // CHECK-LABEL: do_not_fuse_affine_apply_to_non_ind_var
557
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) {
558
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
559
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
560
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
561
+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
562
+ // CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
563
+ // CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
564
+ // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG2]])
565
+ // CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG4]], %[[S1]]] : memref<2x3xf32>
566
+ // CHECK-NEXT: scf.reduce
567
+ // CHECK-NEXT: }
568
+ // CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
569
+ // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG3]])
570
+ // CHECK-NEXT: %[[S1:.*]] = memref.load %[[ALLOC]][%[[ARG4]], %[[S0]]] : memref<2x3xf32>
571
+ // CHECK-NEXT: %[[S2:.*]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
572
+ // CHECK-NEXT: %[[S3:.*]] = arith.mulf %[[S1]], %[[S2]] : f32
573
+ // CHECK-NEXT: memref.store %[[S3]], %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
574
+ // CHECK-NEXT: scf.reduce
575
+ // CHECK-NEXT: }
576
+ // CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
577
+ // CHECK-NEXT: return
0 commit comments