@@ -5166,6 +5166,9 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51665166 return false;
51675167}
51685168
5169+ // Helper function for CheckHLSLBuiltinFunctionCall
5170+ // Note: UsualArithmeticConversions handles the case where at least
5171+ // one arg isn't a bool
51695172bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
51705173 unsigned NumArgs = TheCall->getNumArgs();
51715174
@@ -5181,47 +5184,13 @@ bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
51815184 Sema::AA_Converting);
51825185 if (ResA.isInvalid())
51835186 return true;
5184- TheCall->setArg(0 , ResA.get());
5187+ TheCall->setArg(i , ResA.get());
51855188 }
51865189 return false;
51875190}
51885191
5189- int overloadOrder(Sema *S, QualType ArgTyA) {
5190- auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
5191- switch (kind) {
5192- case BuiltinType::Short:
5193- case BuiltinType::UShort:
5194- return 1;
5195- case BuiltinType::Int:
5196- case BuiltinType::UInt:
5197- return 2;
5198- case BuiltinType::Long:
5199- case BuiltinType::ULong:
5200- return 3;
5201- case BuiltinType::LongLong:
5202- case BuiltinType::ULongLong:
5203- return 4;
5204- case BuiltinType::Float16:
5205- case BuiltinType::Half:
5206- return 5;
5207- case BuiltinType::Float:
5208- return 6;
5209- default:
5210- break;
5211- }
5212- return 0;
5213- }
5214-
5215- QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
5216- auto *VecTyA = ArgTyA->getAs<VectorType>();
5217- auto *VecTyB = ArgTyB->getAs<VectorType>();
5218- QualType VecTyAElem = VecTyA->getElementType();
5219- QualType VecTyBElem = VecTyB->getElementType();
5220- int vecAElemWidth = overloadOrder(S, VecTyAElem);
5221- int vecBElemWidth = overloadOrder(S, VecTyBElem);
5222- return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
5223- }
5224-
5192+ // Helper function for CheckHLSLBuiltinFunctionCall
5193+ // Handles the CK_HLSLVectorTruncation case for builtins
52255194void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
52265195 assert(TheCall->getNumArgs() > 1);
52275196 ExprResult A = TheCall->getArg(0);
@@ -5246,6 +5215,7 @@ void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
52465215 SmallerArg = B.get();
52475216 largerIndex = 0;
52485217 }
5218+
52495219 S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
52505220 << LargerArg->getType() << SmallerArg->getType()
52515221 << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
@@ -5255,61 +5225,79 @@ void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
52555225 return;
52565226}
52575227
5258- bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
5228+ // Helper function for CheckHLSLBuiltinFunctionCall
5229+ void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
5230+ SourceRange targetSrcRange,
5231+ SourceLocation BuiltinLoc) {
5232+ auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
5233+ assert(vecTyTarget);
5234+ QualType vecElemT = vecTyTarget->getElementType();
5235+ if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
5236+ QualType floatVecTy = S->Context.getVectorType(
5237+ S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
5238+ int floatByteSize =
5239+ S->Context.getTypeSizeInChars(S->Context.FloatTy).getQuantity();
5240+ int vecElemByteSize = S->Context.getTypeSizeInChars(vecElemT).getQuantity();
5241+ if (vecElemByteSize > floatByteSize)
5242+ S->Diag(BuiltinLoc, diag::warn_hlsl_impcast_bitwidth_reduction)
5243+ << source.get()->getType() << floatVecTy
5244+ << source.get()->getSourceRange() << targetSrcRange;
5245+
5246+ source = S->SemaConvertVectorExpr(
5247+ source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
5248+ source.get()->getBeginLoc());
5249+ }
5250+ }
5251+
5252+ // Helper function for CheckHLSLBuiltinFunctionCall
5253+ void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
5254+ QualType sourceTy = source.get()->getType();
5255+ auto *vecTyTarget = targetTy->getAs<VectorType>();
5256+ QualType vecElemT = vecTyTarget->getElementType();
5257+ if (vecElemT->isFloatingType() && sourceTy != vecElemT)
5258+ // if float vec splat wil do an unnecessary cast to double
5259+ source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
5260+ source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
5261+ }
5262+
5263+ // Helper function for CheckHLSLBuiltinFunctionCall
5264+ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
52595265 assert(TheCall->getNumArgs() > 1);
52605266 ExprResult A = TheCall->getArg(0);
52615267 ExprResult B = TheCall->getArg(1);
52625268 QualType ArgTyA = A.get()->getType();
52635269 QualType ArgTyB = B.get()->getType();
5264-
52655270 auto *VecTyA = ArgTyA->getAs<VectorType>();
52665271 auto *VecTyB = ArgTyB->getAs<VectorType>();
5272+
52675273 if (VecTyA == nullptr && VecTyB == nullptr)
52685274 return false;
5275+
52695276 if (VecTyA && VecTyB) {
52705277 if (VecTyA->getElementType() == VecTyB->getElementType()) {
52715278 TheCall->setType(VecTyA->getElementType());
52725279 return false;
52735280 }
5274- SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5275- QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
5276- if (CastType == ArgTyA) {
5277- ExprResult ResB = S->SemaConvertVectorExpr(
5278- B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
5279- B.get()->getBeginLoc());
5280- TheCall->setArg(1, ResB.get());
5281- TheCall->setType(VecTyA->getElementType());
5282- return false;
5283- }
5284-
5285- if (CastType == ArgTyB) {
5286- ExprResult ResA = S->SemaConvertVectorExpr(
5287- A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
5288- A.get()->getBeginLoc());
5289- TheCall->setArg(0, ResA.get());
5290- TheCall->setType(VecTyB->getElementType());
5291- return false;
5292- }
5293- return false;
5281+ // Note: type promotion is intended to be handeled via the intrinsics
5282+ // and not the builtin itself.
5283+ S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
5284+ << TheCall->getDirectCallee()
5285+ << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
5286+ return true;
52945287 }
52955288
52965289 if (VecTyB) {
5297- // Convert to the vector result type
5298- ExprResult ResA = A;
5299- if (VecTyB->getElementType() != ArgTyA)
5300- ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
5301- CK_FloatingCast);
5302- ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
5303- TheCall->setArg(0, ResA.get());
5290+ CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
5291+ TheCall->getBeginLoc());
5292+ PromoteVectorArgSplat(S, A, B.get()->getType());
53045293 }
53055294 if (VecTyA) {
5306- ExprResult ResB = B;
5307- if (VecTyA->getElementType() != ArgTyB)
5308- ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
5309- CK_FloatingCast);
5310- ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
5311- TheCall->setArg(1, ResB.get());
5295+ CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
5296+ TheCall->getBeginLoc());
5297+ PromoteVectorArgSplat(S, B, A.get()->getType());
53125298 }
5299+ TheCall->setArg(0, A.get());
5300+ TheCall->setArg(1, B.get());
53135301 return false;
53145302}
53155303
@@ -5322,7 +5310,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
53225310 return true;
53235311 if (PromoteBoolsToInt(this, TheCall))
53245312 return true;
5325- if (PromoteVectorElementCallArgs (this, TheCall))
5313+ if (CheckVectorElementCallArgs (this, TheCall))
53265314 return true;
53275315 PromoteVectorArgTruncation(this, TheCall);
53285316 if (SemaBuiltinVectorToScalarMath(TheCall))
0 commit comments