Skip to content

[DirectX][DXIL] Align type spec of TableGen DXIL Op and LLVM Intrinsic #86311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def int_dx_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrW
def int_dx_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
def int_dx_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>;
def int_dx_barrier : Intrinsic<[], [llvm_i32_ty], [IntrNoDuplicate, IntrWillReturn]>;

def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/Support/DXILABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ enum class ParameterKind : uint8_t {
DXILHandle,
};

enum OverloadKind : uint16_t {
Invalid = 0,
Void = 1,
Half = 1 << 1,
Float = 1 << 2,
Double = 1 << 3,
I1 = 1 << 4,
I8 = 1 << 5,
I16 = 1 << 6,
I32 = 1 << 7,
I64 = 1 << 8,
UserDefineType = 1 << 9,
ObjectType = 1 << 10,
};

/// The kind of resource for an SRV or UAV resource. Sometimes referred to as
/// "Shape" in the DXIL docs.
enum class ResourceKind : uint32_t {
Expand Down
65 changes: 39 additions & 26 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -240,58 +240,63 @@ class DXILOpMappingBase {
DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
Intrinsic LLVMIntrinsic = ?; // LLVM Intrinsic DXIL Operation maps to
string Doc = ""; // A short description of the operation
list<LLVMType> OpTypes = ?; // Valid types of DXIL Operation in the
// format [returnTy, param1ty, ...]
// The following fields denote the same semantics as those of Intrinsic class
// and are initialized with the same values as those of LLVMIntrinsic unless
// overridden in the definition of a record.
list<LLVMType> OpRetTypes = ?; // Valid return types of DXIL Operation
list<LLVMType> OpParamTypes = ?; // Valid parameter types of DXIL Operation
}

class DXILOpMapping<int opCode, DXILOpClass opClass,
Intrinsic intrinsic, string doc,
list<LLVMType> opTys = []> : DXILOpMappingBase {
list<LLVMType> retTys = [],
list<LLVMType> paramTys = []> : DXILOpMappingBase {
int OpCode = opCode; // Opcode corresponding to DXIL Operation
DXILOpClass OpClass = opClass; // Class of DXIL Operation.
Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
string Doc = doc; // to a short description of the operation
list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
list<LLVMType> OpRetTypes = !if(!eq(!size(retTys), 0), LLVMIntrinsic.RetTypes, retTys);
list<LLVMType> OpParamTypes = !if(!eq(!size(paramTys), 0), LLVMIntrinsic.ParamTypes, paramTys);
}

// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
def Abs : DXILOpMapping<6, unary, int_fabs,
"Returns the absolute value of the input.">;
def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
"Determines if the specified value is infinite.",
[llvm_i1_ty, llvm_halforfloat_ty]>;
[llvm_i1_ty], [llvm_halforfloat_ty]>;
def Cos : DXILOpMapping<12, unary, int_cos,
"Returns cosine(theta) for theta in radians.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def Sin : DXILOpMapping<13, unary, int_sin,
"Returns sine(theta) for theta in radians.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def Exp2 : DXILOpMapping<21, unary, int_exp2,
"Returns the base 2 exponential, or 2**x, of the specified value."
"exp2(x) = 2**x.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def Frac : DXILOpMapping<22, unary, int_dx_frac,
"Returns a fraction from 0 to 1 that represents the "
"decimal part of the input.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def Log2 : DXILOpMapping<23, unary, int_log2,
"Returns the base-2 logarithm of the specified value.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def Sqrt : DXILOpMapping<24, unary, int_sqrt,
"Returns the square root of the specified floating-point"
"value, per component.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
"Returns the reciprocal of the square root of the specified value."
"rsqrt(x) = 1 / sqrt(x).",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def Round : DXILOpMapping<26, unary, int_round,
"Returns the input rounded to the nearest integer"
"within a floating-point type.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def Floor : DXILOpMapping<27, unary, int_floor,
"Returns the largest integer that is less than or equal to the input.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
[llvm_halforfloat_ty], [LLVMMatchType<0>]>;
def FMax : DXILOpMapping<35, binary, int_maxnum,
"Float maximum. FMax(a,b) = a > b ? a : b">;
def FMin : DXILOpMapping<36, binary, int_minnum,
Expand All @@ -305,20 +310,28 @@ def UMax : DXILOpMapping<39, binary, int_umax,
def UMin : DXILOpMapping<40, binary, int_umin,
"Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
"Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">;
"Floating point arithmetic multiply/add operation. "
"fmad(m,a,b) = m * a + b.">;
def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
"Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
"Signed integer arithmetic multiply/add operation. "
"imad(m,a,b) = m * a + b.">;
def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
"Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
"Unsigned integer arithmetic multiply/add operation. "
"umad(m,a,b) = m * a + b.">;
def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
"dot product of two float vectors Dot(a,b) = a[0]*b[0]"
" + ... + a[n]*b[n] where n is between 0 and 1",
[llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)>;
def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
"dot product of two float vectors Dot(a,b) = a[0]*b[0]"
" + ... + a[n]*b[n] where n is between 0 and 2",
[llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)>;
def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
"dot product of two float vectors Dot(a,b) = a[0]*b[0]"
" + ... + a[n]*b[n] where n is between 0 and 3",
[llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)>;
def Barrier : DXILOpMapping<80, barrier, int_dx_barrier,
"Inserts a memory barrier in the shader">;
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
"Reads the thread ID">;
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,
Expand Down
55 changes: 20 additions & 35 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,13 @@ using namespace llvm::dxil;

constexpr StringLiteral DXILOpNamePrefix = "dx.op.";

namespace {

enum OverloadKind : uint16_t {
VOID = 1,
HALF = 1 << 1,
FLOAT = 1 << 2,
DOUBLE = 1 << 3,
I1 = 1 << 4,
I8 = 1 << 5,
I16 = 1 << 6,
I32 = 1 << 7,
I64 = 1 << 8,
UserDefineType = 1 << 9,
ObjectType = 1 << 10,
};

} // namespace

static const char *getOverloadTypeName(OverloadKind Kind) {
switch (Kind) {
case OverloadKind::HALF:
case OverloadKind::Half:
return "f16";
case OverloadKind::FLOAT:
case OverloadKind::Float:
return "f32";
case OverloadKind::DOUBLE:
case OverloadKind::Double:
return "f64";
case OverloadKind::I1:
return "i1";
Expand All @@ -57,26 +39,29 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
return "i32";
case OverloadKind::I64:
return "i64";
case OverloadKind::VOID:
case OverloadKind::Void:
case OverloadKind::ObjectType:
case OverloadKind::UserDefineType:
break;
case OverloadKind::Invalid:
report_fatal_error("Invalid Overload Type for type name lookup",
/* gen_crash_diag=*/false);
}
llvm_unreachable("invalid overload type for name");
llvm_unreachable("Unhandled Overload Type specified for type name lookup");
return "void";
}

static OverloadKind getOverloadKind(Type *Ty) {
Type::TypeID T = Ty->getTypeID();
switch (T) {
case Type::VoidTyID:
return OverloadKind::VOID;
return OverloadKind::Void;
case Type::HalfTyID:
return OverloadKind::HALF;
return OverloadKind::Half;
case Type::FloatTyID:
return OverloadKind::FLOAT;
return OverloadKind::Float;
case Type::DoubleTyID:
return OverloadKind::DOUBLE;
return OverloadKind::Double;
case Type::IntegerTyID: {
IntegerType *ITy = cast<IntegerType>(Ty);
unsigned Bits = ITy->getBitWidth();
Expand All @@ -93,7 +78,7 @@ static OverloadKind getOverloadKind(Type *Ty) {
return OverloadKind::I64;
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
return OverloadKind::Void;
}
}
case Type::PointerTyID:
Expand All @@ -102,7 +87,7 @@ static OverloadKind getOverloadKind(Type *Ty) {
return OverloadKind::ObjectType;
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
return OverloadKind::Void;
}
}

Expand Down Expand Up @@ -147,7 +132,7 @@ struct OpCodeProperty {

static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
const OpCodeProperty &Prop) {
if (Kind == OverloadKind::VOID) {
if (Kind == OverloadKind::Void) {
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
}
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
Expand All @@ -157,7 +142,7 @@ static std::string constructOverloadName(OverloadKind Kind, Type *Ty,

static std::string constructOverloadTypeName(OverloadKind Kind,
StringRef TypeName) {
if (Kind == OverloadKind::VOID)
if (Kind == OverloadKind::Void)
return TypeName.str();

assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
Expand Down Expand Up @@ -284,13 +269,13 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
if (Prop->OverloadParamIndex < 0) {
auto &Ctx = FT->getContext();
switch (Prop->OverloadTys) {
case OverloadKind::VOID:
case OverloadKind::Void:
return Type::getVoidTy(Ctx);
case OverloadKind::HALF:
case OverloadKind::Half:
return Type::getHalfTy(Ctx);
case OverloadKind::FLOAT:
case OverloadKind::Float:
return Type::getFloatTy(Ctx);
case OverloadKind::DOUBLE:
case OverloadKind::Double:
return Type::getDoubleTy(Ctx);
case OverloadKind::I1:
return Type::getInt1Ty(Ctx);
Expand Down
11 changes: 11 additions & 0 deletions llvm/test/CodeGen/DirectX/barrier.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s

; Argument of llvm.dx.barrier is expected to be a mask of
; DXIL::BarrierMode values. Chose an int value for testing.

define void @test_barrier() #0 {
entry:
; CHECK: call void @dx.op.barrier.i32(i32 80, i32 9)
call void @llvm.dx.barrier(i32 noundef 9)
ret void
}
Loading