Skip to content

Commit d8da229

Browse files
authored
[mlir][scf] Interpret trip counts as unsigned integers (#178060)
Trip counts represent iteration counts and are always non-negative. This PR fixes all call sites to correctly use `getZExtValue()` instead of `getSExtValue()` when extracting trip count values from `APInt`. Also documents to clarify results are unsigned.
1 parent 507d185 commit d8da229

File tree

7 files changed

+129
-14
lines changed

7 files changed

+129
-14
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
251251
/// Get constant trip counts for each of the induction variables of the given
252252
/// loop operation. If any of the loop's trip counts is not constant, return an
253253
/// empty vector.
254+
/// TODO(#178506): Should return SmallVector<uint64_t> for correct signedness.
254255
llvm::SmallVector<int64_t>
255256
getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp);
256257

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,11 @@ foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
211211
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
212212

213213
/// Return the number of iterations for a loop with a lower bound `lb`, upper
214-
/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
215-
/// comparison between lb and ub is signed or unsigned. A negative step or a
216-
/// lower bound greater than the upper bound are considered invalid and will
217-
/// yield a zero trip count.
214+
/// bound `ub` and step `step`, as an unsigned integer. The `isSigned` flag
215+
/// indicates whether the loop comparison between lb and ub is signed or
216+
/// unsigned. (The result of this function must be interpreted as an unsigned
217+
/// integer.) A lower bound greater than the upper bound is considered invalid
218+
/// and will yield a zero trip count.
218219
/// The `computeUbMinusLb` callback is invoked to compute the difference between
219220
/// the upper and lower bound when not constant. It can be used by the client
220221
/// to compute a static difference when the bounds are not constant.

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
443443
LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
444444
<< " for loop "
445445
<< OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
446-
if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
446+
if (!tripCount.has_value() || tripCount->getZExtValue() > 1)
447447
return failure();
448448

449449
if (*tripCount == 0) {

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,12 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
397397
return UnrolledLoopInfo{forOp, std::nullopt};
398398
}
399399

400+
// TODO(#178506): This may overflow for large trip counts. Should use
401+
// uint64_t.
400402
int64_t tripCountEvenMultiple =
401-
constTripCount->getSExtValue() -
402-
(constTripCount->getSExtValue() % unrollFactor);
403+
constTripCount->getZExtValue() -
404+
(constTripCount->getZExtValue() % unrollFactor);
405+
// TODO(#178506): This may overflow when computing upperBoundUnrolledCst.
403406
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
404407
int64_t stepUnrolledCst = stepCst * unrollFactor;
405408

@@ -501,9 +504,9 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
501504
const APInt &tripCount = *mayBeConstantTripCount;
502505
if (tripCount.isZero())
503506
return success();
504-
if (tripCount.getSExtValue() == 1)
507+
if (tripCount.getZExtValue() == 1)
505508
return forOp.promoteIfSingleIteration(rewriter);
506-
return loopUnrollByFactor(forOp, tripCount.getSExtValue());
509+
return loopUnrollByFactor(forOp, tripCount.getZExtValue());
507510
}
508511

509512
/// Check if bounds of all inner loops are defined outside of `forOp`
@@ -554,7 +557,7 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
554557
"trip "
555558
"count";
556559
unrollJamFactor = tripCount->getZExtValue();
557-
} else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
560+
} else if (tripCount->getZExtValue() % unrollJamFactor != 0) {
558561
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
559562
"multiple of unroll jam factor";
560563
return failure();
@@ -1567,13 +1570,16 @@ mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
15671570
std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
15681571
if (!loBnds || !upBnds || !steps)
15691572
return {};
1573+
// TODO(#178506): The result should be SmallVector<uint64_t> and use uint64_t
1574+
// for trip counts.
15701575
llvm::SmallVector<int64_t> tripCounts;
15711576
for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1577+
// TODO(#178506): Signedness is not handled correctly here.
15721578
std::optional<llvm::APInt> numIter = constantTripCount(
15731579
lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
15741580
if (!numIter)
15751581
return {};
1576-
tripCounts.push_back(numIter->getSExtValue());
1582+
tripCounts.push_back(numIter->getZExtValue());
15771583
}
15781584
return tripCounts;
15791585
}

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,6 @@ std::optional<APInt> constantTripCount(
337337
// case applies, so the static trip count is unknown.
338338
return std::nullopt;
339339
}
340-
if (stepCst.isNegative())
341-
return APInt(bitwidth, 0);
342340
}
343341

344342
if (isIndex) {
@@ -392,6 +390,14 @@ std::optional<APInt> constantTripCount(
392390
return std::nullopt;
393391
}
394392
auto &stepCst = maybeStepCst->first;
393+
// For signed loops, a negative step size could indicate an infinite number of
394+
// iterations.
395+
if (isSigned && stepCst.isSignBitSet()) {
396+
LDBG() << "constantTripCount is infinite because step is negative";
397+
return std::nullopt;
398+
}
399+
400+
// Create new APSInt instances with explicit signedness to ensure they match
395401
llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst);
396402
llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst);
397403
if (!remainder.isZero())

