Skip to content

Commit 481bce0

Browse files
Adding splitdouble HLSL function (#109331)
- Adding hlsl `splitdouble` intrinsics - Adding DXIL lowering - Adding SPIRV lowering - Adding test Fixes: #108901 --------- Co-authored-by: Joao Saffran <[email protected]>
1 parent 7bd8a16 commit 481bce0

File tree

17 files changed

+581
-76
lines changed

17 files changed

+581
-76
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4871,6 +4871,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
48714871
let Prototype = "void(...)";
48724872
}
48734873

4874+
def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
4875+
let Spellings = ["__builtin_hlsl_elementwise_splitdouble"];
4876+
let Attributes = [NoThrow, Const];
4877+
let Prototype = "void(...)";
4878+
}
4879+
48744880
// Builtins for XRay.
48754881
def XRayCustomEvent : Builtin {
48764882
let Spellings = ["__xray_customevent"];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "CGObjCRuntime.h"
1818
#include "CGOpenCLRuntime.h"
1919
#include "CGRecordLayout.h"
20+
#include "CGValue.h"
2021
#include "CodeGenFunction.h"
2122
#include "CodeGenModule.h"
2223
#include "ConstantEmitter.h"
@@ -25,8 +26,10 @@
2526
#include "clang/AST/ASTContext.h"
2627
#include "clang/AST/Attr.h"
2728
#include "clang/AST/Decl.h"
29+
#include "clang/AST/Expr.h"
2830
#include "clang/AST/OSLog.h"
2931
#include "clang/AST/OperationKinds.h"
32+
#include "clang/AST/Type.h"
3033
#include "clang/Basic/TargetBuiltins.h"
3134
#include "clang/Basic/TargetInfo.h"
3235
#include "clang/Basic/TargetOptions.h"
@@ -67,6 +70,7 @@
6770
#include "llvm/TargetParser/X86TargetParser.h"
6871
#include <optional>
6972
#include <sstream>
73+
#include <utility>
7074

7175
using namespace clang;
7276
using namespace CodeGen;
@@ -95,6 +99,76 @@ static void initializeAlloca(CodeGenFunction &CGF, AllocaInst *AI, Value *Size,
9599
I->addAnnotationMetadata("auto-init");
96100
}
97101

102+
static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
103+
Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
104+
const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
105+
const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
106+
107+
CallArgList Args;
108+
LValue Op1TmpLValue =
109+
CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
110+
LValue Op2TmpLValue =
111+
CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
112+
113+
if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
114+
Args.reverseWritebacks();
115+
116+
Value *LowBits = nullptr;
117+
Value *HighBits = nullptr;
118+
119+
if (CGF->CGM.getTarget().getTriple().isDXIL()) {
120+
121+
llvm::Type *RetElementTy = CGF->Int32Ty;
122+
if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
123+
RetElementTy = llvm::VectorType::get(
124+
CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
125+
auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
126+
127+
CallInst *CI = CGF->Builder.CreateIntrinsic(
128+
RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
129+
130+
LowBits = CGF->Builder.CreateExtractValue(CI, 0);
131+
HighBits = CGF->Builder.CreateExtractValue(CI, 1);
132+
133+
} else {
134+
// For Non DXIL targets we generate the instructions.
135+
136+
if (!Op0->getType()->isVectorTy()) {
137+
FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
138+
Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
139+
140+
LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
141+
HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
142+
} else {
143+
int NumElements = 1;
144+
if (const auto *VecTy =
145+
E->getArg(0)->getType()->getAs<clang::VectorType>())
146+
NumElements = VecTy->getNumElements();
147+
148+
FixedVectorType *Uint32VecTy =
149+
FixedVectorType::get(CGF->Int32Ty, NumElements * 2);
150+
Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);
151+
if (NumElements == 1) {
152+
LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);
153+
HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);
154+
} else {
155+
SmallVector<int> EvenMask, OddMask;
156+
for (int I = 0, E = NumElements; I != E; ++I) {
157+
EvenMask.push_back(I * 2);
158+
OddMask.push_back(I * 2 + 1);
159+
}
160+
LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);
161+
HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);
162+
}
163+
}
164+
}
165+
CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
166+
auto *LastInst =
167+
CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
168+
CGF->EmitWritebacks(Args);
169+
return LastInst;
170+
}
171+
98172
/// getBuiltinLibFunction - Given a builtin id for a function like
99173
/// "__builtin_fabsf", return a Function* for "fabsf".
100174
llvm::Constant *CodeGenModule::getBuiltinLibFunction(const FunctionDecl *FD,
@@ -18959,6 +19033,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1895919033
CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
1896019034
nullptr, "hlsl.radians");
1896119035
}
19036+
case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
19037+
19038+
assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
19039+
E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
19040+
E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
19041+
"asuint operands types mismatch");
19042+
return handleHlslSplitdouble(E, this);
19043+
}
1896219044
}
1896319045
return nullptr;
1896419046
}

