Skip to content

[HLSL] Add various overloads for MiniEngine #139800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

V-FEXrt
Copy link
Contributor

@V-FEXrt V-FEXrt commented May 13, 2025

Partial implementation of llvm/wg-hlsl#264

Adds several overloads to various intrinsic functions used by MiniEngine

@V-FEXrt V-FEXrt requested review from spall and farzonl May 13, 2025 21:47
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics HLSL HLSL Language Support labels May 13, 2025
@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Ashley Coleman (V-FEXrt)

Changes

Partial implementation of llvm/wg-hlsl#264

Adds several overloads to various intrinsic functions used by MiniEngine


Full diff: https://github.com/llvm/llvm-project/pull/139800.diff

5 Files Affected:

  • (modified) clang/lib/Headers/hlsl/hlsl_compat_overloads.h (+59-45)
  • (modified) clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl (+6)
  • (added) clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl (+29)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl (+38)
  • (modified) clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl (+13)
diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
index 4874206d349c0..a2f8f658b292e 100644
--- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
+++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
@@ -158,6 +158,42 @@ namespace hlsl {
     return fn((float4)V1, (float4)V2, (float4)V3);                             \
   }
 
+#define _DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(fn)                         \
+  template <typename T, uint N>                                                \
+  constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn(         \
+      vector<T, N> V1, T V2) {                                                 \
+    return fn(V1, (vector<T, N>)V2);                                           \
+  }                                                                            \
+  template <typename T, uint N>                                                \
+  constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn(         \
+      T V1, vector<T, N> V2) {                                                 \
+    return fn((vector<T, N>)V1, V2);                                           \
+  }
+
+#define _DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(fn)                        \
+  template <typename T, uint N>                                                \
+  constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn(         \
+      T V1, vector<T, N> V2, vector<T, N> V3) {                                \
+    return fn((vector<T, N>)V1, V2, V3);                                       \
+  }                                                                            \
+  template <typename T, uint N>                                                \
+  constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn(         \
+      vector<T, N> V1, T V2, vector<T, N> V3) {                                \
+    return fn(V1, (vector<T, N>)V2, V3);                                       \
+  }                                                                            \
+  template <typename T, uint N>                                                \
+  constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn(         \
+      vector<T, N> V1, vector<T, N> V2, T V3) {                                \
+    return fn(V1, V2, (vector<T, N>)V3);                                       \
+  }
+
+#define _DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(fn)                 \
+  template <typename T, uint N>                                                \
+  constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn(         \
+      vector<T, N> V1, T V2, T V3) {                                           \
+    return fn(V1, (vector<T, N>)V2, (vector<T, N>)V3);                         \
+  }
+
 //===----------------------------------------------------------------------===//
 // acos builtins overloads
 //===----------------------------------------------------------------------===//
@@ -197,23 +233,8 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(ceil)
 // clamp builtins overloads
 //===----------------------------------------------------------------------===//
 
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-clamp(vector<T, N> p0, vector<T, N> p1, T p2) {
-  return clamp(p0, p1, (vector<T, N>)p2);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-clamp(vector<T, N> p0, T p1, vector<T, N> p2) {
-  return clamp(p0, (vector<T, N>)p1, p2);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-clamp(vector<T, N> p0, T p1, T p2) {
-  return clamp(p0, (vector<T, N>)p1, (vector<T, N>)p2);
-}
+_DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(clamp)
+_DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(clamp)
 
 //===----------------------------------------------------------------------===//
 // cos builtins overloads
@@ -236,6 +257,22 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(cosh)
 _DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(degrees)
 _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(degrees)
 
+//===----------------------------------------------------------------------===//
+// dot builtins overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), T> dot(vector<T, N> V1,
+                                                          T V2) {
+  return dot(V1, (vector<T, N>)V2);
+}
+
+template <typename T, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), T> dot(T V1,
+                                                          vector<T, N> V2) {
+  return dot((vector<T, N>)V1, V2);
+}
+
 //===----------------------------------------------------------------------===//
 // exp builtins overloads
 //===----------------------------------------------------------------------===//
@@ -277,14 +314,10 @@ constexpr bool4 isinf(double4 V) { return isinf((float4)V); }
 // lerp builtins overloads
 //===----------------------------------------------------------------------===//
 
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-lerp(vector<T, N> x, vector<T, N> y, T s) {
-  return lerp(x, y, (vector<T, N>)s);
-}
-
 _DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(lerp)
 _DXC_COMPAT_TERNARY_INTEGER_OVERLOADS(lerp)
+_DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(lerp)
+_DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(lerp)
 
 //===----------------------------------------------------------------------===//
 // log builtins overloads
