Skip to content

Commit cf9d1c1

Browse files
authored
[SDAG] Simplify SDNodeFlags with bitwise logic (#114061)
This patch allows using enumeration values directly and simplifies the implementation with bitwise logic. It addresses the comment in #113808 (comment).
1 parent 36b7915 commit cf9d1c1

16 files changed

+90
-172
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,7 @@ struct BinaryOpc_match {
533533
if (!Flags.has_value())
534534
return true;
535535

536-
SDNodeFlags TmpFlags = *Flags;
537-
TmpFlags.intersectWith(N->getFlags());
538-
return TmpFlags == *Flags;
536+
return (*Flags & N->getFlags()) == *Flags;
539537
}
540538

541539
return false;
@@ -668,9 +666,7 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
668666
template <typename LHS, typename RHS>
669667
inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
670668
const RHS &R) {
671-
SDNodeFlags Flags;
672-
Flags.setDisjoint(true);
673-
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
669+
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint);
674670
}
675671

676672
template <typename LHS, typename RHS>
@@ -813,9 +809,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
813809
if (!Flags.has_value())
814810
return true;
815811

816-
SDNodeFlags TmpFlags = *Flags;
817-
TmpFlags.intersectWith(N->getFlags());
818-
return TmpFlags == *Flags;
812+
return (*Flags & N->getFlags()) == *Flags;
819813
}
820814

821815
return false;
@@ -848,9 +842,7 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
848842

849843
template <typename Opnd>
850844
inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
851-
SDNodeFlags Flags;
852-
Flags.setNonNeg(true);
853-
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
845+
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg);
854846
}
855847