clang/lib/CodeGen/CGCall.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "llvm/IR/IntrinsicInst.h"
4141
#include "llvm/IR/Intrinsics.h"
4242
#include "llvm/IR/Type.h"
43+
#include "llvm/Support/Path.h"
4344
#include "llvm/Transforms/Utils/Local.h"
4445
#include <optional>
4546
using namespace clang;
@@ -4243,12 +4244,6 @@ static void emitWriteback(CodeGenFunction &CGF,
42434244
CGF.EmitBlock(contBB);
42444245
}
42454246

4246-
static void emitWritebacks(CodeGenFunction &CGF,
4247-
const CallArgList &args) {
4248-
for (const auto &I : args.writebacks())
4249-
emitWriteback(CGF, I);
4250-
}
4251-
42524247
static void deactivateArgCleanupsBeforeCall(CodeGenFunction &CGF,
42534248
const CallArgList &CallArgs) {
42544249
ArrayRef<CallArgList::CallArgCleanup> Cleanups =
@@ -4717,6 +4712,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
47174712
IsUsed = true;
47184713
}
47194714

4715+
void CodeGenFunction::EmitWritebacks(const CallArgList &args) {
4716+
for (const auto &I : args.writebacks())
4717+
emitWriteback(*this, I);
4718+
}
4719+
47204720
void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
47214721
QualType type) {
47224722
DisableDebugLocationUpdates Dis(*this, E);
@@ -5940,7 +5940,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
59405940
// Emit any call-associated writebacks immediately. Arguably this
59415941
// should happen after any return-value munging.
59425942
if (CallArgs.hasWritebacks())
5943-
emitWritebacks(*this, CallArgs);
5943+
EmitWritebacks(CallArgs);
59445944

59455945
// The stack cleanup for inalloca arguments has to run out of the normal
59465946
// lexical order, so deactivate it and run it manually here.

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5460,9 +5460,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) {
54605460
return getOrCreateOpaqueLValueMapping(e);
54615461
}
54625462

5463-
void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
5464-
CallArgList &Args, QualType Ty) {
5465-
5463+
std::pair<LValue, LValue>
5464+
CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
54665465
// Emitting the casted temporary through an opaque value.
54675466
LValue BaseLV = EmitLValue(E->getArgLValue());
54685467
OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV);
@@ -5476,6 +5475,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
54765475
TempLV);
54775476

54785477
OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV);
5478+
return std::make_pair(BaseLV, TempLV);
5479+
}
5480+
5481+
LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
5482+
CallArgList &Args, QualType Ty) {
5483+
5484+
auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
54795485

54805486
llvm::Value *Addr = TempLV.getAddress().getBasePointer();
54815487
llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType());
@@ -5488,6 +5494,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
54885494
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
54895495
LifetimeSize);
54905496
Args.add(RValue::get(TmpAddr, *this), Ty);
5497+
return TempLV;
54915498
}
54925499

54935500
LValue

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4296,8 +4296,11 @@ class CodeGenFunction : public CodeGenTypeCache {
42964296
LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
42974297
LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e);
42984298
LValue EmitHLSLArrayAssignLValue(const BinaryOperator *E);
4299-
void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
4300-
QualType Ty);
4299+
4300+
std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
4301+
QualType Ty);
4302+
LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
4303+
QualType Ty);
43014304

43024305
Address EmitExtVectorElementLValue(LValue V);
43034306

@@ -5147,6 +5150,9 @@ class CodeGenFunction : public CodeGenTypeCache {
51475150
SourceLocation ArgLoc, AbstractCallee AC,
51485151
unsigned ParmNum);
51495152

5153+
/// EmitWriteback - Emit callbacks for function.
5154+
void EmitWritebacks(const CallArgList &Args);
5155+
51505156
/// EmitCallArg - Emit a single call argument.
51515157
void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
51525158

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,24 @@ template <typename T> constexpr uint asuint(T F) {
438438
return __detail::bit_cast<uint, T>(F);
439439
}
440440

