From 349ab902ffa3d07ef33ea78069e62cdfc4ce9ea8 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Tue, 24 Sep 2024 15:53:44 +0200 Subject: [PATCH 1/4] [InstCombine] Decompose more icmps into masks Extend decomposeBitTestICmp() to handle cases where the resulting comparison is of the form `icmp (X & Mask) pred Cmp` with non-zero `Cmp`. Add a flag to allow code to opt-in to this behavior and use it in the "log op of icmp" fold infrastructure. This addresses regressions from #97289. Proofs: https://alive2.llvm.org/ce/z/hUhdbU --- llvm/include/llvm/Analysis/CmpInstAnalysis.h | 9 ++- llvm/lib/Analysis/CmpInstAnalysis.cpp | 60 +++++++++++++++---- .../InstCombine/InstCombineAndOrXor.cpp | 5 +- .../InstCombine/InstCombineCompares.cpp | 21 +------ .../Transforms/InstCombine/and-or-icmps.ll | 18 ++---- 5 files changed, 66 insertions(+), 47 deletions(-) diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h index 406dacd930605..79b325c620f60 100644 --- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h +++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h @@ -92,18 +92,21 @@ namespace llvm { Constant *getPredForFCmpCode(unsigned Code, Type *OpTy, CmpInst::Predicate &Pred); - /// Represents the operation icmp (X & Mask) pred 0, where pred can only be + /// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be /// eq or ne. struct DecomposedBitTest { Value *X; CmpInst::Predicate Pred; APInt Mask; + APInt Cmp; }; - /// Decompose an icmp into the form ((X & Mask) pred 0) if possible. + /// Decompose an icmp into the form ((X & Mask) pred Cmp) if possible. + /// Unless \p AllowNonZeroCmp is true, Cmp will always be 0. std::optional decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, - bool LookThroughTrunc = true); + bool LookThroughTrunc = true, + bool AllowNonZeroCmp = false); } // end namespace llvm diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp index ad111559b0d85..2fe3aebe8190c 100644 --- a/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy, std::optional llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, - bool LookThruTrunc) { + bool LookThruTrunc, bool AllowNonZeroCmp) { using namespace PatternMatch; const APInt *OrigC; @@ -100,22 +100,57 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, switch (Pred) { default: llvm_unreachable("Unexpected predicate"); - case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLT: { // X < 0 is equivalent to (X & SignMask) != 0. - if (!C.isZero()) - return std::nullopt; - Result.Mask = APInt::getSignMask(C.getBitWidth()); - Result.Pred = ICmpInst::ICMP_NE; - break; + if (C.isZero()) { + Result.Mask = APInt::getSignMask(C.getBitWidth()); + Result.Cmp = APInt::getZero(C.getBitWidth()); + Result.Pred = ICmpInst::ICMP_NE; + break; + } + + APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth()); + if (FlippedSign.isPowerOf2()) { + // X s< 10000100 is equivalent to (X & 11111100 == 10000000) + Result.Mask = -FlippedSign; + Result.Cmp = APInt::getSignMask(C.getBitWidth()); + Result.Pred = ICmpInst::ICMP_EQ; + break; + } + + if (FlippedSign.isNegatedPowerOf2()) { + // X s< 01111100 is equivalent to (X & 11111100 != 01111100) + Result.Mask = FlippedSign; + Result.Cmp = C; + Result.Pred = ICmpInst::ICMP_NE; + break; + } + + return std::nullopt; + } case ICmpInst::ICMP_ULT: // X getType()->getScalarSizeInBits()); + Result.Cmp = Result.Cmp.zext(X->getType()->getScalarSizeInBits()); } else { Result.X = LHS; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index e8c0b00661654..52eda8cdee46b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) { // Adapts the external decomposeBitTestICmp for local use. static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, Value *&X, Value *&Y, Value *&Z) { - auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred); + auto Res = llvm::decomposeBitTestICmp( + LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroCmp=*/true); if (!Res) return false; Pred = Res->Pred; X = Res->X; Y = ConstantInt::get(X->getType(), Res->Mask); - Z = ConstantInt::get(X->getType(), 0); + Z = ConstantInt::get(X->getType(), Res->Cmp); return true; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index d0aa63ef06ba8..e9e3dd5124a92 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5934,29 +5934,14 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) { // This matches patterns corresponding to tests of the signbit as well as: // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) - if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) { + if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true, + /*AllowNonZeroCmp=*/true)) { Value *And = Builder.CreateAnd(Res->X, Res->Mask); - Constant *Zero = ConstantInt::getNullValue(Res->X->getType()); + Constant *Zero = ConstantInt::get(Res->X->getType(), Res->Cmp); return new ICmpInst(Res->Pred, And, Zero); } unsigned SrcBits = X->getType()->getScalarSizeInBits(); - if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) { - // If C is a negative power-of-2 (high-bit mask): - // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?) - Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits)); - Value *And = Builder.CreateAnd(X, MaskC); - return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC); - } - - if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) { - // If C is not-of-power-of-2 (one clear bit): - // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?) - Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits)); - Value *And = Builder.CreateAnd(X, MaskC); - return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); - } - if (auto *II = dyn_cast(X)) { if (II->getIntrinsicID() == Intrinsic::cttz || II->getIntrinsicID() == Intrinsic::ctlz) { diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll index 26f708dc787c7..ad28ad980de5b 100644 --- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll +++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll @@ -3335,10 +3335,8 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) { define i1 @and_slt_to_mask(i8 %x) { ; CHECK-LABEL: @and_slt_to_mask( -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2 +; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128 ; CHECK-NEXT: ret i1 [[AND2]] ; %cmp = icmp slt i8 %x, -124 @@ -3365,10 +3363,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) { define i1 @and_sgt_to_mask(i8 %x) { ; CHECK-LABEL: @and_sgt_to_mask( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2 +; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124 ; CHECK-NEXT: ret i1 [[AND2]] ; %cmp = icmp sgt i8 %x, 123 @@ -3395,10 +3391,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) { define i1 @and_ugt_to_mask(i8 %x) { ; CHECK-LABEL: @and_ugt_to_mask( -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2 +; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4 ; CHECK-NEXT: ret i1 [[AND2]] ; %cmp = icmp ugt i8 %x, -5 From e97dceb848aeaecd07ae461bba586d54f0ead364 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Wed, 2 Oct 2024 18:38:38 +0200 Subject: [PATCH 2/4] Rename Cmp -> C --- llvm/include/llvm/Analysis/CmpInstAnalysis.h | 10 +++++----- llvm/lib/Analysis/CmpInstAnalysis.cpp | 16 ++++++++-------- .../InstCombine/InstCombineAndOrXor.cpp | 4 ++-- .../InstCombine/InstCombineCompares.cpp | 4 ++-- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h index 79b325c620f60..c7862a6d39d07 100644 --- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h +++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h @@ -92,21 +92,21 @@ namespace llvm { Constant *getPredForFCmpCode(unsigned Code, Type *OpTy, CmpInst::Predicate &Pred); - /// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be + /// Represents the operation icmp (X & Mask) pred C, where pred can only be /// eq or ne. struct DecomposedBitTest { Value *X; CmpInst::Predicate Pred; APInt Mask; - APInt Cmp; + APInt C; }; - /// Decompose an icmp into the form ((X & Mask) pred Cmp) if possible. - /// Unless \p AllowNonZeroCmp is true, Cmp will always be 0. + /// Decompose an icmp into the form ((X & Mask) pred C) if possible. + /// Unless \p AllowNonZeroC is true, C will always be 0. std::optional decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, bool LookThroughTrunc = true, - bool AllowNonZeroCmp = false); + bool AllowNonZeroC = false); } // end namespace llvm diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp index 2fe3aebe8190c..2580ea7e97248 100644 --- a/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy, std::optional llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, - bool LookThruTrunc, bool AllowNonZeroCmp) { + bool LookThruTrunc, bool AllowNonZeroC) { using namespace PatternMatch; const APInt *OrigC; @@ -104,7 +104,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, // X < 0 is equivalent to (X & SignMask) != 0. if (C.isZero()) { Result.Mask = APInt::getSignMask(C.getBitWidth()); - Result.Cmp = APInt::getZero(C.getBitWidth()); + Result.C = APInt::getZero(C.getBitWidth()); Result.Pred = ICmpInst::ICMP_NE; break; } @@ -113,7 +113,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, if (FlippedSign.isPowerOf2()) { // X s< 10000100 is equivalent to (X & 11111100 == 10000000) Result.Mask = -FlippedSign; - Result.Cmp = APInt::getSignMask(C.getBitWidth()); + Result.C = APInt::getSignMask(C.getBitWidth()); Result.Pred = ICmpInst::ICMP_EQ; break; } @@ -121,7 +121,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, if (FlippedSign.isNegatedPowerOf2()) { // X s< 01111100 is equivalent to (X & 11111100 != 01111100) Result.Mask = FlippedSign; - Result.Cmp = C; + Result.C = C; Result.Pred = ICmpInst::ICMP_NE; break; } @@ -132,7 +132,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, // X getType()->getScalarSizeInBits()); - Result.Cmp = Result.Cmp.zext(X->getType()->getScalarSizeInBits()); + Result.C = Result.C.zext(X->getType()->getScalarSizeInBits()); } else { Result.X = LHS; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 52eda8cdee46b..688601a8ffa54 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -182,14 +182,14 @@ static unsigned conjugateICmpMask(unsigned Mask) { static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, Value *&X, Value *&Y, Value *&Z) { auto Res = llvm::decomposeBitTestICmp( - LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroCmp=*/true); + LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true); if (!Res) return false; Pred = Res->Pred; X = Res->X; Y = ConstantInt::get(X->getType(), Res->Mask); - Z = ConstantInt::get(X->getType(), Res->Cmp); + Z = ConstantInt::get(X->getType(), Res->C); return true; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index e9e3dd5124a92..91bb87bbb2f63 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5935,9 +5935,9 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) { // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true, - /*AllowNonZeroCmp=*/true)) { + /*AllowNonZeroC=*/true)) { Value *And = Builder.CreateAnd(Res->X, Res->Mask); - Constant *Zero = ConstantInt::get(Res->X->getType(), Res->Cmp); + Constant *Zero = ConstantInt::get(Res->X->getType(), Res->C); return new ICmpInst(Res->Pred, And, Zero); } From f318e557f8f75e4266cc35ef676dea0b549abab2 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Fri, 4 Oct 2024 09:58:03 +0200 Subject: [PATCH 3/4] Update variable name and comment --- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 91bb87bbb2f63..12eec99e7d22b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5932,13 +5932,12 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) { return nullptr; // This matches patterns corresponding to tests of the signbit as well as: - // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) - // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) + // (trunc X) pred C2 --> (X & Mask) == C if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true, /*AllowNonZeroC=*/true)) { Value *And = Builder.CreateAnd(Res->X, Res->Mask); - Constant *Zero = ConstantInt::get(Res->X->getType(), Res->C); - return new ICmpInst(Res->Pred, And, Zero); + Constant *C = ConstantInt::get(Res->X->getType(), Res->C); + return new ICmpInst(Res->Pred, And, C); } unsigned SrcBits = X->getType()->getScalarSizeInBits(); From 3d948de60329325c347f9afcbccf7213f04fbc44 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Fri, 4 Oct 2024 10:11:49 +0200 Subject: [PATCH 4/4] Update test after rebase --- llvm/test/Transforms/InstCombine/and-or-icmps.ll | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll index ad28ad980de5b..9ddc628aa4769 100644 --- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll +++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll @@ -3335,8 +3335,7 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) { define i1 @and_slt_to_mask(i8 %x) { ; CHECK-LABEL: @and_slt_to_mask( -; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2 -; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128 +; CHECK-NEXT: [[AND2:%.*]] = icmp slt i8 [[X:%.*]], -126 ; CHECK-NEXT: ret i1 [[AND2]] ; %cmp = icmp slt i8 %x, -124