-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[HLSL][SPIRV][DXIL] Implement WaveActiveSum
intrinsic
#112400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1824,6 +1824,23 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, | |
return false; | ||
} | ||
|
||
static bool CheckNotScalarType(Sema *S, CallExpr *TheCall, QualType Scalar, | ||
unsigned ArgIndex) { | ||
assert(TheCall->getNumArgs() >= ArgIndex); | ||
QualType ArgType = TheCall->getArg(ArgIndex)->getType(); | ||
auto *VTy = ArgType->getAs<VectorType>(); | ||
// is the scalar or vector<scalar> | ||
if (S->Context.hasSameUnqualifiedType(ArgType, Scalar) || | ||
(VTy && | ||
S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))) { | ||
S->Diag(TheCall->getArg(0)->getBeginLoc(), | ||
diag::err_typecheck_expect_scalar_or_vector_not_type) | ||
<< ArgType << Scalar; | ||
return true; | ||
} | ||
return false; | ||
} | ||
Comment on lines
+1827
to
+1842
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems like it's being made generic in a way that isn't particularly useful. Do we have cases where we reject a single specific scalar type, and that that type is something other than bool? Also, I find the error message kind of confusing:
This is redundant, isn't it? What information is this providing that "invalid operand of type 'bool'" wouldn't? |
||
|
||
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { | ||
assert(TheCall->getNumArgs() == 3); | ||
Expr *Arg1 = TheCall->getArg(1); | ||
|
@@ -2059,6 +2076,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { | |
TheCall->setType(ArgTyA); | ||
break; | ||
} | ||
case Builtin::BI__builtin_hlsl_wave_active_sum: { | ||
if (SemaRef.checkArgCount(TheCall, 1)) | ||
return true; | ||
|
||
// Ensure input expr type is a scalar/vector and the same as the return type | ||
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) | ||
return true; | ||
if (CheckNotScalarType(&SemaRef, TheCall, getASTContext().BoolTy, 0)) | ||
return true; | ||
ExprResult Expr = TheCall->getArg(0); | ||
QualType ArgTyExpr = Expr.get()->getType(); | ||
TheCall->setType(ArgTyExpr); | ||
break; | ||
} | ||
// Note these are llvm builtins that we want to catch invalid intrinsic | ||
// generation. Normal handling of these builitns will occur elsewhere. | ||
case Builtin::BI__builtin_elementwise_bitreverse: { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ | ||
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ | ||
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL | ||
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ | ||
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ | ||
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV | ||
|
||
// Test basic lowering to runtime function call. | ||
|
||
// CHECK-LABEL: test_int | ||
int test_int(int expr) { | ||
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.active.sum.i32([[TY]] %[[#]]) | ||
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.active.sum.i32([[TY]] %[[#]]) | ||
// CHECK: ret [[TY]] %[[RET]] | ||
return WaveActiveSum(expr); | ||
} | ||
|
||
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.active.sum.i32([[TY]]) #[[#attr:]] | ||
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.active.sum.i32([[TY]]) #[[#attr:]] | ||
|
||
// CHECK-LABEL: test_uint64_t | ||
uint64_t test_uint64_t(uint64_t expr) { | ||
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.active.sum.i64([[TY]] %[[#]]) | ||
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.active.usum.i64([[TY]] %[[#]]) | ||
// CHECK: ret [[TY]] %[[RET]] | ||
return WaveActiveSum(expr); | ||
} | ||
|
||
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.active.usum.i64([[TY]]) #[[#attr:]] | ||
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.active.sum.i64([[TY]]) #[[#attr:]] | ||
|
||
// Test basic lowering to runtime function call with array and float value. | ||
|
||
// CHECK-LABEL: test_floatv4 | ||
float4 test_floatv4(float4 expr) { | ||
// CHECK-SPIRV: %[[RET1:.*]] = call spir_func [[TY1:.*]] @llvm.spv.wave.active.sum.v4f32([[TY1]] %[[#]] | ||
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.active.sum.v4f32([[TY1]] %[[#]]) | ||
// CHECK: ret [[TY1]] %[[RET1]] | ||
return WaveActiveSum(expr); | ||
} | ||
|
||
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.active.sum.v4f32([[TY1]]) #[[#attr]] | ||
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.active.sum.v4f32([[TY1]]) #[[#attr]] | ||
|
||
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify | ||
|
||
int test_too_few_arg() { | ||
return __builtin_hlsl_wave_active_sum(); | ||
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}} | ||
} | ||
|
||
float2 test_too_many_arg(float2 p0) { | ||
return __builtin_hlsl_wave_active_sum(p0, p0); | ||
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}} | ||
} | ||
|
||
bool test_expr_bool_type_check(bool p0) { | ||
return __builtin_hlsl_wave_active_sum(p0); | ||
// expected-error@-1 {{invalid operand of type 'bool' where 'bool' or a vector of such type is not allowed}} | ||
} | ||
|
||
bool2 test_expr_bool_vec_type_check(bool2 p0) { | ||
return __builtin_hlsl_wave_active_sum(p0); | ||
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>') where 'bool' or a vector of such type is not allowed}} | ||
} | ||
|
||
struct S { float f; }; | ||
|
||
S test_expr_struct_type_check(S p0) { | ||
return __builtin_hlsl_wave_active_sum(p0); | ||
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,8 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV | |
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>; | ||
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; | ||
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>; | ||
def int_dx_wave_active_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could further be split into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a difference between float and signed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No difference in how they are lowered to a DXIL op. They will both be lowered to a |
||
def int_dx_wave_active_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; | ||
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>; | ||
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>; | ||
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the
__attribute__((convergent))
spelling to be consistent with the rest of the file.