Skip to content

Commit 8087fc0

Browse files
committed
[mlir][bufferization]-Replace only one use in TensorEmptyElimination
This MR hanldes the second case where we want to replace only the specific use which was visited in the `use-def` chain (when traversing from the tensor.insert_slice's source). This scenario of replacing all the uses of the tensor.empty may lead into additional read effects after bufferization of the specific subset extract/subview which should not be the case, Thus eliminating a potential copies.
1 parent 4a30806 commit 8087fc0

File tree

5 files changed

+65
-45
lines changed

5 files changed

+65
-45
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ class AnalysisState {
459459
/// Starting from `value`, follow the use-def chain in reverse, always
460460
/// selecting the aliasing OpOperands. Find and return Values for which
461461
/// `condition` evaluates to true. OpOperands of such matching Values are not
462-
/// traversed any further.
462+
/// traversed any further, The visited aliasing opOperands will be preserved
463+
/// through `visitedOpOperands`.
463464
///
464465
/// When reaching the end of a chain, also return the last Value of that
465466
/// chain if `config.alwaysIncludeLeaves` is set.
@@ -484,7 +485,8 @@ class AnalysisState {
484485
/// `config`.
485486
SetVector<Value> findValueInReverseUseDefChain(
486487
Value value, llvm::function_ref<bool(Value)> condition,
487-
TraversalConfig config = TraversalConfig()) const;
488+
TraversalConfig config = TraversalConfig(),
489+
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
488490

489491
/// Find the values that may define the contents of the given value at
490492
/// runtime. A block argument is always a definition. An OpResult is a

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const {
483483
// Starting from `value`, follow the use-def chain in reverse, always selecting
484484
// the aliasing OpOperands. Find and return Values for which `condition`
485485
// evaluates to true. OpOperands of such matching Values are not traversed any
486-
// further.
486+
// further, The visited aliasing opOperands will be preserved through
487+
// `visitedOpOperands`.
487488
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
488489
Value value, llvm::function_ref<bool(Value)> condition,
489-
TraversalConfig config) const {
490+
TraversalConfig config,
491+
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
490492
llvm::DenseSet<Value> visited;
491493
llvm::SetVector<Value> result, workingSet;
492494
workingSet.insert(value);
@@ -553,6 +555,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
553555
}
554556

555557
workingSet.insert(a.opOperand->get());
558+
if (visitedOpOperands)
559+
visitedOpOperands->insert(a.opOperand);
556560
}
557561
}
558562

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -57,44 +57,43 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
5757
return true;
5858
}
5959

60-
/// Return true if the given `insertionPoint` dominates all uses of
61-
/// `emptyTensorOp`.
62-
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
63-
Operation *insertionPoint,
64-
Operation *emptyTensorOp) {
65-
return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
66-
return domInfo.dominates(insertionPoint, user);
67-
});
68-
}
69-
70-
/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
71-
/// that the replacement may use any value from `neededValues`.
60+
/// Find a valid insertion point for a replacement of `useToBeEliminated`,
61+
/// assuming that the replacement may use any value from `neededValues`.
7262
static Operation *
73-
findValidInsertionPoint(Operation *emptyTensorOp,
63+
findValidInsertionPoint(OpOperand *useToBeEliminated,
7464
const SmallVector<Value> &neededValues) {
7565
DominanceInfo domInfo;
7666

67+
Operation *candidateInsertionPoint = useToBeEliminated->getOwner();
68+
assert(isa<OpResult>(useToBeEliminated->get()) && "expected a result value");
69+
// Both `tensor.empty` and its user are within different blocks.
70+
if (useToBeEliminated->getOwner()->getBlock() !=
71+
useToBeEliminated->get().getDefiningOp()->getBlock())
72+
candidateInsertionPoint = useToBeEliminated->get().getDefiningOp();
73+
7774
// Trying to move the needed values before the `emptyTensorOp`.
7875
for (Value val : neededValues) {
79-
if (valueDominateInsertionPoint(domInfo, emptyTensorOp, val))
76+
if (valueDominateInsertionPoint(domInfo, candidateInsertionPoint, val))
8077
continue;
8178
Operation *definingOp = val.getDefiningOp();
8279
if (!definingOp)
8380
continue;
8481

8582
bool isItSafeToMoveOp =
8683
llvm::all_of(definingOp->getOperands(), [&](Value operand) {
87-
return valueDominateInsertionPoint(domInfo, emptyTensorOp, operand);
84+
return valueDominateInsertionPoint(domInfo, candidateInsertionPoint,
85+
operand);
8886
});
8987

9088
if (isItSafeToMoveOp)
91-
definingOp->moveBefore(emptyTensorOp);
89+
definingOp->moveBefore(candidateInsertionPoint);
9290
}
9391

94-
// Gather all possible insertion points: the location of `emptyTensorOp` and
95-
// right after the definition of each value in `neededValues`.
92+
// Gather all possible insertion points: the location of
93+
// `candidateInsertionPoint` and right after the definition of each value in
94+
// `neededValues`.
9695
SmallVector<Operation *> insertionPointCandidates;
97-
insertionPointCandidates.push_back(emptyTensorOp);
96+
insertionPointCandidates.push_back(candidateInsertionPoint);
9897
for (Value val : neededValues) {
9998
// Note: The anchor op is using all of `neededValues`, so:
10099
// * in case of a block argument: There must be at least one op in the block
@@ -116,8 +115,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
116115
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
117116
neededValues))
118117
continue;
119-
// Check if the insertion point is before all uses.
120-
if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
118+
// Check if the insertion point is before the use to be replaced.
119+
if (!domInfo.dominates(insertionPoint, useToBeEliminated->getOwner()))
121120
continue;
122121
return insertionPoint;
123122
}
@@ -129,8 +128,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
129128
LogicalResult mlir::bufferization::eliminateEmptyTensors(
130129
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
131130
OpBuilder::InsertionGuard g(rewriter);
132-
131+
llvm::DenseSet<OpOperand *> visitedOpOperands;
133132
op->walk([&](SubsetInsertionOpInterface op) {
133+
visitedOpOperands.clear();
134134
OpOperand &source = op.getSourceOperand();
135135
// Skip operands that do not bufferize inplace. "tensor.empty" could still
136136
// be replaced, but the transformation may not be beneficial.
@@ -157,16 +157,25 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
157157
config.followSameTypeOrCastsOnly = true;
158158
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
159159
source.get(), /*condition=*/
160-
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
161-
config);
160+
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
161+
&visitedOpOperands);
162162

163163
for (Value v : emptyTensors) {
164164
Operation *emptyTensorOp = v.getDefiningOp();
165165

166+
// Find the use to be replaced from the use-def chain
167+
auto iter = llvm::find_if(
168+
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
169+
return llvm::count(emptyTensorOp->getUses(), *opOperand);
170+
});
171+
if (iter == visitedOpOperands.end())
172+
continue;
173+
OpOperand *useToBeReplaced = *iter;
174+
166175
// Find a suitable insertion point. If no suitable insertion point for
167176
// the replacement can be found, skip this replacement.
168177
Operation *insertionPoint =
169-
findValidInsertionPoint(emptyTensorOp, neededValues);
178+
findValidInsertionPoint(useToBeReplaced, neededValues);
170179
if (!insertionPoint)
171180
continue;
172181

@@ -185,8 +194,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
185194
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
186195
replacement);
187196
}
188-
// Replace the tensor::EmptyOp.
189-
rewriter.replaceOp(emptyTensorOp, replacement);
197+
// Replace the specific use of the tensor::EmptyOp.
198+
useToBeReplaced->getOwner()->setOperand(
199+
useToBeReplaced->getOperandNumber(), replacement);
190200
state.resetCache();
191201
}
192202

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
5252

