diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 2291d64c50a56..59df0cd6833db 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -1385,7 +1385,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; } else { // Reuse the existing local id. - lhs[getLocalVarStartIndex() + loc] = -rhsConst; + lhs[getLocalVarStartIndex() + loc] -= rhsConst; } return success(); } diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp index 9e89a5b79e2e2..8a2d697540d5c 100644 --- a/mlir/unittests/IR/AffineExprTest.cpp +++ b/mlir/unittests/IR/AffineExprTest.cpp @@ -129,3 +129,21 @@ TEST(AffineExprTest, d0PlusD0FloorDivNeg2) { auto sum = d0 + d0.floorDiv(-2) * 2; ASSERT_EQ(toString(sum), "d0 + (d0 floordiv -2) * 2"); } + +TEST(AffineExprTest, simpleAffineExprFlattenerRegression) { + + // Regression test for a bug where mod simplification was not handled + // properly when `lhs % rhs` was happened to have the property that `lhs + // floordiv rhs = lhs`. + MLIRContext ctx; + OpBuilder b(&ctx); + + auto d0 = b.getAffineDimExpr(0); + + // Manually replace variables by constants to avoid constant folding. + AffineExpr expr = (d0 - (d0 + 2)).floorDiv(8) % 8; + AffineExpr result = mlir::simplifyAffineExpr(expr, 1, 0); + + ASSERT_TRUE(isa(result)); + ASSERT_EQ(cast(result).getValue(), 7); +}