Skip to content

[WIP] Enhance 3D register allocation strategy #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: aie-public
Choose a base branch
from
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/LiveIntervals.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class VirtRegMap;
LiveIntervals();
~LiveIntervals() override;

const TargetInstrInfo &getTargetInstrInfo() const { return *TII; }

/// Calculate the spill weight to assign to a single instruction.
static float getSpillWeight(bool isDef, bool isUse,
const MachineBlockFrequencyInfo *MBFI,
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ namespace llvm {
///
FunctionPass *createGreedyRegisterAllocator();
FunctionPass *createGreedyRegisterAllocator(RegClassFilterFunc F);
FunctionPass *createGreedyRegisterAllocator(RegClassFilterFunc F,
LiveIntervalFilterFunc LIF);

/// PBQPRegisterAllocation Pass - This pass implements the Partitioned Boolean
/// Quadratic Prograaming (PBQP) based register allocator.
Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/CodeGen/RegAllocCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ namespace llvm {
class TargetRegisterClass;
class TargetRegisterInfo;

class MachineRegisterInfo;
class TargetInstrInfo;
class LiveInterval;

typedef std::function<bool(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC)> RegClassFilterFunc;

Expand All @@ -26,6 +30,17 @@ static inline bool allocateAllRegClasses(const TargetRegisterInfo &,
return true;
}

typedef std::function<bool(MachineRegisterInfo &MRI, const TargetInstrInfo &TII,
const LiveInterval *LI)>
LiveIntervalFilterFunc;
/// Default live interval filter function for register allocation. All live
/// intervals should be allocated.
static inline bool allocateAllLiveIntervals(MachineRegisterInfo &,
const TargetInstrInfo &,
const LiveInterval *) {
return true;
}

} // namespace llvm

#endif // LLVM_CODEGEN_REGALLOCCOMMON_H
9 changes: 7 additions & 2 deletions llvm/lib/CodeGen/RegAllocBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,13 @@ void RegAllocBase::enqueue(const LiveInterval *LI) {

const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
if (ShouldAllocateClass(*TRI, RC)) {
LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n');
enqueueImpl(LI);
if (ShouldAllocateLiveInterval(*MRI, LIS->getTargetInstrInfo(), LI)) {
LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n');
enqueueImpl(LI);
} else {
LLVM_DEBUG(dbgs() << "Not enqueueing " << printReg(Reg, TRI)
<< " in skipped live interval\n");
}
} else {
LLVM_DEBUG(dbgs() << "Not enqueueing " << printReg(Reg, TRI)
<< " in skipped register class\n");
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/RegAllocBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,17 @@ class RegAllocBase {
LiveRegMatrix *Matrix = nullptr;
RegisterClassInfo RegClassInfo;
const RegClassFilterFunc ShouldAllocateClass;
const LiveIntervalFilterFunc ShouldAllocateLiveInterval;

/// Inst which is a def of an original reg and whose defs are already all
/// dead after remat is saved in DeadRemats. The deletion of such inst is
/// postponed till all the allocations are done, so its remat expr is
/// always available for the remat of all the siblings of the original reg.
SmallPtrSet<MachineInstr *, 32> DeadRemats;

RegAllocBase(const RegClassFilterFunc F = allocateAllRegClasses) :
ShouldAllocateClass(F) {}
RegAllocBase(const RegClassFilterFunc F = allocateAllRegClasses,
const LiveIntervalFilterFunc LIF = allocateAllLiveIntervals)
: ShouldAllocateClass(F), ShouldAllocateLiveInterval(LIF) {}

virtual ~RegAllocBase() = default;

Expand Down
10 changes: 7 additions & 3 deletions llvm/lib/CodeGen/RegAllocGreedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,15 @@ FunctionPass *llvm::createGreedyRegisterAllocator(RegClassFilterFunc Ftor) {
return new RAGreedy(Ftor);
}

RAGreedy::RAGreedy(RegClassFilterFunc F):
MachineFunctionPass(ID),
RegAllocBase(F) {
FunctionPass *
llvm::createGreedyRegisterAllocator(RegClassFilterFunc Ftor,
LiveIntervalFilterFunc LIFtor) {
return new RAGreedy(Ftor, LIFtor);
}

RAGreedy::RAGreedy(RegClassFilterFunc F, LiveIntervalFilterFunc LIF)
: MachineFunctionPass(ID), RegAllocBase(F, LIF) {}

void RAGreedy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<MachineBlockFrequencyInfo>();
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/RegAllocGreedy.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
bool ReverseLocalAssignment = false;

public:
RAGreedy(const RegClassFilterFunc F = allocateAllRegClasses);
RAGreedy(const RegClassFilterFunc F = allocateAllRegClasses,
const LiveIntervalFilterFunc LIF = allocateAllLiveIntervals);

/// Return the pass name.
StringRef getPassName() const override { return "Greedy Register Allocator"; }
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Target/AIE/AIE2InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,6 @@ AIE2InstrInfo::getSpillPseudoExpandInfo(const MachineInstr &MI) const {
{AIE2::LDA_dms_spill, AIE2::sub_dim_size},
{AIE2::LDA_dms_spill, AIE2::sub_dim_stride},
{AIE2::LDA_dms_spill, AIE2::sub_dim_count},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_mod},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_dim_size},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_dim_stride},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_dim_count}};
Expand All @@ -844,7 +843,6 @@ AIE2InstrInfo::getSpillPseudoExpandInfo(const MachineInstr &MI) const {
{AIE2::ST_dms_spill, AIE2::sub_dim_size},
{AIE2::ST_dms_spill, AIE2::sub_dim_stride},
{AIE2::ST_dms_spill, AIE2::sub_dim_count},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_mod},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_dim_size},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_dim_stride},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_dim_count}};
Expand Down Expand Up @@ -1205,7 +1203,6 @@ AIE2InstrInfo::getTiedRegInfo(unsigned Opcode) const {
SubRegSplit(AIE2::sub_dim_size),
SubRegSplit(AIE2::sub_dim_stride),
SubRegSplit(AIE2::sub_dim_count),
SubRegSplit(AIE2::sub_hi_dim_then_sub_mod, /*IsUndef=*/true),
SubRegSplit(AIE2::sub_hi_dim_then_sub_dim_size),
SubRegSplit(AIE2::sub_hi_dim_then_sub_dim_stride),
SubRegSplit(AIE2::sub_hi_dim_then_sub_dim_count)};
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AIE/AIE2InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ foreach instr = [VST_2D_SRS_D8_S32, VST_2D_SRS_D16_S64, VST_2D_SRS_D16_S32,
// Define _split variants for instructions using 3D registers
class Split3DInstr<Instruction RealInst, int opidx> : SplitPseudo<RealInst,
opidx, (ins eM:$mod1, eDN:$dim_size1, eDJ:$dim_stride1, eDC:$dim_count1,
eM:$mod2, eDN:$dim_size2, eDJ:$dim_stride2, eDC:$dim_count2)> {}
eDN:$dim_size2, eDJ:$dim_stride2, eDC:$dim_count2)> {}
foreach instr = [VLDA_3D_dmw_lda_w, VLDA_3D_dmw_lda_am, VLDA_3D_CONV_FP32_BF16,
VLDB_3D, VLDB_3D_128, LDA_3D_dmv_lda_q, VLDB_3D_UNPACK_S8_S4,
VLDB_3D_UNPACK_S16_S8, VLDB_3D_UNPACK_D8_D4, VLDB_3D_UNPACK_D16_D8,
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/AIE/AIE2RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ const std::set<int> &AIE2RegisterInfo::getSubRegSplit(int RegClassId) const {
AIE2::sub_dim_size,
AIE2::sub_dim_stride,
AIE2::sub_dim_count,
AIE2::sub_hi_dim_then_sub_mod,
AIE2::sub_hi_dim_then_sub_dim_size,
AIE2::sub_hi_dim_then_sub_dim_stride,
AIE2::sub_hi_dim_then_sub_dim_count};
Expand Down
30 changes: 22 additions & 8 deletions llvm/lib/Target/AIE/AIEBaseInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,15 +649,29 @@ void AIEBaseInstrInfo::copyThroughSubRegs(MachineBasicBlock &MBB,
MCRegister SrcReg,
bool KillSrc) const {
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();

SmallSet<MCRegister, 8> SrcSubRegs;
collectSubRegs(SrcReg, SrcSubRegs, TRI);
auto &TRI =
*static_cast<const AIEBaseRegisterInfo *>(MRI.getTargetRegisterInfo());

const auto *RC = Register::isPhysicalRegister(SrcReg.id())
? TRI.getMinimalPhysRegClass(SrcReg)
: MRI.getRegClass(SrcReg);
auto &SubRegSplit = TRI.getSubRegSplit(RC->getID());

if (SubRegSplit.size() > 1) {
for (const auto &SubRegIdx : SubRegSplit) {
MCRegister SrcSubReg = TRI.getSubReg(SrcReg, SubRegIdx);
MCRegister DstSubReg = TRI.getSubReg(DstReg, SubRegIdx);
copyPhysReg(MBB, MBBI, DL, DstSubReg, SrcSubReg, KillSrc);
}
} else {
SmallSet<MCRegister, 8> SrcSubRegs;
collectSubRegs(SrcReg, SrcSubRegs, TRI);

for (MCRegister SrcSubReg : SrcSubRegs) {
unsigned SubRegIdx = TRI.getSubRegIndex(SrcReg, SrcSubReg);
MCRegister DstSubReg = TRI.getSubReg(DstReg, SubRegIdx);
copyPhysReg(MBB, MBBI, DL, DstSubReg, SrcSubReg, KillSrc);
for (MCRegister SrcSubReg : SrcSubRegs) {
unsigned SubRegIdx = TRI.getSubRegIndex(SrcReg, SrcSubReg);
MCRegister DstSubReg = TRI.getSubReg(DstReg, SubRegIdx);
copyPhysReg(MBB, MBBI, DL, DstSubReg, SrcSubReg, KillSrc);
}
}
}

Expand Down
Loading
Loading