856848
template <typename Opnd> inline auto m_SExt(const Opnd &Op) {

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,17 +1064,13 @@ class SelectionDAG {
10641064
/// addressing some offset of an object. i.e. if a load is split into multiple
10651065
/// components, create an add nuw from the base pointer to the offset.
10661066
SDValue getObjectPtrOffset(const SDLoc &SL, SDValue Ptr, TypeSize Offset) {
1067-
SDNodeFlags Flags;
1068-
Flags.setNoUnsignedWrap(true);
1069-
return getMemBasePlusOffset(Ptr, Offset, SL, Flags);
1067+
return getMemBasePlusOffset(Ptr, Offset, SL, SDNodeFlags::NoUnsignedWrap);
10701068
}
10711069

10721070
SDValue getObjectPtrOffset(const SDLoc &SL, SDValue Ptr, SDValue Offset) {
10731071
// The object itself can't wrap around the address space, so it shouldn't be
10741072
// possible for the adds of the offsets to the split parts to overflow.
1075-
SDNodeFlags Flags;
1076-
Flags.setNoUnsignedWrap(true);
1077-
return getMemBasePlusOffset(Ptr, Offset, SL, Flags);
1073+
return getMemBasePlusOffset(Ptr, Offset, SL, SDNodeFlags::NoUnsignedWrap);
10781074
}
10791075

10801076
/// Return a new CALLSEQ_START node, that starts new call frame, in which

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ struct SDNodeFlags {
391391
None = 0,
392392
NoUnsignedWrap = 1 << 0,
393393
NoSignedWrap = 1 << 1,
394+
NoWrap = NoUnsignedWrap | NoSignedWrap,
394395
Exact = 1 << 2,
395396
Disjoint = 1 << 3,
396397
NonNeg = 1 << 4,
@@ -419,7 +420,7 @@ struct SDNodeFlags {
419420
};
420421

421422
/// Default constructor turns off all optimization flags.
422-
SDNodeFlags() : Flags(0) {}
423+
SDNodeFlags(unsigned Flags = SDNodeFlags::None) : Flags(Flags) {}
423424

424425
/// Propagate the fast-math-flags from an IR FPMathOperator.
425426
void copyFMF(const FPMathOperator &FPMO) {
@@ -467,15 +468,23 @@ struct SDNodeFlags {
467468
bool operator==(const SDNodeFlags &Other) const {
468469
return Flags == Other.Flags;
469470
}
470-
471-
/// Clear any flags in this flag set that aren't also set in Flags. All
472-
/// flags will be cleared if Flags are undefined.
473-
void intersectWith(const SDNodeFlags Flags) { this->Flags &= Flags.Flags; }
471+
void operator&=(const SDNodeFlags &OtherFlags) { Flags &= OtherFlags.Flags; }
472+
void operator|=(const SDNodeFlags &OtherFlags) { Flags |= OtherFlags.Flags; }
474473
};
475474

476475
LLVM_DECLARE_ENUM_AS_BITMASK(decltype(SDNodeFlags::None),
477476
SDNodeFlags::Unpredictable);
478477

478+
inline SDNodeFlags operator|(SDNodeFlags LHS, SDNodeFlags RHS) {
479+
LHS |= RHS;
480+
return LHS;
481+
}
482+
483+
inline SDNodeFlags operator&(SDNodeFlags LHS, SDNodeFlags RHS) {
484+
LHS &= RHS;
485+
return LHS;
486+
}
487+
479488
/// Represents one node in the SelectionDAG.
480489
///
481490
class SDNode : public FoldingSetNode, public ilist_node<SDNode> {
@@ -1013,6 +1022,7 @@ END_TWO_BYTE_PACK()
10131022

10141023
SDNodeFlags getFlags() const { return Flags; }
10151024
void setFlags(SDNodeFlags NewFlags) { Flags = NewFlags; }
1025+
void dropFlags(unsigned Mask) { Flags &= ~Mask; }
10161026

10171027
/// Clear any flags in this node that aren't also set in Flags.
10181028
/// If Flags is not in a defined state then this has no effect.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,7 @@ SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
12101210
SDNodeFlags NewFlags;
12111211
if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
12121212
Flags.hasNoUnsignedWrap())
1213-
NewFlags.setNoUnsignedWrap(true);
1213+
NewFlags |= SDNodeFlags::NoUnsignedWrap;
12141214

12151215
if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
12161216
// Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
@@ -2892,11 +2892,11 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
28922892
if (N->getFlags().hasNoUnsignedWrap() &&
28932893
N0->getFlags().hasNoUnsignedWrap() &&
28942894
N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
2895-
Flags.setNoUnsignedWrap(true);
2895+
Flags |= SDNodeFlags::NoUnsignedWrap;
28962896
if (N->getFlags().hasNoSignedWrap() &&
28972897
N0->getFlags().hasNoSignedWrap() &&
28982898
N0.getOperand(0)->getFlags().hasNoSignedWrap())
2899-
Flags.setNoSignedWrap(true);
2899+
Flags |= SDNodeFlags::NoSignedWrap;
29002900
}
29012901
SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
29022902
DAG.getConstant(CM, DL, VT), Flags);
@@ -2920,12 +2920,12 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
29202920
N0->getFlags().hasNoUnsignedWrap() &&
29212921
OMul->getFlags().hasNoUnsignedWrap() &&
29222922
OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
2923-
Flags.setNoUnsignedWrap(true);
2923+
Flags |= SDNodeFlags::NoUnsignedWrap;
29242924
if (N->getFlags().hasNoSignedWrap() &&
29252925
N0->getFlags().hasNoSignedWrap() &&
29262926
OMul->getFlags().hasNoSignedWrap() &&
29272927
OMul.getOperand(0)->getFlags().hasNoSignedWrap())
2928-
Flags.setNoSignedWrap(true);
2928+
Flags |= SDNodeFlags::NoSignedWrap;
29292929
}
29302930
SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
29312931
DAG.getConstant(CM, DL, VT), Flags);
@@ -2987,11 +2987,8 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
29872987

29882988
// fold (a+b) -> (a|b) iff a and b share no bits.
29892989
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2990-
DAG.haveNoCommonBitsSet(N0, N1)) {
2991-
SDNodeFlags Flags;
2992-
Flags.setDisjoint(true);
2993-
return DAG.getNode(ISD::OR, DL, VT, N0, N1, Flags);
2994-
}
2990+
DAG.haveNoCommonBitsSet(N0, N1))
2991+
return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
29952992

29962993
// Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
29972994
if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
@@ -9556,11 +9553,8 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
95569553

95579554
// fold (a^b) -> (a|b) iff a and b share no bits.
95589555
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
9559-
DAG.haveNoCommonBitsSet(N0, N1)) {
9560-
SDNodeFlags Flags;
9561-
Flags.setDisjoint(true);
9562-
return DAG.getNode(ISD::OR, DL, VT, N0, N1, Flags);
9563-
}
9556+
DAG.haveNoCommonBitsSet(N0, N1))
9557+
return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
95649558

95659559
// look for 'add-like' folds:
95669560
// XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
@@ -10210,7 +10204,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
1021010204
SDNodeFlags Flags;
1021110205
// Preserve the disjoint flag for Or.
1021210206
if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10213-
Flags.setDisjoint(true);
10207+
Flags |= SDNodeFlags::Disjoint;
1021410208
return DAG.getNode(N0.getOpcode(), DL, VT, Shl0, Shl1, Flags);
1021510209
}
1021610210
}
@@ -13922,11 +13916,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
1392213916
// fold (sext x) -> (zext x) if the sign bit is known zero.
1392313917
if (!TLI.isSExtCheaperThanZExt(N0.getValueType(), VT) &&
1392413918
(!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
13925-
DAG.SignBitIsZero(N0)) {
13926-
SDNodeFlags Flags;
13927-
Flags.setNonNeg(true);
13928-
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, Flags);
13929-
}
13919+
DAG.SignBitIsZero(N0))
13920+
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, SDNodeFlags::NonNeg);
1393013921

