diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 4866e31b19d5d..983f7a29cb220 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -459,7 +459,8 @@ class AnalysisState { /// Starting from `value`, follow the use-def chain in reverse, always /// selecting the aliasing OpOperands. Find and return Values for which /// `condition` evaluates to true. OpOperands of such matching Values are not - /// traversed any further. + /// traversed any further, the visited aliasing opOperands will be preserved + /// through `visitedOpOperands`. /// /// When reaching the end of a chain, also return the last Value of that /// chain if `config.alwaysIncludeLeaves` is set. @@ -484,7 +485,8 @@ class AnalysisState { /// `config`. SetVector findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - TraversalConfig config = TraversalConfig()) const; + TraversalConfig config = TraversalConfig(), + llvm::DenseSet *visitedOpOperands = nullptr) const; /// Find the values that may define the contents of the given value at /// runtime. A block argument is always a definition. An OpResult is a diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 065739ea8e595..f8a7a22787404 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const { // Starting from `value`, follow the use-def chain in reverse, always selecting // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any -// further. +// further, the visited aliasing opOperands will be preserved through +// `visitedOpOperands`. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - TraversalConfig config) const { + TraversalConfig config, + llvm::DenseSet *visitedOpOperands) const { llvm::DenseSet visited; llvm::SetVector result, workingSet; workingSet.insert(value); @@ -553,6 +555,8 @@ llvm::SetVector AnalysisState::findValueInReverseUseDefChain( } workingSet.insert(a.opOperand->get()); + if (visitedOpOperands) + visitedOpOperands->insert(a.opOperand); } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index cb2efef5c038b..abc0635a2cdff 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -48,27 +48,20 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, return true; } -/// Return true if the given `insertionPoint` dominates all uses of -/// `emptyTensorOp`. -static bool insertionPointDominatesUses(const DominanceInfo &domInfo, - Operation *insertionPoint, - Operation *emptyTensorOp) { - return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) { - return domInfo.dominates(insertionPoint, user); - }); -} - -/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming -/// that the replacement may use any value from `neededValues`. +/// Find a valid insertion point for a replacement of `emptyTensorOp`'s +/// use of `user` operation, assuming that the replacement may use any +/// value from `neededValues`. static Operation * -findValidInsertionPoint(Operation *emptyTensorOp, +findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, const SmallVector &neededValues) { DominanceInfo domInfo; + Operation *candidateInsertionPoint = emptyTensorOp; - // Gather all possible insertion points: the location of `emptyTensorOp` and - // right after the definition of each value in `neededValues`. + // Gather all possible insertion points: the location of + // `candidateInsertionPoint` and right after the definition of each value in + // `neededValues`. SmallVector insertionPointCandidates; - insertionPointCandidates.push_back(emptyTensorOp); + insertionPointCandidates.push_back(candidateInsertionPoint); for (Value val : neededValues) { // Note: The anchor op is using all of `neededValues`, so: // * in case of a block argument: There must be at least one op in the block @@ -90,8 +83,8 @@ findValidInsertionPoint(Operation *emptyTensorOp, if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, neededValues)) continue; - // Check if the insertion point is before all uses. - if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp)) + // Check if the insertion point is before the use to be replaced. + if (!domInfo.dominates(insertionPoint, user)) continue; return insertionPoint; } @@ -103,8 +96,9 @@ findValidInsertionPoint(Operation *emptyTensorOp, LogicalResult mlir::bufferization::eliminateEmptyTensors( RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { OpBuilder::InsertionGuard g(rewriter); - + llvm::DenseSet visitedOpOperands; op->walk([&](SubsetInsertionOpInterface op) { + visitedOpOperands.clear(); OpOperand &source = op.getSourceOperand(); // Skip operands that do not bufferize inplace. "tensor.empty" could still // be replaced, but the transformation may not be beneficial. @@ -131,16 +125,28 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( config.followSameTypeOrCastsOnly = true; SetVector emptyTensors = state.findValueInReverseUseDefChain( source.get(), /*condition=*/ - [&](Value val) { return val.getDefiningOp(); }, - config); + [&](Value val) { return val.getDefiningOp(); }, config, + &visitedOpOperands); for (Value v : emptyTensors) { Operation *emptyTensorOp = v.getDefiningOp(); + // Find the use to be replaced from the use-def chain. + auto iter = llvm::find_if( + visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { + return llvm::count(emptyTensorOp->getUses(), *opOperand); + }); + // This could be achieved when a use of `emptyTensorOp` is being + // consumed by `SubsetInsertionOpInterface`'s source directly. + if (iter == visitedOpOperands.end()) + continue; + OpOperand *useToBeReplaced = *iter; + Operation *user = useToBeReplaced->getOwner(); + // Find a suitable insertion point. If no suitable insertion point for // the replacement can be found, skip this replacement. Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, neededValues); + findValidInsertionPoint(emptyTensorOp, user, neededValues); if (!insertionPoint) continue; @@ -159,8 +165,10 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( replacement = rewriter.create(v.getLoc(), v.getType(), replacement); } - // Replace the tensor::EmptyOp. - rewriter.replaceOp(emptyTensorOp, replacement); + // Replace the specific use of the tensor::EmptyOp. + rewriter.modifyOpInPlace(user, [&]() { + user->setOperand(useToBeReplaced->getOperandNumber(), replacement); + }); state.resetCache(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir index 2ba8246a8d525..9150986f4c2a2 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir @@ -55,6 +55,7 @@ func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor< // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"] %cst = arith.constant 0.000000e+00 : f32 + // CHECK: bufferization.alloc_tensor(%arg1) %0 = tensor.empty(%arg1) : tensor // CHECK: bufferization.alloc_tensor(%arg1) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index efe59af97d964..26434774730e1 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -365,3 +365,103 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32 bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> () return } + +// ----- + +// `EmptyTensorElimination` fails to find a valid insertion +// point for the new injected `SubsetExtraction`. +// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors +func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> { + %cst_1 = arith.constant 1.0 : f32 + %cst_2 = arith.constant 2.0 : f32 + // CHECK: memref.alloc + // CHECK: memref.alloc + // CHECK: memref.alloc + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %empty_2 = tensor.empty() : tensor<5x6x64xf32> + %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + // CHECK: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_2 : tensor<5x6x128xf32> +} + +// ----- + +// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor +func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { + %cst_1 = arith.constant 1.0 : f32 + %cst_2 = arith.constant 2.0 : f32 + // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> + // CHECK: memref.alloc + // CHECK-NOT: memref.alloc + %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %empty_2 = tensor.empty() : tensor<5x6x64xf32> + %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + // CHECK: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_2 : tensor<5x6x128xf32> +} + +// ----- + +// `EmptyTensorElimination` will replace the specific use of the tensor +// empty with the new injected `SubsetExtraction`, i.e. the specific use +// which has been tracked. + +// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty +// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty +func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> { + %cst_1 = arith.constant 1.0 : f32 + %cst_2 = arith.constant 2.0 : f32 + %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice + // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]] + // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]] + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + // CHECK: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + // CHECK-NOT: memref.copy + %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_2 : tensor<5x6x128xf32> +} + +// ----- + +// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read +// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read +func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>) + -> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) { + %cst_1 = arith.constant 1.0 : f32 + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32> + // CHECK-NOT: memref.alloc + %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> + %res_2 = linalg.generic{ + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"] + } + ins(%empty_1 : tensor<5x6x64xf32>) + outs(%arg2 :tensor<5x6x64xf32>) { + ^bb0(%in: f32, %out: f32): + %res = arith.addf %in, %in : f32 + linalg.yield %res : f32 + } -> tensor<5x6x64xf32> + // CHECK-NOT: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32> +}