Skip to content

Commit 524c334

Browse files
committed
Add __builtin_spirv_ internal builtins
Way they are implemented is described in: KhronosGroup#3221 The PR also adds SPV_EXT_float8 extension and packed conversions for SPV_INTEL_int4 Currently only conversion instructions (and internal builtins) are covered. TODO: in the following PR Saturation decoration will be added. Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 095a8c2 commit 524c334

File tree

14 files changed

+974
-10
lines changed

14 files changed

+974
-10
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,4 @@ EXT(SPV_INTEL_ternary_bitwise_function)
8282
EXT(SPV_INTEL_int4)
8383
EXT(SPV_INTEL_function_variants)
8484
EXT(SPV_INTEL_shader_atomic_bfloat16)
85+
EXT(SPV_EXT_float8)

lib/SPIRV/SPIRVInternal.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL";
373373
const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL";
374374
const static char ConvertHandleToSampledImageINTEL[] =
375375
"ConvertHandleToSampledImageINTEL";
376+
const static char InternalPrefix[] = "__builtin_spirv_";
376377
} // namespace kSPIRVName
377378

378379
namespace kSPIRVPostfix {
@@ -728,6 +729,9 @@ CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
728729
StringRef InstName = SPIR_TEMP_NAME_PREFIX_CALL,
729730
bool TakeFuncName = true);
730731

732+
/// Check if an LLVM type is spirv.CooperativeMatrixKHR.
733+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty);
734+
731735
/// Add a call instruction for SPIR-V builtin function.
732736
CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
733737
ArrayRef<Value *> Args, AttributeList *Attrs,
@@ -1029,6 +1033,85 @@ bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false);
10291033

10301034
bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false);
10311035