1393113922
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
1393213923
return NewVSel;
@@ -14807,10 +14798,9 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
1480714798
uint64_t PtrOff = PtrAdjustmentInBits / 8;
1480814799
SDLoc DL(LN0);
1480914800
// The original load itself didn't wrap, so an offset within it doesn't.
14810-
SDNodeFlags Flags;
14811-
Flags.setNoUnsignedWrap(true);
14812-
SDValue NewPtr = DAG.getMemBasePlusOffset(
14813-
LN0->getBasePtr(), TypeSize::getFixed(PtrOff), DL, Flags);
14801+
SDValue NewPtr =
14802+
DAG.getMemBasePlusOffset(LN0->getBasePtr(), TypeSize::getFixed(PtrOff),
14803+
DL, SDNodeFlags::NoUnsignedWrap);
1481414804
AddToWorklist(NewPtr.getNode());
1481514805

1481614806
SDValue Load;

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,12 +1697,9 @@ SDValue SelectionDAGLegalize::ExpandFCOPYSIGN(SDNode *Node) const {
16971697
SignBit = DAG.getNode(ISD::TRUNCATE, DL, MagVT, SignBit);
16981698
}
16991699

1700-
SDNodeFlags Flags;
1701-
Flags.setDisjoint(true);
1702-
17031700
// Store the part with the modified sign and convert back to float.
1704-
SDValue CopiedSign =
1705-
DAG.getNode(ISD::OR, DL, MagVT, ClearedSign, SignBit, Flags);
1701+
SDValue CopiedSign = DAG.getNode(ISD::OR, DL, MagVT, ClearedSign, SignBit,
1702+
SDNodeFlags::Disjoint);
17061703

17071704
return modifySignAsInt(MagAsInt, DL, CopiedSign);
17081705
}

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4674,9 +4674,9 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
46744674
DAG.getNode(ISD::SHL, dl, ShAmtVT, SrlTmp,
46754675
DAG.getConstant(Log2_32(ShiftUnitInBits), dl, ShAmtVT));
46764676

4677-
Flags.setExact(true);
4678-
SDValue ByteOffset = DAG.getNode(ISD::SRL, dl, ShAmtVT, BitOffset,
4679-
DAG.getConstant(3, dl, ShAmtVT), Flags);
4677+
SDValue ByteOffset =
4678+
DAG.getNode(ISD::SRL, dl, ShAmtVT, BitOffset,
4679+
DAG.getConstant(3, dl, ShAmtVT), SDNodeFlags::Exact);
46804680
// And clamp it, because OOB load is an immediate UB,
46814681
// while shift overflow would have *just* been poison.
46824682
ByteOffset = DAG.getNode(ISD::AND, dl, ShAmtVT, ByteOffset,

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,11 +1700,8 @@ SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) {
17001700
SDValue ClearedSign =
17011701
DAG.getNode(ISD::VP_AND, DL, IntVT, Mag, ClearSignMask, Mask, EVL);
17021702

1703-
SDNodeFlags Flags;
1704-
Flags.setDisjoint(true);
1705-
17061703
SDValue CopiedSign = DAG.getNode(ISD::VP_OR, DL, IntVT, ClearedSign, SignBit,
1707-
Mask, EVL, Flags);
1704+
Mask, EVL, SDNodeFlags::Disjoint);
17081705

17091706
return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
17101707
}
@@ -1886,11 +1883,8 @@ SDValue VectorLegalizer::ExpandFCOPYSIGN(SDNode *Node) {
18861883
APInt::getSignedMaxValue(IntVT.getScalarSizeInBits()), DL, IntVT);
18871884
SDValue ClearedSign = DAG.getNode(ISD::AND, DL, IntVT, Mag, ClearSignMask);
18881885

1889-
SDNodeFlags Flags;
1890-
Flags.setDisjoint(true);
1891-
1892-
SDValue CopiedSign =
1893-
DAG.getNode(ISD::OR, DL, IntVT, ClearedSign, SignBit, Flags);
1886+
SDValue CopiedSign = DAG.getNode(ISD::OR, DL, IntVT, ClearedSign, SignBit,
1887+
SDNodeFlags::Disjoint);
18941888

