diff --git a/llvm/include/llvm/IR/ConstantRangeList.h b/llvm/include/llvm/IR/ConstantRangeList.h index 46edaff19e73f..44d1daebe49e4 100644 --- a/llvm/include/llvm/IR/ConstantRangeList.h +++ b/llvm/include/llvm/IR/ConstantRangeList.h @@ -72,6 +72,8 @@ class [[nodiscard]] ConstantRangeList { APInt(64, Upper, /*isSigned=*/true))); } + void subtract(const ConstantRange &SubRange); + /// Return the range list that results from the union of this /// ConstantRangeList with another ConstantRangeList, "CRL". ConstantRangeList unionWith(const ConstantRangeList &CRL) const; diff --git a/llvm/lib/IR/ConstantRangeList.cpp b/llvm/lib/IR/ConstantRangeList.cpp index 0373524a09f10..0856f79bb9191 100644 --- a/llvm/lib/IR/ConstantRangeList.cpp +++ b/llvm/lib/IR/ConstantRangeList.cpp @@ -81,6 +81,65 @@ void ConstantRangeList::insert(const ConstantRange &NewRange) { } } +void ConstantRangeList::subtract(const ConstantRange &SubRange) { + if (SubRange.isEmptySet() || empty()) + return; + assert(!SubRange.isFullSet() && "Do not support full set"); + assert(SubRange.getLower().slt(SubRange.getUpper())); + assert(getBitWidth() == SubRange.getBitWidth()); + // Handle common cases. + if (Ranges.back().getUpper().sle(SubRange.getLower())) + return; + if (SubRange.getUpper().sle(Ranges.front().getLower())) + return; + + SmallVector Result; + auto AppendRangeIfNonEmpty = [&Result](APInt Start, APInt End) { + if (Start.slt(End)) + Result.push_back(ConstantRange(Start, End)); + }; + for (auto &Range : Ranges) { + if (SubRange.getUpper().sle(Range.getLower()) || + Range.getUpper().sle(SubRange.getLower())) { + // "Range" and "SubRange" do not overlap. + // L---U : Range + // L---U : SubRange (Case1) + // L---U : SubRange (Case2) + Result.push_back(Range); + } else if (Range.getLower().sle(SubRange.getLower()) && + SubRange.getUpper().sle(Range.getUpper())) { + // "Range" contains "SubRange". + // L---U : Range + // L-U : SubRange + // Note that ConstantRange::contains(ConstantRange) checks unsigned, + // but we need signed checking here. + AppendRangeIfNonEmpty(Range.getLower(), SubRange.getLower()); + AppendRangeIfNonEmpty(SubRange.getUpper(), Range.getUpper()); + } else if (SubRange.getLower().sle(Range.getLower()) && + Range.getUpper().sle(SubRange.getUpper())) { + // "SubRange" contains "Range". + // L-U : Range + // L---U : SubRange + continue; + } else if (Range.getLower().sge(SubRange.getLower()) && + Range.getLower().sle(SubRange.getUpper())) { + // "Range" and "SubRange" overlap at the left. + // L---U : Range + // L---U : SubRange + AppendRangeIfNonEmpty(SubRange.getUpper(), Range.getUpper()); + } else { + // "Range" and "SubRange" overlap at the right. + // L---U : Range + // L---U : SubRange + assert(SubRange.getLower().sge(Range.getLower()) && + SubRange.getLower().sle(Range.getUpper())); + AppendRangeIfNonEmpty(Range.getLower(), SubRange.getLower()); + } + } + + Ranges = Result; +} + ConstantRangeList ConstantRangeList::unionWith(const ConstantRangeList &CRL) const { assert(getBitWidth() == CRL.getBitWidth() && diff --git a/llvm/unittests/IR/ConstantRangeListTest.cpp b/llvm/unittests/IR/ConstantRangeListTest.cpp index b679dd3a33d5d..d00e0a8ff2a97 100644 --- a/llvm/unittests/IR/ConstantRangeListTest.cpp +++ b/llvm/unittests/IR/ConstantRangeListTest.cpp @@ -101,6 +101,58 @@ ConstantRangeList GetCRL(ArrayRef> Pairs) { return ConstantRangeList(Ranges); } +TEST_F(ConstantRangeListTest, Subtract) { + APInt AP0 = APInt(64, 0, /*isSigned=*/true); + APInt AP2 = APInt(64, 2, /*isSigned=*/true); + APInt AP3 = APInt(64, 3, /*isSigned=*/true); + APInt AP4 = APInt(64, 4, /*isSigned=*/true); + APInt AP8 = APInt(64, 8, /*isSigned=*/true); + APInt AP10 = APInt(64, 10, /*isSigned=*/true); + APInt AP11 = APInt(64, 11, /*isSigned=*/true); + APInt AP12 = APInt(64, 12, /*isSigned=*/true); + ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}}); + + // Execute ConstantRangeList::subtract(ConstantRange) and check the result + // is expected. Pass "CRL" by value so that subtract() does not affect the + // argument in caller. + auto SubtractAndCheck = [](ConstantRangeList CRL, + const std::pair &Range, + const ConstantRangeList &ExpectedCRL) { + CRL.subtract(ConstantRange(APInt(64, Range.first, /*isSigned=*/true), + APInt(64, Range.second, /*isSigned=*/true))); + EXPECT_EQ(CRL, ExpectedCRL); + }; + + // No overlap + SubtractAndCheck(CRL, {-4, 0}, CRL); + SubtractAndCheck(CRL, {4, 8}, CRL); + SubtractAndCheck(CRL, {12, 16}, CRL); + + // Overlap (left, right, or both) + SubtractAndCheck(CRL, {-4, 2}, GetCRL({{AP2, AP4}, {AP8, AP12}})); + SubtractAndCheck(CRL, {-4, 4}, GetCRL({{AP8, AP12}})); + SubtractAndCheck(CRL, {-4, 8}, GetCRL({{AP8, AP12}})); + SubtractAndCheck(CRL, {0, 2}, GetCRL({{AP2, AP4}, {AP8, AP12}})); + SubtractAndCheck(CRL, {0, 4}, GetCRL({{AP8, AP12}})); + SubtractAndCheck(CRL, {0, 8}, GetCRL({{AP8, AP12}})); + SubtractAndCheck(CRL, {10, 12}, GetCRL({{AP0, AP4}, {AP8, AP10}})); + SubtractAndCheck(CRL, {8, 12}, GetCRL({{AP0, AP4}})); + SubtractAndCheck(CRL, {6, 12}, GetCRL({{AP0, AP4}})); + SubtractAndCheck(CRL, {10, 16}, GetCRL({{AP0, AP4}, {AP8, AP10}})); + SubtractAndCheck(CRL, {8, 16}, GetCRL({{AP0, AP4}})); + SubtractAndCheck(CRL, {6, 16}, GetCRL({{AP0, AP4}})); + SubtractAndCheck(CRL, {2, 10}, GetCRL({{AP0, AP2}, {AP10, AP12}})); + + // Subset + SubtractAndCheck(CRL, {2, 3}, GetCRL({{AP0, AP2}, {AP3, AP4}, {AP8, AP12}})); + SubtractAndCheck(CRL, {10, 11}, + GetCRL({{AP0, AP4}, {AP8, AP10}, {AP11, AP12}})); + + // Superset + SubtractAndCheck(CRL, {0, 12}, GetCRL({})); + SubtractAndCheck(CRL, {-4, 16}, GetCRL({})); +} + TEST_F(ConstantRangeListTest, Union) { APInt APN4 = APInt(64, -4, /*isSigned=*/true); APInt APN2 = APInt(64, -2, /*isSigned=*/true);