Skip to content

[llvm] Support save/restore point splitting in shrink-wrap #119359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llvm/include/llvm/CodeGen/MIRYamlMapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,15 +612,19 @@ namespace yaml {

struct SRPEntry {
StringValue Point;
std::vector<StringValue> Registers;

bool operator==(const SRPEntry &Other) const { return Point == Other.Point; }
bool operator==(const SRPEntry &Other) const {
return Point == Other.Point && Registers == Other.Registers;
}
};

using SaveRestorePoints = std::vector<SRPEntry>;

template <> struct MappingTraits<SRPEntry> {
static void mapping(IO &YamlIO, SRPEntry &Entry) {
YamlIO.mapRequired("point", Entry.Point);
YamlIO.mapRequired("registers", Entry.Registers);
}
};

Expand Down
140 changes: 131 additions & 9 deletions llvm/include/llvm/CodeGen/MachineFrameInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ class MachineBasicBlock;
class BitVector;
class AllocaInst;

using SaveRestorePoints = DenseMap<MachineBasicBlock *, std::vector<Register>>;

class CalleeSavedInfoPerBB {
DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> Map;

public:
std::vector<CalleeSavedInfo> get(MachineBasicBlock *MBB) const {
return Map.lookup(MBB);
}

void set(DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
Map = std::move(CSI);
}
};

