Skip to content

Commit 080b65f

Browse files
committed
remove all type promotion
1 parent 8d8d1ff commit 080b65f

File tree

6 files changed

+90
-192
lines changed

6 files changed

+90
-192
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -17996,8 +17996,10 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1799617996
assert(T0->getScalarType() == T1->getScalarType() &&
1799717997
"Dot product of vectors need the same element types.");
1799817998

17999-
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
18000-
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
17999+
[[maybe_unused]] auto *VecTy0 =
18000+
E->getArg(0)->getType()->getAs<VectorType>();
18001+
[[maybe_unused]] auto *VecTy1 =
18002+
E->getArg(1)->getType()->getAs<VectorType>();
1800118003
// A HLSLVectorTruncation should have happend
1800218004
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
1800318005
"Dot product requires vectors to be of the same size.");

clang/lib/Headers/hlsl/hlsl_intrinsics.h

+6
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ double4 cos(double4);
182182
//===----------------------------------------------------------------------===//
183183
// dot product builtins
184184
//===----------------------------------------------------------------------===//
185+
186+
/// \fn K dot(T X, T Y)
187+
/// \brief Return the dot product (a scalar value) of \a X and \a Y.
188+
/// \param X The X input value.
189+
/// \param Y The Y input value.
190+
185191
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
186192
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
187193
half dot(half, half);

clang/lib/Sema/SemaChecking.cpp

+35-125
Original file line numberDiff line numberDiff line change
@@ -2962,9 +2962,8 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
29622962
}
29632963
}
29642964

2965-
if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
2965+
if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall))
29662966
return ExprError();
2967-
}
29682967

29692968
// Since the target specific builtins for each arch overlap, only check those
29702969
// of the arch we are compiling for.
@@ -5166,96 +5165,6 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51665165
return false;
51675166
}
51685167

