Skip to content

[HLSL] Implement support for HLSL intrinsic - select #107129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 9, 2024
Merged
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4751,6 +4751,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_select"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -9206,6 +9206,9 @@ def err_typecheck_expect_scalar_operand : Error<
"operand of type %0 where arithmetic or pointer type is required">;
def err_typecheck_cond_incompatible_operands : Error<
"incompatible operand types%diff{ ($ and $)|}0,1">;
def err_typecheck_expect_scalar_or_vector : Error<
"invalid operand of type %0 where %1 or "
"a vector of such type is required">;
def err_typecheck_expect_flt_or_vector : Error<
"invalid operand of type %0 where floating, complex or "
"a vector of such types is required">;
Expand Down
40 changes: 37 additions & 3 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6241,8 +6241,20 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
}

// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
return RValue::get(V);
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E, ReturnValue)) {
switch (EvalKind) {
case TEK_Scalar:
if (V->getType()->isVoidTy())
return RValue::get(nullptr);
return RValue::get(V);
case TEK_Aggregate:
return RValue::getAggregate(ReturnValue.getAddress(),
ReturnValue.isVolatile());
case TEK_Complex:
llvm_unreachable("No current hlsl builtin returns complex");
}
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
}

if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
return EmitHipStdParUnsupportedBuiltin(this, FD);
Expand Down Expand Up @@ -18508,7 +18520,8 @@ Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
const CallExpr *E,
ReturnValueSlot ReturnValue) {
if (!getLangOpts().HLSL)
return nullptr;

Expand Down Expand Up @@ -18695,6 +18708,27 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
nullptr, "hlsl.saturate");
}
case Builtin::BI__builtin_hlsl_select: {
Value *OpCond = EmitScalarExpr(E->getArg(0));
RValue RValTrue = EmitAnyExpr(E->getArg(1));
Value *OpTrue =
RValTrue.isScalar()
? RValTrue.getScalarVal()
: RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
RValue RValFalse = EmitAnyExpr(E->getArg(2));
Value *OpFalse =
RValFalse.isScalar()
? RValFalse.getScalarVal()
: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);

Value *SelectVal =
Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
if (!RValTrue.isScalar())
Builder.CreateStore(SelectVal, ReturnValue.getAddress(),
ReturnValue.isVolatile());

return SelectVal;
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4700,7 +4700,8 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
ReturnValueSlot ReturnValue);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
const CallExpr *E);
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,32 @@ double3 saturate(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
double4 saturate(double4);

//===----------------------------------------------------------------------===//
// select builtins
//===----------------------------------------------------------------------===//

/// \fn T select(bool Cond, T TrueVal, T FalseVal)
/// \brief ternary operator.
/// \param Cond The Condition input value.
/// \param TrueVal The Value returned if Cond is true.
/// \param FalseVal The Value returned if Cond is false.

template <typename T>
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
T select(bool, T, T);

/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
/// vector<T,Sz> FalseVals)
/// \brief ternary operator for vectors. All vectors must be the same size.
/// \param Conds The Condition input values.
/// \param TrueVals The vector values are chosen from when conditions are true.
/// \param FalseVals The vector values are chosen from when conditions are
/// false.

template <typename T, int Sz>
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);

//===----------------------------------------------------------------------===//
// sin builtins
//===----------------------------------------------------------------------===//
Expand Down
87 changes: 87 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,79 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
TheCall->setType(ReturnType);
}

static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
unsigned ArgIndex) {
assert(TheCall->getNumArgs() >= ArgIndex);
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
auto *VTy = ArgType->getAs<VectorType>();
// not the scalar or vector<scalar>
if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
(VTy &&
S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
S->Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_typecheck_expect_scalar_or_vector)
<< ArgType << Scalar;
return true;
}
return false;
}

static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
Expr *Arg2 = TheCall->getArg(2);
if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
S->Diag(TheCall->getBeginLoc(),
diag::err_typecheck_call_different_arg_types)
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
<< Arg2->getSourceRange();
return true;
}