/// The CalleeSavedInfo class tracks the information need to locate where a
/// callee saved register is in the current frame.
/// Callee saved reg can also be saved to a different register rather than
Expand All @@ -37,6 +52,8 @@ class CalleeSavedInfo {
int FrameIdx;
unsigned DstReg;
};
std::vector<MachineBasicBlock *> SpilledIn;
std::vector<MachineBasicBlock *> RestoredIn;
/// Flag indicating whether the register is actually restored in the epilog.
/// In most cases, if a register is saved, it is also restored. There are
/// some situations, though, when this is not the case. For example, the
Expand All @@ -58,9 +75,9 @@ class CalleeSavedInfo {
explicit CalleeSavedInfo(unsigned R, int FI = 0) : Reg(R), FrameIdx(FI) {}

// Accessors.
Register getReg() const { return Reg; }
int getFrameIdx() const { return FrameIdx; }
unsigned getDstReg() const { return DstReg; }
Register getReg() const { return Reg; }
int getFrameIdx() const { return FrameIdx; }
unsigned getDstReg() const { return DstReg; }
void setFrameIdx(int FI) {
FrameIdx = FI;
SpilledToReg = false;
Expand All @@ -72,6 +89,16 @@ class CalleeSavedInfo {
bool isRestored() const { return Restored; }
void setRestored(bool R) { Restored = R; }
bool isSpilledToReg() const { return SpilledToReg; }
ArrayRef<MachineBasicBlock *> spilledIn() const { return SpilledIn; }
ArrayRef<MachineBasicBlock *> restoredIn() const { return RestoredIn; }
void addSpilledIn(MachineBasicBlock *MBB) { SpilledIn.push_back(MBB); }
void addRestoredIn(MachineBasicBlock *MBB) { RestoredIn.push_back(MBB); }
void setSpilledIn(std::vector<MachineBasicBlock *> BBV) {
SpilledIn = std::move(BBV);
}
void setRestoredIn(std::vector<MachineBasicBlock *> BBV) {
RestoredIn = std::move(BBV);
}
};

/// The MachineFrameInfo class represents an abstract stack frame until
Expand Down Expand Up @@ -295,6 +322,10 @@ class MachineFrameInfo {
/// Has CSInfo been set yet?
bool CSIValid = false;

CalleeSavedInfoPerBB CSInfoPerSave;

CalleeSavedInfoPerBB CSInfoPerRestore;

/// References to frame indices which are mapped
/// into the local frame allocation block. <FrameIdx, LocalOffset>
SmallVector<std::pair<int, int64_t>, 32> LocalFrameObjects;
Expand Down Expand Up @@ -331,9 +362,16 @@ class MachineFrameInfo {
bool HasTailCall = false;

/// Not null, if shrink-wrapping found a better place for the prologue.
MachineBasicBlock *Save = nullptr;
MachineBasicBlock *Prolog = nullptr;
/// Not null, if shrink-wrapping found a better place for the epilogue.
MachineBasicBlock *Restore = nullptr;
MachineBasicBlock *Epilog = nullptr;

/// Not empty, if shrink-wrapping found a better place for saving callee
/// saves.
SaveRestorePoints SavePoints;
/// Not empty, if shrink-wrapping found a better place for restoring callee
/// saves.
SaveRestorePoints RestorePoints;

/// Size of the UnsafeStack Frame
uint64_t UnsafeStackSize = 0;
Expand Down Expand Up @@ -809,21 +847,105 @@ class MachineFrameInfo {
/// \copydoc getCalleeSavedInfo()
std::vector<CalleeSavedInfo> &getCalleeSavedInfo() { return CSInfo; }

/// Returns callee saved info vector for provided save point in
/// the current function.
std::vector<CalleeSavedInfo> getCSInfoPerSave(MachineBasicBlock *MBB) const {
return CSInfoPerSave.get(MBB);
}

/// Returns callee saved info vector for provided restore point
/// in the current function.
std::vector<CalleeSavedInfo>
getCSInfoPerRestore(MachineBasicBlock *MBB) const {
return CSInfoPerRestore.get(MBB);
}

/// Used by prolog/epilog inserter to set the function's callee saved
/// information.
void setCalleeSavedInfo(std::vector<CalleeSavedInfo> CSI) {
CSInfo = std::move(CSI);
}

/// Used by prolog/epilog inserter to set the function's callee saved
/// information for particular save point.
void setCSInfoPerSave(
DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
CSInfoPerSave.set(CSI);
}

/// Used by prolog/epilog inserter to set the function's callee saved
/// information for particular restore point.
void setCSInfoPerRestore(
DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
CSInfoPerRestore.set(CSI);
}

/// Has the callee saved info been calculated yet?
bool isCalleeSavedInfoValid() const { return CSIValid; }

void setCalleeSavedInfoValid(bool v) { CSIValid = v; }

MachineBasicBlock *getSavePoint() const { return Save; }
void setSavePoint(MachineBasicBlock *NewSave) { Save = NewSave; }
MachineBasicBlock *getRestorePoint() const { return Restore; }
void setRestorePoint(MachineBasicBlock *NewRestore) { Restore = NewRestore; }
const SaveRestorePoints &getRestorePoints() const { return RestorePoints; }

const SaveRestorePoints &getSavePoints() const { return SavePoints; }

std::pair<MachineBasicBlock *, std::vector<Register>>
getRestorePoint(MachineBasicBlock *MBB) const {
if (auto It = RestorePoints.find(MBB); It != RestorePoints.end())
return *It;

std::vector<Register> Regs = {};
return std::make_pair(nullptr, Regs);
}

std::pair<MachineBasicBlock *, std::vector<Register>>
getSavePoint(MachineBasicBlock *MBB) const {
if (auto It = SavePoints.find(MBB); It != SavePoints.end())
return *It;

std::vector<Register> Regs = {};
return std::make_pair(nullptr, Regs);
}

void setSavePoints(SaveRestorePoints NewSavePoints) {
SavePoints = std::move(NewSavePoints);
}

void setRestorePoints(SaveRestorePoints NewRestorePoints) {
RestorePoints = std::move(NewRestorePoints);
}

void setSavePoint(MachineBasicBlock *MBB, std::vector<Register> &Regs) {
if (SavePoints.contains(MBB))
SavePoints[MBB] = Regs;
else
SavePoints.insert(std::make_pair(MBB, Regs));
}

static const SaveRestorePoints constructSaveRestorePoints(
const SaveRestorePoints &SRP,
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &BBMap) {
SaveRestorePoints Pts{};
for (auto &Src : SRP) {
Pts.insert(std::make_pair(BBMap.find(Src.first)->second, Src.second));
}
return Pts;
}

void setRestorePoint(MachineBasicBlock *MBB, std::vector<Register> &Regs) {
if (RestorePoints.contains(MBB))
RestorePoints[MBB] = Regs;
else
RestorePoints.insert(std::make_pair(MBB, Regs));
}

MachineBasicBlock *getProlog() const { return Prolog; }
void setProlog(MachineBasicBlock *BB) { Prolog = BB; }
MachineBasicBlock *getEpilog() const { return Epilog; }
void setEpilog(MachineBasicBlock *BB) { Epilog = BB; }

void clearSavePoints() { SavePoints.clear(); }
void clearRestorePoints() { RestorePoints.clear(); }

uint64_t getUnsafeStackSize() const { return UnsafeStackSize; }
void setUnsafeStackSize(uint64_t Size) { UnsafeStackSize = Size; }
Expand Down
24 changes: 18 additions & 6 deletions llvm/lib/CodeGen/MIRParser/MIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1056,20 +1056,32 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
bool MIRParserImpl::initializeSaveRestorePoints(
PerFunctionMIParsingState &PFS, const yaml::SaveRestorePoints &YamlSRP,
bool IsSavePoints) {
SMDiagnostic Error;
MachineFunction &MF = PFS.MF;
MachineFrameInfo &MFI = MF.getFrameInfo();
llvm::SaveRestorePoints SRPoints;

if (!YamlSRP.empty()) {
const auto &Entry = YamlSRP.front();
for (const auto &Entry : YamlSRP) {
const auto &MBBSource = Entry.Point;
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, MBBSource.Value))
return true;
if (IsSavePoints)
MFI.setSavePoint(MBB);
else
MFI.setRestorePoint(MBB);

std::vector<Register> Registers{};
for (auto &RegStr : Entry.Registers) {
Register Reg;
if (parseNamedRegisterReference(PFS, Reg, RegStr.Value, Error))
return error(Error, RegStr.SourceRange);

Registers.push_back(Reg);
}
SRPoints.insert(std::make_pair(MBB, Registers));
}

if (IsSavePoints)
MFI.setSavePoints(SRPoints);
else
MFI.setRestorePoints(SRPoints);
return false;
}

Expand Down
45 changes: 29 additions & 16 deletions llvm/lib/CodeGen/MIRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ class MIRPrinter {
const MachineRegisterInfo &RegInfo,
const TargetRegisterInfo *TRI);
void convert(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
const MachineFrameInfo &MFI);
const MachineFrameInfo &MFI, const TargetRegisterInfo *TRI);
void convert(ModuleSlotTracker &MST, yaml::SaveRestorePoints &YamlSRP,
MachineBasicBlock *SaveRestorePoint);
const DenseMap<MachineBasicBlock *, std::vector<Register>> &SRP,
const TargetRegisterInfo *TRI);
void convert(yaml::MachineFunction &MF,
const MachineConstantPool &ConstantPool);
void convert(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
Expand Down Expand Up @@ -237,7 +238,8 @@ void MIRPrinter::print(const MachineFunction &MF) {
convert(YamlMF, MF, MF.getRegInfo(), MF.getSubtarget().getRegisterInfo());
MachineModuleSlotTracker MST(MMI, &MF);
MST.incorporateFunction(MF.getFunction());
convert(MST, YamlMF.FrameInfo, MF.getFrameInfo());
convert(MST, YamlMF.FrameInfo, MF.getFrameInfo(),
MF.getSubtarget().getRegisterInfo());
convertStackObjects(YamlMF, MF, MST);
convertEntryValueObjects(YamlMF, MF, MST);
convertCallSiteObjects(YamlMF, MF, MST);
Expand Down Expand Up @@ -374,7 +376,8 @@ void MIRPrinter::convert(yaml::MachineFunction &YamlMF,

void MIRPrinter::convert(ModuleSlotTracker &MST,
yaml::MachineFrameInfo &YamlMFI,
const MachineFrameInfo &MFI) {
const MachineFrameInfo &MFI,
const TargetRegisterInfo *TRI) {
YamlMFI.IsFrameAddressTaken = MFI.isFrameAddressTaken();
YamlMFI.IsReturnAddressTaken = MFI.isReturnAddressTaken();
YamlMFI.HasStackMap = MFI.hasStackMap();
Expand All @@ -394,10 +397,10 @@ void MIRPrinter::convert(ModuleSlotTracker &MST,
YamlMFI.HasTailCall = MFI.hasTailCall();
YamlMFI.IsCalleeSavedInfoValid = MFI.isCalleeSavedInfoValid();
YamlMFI.LocalFrameSize = MFI.getLocalFrameSize();
if (MFI.getSavePoint())
convert(MST, YamlMFI.SavePoints, MFI.getSavePoint());
if (MFI.getRestorePoint())
convert(MST, YamlMFI.RestorePoints, MFI.getRestorePoint());
if (!MFI.getSavePoints().empty())
convert(MST, YamlMFI.SavePoints, MFI.getSavePoints(), TRI);
if (!MFI.getRestorePoints().empty())
convert(MST, YamlMFI.RestorePoints, MFI.getRestorePoints(), TRI);
}

void MIRPrinter::convertEntryValueObjects(yaml::MachineFunction &YMF,
Expand Down Expand Up @@ -618,14 +621,24 @@ void MIRPrinter::convert(yaml::MachineFunction &MF,

void MIRPrinter::convert(ModuleSlotTracker &MST,
yaml::SaveRestorePoints &YamlSRP,
MachineBasicBlock *SRP) {
std::string Str;
yaml::SRPEntry Entry;
raw_string_ostream StrOS(Str);
StrOS << printMBBReference(*SRP);
Entry.Point = StrOS.str();
Str.clear();
YamlSRP.push_back(Entry);
const SaveRestorePoints &SRP,
const TargetRegisterInfo *TRI) {
for (const auto &MBBEntry : SRP) {
std::string Str;
yaml::SRPEntry Entry;
raw_string_ostream StrOS(Str);
StrOS << printMBBReference(*MBBEntry.first);
Entry.Point = StrOS.str();
Str.clear();
for (auto &Reg : MBBEntry.second) {
if (Reg != MCRegister::NoRegister) {
StrOS << printReg(Reg, TRI);
Entry.Registers.push_back(StrOS.str());
Str.clear();
}
}
YamlSRP.push_back(Entry);
}
}

void MIRPrinter::convert(ModuleSlotTracker &MST,
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/CodeGen/MachineFrameInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,23 @@ void MachineFrameInfo::print(const MachineFunction &MF, raw_ostream &OS) const{
}
OS << "\n";
}

OS << "save/restore points:\n";

if (!SavePoints.empty()) {
OS << "save points:\n";

for (auto &item : SavePoints)
OS << printMBBReference(*item.first) << "\n";
} else
OS << "save points are empty\n";

if (!RestorePoints.empty()) {
OS << "restore points:\n";
for (auto &item : RestorePoints)
OS << printMBBReference(*item.first) << "\n";
} else
OS << "restore points are empty\n";
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
Expand Down
Loading
Loading