Skip to content

[DXIL] exp, any, lerp, & rcp Intrinsic Lowering #84526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
bool isFloatingType() const; // C99 6.2.5p11 (real floating + complex)
bool isHalfType() const; // OpenCL 6.1.1.1, NEON (IEEE 754-2008 half)
bool isFloat16Type() const; // C11 extension ISO/IEC TS 18661
bool isFloat32Type() const;
bool isBFloat16Type() const;
bool isFloat128Type() const;
bool isIbm128Type() const;
Expand Down Expand Up @@ -7452,6 +7453,10 @@ inline bool Type::isFloat16Type() const {
return isSpecificBuiltinType(BuiltinType::Float16);
}

inline bool Type::isFloat32Type() const {
return isSpecificBuiltinType(BuiltinType::Float);
}

inline bool Type::isBFloat16Type() const {
return isSpecificBuiltinType(BuiltinType::BFloat16);
}
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4598,7 +4598,7 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {

def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

Expand Down
35 changes: 4 additions & 31 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18021,38 +18021,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
Value *S = EmitScalarExpr(E->getArg(2));
llvm::Type *Xty = X->getType();
llvm::Type *Yty = Y->getType();
llvm::Type *Sty = S->getType();
if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
if (Xty->isFloatingPointTy()) {
auto V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
}
llvm_unreachable("Scalar Lerp is only supported on floats.");
}
// A VectorSplat should have happened
assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
"Lerp of vector and scalar is not supported.");

