diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index bf49ec6f6c649..0d24790b4ad58 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -107,4 +107,6 @@ def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0> def int_dx_discard : DefaultAttrsIntrinsic<[], [llvm_i1_ty], []>; def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>; def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>; + +def int_dx_group_memory_barrier_with_group_sync : DefaultAttrsIntrinsic<[], [], []>; } diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 7cc08b2fe7cc4..cff6cdce813de 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -294,6 +294,60 @@ class Attributes attrs> { list op_attrs = attrs; } +defvar BarrierMode_DeviceMemoryBarrier = 2; +defvar BarrierMode_DeviceMemoryBarrierWithGroupSync = 3; +defvar BarrierMode_GroupMemoryBarrier = 8; +defvar BarrierMode_GroupMemoryBarrierWithGroupSync = 9; +defvar BarrierMode_AllMemoryBarrier = 10; +defvar BarrierMode_AllMemoryBarrierWithGroupSync = 11; + +// Intrinsic arg selection +class IntrinArgSelectType; +def IntrinArgSelect_Index : IntrinArgSelectType; +def IntrinArgSelect_I8 : IntrinArgSelectType; +def IntrinArgSelect_I32 : IntrinArgSelectType; + +class IntrinArgSelect { + IntrinArgSelectType type = type_; + int value = value_; +} +class IntrinArgIndex : IntrinArgSelect; +class IntrinArgI8 : IntrinArgSelect; +class IntrinArgI32 : IntrinArgSelect; + +// Select which intrinsic to lower from for a DXILOp. +// If the intrinsic is the only argument given to IntrinSelect, then the +// arguments of the intrinsic will be copied in the same order. Example: +// let intrinsics = [ +// IntrinSelect, +// IntrinSelect, +// ] +//========================================================================================= +// Using IntrinArgIndex<>, arguments of the intrinsic can be copied to the DXIL +// OP in specific order: +// let intrinsics = [ +// IntrinSelect, IntrinArgIndex<1>, IntrinArgIndex<0>> ] +// >, +// ] +//========================================================================================= +// Using IntrinArgI8<> and IntrinArgI32<>, integer constants can be added +// directly to the dxil op. This can be used in conjunction with +// IntrinArgIndex: +// let intrinsics = [ +// IntrinSelect, IntrinArgI8<0>, IntrinArgI8<1> ] +// >, +// IntrinSelect, IntrinArgI8<0>, IntrinArgI8<0> ] +// >, +// ] +// +class IntrinSelect arg_selects_=[]> { + Intrinsic intrinsic = intrinsic_; + list arg_selects = arg_selects_; +} + // Abstraction DXIL Operation class DXILOp { // A short description of the operation @@ -305,8 +359,8 @@ class DXILOp { // Class of DXIL Operation. DXILOpClass OpClass = opclass; - // LLVM Intrinsic DXIL Operation maps to - Intrinsic LLVMIntrinsic = ?; + // LLVM Intrinsics DXIL Operation maps from + list intrinsics = []; // Result type of the op DXILOpParamType result; @@ -328,7 +382,7 @@ class DXILOp { def Abs : DXILOp<6, unary> { let Doc = "Returns the absolute value of the input."; - let LLVMIntrinsic = int_fabs; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -338,7 +392,7 @@ def Abs : DXILOp<6, unary> { def Saturate : DXILOp<7, unary> { let Doc = "Clamps a single or double precision floating point value to [0.0f...1.0f]."; - let LLVMIntrinsic = int_dx_saturate; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -348,7 +402,7 @@ def Saturate : DXILOp<7, unary> { def IsInf : DXILOp<9, isSpecialFloat> { let Doc = "Determines if the specified value is infinite."; - let LLVMIntrinsic = int_dx_isinf; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = Int1Ty; let overloads = [Overloads]; @@ -358,7 +412,7 @@ def IsInf : DXILOp<9, isSpecialFloat> { def Cos : DXILOp<12, unary> { let Doc = "Returns cosine(theta) for theta in radians."; - let LLVMIntrinsic = int_cos; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -368,7 +422,7 @@ def Cos : DXILOp<12, unary> { def Sin : DXILOp<13, unary> { let Doc = "Returns sine(theta) for theta in radians."; - let LLVMIntrinsic = int_sin; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -378,7 +432,7 @@ def Sin : DXILOp<13, unary> { def Tan : DXILOp<14, unary> { let Doc = "Returns tangent(theta) for theta in radians."; - let LLVMIntrinsic = int_tan; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -388,7 +442,7 @@ def Tan : DXILOp<14, unary> { def ACos : DXILOp<15, unary> { let Doc = "Returns the arccosine of the specified value."; - let LLVMIntrinsic = int_acos; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -398,7 +452,7 @@ def ACos : DXILOp<15, unary> { def ASin : DXILOp<16, unary> { let Doc = "Returns the arcsine of the specified value."; - let LLVMIntrinsic = int_asin; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -408,7 +462,7 @@ def ASin : DXILOp<16, unary> { def ATan : DXILOp<17, unary> { let Doc = "Returns the arctangent of the specified value."; - let LLVMIntrinsic = int_atan; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -418,7 +472,7 @@ def ATan : DXILOp<17, unary> { def HCos : DXILOp<18, unary> { let Doc = "Returns the hyperbolic cosine of the specified value."; - let LLVMIntrinsic = int_cosh; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -428,7 +482,7 @@ def HCos : DXILOp<18, unary> { def HSin : DXILOp<19, unary> { let Doc = "Returns the hyperbolic sine of the specified value."; - let LLVMIntrinsic = int_sinh; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -438,7 +492,7 @@ def HSin : DXILOp<19, unary> { def HTan : DXILOp<20, unary> { let Doc = "Returns the hyperbolic tan of the specified value."; - let LLVMIntrinsic = int_tanh; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -449,7 +503,7 @@ def HTan : DXILOp<20, unary> { def Exp2 : DXILOp<21, unary> { let Doc = "Returns the base 2 exponential, or 2**x, of the specified value. " "exp2(x) = 2**x."; - let LLVMIntrinsic = int_exp2; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -460,7 +514,7 @@ def Exp2 : DXILOp<21, unary> { def Frac : DXILOp<22, unary> { let Doc = "Returns a fraction from 0 to 1 that represents the decimal part " "of the input."; - let LLVMIntrinsic = int_dx_frac; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -470,7 +524,7 @@ def Frac : DXILOp<22, unary> { def Log2 : DXILOp<23, unary> { let Doc = "Returns the base-2 logarithm of the specified value."; - let LLVMIntrinsic = int_log2; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -481,7 +535,7 @@ def Log2 : DXILOp<23, unary> { def Sqrt : DXILOp<24, unary> { let Doc = "Returns the square root of the specified floating-point value, " "per component."; - let LLVMIntrinsic = int_sqrt; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -492,7 +546,7 @@ def Sqrt : DXILOp<24, unary> { def RSqrt : DXILOp<25, unary> { let Doc = "Returns the reciprocal of the square root of the specified value. " "rsqrt(x) = 1 / sqrt(x)."; - let LLVMIntrinsic = int_dx_rsqrt; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -503,7 +557,7 @@ def RSqrt : DXILOp<25, unary> { def Round : DXILOp<26, unary> { let Doc = "Returns the input rounded to the nearest integer within a " "floating-point type."; - let LLVMIntrinsic = int_roundeven; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -514,7 +568,7 @@ def Round : DXILOp<26, unary> { def Floor : DXILOp<27, unary> { let Doc = "Returns the largest integer that is less than or equal to the input."; - let LLVMIntrinsic = int_floor; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -525,7 +579,7 @@ def Floor : DXILOp<27, unary> { def Ceil : DXILOp<28, unary> { let Doc = "Returns the smallest integer that is greater than or equal to the " "input."; - let LLVMIntrinsic = int_ceil; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -535,7 +589,7 @@ def Ceil : DXILOp<28, unary> { def Trunc : DXILOp<29, unary> { let Doc = "Returns the specified value truncated to the integer component."; - let LLVMIntrinsic = int_trunc; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -545,7 +599,7 @@ def Trunc : DXILOp<29, unary> { def Rbits : DXILOp<30, unary> { let Doc = "Returns the specified value with its bits reversed."; - let LLVMIntrinsic = int_bitreverse; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = @@ -567,7 +621,7 @@ def CountBits : DXILOp<31, unaryBits> { def FirstbitHi : DXILOp<33, unaryBits> { let Doc = "Returns the location of the first set bit starting from " "the highest order bit and working downward."; - let LLVMIntrinsic = int_dx_firstbituhigh; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = Int32Ty; let overloads = @@ -579,7 +633,7 @@ def FirstbitHi : DXILOp<33, unaryBits> { def FirstbitSHi : DXILOp<34, unaryBits> { let Doc = "Returns the location of the first set bit from " "the highest order bit based on the sign."; - let LLVMIntrinsic = int_dx_firstbitshigh; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = Int32Ty; let overloads = @@ -590,7 +644,7 @@ def FirstbitSHi : DXILOp<34, unaryBits> { def FMax : DXILOp<35, binary> { let Doc = "Float maximum. FMax(a,b) = a > b ? a : b"; - let LLVMIntrinsic = int_maxnum; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -601,7 +655,7 @@ def FMax : DXILOp<35, binary> { def FMin : DXILOp<36, binary> { let Doc = "Float minimum. FMin(a,b) = a < b ? a : b"; - let LLVMIntrinsic = int_minnum; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -612,7 +666,7 @@ def FMin : DXILOp<36, binary> { def SMax : DXILOp<37, binary> { let Doc = "Signed integer maximum. SMax(a,b) = a > b ? a : b"; - let LLVMIntrinsic = int_smax; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -623,7 +677,7 @@ def SMax : DXILOp<37, binary> { def SMin : DXILOp<38, binary> { let Doc = "Signed integer minimum. SMin(a,b) = a < b ? a : b"; - let LLVMIntrinsic = int_smin; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -634,7 +688,7 @@ def SMin : DXILOp<38, binary> { def UMax : DXILOp<39, binary> { let Doc = "Unsigned integer maximum. UMax(a,b) = a > b ? a : b"; - let LLVMIntrinsic = int_umax; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -645,7 +699,7 @@ def UMax : DXILOp<39, binary> { def UMin : DXILOp<40, binary> { let Doc = "Unsigned integer minimum. UMin(a,b) = a < b ? a : b"; - let LLVMIntrinsic = int_umin; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -657,7 +711,7 @@ def UMin : DXILOp<40, binary> { def FMad : DXILOp<46, tertiary> { let Doc = "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m " "* a + b."; - let LLVMIntrinsic = int_fmuladd; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -669,7 +723,7 @@ def FMad : DXILOp<46, tertiary> { def IMad : DXILOp<48, tertiary> { let Doc = "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m " "* a + b."; - let LLVMIntrinsic = int_dx_imad; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -681,7 +735,7 @@ def IMad : DXILOp<48, tertiary> { def UMad : DXILOp<49, tertiary> { let Doc = "Unsigned integer arithmetic multiply/add operation. umad(m,a, = m " "* a + b."; - let LLVMIntrinsic = int_dx_umad; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = @@ -693,7 +747,7 @@ def UMad : DXILOp<49, tertiary> { def Dot2 : DXILOp<54, dot2> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " "a[n]*b[n] where n is 0 to 1 inclusive"; - let LLVMIntrinsic = int_dx_dot2; + let intrinsics = [ IntrinSelect ]; let arguments = !listsplat(OverloadTy, 4); let result = OverloadTy; let overloads = [Overloads]; @@ -704,7 +758,7 @@ def Dot2 : DXILOp<54, dot2> { def Dot3 : DXILOp<55, dot3> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " "a[n]*b[n] where n is 0 to 2 inclusive"; - let LLVMIntrinsic = int_dx_dot3; + let intrinsics = [ IntrinSelect ]; let arguments = !listsplat(OverloadTy, 6); let result = OverloadTy; let overloads = [Overloads]; @@ -715,7 +769,7 @@ def Dot3 : DXILOp<55, dot3> { def Dot4 : DXILOp<56, dot4> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " "a[n]*b[n] where n is 0 to 3 inclusive"; - let LLVMIntrinsic = int_dx_dot4; + let intrinsics = [ IntrinSelect ]; let arguments = !listsplat(OverloadTy, 8); let result = OverloadTy; let overloads = [Overloads]; @@ -772,7 +826,7 @@ def CheckAccessFullyMapped : DXILOp<71, checkAccessFullyMapped> { def Discard : DXILOp<82, discard> { let Doc = "discard the current pixel"; - let LLVMIntrinsic = int_dx_discard; + let intrinsics = [ IntrinSelect ]; let arguments = [Int1Ty]; let result = VoidTy; let stages = [Stages]; @@ -780,7 +834,7 @@ def Discard : DXILOp<82, discard> { def ThreadId : DXILOp<93, threadId> { let Doc = "Reads the thread ID"; - let LLVMIntrinsic = int_dx_thread_id; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -790,7 +844,7 @@ def ThreadId : DXILOp<93, threadId> { def GroupId : DXILOp<94, groupId> { let Doc = "Reads the group ID (SV_GroupID)"; - let LLVMIntrinsic = int_dx_group_id; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -800,7 +854,7 @@ def GroupId : DXILOp<94, groupId> { def ThreadIdInGroup : DXILOp<95, threadIdInGroup> { let Doc = "Reads the thread ID within the group (SV_GroupThreadID)"; - let LLVMIntrinsic = int_dx_thread_id_in_group; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy]; let result = OverloadTy; let overloads = [Overloads]; @@ -811,7 +865,7 @@ def ThreadIdInGroup : DXILOp<95, threadIdInGroup> { def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> { let Doc = "Provides a flattened index for a given thread within a given " "group (SV_GroupIndex)"; - let LLVMIntrinsic = int_dx_flattened_thread_id_in_group; + let intrinsics = [ IntrinSelect ]; let result = OverloadTy; let overloads = [Overloads]; let stages = [Stages]; @@ -820,7 +874,7 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> { def MakeDouble : DXILOp<101, makeDouble> { let Doc = "creates a double value"; - let LLVMIntrinsic = int_dx_asdouble; + let intrinsics = [ IntrinSelect ]; let arguments = [Int32Ty, Int32Ty]; let result = DoubleTy; let stages = [Stages]; @@ -839,7 +893,7 @@ def SplitDouble : DXILOp<102, splitDouble> { def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> { let Doc = "signed dot product of 4 x i8 vectors packed into i32, with " "accumulate to i32"; - let LLVMIntrinsic = int_dx_dot4add_i8packed; + let intrinsics = [ IntrinSelect ]; let arguments = [Int32Ty, Int32Ty, Int32Ty]; let result = Int32Ty; let attributes = [Attributes]; @@ -849,7 +903,7 @@ def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> { def Dot4AddU8Packed : DXILOp<164, dot4AddPacked> { let Doc = "unsigned dot product of 4 x i8 vectors packed into i32, with " "accumulate to i32"; - let LLVMIntrinsic = int_dx_dot4add_u8packed; + let intrinsics = [ IntrinSelect ]; let arguments = [Int32Ty, Int32Ty, Int32Ty]; let result = Int32Ty; let attributes = [Attributes]; @@ -872,7 +926,7 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> { def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> { let Doc = "returns true if the expression is true in any of the active lanes in the current wave"; - let LLVMIntrinsic = int_dx_wave_any; + let intrinsics = [ IntrinSelect ]; let arguments = [Int1Ty]; let result = Int1Ty; let stages = [Stages]; @@ -880,7 +934,7 @@ def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> { def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> { let Doc = "returns 1 for the first lane in the wave"; - let LLVMIntrinsic = int_dx_wave_is_first_lane; + let intrinsics = [ IntrinSelect ]; let arguments = []; let result = Int1Ty; let stages = [Stages]; @@ -889,7 +943,7 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> { def WaveReadLaneAt: DXILOp<117, waveReadLaneAt> { let Doc = "returns the value from the specified lane"; - let LLVMIntrinsic = int_dx_wave_readlane; + let intrinsics = [ IntrinSelect ]; let arguments = [OverloadTy, Int32Ty]; let result = OverloadTy; let overloads = [Overloads]; @@ -899,7 +953,7 @@ def WaveReadLaneAt: DXILOp<117, waveReadLaneAt> { def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> { let Doc = "returns the index of the current lane in the wave"; - let LLVMIntrinsic = int_dx_wave_getlaneindex; + let intrinsics = [ IntrinSelect ]; let arguments = []; let result = Int32Ty; let stages = [Stages]; @@ -908,9 +962,23 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> { def WaveAllBitCount : DXILOp<135, waveAllOp> { let Doc = "returns the count of bits set to 1 across the wave"; - let LLVMIntrinsic = int_dx_wave_active_countbits; + let intrinsics = [ IntrinSelect ]; let arguments = [Int1Ty]; let result = Int32Ty; let stages = [Stages]; let attributes = [Attributes]; } + +def Barrier : DXILOp<80, barrier> { + let Doc = "inserts a memory barrier in the shader"; + let intrinsics = [ + IntrinSelect< + int_dx_group_memory_barrier_with_group_sync, + [ IntrinArgI32 ]>, + ]; + + let arguments = [Int32Ty]; + let result = VoidTy; + let stages = [Stages]; + let attributes = [Attributes]; +} diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index b5cc209493ed1..a0d46efd1763e 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -106,17 +106,43 @@ class OpLowerer { return false; } - [[nodiscard]] - bool replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) { + struct IntrinArgSelect { + enum class Type { +#define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name, +#include "DXILOperation.inc" + }; + Type Type; + int Value; + }; + + [[nodiscard]] bool + replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp, + ArrayRef ArgSelects) { bool IsVectorArgExpansion = isVectorArgExpansion(F); + assert(!(IsVectorArgExpansion && ArgSelects.size()) && + "Cann't do vector arg expansion when using arg selects."); return replaceFunction(F, [&](CallInst *CI) -> Error { - SmallVector Args; OpBuilder.getIRB().SetInsertPoint(CI); - if (IsVectorArgExpansion) { - SmallVector NewArgs = argVectorFlatten(CI, OpBuilder.getIRB()); - Args.append(NewArgs.begin(), NewArgs.end()); - } else + SmallVector Args; + if (ArgSelects.size()) { + for (const IntrinArgSelect &A : ArgSelects) { + switch (A.Type) { + case IntrinArgSelect::Type::Index: + Args.push_back(CI->getArgOperand(A.Value)); + break; + case IntrinArgSelect::Type::I8: + Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value)); + break; + case IntrinArgSelect::Type::I32: + Args.push_back(OpBuilder.getIRB().getInt32(A.Value)); + break; + } + } + } else if (IsVectorArgExpansion) { + Args = argVectorFlatten(CI, OpBuilder.getIRB()); + } else { Args.append(CI->arg_begin(), CI->arg_end()); + } Expected OpCall = OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType()); @@ -641,9 +667,10 @@ class OpLowerer { switch (ID) { default: continue; -#define DXIL_OP_INTRINSIC(OpCode, Intrin) \ +#define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \ case Intrin: \ - HasErrors |= replaceFunctionWithOp(F, OpCode); \ + HasErrors |= replaceFunctionWithOp( \ + F, OpCode, ArrayRef{__VA_ARGS__}); \ break; #include "DXILOperation.inc" case Intrinsic::dx_handle_fromBinding: diff --git a/llvm/test/CodeGen/DirectX/group_memory_barrier_with_group_sync.ll b/llvm/test/CodeGen/DirectX/group_memory_barrier_with_group_sync.ll new file mode 100644 index 0000000000000..baf93d4e177f0 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/group_memory_barrier_with_group_sync.ll @@ -0,0 +1,8 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s + +define void @test_group_memory_barrier_with_group_sync() { +entry: + ; CHECK: call void @dx.op.barrier(i32 80, i32 9) + call void @llvm.dx.group.memory.barrier.with.group.sync() + ret void +} \ No newline at end of file diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp index e74fc00015b40..a0c93bed5ad83 100644 --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -32,6 +32,20 @@ using namespace llvm::dxil; namespace { +struct DXILIntrinsicSelect { + StringRef Intrinsic; + SmallVector ArgSelectRecords; +}; + +static StringRef StripIntrinArgSelectTypePrefix(StringRef Type) { + StringRef Prefix = "IntrinArgSelect_"; + if (!Type.starts_with(Prefix)) { + PrintFatalError("IntrinArgSelectType definintion must be prefixed with " + "'IntrinArgSelect_'"); + } + return Type.substr(Prefix.size()); +} + struct DXILOperationDesc { std::string OpName; // name of DXIL operation int OpCode; // ID of DXIL operation @@ -42,8 +56,7 @@ struct DXILOperationDesc { SmallVector OverloadRecs; SmallVector StageRecs; SmallVector AttrRecs; - StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which - // means no map exists + SmallVector IntrinsicSelects; SmallVector ShaderStages; // shader stages to which this applies, empty for all. int OverloadParamIndex; // Index of parameter with overload type. @@ -71,6 +84,21 @@ static void ascendingSortByVersion(std::vector &Recs) { }); } +/// Take a `int_{intrinsic_name}` and return just the intrinsic_name part if +/// available. Otherwise return the empty string. +static StringRef GetIntrinsicName(const RecordVal *RV) { + if (RV && RV->getValue()) { + if (const DefInit *DI = dyn_cast(RV->getValue())) { + auto *IntrinsicDef = DI->getDef(); + auto DefName = IntrinsicDef->getName(); + assert(DefName.starts_with("int_") && "invalid intrinsic name"); + // Remove the int_ from intrinsic name. + return DefName.substr(4); + } + } + return ""; +} + /// Construct an object using the DXIL Operation records specified /// in DXIL.td. This serves as the single source of reference of /// the information extracted from the specified Record R, for @@ -157,14 +185,16 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { OpName); } - const RecordVal *RV = R->getValue("LLVMIntrinsic"); - if (RV && RV->getValue()) { - if (const DefInit *DI = dyn_cast(RV->getValue())) { - auto *IntrinsicDef = DI->getDef(); - auto DefName = IntrinsicDef->getName(); - assert(DefName.starts_with("int_") && "invalid intrinsic name"); - // Remove the int_ from intrinsic name. - Intrinsic = DefName.substr(4); + auto IntrinsicSelectRecords = R->getValueAsListOfDefs("intrinsics"); + if (IntrinsicSelectRecords.size()) { + for (const Record *R : IntrinsicSelectRecords) { + DXILIntrinsicSelect IntrSelect; + IntrSelect.Intrinsic = GetIntrinsicName(R->getValue("intrinsic")); + auto Args = R->getValueAsListOfDefs("arg_selects"); + for (const Record *ArgSelect : Args) { + IntrSelect.ArgSelectRecords.emplace_back(ArgSelect); + } + IntrinsicSelects.emplace_back(std::move(IntrSelect)); } } } @@ -374,19 +404,45 @@ static void emitDXILOpFunctionTypes(ArrayRef Ops, /// \param Output stream static void emitDXILIntrinsicMap(ArrayRef Ops, raw_ostream &OS) { + OS << "#ifdef DXIL_OP_INTRINSIC\n"; OS << "\n"; for (const auto &Op : Ops) { - if (Op.Intrinsic.empty()) + if (Op.IntrinsicSelects.empty()) { continue; - OS << "DXIL_OP_INTRINSIC(dxil::OpCode::" << Op.OpName - << ", Intrinsic::" << Op.Intrinsic << ")\n"; + } + for (const DXILIntrinsicSelect &MappedIntr : Op.IntrinsicSelects) { + OS << "DXIL_OP_INTRINSIC(dxil::OpCode::" << Op.OpName + << ", Intrinsic::" << MappedIntr.Intrinsic << ", "; + for (const Record *ArgSelect : MappedIntr.ArgSelectRecords) { + std::string Type = + ArgSelect->getValueAsDef("type")->getNameInitAsString(); + int Value = ArgSelect->getValueAsInt("value"); + OS << "(IntrinArgSelect{" + << "IntrinArgSelect::Type::" << StripIntrinArgSelectTypePrefix(Type) + << "," << Value << "}), "; + } + OS << ")\n"; + } } OS << "\n"; OS << "#undef DXIL_OP_INTRINSIC\n"; OS << "#endif\n\n"; } +/// Emit the IntrinArgSelect type for DirectX intrinsic to DXIL Op lowering +static void emitDXILIntrinsicArgSelectTypes(const RecordKeeper &Records, + raw_ostream &OS) { + OS << "#ifdef DXIL_OP_INTRINSIC_ARG_SELECT_TYPE\n"; + for (const Record *Records : + Records.getAllDerivedDefinitions("IntrinArgSelectType")) { + StringRef StrippedName = StripIntrinArgSelectTypePrefix(Records->getName()); + OS << "DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(" << StrippedName << ")\n"; + } + OS << "#undef DXIL_OP_INTRINSIC_ARG_SELECT_TYPE\n"; + OS << "#endif\n\n"; +} + /// Emit DXIL operation table /// \param A vector of DXIL Ops /// \param Output stream @@ -527,6 +583,7 @@ static void emitDxilOperation(const RecordKeeper &Records, raw_ostream &OS) { emitDXILOpClasses(Records, OS); emitDXILOpParamTypes(Records, OS); emitDXILOpFunctionTypes(DXILOps, OS); + emitDXILIntrinsicArgSelectTypes(Records, OS); emitDXILIntrinsicMap(DXILOps, OS); OS << "#ifdef DXIL_OP_OPERATION_TABLE\n\n"; emitDXILOperationTableDataStructs(Records, OS);