Skip to content

[HLSL][clang] Add elementwise builtins for trig intrinsics #95999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/docs/LanguageExtensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,12 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in
T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types
T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types
T __builtin_elementwise_tan(T x) return the tangent of x interpreted as an angle in radians floating point types
T __builtin_elementwise_asin(T x) return the arcsine of x interpreted as an angle in radians floating point types
T __builtin_elementwise_acos(T x) return the arccosine of x interpreted as an angle in radians floating point types
T __builtin_elementwise_atan(T x) return the arctangent of x interpreted as an angle in radians floating point types
T __builtin_elementwise_sinh(T x) return the hyperbolic sine of angle x in radians floating point types
T __builtin_elementwise_cosh(T x) return the hyperbolic cosine of angle x in radians floating point types
T __builtin_elementwise_tanh(T x) return the hyperbolic tangent of angle x in radians floating point types
T __builtin_elementwise_floor(T x) return the largest integral value less than or equal to x floating point types
T __builtin_elementwise_log(T x) return the natural logarithm of x floating point types
T __builtin_elementwise_log2(T x) return the base 2 logarithm of x floating point types
Expand Down
36 changes: 36 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,24 @@ def ElementwiseAbs : Builtin {
let Prototype = "void(...)";
}

def ElementwiseACos : Builtin {
let Spellings = ["__builtin_elementwise_acos"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

def ElementwiseASin : Builtin {
let Spellings = ["__builtin_elementwise_asin"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

def ElementwiseATan : Builtin {
let Spellings = ["__builtin_elementwise_atan"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

def ElementwiseBitreverse : Builtin {
let Spellings = ["__builtin_elementwise_bitreverse"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down Expand Up @@ -1248,6 +1266,12 @@ def ElementwiseCos : Builtin {
let Prototype = "void(...)";
}

def ElementwiseCosh : Builtin {
let Spellings = ["__builtin_elementwise_cosh"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

def ElementwiseExp : Builtin {
let Spellings = ["__builtin_elementwise_exp"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down Expand Up @@ -1320,6 +1344,12 @@ def ElementwiseSin : Builtin {
let Prototype = "void(...)";
}

def ElementwiseSinh : Builtin {
let Spellings = ["__builtin_elementwise_sinh"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

def ElementwiseSqrt : Builtin {
let Spellings = ["__builtin_elementwise_sqrt"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand All @@ -1332,6 +1362,12 @@ def ElementwiseTan : Builtin {
let Prototype = "void(...)";
}

def ElementwiseTanh : Builtin {
let Spellings = ["__builtin_elementwise_tanh"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

def ElementwiseTrunc : Builtin {
let Spellings = ["__builtin_elementwise_trunc"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down
19 changes: 18 additions & 1 deletion clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3690,7 +3690,15 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,

return RValue::get(Result);
}

case Builtin::BI__builtin_elementwise_acos:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::acos, "elt.acos"));
case Builtin::BI__builtin_elementwise_asin:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::asin, "elt.asin"));
case Builtin::BI__builtin_elementwise_atan:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::atan, "elt.atan"));
case Builtin::BI__builtin_elementwise_ceil:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::ceil, "elt.ceil"));
Expand Down Expand Up @@ -3719,6 +3727,9 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_elementwise_cos:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::cos, "elt.cos"));
case Builtin::BI__builtin_elementwise_cosh:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::cosh, "elt.cosh"));
case Builtin::BI__builtin_elementwise_floor:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::floor, "elt.floor"));
Expand All @@ -3737,9 +3748,15 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_elementwise_sin:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::sin, "elt.sin"));
case Builtin::BI__builtin_elementwise_sinh:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::sinh, "elt.sinh"));
case Builtin::BI__builtin_elementwise_tan:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::tan, "elt.tan"));
case Builtin::BI__builtin_elementwise_tanh:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::tanh, "elt.tanh"));
case Builtin::BI__builtin_elementwise_trunc:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::trunc, "elt.trunc"));
Expand Down
173 changes: 173 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ double3 abs(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_abs)
double4 abs(double4);

//===----------------------------------------------------------------------===//
// acos builtins
//===----------------------------------------------------------------------===//