TheCall->setType(Arg1->getType());
return false;
}

static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
Expr *Arg2 = TheCall->getArg(2);
if (!Arg1->getType()->isVectorType()) {
S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
<< "Second" << TheCall->getDirectCallee() << Arg1->getType()
<< Arg1->getSourceRange();
return true;
}

if (!Arg2->getType()->isVectorType()) {
S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
<< "Third" << TheCall->getDirectCallee() << Arg2->getType()
<< Arg2->getSourceRange();
return true;
}

if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
S->Diag(TheCall->getBeginLoc(),
diag::err_typecheck_call_different_arg_types)
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
<< Arg2->getSourceRange();
return true;
}

// caller has checked that Arg0 is a vector.
// check all three args have the same length.
if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
S->Diag(TheCall->getBeginLoc(),
diag::err_typecheck_vector_lengths_not_equal)
<< TheCall->getArg(0)->getType() << Arg1->getType()
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
return true;
}
TheCall->setType(Arg1->getType());
return false;
}

// Note: returning true in this case results in CheckBuiltinFunctionCall
// returning an ExprError
bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
Expand Down Expand Up @@ -1544,6 +1617,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_select: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
return true;
QualType ArgTy = TheCall->getArg(0)->getType();
if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
return true;
auto *VTy = ArgTy->getAs<VectorType>();
if (VTy && VTy->getElementType()->isBooleanType() &&
CheckVectorSelect(&SemaRef, TheCall))
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_saturate:
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
Expand Down
54 changes: 54 additions & 0 deletions clang/test/CodeGenHLSL/builtins/select.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK

// CHECK-LABEL: test_select_bool_int
// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, i32 {{%.*}}, i32 {{%.*}}
// CHECK: ret i32 [[SELECT]]
int test_select_bool_int(bool cond0, int tVal, int fVal) {
return select<int>(cond0, tVal, fVal);
}

struct S { int a; };
// CHECK-LABEL: test_select_infer
// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, ptr {{%.*}}, ptr {{%.*}}
// CHECK: store ptr [[SELECT]]
// CHECK: ret void
struct S test_select_infer(bool cond0, struct S tVal, struct S fVal) {
return select(cond0, tVal, fVal);
}

// CHECK-LABEL: test_select_bool_vector
// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
// CHECK: ret <2 x i32> [[SELECT]]
int2 test_select_bool_vector(bool cond0, int2 tVal, int2 fVal) {
return select<int2>(cond0, tVal, fVal);
}

// CHECK-LABEL: test_select_vector_1
// CHECK: [[SELECT:%.*]] = select <1 x i1> {{%.*}}, <1 x i32> {{%.*}}, <1 x i32> {{%.*}}
// CHECK: ret <1 x i32> [[SELECT]]
int1 test_select_vector_1(bool1 cond0, int1 tVals, int1 fVals) {
return select<int,1>(cond0, tVals, fVals);
}

// CHECK-LABEL: test_select_vector_2
// CHECK: [[SELECT:%.*]] = select <2 x i1> {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
// CHECK: ret <2 x i32> [[SELECT]]
int2 test_select_vector_2(bool2 cond0, int2 tVals, int2 fVals) {
return select<int,2>(cond0, tVals, fVals);
}

// CHECK-LABEL: test_select_vector_3
// CHECK: [[SELECT:%.*]] = select <3 x i1> {{%.*}}, <3 x i32> {{%.*}}, <3 x i32> {{%.*}}
// CHECK: ret <3 x i32> [[SELECT]]
int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
return select<int,3>(cond0, tVals, fVals);
}

// CHECK-LABEL: test_select_vector_4
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> {{%.*}}
// CHECK: ret <4 x i32> [[SELECT]]
int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
return select(cond0, tVals, fVals);
}
Loading
Loading