441+
//===----------------------------------------------------------------------===//
442+
// asuint splitdouble builtins
443+
//===----------------------------------------------------------------------===//
444+
445+
/// \fn void asuint(double D, out uint lowbits, out int highbits)
446+
/// \brief Split and interprets the lowbits and highbits of double D into uints.
447+
/// \param D The input double.
448+
/// \param lowbits The output lowbits of D.
449+
/// \param highbits The output highbits of D.
450+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
451+
void asuint(double, out uint, out uint);
452+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
453+
void asuint(double2, out uint2, out uint2);
454+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
455+
void asuint(double3, out uint3, out uint3);
456+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
457+
void asuint(double4, out uint4, out uint4);
458+
441459
//===----------------------------------------------------------------------===//
442460
// atan builtins
443461
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,18 +1698,27 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
16981698
return true;
16991699
}
17001700

1701-
static bool CheckArgsTypesAreCorrect(
1701+
bool CheckArgTypeIsCorrect(
1702+
Sema *S, Expr *Arg, QualType ExpectedType,
1703+
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
1704+
QualType PassedType = Arg->getType();
1705+
if (Check(PassedType)) {
1706+
if (auto *VecTyA = PassedType->getAs<VectorType>())
1707+
ExpectedType = S->Context.getVectorType(
1708+
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
1709+
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
1710+
<< PassedType << ExpectedType << 1 << 0 << 0;
1711+
return true;
1712+
}
1713+
return false;
1714+
}
1715+
1716+
bool CheckAllArgTypesAreCorrect(
17021717
Sema *S, CallExpr *TheCall, QualType ExpectedType,
17031718
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
17041719
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
1705-
QualType PassedType = TheCall->getArg(i)->getType();
1706-
if (Check(PassedType)) {
1707-
if (auto *VecTyA = PassedType->getAs<VectorType>())
1708-
ExpectedType = S->Context.getVectorType(
1709-
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
1710-
S->Diag(TheCall->getArg(0)->getBeginLoc(),
1711-
diag::err_typecheck_convert_incompatible)
1712-
<< PassedType << ExpectedType << 1 << 0 << 0;
1720+
Expr *Arg = TheCall->getArg(i);
1721+
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
17131722
return true;
17141723
}
17151724
}
@@ -1720,8 +1729,8 @@ static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
17201729
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
17211730
return !PassedType->hasFloatingRepresentation();
17221731
};
1723-
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
1724-
checkAllFloatTypes);
1732+
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
1733+
checkAllFloatTypes);
17251734
}
17261735

17271736
static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
@@ -1732,8 +1741,19 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
17321741
: PassedType;
17331742
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
17341743
};
1735-
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
1736-
checkFloatorHalf);
1744+
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
1745+
checkFloatorHalf);
1746+
}
1747+
1748+
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
1749+
unsigned ArgIndex) {
1750+
auto *Arg = TheCall->getArg(ArgIndex);
1751+
SourceLocation OrigLoc = Arg->getExprLoc();
1752+
if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
1753+
Expr::MLV_Valid)
1754+
return false;
1755+
S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
1756+
return true;
17371757
}
17381758

17391759
static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
@@ -1742,24 +1762,24 @@ static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
17421762
return VecTy->getElementType()->isDoubleType();
17431763
return false;
17441764
};
1745-
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
1746-
checkDoubleVector);
1765+
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
1766+
checkDoubleVector);
17471767
}
17481768
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
17491769
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
17501770
return !PassedType->hasIntegerRepresentation() &&
17511771
!PassedType->hasFloatingRepresentation();
17521772
};
1753-
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.IntTy,
1754-
checkAllSignedTypes);
1773+
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
1774+
checkAllSignedTypes);
17551775
}
17561776

17571777
static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
17581778
auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
17591779
return !PassedType->hasUnsignedIntegerRepresentation();
17601780
};
1761-
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
1762-
checkAllUnsignedTypes);
1781+
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
1782+
checkAllUnsignedTypes);
17631783
}
17641784

17651785
static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2074,6 +2094,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
20742094
return true;
20752095
break;
20762096
}
2097+
case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
2098+
if (SemaRef.checkArgCount(TheCall, 3))
2099+
return true;
2100+
2101+
if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
2102+
CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
2103+
1) ||
2104+
CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
2105+
2))
2106+
return true;
2107+
2108+
if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
2109+
CheckModifiableLValue(&SemaRef, TheCall, 2))
2110+
return true;
2111+
break;
2112+
}
20772113
case Builtin::BI__builtin_elementwise_acos:
20782114
case Builtin::BI__builtin_elementwise_asin:
20792115
case Builtin::BI__builtin_elementwise_atan:

0 commit comments

Comments
 (0)