Skip to content

Commit 6506355

Browse files
authored
[mlir][arith] Only fold splats for static shape result types (#93102)
This prevents an assertion when constructing the DenseElementsAttr result, where the passed-in type is expected to have a static shape. Fixes #92057
1 parent 264aaa5 commit 6506355

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,10 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
298298
calculate(op.getSplatValue<ElementValueT>(), castStatus);
299299
if (!castStatus)
300300
return {};
301-
return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
301+
auto shapedResType = cast<ShapedType>(resType);
302+
if (!shapedResType.hasStaticShape())
303+
return {};
304+
return DenseElementsAttr::get(shapedResType, elementResult);
302305
}
303306
if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
304307
// Operand is ElementsAttr-derived; perform an element-wise fold by

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,6 +2950,14 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
29502950
return %ext : tensor<i16>
29512951
}
29522952

2953+
// Just checks that this doesn't crash.
2954+
// CHECK-LABEL: @signedExtendSplatAsDynamicShape
2955+
func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
2956+
%splat = arith.constant dense<5> : tensor<2xi16>
2957+
%extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
2958+
return %extsplat : tensor<?xi64>
2959+
}
2960+
29532961
// CHECK-LABEL: @extsi_i0
29542962
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
29552963
// CHECK: return %[[ZERO]] : i16

0 commit comments

Comments
 (0)