18951889
return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
18961890
}

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,16 +1381,14 @@ void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
13811381
unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinValue() / 8;
13821382

13831383
if (MemVT.isScalableVector()) {
1384-
SDNodeFlags Flags;
13851384
SDValue BytesIncrement = DAG.getVScale(
13861385
DL, Ptr.getValueType(),
13871386
APInt(Ptr.getValueSizeInBits().getFixedValue(), IncrementSize));
13881387
MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace());
1389-
Flags.setNoUnsignedWrap(true);
13901388
if (ScaledOffset)
13911389
*ScaledOffset += IncrementSize;
13921390
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement,
1393-
Flags);
1391+
SDNodeFlags::NoUnsignedWrap);
13941392
} else {
13951393
MPI = N->getPointerInfo().getWithOffset(IncrementSize);
13961394
// Increment the pointer to the other half.

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12377,7 +12377,7 @@ bool SDNode::hasPredecessor(const SDNode *N) const {
1237712377
}
1237812378

1237912379
void SDNode::intersectFlagsWith(const SDNodeFlags Flags) {
12380-
this->Flags.intersectWith(Flags);
12380+
this->Flags &= Flags;
1238112381
}
1238212382

1238312383
SDValue

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4318,7 +4318,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
43184318
SDNodeFlags Flags;
43194319
if (NW.hasNoUnsignedWrap() ||
43204320
(int64_t(Offset) >= 0 && NW.hasNoUnsignedSignedWrap()))
4321-
Flags.setNoUnsignedWrap(true);
4321+
Flags |= SDNodeFlags::NoUnsignedWrap;
43224322

43234323
N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N,
43244324
DAG.getConstant(Offset, dl, N.getValueType()), Flags);
@@ -4484,10 +4484,9 @@ void SelectionDAGBuilder::visitAlloca(const AllocaInst &I) {
44844484
// Round the size of the allocation up to the stack alignment size
44854485
// by add SA-1 to the size. This doesn't overflow because we're computing
44864486
// an address inside an alloca.
4487-
SDNodeFlags Flags;
4488-
Flags.setNoUnsignedWrap(true);
44894487
AllocSize = DAG.getNode(ISD::ADD, dl, AllocSize.getValueType(), AllocSize,
4490-
DAG.getConstant(StackAlignMask, dl, IntPtr), Flags);
4488+
DAG.getConstant(StackAlignMask, dl, IntPtr),
4489+
SDNodeFlags::NoUnsignedWrap);
44914490

44924491
// Mask out the low bits for alignment purposes.
44934492
AllocSize = DAG.getNode(ISD::AND, dl, AllocSize.getValueType(), AllocSize,
@@ -11224,15 +11223,13 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
1122411223

1122511224
// An aggregate return value cannot wrap around the address space, so
1122611225
// offsets to its parts don't wrap either.
11227-
SDNodeFlags Flags;
11228-
Flags.setNoUnsignedWrap(true);
11229-
1123011226
MachineFunction &MF = CLI.DAG.getMachineFunction();
1123111227
Align HiddenSRetAlign = MF.getFrameInfo().getObjectAlign(DemoteStackIdx);
1123211228
for (unsigned i = 0; i < NumValues; ++i) {
11233-
SDValue Add = CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
11234-
CLI.DAG.getConstant(Offsets[i], CLI.DL,
11235-
PtrVT), Flags);
11229+
SDValue Add =
11230+
CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
11231+
CLI.DAG.getConstant(Offsets[i], CLI.DL, PtrVT),
11232+
SDNodeFlags::NoUnsignedWrap);
1123611233
SDValue L = CLI.DAG.getLoad(
1123711234
RetTys[i], CLI.DL, CLI.Chain, Add,
1123811235
MachinePointerInfo::getFixedStack(CLI.DAG.getMachineFunction(),

llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4224,11 +4224,8 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
42244224

42254225
// Set the NoFPExcept flag when no original matched node could
42264226
// raise an FP exception, but the new node potentially might.
4227-
if (!MayRaiseFPException && mayRaiseFPException(Res)) {
4228-
SDNodeFlags Flags = Res->getFlags();
4229-
Flags.setNoFPExcept(true);
4230-
Res->setFlags(Flags);
4231-
}
4227+
if (!MayRaiseFPException && mayRaiseFPException(Res))
4228+
Res->setFlags(Res->getFlags() | SDNodeFlags::NoFPExcept);
42324229

42334230
// If the node had chain/glue results, update our notion of the current
42344231
// chain and glue.

0 commit comments

Comments
 (0)