diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 8668b25661dec..7e89f84319877 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4751,6 +4751,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } +def HLSLSelect : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_select"]; + let Attributes = [NoThrow, Const]; + let Prototype = "void(...)"; +} + // Builtins for XRay. def XRayCustomEvent : Builtin { let Spellings = ["__xray_customevent"]; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index dcb49d8a67604..68c6993089dcb 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -9206,6 +9206,9 @@ def err_typecheck_expect_scalar_operand : Error< "operand of type %0 where arithmetic or pointer type is required">; def err_typecheck_cond_incompatible_operands : Error< "incompatible operand types%diff{ ($ and $)|}0,1">; +def err_typecheck_expect_scalar_or_vector : Error< + "invalid operand of type %0 where %1 or " + "a vector of such type is required">; def err_typecheck_expect_flt_or_vector : Error< "invalid operand of type %0 where floating, complex or " "a vector of such types is required">; diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index e4d169d2ad603..79f09bfd0f13d 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -6241,8 +6241,20 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, } // EmitHLSLBuiltinExpr will check getLangOpts().HLSL - if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E)) - return RValue::get(V); + if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E, ReturnValue)) { + switch (EvalKind) { + case TEK_Scalar: + if (V->getType()->isVoidTy()) + return RValue::get(nullptr); + return RValue::get(V); + case TEK_Aggregate: + return RValue::getAggregate(ReturnValue.getAddress(), + ReturnValue.isVolatile()); + case TEK_Complex: + llvm_unreachable("No current hlsl builtin returns complex"); + } + llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr"); + } if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice) return EmitHipStdParUnsupportedBuiltin(this, FD); @@ -18508,7 +18520,8 @@ Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) { } Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, - const CallExpr *E) { + const CallExpr *E, + ReturnValueSlot ReturnValue) { if (!getLangOpts().HLSL) return nullptr; @@ -18695,6 +18708,27 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.saturate"); } + case Builtin::BI__builtin_hlsl_select: { + Value *OpCond = EmitScalarExpr(E->getArg(0)); + RValue RValTrue = EmitAnyExpr(E->getArg(1)); + Value *OpTrue = + RValTrue.isScalar() + ? RValTrue.getScalarVal() + : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this); + RValue RValFalse = EmitAnyExpr(E->getArg(2)); + Value *OpFalse = + RValFalse.isScalar() + ? RValFalse.getScalarVal() + : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this); + + Value *SelectVal = + Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select"); + if (!RValTrue.isScalar()) + Builder.CreateStore(SelectVal, ReturnValue.getAddress(), + ReturnValue.isVolatile()); + + return SelectVal; + } case Builtin::BI__builtin_hlsl_wave_get_lane_index: { return EmitRuntimeCall(CGM.CreateRuntimeFunction( llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index", diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 368fc112187ff..35a5275e61211 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4700,7 +4700,8 @@ class CodeGenFunction : public CodeGenTypeCache { llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E); - llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E); + llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E, + ReturnValueSlot ReturnValue); llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx, const CallExpr *E); llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index 6d38b668fe770..f7e37511cbe4e 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -1603,6 +1603,32 @@ double3 saturate(double3); _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate) double4 saturate(double4); +//===----------------------------------------------------------------------===// +// select builtins +//===----------------------------------------------------------------------===// + +/// \fn T select(bool Cond, T TrueVal, T FalseVal) +/// \brief ternary operator. +/// \param Cond The Condition input value. +/// \param TrueVal The Value returned if Cond is true. +/// \param FalseVal The Value returned if Cond is false. + +template +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select) +T select(bool, T, T); + +/// \fn vector select(vector Conds, vector TrueVals, +/// vector FalseVals) +/// \brief ternary operator for vectors. All vectors must be the same size. +/// \param Conds The Condition input values. +/// \param TrueVals The vector values are chosen from when conditions are true. +/// \param FalseVals The vector values are chosen from when conditions are +/// false. + +template +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select) +vector select(vector, vector, vector); + //===----------------------------------------------------------------------===// // sin builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index fabc6f32906b1..49482fdfd0b1e 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1512,6 +1512,79 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, TheCall->setType(ReturnType); } +static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, + unsigned ArgIndex) { + assert(TheCall->getNumArgs() >= ArgIndex); + QualType ArgType = TheCall->getArg(ArgIndex)->getType(); + auto *VTy = ArgType->getAs(); + // not the scalar or vector + if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) || + (VTy && + S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) { + S->Diag(TheCall->getArg(0)->getBeginLoc(), + diag::err_typecheck_expect_scalar_or_vector) + << ArgType << Scalar; + return true; + } + return false; +} + +static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { + assert(TheCall->getNumArgs() == 3); + Expr *Arg1 = TheCall->getArg(1); + Expr *Arg2 = TheCall->getArg(2); + if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) { + S->Diag(TheCall->getBeginLoc(), + diag::err_typecheck_call_different_arg_types) + << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange() + << Arg2->getSourceRange(); + return true; + } + + TheCall->setType(Arg1->getType()); + return false; +} + +static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) { + assert(TheCall->getNumArgs() == 3); + Expr *Arg1 = TheCall->getArg(1); + Expr *Arg2 = TheCall->getArg(2); + if (!Arg1->getType()->isVectorType()) { + S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type) + << "Second" << TheCall->getDirectCallee() << Arg1->getType() + << Arg1->getSourceRange(); + return true; + } + + if (!Arg2->getType()->isVectorType()) { + S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type) + << "Third" << TheCall->getDirectCallee() << Arg2->getType() + << Arg2->getSourceRange(); + return true; + } + + if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) { + S->Diag(TheCall->getBeginLoc(), + diag::err_typecheck_call_different_arg_types) + << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange() + << Arg2->getSourceRange(); + return true; + } + + // caller has checked that Arg0 is a vector. + // check all three args have the same length. + if (TheCall->getArg(0)->getType()->getAs()->getNumElements() != + Arg1->getType()->getAs()->getNumElements()) { + S->Diag(TheCall->getBeginLoc(), + diag::err_typecheck_vector_lengths_not_equal) + << TheCall->getArg(0)->getType() << Arg1->getType() + << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange(); + return true; + } + TheCall->setType(Arg1->getType()); + return false; +} + // Note: returning true in this case results in CheckBuiltinFunctionCall // returning an ExprError bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { @@ -1544,6 +1617,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } + case Builtin::BI__builtin_hlsl_select: { + if (SemaRef.checkArgCount(TheCall, 3)) + return true; + if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0)) + return true; + QualType ArgTy = TheCall->getArg(0)->getType(); + if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall)) + return true; + auto *VTy = ArgTy->getAs(); + if (VTy && VTy->getElementType()->isBooleanType() && + CheckVectorSelect(&SemaRef, TheCall)) + return true; + break; + } case Builtin::BI__builtin_hlsl_elementwise_saturate: case Builtin::BI__builtin_hlsl_elementwise_rcp: { if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall)) diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl new file mode 100644 index 0000000000000..cade938b71a2b --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/select.hlsl @@ -0,0 +1,54 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s --check-prefixes=CHECK + +// CHECK-LABEL: test_select_bool_int +// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, i32 {{%.*}}, i32 {{%.*}} +// CHECK: ret i32 [[SELECT]] +int test_select_bool_int(bool cond0, int tVal, int fVal) { + return select(cond0, tVal, fVal); +} + +struct S { int a; }; +// CHECK-LABEL: test_select_infer +// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, ptr {{%.*}}, ptr {{%.*}} +// CHECK: store ptr [[SELECT]] +// CHECK: ret void +struct S test_select_infer(bool cond0, struct S tVal, struct S fVal) { + return select(cond0, tVal, fVal); +} + +// CHECK-LABEL: test_select_bool_vector +// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}} +// CHECK: ret <2 x i32> [[SELECT]] +int2 test_select_bool_vector(bool cond0, int2 tVal, int2 fVal) { + return select(cond0, tVal, fVal); +} + +// CHECK-LABEL: test_select_vector_1 +// CHECK: [[SELECT:%.*]] = select <1 x i1> {{%.*}}, <1 x i32> {{%.*}}, <1 x i32> {{%.*}} +// CHECK: ret <1 x i32> [[SELECT]] +int1 test_select_vector_1(bool1 cond0, int1 tVals, int1 fVals) { + return select(cond0, tVals, fVals); +} + +// CHECK-LABEL: test_select_vector_2 +// CHECK: [[SELECT:%.*]] = select <2 x i1> {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}} +// CHECK: ret <2 x i32> [[SELECT]] +int2 test_select_vector_2(bool2 cond0, int2 tVals, int2 fVals) { + return select(cond0, tVals, fVals); +} + +// CHECK-LABEL: test_select_vector_3 +// CHECK: [[SELECT:%.*]] = select <3 x i1> {{%.*}}, <3 x i32> {{%.*}}, <3 x i32> {{%.*}} +// CHECK: ret <3 x i32> [[SELECT]] +int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) { + return select(cond0, tVals, fVals); +} + +// CHECK-LABEL: test_select_vector_4 +// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> {{%.*}} +// CHECK: ret <4 x i32> [[SELECT]] +int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) { + return select(cond0, tVals, fVals); +} diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl new file mode 100644 index 0000000000000..34b5fb6d54cd5 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl @@ -0,0 +1,119 @@ +// RUN: %clang_cc1 -finclude-default-header +// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only +// -disable-llvm-passes -verify -verify-ignore-unexpected + +int test_no_arg() { + return select(); + // expected-error@-1 {{no matching function for call to 'select'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template + // not viable: requires 3 arguments, but 0 were provided}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not + // viable: requires 3 arguments, but 0 were provided}} +} + +int test_too_few_args(bool p0) { + return select(p0); + // expected-error@-1 {{no matching function for call to 'select'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not + // viable: requires 3 arguments, but 1 was provided}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not + // viable: requires 3 arguments, but 1 was provided}} +} + +int test_too_many_args(bool p0, int t0, int f0, int g0) { + return select(p0, t0, f0, g0); + // expected-error@-1 {{no matching function for call to 'select'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not + // viable: requires 3 arguments, but 4 were provided}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not + // viable: requires 3 arguments, but 4 were provided}} +} + +int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) { + return select(p0, t0, f0); + // expected-error@-1 {{no matching function for call to 'select'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not + // viable: no known conversion from 'vector' (vector of 1 'int' value) + // to 'bool' for 1st argument}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could + // not match 'vector' against 'int'}} +} + +int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) { + return select(p0, t0, f0); + // expected-warning@-1 {{implicit conversion truncates vector: + // 'vector' (vector of 2 'int' values) to 'vector' + // (vector of 1 'int' value)}} +} + +int2 test_select_vector_vals_not_vecs(bool2 p0, int t0, + int f0) { + return select(p0, t0, f0); + // expected-error@-1 {{no matching function for call to 'select'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: + // could not match 'vector' against 'int'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not + // viable: no known conversion from 'vector' + // (vector of 2 'bool' values) to 'bool' for 1st argument}} +} + +int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) { + return select(p0, t0, f0); // produce warnings + // expected-warning@-1 {{implicit conversion truncates vector: + // 'vector' (vector of 2 'bool' values) to 'vector' + // (vector of 1 'bool' value)}} + // expected-warning@-2 {{implicit conversion truncates vector: + // 'vector' (vector of 2 'int' values) to 'vector' + // (vector of 1 'int' value)}} +} + +// __builtin_hlsl_select tests +int test_select_builtin_wrong_arg_count(bool p0, int t0) { + return __builtin_hlsl_select(p0, t0); + // expected-error@-1 {{too few arguments to function call, expected 3, + // have 2}} +} + +// not a bool or a vector of bool. should be 2 errors. +int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) { + return __builtin_hlsl_select(p0, t0, f0); + // expected-error@-1 {{passing 'int' to parameter of incompatible type + // 'bool'}} + // expected-error@-2 {{First argument to __builtin_hlsl_select must be of + // vector type}} + } + +int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) { + return __builtin_hlsl_select(p0, t0, f0); + // expected-error@-1 {{passing 'vector' (vector of 1 'int' value) to + // parameter of incompatible type 'bool'}} + // expected-error@-2 {{First argument to __builtin_hlsl_select must be of + // vector type}} +} + +// if a bool last 2 args are of same type +int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) { + return __builtin_hlsl_select(p0, t0, f0); + // expected-error@-1 {{arguments are of different types ('int' vs 'double')}} +} + +// if a vector second arg isnt a vector +int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) { + return __builtin_hlsl_select(p0, t0, f0); + // expected-error@-1 {{Second argument to __builtin_hlsl_select must be of + // vector type}} +} + +// if a vector third arg isn't a vector +int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) { + return __builtin_hlsl_select(p0, t0, f0); + // expected-error@-1 {{Third argument to __builtin_hlsl_select must be of + // vector type}} +} + +// if vector last 2 aren't same type (so both are vectors but wrong type) +int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) { + return __builtin_hlsl_select(p0, t0, f0); + // expected-error@-1 {{arguments are of different types ('vector' + // vs 'vector')}} +}