diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 2db4c1729c39f..1e88963345763 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -14,16 +14,21 @@ #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; using namespace llvm::dxil; -static void updateFunctionFlags(ComputedShaderFlags &CSF, - const Instruction &I) { +static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, + DXILResourceTypeMap &DRTM) { if (!CSF.Doubles) CSF.Doubles = I.getType()->isDoubleTy(); @@ -44,9 +49,23 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, break; } } + + if (auto *II = dyn_cast(&I)) { + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::dx_typedBufferLoad: { + dxil::ResourceTypeInfo &RTI = + DRTM[cast(II->getArgOperand(0)->getType())]; + if (RTI.isTyped()) + CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1; + } + } + } } -void ModuleShaderFlags::initialize(const Module &M) { +void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) { + // Collect shader flags for each of the functions for (const auto &F : M.getFunctionList()) { if (F.isDeclaration()) { @@ -57,7 +76,7 @@ void ModuleShaderFlags::initialize(const Module &M) { ComputedShaderFlags CSF; for (const auto &BB : F) for (const auto &I : BB) - updateFunctionFlags(CSF, I); + updateFunctionFlags(CSF, I, DRTM); // Insert shader flag mask for function F FunctionFlags.push_back({&F, CSF}); // Update combined shader flags mask @@ -104,8 +123,11 @@ AnalysisKey ShaderFlagsAnalysis::Key; ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M, ModuleAnalysisManager &AM) { + DXILResourceTypeMap &DRTM = AM.getResult(M); + ModuleShaderFlags MSFI; - MSFI.initialize(M); + MSFI.initialize(M, DRTM); + return MSFI; } @@ -132,11 +154,22 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) { - MSFI.initialize(M); + DXILResourceTypeMap &DRTM = + getAnalysis().getResourceTypeMap(); + + MSFI.initialize(M, DRTM); return false; } +void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive(); +} + char ShaderFlagsAnalysisWrapper::ID = 0; -INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", - "DXIL Shader Flag Analysis", true, true) +INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", + "DXIL Shader Flag Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) +INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", + "DXIL Shader Flag Analysis", true, true) diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h index 2d60137f8b191..67ddab39d0f34 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.h +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h @@ -26,6 +26,7 @@ namespace llvm { class Module; class GlobalVariable; +class DXILResourceTypeMap; namespace dxil { @@ -84,7 +85,7 @@ struct ComputedShaderFlags { }; struct ModuleShaderFlags { - void initialize(const Module &); + void initialize(const Module &, DXILResourceTypeMap &DRTM); const ComputedShaderFlags &getFunctionFlags(const Function *) const; const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; } @@ -135,9 +136,7 @@ class ShaderFlagsAnalysisWrapper : public ModulePass { bool runOnModule(Module &M) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } + void getAnalysisUsage(AnalysisUsage &AU) const override; }; } // namespace dxil diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/typed-uav-load-additional-formats.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/typed-uav-load-additional-formats.ll new file mode 100644 index 0000000000000..b6947393c4533 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/typed-uav-load-additional-formats.ll @@ -0,0 +1,44 @@ +; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s +; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=CHECK-OBJ + +target triple = "dxil-pc-shadermodel6.7-library" + +; CHECK-OBJ: - Name: SFI0 +; CHECK-OBJ: Flags: +; CHECK-OBJ: TypedUAVLoadAdditionalFormats: true + +; CHECK: Combined Shader Flags for Module +; CHECK-NEXT: Shader Flags Value: 0x00002000 + +; CHECK: Note: shader requires additional functionality: +; CHECK: Typed UAV Load Additional Formats + +; CHECK: Function multicomponent : 0x00002000 +define <4 x float> @multicomponent() #0 { + %res = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false) + %val = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %res, i32 0) + ret <4 x float> %val +} + +; CHECK: Function onecomponent : 0x00000000 +define float @onecomponent() #0 { + %res = call target("dx.TypedBuffer", float, 1, 0, 0) + @llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false) + %val = call float @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", float, 1, 0, 0) %res, i32 0) + ret float %val +} + +; CHECK: Function noload : 0x00000000 +define void @noload(<4 x float> %val) #0 { + %res = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false) + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %res, i32 0, + <4 x float> %val) + ret void +} + +attributes #0 = { convergent norecurse nounwind "hlsl.export"}