Skip to content

Commit 463ea28

Browse files
committed
[InstCombine] Fold comparison of integers by parts
Let's say you represent (i32, i32) as an i64 from which the parts are extracted with lshr/trunc. Then, if you compare two tuples by parts you get something like A[0] == B[0] && A[1] == B[1], just that the part extraction happens by lshr/trunc and not a narrow load or similar. The fold implemented here reduces such equality comparisons by converting them into a comparison on a larger part of the integer (which might be the whole integer). It handles both the "and of eq" and the conjugated "or of ne" case. I'm being conservative with one-use for now, though this could be relaxed if profitable (the base pattern converts 11 instructions into 5 instructions, but there's quite a few variations on how it can play out). Differential Revision: https://reviews.llvm.org/D101232
1 parent 93a9a8a commit 463ea28

File tree

2 files changed

+179
-240
lines changed

2 files changed

+179
-240
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

+87
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,87 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
10761076
return nullptr;
10771077
}
10781078

1079+
struct IntPart {
1080+
Value *From;
1081+
unsigned StartBit;
1082+
unsigned NumBits;
1083+
};
1084+
1085+
/// Match an extraction of bits from an integer.
1086+
static Optional<IntPart> matchIntPart(Value *V) {
1087+
Value *X;
1088+
if (!match(V, m_OneUse(m_Trunc(m_Value(X)))))
1089+
return None;
1090+
1091+
unsigned NumOriginalBits = X->getType()->getScalarSizeInBits();
1092+
unsigned NumExtractedBits = V->getType()->getScalarSizeInBits();
1093+
Value *Y;
1094+
const APInt *Shift;
1095+
// For a trunc(lshr Y, Shift) pattern, make sure we're only extracting bits
1096+
// from Y, not any shifted-in zeroes.
1097+
if (match(X, m_OneUse(m_LShr(m_Value(Y), m_APInt(Shift)))) &&
1098+
Shift->ule(NumOriginalBits - NumExtractedBits))
1099+
return {{Y, (unsigned)Shift->getZExtValue(), NumExtractedBits}};
1100+
return {{X, 0, NumExtractedBits}};
1101+
}
1102+
1103+
/// Materialize an extraction of bits from an integer in IR.
1104+
static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) {
1105+
Value *V = P.From;
1106+
if (P.StartBit)
1107+
V = Builder.CreateLShr(V, P.StartBit);
1108+
Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits);
1109+
if (TruncTy != V->getType())
1110+
V = Builder.CreateTrunc(V, TruncTy);
1111+
return V;
1112+
}
1113+
1114+
/// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01
1115+
/// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01
1116+
/// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer.
1117+
static Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd,
1118+
InstCombiner::BuilderTy &Builder) {
1119+
if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse())
1120+
return nullptr;
1121+
1122+
CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
1123+
if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
1124+
return nullptr;
1125+
1126+
Optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
1127+
Optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
1128+
Optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
1129+
Optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
1130+
if (!L0 || !R0 || !L1 || !R1)
1131+
return nullptr;
1132+
1133+
// Make sure the LHS/RHS compare a part of the same value, possibly after
1134+
// an operand swap.
1135+
if (L0->From != L1->From || R0->From != R1->From) {
1136+
if (L0->From != R1->From || R0->From != L1->From)
1137+
return nullptr;
1138+
std::swap(L1, R1);
1139+
}
1140+
1141+
// Make sure the extracted parts are adjacent, canonicalizing to L0/R0 being
1142+
// the low part and L1/R1 being the high part.
1143+
if (L0->StartBit + L0->NumBits != L1->StartBit ||
1144+
R0->StartBit + R0->NumBits != R1->StartBit) {
1145+
if (L1->StartBit + L1->NumBits != L0->StartBit ||
1146+
R1->StartBit + R1->NumBits != R0->StartBit)
1147+
return nullptr;
1148+
std::swap(L0, L1);
1149+
std::swap(R0, R1);
1150+
}
1151+
1152+
// We can simplify to a comparison of these larger parts of the integers.
1153+
IntPart L = {L0->From, L0->StartBit, L0->NumBits + L1->NumBits};
1154+
IntPart R = {R0->From, R0->StartBit, R0->NumBits + R1->NumBits};
1155+
Value *LValue = extractIntPart(L, Builder);
1156+
Value *RValue = extractIntPart(R, Builder);
1157+
return Builder.CreateICmp(Pred, LValue, RValue);
1158+
}
1159+
10791160
/// Reduce logic-of-compares with equality to a constant by substituting a
10801161
/// common operand with the constant. Callers are expected to call this with
10811162
/// Cmp0/Cmp1 switched to handle logic op commutativity.
@@ -1181,6 +1262,9 @@ Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
11811262
foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder))
11821263
return X;
11831264

1265+
if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/true, Builder))
1266+
return X;
1267+
11841268
// This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2).
11851269
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
11861270

@@ -2411,6 +2495,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
24112495
foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder))
24122496
return X;
24132497

2498+
if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/false, Builder))
2499+
return X;
2500+
24142501
// (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
24152502
// TODO: Remove this when foldLogOpOfMaskedICmps can handle vectors.
24162503
if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_Zero()) &&

0 commit comments

Comments
 (0)