Skip to content

Commit c3eb297

Browse files
authored
[mlir][scf] Considering defining operators of indices when fusing scf::ParallelOp (#80145)
When checking the load indices of the second loop coincide with the store indices of the first loop, it only considers the index values are the same or not. However, there are some cases the index values defined by other operators. In these cases, it will treat them as different even the results of defining operators are the same. We already check if the iteration space is the same in isFusionLegal(). When checking operands of defining operators, we only need to consider the operands come from the same induction variables. If so, we know the results of defining operators are the same.
1 parent 5d41788 commit c3eb297

File tree

2 files changed

+120
-2
lines changed

2 files changed

+120
-2
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/Builders.h"
2020
#include "mlir/IR/IRMapping.h"
2121
#include "mlir/IR/OpDefinition.h"
22+
#include "mlir/IR/OperationSupport.h"
2223
#include "mlir/Interfaces/SideEffectInterfaces.h"
2324

2425
namespace mlir {
@@ -102,8 +103,30 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
102103
return WalkResult::interrupt();
103104
for (int i = 0, e = storeIndices.size(); i < e; ++i) {
104105
if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
105-
loadIndices[i])
106-
return WalkResult::interrupt();
106+
loadIndices[i]) {
107+
auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108+
auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109+
if (storeIndexDefOp && loadIndexDefOp) {
110+
if (!isMemoryEffectFree(storeIndexDefOp))
111+
return WalkResult::interrupt();
112+
if (!isMemoryEffectFree(loadIndexDefOp))
113+
return WalkResult::interrupt();
114+
if (!OperationEquivalence::isEquivalentTo(
115+
storeIndexDefOp, loadIndexDefOp,
116+
[&](Value storeIndex, Value loadIndex) {
117+
if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118+
firstToSecondPloopIndices.lookupOrDefault(loadIndex))
119+
return failure();
120+
else
121+
return success();
122+
},
123+
/*markEquivalent=*/nullptr,
124+
OperationEquivalence::Flags::IgnoreLocations)) {
125+
return WalkResult::interrupt();
126+
}
127+
} else
128+
return WalkResult::interrupt();
129+
}
107130
}
108131
return WalkResult::advance();
109132
});

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,98 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
480480
// CHECK: scf.reduce
481481
// CHECK: }
482482
// CHECK: memref.dealloc [[SUM]]
483+
484+
// -----
485+
486+
func.func @fuse_same_indices_by_affine_apply(
487+
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
488+
%c0 = arith.constant 0 : index
489+
%c1 = arith.constant 1 : index
490+
%c2 = arith.constant 2 : index
491+
%sum = memref.alloc() : memref<2x3xf32>
492+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
493+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
494+
%1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
495+
memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
496+
scf.reduce
497+
}
498+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
499+
%1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
500+
%sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
501+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
502+
%product = arith.mulf %sum_elem, %A_elem : f32
503+
memref.store %product, %B[%i, %j] : memref<2x2xf32>
504+
scf.reduce
505+
}
506+
memref.dealloc %sum : memref<2x3xf32>
507+
return
508+
}
509+
// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
510+
// CHECK-LABEL: fuse_same_indices_by_affine_apply
511+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>) {
512+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
513+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
514+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
515+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
516+
// CHECK-NEXT: scf.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
517+
// CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
518+
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
519+
// CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32>
520+
// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
521+
// CHECK-NEXT: %[[S3:.*]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32>
522+
// CHECK-NEXT: %[[S4:.*]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
523+
// CHECK-NEXT: %[[S5:.*]] = arith.mulf %[[S3]], %[[S4]] : f32
524+
// CHECK-NEXT: memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
525+
// CHECK-NEXT: scf.reduce
526+
// CHECK-NEXT: }
527+
// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
528+
// CHECK-NEXT: return
529+
530+
// -----
531+
532+
func.func @do_not_fuse_affine_apply_to_non_ind_var(
533+
%A: memref<2x2xf32>, %B: memref<2x2xf32>, %OffsetA: index, %OffsetB: index) {
534+
%c0 = arith.constant 0 : index
535+
%c1 = arith.constant 1 : index
536+
%c2 = arith.constant 2 : index
537+
%sum = memref.alloc() : memref<2x3xf32>
538+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
539+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
540+
%1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetA)
541+
memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
542+
scf.reduce
543+
}
544+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
545+
%1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetB)
546+
%sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
547+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
548+
%product = arith.mulf %sum_elem, %A_elem : f32
549+
memref.store %product, %B[%i, %j] : memref<2x2xf32>
550+
scf.reduce
551+
}
552+
memref.dealloc %sum : memref<2x3xf32>
553+
return
554+
}
555+
// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
556+
// CHECK-LABEL: do_not_fuse_affine_apply_to_non_ind_var
557+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) {
558+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
559+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
560+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
561+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
562+
// CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
563+
// CHECK-NEXT: %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
564+
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG2]])
565+
// CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG4]], %[[S1]]] : memref<2x3xf32>
566+
// CHECK-NEXT: scf.reduce
567+
// CHECK-NEXT: }
568+
// CHECK-NEXT: scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
569+
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG3]])
570+
// CHECK-NEXT: %[[S1:.*]] = memref.load %[[ALLOC]][%[[ARG4]], %[[S0]]] : memref<2x3xf32>
571+
// CHECK-NEXT: %[[S2:.*]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
572+
// CHECK-NEXT: %[[S3:.*]] = arith.mulf %[[S1]], %[[S2]] : f32
573+
// CHECK-NEXT: memref.store %[[S3]], %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
574+
// CHECK-NEXT: scf.reduce
575+
// CHECK-NEXT: }
576+
// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
577+
// CHECK-NEXT: return

0 commit comments

Comments
 (0)