Skip to content

Commit 1001f90

Browse files
committed
[InstCombine] Optimize and of icmps with power-of-2 and contiguous masks
Add an instance combine optimization for expressions of the form: (%arg u< C1) & ((%arg & C2) != C2) -> %arg u< C2 Where C1 is a power-of-2 and C2 is a contiguous mask starting 1 bit below C1. This commit resolves GitHub missed-optimization issue #54856. Validation of scalar tests: - https://alive2.llvm.org/ce/z/JfKjiU - https://alive2.llvm.org/ce/z/AruHY_ - https://alive2.llvm.org/ce/z/JAiR6t - https://alive2.llvm.org/ce/z/S2X2e5 - https://alive2.llvm.org/ce/z/4cycdE - https://alive2.llvm.org/ce/z/NcDiLP Validation of vector tests: - https://alive2.llvm.org/ce/z/ABY6tE - https://alive2.llvm.org/ce/z/BTJi3s - https://alive2.llvm.org/ce/z/3BKWpu - https://alive2.llvm.org/ce/z/RrAbkj - https://alive2.llvm.org/ce/z/nM6fsN Reviewed By: goldstein.w.n Differential Revision: https://reviews.llvm.org/D125717
1 parent ea868d5 commit 1001f90

File tree

3 files changed

+137
-96
lines changed

3 files changed

+137
-96
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,14 @@ inline cst_pred_ty<is_any_apint> m_AnyIntegralConstant() {
445445
return cst_pred_ty<is_any_apint>();
446446
}
447447

448+
struct is_shifted_mask {
449+
bool isValue(const APInt &C) { return C.isShiftedMask(); }
450+
};
451+
452+
inline cst_pred_ty<is_shifted_mask> m_ShiftedMask() {
453+
return cst_pred_ty<is_shifted_mask>();
454+
}
455+
448456
struct is_all_ones {
449457
bool isValue(const APInt &C) { return C.isAllOnes(); }
450458
};

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,108 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd,
955955
return nullptr;
956956
}
957957

