Skip to content

Commit 85ea2c2

Browse files
committed
[mlir][bufferization]-Replace only one use in TensorEmptyElimination
This MR hanldes the second case where we want to replace only the specific use which was visited in the `use-def` chain (when traversing from the tensor.insert_slice's source). This scenario of replacing all the uses of the tensor.empty may lead into additional read effects after bufferization of the specific subset extract/subview which should not be the case, Thus eliminating a potential copies.
1 parent aa2a621 commit 85ea2c2

File tree

5 files changed

+50
-37
lines changed

5 files changed

+50
-37
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ class AnalysisState {
459459
/// Starting from `value`, follow the use-def chain in reverse, always
460460
/// selecting the aliasing OpOperands. Find and return Values for which
461461
/// `condition` evaluates to true. OpOperands of such matching Values are not
462-
/// traversed any further.
462+
/// traversed any further, the visited aliasing opOperands will be preserved
463+
/// through `visitedOpOperands`.
463464
///
464465
/// When reaching the end of a chain, also return the last Value of that
465466
/// chain if `config.alwaysIncludeLeaves` is set.
@@ -484,7 +485,8 @@ class AnalysisState {
484485
/// `config`.
485486
SetVector<Value> findValueInReverseUseDefChain(
486487
Value value, llvm::function_ref<bool(Value)> condition,
487-
TraversalConfig config = TraversalConfig()) const;
488+
TraversalConfig config = TraversalConfig(),
489+
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
488490

489491
/// Find the values that may define the contents of the given value at
490492
/// runtime. A block argument is always a definition. An OpResult is a

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const {
483483
// Starting from `value`, follow the use-def chain in reverse, always selecting
484484
// the aliasing OpOperands. Find and return Values for which `condition`
485485
// evaluates to true. OpOperands of such matching Values are not traversed any
486-
// further.
486+
// further, the visited aliasing opOperands will be preserved through
487+
// `visitedOpOperands`.
487488
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
488489
Value value, llvm::function_ref<bool(Value)> condition,
489-
TraversalConfig config) const {
490+
TraversalConfig config,
491+
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
490492
llvm::DenseSet<Value> visited;
491493
llvm::SetVector<Value> result, workingSet;
492494
workingSet.insert(value);
@@ -553,6 +555,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
553555
}
554556

555557
workingSet.insert(a.opOperand->get());
558+
if (visitedOpOperands)
559+
visitedOpOperands->insert(a.opOperand);
556560
}
557561
}
558562

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,20 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
4848
return true;
4949
}
5050

