Skip to content

Commit c4bac7f

Browse files
[LLVM][AArch64]Use load/store with consecutive registers in SME2 or S… (#77665)
…VE2.1 for spill/fill When possible the spill/fill register in Frame Lowering uses the ld/st consecutive pairs available in sme or sve2.1.
1 parent b27eb0a commit c4bac7f

File tree

5 files changed

+1815
-1932
lines changed

5 files changed

+1815
-1932
lines changed

llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Lines changed: 160 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,6 +1508,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
15081508
switch (I->getOpcode()) {
15091509
default:
15101510
return false;
1511+
case AArch64::PTRUE_C_B:
1512+
case AArch64::LD1B_2Z_IMM:
1513+
case AArch64::ST1B_2Z_IMM:
15111514
case AArch64::STR_ZXI:
15121515
case AArch64::STR_PXI:
15131516
case AArch64::LDR_ZXI:
@@ -2781,6 +2784,16 @@ struct RegPairInfo {
27812784

27822785
} // end anonymous namespace
27832786

2787+
unsigned findFreePredicateReg(BitVector &SavedRegs) {
2788+
for (unsigned PReg = AArch64::P8; PReg <= AArch64::P15; ++PReg) {
2789+
if (SavedRegs.test(PReg)) {
2790+
unsigned PNReg = PReg - AArch64::P0 + AArch64::PN0;
2791+
return PNReg;
2792+
}
2793+
}
2794+
return AArch64::NoRegister;
2795+
}
2796+
27842797
static void computeCalleeSaveRegisterPairs(
27852798
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
27862799
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -2859,7 +2872,11 @@ static void computeCalleeSaveRegisterPairs(
28592872
RPI.Reg2 = NextReg;
28602873
break;
28612874
case RegPairInfo::PPR:
2875+
break;
28622876
case RegPairInfo::ZPR:
2877+
if (AFI->getPredicateRegForFillSpill() != 0)
2878+
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1))
2879+
RPI.Reg2 = NextReg;
28632880
break;
28642881
}
28652882
}
@@ -2897,14 +2914,13 @@ static void computeCalleeSaveRegisterPairs(
28972914
if (NeedsWinCFI &&
28982915
RPI.isPaired()) // RPI.FrameIdx must be the lower index of the pair
28992916
RPI.FrameIdx = CSI[i + RegInc].getFrameIdx();
2900-
29012917
int Scale = RPI.getScale();
29022918

29032919
int OffsetPre = RPI.isScalable() ? ScalableByteOffset : ByteOffset;
29042920
assert(OffsetPre % Scale == 0);
29052921

29062922
if (RPI.isScalable())
2907-
ScalableByteOffset += StackFillDir * Scale;
2923+
ScalableByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);
29082924
else
29092925
ByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);
29102926

@@ -2915,9 +2931,6 @@ static void computeCalleeSaveRegisterPairs(
29152931
(IsWindows && RPI.Reg2 == AArch64::LR)))
29162932
ByteOffset += StackFillDir * 8;
29172933

