From 95ffa642bc65f3090928967f37d9cbc6b16abd85 Mon Sep 17 00:00:00 2001 From: lipracer Date: Wed, 4 Oct 2023 00:09:25 +0800 Subject: [PATCH] [mlir]: fix a issue and refine some code (#67977) 1) fix empty-tensor-elimination pass crash 2) improve linlg.copy op's canonicalization pattern 3) add indentation when emit regionBuilder func --- .../Transforms/EmptyTensorElimination.cpp | 2 ++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 15 ++++++++------- ...e-shot-bufferize-empty-tensor-elimination.mlir | 11 +++++++++++ .../mlir-linalg-ods-yaml-gen.cpp | 14 +++++++------- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 77ad13dacaa98..4c5789306ad75 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -149,6 +149,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); if (!replacement) continue; + if (emptyTensorOp == replacement.getDefiningOp()) + continue; if (replacement.getType() != v.getType()) { rewriter.setInsertionPointAfterValue(replacement); replacement = rewriter.create(v.getLoc(), v.getType(), diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 491f4a6657461..5457d51db1cc1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -545,16 +545,17 @@ class RegionBuilderHelper { namespace { -struct EraseSelfCopyOnBuffers : OpRewritePattern { +struct EraseSelfCopy : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CopyOp copyOp, PatternRewriter &rewriter) const override { - if (!copyOp.hasBufferSemantics()) - return rewriter.notifyMatchFailure(copyOp, - "does not have buffer semantics"); - if (copyOp.getInputs().front() != copyOp.getOutputs().front()) + if (copyOp.getInputs() != copyOp.getOutputs()) return rewriter.notifyMatchFailure(copyOp, "not a self copy"); - rewriter.eraseOp(copyOp); + if (copyOp.hasBufferSemantics()) + rewriter.eraseOp(copyOp); + else + rewriter.replaceOp(copyOp, copyOp.getInputs()); + return success(); } }; @@ -563,7 +564,7 @@ struct EraseSelfCopyOnBuffers : OpRewritePattern { void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 41e43047657da..b68682a459ed2 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -317,3 +317,14 @@ func.func @linalg_copy(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> { %1 = linalg.copy ins(%filled : tensor<5xf32>) outs(%t : tensor<5xf32>) -> tensor<5xf32> return %1 : tensor<5xf32> } + +// ----- + +// CHECK-LABEL: func @linalg_copy_empty( +// CHECK: %[[ret:.*]] = memref.alloc() +// CHECK-NEXT: return %[[ret]] +func.func @linalg_copy_empty() -> tensor<26xi32> { + %0 = tensor.empty() : tensor<26xi32> + %1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32> + return %1 : tensor<26xi32> +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 664167e4f6c34..5898b0f7d69e8 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -1029,13 +1029,13 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, // {1}: attribute name // {2}: default type function name static const char attrDef[] = R"FMT( -{0} {1}Val = {0}::{2}; -auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ - return attr.getName() == "{1}"; }); -if ({1}Iter != attrs.end()) {{ - if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue())) - {1}Val = attr.getValue(); -} + {0} {1}Val = {0}::{2}; + auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ + return attr.getName() == "{1}"; }); + if ({1}Iter != attrs.end()) {{ + if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue())) + {1}Val = attr.getValue(); + } )FMT"; std::string enumName = convertOperandKindToEnumName(arg.kind); attrs.push_back(