5169-
// Helper function for CheckHLSLBuiltinFunctionCall
5170-
// Note: UsualArithmeticConversions handles the case where at least
5171-
// one arg isn't a bool
5172-
bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
5173-
unsigned NumArgs = TheCall->getNumArgs();
5174-
5175-
for (unsigned i = 0; i < NumArgs; ++i) {
5176-
ExprResult A = TheCall->getArg(i);
5177-
if (!A.get()->getType()->isBooleanType())
5178-
return false;
5179-
}
5180-
// if we got here all args are bool
5181-
for (unsigned i = 0; i < NumArgs; ++i) {
5182-
ExprResult A = TheCall->getArg(i);
5183-
ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
5184-
Sema::AA_Converting);
5185-
if (ResA.isInvalid())
5186-
return true;
5187-
TheCall->setArg(i, ResA.get());
5188-
}
5189-
return false;
5190-
}
5191-
5192-
// Helper function for CheckHLSLBuiltinFunctionCall
5193-
// Handles the CK_HLSLVectorTruncation case for builtins
5194-
void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
5195-
assert(TheCall->getNumArgs() > 1);
5196-
ExprResult A = TheCall->getArg(0);
5197-
ExprResult B = TheCall->getArg(1);
5198-
QualType ArgTyA = A.get()->getType();
5199-
QualType ArgTyB = B.get()->getType();
5200-
5201-
auto *VecTyA = ArgTyA->getAs<VectorType>();
5202-
auto *VecTyB = ArgTyB->getAs<VectorType>();
5203-
if (VecTyA == nullptr && VecTyB == nullptr)
5204-
return;
5205-
if (VecTyA == nullptr || VecTyB == nullptr)
5206-
return;
5207-
if (VecTyA->getNumElements() == VecTyB->getNumElements())
5208-
return;
5209-
5210-
Expr *LargerArg = B.get();
5211-
Expr *SmallerArg = A.get();
5212-
int largerIndex = 1;
5213-
if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
5214-
LargerArg = A.get();
5215-
SmallerArg = B.get();
5216-
largerIndex = 0;
5217-
}
5218-
5219-
S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
5220-
<< LargerArg->getType() << SmallerArg->getType()
5221-
<< LargerArg->getSourceRange() << SmallerArg->getSourceRange();
5222-
ExprResult ResLargerArg = S->ImpCastExprToType(
5223-
LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
5224-
TheCall->setArg(largerIndex, ResLargerArg.get());
5225-
return;
5226-
}
5227-
5228-
// Helper function for CheckHLSLBuiltinFunctionCall
5229-
void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
5230-
SourceRange targetSrcRange,
5231-
SourceLocation BuiltinLoc) {
5232-
auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
5233-
assert(vecTyTarget);
5234-
QualType vecElemT = vecTyTarget->getElementType();
5235-
if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
5236-
QualType floatVecTy = S->Context.getVectorType(
5237-
S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
5238-
5239-
S->Diag(BuiltinLoc, diag::warn_impcast_integer_float_precision)
5240-
<< source.get()->getType() << floatVecTy
5241-
<< source.get()->getSourceRange() << targetSrcRange;
5242-
source = S->SemaConvertVectorExpr(
5243-
source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
5244-
source.get()->getBeginLoc());
5245-
}
5246-
}
5247-
5248-
// Helper function for CheckHLSLBuiltinFunctionCall
5249-
void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
5250-
QualType sourceTy = source.get()->getType();
5251-
auto *vecTyTarget = targetTy->getAs<VectorType>();
5252-
QualType vecElemT = vecTyTarget->getElementType();
5253-
if (vecElemT->isFloatingType() && sourceTy != vecElemT)
5254-
// if float vec splat wil do an unnecessary cast to double
5255-
source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
5256-
source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
5257-
}
5258-
52595168
// Helper function for CheckHLSLBuiltinFunctionCall
52605169
bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
52615170
assert(TheCall->getNumArgs() > 1);
@@ -5265,36 +5174,42 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
52655174
QualType ArgTyB = B.get()->getType();
52665175
auto *VecTyA = ArgTyA->getAs<VectorType>();
52675176
auto *VecTyB = ArgTyB->getAs<VectorType>();
5268-
5177+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
52695178
if (VecTyA == nullptr && VecTyB == nullptr)
52705179
return false;
52715180

52725181
if (VecTyA && VecTyB) {
5273-
if (VecTyA->getElementType() == VecTyB->getElementType()) {
5274-
TheCall->setType(VecTyA->getElementType());
5275-
return false;
5182+
bool retValue = false;
5183+
if (VecTyA->getElementType() != VecTyB->getElementType()) {
5184+
// Note: type promotion is intended to be handeled via the intrinsics
5185+
// and not the builtin itself.
5186+
S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
5187+
<< TheCall->getDirectCallee()
5188+
<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
5189+
retValue = true;
5190+
}
5191+
if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
5192+
// if we get here a HLSLVectorTruncation is needed.
5193+
S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5194+
<< TheCall->getDirectCallee()
5195+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5196+
TheCall->getArg(1)->getEndLoc());
5197+
retValue = true;
52765198
}
5277-
// Note: type promotion is intended to be handeled via the intrinsics
5278-
// and not the builtin itself.
5279-
S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
5280-
<< TheCall->getDirectCallee()
5281-
<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
5282-
return true;
5283-
}
52845199

5285-
if (VecTyB) {
5286-
CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
5287-
TheCall->getBeginLoc());
5288-
PromoteVectorArgSplat(S, A, B.get()->getType());
5289-
}
5290-
if (VecTyA) {
5291-
CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
5292-
TheCall->getBeginLoc());
5293-
PromoteVectorArgSplat(S, B, A.get()->getType());
5200+
if (retValue)
5201+
TheCall->setType(VecTyA->getElementType());
5202+
5203+
return retValue;
52945204
}
5295-
TheCall->setArg(0, A.get());
5296-
TheCall->setArg(1, B.get());
5297-
return false;
5205+
5206+
// Note: if we get here one of the args is a scalar which
5207+
// requires a VectorSplat on Arg0 or Arg1
5208+
S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5209+
<< TheCall->getDirectCallee()
5210+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5211+
TheCall->getArg(1)->getEndLoc());
5212+
return true;
52985213
}
52995214

