Skip to content

Commit 192439d

Browse files
[mlir][Complex] Fix bug in MergeComplexBitcast (#74271)
When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid. Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input. Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (#74270).
1 parent c3a9c90 commit 192439d

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

mlir/lib/Dialect/Complex/IR/ComplexOps.cpp

+12-19
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
100100
}
101101

102102
if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
103-
return emitOpError("requires input or output is a complex type");
103+
return emitOpError(
104+
"requires that either input or output has a complex type");
104105
}
105106

106107
if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
125126
LogicalResult matchAndRewrite(BitcastOp op,
126127
PatternRewriter &rewriter) const override {
127128
if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
128-
rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
129-
defining.getOperand());
129+
if (isa<ComplexType>(op.getType()) ||
130+
isa<ComplexType>(defining.getOperand().getType())) {
131+
// complex.bitcast requires that input or output is complex.
132+
rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
133+
defining.getOperand());
134+
} else {
135+
rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
136+
defining.getOperand());
137+
}
130138
return success();
131139
}
132140

@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
155163
}
156164
};
157165

158-
struct ArithBitcast final : OpRewritePattern<BitcastOp> {
159-
using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
160-
161-
LogicalResult matchAndRewrite(BitcastOp op,
162-
PatternRewriter &rewriter) const override {
163-
if (isa<ComplexType>(op.getType()) ||
164-
isa<ComplexType>(op.getOperand().getType()))
165-
return failure();
166-
167-
rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
168-
op.getOperand());
169-
return success();
170-
}
171-
};
172-
173166
void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
174167
MLIRContext *context) {
175-
results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
168+
results.add<MergeComplexBitcast, MergeArithBitcast>(context);
176169
}
177170

178171
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Complex/invalid.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
2525
// -----
2626

2727
func.func @complex_bitcast_i64(%arg0 : i64) {
28-
// expected-error @+1 {{op requires input or output is a complex type}}
28+
// expected-error @+1 {{op requires that either input or output has a complex type}}
2929
%0 = complex.bitcast %arg0: i64 to f64
3030
return
3131
}

0 commit comments

Comments
 (0)