51-
/// Return true if the given `insertionPoint` dominates all uses of
52-
/// `emptyTensorOp`.
53-
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
54-
Operation *insertionPoint,
55-
Operation *emptyTensorOp) {
56-
return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
57-
return domInfo.dominates(insertionPoint, user);
58-
});
59-
}
60-
61-
/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
62-
/// that the replacement may use any value from `neededValues`.
51+
/// Find a valid insertion point for a replacement of `useToBeEliminated`,
52+
/// assuming that the replacement may use any value from `neededValues`.
6353
static Operation *
64-
findValidInsertionPoint(Operation *emptyTensorOp,
54+
findValidInsertionPoint(OpOperand *useToBeEliminated,
6555
const SmallVector<Value> &neededValues) {
6656
DominanceInfo domInfo;
57+
assert(isa<OpResult>(useToBeEliminated->get()) && "expected a result value");
58+
Operation *candidateInsertionPoint = useToBeEliminated->get().getDefiningOp();
6759

68-
// Gather all possible insertion points: the location of `emptyTensorOp` and
69-
// right after the definition of each value in `neededValues`.
60+
// Gather all possible insertion points: the location of
61+
// `candidateInsertionPoint` and right after the definition of each value in
62+
// `neededValues`.
7063
SmallVector<Operation *> insertionPointCandidates;
71-
insertionPointCandidates.push_back(emptyTensorOp);
64+
insertionPointCandidates.push_back(candidateInsertionPoint);
7265
for (Value val : neededValues) {
7366
// Note: The anchor op is using all of `neededValues`, so:
7467
// * in case of a block argument: There must be at least one op in the block
@@ -90,8 +83,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
9083
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
9184
neededValues))
9285
continue;
93-
// Check if the insertion point is before all uses.
94-
if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
86+
// Check if the insertion point is before the use to be replaced.
87+
if (!domInfo.dominates(insertionPoint, useToBeEliminated->getOwner()))
9588
continue;
9689
return insertionPoint;
9790
}
@@ -103,8 +96,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
10396
LogicalResult mlir::bufferization::eliminateEmptyTensors(
10497
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
10598
OpBuilder::InsertionGuard g(rewriter);
106-
99+
llvm::DenseSet<OpOperand *> visitedOpOperands;
107100
op->walk([&](SubsetInsertionOpInterface op) {
101+
visitedOpOperands.clear();
108102
OpOperand &source = op.getSourceOperand();
109103
// Skip operands that do not bufferize inplace. "tensor.empty" could still
110104
// be replaced, but the transformation may not be beneficial.
@@ -131,16 +125,25 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
131125
config.followSameTypeOrCastsOnly = true;
132126
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
133127
source.get(), /*condition=*/
134-
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
135-
config);
128+
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
129+
&visitedOpOperands);
136130

137131
for (Value v : emptyTensors) {
138132
Operation *emptyTensorOp = v.getDefiningOp();
139133

134+
// Find the use to be replaced from the use-def chain
135+
auto iter = llvm::find_if(
136+
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
137+
return llvm::count(emptyTensorOp->getUses(), *opOperand);
138+
});
139+
if (iter == visitedOpOperands.end())
140+
continue;
141+
OpOperand *useToBeReplaced = *iter;
142+
140143
// Find a suitable insertion point. If no suitable insertion point for
141144
// the replacement can be found, skip this replacement.
142145
Operation *insertionPoint =
143-
findValidInsertionPoint(emptyTensorOp, neededValues);
146+
findValidInsertionPoint(useToBeReplaced, neededValues);
144147
if (!insertionPoint)
145148
continue;
146149

@@ -159,8 +162,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
159162
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
160163
replacement);
161164
}
162-
// Replace the tensor::EmptyOp.
163-
rewriter.replaceOp(emptyTensorOp, replacement);
165+
// Replace the specific use of the tensor::EmptyOp.
166+
useToBeReplaced->getOwner()->setOperand(
167+
useToBeReplaced->getOperandNumber(), replacement);
164168
state.resetCache();
165169
}
166170

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<
5555
// CHECK: tensor.extract_slice
5656
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
5757
%cst = arith.constant 0.000000e+00 : f32
58+
// CHECK: bufferization.alloc_tensor(%arg1)
5859
%0 = tensor.empty(%arg1) : tensor<?xf32>
5960

6061
// CHECK: bufferization.alloc_tensor(%arg1)

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,9 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
396396
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
397397
%cst_1 = arith.constant 1.0 : f32
398398
%cst_2 = arith.constant 2.0 : f32
399+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
399400
// CHECK: memref.alloc
400-
// CHECK: memref.alloc
401+
// CHECK-NOT: memref.alloc
401402
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
402403
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
403404
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
@@ -413,10 +414,9 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
413414

414415
// -----
415416

416-
// `EmptyTensorElimination` replaces all of the uses of the tensor
417-
// empty with the new injected `SubsetExtraction`, without to consider
418-
// the specific use has been tracked, sometimes creating a non existent
419-
// bufferization conflicts.
417+
// `EmptyTensorElimination` will replace the specific use of the tensor
418+
// empty with the new injected `SubsetExtraction`, i.e. the specific use
419+
// which has been tracked.
420420

421421
// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty
422422
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty
@@ -427,13 +427,13 @@ func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
427427
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
428428
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429429
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430-
// CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
430+
// CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
431431
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
432432
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
433433
// CHECK: memref.copy
434434
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
435435
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
436-
// CHECK: memref.copy
436+
// CHECK-NOT: memref.copy
437437
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
438438
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
439439
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -442,11 +442,13 @@ func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
442442
// -----
443443

444444
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
445+
// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
445446
func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>)
446447
-> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
447448
%cst_1 = arith.constant 1.0 : f32
448449
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
449-
// CHECK: memref.alloc
450+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32>
451+
// CHECK-NOT: memref.alloc
450452
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
451453
%res_2 = linalg.generic{
452454
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
@@ -458,7 +460,7 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
458460
%res = arith.addf %in, %in : f32
459461
linalg.yield %res : f32
460462
} -> tensor<5x6x64xf32>
461-
// CHECK: memref.copy
463+
// CHECK-NOT: memref.copy
462464
%inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
463465
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
464466
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>

0 commit comments

Comments
 (0)