53005215
// Note: returning true in this case results in CheckBuiltinFunctionCall
@@ -5304,11 +5219,8 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
53045219
case Builtin::BI__builtin_hlsl_dot: {
53055220
if (checkArgCount(*this, TheCall, 2))
53065221
return true;
5307-
if (PromoteBoolsToInt(this, TheCall))
5308-
return true;
53095222
if (CheckVectorElementCallArgs(this, TheCall))
53105223
return true;
5311-
PromoteVectorArgTruncation(this, TheCall);
53125224
if (SemaBuiltinVectorToScalarMath(TheCall))
53135225
return true;
53145226
break;
@@ -19759,24 +19671,22 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1975919671

1976019672
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1976119673
QualType Res;
19762-
bool result = SemaBuiltinVectorMath(TheCall, Res);
19763-
if (result)
19674+
if (SemaBuiltinVectorMath(TheCall, Res))
1976419675
return true;
1976519676
TheCall->setType(Res);
1976619677
return false;
1976719678
}
1976819679

1976919680
bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
1977019681
QualType Res;
19771-
bool result = SemaBuiltinVectorMath(TheCall, Res);
19772-
if (result)
19682+
if (SemaBuiltinVectorMath(TheCall, Res))
1977319683
return true;
1977419684

19775-
if (auto *VecTy0 = Res->getAs<VectorType>()) {
19685+
if (auto *VecTy0 = Res->getAs<VectorType>())
1977619686
TheCall->setType(VecTy0->getElementType());
19777-
} else {
19687+
else
1977819688
TheCall->setType(Res);
19779-
}
19689+
1978019690
return false;
1978119691
}
1978219692

Original file line numberDiff line numberDiff line change
@@ -1,44 +1,6 @@
1-
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2-
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
3-
// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \
4-
// RUN: --check-prefixes=CHECK
5-
6-
// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
7-
// CHECK: ret float %dx.dot
8-
float test_builtin_dot_float2_splat ( float p0, float2 p1 ) {
9-
return __builtin_hlsl_dot( p0, p1 );
10-
}
11-
12-
// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
13-
// CHECK: ret float %dx.dot
14-
float test_builtin_dot_float3_splat ( float p0, float3 p1 ) {
15-
return __builtin_hlsl_dot( p0, p1 );
16-
}
17-
18-
// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
19-
// CHECK: ret float %dx.dot
20-
float test_builtin_dot_float4_splat ( float p0, float4 p1 ) {
21-
return __builtin_hlsl_dot( p0, p1 );
22-
}
23-
24-
// CHECK: %conv = sitofp i32 %1 to float
25-
// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
26-
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
27-
// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
28-
// CHECK: ret float %dx.dot
29-
float test_dot_float2_int_splat ( float2 p0, int p1 ) {
30-
return __builtin_hlsl_dot ( p0, p1 );
31-
}
32-
33-
// CHECK: %conv = sitofp i32 %1 to float
34-
// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
35-
// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
36-
// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
37-
// CHECK: ret float %dx.dot
38-
float test_dot_float3_int_splat ( float3 p0, int p1 ) {
39-
return __builtin_hlsl_dot ( p0, p1 );
40-
}
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
412

3+
// CHECK-LABEL: builtin_bool_to_float_type_promotion
424
// CHECK: %conv1 = uitofp i1 %tobool to double
435
// CHECK: %dx.dot = fmul double %conv, %conv1
446
// CHECK: %conv2 = fptrunc double %dx.dot to float
@@ -47,6 +9,7 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
479
return __builtin_hlsl_dot ( p0, p1 );
4810
}
4911