958+
/// Try to fold (icmp(A & B) == 0) & (icmp(A & D) != E) into (icmp A u< D) iff
959+
/// B is a contiguous set of ones starting from the most significant bit
960+
/// (negative power of 2), D and E are equal, and D is a contiguous set of ones
961+
/// starting at the most significant zero bit in B. Parameter B supports masking
962+
/// using undef/poison in either scalar or vector values.
963+
static Value *foldNegativePower2AndShiftedMask(
964+
Value *A, Value *B, Value *D, Value *E, ICmpInst::Predicate PredL,
965+
ICmpInst::Predicate PredR, InstCombiner::BuilderTy &Builder) {
966+
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
967+
"Expected equality predicates for masked type of icmps.");
968+
if (PredL != ICmpInst::ICMP_EQ || PredR != ICmpInst::ICMP_NE)
969+
return nullptr;
970+
971+
if (!match(B, m_NegatedPower2()) || !match(D, m_ShiftedMask()) ||
972+
!match(E, m_ShiftedMask()))
973+
return nullptr;
974+
975+
// Test scalar arguments for conversion. B has been validated earlier to be a
976+
// negative power of two and thus is guaranteed to have one or more contiguous
977+
// ones starting from the MSB followed by zero or more contiguous zeros. D has
978+
// been validated earlier to be a shifted set of one or more contiguous ones.
979+
// In order to match, B leading ones and D leading zeros should be equal. The
980+
// predicate that B be a negative power of 2 prevents the condition of there
981+
// ever being zero leading ones. Thus 0 == 0 cannot occur. The predicate that
982+
// D always be a shifted mask prevents the condition of D equaling 0. This
983+
// prevents matching the condition where B contains the maximum number of
984+
// leading one bits (-1) and D contains the maximum number of leading zero
985+
// bits (0).
986+
auto isReducible = [](const Value *B, const Value *D, const Value *E) {
987+
const APInt *BCst, *DCst, *ECst;
988+
return match(B, m_APIntAllowUndef(BCst)) && match(D, m_APInt(DCst)) &&
989+
match(E, m_APInt(ECst)) && *DCst == *ECst &&
990+
(isa<UndefValue>(B) ||
991+
(BCst->countLeadingOnes() == DCst->countLeadingZeros()));
992+
};
993+
994+
// Test vector type arguments for conversion.
995+
if (const auto *BVTy = dyn_cast<VectorType>(B->getType())) {
996+
const auto *BFVTy = dyn_cast<FixedVectorType>(BVTy);
997+
const auto *BConst = dyn_cast<Constant>(B);
998+
const auto *DConst = dyn_cast<Constant>(D);
999+
const auto *EConst = dyn_cast<Constant>(E);
1000+
1001+
if (!BFVTy || !BConst || !DConst || !EConst)
1002+
return nullptr;
1003+
1004+
for (unsigned I = 0; I != BFVTy->getNumElements(); ++I) {
1005+
const auto *BElt = BConst->getAggregateElement(I);
1006+
const auto *DElt = DConst->getAggregateElement(I);
1007+
const auto *EElt = EConst->getAggregateElement(I);
1008+
1009+
if (!BElt || !DElt || !EElt)
1010+
return nullptr;
1011+
if (!isReducible(BElt, DElt, EElt))
1012+
return nullptr;
1013+
}
1014+
} else {
1015+
// Test scalar type arguments for conversion.
1016+
if (!isReducible(B, D, E))
1017+
return nullptr;
1018+
}
1019+
return Builder.CreateICmp(ICmpInst::ICMP_ULT, A, D);
1020+
}
1021+
1022+
/// Try to fold ((icmp X u< P) & (icmp(X & M) != M)) or ((icmp X s> -1) &
1023+
/// (icmp(X & M) != M)) into (icmp X u< M). Where P is a power of 2, M < P, and
1024+
/// M is a contiguous shifted mask starting at the right most significant zero
1025+
/// bit in P. SGT is supported as when P is the largest representable power of
1026+
/// 2, an earlier optimization converts the expression into (icmp X s> -1).
1027+
/// Parameter P supports masking using undef/poison in either scalar or vector
1028+
/// values.
1029+
static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
1030+
bool JoinedByAnd,
1031+
InstCombiner::BuilderTy &Builder) {
1032+
if (!JoinedByAnd)
1033+
return nullptr;
1034+
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
1035+
ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
1036+
CmpPred1 = Cmp1->getPredicate();
1037+
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
1038+
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
1039+
// SignMask) == 0).
1040+
std::optional<std::pair<unsigned, unsigned>> MaskPair =
1041+
getMaskedTypeForICmpPair(A, B, C, D, E, Cmp0, Cmp1, CmpPred0, CmpPred1);
1042+
if (!MaskPair)
1043+
return nullptr;
1044+
1045+
const auto compareBMask = BMask_NotMixed | BMask_NotAllOnes;
1046+
unsigned CmpMask0 = MaskPair->first;
1047+
unsigned CmpMask1 = MaskPair->second;
1048+
if ((CmpMask0 & Mask_AllZeros) && (CmpMask1 == compareBMask)) {
1049+
if (Value *V = foldNegativePower2AndShiftedMask(A, B, D, E, CmpPred0,
1050+
CmpPred1, Builder))
1051+
return V;
1052+
} else if ((CmpMask0 == compareBMask) && (CmpMask1 & Mask_AllZeros)) {
1053+
if (Value *V = foldNegativePower2AndShiftedMask(A, D, B, C, CmpPred1,
1054+
CmpPred0, Builder))
1055+
return V;
1056+
}
1057+
return nullptr;
1058+
}
1059+
9581060
/// Commuted variants are assumed to be handled by calling this function again
9591061
/// with the parameters swapped.
9601062
static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
@@ -2925,6 +3027,9 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
29253027
if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder))
29263028
return V;
29273029

3030+
if (Value *V = foldPowerOf2AndShiftedMask(LHS, RHS, IsAnd, Builder))
3031+
return V;
3032+
29283033
// TODO: Verify whether this is safe for logical and/or.
29293034
if (!IsLogical) {
29303035
if (Value *X = foldUnsignedUnderflowCheck(LHS, RHS, IsAnd, Q, Builder))

0 commit comments

Comments
 (0)