Skip to content

Commit 0a21da2

Browse files
committed
[Backport to 17] Add __builtin_spirv_ internal builtins (KhronosGroup#3374)
Way they are implemented is described in: KhronosGroup#3221 The PR also adds SPV_EXT_float8 extension and packed conversions for SPV_INTEL_int4 Currently only conversion instructions (and internal builtins) are covered. TODO: in the following PR Saturation decoration will be added. Signed-off-by: Sidorov, Dmitry <[email protected]> --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 3f4d6e8 commit 0a21da2

File tree

14 files changed

+979
-14
lines changed

14 files changed

+979
-14
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,6 @@ EXT(SPV_INTEL_ternary_bitwise_function)
8181
EXT(SPV_INTEL_int4)
8282
EXT(SPV_INTEL_function_variants)
8383
EXT(SPV_INTEL_shader_atomic_bfloat16)
84+
EXT(SPV_EXT_float8)
8485
EXT(SPV_INTEL_predicated_io)
8586
EXT(SPV_INTEL_sigmoid)

lib/SPIRV/SPIRVInternal.h

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL";
373373
const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL";
374374
const static char ConvertHandleToSampledImageINTEL[] =
375375
"ConvertHandleToSampledImageINTEL";
376+
const static char InternalBuiltinPrefix[] = "__builtin_spirv_";
376377
} // namespace kSPIRVName
377378

378379
namespace kSPIRVPostfix {
@@ -659,7 +660,7 @@ Op getSPIRVFuncOC(StringRef Name, SmallVectorImpl<std::string> *Dec = nullptr);
659660
bool getSPIRVBuiltin(const std::string &Name, spv::BuiltIn &Builtin);
660661

661662
/// \param Name LLVM function name
662-
/// \param DemangledName demanged name of the OpenCL built-in function
663+
/// \param DemangledName demangled name of the OpenCL built-in function
663664
/// \returns true if Name is the name of the OpenCL built-in function,
664665
/// false for other functions
665666
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp = false);
@@ -722,6 +723,9 @@ CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
722723
StringRef InstName = SPIR_TEMP_NAME_PREFIX_CALL,
723724
bool TakeFuncName = true);
724725

726+
/// Check if an LLVM type is spirv.CooperativeMatrixKHR.
727+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty);
728+
725729
/// Add a call instruction for SPIR-V builtin function.
726730
CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
727731
ArrayRef<Value *> Args, AttributeList *Attrs,
@@ -1040,8 +1044,84 @@ bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false);
10401044

10411045
bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false);
10421046

