Skip to content

Commit dbb782d

Browse files
[mlir][shape] Turn ShapeOfOp folding into canonicalization pattern (#74438)
The `ShapeOfOp` folder used to generate invalid IR. Input: ``` %0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex> ``` Output: ``` %0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex> error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>' ``` This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original `shape.shape_of` op had a return type of `tensor<?xindex>`, but the folded attribute (materialized as a `shape.const_shape` op) must have a type of `tensor<0xf32>` to be valid. This commit fixes tests such as `mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after each pattern application (#74270).
1 parent 20da662 commit dbb782d

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
566566
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
567567

568568
let hasCanonicalizer = 1;
569-
let hasFolder = 1;
570569
let hasVerifier = 1;
571570
}
572571

mlir/lib/Dialect/Shape/IR/Shape.cpp

+25-9
Original file line numberDiff line numberDiff line change
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
16781678
// ShapeOfOp
16791679
//===----------------------------------------------------------------------===//
16801680

1681-
OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
1682-
auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
1683-
if (!type || !type.hasStaticShape())
1684-
return nullptr;
1685-
Builder builder(getContext());
1686-
return builder.getIndexTensorAttr(type.getShape());
1687-
}
1688-
16891681
namespace {
1682+
/// Replace shape_of(x) where x has a constant shape with a const_shape op.
1683+
struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1684+
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1685+
1686+
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1687+
PatternRewriter &rewriter) const override {
1688+
auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1689+
if (!type || !type.hasStaticShape())
1690+
return failure();
1691+
Location loc = op.getLoc();
1692+
Value constShape =
1693+
rewriter
1694+
.create<ConstShapeOp>(loc,
1695+
rewriter.getIndexTensorAttr(type.getShape()))
1696+
.getResult();
1697+
if (constShape.getType() != op.getResult().getType())
1698+
constShape = rewriter.create<tensor::CastOp>(
1699+
loc, op.getResult().getType(), constShape);
1700+
rewriter.replaceOp(op, constShape);
1701+
return success();
1702+
}
1703+
};
1704+
16901705
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
16911706
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
16921707

@@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
17391754
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
17401755
MLIRContext *context) {
17411756
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1742-
ExtractFromShapeOfExtentTensor>(context);
1757+
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1758+
context);
17431759
}
17441760

17451761
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(

mlir/test/Dialect/Shape/canonicalize.mlir

+12
Original file line numberDiff line numberDiff line change
@@ -1492,3 +1492,15 @@ func.func @add_poison() -> !shape.size {
14921492
%result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
14931493
return %result : !shape.size
14941494
}
1495+
1496+
// -----
1497+
1498+
// CHECK-LABEL: func @shape_of_0d(
1499+
// CHECK-SAME: %[[arg0:.*]]: tensor<f32>
1500+
// CHECK: %[[const:.*]] = shape.const_shape [] : tensor<0xindex>
1501+
// CHECK: %[[cast:.*]] = tensor.cast %[[const]] : tensor<0xindex> to tensor<?xindex>
1502+
// CHECK: return %[[cast]]
1503+
func.func @shape_of_0d(%arg0: tensor<f32>) -> tensor<?xindex> {
1504+
%0 = shape.shape_of %arg0 : tensor<f32> -> tensor<?xindex>
1505+
return %0 : tensor<?xindex>
1506+
}

0 commit comments

Comments
 (0)