2918-
assert(!(RPI.isScalable() && RPI.isPaired()) &&
2919-
"Paired spill/fill instructions don't exist for SVE vectors");
2920-
29212934
// Round up size of non-pair to pair size if we need to pad the
29222935
// callee-save area to ensure 16-byte alignment.
29232936
if (NeedGapToAlignStack && !NeedsWinCFI &&
@@ -3004,6 +3017,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30043017
}
30053018
return true;
30063019
}
3020+
bool PTrueCreated = false;
30073021
for (const RegPairInfo &RPI : llvm::reverse(RegPairs)) {
30083022
unsigned Reg1 = RPI.Reg1;
30093023
unsigned Reg2 = RPI.Reg2;
@@ -3038,10 +3052,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30383052
Alignment = Align(16);
30393053
break;
30403054
case RegPairInfo::ZPR:
3041-
StrOpc = AArch64::STR_ZXI;
3042-
Size = 16;
3043-
Alignment = Align(16);
3044-
break;
3055+
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
3056+
Size = 16;
3057+
Alignment = Align(16);
3058+
break;
30453059
case RegPairInfo::PPR:
30463060
StrOpc = AArch64::STR_PXI;
30473061
Size = 2;
@@ -3065,33 +3079,79 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30653079
std::swap(Reg1, Reg2);
30663080
std::swap(FrameIdxReg1, FrameIdxReg2);
30673081
}
3068-
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
3069-
if (!MRI.isReserved(Reg1))
3070-
MBB.addLiveIn(Reg1);
3071-
if (RPI.isPaired()) {
3082+
3083+
if (RPI.isPaired() && RPI.isScalable()) {
3084+
const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>();
3085+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
3086+
unsigned PnReg = AFI->getPredicateRegForFillSpill();
3087+
assert(((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) && PnReg != 0) &&
3088+
"Expects SVE2.1 or SME2 target and a predicate register");
3089+
#ifdef EXPENSIVE_CHECKS
3090+
auto IsPPR = [](const RegPairInfo &c) {
3091+
return c.Reg1 == RegPairInfo::PPR;
3092+
};
3093+
auto PPRBegin = std::find_if(RegPairs.begin(), RegPairs.end(), IsPPR);
3094+
auto IsZPR = [](const RegPairInfo &c) {
3095+
return c.Type == RegPairInfo::ZPR;
3096+
};
3097+
auto ZPRBegin = std::find_if(RegPairs.begin(), RegPairs.end(), IsZPR);
3098+
assert(!(PPRBegin < ZPRBegin) &&
3099+
"Expected callee save predicate to be handled first");
3100+
#endif
3101+
if (!PTrueCreated) {
3102+
PTrueCreated = true;
3103+
BuildMI(MBB, MI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
3104+
.setMIFlags(MachineInstr::FrameSetup);
3105+
}
3106+
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
3107+
if (!MRI.isReserved(Reg1))
3108+
MBB.addLiveIn(Reg1);
30723109
if (!MRI.isReserved(Reg2))
30733110
MBB.addLiveIn(Reg2);
3074-
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
3111+
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0));
30753112
MIB.addMemOperand(MF.getMachineMemOperand(
30763113
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
30773114
MachineMemOperand::MOStore, Size, Alignment));
3115+
MIB.addReg(PnReg);
3116+
MIB.addReg(AArch64::SP)
3117+
.addImm(RPI.Offset) // [sp, #offset*scale],
3118+
// where factor*scale is implicit
3119+
.setMIFlag(MachineInstr::FrameSetup);
3120+
MIB.addMemOperand(MF.getMachineMemOperand(
3121+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3122+
MachineMemOperand::MOStore, Size, Alignment));
3123+
if (NeedsWinCFI)
3124+
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
3125+
} else { // The code when the pair of ZReg is not present
3126+
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
3127+
if (!MRI.isReserved(Reg1))
3128+
MBB.addLiveIn(Reg1);
3129+
if (RPI.isPaired()) {
3130+
if (!MRI.isReserved(Reg2))
3131+
MBB.addLiveIn(Reg2);
3132+
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
3133+
MIB.addMemOperand(MF.getMachineMemOperand(
3134+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
3135+
MachineMemOperand::MOStore, Size, Alignment));
3136+
}
3137+
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
3138+
.addReg(AArch64::SP)
3139+
.addImm(RPI.Offset) // [sp, #offset*scale],
3140+
// where factor*scale is implicit
3141+
.setMIFlag(MachineInstr::FrameSetup);
3142+
MIB.addMemOperand(MF.getMachineMemOperand(
3143+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3144+
MachineMemOperand::MOStore, Size, Alignment));
3145+
if (NeedsWinCFI)
3146+
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
30783147
}
3079-
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
3080-
.addReg(AArch64::SP)
3081-
.addImm(RPI.Offset) // [sp, #offset*scale],
3082-
// where factor*scale is implicit
3083-
.setMIFlag(MachineInstr::FrameSetup);
3084-
MIB.addMemOperand(MF.getMachineMemOperand(
3085-
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3086-
MachineMemOperand::MOStore, Size, Alignment));
3087-
if (NeedsWinCFI)
3088-
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
3089-
30903148
// Update the StackIDs of the SVE stack slots.
30913149
MachineFrameInfo &MFI = MF.getFrameInfo();
3092-
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
3093-
MFI.setStackID(RPI.FrameIdx, TargetStackID::ScalableVector);
3094-
3150+
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
3151+
MFI.setStackID(FrameIdxReg1, TargetStackID::ScalableVector);
3152+
if (RPI.isPaired())
3153+
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
3154+
}
30953155
}
30963156
return true;
30973157
}
@@ -3109,7 +3169,6 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31093169
DL = MBBI->getDebugLoc();
31103170

31113171
computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, hasFP(MF));
3112-
31133172
if (homogeneousPrologEpilog(MF, &MBB)) {
31143173
auto MIB = BuildMI(MBB, MBBI, DL, TII.get(AArch64::HOM_Epilog))
31153174
.setMIFlag(MachineInstr::FrameDestroy);
@@ -3130,6 +3189,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31303189
auto ZPREnd = std::find_if_not(ZPRBegin, RegPairs.end(), IsZPR);
31313190
std::reverse(ZPRBegin, ZPREnd);
31323191

3192+
bool PTrueCreated = false;
31333193
for (const RegPairInfo &RPI : RegPairs) {
31343194
unsigned Reg1 = RPI.Reg1;
31353195
unsigned Reg2 = RPI.Reg2;
@@ -3162,7 +3222,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31623222
Alignment = Align(16);
31633223
break;
31643224
case RegPairInfo::ZPR:
3165-
LdrOpc = AArch64::LDR_ZXI;
3225+
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
31663226
Size = 16;
31673227
Alignment = Align(16);
31683228
break;
@@ -3187,25 +3247,58 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31873247
std::swap(Reg1, Reg2);
31883248
std::swap(FrameIdxReg1, FrameIdxReg2);
31893249
}
3190-
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
3191-
if (RPI.isPaired()) {
3192-
MIB.addReg(Reg2, getDefRegState(true));
3250+
3251+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
3252+
if (RPI.isPaired() && RPI.isScalable()) {
3253+
const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>();
3254+
unsigned PnReg = AFI->getPredicateRegForFillSpill();
3255+
assert(((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) && PnReg != 0) &&
3256+
"Expects SVE2.1 or SME2 target and a predicate register");
3257+
#ifdef EXPENSIVE_CHECKS
3258+
assert(!(PPRBegin < ZPRBegin) &&
3259+
"Expected callee save predicate to be handled first");
3260+
#endif
3261+
if (!PTrueCreated) {
3262+
PTrueCreated = true;
3263+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
3264+
.setMIFlags(MachineInstr::FrameDestroy);
3265+
}
3266+
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
3267+
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0),
3268+
getDefRegState(true));
31933269
MIB.addMemOperand(MF.getMachineMemOperand(
31943270
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
31953271
MachineMemOperand::MOLoad, Size, Alignment));
3272+
MIB.addReg(PnReg);
3273+
MIB.addReg(AArch64::SP)
3274+
.addImm(RPI.Offset) // [sp, #offset*scale]
3275+
// where factor*scale is implicit
3276+
.setMIFlag(MachineInstr::FrameDestroy);
3277+
MIB.addMemOperand(MF.getMachineMemOperand(
3278+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3279+
MachineMemOperand::MOLoad, Size, Alignment));
3280+
if (NeedsWinCFI)
3281+
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
3282+
} else {
3283+
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
3284+
if (RPI.isPaired()) {
3285+
MIB.addReg(Reg2, getDefRegState(true));
3286+
MIB.addMemOperand(MF.getMachineMemOperand(
3287+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
3288+
MachineMemOperand::MOLoad, Size, Alignment));
3289+
}
3290+
MIB.addReg(Reg1, getDefRegState(true));
3291+
MIB.addReg(AArch64::SP)
3292+
.addImm(RPI.Offset) // [sp, #offset*scale]
3293+
// where factor*scale is implicit
3294+
.setMIFlag(MachineInstr::FrameDestroy);
3295+
MIB.addMemOperand(MF.getMachineMemOperand(
3296+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3297+
MachineMemOperand::MOLoad, Size, Alignment));
3298+
if (NeedsWinCFI)
3299+
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
31963300
}
3197-
MIB.addReg(Reg1, getDefRegState(true))
3198-
.addReg(AArch64::SP)
3199-
.addImm(RPI.Offset) // [sp, #offset*scale]
3200-
// where factor*scale is implicit
3201-
.setMIFlag(MachineInstr::FrameDestroy);
3202-
MIB.addMemOperand(MF.getMachineMemOperand(
3203-
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3204-
MachineMemOperand::MOLoad, Size, Alignment));
3205-
if (NeedsWinCFI)
3206-
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
32073301
}
3208-
32093302
return true;
32103303
}
32113304