mlir/test/Dialect/SCF/trip_count.mlir

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,4 +699,105 @@ func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(%lb : i32, %other : i3
699699
scf.yield %arg0 : i32
700700
}
701701
return %1 : i32
702+
}
703+
704+
// -----
705+
706+
// CHECK-LABEL:func.func @trip_count_i8_unsigned_full_range(
707+
func.func @trip_count_i8_unsigned_full_range(%a : i32, %b : i32) -> i32 {
708+
%c0 = arith.constant 0 : i8
709+
%c255 = arith.constant 255 : i8
710+
%c1 = arith.constant 1 : i8
711+
712+
// Unsigned i8 from 0 to 255: trip count is 255
713+
// Trip counts are returned in their natural bitwidth and printed as signed.
714+
// 255 in i8 is represented as -1 when printed in signed format.
715+
// CHECK: "test.trip-count" = -1 : i8
716+
%r = scf.for unsigned %i = %c0 to %c255 step %c1 iter_args(%0 = %a) -> i32 : i8 {
717+
scf.yield %b : i32
718+
}
719+
return %r : i32
720+
}
721+
722+
// -----
723+
724+
// CHECK-LABEL:func.func @trip_count_i8_unsigned_partial_range(
725+
func.func @trip_count_i8_unsigned_partial_range(%a : i32, %b : i32) -> i32 {
726+
%c0 = arith.constant 0 : i8
727+
%c200 = arith.constant 200 : i8
728+
%c1 = arith.constant 1 : i8
729+
730+
// Unsigned i8 from 0 to 200: trip count is 200
731+
// 200 in i8 is represented as -56 when printed in signed format.
732+
// CHECK: "test.trip-count" = -56 : i8
733+
%r = scf.for unsigned %i = %c0 to %c200 step %c1 iter_args(%0 = %a) -> i32 : i8 {
734+
scf.yield %b : i32
735+
}
736+
return %r : i32
737+
}
738+
739+
// -----
740+
741+
// CHECK-LABEL:func.func @trip_count_i8_unsigned_high_range(
742+
func.func @trip_count_i8_unsigned_high_range(%a : i32, %b : i32) -> i32 {
743+
%c200 = arith.constant 200 : i8
744+
%c255 = arith.constant 255 : i8
745+
%c1 = arith.constant 1 : i8
746+
747+
// Unsigned i8 from 200 to 255: trip count is 55
748+
// CHECK: "test.trip-count" = 55 : i8
749+
%r = scf.for unsigned %i = %c200 to %c255 step %c1 iter_args(%0 = %a) -> i32 : i8 {
750+
scf.yield %b : i32
751+
}
752+
return %r : i32
753+
}
754+
755+
// -----
756+
757+
// CHECK-LABEL:func.func @trip_count_i8_signed_crossing_zero(
758+
func.func @trip_count_i8_signed_crossing_zero(%a : i32, %b : i32) -> i32 {
759+
%c-128 = arith.constant -128 : i32
760+
%c127 = arith.constant 127 : i32
761+
%c1 = arith.constant 1 : i32
762+
763+
// Signed i32 from -128 to 127, crossing zero
764+
// CHECK: "test.trip-count" = 255
765+
%r = scf.for %i = %c-128 to %c127 step %c1 iter_args(%0 = %a) -> i32 : i32 {
766+
scf.yield %b : i32
767+
}
768+
return %r : i32
769+
}
770+
771+
// -----
772+
773+
// CHECK-LABEL:func.func @trip_count_i16_unsigned_full_range(
774+
func.func @trip_count_i16_unsigned_full_range(%a : i32, %b : i32) -> i32 {
775+
%c0 = arith.constant 0 : i16
776+
%c65535 = arith.constant 65535 : i16
777+
%c1 = arith.constant 1 : i16
778+
779+
// Unsigned i16 from 0 to 65535: trip count is 65535
780+
// 65535 in i16 is represented as -1 when printed in signed format.
781+
// CHECK: "test.trip-count" = -1 : i16
782+
%r = scf.for unsigned %i = %c0 to %c65535 step %c1 iter_args(%0 = %a) -> i32 : i16 {
783+
scf.yield %b : i32
784+
}
785+
return %r : i32
786+
}
787+
788+
// -----
789+
790+
// CHECK-LABEL:func.func @trip_count_i8_unsigned_step_2(
791+
func.func @trip_count_i8_unsigned_step_2(%a : i32, %b : i32) -> i32 {
792+
%c0 = arith.constant 0 : i8
793+
%c255 = arith.constant 255 : i8
794+
%c2 = arith.constant 2 : i8
795+
796+
// Unsigned i8 from 0 to 255 step 2: trip count is 128 (255/2 rounded up)
797+
// 128 in i8 is represented as -128 when printed in signed format.
798+
// CHECK: "test.trip-count" = -128 : i8
799+
%r = scf.for unsigned %i = %c0 to %c255 step %c2 iter_args(%0 = %a) -> i32 : i8 {
800+
scf.yield %b : i32
801+
}
802+
return %r : i32
702803
}

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct TestSCFForUtilsPass
5050
"test.trip-count",
5151
IntegerAttr::get(IntegerType::get(&getContext(),
5252
tripCount.value().getBitWidth()),
53-
tripCount.value().getSExtValue()));
53+
tripCount.value().getZExtValue()));
5454
else
5555
loopOp->setDiscardableAttr("test.trip-count",
5656
StringAttr::get(&getContext(), "none"));

0 commit comments

Comments
 (0)