1043-
template <typename T>
1044-
MetadataAsValue *map2MDString(LLVMContext &C, SPIRVValue *V);
1047+
/// \param MangledName LLVM function name.
1048+
/// \param DemangledName demangled name of the input function if it is the
1049+
/// translator's internal built-in function.
1050+
/// \returns true if MangledName is the name of the translator's internal
1051+
/// built-in function, false for other functions.
1052+
/// Used for 'mini'-floats conversion functions
1053+
bool isInternalSPIRVBuiltin(StringRef MangledName, StringRef &DemangledName);
1054+
1055+
// Wrapper around SPIR-V 1.6.4 FP Encoding to be used in the conversion
1056+
// descriptor
1057+
enum FPEncodingWrap {
1058+
Integer = FPEncoding::FPEncodingMax - 1,
1059+
IEEE754 = FPEncoding::FPEncodingMax,
1060+
BF16 = FPEncoding::FPEncodingBFloat16KHR,
1061+
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
1062+
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1063+
};
1064+
1065+
// Structure describing non-trivial conversions (FP8 and int4)
1066+
struct FPConversionDesc {
1067+
FPEncodingWrap SrcEncoding;
1068+
FPEncodingWrap DstEncoding;
1069+
SPIRVWord ConvOpCode;
1070+
1071+
// To use as a key in std::map
1072+
bool operator==(const FPConversionDesc &Other) const {
1073+
return SrcEncoding == Other.SrcEncoding &&
1074+
DstEncoding == Other.DstEncoding && ConvOpCode == Other.ConvOpCode;
1075+
}
1076+
1077+
bool operator<(const FPConversionDesc &Other) const {
1078+
if (ConvOpCode != Other.ConvOpCode)
1079+
return ConvOpCode < Other.ConvOpCode;
1080+
if (SrcEncoding != Other.SrcEncoding)
1081+
return SrcEncoding < Other.SrcEncoding;
1082+
return DstEncoding < Other.DstEncoding;
1083+
}
1084+
};
1085+
1086+
// Maps internal builtin name to conversion descriptor
1087+
typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
1088+
1089+
// clang-format off
1090+
template <> inline void FPConvertToEncodingMap::init() {
1091+
// 8-bit conversions
1092+
add("ConvertE4M3ToFP16EXT",
1093+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1094+
add("ConvertE5M2ToFP16EXT",
1095+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1096+
add("ConvertE4M3ToBF16EXT",
1097+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1098+
add("ConvertE5M2ToBF16EXT",
1099+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1100+
add("ConvertFP16ToE4M3EXT",
1101+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1102+
add("ConvertFP16ToE5M2EXT",
1103+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1104+
add("ConvertBF16ToE4M3EXT",
1105+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1106+
add("ConvertBF16ToE5M2EXT",
1107+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1108+
1109+
add("ConvertInt4ToE4M3INTEL",
1110+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1111+
add("ConvertInt4ToE5M2INTEL",
1112+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1113+
add("ConvertInt4ToFP16INTEL",
1114+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1115+
add("ConvertInt4ToBF16INTEL",
1116+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1117+
add("ConvertFP16ToInt4INTEL",
1118+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1119+
add("ConvertBF16ToInt4INTEL",
1120+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1121+
}
1122+
1123+
// clang-format on
1124+
10451125
} // namespace SPIRV
10461126

10471127
#endif // SPIRV_SPIRVINTERNAL_H

lib/SPIRV/SPIRVReader.cpp

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,9 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
313313

314314
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
315315
switch (T->getFloatBitWidth()) {
316+
case 8:
317+
// No LLVM IR counter part for FP8 - map it on i8
318+
return Type::getIntNTy(*Context, 8);
316319
case 16:
317320
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
318321
return Type::getBFloatTy(*Context);
@@ -1057,6 +1060,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10571060
CastInst::CastOps CO = Instruction::BitCast;
10581061
bool IsExt =
10591062
Dst->getScalarSizeInBits() > Src->getType()->getScalarSizeInBits();
1063+
1064+
auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap {
1065+
if (Ty->isTypeFloat()) {
1066+
unsigned Enc =
1067+
static_cast<SPIRVTypeFloat *>(Ty)->getFloatingPointEncoding();
1068+
return static_cast<FPEncodingWrap>(Enc);
1069+
}
1070+
if (Ty->isTypeInt())
1071+
return FPEncodingWrap::Integer;
1072+
return FPEncodingWrap::IEEE754;
1073+
};
1074+
1075+
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1076+
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1077+
};
1078+
10601079
switch (BC->getOpCode()) {
10611080
case OpPtrCastToGeneric:
10621081
case OpGenericCastToPtr:
@@ -1078,10 +1097,58 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10781097
case OpUConvert:
10791098
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10801099
break;
1081-
case OpFConvert:
1082-
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1100+
case OpConvertSToF:
1101+
case OpConvertFToS:
1102+
case OpConvertUToF:
1103+
case OpConvertFToU:
1104+
case OpFConvert: {
1105+
const auto OC = BC->getOpCode();
1106+
{
1107+
auto SPVOps = BC->getOperands();
1108+
auto *SPVSrcTy = SPVOps[0]->getType();
1109+
auto *SPVDstTy = BC->getType();
1110+
1111+
auto GetEncodingAndUpdateType =
1112+
[GetFPEncoding](SPIRVType *&SPVTy) -> FPEncodingWrap {
1113+
if (SPVTy->isTypeVector()) {
1114+
SPVTy = SPVTy->getVectorComponentType();
1115+
} else if (SPVTy->isTypeCooperativeMatrixKHR()) {
1116+
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVTy);
1117+
SPVTy = MT->getCompType();
1118+
}
1119+
return GetFPEncoding(SPVTy);
1120+
};
1121+
1122+
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
1123+
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1124+
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1125+
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
1126+
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
1127+
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
1128+
std::vector<Value *> Ops = {Src};
1129+
std::vector<Type *> OpsTys = {Src->getType()};
1130+
1131+
std::string BuiltinName =
1132+
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
1133+
BuiltinFuncMangleInfo Info;
1134+
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1135+
1136+
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
1137+
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
1138+
return CallInst::Create(Func, Ops, "", BB);
1139+
}
1140+
}
1141+
1142+
if (OC == OpFConvert) {
1143+
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1144+
break;
1145+
}
1146+
CO = static_cast<CastInst::CastOps>(OpCodeMap::rmap(OC));
10831147
break;
1148+
}
10841149
case OpBitcast:
1150+
if (!Dst->isPointerTy() && Dst == Src->getType())
1151+
return Src;
10851152
// OpBitcast need to be handled as a special-case when the source is a
10861153
// pointer and the destination is not a pointer, and where the source is not
10871154
// a pointer and the destination is a pointer. This is supported by the
@@ -2890,11 +2957,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
28902957
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
28912958
auto BI = static_cast<SPIRVInstruction *>(BV);
28922959
Value *Inst = nullptr;
2893-
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion() ||
2894-
BI->getType()->isTypeCooperativeMatrixKHR())
2960+
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion()) {
28952961
Inst = transSPIRVBuiltinFromInst(BI, BB);
2896-
else
2962+
} else if (BI->getType()->isTypeCooperativeMatrixKHR()) {
2963+
// For cooperative matrix conversions generate __builtin_spirv
2964+
// conversions instead of __spirv_FConvert in case of mini-float
2965+
// type element type.
2966+
auto *OutMatrixElementTy =
2967+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(BI->getType())
2968+
->getCompType();
2969+
auto *InMatrixElementTy =
2970+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(
2971+
static_cast<SPIRVUnary *>(BI)->getOperand(0)->getType())
2972+
->getCompType();
2973+
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
2974+
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
2975+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
2976+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
2977+
Inst = transConvertInst(BV, F, BB);
2978+
else
2979+
Inst = transSPIRVBuiltinFromInst(BI, BB);
2980+
} else {
28972981
Inst = transConvertInst(BV, F, BB);
2982+
}
28982983
return mapValue(BV, Inst);
28992984
}
29002985
return mapValue(

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
// This file needs to be included before anything that declares
4242
// llvm::PointerType to avoid a compilation bug on MSVC.
43+
#include "llvm/Demangle/Demangle.h"
4344
#include "llvm/Demangle/ItaniumDemangle.h"
4445

4546
#include "FunctionDescriptor.h"
@@ -265,6 +266,12 @@ bool isSYCLBfloat16Type(llvm::Type *Ty) {
265266
return false;
266267
}
267268

269+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty) {
270+
if (auto *TargetTy = dyn_cast<TargetExtType>(Ty))
271+
return TargetTy->getName() == "spirv.CooperativeMatrixKHR";
272+
return false;
273+
}
274+
268275
Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
269276
StringRef Name, BuiltinFuncMangleInfo *Mangle,
270277
AttributeList *Attrs, bool TakeName) {
@@ -464,7 +471,7 @@ bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
464471
return getByName(R.str(), B);
465472
}
466473

467-
// Demangled name is a substring of the name. The DemangledName is updated only
474+
// DemangledName is a substring of Name. The DemangledName is updated only
468475
// if true is returned
469476
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
470477
if (Name == "printf") {
@@ -509,6 +516,21 @@ bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
509516
return false;
510517
}
511518

519+
// DemangledName is a substring of Name. The DemangledName is updated only
520+
// if true is returned.
521+
bool isInternalSPIRVBuiltin(StringRef Name, StringRef &DemangledName) {
522+
if (!Name.starts_with("_Z"))
523+
return false;
524+
constexpr unsigned DemangledNameLenStart = 2;
525+
size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
526+
if (!Name.substr(Start, Name.size() - 1)
527+
.starts_with(kSPIRVName::InternalBuiltinPrefix))
528+
return false;
529+
DemangledName = llvm::itaniumDemangle(Name.data());
530+
DemangledName.consume_front(kSPIRVName::InternalBuiltinPrefix);
531+
return true;
532+
}
533+
512534
// Check if a mangled type Name is unsigned
513535
bool isMangledTypeUnsigned(char Mangled) {
514536
return Mangled == 'h' /* uchar */

0 commit comments

Comments
 (0)