Skip to content

[HLSL][SPIRV][DXIL] Implement WaveActiveSum intrinsic #112400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4750,6 +4750,12 @@ def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Prototype = "unsigned int(bool)";
}

def HLSLWaveActiveSum : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_sum"];
let Attributes = [NoThrow, Const];
let Prototype = "void (...)";
}

def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_get_lane_index"];
let Attributes = [NoThrow, Const];
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -9236,6 +9236,9 @@ def err_typecheck_expect_scalar_or_vector : Error<
"a vector of such type is required">;
def err_typecheck_expect_any_scalar_or_vector : Error<
"invalid operand of type %0 where a scalar or vector is required">;
def err_typecheck_expect_scalar_or_vector_not_type : Error<
"invalid operand of type %0 where %1 or "
"a vector of such type is not allowed">;
def err_typecheck_expect_flt_or_vector : Error<
"invalid operand of type %0 where floating, complex or "
"a vector of such types is required">;
Expand Down
34 changes: 34 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18715,6 +18715,23 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
return RT.getUDotIntrinsic();
}

// Return wave active sum that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
return llvm::Intrinsic::spv_wave_active_sum;
case llvm::Triple::dxil: {
if (QT->isUnsignedIntegerType())
return llvm::Intrinsic::dx_wave_active_usum;
return llvm::Intrinsic::dx_wave_active_sum;
}
default:
llvm_unreachable("Intrinsic WaveActiveSum"
" not supported by target architecture");
}
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E,
ReturnValueSlot ReturnValue) {
Expand Down Expand Up @@ -18960,6 +18977,23 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_active_sum: {
// Due to the use of variadic arguments, explicitly retreive argument
Value *OpExpr = EmitScalarExpr(E->getArg(0));
llvm::FunctionType *FT = llvm::FunctionType::get(
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
Intrinsic::ID IID = getWaveActiveSumIntrinsic(
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
E->getArg(0)->getType());

// Get overloaded name
std::string Name =
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
/*Local=*/false,
/*AssumeConvergent=*/true),
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
// defined in SPIRVBuiltins.td. So instead we manually get the matching name
Expand Down
99 changes: 99 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,105 @@ __attribute__((convergent)) double3 WaveReadLaneAt(double3, int32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double4 WaveReadLaneAt(double4, int32_t);

//===----------------------------------------------------------------------===//
// WaveActiveSum builtins
//===----------------------------------------------------------------------===//

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) half WaveActiveSum(half);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the __attribute__((convergent)) spelling to be consistent with the rest of the file.

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) half2 WaveActiveSum(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) half3 WaveActiveSum(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) half4 WaveActiveSum(half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int16_t WaveActiveSum(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int16_t2 WaveActiveSum(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int16_t3 WaveActiveSum(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int16_t4 WaveActiveSum(int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint16_t WaveActiveSum(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint16_t2 WaveActiveSum(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint16_t3 WaveActiveSum(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint16_t4 WaveActiveSum(uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int WaveActiveSum(int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int2 WaveActiveSum(int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int3 WaveActiveSum(int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int4 WaveActiveSum(int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint WaveActiveSum(uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint2 WaveActiveSum(uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint3 WaveActiveSum(uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint4 WaveActiveSum(uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int64_t WaveActiveSum(int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int64_t2 WaveActiveSum(int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int64_t3 WaveActiveSum(int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) int64_t4 WaveActiveSum(int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint64_t WaveActiveSum(uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint64_t2 WaveActiveSum(uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint64_t3 WaveActiveSum(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) uint64_t4 WaveActiveSum(uint64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) float WaveActiveSum(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) float2 WaveActiveSum(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) float3 WaveActiveSum(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) float4 WaveActiveSum(float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) double WaveActiveSum(double);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) double2 WaveActiveSum(double2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) double3 WaveActiveSum(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute((convergent)) double4 WaveActiveSum(double4);

//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 31 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,23 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
return false;
}

static bool CheckNotScalarType(Sema *S, CallExpr *TheCall, QualType Scalar,
unsigned ArgIndex) {
assert(TheCall->getNumArgs() >= ArgIndex);
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
auto *VTy = ArgType->getAs<VectorType>();
// is the scalar or vector<scalar>
if (S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
(VTy &&
S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))) {
S->Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_typecheck_expect_scalar_or_vector_not_type)
<< ArgType << Scalar;
return true;
}
return false;
}
Comment on lines +1827 to +1842
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems like it's being made generic in a way that isn't particularly useful. Do we have cases where we reject a single specific scalar type, and that that type is something other than bool?

Also, I find the error message kind of confusing:

invalid operand of type 'bool' where 'bool' or a vector of such type is not allowed

This is redundant, isn't it? What information is this providing that "invalid operand of type 'bool'" wouldn't?


static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
Expand Down Expand Up @@ -2059,6 +2076,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyA);
break;
}
case Builtin::BI__builtin_hlsl_wave_active_sum: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;

// Ensure input expr type is a scalar/vector and the same as the return type
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
return true;
if (CheckNotScalarType(&SemaRef, TheCall, getASTContext().BoolTy, 0))
return true;
ExprResult Expr = TheCall->getArg(0);
QualType ArgTyExpr = Expr.get()->getType();
TheCall->setType(ArgTyExpr);
break;
}
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
Expand Down
45 changes: 45 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_int
int test_int(int expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.active.sum.i32([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.active.sum.i32([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveSum(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.active.sum.i32([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.active.sum.i32([[TY]]) #[[#attr:]]

// CHECK-LABEL: test_uint64_t
uint64_t test_uint64_t(uint64_t expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.active.sum.i64([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.active.usum.i64([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveSum(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.active.usum.i64([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.active.sum.i64([[TY]]) #[[#attr:]]

// Test basic lowering to runtime function call with array and float value.

// CHECK-LABEL: test_floatv4
float4 test_floatv4(float4 expr) {
// CHECK-SPIRV: %[[RET1:.*]] = call spir_func [[TY1:.*]] @llvm.spv.wave.active.sum.v4f32([[TY1]] %[[#]]
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.active.sum.v4f32([[TY1]] %[[#]])
// CHECK: ret [[TY1]] %[[RET1]]
return WaveActiveSum(expr);
}

// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.active.sum.v4f32([[TY1]]) #[[#attr]]
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.active.sum.v4f32([[TY1]]) #[[#attr]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
28 changes: 28 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveSum-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

int test_too_few_arg() {
return __builtin_hlsl_wave_active_sum();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

float2 test_too_many_arg(float2 p0) {
return __builtin_hlsl_wave_active_sum(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool test_expr_bool_type_check(bool p0) {
return __builtin_hlsl_wave_active_sum(p0);
// expected-error@-1 {{invalid operand of type 'bool' where 'bool' or a vector of such type is not allowed}}
}

bool2 test_expr_bool_vec_type_check(bool2 p0) {
return __builtin_hlsl_wave_active_sum(p0);
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>') where 'bool' or a vector of such type is not allowed}}
}

struct S { float f; };

S test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_active_sum(p0);
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
}
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_active_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could further be split into int_dx_wave_active_fsum and int_dx_wave_active_ssum to denote the difference in float/signed integer. They were merged into one as the separation is not required for lowering to WaveActiveOp. Would appreciate if reviewers weighed in. Could increase readability/documentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a difference between float and signed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No difference in how they are lowered to a DXIL op. They will both be lowered to a WaveActiveOp with the same flags.

def int_dx_wave_active_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_wave_active_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,18 @@ defset list<DXILConstant> BarrierModes = {
def BarrierMode_AllMemoryBarrierWithGroupSync : DXILConstant<11>;
}

defset list<DXILConstant> WaveOpKind = {
def WaveOpKind_Sum : DXILConstant<0>;
def WaveOpKind_Product : DXILConstant<1>;
def WaveOpKind_Min : DXILConstant<2>;
def WaveOpKind_Max : DXILConstant<3>;
}

defset list<DXILConstant> SignedOpKind = {
def SignedOpKind_Signed : DXILConstant<0>;
def SignedOpKind_Unsigned : DXILConstant<1>;
}

// Intrinsic arg selection
class Arg {
int index = -1;
Expand Down Expand Up @@ -842,6 +854,24 @@ def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
let stages = [Stages<DXIL1_6, [all_stages]>];
}

def WaveActiveOp : DXILOp<119, waveActiveOp> {
let Doc = "returns the result of the operation across waves";
let intrinsic_selects = [
IntrinsicSelect<
int_dx_wave_active_sum,
[ ArgSelect<0>, ArgI8<WaveOpKind_Sum>, ArgI8<SignedOpKind_Signed> ]>,
IntrinsicSelect<
int_dx_wave_active_usum,
[ ArgSelect<0>, ArgI8<WaveOpKind_Sum>, ArgI8<SignedOpKind_Unsigned> ]>,
];

let arguments = [OverloadTy, Int8Ty, Int8Ty];
let result = OverloadTy;
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
let Doc = "returns 1 for the first lane in the wave";
let LLVMIntrinsic = int_dx_wave_is_first_lane;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
switch (ID) {
case Intrinsic::dx_frac:
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_wave_active_sum:
case Intrinsic::dx_wave_active_usum:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_splitdouble:
return true;
Expand Down
Loading
Loading