/// \fn T acos(T Val)
/// \brief Returns the arccosine of the input value, \a Val.
/// \param Val The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
half acos(half);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
half2 acos(half2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
half3 acos(half3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
half4 acos(half4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
float acos(float);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
float2 acos(float2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
float3 acos(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_acos)
float4 acos(float4);

//===----------------------------------------------------------------------===//
// all builtins
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -331,6 +359,62 @@ bool any(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_any)
bool any(double4);

//===----------------------------------------------------------------------===//
// asin builtins
//===----------------------------------------------------------------------===//

/// \fn T asin(T Val)
/// \brief Returns the arcsine of the input value, \a Val.
/// \param Val The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
half asin(half);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
half2 asin(half2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
half3 asin(half3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
half4 asin(half4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
float asin(float);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
float2 asin(float2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
float3 asin(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
float4 asin(float4);

//===----------------------------------------------------------------------===//
// atan builtins
//===----------------------------------------------------------------------===//

/// \fn T atan(T Val)
/// \brief Returns the arctangent of the input value, \a Val.
/// \param Val The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
half atan(half);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
half2 atan(half2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
half3 atan(half3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
half4 atan(half4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
float atan(float);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
float2 atan(float2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
float3 atan(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_atan)
float4 atan(float4);

//===----------------------------------------------------------------------===//
// ceil builtins
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -502,6 +586,34 @@ float3 cos(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
float4 cos(float4);

//===----------------------------------------------------------------------===//
// cosh builtins
//===----------------------------------------------------------------------===//

/// \fn T cosh(T Val)
/// \brief Returns the hyperbolic cosine of the input value, \a Val.
/// \param Val The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
half cosh(half);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
half2 cosh(half2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
half3 cosh(half3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
half4 cosh(half4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
float cosh(float);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
float2 cosh(float2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
float3 cosh(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh)
float4 cosh(float4);

//===----------------------------------------------------------------------===//
// dot product builtins
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1418,6 +1530,34 @@ float3 sin(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sin)
float4 sin(float4);

//===----------------------------------------------------------------------===//
// sinh builtins
//===----------------------------------------------------------------------===//

/// \fn T sinh(T Val)
/// \brief Returns the hyperbolic sine of the input value, \a Val.
/// \param Val The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
half sinh(half);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
half2 sinh(half2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
half3 sinh(half3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
half4 sinh(half4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
float sinh(float);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
float2 sinh(float2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
float3 sinh(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sinh)
float4 sinh(float4);

//===----------------------------------------------------------------------===//
// sqrt builtins
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1451,6 +1591,11 @@ float4 sqrt(float4);
//===----------------------------------------------------------------------===//
// tan builtins
//===----------------------------------------------------------------------===//

/// \fn T tan(T Val)
/// \brief Returns the tangent of the input value, \a Val.
/// \param Val The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tan)
half tan(half);
Expand All @@ -1471,6 +1616,34 @@ float3 tan(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tan)
float4 tan(float4);

//===----------------------------------------------------------------------===//
// tanh builtins
//===----------------------------------------------------------------------===//

/// \fn T tanh(T Val)
/// \brief Returns the hyperbolic tangent of the input value, \a Val.
/// \param Val The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
half tanh(half);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
half2 tanh(half2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
half3 tanh(half3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
half4 tanh(half4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
float tanh(float);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
float2 tanh(float2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
float3 tanh(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_tanh)
float4 tanh(float4);

//===----------------------------------------------------------------------===//
// trunc builtins
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3160,8 +3160,12 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,

// These builtins restrict the element type to floating point
// types only.
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
case Builtin::BI__builtin_elementwise_ceil:
case Builtin::BI__builtin_elementwise_cos:
case Builtin::BI__builtin_elementwise_cosh:
case Builtin::BI__builtin_elementwise_exp:
case Builtin::BI__builtin_elementwise_exp2:
case Builtin::BI__builtin_elementwise_floor:
Expand All @@ -3173,8 +3177,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
case Builtin::BI__builtin_elementwise_rint:
case Builtin::BI__builtin_elementwise_nearbyint:
case Builtin::BI__builtin_elementwise_sin:
case Builtin::BI__builtin_elementwise_sinh:
case Builtin::BI__builtin_elementwise_sqrt:
case Builtin::BI__builtin_elementwise_tan:
case Builtin::BI__builtin_elementwise_tanh:
case Builtin::BI__builtin_elementwise_trunc:
case Builtin::BI__builtin_elementwise_canonicalize: {
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
Expand Down Expand Up @@ -3635,8 +3641,12 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
case Builtin::BI__builtin_elementwise_ceil:
case Builtin::BI__builtin_elementwise_cos:
case Builtin::BI__builtin_elementwise_cosh:
case Builtin::BI__builtin_elementwise_exp:
case Builtin::BI__builtin_elementwise_exp2:
case Builtin::BI__builtin_elementwise_floor:
Expand All @@ -3646,8 +3656,10 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
case Builtin::BI__builtin_elementwise_pow:
case Builtin::BI__builtin_elementwise_roundeven:
case Builtin::BI__builtin_elementwise_sin:
case Builtin::BI__builtin_elementwise_sinh:
case Builtin::BI__builtin_elementwise_sqrt:
case Builtin::BI__builtin_elementwise_tan:
case Builtin::BI__builtin_elementwise_tanh:
case Builtin::BI__builtin_elementwise_trunc: {
if (CheckFloatOrHalfRepresentations(this, TheCall))
return true;
Expand Down
Loading
Loading