Skip to content

Commit 2d94312

Browse files
[mlir][IR] Fix bug in AffineExpr simplifier lhs % rhs where lhs = lhs floordiv rhs
Fixes an issue where the `SimpleAffineExprFlattener` would simplify `lhs % rhs` to just `-(lhs floordiv rhs)` instead of `lhs - (lhs floordiv rhs)` if `lhs` happened to be equal to `lhs floordiv rhs`. The reported failure case was `(d0, d1) -> (((d1 - (d1 + 2)) floordiv 8) % 8)` from #114654. Note that many paths that simplify AffineMaps (e.g. the AffineApplyOp folder and canonicalization) would not observe this bug because of of slightly different paths taken by the code. Slightly different grouping of the terms could also result in avoiding the bug. The way to reproduce was by constructing the map directly, replacing `d1` with `1` and calling `mlir::simplifyAffineExpr`. Resolves #114654.
1 parent 708a478 commit 2d94312

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1385,7 +1385,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
13851385
lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
13861386
} else {
13871387
// Reuse the existing local id.
1388-
lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1388+
lhs[getLocalVarStartIndex() + loc] -= rhsConst;
13891389
}
13901390
return success();
13911391
}

mlir/unittests/IR/AffineExprTest.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,24 @@ TEST(AffineExprTest, d0PlusD0FloorDivNeg2) {
129129
auto sum = d0 + d0.floorDiv(-2) * 2;
130130
ASSERT_EQ(toString(sum), "d0 + (d0 floordiv -2) * 2");
131131
}
132+
133+
TEST(AffineExprTEst, simpleAffineExprFlattenerRegression) {
134+
135+
// Regression test for a bug where mod simplification was not handled
136+
// properly when `lhs % rhs` was happened to have the property that `lhs
137+
// floordiv rhs = lhs`.
138+
MLIRContext ctx;
139+
OpBuilder b(&ctx);
140+
141+
auto d0 = b.getAffineDimExpr(0);
142+
auto d1 = b.getAffineDimExpr(1);
143+
144+
// Manually replace variables by constants to avoid constant folding.
145+
AffineExpr expr = (d0 - (d1 + 2)).floorDiv(8) % 8;
146+
expr = expr.replaceDims(
147+
{b.getAffineConstantExpr(1), b.getAffineConstantExpr(1)});
148+
AffineExpr result = mlir::simplifyAffineExpr(expr, 2, 0);
149+
150+
ASSERT_TRUE(isa<AffineConstantExpr>(result));
151+
ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7);
152+
}

0 commit comments

Comments
 (0)