Skip to content

[SPIRV] Implement log10 for logical SPIR-V #66921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,24 +244,27 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
auto &MF = MIRBuilder.getMF();
const Type *LLVMFPTy;
if (SpvType) {
LLVMFPTy = getTypeForSPIRVType(SpvType);
assert(LLVMFPTy->isFloatingPointTy());
} else {
LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
auto &Ctx = MF.getFunction().getContext();
if (!SpvType) {
const Type *LLVMFPTy = Type::getFloatTy(Ctx);
SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
}
// Find a constant in DT or build a new one.
const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
const auto ConstFP = ConstantFP::get(Ctx, Val);
Register Res = DT.find(ConstFP, &MF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
assignSPIRVTypeToVReg(SpvType, Res, MF);
DT.add(ConstFP, &MF, Res);
MIRBuilder.buildFConstant(Res, *ConstFP);

MachineInstrBuilder MIB;
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
}

return Res;
}

Expand Down
56 changes: 55 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, const ExtInstList &ExtInsts) const;

bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Register buildI32Constant(uint32_t Val, MachineInstr &I,
const SPIRVType *ResType = nullptr) const;

Expand Down Expand Up @@ -362,7 +365,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_FLOG2:
return selectExtInst(ResVReg, ResType, I, CL::log2, GL::Log2);
case TargetOpcode::G_FLOG10:
return selectExtInst(ResVReg, ResType, I, CL::log10);
return selectLog10(ResVReg, ResType, I);

case TargetOpcode::G_FABS:
return selectExtInst(ResVReg, ResType, I, CL::fabs, GL::FAbs);
Expand Down Expand Up @@ -1562,6 +1565,57 @@ bool SPIRVInstructionSelector::selectGlobalValue(
return Reg.isValid();
}

bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
if (STI.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
return selectExtInst(ResVReg, ResType, I, CL::log10);
}

// There is no log10 instruction in the GLSL Extended Instruction set, so it
// is implemented as:
// log10(x) = log2(x) * (1 / log2(10))
// = log2(x) * 0.30103

MachineIRBuilder MIRBuilder(I);
MachineBasicBlock &BB = *I.getParent();

// Build log2(x).
Register VarReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
bool Result =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
.addDef(VarReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
.addImm(GL::Log2)
.add(I.getOperand(1))
.constrainAllUses(TII, TRI, RBI);

// Build 0.30103.
assert(ResType->getOpcode() == SPIRV::OpTypeVector ||
ResType->getOpcode() == SPIRV::OpTypeFloat);
// TODO: Add matrix implementation once supported by the HLSL frontend.
const SPIRVType *SpirvScalarType =
ResType->getOpcode() == SPIRV::OpTypeVector
? GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg())
: ResType;
Register ScaleReg =
GR.buildConstantFP(APFloat(0.30103f), MIRBuilder, SpirvScalarType);

// Multiply log2(x) by 0.30103 to get log10(x) result.
auto Opcode = ResType->getOpcode() == SPIRV::OpTypeVector
? SPIRV::OpVectorTimesScalar
: SPIRV::OpFMulS;
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(VarReg)
.addUse(ScaleReg)
.constrainAllUses(TII, TRI, RBI);

return Result;
}

namespace llvm {
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});

// TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
// tighten these requirements. Many of these math functions are only legal on
// specific bitwidths, so they are not selectable for
// allFloatScalarsAndVectors.
getActionDefinitionsBuilder({G_FPOW,
G_FEXP,
G_FEXP2,
G_FLOG,
G_FLOG2,
G_FLOG10,
G_FABS,
G_FMINNUM,
G_FMAXNUM,
Expand All @@ -259,8 +264,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
allFloatScalarsAndVectors, allIntScalarsAndVectors);

if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors);

getActionDefinitionsBuilder(
{G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
.legalForCartesianProduct(allIntScalarsAndVectors,
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s

; CHECK: %[[#extinst:]] = OpExtInstImport "GLSL.std.450"

; CHECK: %[[#float:]] = OpTypeFloat 32
; CHECK: %[[#v4float:]] = OpTypeVector %[[#float]] 4
; CHECK: %[[#float_0_30103001:]] = OpConstant %[[#float]] 0.30103000998497009
; CHECK: %[[#_ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
; CHECK: %[[#_ptr_Function_float:]] = OpTypePointer Function %[[#float]]

define void @main() {
entry:
; CHECK: %[[#f:]] = OpVariable %[[#_ptr_Function_float]] Function
; CHECK: %[[#logf:]] = OpVariable %[[#_ptr_Function_float]] Function
; CHECK: %[[#f4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
; CHECK: %[[#logf4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
%f = alloca float, align 4
%logf = alloca float, align 4
%f4 = alloca <4 x float>, align 16
%logf4 = alloca <4 x float>, align 16

; CHECK: %[[#load:]] = OpLoad %[[#float]] %[[#f]] Aligned 4
; CHECK: %[[#log2:]] = OpExtInst %[[#float]] %[[#extinst]] Log2 %[[#load]]
; CHECK: %[[#res:]] = OpFMul %[[#float]] %[[#log2]] %[[#float_0_30103001]]
; CHECK: OpStore %[[#logf]] %[[#res]] Aligned 4
%0 = load float, ptr %f, align 4
%elt.log10 = call float @llvm.log10.f32(float %0)
store float %elt.log10, ptr %logf, align 4

; CHECK: %[[#load:]] = OpLoad %[[#v4float]] %[[#f4]] Aligned 16
; CHECK: %[[#log2:]] = OpExtInst %[[#v4float]] %[[#extinst]] Log2 %[[#load]]
; CHECK: %[[#res:]] = OpVectorTimesScalar %[[#v4float]] %[[#log2]] %[[#float_0_30103001]]
; CHECK: OpStore %[[#logf4]] %[[#res]] Aligned 16
%1 = load <4 x float>, ptr %f4, align 16
%elt.log101 = call <4 x float> @llvm.log10.v4f32(<4 x float> %1)
store <4 x float> %elt.log101, ptr %logf4, align 16

ret void
}

declare float @llvm.log10.f32(float)
declare <4 x float> @llvm.log10.v4f32(<4 x float>)