@@ -311,33 +344,13 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(log2)
 // max builtins overloads
 //===----------------------------------------------------------------------===//
 
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-max(vector<T, N> p0, T p1) {
-  return max(p0, (vector<T, N>)p1);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-max(T p0, vector<T, N> p1) {
-  return max((vector<T, N>)p0, p1);
-}
+_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(max)
 
 //===----------------------------------------------------------------------===//
 // min builtins overloads
 //===----------------------------------------------------------------------===//
 
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-min(vector<T, N> p0, T p1) {
-  return min(p0, (vector<T, N>)p1);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-min(T p0, vector<T, N> p1) {
-  return min((vector<T, N>)p0, p1);
-}
+_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(min)
 
 //===----------------------------------------------------------------------===//
 // normalize builtins overloads
@@ -352,6 +365,7 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(normalize)
 
 _DXC_COMPAT_BINARY_DOUBLE_OVERLOADS(pow)
 _DXC_COMPAT_BINARY_INTEGER_OVERLOADS(pow)
+_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(pow)
 
 //===----------------------------------------------------------------------===//
 // rsqrt builtins overloads
diff --git a/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl
index c0e1e914831aa..5bf23db7671ec 100644
--- a/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl
@@ -90,6 +90,12 @@ double4 test_clamp_double4_mismatch1(double4 p0, double p1) { return clamp(p0, p
 // CHECK: [[CLAMP:%.*]] = call reassoc nnan ninf nsz arcp afn {{.*}} <4 x double> @llvm.[[TARGET]].nclamp.v4f64(<4 x double> %{{.*}}, <4 x double> [[CONV1]], <4 x double> %{{.*}})
 // CHECK: ret <4 x double> [[CLAMP]]
 double4 test_clamp_double4_mismatch2(double4 p0, double p1) { return clamp(p0, p1,p0); }
+// CHECK: define [[FNATTRS]] [[FFNATTRS]] <4 x double> {{.*}}test_clamp_double4_mismatch3
+// CHECK: [[CONV0:%.*]] = insertelement <4 x double> poison, double %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <4 x double> [[CONV0]], <4 x double> poison, <4 x i32> zeroinitializer
+// CHECK: [[CLAMP:%.*]] = call reassoc nnan ninf nsz arcp afn {{.*}} <4 x double> @llvm.[[TARGET]].nclamp.v4f64(<4 x double> [[CONV1]], <4 x double> %{{.*}}, <4 x double> %{{.*}})
+// CHECK: ret <4 x double> [[CLAMP]]
+double4 test_clamp_double4_mismatch3(double4 p0, double p1) { return clamp(p1, p0, p0); }
 
 // CHECK: define [[FNATTRS]] <3 x i32> {{.*}}test_overloads3
 // CHECK: [[CONV0:%.*]] = insertelement <3 x i32> poison, i32 %{{.*}}, i64 0
diff --git a/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl
new file mode 100644
index 0000000000000..33f0c7625b2eb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -triple dxil-pc-shadermodel6.3-library %s \
+// RUN:  -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK \
+// RUN:  -DTARGET=dx -DFNATTRS=noundef -DFFNATTRS="nofpclass(nan inf)"
+
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -triple spirv-unknown-vulkan-compute %s \
+// RUN:  -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK \
+// RUN:  -DTARGET=spv -DFNATTRS="spir_func noundef" -DFFNATTRS="nofpclass(nan inf)"
+
+// CHECK: define [[FNATTRS]] [[FFNATTRS]] float {{.*}}test_dot_float4_mismatch1
+// CHECK: [[CONV0:%.*]] = insertelement <4 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <4 x float> [[CONV0]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK: [[DOT:%.*]] = call {{.*}} float @llvm.[[TARGET]].fdot.v4f32(<4 x float> %{{.*}}, <4 x float> [[CONV1]])
+// CHECK: ret float [[DOT]]
+float test_dot_float4_mismatch1(float4 p0, float p1) { return dot(p0, p1); }
+
+// CHECK: define [[FNATTRS]] [[FFNATTRS]] float {{.*}}test_dot_float4_mismatch2
+// CHECK: [[CONV0:%.*]] = insertelement <4 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <4 x float> [[CONV0]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK: [[DOT:%.*]] = call {{.*}} float @llvm.[[TARGET]].fdot.v4f32(<4 x float> [[CONV1]], <4 x float> %{{.*}})
+// CHECK: ret float [[DOT]]
+float test_dot_float4_mismatch2(float4 p0, float p1) { return dot(p1, p0); }
+
+// CHECK: define [[FNATTRS]] i32 {{.*}}test_dot_int2_mismatch1
+// CHECK: [[CONV0:%.*]] = insertelement <2 x i32> poison, i32 %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <2 x i32> [[CONV0]], <2 x i32> poison, <2 x i32> zeroinitializer
+// CHECK: [[DOT:%.*]] = call {{.*}} i32 @llvm.[[TARGET]].sdot.v2i32(<2 x i32> %{{.*}}, <2 x i32> [[CONV1]])
+// CHECK: ret i32 [[DOT]]
+int test_dot_int2_mismatch1(int2 p0, int p1) { return dot(p0, p1); }
+
diff --git a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
index 3cb14f8555cab..0158370a847d1 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
@@ -179,3 +179,41 @@ half3 test_lerp_half_scalar(half3 x, half3 y, half s) { return lerp(x, y, s); }
 float3 test_lerp_float_scalar(float3 x, float3 y, float s) {
   return lerp(x, y, s);
 }
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar1Dv2_ff(
+// CHECK:    [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK:    [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK:    [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, <2 x float> [[SPLAT]])
+// CHECK:    ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar1(float2 v, float s) {
+  return lerp(v, v, s);
+}
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar2Dv2_ff(
+// CHECK:    [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK:    [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK:    [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> [[SPLAT]], <2 x float> {{.*}})
+// CHECK:    ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar2(float2 v, float s) {
+  return lerp(v, s, v);
+}
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar3Dv2_ff(
+// CHECK:    [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK:    [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK:    [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> [[SPLAT]], <2 x float> {{.*}}, <2 x float> {{.*}})
+// CHECK:    ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar3(float2 v, float s) {
+  return lerp(s, v, v);
+}
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar4Dv2_ff(
+// CHECK:    [[SPLATINSERT0:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK:    [[SPLAT0:%.*]] = shufflevector <2 x float> [[SPLATINSERT0]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK:    [[SPLATINSERT1:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK:    [[SPLAT1:%.*]] = shufflevector <2 x float> [[SPLATINSERT1]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK:    [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> [[SPLAT0]], <2 x float> [[SPLAT1]])
+// CHECK:    ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar4(float2 v, float s) {
+  return lerp(v, s, s);
+}
diff --git a/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl
index 39003aef7b7b5..3fc8cfcd9a8cb 100644
--- a/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl
@@ -126,3 +126,16 @@ float3 test_pow_uint64_t3(uint64_t3 p0, uint64_t3 p1) { return pow(p0, p1); }
 // CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <4 x float> @llvm.pow.v4f32(<4 x float> [[CONV0]], <4 x float> [[CONV1]])
 // CHECK: ret <4 x float> [[POW]]
 float4 test_pow_uint64_t4(uint64_t4 p0, uint64_t4 p1) { return pow(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> {{.*}}test_pow_float2_mismatch1
+// CHECK: [[CONV0:%.*]] = insertelement <2 x float> poison, float {{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <2 x float> [[CONV0]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <2 x float> @llvm.pow.v2f32(<2 x float> {{.*}}, <2 x float> [[CONV1]])
+// CHECK: ret <2 x float> [[POW]]
+float2 test_pow_float2_mismatch1(float2 p0, float p1) { return pow(p0, p1); }
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> {{.*}}test_pow_float2_mismatch2
+// CHECK: [[CONV0:%.*]] = insertelement <2 x float> poison, float {{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <2 x float> [[CONV0]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <2 x float> @llvm.pow.v2f32(<2 x float> [[CONV1]], <2 x float> {{.*}})
+// CHECK: ret <2 x float> [[POW]]
+float2 test_pow_float2_mismatch2(float2 p0, float p1) { return pow(p1, p0); }

return fn(V1, V2, (vector<T, N>)V3); \
}

#define _DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(fn) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you put this overload in its own macro and not in the previous one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some intrinsics didn't need both so I didn't want to unnecessarily add extra overloads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I thought in all cases you used both TERNARY_SINGLE and TERNARY_VECTOR? Which one doesn't use both?

// dot builtins overloads
//===----------------------------------------------------------------------===//

template <typename T, uint N>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use your '_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS' macro here right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its subtle, but no (I actually made that mistake myself).

_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS creates type vector<T, N> fn(vector<T, N> x, T y but dot requires T fn(vector<T, N> x, T y) (return types are different shape)

// CHECK: [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, <2 x float> [[SPLAT]])
// CHECK: ret <2 x float> [[LERP]]
float2 test_lerp_float_scalar1(float2 v, float s) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test worth adding since its a duplicate of the one above, just uses float2 instead of float3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly no, I was on the fence, but I couldn't quite figure out how to group them. I suppose I can just move up and rename the params

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants