@@ -1508,6 +1508,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
1508
1508
switch (I->getOpcode ()) {
1509
1509
default :
1510
1510
return false ;
1511
+ case AArch64::PTRUE_C_B:
1512
+ case AArch64::LD1B_2Z_IMM:
1513
+ case AArch64::ST1B_2Z_IMM:
1511
1514
case AArch64::STR_ZXI:
1512
1515
case AArch64::STR_PXI:
1513
1516
case AArch64::LDR_ZXI:
@@ -2781,6 +2784,16 @@ struct RegPairInfo {
2781
2784
2782
2785
} // end anonymous namespace
2783
2786
2787
+ unsigned findFreePredicateReg (BitVector &SavedRegs) {
2788
+ for (unsigned PReg = AArch64::P8; PReg <= AArch64::P15; ++PReg) {
2789
+ if (SavedRegs.test (PReg)) {
2790
+ unsigned PNReg = PReg - AArch64::P0 + AArch64::PN0;
2791
+ return PNReg;
2792
+ }
2793
+ }
2794
+ return AArch64::NoRegister;
2795
+ }
2796
+
2784
2797
static void computeCalleeSaveRegisterPairs (
2785
2798
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
2786
2799
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -2859,7 +2872,11 @@ static void computeCalleeSaveRegisterPairs(
2859
2872
RPI.Reg2 = NextReg;
2860
2873
break ;
2861
2874
case RegPairInfo::PPR:
2875
+ break ;
2862
2876
case RegPairInfo::ZPR:
2877
+ if (AFI->getPredicateRegForFillSpill () != 0 )
2878
+ if (((RPI.Reg1 - AArch64::Z0) & 1 ) == 0 && (NextReg == RPI.Reg1 + 1 ))
2879
+ RPI.Reg2 = NextReg;
2863
2880
break ;
2864
2881
}
2865
2882
}
@@ -2897,14 +2914,13 @@ static void computeCalleeSaveRegisterPairs(
2897
2914
if (NeedsWinCFI &&
2898
2915
RPI.isPaired ()) // RPI.FrameIdx must be the lower index of the pair
2899
2916
RPI.FrameIdx = CSI[i + RegInc].getFrameIdx ();
2900
-
2901
2917
int Scale = RPI.getScale ();
2902
2918
2903
2919
int OffsetPre = RPI.isScalable () ? ScalableByteOffset : ByteOffset;
2904
2920
assert (OffsetPre % Scale == 0 );
2905
2921
2906
2922
if (RPI.isScalable ())
2907
- ScalableByteOffset += StackFillDir * Scale;
2923
+ ScalableByteOffset += StackFillDir * (RPI. isPaired () ? 2 * Scale : Scale) ;
2908
2924
else
2909
2925
ByteOffset += StackFillDir * (RPI.isPaired () ? 2 * Scale : Scale);
2910
2926
@@ -2915,9 +2931,6 @@ static void computeCalleeSaveRegisterPairs(
2915
2931
(IsWindows && RPI.Reg2 == AArch64::LR)))
2916
2932
ByteOffset += StackFillDir * 8 ;
2917
2933
2918
- assert (!(RPI.isScalable () && RPI.isPaired ()) &&
2919
- " Paired spill/fill instructions don't exist for SVE vectors" );
2920
-
2921
2934
// Round up size of non-pair to pair size if we need to pad the
2922
2935
// callee-save area to ensure 16-byte alignment.
2923
2936
if (NeedGapToAlignStack && !NeedsWinCFI &&
@@ -3004,6 +3017,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3004
3017
}
3005
3018
return true ;
3006
3019
}
3020
+ bool PTrueCreated = false ;
3007
3021
for (const RegPairInfo &RPI : llvm::reverse (RegPairs)) {
3008
3022
unsigned Reg1 = RPI.Reg1 ;
3009
3023
unsigned Reg2 = RPI.Reg2 ;
@@ -3038,10 +3052,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3038
3052
Alignment = Align (16 );
3039
3053
break ;
3040
3054
case RegPairInfo::ZPR:
3041
- StrOpc = AArch64::STR_ZXI;
3042
- Size = 16 ;
3043
- Alignment = Align (16 );
3044
- break ;
3055
+ StrOpc = RPI. isPaired () ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
3056
+ Size = 16 ;
3057
+ Alignment = Align (16 );
3058
+ break ;
3045
3059
case RegPairInfo::PPR:
3046
3060
StrOpc = AArch64::STR_PXI;
3047
3061
Size = 2 ;
@@ -3065,33 +3079,79 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3065
3079
std::swap (Reg1, Reg2);
3066
3080
std::swap (FrameIdxReg1, FrameIdxReg2);
3067
3081
}
3068
- MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3069
- if (!MRI.isReserved (Reg1))
3070
- MBB.addLiveIn (Reg1);
3071
- if (RPI.isPaired ()) {
3082
+
3083
+ if (RPI.isPaired () && RPI.isScalable ()) {
3084
+ const AArch64Subtarget &Subtarget = MF.getSubtarget <AArch64Subtarget>();
3085
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3086
+ unsigned PnReg = AFI->getPredicateRegForFillSpill ();
3087
+ assert (((Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ()) && PnReg != 0 ) &&
3088
+ " Expects SVE2.1 or SME2 target and a predicate register" );
3089
+ #ifdef EXPENSIVE_CHECKS
3090
+ auto IsPPR = [](const RegPairInfo &c) {
3091
+ return c.Reg1 == RegPairInfo::PPR;
3092
+ };
3093
+ auto PPRBegin = std::find_if (RegPairs.begin (), RegPairs.end (), IsPPR);
3094
+ auto IsZPR = [](const RegPairInfo &c) {
3095
+ return c.Type == RegPairInfo::ZPR;
3096
+ };
3097
+ auto ZPRBegin = std::find_if (RegPairs.begin (), RegPairs.end (), IsZPR);
3098
+ assert (!(PPRBegin < ZPRBegin) &&
3099
+ " Expected callee save predicate to be handled first" );
3100
+ #endif
3101
+ if (!PTrueCreated) {
3102
+ PTrueCreated = true ;
3103
+ BuildMI (MBB, MI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3104
+ .setMIFlags (MachineInstr::FrameSetup);
3105
+ }
3106
+ MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3107
+ if (!MRI.isReserved (Reg1))
3108
+ MBB.addLiveIn (Reg1);
3072
3109
if (!MRI.isReserved (Reg2))
3073
3110
MBB.addLiveIn (Reg2);
3074
- MIB.addReg (Reg2, getPrologueDeath (MF, Reg2 ));
3111
+ MIB.addReg (/* PairRegs */ AArch64::Z0_Z1 + (RPI. Reg1 - AArch64::Z0 ));
3075
3112
MIB.addMemOperand (MF.getMachineMemOperand (
3076
3113
MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3077
3114
MachineMemOperand::MOStore, Size, Alignment));
3115
+ MIB.addReg (PnReg);
3116
+ MIB.addReg (AArch64::SP)
3117
+ .addImm (RPI.Offset ) // [sp, #offset*scale],
3118
+ // where factor*scale is implicit
3119
+ .setMIFlag (MachineInstr::FrameSetup);
3120
+ MIB.addMemOperand (MF.getMachineMemOperand (
3121
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3122
+ MachineMemOperand::MOStore, Size, Alignment));
3123
+ if (NeedsWinCFI)
3124
+ InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3125
+ } else { // The code when the pair of ZReg is not present
3126
+ MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3127
+ if (!MRI.isReserved (Reg1))
3128
+ MBB.addLiveIn (Reg1);
3129
+ if (RPI.isPaired ()) {
3130
+ if (!MRI.isReserved (Reg2))
3131
+ MBB.addLiveIn (Reg2);
3132
+ MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
3133
+ MIB.addMemOperand (MF.getMachineMemOperand (
3134
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3135
+ MachineMemOperand::MOStore, Size, Alignment));
3136
+ }
3137
+ MIB.addReg (Reg1, getPrologueDeath (MF, Reg1))
3138
+ .addReg (AArch64::SP)
3139
+ .addImm (RPI.Offset ) // [sp, #offset*scale],
3140
+ // where factor*scale is implicit
3141
+ .setMIFlag (MachineInstr::FrameSetup);
3142
+ MIB.addMemOperand (MF.getMachineMemOperand (
3143
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3144
+ MachineMemOperand::MOStore, Size, Alignment));
3145
+ if (NeedsWinCFI)
3146
+ InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3078
3147
}
3079
- MIB.addReg (Reg1, getPrologueDeath (MF, Reg1))
3080
- .addReg (AArch64::SP)
3081
- .addImm (RPI.Offset ) // [sp, #offset*scale],
3082
- // where factor*scale is implicit
3083
- .setMIFlag (MachineInstr::FrameSetup);
3084
- MIB.addMemOperand (MF.getMachineMemOperand (
3085
- MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3086
- MachineMemOperand::MOStore, Size, Alignment));
3087
- if (NeedsWinCFI)
3088
- InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3089
-
3090
3148
// Update the StackIDs of the SVE stack slots.
3091
3149
MachineFrameInfo &MFI = MF.getFrameInfo ();
3092
- if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
3093
- MFI.setStackID (RPI.FrameIdx , TargetStackID::ScalableVector);
3094
-
3150
+ if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
3151
+ MFI.setStackID (FrameIdxReg1, TargetStackID::ScalableVector);
3152
+ if (RPI.isPaired ())
3153
+ MFI.setStackID (FrameIdxReg2, TargetStackID::ScalableVector);
3154
+ }
3095
3155
}
3096
3156
return true ;
3097
3157
}
@@ -3109,7 +3169,6 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3109
3169
DL = MBBI->getDebugLoc ();
3110
3170
3111
3171
computeCalleeSaveRegisterPairs (MF, CSI, TRI, RegPairs, hasFP (MF));
3112
-
3113
3172
if (homogeneousPrologEpilog (MF, &MBB)) {
3114
3173
auto MIB = BuildMI (MBB, MBBI, DL, TII.get (AArch64::HOM_Epilog))
3115
3174
.setMIFlag (MachineInstr::FrameDestroy);
@@ -3130,6 +3189,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3130
3189
auto ZPREnd = std::find_if_not (ZPRBegin, RegPairs.end (), IsZPR);
3131
3190
std::reverse (ZPRBegin, ZPREnd);
3132
3191
3192
+ bool PTrueCreated = false ;
3133
3193
for (const RegPairInfo &RPI : RegPairs) {
3134
3194
unsigned Reg1 = RPI.Reg1 ;
3135
3195
unsigned Reg2 = RPI.Reg2 ;
@@ -3162,7 +3222,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3162
3222
Alignment = Align (16 );
3163
3223
break ;
3164
3224
case RegPairInfo::ZPR:
3165
- LdrOpc = AArch64::LDR_ZXI;
3225
+ LdrOpc = RPI. isPaired () ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
3166
3226
Size = 16 ;
3167
3227
Alignment = Align (16 );
3168
3228
break ;
@@ -3187,25 +3247,58 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3187
3247
std::swap (Reg1, Reg2);
3188
3248
std::swap (FrameIdxReg1, FrameIdxReg2);
3189
3249
}
3190
- MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3191
- if (RPI.isPaired ()) {
3192
- MIB.addReg (Reg2, getDefRegState (true ));
3250
+
3251
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3252
+ if (RPI.isPaired () && RPI.isScalable ()) {
3253
+ const AArch64Subtarget &Subtarget = MF.getSubtarget <AArch64Subtarget>();
3254
+ unsigned PnReg = AFI->getPredicateRegForFillSpill ();
3255
+ assert (((Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ()) && PnReg != 0 ) &&
3256
+ " Expects SVE2.1 or SME2 target and a predicate register" );
3257
+ #ifdef EXPENSIVE_CHECKS
3258
+ assert (!(PPRBegin < ZPRBegin) &&
3259
+ " Expected callee save predicate to be handled first" );
3260
+ #endif
3261
+ if (!PTrueCreated) {
3262
+ PTrueCreated = true ;
3263
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3264
+ .setMIFlags (MachineInstr::FrameDestroy);
3265
+ }
3266
+ MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3267
+ MIB.addReg (/* PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0),
3268
+ getDefRegState (true ));
3193
3269
MIB.addMemOperand (MF.getMachineMemOperand (
3194
3270
MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3195
3271
MachineMemOperand::MOLoad, Size, Alignment));
3272
+ MIB.addReg (PnReg);
3273
+ MIB.addReg (AArch64::SP)
3274
+ .addImm (RPI.Offset ) // [sp, #offset*scale]
3275
+ // where factor*scale is implicit
3276
+ .setMIFlag (MachineInstr::FrameDestroy);
3277
+ MIB.addMemOperand (MF.getMachineMemOperand (
3278
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3279
+ MachineMemOperand::MOLoad, Size, Alignment));
3280
+ if (NeedsWinCFI)
3281
+ InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
3282
+ } else {
3283
+ MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3284
+ if (RPI.isPaired ()) {
3285
+ MIB.addReg (Reg2, getDefRegState (true ));
3286
+ MIB.addMemOperand (MF.getMachineMemOperand (
3287
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3288
+ MachineMemOperand::MOLoad, Size, Alignment));
3289
+ }
3290
+ MIB.addReg (Reg1, getDefRegState (true ));
3291
+ MIB.addReg (AArch64::SP)
3292
+ .addImm (RPI.Offset ) // [sp, #offset*scale]
3293
+ // where factor*scale is implicit
3294
+ .setMIFlag (MachineInstr::FrameDestroy);
3295
+ MIB.addMemOperand (MF.getMachineMemOperand (
3296
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3297
+ MachineMemOperand::MOLoad, Size, Alignment));
3298
+ if (NeedsWinCFI)
3299
+ InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
3196
3300
}
3197
- MIB.addReg (Reg1, getDefRegState (true ))
3198
- .addReg (AArch64::SP)
3199
- .addImm (RPI.Offset ) // [sp, #offset*scale]
3200
- // where factor*scale is implicit
3201
- .setMIFlag (MachineInstr::FrameDestroy);
3202
- MIB.addMemOperand (MF.getMachineMemOperand (
3203
- MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3204
- MachineMemOperand::MOLoad, Size, Alignment));
3205
- if (NeedsWinCFI)
3206
- InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
3207
3301
}
3208
-
3209
3302
return true ;
3210
3303
}
3211
3304
@@ -3234,6 +3327,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
3234
3327
3235
3328
unsigned ExtraCSSpill = 0 ;
3236
3329
bool HasUnpairedGPR64 = false ;
3330
+ bool HasPairZReg = false ;
3237
3331
// Figure out which callee-saved registers to save/restore.
3238
3332
for (unsigned i = 0 ; CSRegs[i]; ++i) {
3239
3333
const unsigned Reg = CSRegs[i];
@@ -3287,6 +3381,28 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
3287
3381
!RegInfo->isReservedReg (MF, PairedReg))
3288
3382
ExtraCSSpill = PairedReg;
3289
3383
}
3384
+ // Check if there is a pair of ZRegs, so it can select PReg for spill/fill
3385
+ HasPairZReg |= (AArch64::ZPRRegClass.contains (Reg, CSRegs[i ^ 1 ]) &&
3386
+ SavedRegs.test (CSRegs[i ^ 1 ]));
3387
+ }
3388
+
3389
+ if (HasPairZReg && (Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ())) {
3390
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3391
+ // Find a suitable predicate register for the multi-vector spill/fill
3392
+ // instructions.
3393
+ unsigned PnReg = findFreePredicateReg (SavedRegs);
3394
+ if (PnReg != AArch64::NoRegister)
3395
+ AFI->setPredicateRegForFillSpill (PnReg);
3396
+ // If no free callee-save has been found assign one.
3397
+ if (!AFI->getPredicateRegForFillSpill () &&
3398
+ MF.getFunction ().getCallingConv () ==
3399
+ CallingConv::AArch64_SVE_VectorCall) {
3400
+ SavedRegs.set (AArch64::P8);
3401
+ AFI->setPredicateRegForFillSpill (AArch64::PN8);
3402
+ }
3403
+
3404
+ assert (!RegInfo->isReservedReg (MF, AFI->getPredicateRegForFillSpill ()) &&
3405
+ " Predicate cannot be a reserved register" );
3290
3406
}
3291
3407
3292
3408
if (MF.getFunction ().getCallingConv () == CallingConv::Win64 &&
0 commit comments