diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h index 4874206d349c0..a2f8f658b292e 100644 --- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h +++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h @@ -158,6 +158,42 @@ namespace hlsl { return fn((float4)V1, (float4)V2, (float4)V3); \ } +#define _DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(fn) \ + template \ + constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> fn( \ + vector V1, T V2) { \ + return fn(V1, (vector)V2); \ + } \ + template \ + constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> fn( \ + T V1, vector V2) { \ + return fn((vector)V1, V2); \ + } + +#define _DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(fn) \ + template \ + constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> fn( \ + T V1, vector V2, vector V3) { \ + return fn((vector)V1, V2, V3); \ + } \ + template \ + constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> fn( \ + vector V1, T V2, vector V3) { \ + return fn(V1, (vector)V2, V3); \ + } \ + template \ + constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> fn( \ + vector V1, vector V2, T V3) { \ + return fn(V1, V2, (vector)V3); \ + } + +#define _DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(fn) \ + template \ + constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> fn( \ + vector V1, T V2, T V3) { \ + return fn(V1, (vector)V2, (vector)V3); \ + } + //===----------------------------------------------------------------------===// // acos builtins overloads //===----------------------------------------------------------------------===// @@ -197,23 +233,8 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(ceil) // clamp builtins overloads //===----------------------------------------------------------------------===// -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -clamp(vector p0, vector p1, T p2) { - return clamp(p0, p1, (vector)p2); -} - -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -clamp(vector p0, T p1, vector p2) { - return clamp(p0, (vector)p1, p2); -} - -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -clamp(vector p0, T p1, T p2) { - return clamp(p0, (vector)p1, (vector)p2); -} +_DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(clamp) +_DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(clamp) //===----------------------------------------------------------------------===// // cos builtins overloads @@ -236,6 +257,22 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(cosh) _DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(degrees) _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(degrees) +//===----------------------------------------------------------------------===// +// dot builtins overloads +//===----------------------------------------------------------------------===// + +template +constexpr __detail::enable_if_t<(N > 1 && N <= 4), T> dot(vector V1, + T V2) { + return dot(V1, (vector)V2); +} + +template +constexpr __detail::enable_if_t<(N > 1 && N <= 4), T> dot(T V1, + vector V2) { + return dot((vector)V1, V2); +} + //===----------------------------------------------------------------------===// // exp builtins overloads //===----------------------------------------------------------------------===// @@ -277,14 +314,10 @@ constexpr bool4 isinf(double4 V) { return isinf((float4)V); } // lerp builtins overloads //===----------------------------------------------------------------------===// -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -lerp(vector x, vector y, T s) { - return lerp(x, y, (vector)s); -} - _DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(lerp) _DXC_COMPAT_TERNARY_INTEGER_OVERLOADS(lerp) +_DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(lerp) +_DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(lerp) //===----------------------------------------------------------------------===// // log builtins overloads @@ -311,33 +344,13 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(log2) // max builtins overloads //===----------------------------------------------------------------------===// -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -max(vector p0, T p1) { - return max(p0, (vector)p1); -} - -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -max(T p0, vector p1) { - return max((vector)p0, p1); -} +_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(max) //===----------------------------------------------------------------------===// // min builtins overloads //===----------------------------------------------------------------------===// -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -min(vector p0, T p1) { - return min(p0, (vector)p1); -} - -template -constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector> -min(T p0, vector p1) { - return min((vector)p0, p1); -} +_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(min) //===----------------------------------------------------------------------===// // normalize builtins overloads @@ -352,6 +365,7 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(normalize) _DXC_COMPAT_BINARY_DOUBLE_OVERLOADS(pow) _DXC_COMPAT_BINARY_INTEGER_OVERLOADS(pow) +_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(pow) //===----------------------------------------------------------------------===// // rsqrt builtins overloads diff --git a/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl index c0e1e914831aa..5bf23db7671ec 100644 --- a/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl +++ b/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl @@ -90,6 +90,12 @@ double4 test_clamp_double4_mismatch1(double4 p0, double p1) { return clamp(p0, p // CHECK: [[CLAMP:%.*]] = call reassoc nnan ninf nsz arcp afn {{.*}} <4 x double> @llvm.[[TARGET]].nclamp.v4f64(<4 x double> %{{.*}}, <4 x double> [[CONV1]], <4 x double> %{{.*}}) // CHECK: ret <4 x double> [[CLAMP]] double4 test_clamp_double4_mismatch2(double4 p0, double p1) { return clamp(p0, p1,p0); } +// CHECK: define [[FNATTRS]] [[FFNATTRS]] <4 x double> {{.*}}test_clamp_double4_mismatch3 +// CHECK: [[CONV0:%.*]] = insertelement <4 x double> poison, double %{{.*}}, i64 0 +// CHECK: [[CONV1:%.*]] = shufflevector <4 x double> [[CONV0]], <4 x double> poison, <4 x i32> zeroinitializer +// CHECK: [[CLAMP:%.*]] = call reassoc nnan ninf nsz arcp afn {{.*}} <4 x double> @llvm.[[TARGET]].nclamp.v4f64(<4 x double> [[CONV1]], <4 x double> %{{.*}}, <4 x double> %{{.*}}) +// CHECK: ret <4 x double> [[CLAMP]] +double4 test_clamp_double4_mismatch3(double4 p0, double p1) { return clamp(p1, p0, p0); } // CHECK: define [[FNATTRS]] <3 x i32> {{.*}}test_overloads3 // CHECK: [[CONV0:%.*]] = insertelement <3 x i32> poison, i32 %{{.*}}, i64 0 diff --git a/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl new file mode 100644 index 0000000000000..33f0c7625b2eb --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -triple dxil-pc-shadermodel6.3-library %s \ +// RUN: -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK \ +// RUN: -DTARGET=dx -DFNATTRS=noundef -DFFNATTRS="nofpclass(nan inf)" + +// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -triple spirv-unknown-vulkan-compute %s \ +// RUN: -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK \ +// RUN: -DTARGET=spv -DFNATTRS="spir_func noundef" -DFFNATTRS="nofpclass(nan inf)" + +// CHECK: define [[FNATTRS]] [[FFNATTRS]] float {{.*}}test_dot_float4_mismatch1 +// CHECK: [[CONV0:%.*]] = insertelement <4 x float> poison, float %{{.*}}, i64 0 +// CHECK: [[CONV1:%.*]] = shufflevector <4 x float> [[CONV0]], <4 x float> poison, <4 x i32> zeroinitializer +// CHECK: [[DOT:%.*]] = call {{.*}} float @llvm.[[TARGET]].fdot.v4f32(<4 x float> %{{.*}}, <4 x float> [[CONV1]]) +// CHECK: ret float [[DOT]] +float test_dot_float4_mismatch1(float4 p0, float p1) { return dot(p0, p1); } + +// CHECK: define [[FNATTRS]] [[FFNATTRS]] float {{.*}}test_dot_float4_mismatch2 +// CHECK: [[CONV0:%.*]] = insertelement <4 x float> poison, float %{{.*}}, i64 0 +// CHECK: [[CONV1:%.*]] = shufflevector <4 x float> [[CONV0]], <4 x float> poison, <4 x i32> zeroinitializer +// CHECK: [[DOT:%.*]] = call {{.*}} float @llvm.[[TARGET]].fdot.v4f32(<4 x float> [[CONV1]], <4 x float> %{{.*}}) +// CHECK: ret float [[DOT]] +float test_dot_float4_mismatch2(float4 p0, float p1) { return dot(p1, p0); } + +// CHECK: define [[FNATTRS]] i32 {{.*}}test_dot_int2_mismatch1 +// CHECK: [[CONV0:%.*]] = insertelement <2 x i32> poison, i32 %{{.*}}, i64 0 +// CHECK: [[CONV1:%.*]] = shufflevector <2 x i32> [[CONV0]], <2 x i32> poison, <2 x i32> zeroinitializer +// CHECK: [[DOT:%.*]] = call {{.*}} i32 @llvm.[[TARGET]].sdot.v2i32(<2 x i32> %{{.*}}, <2 x i32> [[CONV1]]) +// CHECK: ret i32 [[DOT]] +int test_dot_int2_mismatch1(int2 p0, int p1) { return dot(p0, p1); } + diff --git a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl index 3cb14f8555cab..0158370a847d1 100644 --- a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl +++ b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl @@ -179,3 +179,41 @@ half3 test_lerp_half_scalar(half3 x, half3 y, half s) { return lerp(x, y, s); } float3 test_lerp_float_scalar(float3 x, float3 y, float s) { return lerp(x, y, s); } + +// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar1Dv2_ff( +// CHECK: [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0 +// CHECK: [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, <2 x float> [[SPLAT]]) +// CHECK: ret <2 x float> [[LERP]] +float2 test_lerp_float_scalar1(float2 v, float s) { + return lerp(v, v, s); +} + +// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar2Dv2_ff( +// CHECK: [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0 +// CHECK: [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> [[SPLAT]], <2 x float> {{.*}}) +// CHECK: ret <2 x float> [[LERP]] +float2 test_lerp_float_scalar2(float2 v, float s) { + return lerp(v, s, v); +} + +// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar3Dv2_ff( +// CHECK: [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0 +// CHECK: [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> [[SPLAT]], <2 x float> {{.*}}, <2 x float> {{.*}}) +// CHECK: ret <2 x float> [[LERP]] +float2 test_lerp_float_scalar3(float2 v, float s) { + return lerp(s, v, v); +} + +// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar4Dv2_ff( +// CHECK: [[SPLATINSERT0:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0 +// CHECK: [[SPLAT0:%.*]] = shufflevector <2 x float> [[SPLATINSERT0]], <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: [[SPLATINSERT1:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0 +// CHECK: [[SPLAT1:%.*]] = shufflevector <2 x float> [[SPLATINSERT1]], <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> [[SPLAT0]], <2 x float> [[SPLAT1]]) +// CHECK: ret <2 x float> [[LERP]] +float2 test_lerp_float_scalar4(float2 v, float s) { + return lerp(v, s, s); +} diff --git a/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl index 39003aef7b7b5..3fc8cfcd9a8cb 100644 --- a/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl +++ b/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl @@ -126,3 +126,16 @@ float3 test_pow_uint64_t3(uint64_t3 p0, uint64_t3 p1) { return pow(p0, p1); } // CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <4 x float> @llvm.pow.v4f32(<4 x float> [[CONV0]], <4 x float> [[CONV1]]) // CHECK: ret <4 x float> [[POW]] float4 test_pow_uint64_t4(uint64_t4 p0, uint64_t4 p1) { return pow(p0, p1); } + +// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> {{.*}}test_pow_float2_mismatch1 +// CHECK: [[CONV0:%.*]] = insertelement <2 x float> poison, float {{.*}}, i64 0 +// CHECK: [[CONV1:%.*]] = shufflevector <2 x float> [[CONV0]], <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <2 x float> @llvm.pow.v2f32(<2 x float> {{.*}}, <2 x float> [[CONV1]]) +// CHECK: ret <2 x float> [[POW]] +float2 test_pow_float2_mismatch1(float2 p0, float p1) { return pow(p0, p1); } +// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> {{.*}}test_pow_float2_mismatch2 +// CHECK: [[CONV0:%.*]] = insertelement <2 x float> poison, float {{.*}}, i64 0 +// CHECK: [[CONV1:%.*]] = shufflevector <2 x float> [[CONV0]], <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <2 x float> @llvm.pow.v2f32(<2 x float> [[CONV1]], <2 x float> {{.*}}) +// CHECK: ret <2 x float> [[POW]] +float2 test_pow_float2_mismatch2(float2 p0, float p1) { return pow(p1, p0); }