12+
// CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
5013
// CHECK: %conv = uitofp i1 %tobool to double
5114
// CHECK: %conv1 = fpext float %1 to double
5215
// CHECK: %dx.dot = fmul double %conv, %conv1
@@ -56,28 +19,12 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
5619
return __builtin_hlsl_dot ( p0, p1 );
5720
}
5821

59-
// CHECK: %conv = zext i1 %tobool to i32
60-
// CHECK: %conv3 = zext i1 %tobool2 to i32
61-
// CHECK: %dx.dot = mul i32 %conv, %conv3
62-
// CHECK: ret i32 %dx.dot
63-
int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
64-
return __builtin_hlsl_dot ( p0, p1 );
65-
}
66-
22+
// CHECK-LABEL: builtin_dot_int_to_float_promotion
6723
// CHECK: %conv = fpext float %0 to double
6824
// CHECK: %conv1 = sitofp i32 %1 to double
6925
// CHECK: dx.dot = fmul double %conv, %conv1
7026
// CHECK: %conv2 = fptrunc double %dx.dot to float
7127
// CHECK: ret float %conv2
72-
float test_builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
73-
return __builtin_hlsl_dot ( p0, p1 );
74-
}
75-
76-
77-
// CHECK: %conv = sitofp <2 x i32> %0 to <2 x float>
78-
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
79-
// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
80-
// CHECK: ret float %dx.dot
81-
float test_builtin_dot_int_vect_to_float_vec_promotion ( int2 p0, float p1 ) {
28+
float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
8229
return __builtin_hlsl_dot ( p0, p1 );
8330
}

clang/test/CodeGenHLSL/builtins/dot.hlsl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
22
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
3-
// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \
3+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
44
// RUN: --check-prefixes=CHECK,NATIVE_HALF
5-
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
5+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
66
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
77
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
88

clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl

+37-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2-
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm \
3-
// RUN: -disable-llvm-passes -verify -verify-ignore-unexpected
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
42

53
float test_no_second_arg ( float2 p0) {
64
return __builtin_hlsl_dot ( p0 );
@@ -25,7 +23,7 @@ float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) {
2523

2624
float test_dot_builtin_vector_size_mismatch ( float3 p0, float2 p1 ) {
2725
return __builtin_hlsl_dot ( p0, p1 );
28-
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float2' (aka 'vector<float, 2>')}}
26+
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
2927
}
3028

3129
float test_dot_scalar_mismatch ( float p0, int p1 ) {
@@ -75,3 +73,38 @@ int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1
7573
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
7674
}
7775
#endif
76+
77+
float test_builtin_dot_float2_splat ( float p0, float2 p1 ) {
78+
return __builtin_hlsl_dot( p0, p1 );
79+
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
80+
}
81+
82+
float test_builtin_dot_float3_splat ( float p0, float3 p1 ) {
83+
return __builtin_hlsl_dot( p0, p1 );
84+
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
85+
}
86+
87+
float test_builtin_dot_float4_splat ( float p0, float4 p1 ) {
88+
return __builtin_hlsl_dot( p0, p1 );
89+
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
90+
}
91+
92+
float test_dot_float2_int_splat ( float2 p0, int p1 ) {
93+
return __builtin_hlsl_dot ( p0, p1 );
94+
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
95+
}
96+
97+
float test_dot_float3_int_splat ( float3 p0, int p1 ) {
98+
return __builtin_hlsl_dot ( p0, p1 );
99+
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
100+
}
101+
102+
float test_builtin_dot_int_vect_to_float_vec_promotion ( int2 p0, float p1 ) {
103+
return __builtin_hlsl_dot ( p0, p1 );
104+
// expected-error@-1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
105+
}
106+
107+
int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
108+
return __builtin_hlsl_dot ( p0, p1 );
109+
// expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
110+
}

0 commit comments

Comments
 (0)