5353
// CHECK-LABEL: func @buffer_forwarding_conflict_with_different_element_type
5454
func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
55-
// CHECK: tensor.extract_slice
56-
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
5755
%cst = arith.constant 0.000000e+00 : f32
56+
// CHECK: bufferization.alloc_tensor(%arg1)
5857
%0 = tensor.empty(%arg1) : tensor<?xf32>
5958

6059
// CHECK: bufferization.alloc_tensor(%arg1)
@@ -64,6 +63,10 @@ func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<
6463
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
6564
%2 = linalg.copy ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xbf16>) -> tensor<?xbf16>
6665

66+
67+
// CHECK: tensor.extract_slice
68+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
69+
6770
// CHECK: linalg.copy
6871
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
6972
%3 = linalg.copy ins(%2 : tensor<?xbf16>) outs(%0 : tensor<?xf32>) -> tensor<?xf32>

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
396396
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
397397
%cst_1 = arith.constant 1.0 : f32
398398
%cst_2 = arith.constant 2.0 : f32
399-
// CHECK: memref.alloc
399+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
400400
// CHECK-NOT: memref.alloc
401401
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
402402
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
@@ -413,10 +413,9 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
413413

414414
// -----
415415

416-
// `EmptyTensorElimination` replaces all of the uses of the tensor
417-
// empty with the new injected `SubsetExtraction`, without to consider
418-
// the specific use has been tracked, sometimes creating a non existent
419-
// bufferization conflicts.
416+
// `EmptyTensorElimination` will replace the specific use of the tensor
417+
// empty with the new injected `SubsetExtraction`, i.e. the specific use
418+
// which has been tracked.
420419

421420
// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty
422421
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty
@@ -425,15 +424,16 @@ func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
425424
%cst_2 = arith.constant 2.0 : f32
426425
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
427426
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
428-
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429-
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430-
// CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
427+
// CHECK-ELIM: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_2:.*]]
428+
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_4]] : tensor<5x6x64xf32>)
429+
// CHECK-ELIM: %[[VAL_6:.*]] = tensor.insert_slice
431430
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
431+
// CHECK-ELIM: %[[VAL_7:.*]] = tensor.extract_slice %[[VAL_6]]
432+
// CHECK-ELIM: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_7]] : tensor<5x6x64xf32>)
432433
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
433-
// CHECK: memref.copy
434+
// CHECK-NOT: memref.copy
434435
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
435436
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
436-
// CHECK: memref.copy
437437
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
438438
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
439439
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -446,7 +446,8 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
446446
-> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
447447
%cst_1 = arith.constant 1.0 : f32
448448
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
449-
// CHECK: memref.alloc
449+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32>
450+
// CHECK-NOT: memref.alloc
450451
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
451452
%res_2 = linalg.generic{
452453
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
@@ -458,7 +459,7 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
458459
%res = arith.addf %in, %in : f32
459460
linalg.yield %res : f32
460461
} -> tensor<5x6x64xf32>
461-
// CHECK: memref.copy
462+
// CHECK-NOT: memref.copy
462463
%inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
463464
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
464465
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>

0 commit comments

Comments
 (0)