diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index a871fac46b9fd..ce8c25233ed56 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -16,6 +16,7 @@ def int_dx_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrW def int_dx_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>; def int_dx_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>; def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>; +def int_dx_barrier : Intrinsic<[], [llvm_i32_ty], [IntrNoDuplicate, IntrWillReturn]>; def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">, Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>; diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h index da4bea8fc46e3..fdf140d125f87 100644 --- a/llvm/include/llvm/Support/DXILABI.h +++ b/llvm/include/llvm/Support/DXILABI.h @@ -39,6 +39,21 @@ enum class ParameterKind : uint8_t { DXILHandle, }; +enum OverloadKind : uint16_t { + Invalid = 0, + Void = 1, + Half = 1 << 1, + Float = 1 << 2, + Double = 1 << 3, + I1 = 1 << 4, + I8 = 1 << 5, + I16 = 1 << 6, + I32 = 1 << 7, + I64 = 1 << 8, + UserDefineType = 1 << 9, + ObjectType = 1 << 10, +}; + /// The kind of resource for an SRV or UAV resource. Sometimes referred to as /// "Shape" in the DXIL docs. enum class ResourceKind : uint32_t { diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 1fd6f3ed044ec..fa4cf446554cc 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -240,18 +240,23 @@ class DXILOpMappingBase { DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation. Intrinsic LLVMIntrinsic = ?; // LLVM Intrinsic DXIL Operation maps to string Doc = ""; // A short description of the operation - list OpTypes = ?; // Valid types of DXIL Operation in the - // format [returnTy, param1ty, ...] + // The following fields denote the same semantics as those of Intrinsic class + // and are initialized with the same values as those of LLVMIntrinsic unless + // overridden in the definition of a record. + list OpRetTypes = ?; // Valid return types of DXIL Operation + list OpParamTypes = ?; // Valid parameter types of DXIL Operation } class DXILOpMapping opTys = []> : DXILOpMappingBase { + list retTys = [], + list paramTys = []> : DXILOpMappingBase { int OpCode = opCode; // Opcode corresponding to DXIL Operation DXILOpClass OpClass = opClass; // Class of DXIL Operation. Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps string Doc = doc; // to a short description of the operation - list OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys); + list OpRetTypes = !if(!eq(!size(retTys), 0), LLVMIntrinsic.RetTypes, retTys); + list OpParamTypes = !if(!eq(!size(paramTys), 0), LLVMIntrinsic.ParamTypes, paramTys); } // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic @@ -259,39 +264,39 @@ def Abs : DXILOpMapping<6, unary, int_fabs, "Returns the absolute value of the input.">; def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf, "Determines if the specified value is infinite.", - [llvm_i1_ty, llvm_halforfloat_ty]>; + [llvm_i1_ty], [llvm_halforfloat_ty]>; def Cos : DXILOpMapping<12, unary, int_cos, "Returns cosine(theta) for theta in radians.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def Sin : DXILOpMapping<13, unary, int_sin, "Returns sine(theta) for theta in radians.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def Exp2 : DXILOpMapping<21, unary, int_exp2, "Returns the base 2 exponential, or 2**x, of the specified value." "exp2(x) = 2**x.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def Frac : DXILOpMapping<22, unary, int_dx_frac, "Returns a fraction from 0 to 1 that represents the " "decimal part of the input.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def Log2 : DXILOpMapping<23, unary, int_log2, "Returns the base-2 logarithm of the specified value.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def Sqrt : DXILOpMapping<24, unary, int_sqrt, "Returns the square root of the specified floating-point" "value, per component.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt, "Returns the reciprocal of the square root of the specified value." "rsqrt(x) = 1 / sqrt(x).", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def Round : DXILOpMapping<26, unary, int_round, "Returns the input rounded to the nearest integer" "within a floating-point type.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def Floor : DXILOpMapping<27, unary, int_floor, "Returns the largest integer that is less than or equal to the input.", - [llvm_halforfloat_ty, LLVMMatchType<0>]>; + [llvm_halforfloat_ty], [LLVMMatchType<0>]>; def FMax : DXILOpMapping<35, binary, int_maxnum, "Float maximum. FMax(a,b) = a > b ? a : b">; def FMin : DXILOpMapping<36, binary, int_minnum, @@ -305,20 +310,28 @@ def UMax : DXILOpMapping<39, binary, int_umax, def UMin : DXILOpMapping<40, binary, int_umin, "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">; def FMad : DXILOpMapping<46, tertiary, int_fmuladd, - "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">; + "Floating point arithmetic multiply/add operation. " + "fmad(m,a,b) = m * a + b.">; def IMad : DXILOpMapping<48, tertiary, int_dx_imad, - "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">; + "Signed integer arithmetic multiply/add operation. " + "imad(m,a,b) = m * a + b.">; def UMad : DXILOpMapping<49, tertiary, int_dx_umad, - "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">; -let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in - def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2, - "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">; -let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in - def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3, - "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">; -let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in - def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4, - "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">; + "Unsigned integer arithmetic multiply/add operation. " + "umad(m,a,b) = m * a + b.">; +def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2, + "dot product of two float vectors Dot(a,b) = a[0]*b[0]" + " + ... + a[n]*b[n] where n is between 0 and 1", + [llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)>; +def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3, + "dot product of two float vectors Dot(a,b) = a[0]*b[0]" + " + ... + a[n]*b[n] where n is between 0 and 2", + [llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)>; +def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4, + "dot product of two float vectors Dot(a,b) = a[0]*b[0]" + " + ... + a[n]*b[n] where n is between 0 and 3", + [llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)>; +def Barrier : DXILOpMapping<80, barrier, int_dx_barrier, + "Inserts a memory barrier in the shader">; def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id, "Reads the thread ID">; def GroupId : DXILOpMapping<94, groupId, int_dx_group_id, diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 0b3982ea0f438..22bbda461f1ae 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -21,31 +21,13 @@ using namespace llvm::dxil; constexpr StringLiteral DXILOpNamePrefix = "dx.op."; -namespace { - -enum OverloadKind : uint16_t { - VOID = 1, - HALF = 1 << 1, - FLOAT = 1 << 2, - DOUBLE = 1 << 3, - I1 = 1 << 4, - I8 = 1 << 5, - I16 = 1 << 6, - I32 = 1 << 7, - I64 = 1 << 8, - UserDefineType = 1 << 9, - ObjectType = 1 << 10, -}; - -} // namespace - static const char *getOverloadTypeName(OverloadKind Kind) { switch (Kind) { - case OverloadKind::HALF: + case OverloadKind::Half: return "f16"; - case OverloadKind::FLOAT: + case OverloadKind::Float: return "f32"; - case OverloadKind::DOUBLE: + case OverloadKind::Double: return "f64"; case OverloadKind::I1: return "i1"; @@ -57,12 +39,15 @@ static const char *getOverloadTypeName(OverloadKind Kind) { return "i32"; case OverloadKind::I64: return "i64"; - case OverloadKind::VOID: + case OverloadKind::Void: case OverloadKind::ObjectType: case OverloadKind::UserDefineType: break; + case OverloadKind::Invalid: + report_fatal_error("Invalid Overload Type for type name lookup", + /* gen_crash_diag=*/false); } - llvm_unreachable("invalid overload type for name"); + llvm_unreachable("Unhandled Overload Type specified for type name lookup"); return "void"; } @@ -70,13 +55,13 @@ static OverloadKind getOverloadKind(Type *Ty) { Type::TypeID T = Ty->getTypeID(); switch (T) { case Type::VoidTyID: - return OverloadKind::VOID; + return OverloadKind::Void; case Type::HalfTyID: - return OverloadKind::HALF; + return OverloadKind::Half; case Type::FloatTyID: - return OverloadKind::FLOAT; + return OverloadKind::Float; case Type::DoubleTyID: - return OverloadKind::DOUBLE; + return OverloadKind::Double; case Type::IntegerTyID: { IntegerType *ITy = cast(Ty); unsigned Bits = ITy->getBitWidth(); @@ -93,7 +78,7 @@ static OverloadKind getOverloadKind(Type *Ty) { return OverloadKind::I64; default: llvm_unreachable("invalid overload type"); - return OverloadKind::VOID; + return OverloadKind::Void; } } case Type::PointerTyID: @@ -102,7 +87,7 @@ static OverloadKind getOverloadKind(Type *Ty) { return OverloadKind::ObjectType; default: llvm_unreachable("invalid overload type"); - return OverloadKind::VOID; + return OverloadKind::Void; } } @@ -147,7 +132,7 @@ struct OpCodeProperty { static std::string constructOverloadName(OverloadKind Kind, Type *Ty, const OpCodeProperty &Prop) { - if (Kind == OverloadKind::VOID) { + if (Kind == OverloadKind::Void) { return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); } return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + @@ -157,7 +142,7 @@ static std::string constructOverloadName(OverloadKind Kind, Type *Ty, static std::string constructOverloadTypeName(OverloadKind Kind, StringRef TypeName) { - if (Kind == OverloadKind::VOID) + if (Kind == OverloadKind::Void) return TypeName.str(); assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); @@ -284,13 +269,13 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { if (Prop->OverloadParamIndex < 0) { auto &Ctx = FT->getContext(); switch (Prop->OverloadTys) { - case OverloadKind::VOID: + case OverloadKind::Void: return Type::getVoidTy(Ctx); - case OverloadKind::HALF: + case OverloadKind::Half: return Type::getHalfTy(Ctx); - case OverloadKind::FLOAT: + case OverloadKind::Float: return Type::getFloatTy(Ctx); - case OverloadKind::DOUBLE: + case OverloadKind::Double: return Type::getDoubleTy(Ctx); case OverloadKind::I1: return Type::getInt1Ty(Ctx); diff --git a/llvm/test/CodeGen/DirectX/barrier.ll b/llvm/test/CodeGen/DirectX/barrier.ll new file mode 100644 index 0000000000000..8be4aac1f782b --- /dev/null +++ b/llvm/test/CodeGen/DirectX/barrier.ll @@ -0,0 +1,11 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Argument of llvm.dx.barrier is expected to be a mask of +; DXIL::BarrierMode values. Chose an int value for testing. + +define void @test_barrier() #0 { +entry: + ; CHECK: call void @dx.op.barrier.i32(i32 80, i32 9) + call void @llvm.dx.barrier(i32 noundef 9) + ret void +} diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp index f2504775d557f..b3772b92f2c23 100644 --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -39,8 +39,8 @@ struct DXILOperationDesc { int OpCode; // ID of DXIL operation StringRef OpClass; // name of the opcode class StringRef Doc; // the documentation description of this instruction - SmallVector OpTypes; // Vector of operand type records - - // return type is at index 0 + SmallVector OpOverloadTys; // Vector of operand overload types - + // return type is at index 0 SmallVector OpAttributes; // operation attribute represented as strings StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which @@ -65,41 +65,167 @@ struct DXILOperationDesc { }; } // end anonymous namespace -/// Return dxil::ParameterKind corresponding to input LLVMType record +/// Return dxil::ParameterKind corresponding to input Overload Kind /// -/// \param R TableGen def record of class LLVMType +/// \param OLKind Overload Kind /// \return ParameterKind As defined in llvm/Support/DXILABI.h -static ParameterKind getParameterKind(const Record *R) { +static ParameterKind getParameterKind(const dxil::OverloadKind OLKind) { + switch (OLKind) { + case OverloadKind::Void: + return ParameterKind::Void; + case OverloadKind::Half: + return ParameterKind::Half; + case OverloadKind::Float: + return ParameterKind::Float; + case OverloadKind::Double: + return ParameterKind::Double; + case OverloadKind::I1: + return ParameterKind::I1; + case OverloadKind::I8: + return ParameterKind::I8; + case OverloadKind::I16: + return ParameterKind::I16; + case OverloadKind::I32: + return ParameterKind::I32; + case OverloadKind::I64: + return ParameterKind::I64; + default: + if ((OLKind == + (OverloadKind::Half | OverloadKind::Float | OverloadKind::Double)) || + (OLKind == (OverloadKind::Half | OverloadKind::Float)) || + (OLKind == (OverloadKind::I1 | OverloadKind::I8 | OverloadKind::I16 | + OverloadKind::I32 | OverloadKind::I64)) || + (OLKind == (OverloadKind::I16 | OverloadKind::I32))) { + return ParameterKind::Overload; + } else { + report_fatal_error("Unsupported Overload Type encountered", + /* gen_crash_diag=*/false); + } + } +} + +/// Return a string representation of ParameterKind enum +/// \param Kind Parameter Kind enum value +/// \return std::string string representation of input Kind +static std::string getParameterKindStr(ParameterKind Kind) { + switch (Kind) { + case ParameterKind::Invalid: + return "Invalid"; + case ParameterKind::Void: + return "Void"; + case ParameterKind::Half: + return "Half"; + case ParameterKind::Float: + return "Float"; + case ParameterKind::Double: + return "Double"; + case ParameterKind::I1: + return "I1"; + case ParameterKind::I8: + return "I8"; + case ParameterKind::I16: + return "I16"; + case ParameterKind::I32: + return "I32"; + case ParameterKind::I64: + return "I64"; + case ParameterKind::Overload: + return "Overload"; + case ParameterKind::CBufferRet: + return "CBufferRet"; + case ParameterKind::ResourceRet: + return "ResourceRet"; + case ParameterKind::DXILHandle: + return "DXILHandle"; + } + llvm_unreachable("Unknown llvm::dxil::ParameterKind enum"); +} + +static dxil::OverloadKind getOverloadKind(const Record *R) { auto VTRec = R->getValueAsDef("VT"); switch (getValueType(VTRec)) { case MVT::isVoid: - return ParameterKind::Void; + return OverloadKind::Void; case MVT::f16: - return ParameterKind::Half; + return OverloadKind::Half; case MVT::f32: - return ParameterKind::Float; + return OverloadKind::Float; case MVT::f64: - return ParameterKind::Double; + return OverloadKind::Double; case MVT::i1: - return ParameterKind::I1; + return OverloadKind::I1; case MVT::i8: - return ParameterKind::I8; + return OverloadKind::I8; case MVT::i16: - return ParameterKind::I16; + return OverloadKind::I16; case MVT::i32: - return ParameterKind::I32; - case MVT::fAny: + return OverloadKind::I32; + case MVT::i64: + return OverloadKind::I64; case MVT::iAny: - return ParameterKind::Overload; + return static_cast( + OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64); + case MVT::fAny: + return static_cast( + OverloadKind::Half | OverloadKind::Float | OverloadKind::Double); case MVT::Other: // Handle DXIL-specific overload types - if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) { - return ParameterKind::Overload; + { + if (R->getValueAsInt("isHalfOrFloat")) { + return static_cast(OverloadKind::Half | + OverloadKind::Float); + } else if (R->getValueAsInt("isI16OrI32")) { + return static_cast(OverloadKind::I16 | + OverloadKind::I32); + } } LLVM_FALLTHROUGH; default: - llvm_unreachable("Support for specified DXIL Type not yet implemented"); + report_fatal_error( + "Support for specified parameter OverloadKind not yet implemented", + /* gen_crash_diag=*/false); + } +} + +/// Return a string representation of OverloadKind enum +/// \param OLKind Overload Kind +/// \return std::string string representation of OverloadKind + +static std::string getOverloadKindStr(const dxil::OverloadKind OLKind) { + switch (OLKind) { + case OverloadKind::Void: + return "OverloadKind::Void"; + case OverloadKind::Half: + return "OverloadKind::Half"; + case OverloadKind::Float: + return "OverloadKind::Float"; + case OverloadKind::Double: + return "OverloadKind::Double"; + case OverloadKind::I1: + return "OverloadKind::I1"; + case OverloadKind::I8: + return "OverloadKind::I8"; + case OverloadKind::I16: + return "OverloadKind::I16"; + case OverloadKind::I32: + return "OverloadKind::I32"; + case OverloadKind::I64: + return "OverloadKind::I64"; + default: + if (OLKind == (OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64)) { + return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64"; + } else if (OLKind == (OverloadKind::Half | OverloadKind::Float | + OverloadKind::Double)) { + return "OverloadKind::Half | OverloadKind::Float | OverloadKind::Double"; + } else if (OLKind == (OverloadKind::Half | OverloadKind::Float)) { + return "OverloadKind::Half | OverloadKind::Float"; + } else if (OLKind == (OverloadKind::I16 | OverloadKind::I32)) { + return "OverloadKind::I16 | OverloadKind::I32"; + } else { + report_fatal_error("Unsupported OverloadKind specified", + /* gen_crash_diag=*/false); + } } } @@ -114,9 +240,25 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { Doc = R->getValueAsString("Doc"); - auto TypeRecs = R->getValueAsListOfDefs("OpTypes"); + // Populate OpOverloadTys with return type and parameter types + auto RetTypeRecs = R->getValueAsListOfDefs("OpRetTypes"); + auto ParamTypeRecs = R->getValueAsListOfDefs("OpParamTypes"); + unsigned RetTypeRecSize = RetTypeRecs.size(); + unsigned ParamTypeRecSize = ParamTypeRecs.size(); + // A vector with return type and parameter type records + std::vector TypeRecs; + TypeRecs.reserve(RetTypeRecSize + ParamTypeRecSize); + // If return type lust is empty, the return type is void + if (RetTypeRecSize == 0) { + OpOverloadTys.emplace_back(OverloadKind::Void); + } else { + // Append RetTypeRecs to TypeRecs + TypeRecs.insert(TypeRecs.end(), RetTypeRecs.begin(), RetTypeRecs.end()); + } + // Append RetTypeRecs to TypeRecs + TypeRecs.insert(TypeRecs.end(), ParamTypeRecs.begin(), ParamTypeRecs.end()); + unsigned TypeRecsSize = TypeRecs.size(); - // Populate OpTypes with return type and parameter types // Parameter indices of overloaded parameters. // This vector contains overload parameters in the order used to @@ -146,13 +288,13 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { if (!knownType) { report_fatal_error("Specification of multiple differing overload " "parameter types not yet supported", - false); + /* gen_crash_diag=*/false); } } else { OverloadParamIndices.push_back(i); } } - // Populate OpTypes array according to the type specification + // Populate OpOverloadTys array according to the type specification if (TR->isAnonymous()) { // Check prior overload types exist assert(!OverloadParamIndices.empty() && @@ -160,10 +302,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { // Get the parameter index of anonymous type, TR, references auto OLParamIndex = TR->getValueAsInt("Number"); // Resolve and insert the type to that at OLParamIndex - OpTypes.emplace_back(TypeRecs[OLParamIndex]); + OpOverloadTys.emplace_back(getOverloadKind(TypeRecs[OLParamIndex])); } else { - // A non-anonymous type. Just record it in OpTypes - OpTypes.emplace_back(TR); + // A non-anonymous type. Just record it in OpOverloadTys + OpOverloadTys.emplace_back(getOverloadKind(TR)); } } @@ -172,7 +314,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { if (!OverloadParamIndices.empty()) { if (OverloadParamIndices.size() > 1) report_fatal_error("Multiple overload type specification not supported", - false); + /* gen_crash_diag=*/false); OverloadParamIndex = OverloadParamIndices[0]; } // Get the operation class @@ -196,89 +338,6 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { } } -/// Return a string representation of ParameterKind enum -/// \param Kind Parameter Kind enum value -/// \return std::string string representation of input Kind -static std::string getParameterKindStr(ParameterKind Kind) { - switch (Kind) { - case ParameterKind::Invalid: - return "Invalid"; - case ParameterKind::Void: - return "Void"; - case ParameterKind::Half: - return "Half"; - case ParameterKind::Float: - return "Float"; - case ParameterKind::Double: - return "Double"; - case ParameterKind::I1: - return "I1"; - case ParameterKind::I8: - return "I8"; - case ParameterKind::I16: - return "I16"; - case ParameterKind::I32: - return "I32"; - case ParameterKind::I64: - return "I64"; - case ParameterKind::Overload: - return "Overload"; - case ParameterKind::CBufferRet: - return "CBufferRet"; - case ParameterKind::ResourceRet: - return "ResourceRet"; - case ParameterKind::DXILHandle: - return "DXILHandle"; - } - llvm_unreachable("Unknown llvm::dxil::ParameterKind enum"); -} - -/// Return a string representation of OverloadKind enum that maps to -/// input LLVMType record -/// \param R TableGen def record of class LLVMType -/// \return std::string string representation of OverloadKind - -static std::string getOverloadKindStr(const Record *R) { - auto VTRec = R->getValueAsDef("VT"); - switch (getValueType(VTRec)) { - case MVT::isVoid: - return "OverloadKind::VOID"; - case MVT::f16: - return "OverloadKind::HALF"; - case MVT::f32: - return "OverloadKind::FLOAT"; - case MVT::f64: - return "OverloadKind::DOUBLE"; - case MVT::i1: - return "OverloadKind::I1"; - case MVT::i8: - return "OverloadKind::I8"; - case MVT::i16: - return "OverloadKind::I16"; - case MVT::i32: - return "OverloadKind::I32"; - case MVT::i64: - return "OverloadKind::I64"; - case MVT::iAny: - return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64"; - case MVT::fAny: - return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE"; - case MVT::Other: - // Handle DXIL-specific overload types - { - if (R->getValueAsInt("isHalfOrFloat")) { - return "OverloadKind::HALF | OverloadKind::FLOAT"; - } else if (R->getValueAsInt("isI16OrI32")) { - return "OverloadKind::I16 | OverloadKind::I32"; - } - } - LLVM_FALLTHROUGH; - default: - llvm_unreachable( - "Support for specified parameter OverloadKind not yet implemented"); - } -} - /// Emit Enums of DXIL Ops /// \param A vector of DXIL Ops /// \param Output stream @@ -376,8 +435,8 @@ static void emitDXILOperationTable(std::vector &Ops, OpClassStrings.add(Op.OpClass.data()); SmallVector ParamKindVec; // ParamKindVec is a vector of parameters. Skip return type at index 0 - for (unsigned i = 1; i < Op.OpTypes.size(); i++) { - ParamKindVec.emplace_back(getParameterKind(Op.OpTypes[i])); + for (unsigned i = 1; i < Op.OpOverloadTys.size(); i++) { + ParamKindVec.emplace_back(getParameterKind(Op.OpOverloadTys[i])); } ParameterMap[Op.OpClass] = ParamKindVec; Parameters.add(ParamKindVec); @@ -391,7 +450,7 @@ static void emitDXILOperationTable(std::vector &Ops, // Emit the DXIL operation table. //{dxil::OpCode::Sin, OpCodeNameIndex, OpCodeClass::unary, // OpCodeClassNameIndex, - // OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0, + // OverloadKind::Float | OverloadKind::Half, Attribute::AttrKind::ReadNone, 0, // 3, ParameterTableOffset}, OS << "static const OpCodeProperty *getOpCodeProperty(dxil::OpCode Op) " "{\n"; @@ -406,14 +465,14 @@ static void emitDXILOperationTable(std::vector &Ops, // return type - as overload parameter to emit the appropriate overload kind // enum. if (OLParamIdx < 0) { - OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0; + OLParamIdx = (Op.OpOverloadTys.size() > 1) ? 1 : 0; } OS << " { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName) << ", OpCodeClass::" << Op.OpClass << ", " << OpClassStrings.get(Op.OpClass.data()) << ", " - << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", " + << getOverloadKindStr(Op.OpOverloadTys[OLParamIdx]) << ", " << emitDXILOperationAttr(Op.OpAttributes) << ", " - << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", " + << Op.OverloadParamIndex << ", " << Op.OpOverloadTys.size() - 1 << ", " << Parameters.get(ParameterMap[Op.OpClass]) << " },\n"; } OS << " };\n";