[[maybe_unused]] auto *XVecTy =
E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *YVecTy =
E->getArg(1)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *SVecTy =
E->getArg(2)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
XVecTy->getNumElements() == SVecTy->getNumElements() &&
"Lerp requires vectors to be of the same size.");
assert(XVecTy->getElementType()->isRealFloatingType() &&
XVecTy->getElementType() == YVecTy->getElementType() &&
XVecTy->getElementType() == SVecTy->getElementType() &&
"Lerp requires float vectors to be of the same type.");
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
llvm_unreachable("lerp operand must have a float representation");
return Builder.CreateIntrinsic(
/*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
nullptr, "dx.lerp");
/*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
}
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Expand Down
51 changes: 37 additions & 14 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5234,10 +5234,6 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
TheCall->getArg(1)->getEndLoc());
retValue = true;
}

if (!retValue)
TheCall->setType(VecTyA->getElementType());

return retValue;
}
}
Expand All @@ -5251,11 +5247,12 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}

bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
QualType ExpectedType = S->Context.FloatTy;
bool CheckArgsTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
QualType PassedType = TheCall->getArg(i)->getType();
if (!PassedType->hasFloatingRepresentation()) {
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
Expand All @@ -5268,6 +5265,26 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
return false;
}

bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkAllFloatTypes);
}

bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->getAs<clang::VectorType>()->getElementType()
: PassedType;
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
};
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkFloatorHalf);
}

void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
QualType ReturnType) {
auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
Expand Down Expand Up @@ -5295,21 +5312,27 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_isinf: {
if (checkArgCount(*this, TheCall, 1))
return true;
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
return true;
SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
case Builtin::BI__builtin_hlsl_elementwise_rcp:
case Builtin::BI__builtin_hlsl_elementwise_frac: {
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
if (CheckFloatOrHalfRepresentations(this, TheCall))
return true;
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_isinf: {
if (CheckFloatOrHalfRepresentations(this, TheCall))
return true;
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
break;
}
case Builtin::BI__builtin_hlsl_lerp: {
Expand All @@ -5319,7 +5342,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (SemaBuiltinElementwiseTernaryMath(TheCall))
return true;
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
if (CheckFloatOrHalfRepresentations(this, TheCall))
return true;
break;
}
Expand Down
22 changes: 0 additions & 22 deletions clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s



// CHECK-LABEL: builtin_lerp_half_scalar
// CHECK: %3 = fsub double %conv1, %conv
// CHECK: %4 = fmul double %conv2, %3
// CHECK: %dx.lerp = fadd double %conv, %4
// CHECK: %conv3 = fptrunc double %dx.lerp to half
// CHECK: ret half %conv3
half builtin_lerp_half_scalar (half p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}

// CHECK-LABEL: builtin_lerp_float_scalar
// CHECK: %3 = fsub double %conv1, %conv
// CHECK: %4 = fmul double %conv2, %3
// CHECK: %dx.lerp = fadd double %conv, %4
// CHECK: %conv3 = fptrunc double %dx.lerp to float
// CHECK: ret float %conv3
float builtin_lerp_float_scalar ( float p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
}

// CHECK-LABEL: builtin_lerp_half_vector
// CHECK: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
// CHECK: ret <3 x half> %dx.lerp
Expand Down
27 changes: 11 additions & 16 deletions clang/test/CodeGenHLSL/builtins/lerp.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,46 @@
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF

// NATIVE_HALF: %3 = fsub half %1, %0
// NATIVE_HALF: %4 = fmul half %2, %3
// NATIVE_HALF: %dx.lerp = fadd half %0, %4

// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
// NATIVE_HALF: ret half %dx.lerp
// NO_HALF: %3 = fsub float %1, %0
// NO_HALF: %4 = fmul float %2, %3
// NO_HALF: %dx.lerp = fadd float %0, %4
// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
// NO_HALF: ret float %dx.lerp
half test_lerp_half(half p0) { return lerp(p0, p0, p0); }

// NATIVE_HALF: %dx.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
// NATIVE_HALF: ret <2 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
// NO_HALF: ret <2 x float> %dx.lerp
half2 test_lerp_half2(half2 p0, half2 p1) { return lerp(p0, p0, p0); }
half2 test_lerp_half2(half2 p0) { return lerp(p0, p0, p0); }

// NATIVE_HALF: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
// NATIVE_HALF: ret <3 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
// NO_HALF: ret <3 x float> %dx.lerp
half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
half3 test_lerp_half3(half3 p0) { return lerp(p0, p0, p0); }

// NATIVE_HALF: %dx.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
// NATIVE_HALF: ret <4 x half> %dx.lerp
// NO_HALF: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
// NO_HALF: ret <4 x float> %dx.lerp
half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
half4 test_lerp_half4(half4 p0) { return lerp(p0, p0, p0); }

// CHECK: %3 = fsub float %1, %0
// CHECK: %4 = fmul float %2, %3
// CHECK: %dx.lerp = fadd float %0, %4
// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
// CHECK: ret float %dx.lerp
float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
float test_lerp_float(float p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
// CHECK: ret <2 x float> %dx.lerp
float2 test_lerp_float2(float2 p0, float2 p1) { return lerp(p0, p0, p0); }
float2 test_lerp_float2(float2 p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
// CHECK: ret <3 x float> %dx.lerp
float3 test_lerp_float3(float3 p0, float3 p1) { return lerp(p0, p0, p0); }
float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
// CHECK: ret <4 x float> %dx.lerp
float4 test_lerp_float4(float4 p0, float4 p1) { return lerp(p0, p0, p0); }
float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }

// CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
// CHECK: ret <2 x float> %dx.lerp
Expand Down
12 changes: 12 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,15 @@ float2 builtin_frac_int2_to_float2_promotion(int2 p1) {
return __builtin_hlsl_elementwise_frac(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}

// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
half builtin_frac_half_scalar (half p0) {
return __builtin_hlsl_elementwise_frac (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_frac_float_scalar ( float p0) {
return __builtin_hlsl_elementwise_frac (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

11 changes: 11 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,14 @@ bool2 builtin_isinf_int2_to_float2_promotion(int2 p1) {
return __builtin_hlsl_elementwise_isinf(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}

// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
half builtin_isinf_half_scalar (half p0) {
return __builtin_hlsl_elementwise_isinf (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_isinf_float_scalar ( float p0) {
return __builtin_hlsl_elementwise_isinf (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}
17 changes: 15 additions & 2 deletions clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,18 @@ float builtin_lerp_int_to_float_promotion(float p0, int p1) {

float4 test_lerp_int4(int4 p0, int4 p1, int4 p2) {
return __builtin_hlsl_lerp(p0, p1, p2);
// expected-error@-1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
}
// expected-error@-1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
}

// note: DefaultVariadicArgumentPromotion --> DefaultArgumentPromotion has already promoted to double
// we don't know anymore that the input was half when __builtin_hlsl_lerp is called so we default to float
// for expected type
half builtin_lerp_half_scalar (half p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_lerp_float_scalar ( float p0) {
return __builtin_hlsl_lerp ( p0, p0, p0 );
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}
11 changes: 11 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,14 @@ float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
return __builtin_hlsl_elementwise_rsqrt(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}

// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
half builtin_rsqrt_half_scalar (half p0) {
return __builtin_hlsl_elementwise_rsqrt (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}

float builtin_rsqrt_float_scalar ( float p0) {
return __builtin_hlsl_elementwise_rsqrt (p0);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
}
4 changes: 1 addition & 3 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def int_dx_isinf :
DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
[llvm_anyfloat_ty]>;

def int_dx_lerp :
Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;

def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
DirectXSubtarget.cpp
DirectXTargetMachine.cpp
DXContainerGlobals.cpp
DXILIntrinsicExpansion.cpp
DXILMetadata.cpp
DXILOpBuilder.cpp
DXILOpLowering.cpp
Expand Down
Loading