1036+
/// \param Name LLVM function name
1037+
/// \param DemangledName demanged name of the translator's internal built-in
1038+
/// function.
1039+
/// \returns true if Name is the name of the translator's internal built-in
1040+
/// function, false for other functions
1041+
/// Used for 'mini'-floats conversion functions
1042+
bool isInternalSPIRVBuiltin(StringRef Name, StringRef &DemangledName);
1043+
1044+
// Wrapper around SPIR-V 1.6.4 FP Encoding to be used in the conversion
1045+
// descriptor
1046+
enum FPEncodingWrap {
1047+
Integer = FPEncoding::FPEncodingMax - 1,
1048+
IEEE754 = FPEncoding::FPEncodingMax,
1049+
BF16 = FPEncoding::FPEncodingBFloat16KHR,
1050+
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
1051+
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1052+
};
1053+
1054+
// Structure describing non-trivial conversions (FP8 and int4)
1055+
struct FPConversionDesc {
1056+
FPEncodingWrap SrcEncoding;
1057+
FPEncodingWrap DstEncoding;
1058+
SPIRVWord ConvOpCode;
1059+
1060+
// To use as a key in std::map
1061+
bool operator==(const FPConversionDesc &Other) const {
1062+
return SrcEncoding == Other.SrcEncoding &&
1063+
DstEncoding == Other.DstEncoding && ConvOpCode == Other.ConvOpCode;
1064+
}
1065+
1066+
bool operator<(const FPConversionDesc &Other) const {
1067+
if (ConvOpCode != Other.ConvOpCode)
1068+
return ConvOpCode < Other.ConvOpCode;
1069+
if (SrcEncoding != Other.SrcEncoding)
1070+
return SrcEncoding < Other.SrcEncoding;
1071+
return DstEncoding < Other.DstEncoding;
1072+
}
1073+
};
1074+
1075+
// Maps internal builtin name to conversion descriptor
1076+
typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
1077+
1078+
// clang-format off
1079+
typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
1080+
template <> inline void FPConvertToEncodingMap::init() {
1081+
// 8-bit conversions
1082+
add("ConvertE4M3ToFP16EXT",
1083+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1084+
add("ConvertE5M2ToFP16EXT",
1085+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1086+
add("ConvertE4M3ToBF16EXT",
1087+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1088+
add("ConvertE5M2ToBF16EXT",
1089+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1090+
add("ConvertFP16ToE4M3EXT",
1091+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1092+
add("ConvertFP16ToE5M2EXT",
1093+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1094+
add("ConvertBF16ToE4M3EXT",
1095+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1096+
add("ConvertBF16ToE5M2EXT",
1097+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1098+
1099+
add("ConvertInt4ToE4M3INTEL",
1100+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1101+
add("ConvertInt4ToE5M2INTEL",
1102+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1103+
add("ConvertInt4ToFP16INTEL",
1104+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1105+
add("ConvertInt4ToBF16INTEL",
1106+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1107+
add("ConvertFP16ToInt4INTEL",
1108+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1109+
add("ConvertBF16ToInt4INTEL",
1110+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1111+
}
1112+
1113+
// clang-format on
1114+
10321115
} // namespace SPIRV
10331116

10341117
#endif // SPIRV_SPIRVINTERNAL_H

lib/SPIRV/SPIRVReader.cpp

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
314314

315315
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
316316
switch (T->getFloatBitWidth()) {
317+
case 8:
318+
// No LLVM IR counter part for FP8 - map it on i8
319+
return Type::getIntNTy(*Context, 8);
317320
case 16:
318321
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
319322
return Type::getBFloatTy(*Context);
@@ -1066,6 +1069,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10661069
CastInst::CastOps CO = Instruction::BitCast;
10671070
bool IsExt =
10681071
Dst->getScalarSizeInBits() > Src->getType()->getScalarSizeInBits();
1072+
1073+
auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap {
1074+
if (Ty->isTypeFloat()) {
1075+
unsigned Enc =
1076+
static_cast<SPIRVTypeFloat *>(Ty)->getFloatingPointEncoding();
1077+
return static_cast<FPEncodingWrap>(Enc);
1078+
}
1079+
if (Ty->isTypeInt())
1080+
return FPEncodingWrap::Integer;
1081+
return FPEncodingWrap::IEEE754;
1082+
};
1083+
1084+
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1085+
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1086+
};
1087+
10691088
switch (BC->getOpCode()) {
10701089
case OpPtrCastToGeneric:
10711090
case OpGenericCastToPtr:
@@ -1087,10 +1106,61 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10871106
case OpUConvert:
10881107
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10891108
break;
1090-
case OpFConvert:
1091-
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1109+
case OpConvertSToF:
1110+
case OpConvertFToS:
1111+
case OpConvertUToF:
1112+
case OpConvertFToU:
1113+
case OpFConvert: {
1114+
const auto OC = BC->getOpCode();
1115+
{
1116+
auto SPVOps = BC->getOperands();
1117+
auto *SPVSrcTy = SPVOps[0]->getType();
1118+
auto *SPVDstTy = BC->getType();
1119+
if (SPVSrcTy->isTypeVector()) {
1120+
SPVSrcTy = SPVSrcTy->getVectorComponentType();
1121+
} else if (SPVSrcTy->isTypeCooperativeMatrixKHR()) {
1122+
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVSrcTy);
1123+
SPVSrcTy = MT->getCompType();
1124+
}
1125+
if (SPVDstTy->isTypeVector()) {
1126+
SPVDstTy = SPVDstTy->getVectorComponentType();
1127+
} else if (SPVDstTy->isTypeCooperativeMatrixKHR()) {
1128+
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVDstTy);
1129+
SPVDstTy = MT->getCompType();
1130+
}
1131+
FPEncodingWrap SrcEnc = GetFPEncoding(SPVSrcTy);
1132+
FPEncodingWrap DstEnc = GetFPEncoding(SPVDstTy);
1133+
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1134+
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
1135+
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
1136+
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
1137+
std::vector<Value *> Ops = {Src};
1138+
std::vector<Type *> OpsTys = {Src->getType()};
1139+
1140+
std::string BuiltinName =
1141+
kSPIRVName::InternalPrefix + std::string(Conv);
1142+
BuiltinFuncMangleInfo Info;
1143+
std::string MangledName;
1144+
1145+
if (MangledName.empty())
1146+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1147+
1148+
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
1149+
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
1150+
return CallInst::Create(Func, Ops, "", BB);
1151+
}
1152+
}
1153+
1154+
if (OC == OpFConvert) {
1155+
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1156+
break;
1157+
}
1158+
CO = static_cast<CastInst::CastOps>(OpCodeMap::rmap(OC));
10921159
break;
1160+
}
10931161
case OpBitcast:
1162+
if (!Dst->isPointerTy() && Dst == Src->getType())
1163+
return Src;
10941164
// OpBitcast need to be handled as a special-case when the source is a
10951165
// pointer and the destination is not a pointer, and where the source is not
10961166
// a pointer and the destination is a pointer. This is supported by the
@@ -2990,11 +3060,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
29903060
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
29913061
auto *BI = static_cast<SPIRVInstruction *>(BV);
29923062
Value *Inst = nullptr;
2993-
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion() ||
2994-
BI->getType()->isTypeCooperativeMatrixKHR())
3063+
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion()) {
29953064
Inst = transSPIRVBuiltinFromInst(BI, BB);
2996-
else
3065+
} else if (BI->getType()->isTypeCooperativeMatrixKHR()) {
3066+
// For cooperative matrix conversions generate __builtin_spirv
3067+
// conversions instead of __spirv_FConvert in case of mini-float
3068+
// type element type.
3069+
auto *OutMatrixElementTy =
3070+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(BI->getType())
3071+
->getCompType();
3072+
auto *InMatrixElementTy =
3073+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(
3074+
static_cast<SPIRVUnary *>(BI)->getOperand(0)->getType())
3075+
->getCompType();
3076+
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
3077+
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
3078+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
3079+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
3080+
Inst = transConvertInst(BV, F, BB);
3081+
else
3082+
Inst = transSPIRVBuiltinFromInst(BI, BB);
3083+
} else {
29973084
Inst = transConvertInst(BV, F, BB);
3085+
}
29983086
return mapValue(BV, Inst);
29993087
}
30003088
return mapValue(

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
// This file needs to be included before anything that declares
4242
// llvm::PointerType to avoid a compilation bug on MSVC.
43+
#include "llvm/Demangle/Demangle.h"
4344
#include "llvm/Demangle/ItaniumDemangle.h"
4445

4546
#include "FunctionDescriptor.h"
@@ -267,6 +268,12 @@ bool isSYCLBfloat16Type(llvm::Type *Ty) {
267268
return false;
268269
}
269270

271+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty) {
272+
if (auto *TargetTy = dyn_cast<TargetExtType>(Ty))
273+
return TargetTy->getName() == "spirv.CooperativeMatrixKHR";
274+
return false;
275+
}
276+
270277
Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
271278
StringRef Name, BuiltinFuncMangleInfo *Mangle,
272279
AttributeList *Attrs, bool TakeName) {
@@ -484,6 +491,21 @@ bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
484491
return false;
485492
}
486493

494+
// Demangled name is a substring of the name. The DemangledName is updated only
495+
// if true is returned
496+
bool isInternalSPIRVBuiltin(StringRef Name, StringRef &DemangledName) {
497+
if (!Name.starts_with("_Z"))
498+
return false;
499+
constexpr unsigned DemangledNameLenStart = 2;
500+
size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
501+
if (!Name.substr(Start, Name.size() - 1)
502+
.starts_with(kSPIRVName::InternalPrefix))
503+
return false;
504+
DemangledName = llvm::itaniumDemangle(Name.data(), false);
505+
DemangledName.consume_front(kSPIRVName::InternalPrefix);
506+
return true;
507+
}
508+
487509
// Check if a mangled type Name is unsigned
488510
bool isMangledTypeUnsigned(char Mangled) {
489511
return Mangled == 'h' /* uchar */

0 commit comments

Comments
 (0)