@@ -3234,6 +3327,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
32343327

32353328
unsigned ExtraCSSpill = 0;
32363329
bool HasUnpairedGPR64 = false;
3330+
bool HasPairZReg = false;
32373331
// Figure out which callee-saved registers to save/restore.
32383332
for (unsigned i = 0; CSRegs[i]; ++i) {
32393333
const unsigned Reg = CSRegs[i];
@@ -3287,6 +3381,28 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
32873381
!RegInfo->isReservedReg(MF, PairedReg))
32883382
ExtraCSSpill = PairedReg;
32893383
}
3384+
// Check if there is a pair of ZRegs, so it can select PReg for spill/fill
3385+
HasPairZReg |= (AArch64::ZPRRegClass.contains(Reg, CSRegs[i ^ 1]) &&
3386+
SavedRegs.test(CSRegs[i ^ 1]));
3387+
}
3388+
3389+
if (HasPairZReg && (Subtarget.hasSVE2p1() || Subtarget.hasSME2())) {
3390+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
3391+
// Find a suitable predicate register for the multi-vector spill/fill
3392+
// instructions.
3393+
unsigned PnReg = findFreePredicateReg(SavedRegs);
3394+
if (PnReg != AArch64::NoRegister)
3395+
AFI->setPredicateRegForFillSpill(PnReg);
3396+
// If no free callee-save has been found assign one.
3397+
if (!AFI->getPredicateRegForFillSpill() &&
3398+
MF.getFunction().getCallingConv() ==
3399+
CallingConv::AArch64_SVE_VectorCall) {
3400+
SavedRegs.set(AArch64::P8);
3401+
AFI->setPredicateRegForFillSpill(AArch64::PN8);
3402+
}
3403+
3404+
assert(!RegInfo->isReservedReg(MF, AFI->getPredicateRegForFillSpill()) &&
3405+
"Predicate cannot be a reserved register");
32903406
}
32913407

32923408
if (MF.getFunction().getCallingConv() == CallingConv::Win64 &&

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
212212
// on function entry to record the initial pstate of a function.
213213
Register PStateSMReg = MCRegister::NoRegister;
214214

215+
// Has the PNReg used to build PTRUE instruction.
216+
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
217+
unsigned PredicateRegForFillSpill = 0;
218+
215219
public:
216220
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
217221

@@ -220,6 +224,13 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
220224
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
221225
const override;
222226

227+
void setPredicateRegForFillSpill(unsigned Reg) {
228+
PredicateRegForFillSpill = Reg;
229+
}
230+
unsigned getPredicateRegForFillSpill() const {
231+
return PredicateRegForFillSpill;
232+
}
233+
223234
Register getPStateSMReg() const { return PStateSMReg; };
224235
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
225236

0 commit comments

Comments
 (0)