Skip to content

[NFC][SPIRV] Cleanup selectOpWithSrc functions #117077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 23 additions & 39 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectGlobalValue(Register ResVReg, MachineInstr &I,
const MachineInstr *Init = nullptr) const;

bool selectNAryOpWithSrcs(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, std::vector<Register> SrcRegs,
unsigned Opcode) const;
bool selectOpWithSrcs(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, std::vector<Register> SrcRegs,
unsigned Opcode) const;

bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, Register SrcReg,
unsigned Opcode) const;
bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
unsigned Opcode) const;

Expand Down Expand Up @@ -859,11 +856,11 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
return false;
}

bool SPIRVInstructionSelector::selectNAryOpWithSrcs(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
std::vector<Register> Srcs,
unsigned Opcode) const {
bool SPIRVInstructionSelector::selectOpWithSrcs(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
std::vector<Register> Srcs,
unsigned Opcode) const {
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType));
Expand All @@ -873,18 +870,6 @@ bool SPIRVInstructionSelector::selectNAryOpWithSrcs(Register ResVReg,
return MIB.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectUnOpWithSrc(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
Register SrcReg,
unsigned Opcode) const {
return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(SrcReg)
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
Expand Down Expand Up @@ -920,8 +905,8 @@ bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
}
return selectUnOpWithSrc(ResVReg, ResType, I, I.getOperand(1).getReg(),
Opcode);
return selectOpWithSrcs(ResVReg, ResType, I, {I.getOperand(1).getReg()},
Opcode);
}

bool SPIRVInstructionSelector::selectBitcast(Register ResVReg,
Expand Down Expand Up @@ -1066,7 +1051,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
ValTy, I, TII, SPIRV::StorageClass::UniformConstant);
SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
selectUnOpWithSrc(SrcReg, SourceTy, I, VarReg, SPIRV::OpBitcast);
selectOpWithSrcs(SrcReg, SourceTy, I, {VarReg}, SPIRV::OpBitcast);
}
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCopyMemorySized))
.addUse(I.getOperand(0).getReg())
Expand Down Expand Up @@ -1111,7 +1096,7 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
if (NegateOpcode != 0) {
// Translation with negative value operand is requested
Register TmpReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
Result &= selectUnOpWithSrc(TmpReg, ResType, I, ValueReg, NegateOpcode);
Result &= selectOpWithSrcs(TmpReg, ResType, I, {ValueReg}, NegateOpcode);
ValueReg = TmpReg;
}

Expand Down Expand Up @@ -2374,7 +2359,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg,
SrcReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
selectSelect(SrcReg, TmpType, I, false);
}
return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, Opcode);
return selectOpWithSrcs(ResVReg, ResType, I, {SrcReg}, Opcode);
}

bool SPIRVInstructionSelector::selectExt(Register ResVReg,
Expand Down Expand Up @@ -3068,7 +3053,7 @@ bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
// zero or sign extend
Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
bool Result =
selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(), Opcode);
selectOpWithSrcs(ExtReg, ResType, I, {I.getOperand(2).getReg()}, Opcode);
return Result && selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
}

Expand Down Expand Up @@ -3100,7 +3085,7 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
GR.getOrCreateSPIRVVectorType(baseType, 2 * count, MIRBuilder);
Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
bool Result =
selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg, SPIRV::OpBitcast);
selectOpWithSrcs(bitcastReg, postCastT, I, {OpReg}, SPIRV::OpBitcast);

// 2. call firstbithigh
Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
Expand All @@ -3114,11 +3099,11 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
if (isScalarRes) {
// if scalar do a vector extract
Result &= selectNAryOpWithSrcs(
Result &= selectOpWithSrcs(
HighReg, ResType, I,
{FBHReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
SPIRV::OpVectorExtractDynamic);
Result &= selectNAryOpWithSrcs(
Result &= selectOpWithSrcs(
LowReg, ResType, I,
{FBHReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
SPIRV::OpVectorExtractDynamic);
Expand Down Expand Up @@ -3176,21 +3161,20 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,

// check if the high bits are == -1; true if -1
Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
SPIRV::OpIEqual);
Result &= selectOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
SPIRV::OpIEqual);

// Select low bits if true in BReg, otherwise high bits
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
Result &= selectNAryOpWithSrcs(TmpReg, ResType, I, {BReg, LowReg, HighReg},
selectOp);
Result &=
selectOpWithSrcs(TmpReg, ResType, I, {BReg, LowReg, HighReg}, selectOp);

// Add 32 for high bits, 0 for low bits
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
Result &=
selectNAryOpWithSrcs(ValReg, ResType, I, {BReg, Reg0, Reg32}, selectOp);
Result &= selectOpWithSrcs(ValReg, ResType, I, {BReg, Reg0, Reg32}, selectOp);

return Result &&
selectNAryOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg}, addOp);
selectOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg}, addOp);
}

bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
Expand Down
Loading