Skip to content

Commit a4b7a2d

Browse files
authored
[DirectX] Propagate shader flags mask of callees to callers (#118306)
Propagate shader flags mask of callees to callers. Add tests to verify propagation of shader flags
1 parent 5187482 commit a4b7a2d

File tree

4 files changed

+259
-40
lines changed

4 files changed

+259
-40
lines changed

llvm/lib/Target/DirectX/DXILShaderFlags.cpp

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313

1414
#include "DXILShaderFlags.h"
1515
#include "DirectX.h"
16-
#include "llvm/ADT/STLExtras.h"
16+
#include "llvm/ADT/SCCIterator.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
#include "llvm/Analysis/CallGraph.h"
1719
#include "llvm/Analysis/DXILResource.h"
1820
#include "llvm/IR/Instruction.h"
21+
#include "llvm/IR/Instructions.h"
1922
#include "llvm/IR/IntrinsicInst.h"
2023
#include "llvm/IR/Intrinsics.h"
2124
#include "llvm/IR/IntrinsicsDirectX.h"
@@ -27,24 +30,31 @@
2730
using namespace llvm;
2831
using namespace llvm::dxil;
2932

30-
static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
31-
DXILResourceTypeMap &DRTM) {
33+
/// Update the shader flags mask based on the given instruction.
34+
/// \param CSF Shader flags mask to update.
35+
/// \param I Instruction to check.
36+
void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
37+
const Instruction &I,
38+
DXILResourceTypeMap &DRTM) {
3239
if (!CSF.Doubles)
3340
CSF.Doubles = I.getType()->isDoubleTy();
3441

3542
if (!CSF.Doubles) {
36-
for (Value *Op : I.operands())
37-
CSF.Doubles |= Op->getType()->isDoubleTy();
43+
for (const Value *Op : I.operands()) {
44+
if (Op->getType()->isDoubleTy()) {
45+
CSF.Doubles = true;
46+
break;
47+
}
48+
}
3849
}
50+
3951
if (CSF.Doubles) {
4052
switch (I.getOpcode()) {
4153
case Instruction::FDiv:
4254
case Instruction::UIToFP:
4355
case Instruction::SIToFP:
4456
case Instruction::FPToUI:
4557
case Instruction::FPToSI:
46-
// TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
47-
// https://github.com/llvm/llvm-project/issues/114554
4858
CSF.DX11_1_DoubleExtensions = true;
4959
break;
5060
}
@@ -62,27 +72,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
6272
}
6373
}
6474
}
75+
// Handle call instructions
76+
if (auto *CI = dyn_cast<CallInst>(&I)) {
77+
const Function *CF = CI->getCalledFunction();
78+
// Merge-in shader flags mask of the called function in the current module
79+
if (FunctionFlags.contains(CF))
80+
CSF.merge(FunctionFlags[CF]);
81+
82+
// TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
83+
// DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
84+
}
6585
}
6686

67-
void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {
68-
69-
// Collect shader flags for each of the functions
70-
for (const auto &F : M.getFunctionList()) {
71-
if (F.isDeclaration()) {
72-
assert(!F.getName().starts_with("dx.op.") &&
73-
"DXIL Shader Flag analysis should not be run post-lowering.");
74-
continue;
87+
/// Construct ModuleShaderFlags for module Module M
88+
void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) {
89+
CallGraph CG(M);
90+
91+
// Compute Shader Flags Mask for all functions using post-order visit of SCC
92+
// of the call graph.
93+
for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
94+
++SCCI) {
95+
const std::vector<CallGraphNode *> &CurSCC = *SCCI;
96+
97+
// Union of shader masks of all functions in CurSCC
98+
ComputedShaderFlags SCCSF;
99+
// List of functions in CurSCC that are neither external nor declarations
100+
// and hence whose flags are collected
101+
SmallVector<Function *> CurSCCFuncs;
102+
for (CallGraphNode *CGN : CurSCC) {
103+
Function *F = CGN->getFunction();
104+
if (!F)
105+
continue;
106+
107+
if (F->isDeclaration()) {
108+
assert(!F->getName().starts_with("dx.op.") &&
109+
"DXIL Shader Flag analysis should not be run post-lowering.");
110+
continue;
111+
}
112+
113+
ComputedShaderFlags CSF;
114+
for (const auto &BB : *F)
115+
for (const auto &I : BB)
116+
updateFunctionFlags(CSF, I, DRTM);
117+
// Update combined shader flags mask for all functions in this SCC
118+
SCCSF.merge(CSF);
119+
120+
CurSCCFuncs.push_back(F);
75121
}
76-
ComputedShaderFlags CSF;
77-
for (const auto &BB : F)
78-
for (const auto &I : BB)
79-
updateFunctionFlags(CSF, I, DRTM);
80-
// Insert shader flag mask for function F
81-
FunctionFlags.push_back({&F, CSF});
82-
// Update combined shader flags mask
83-
CombinedSFMask.merge(CSF);
122+
123+
// Update combined shader flags mask for all functions of the module
124+
CombinedSFMask.merge(SCCSF);
125+
126+
// Shader flags mask of each of the functions in an SCC of the call graph is
127+
// the union of all functions in the SCC. Update shader flags masks of
128+
// functions in CurSCC accordingly. This is trivially true if SCC contains
129+
// one function.
130+
for (Function *F : CurSCCFuncs)
131+
// Merge SCCSF with that of F
132+
FunctionFlags[F].merge(SCCSF);
84133
}
85-
llvm::sort(FunctionFlags);
86134
}
87135

88136
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -106,12 +154,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
106154
/// Return the shader flags mask of the specified function Func.
107155
const ComputedShaderFlags &
108156
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
109-
const auto Iter = llvm::lower_bound(
110-
FunctionFlags, Func,
111-
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
112-
const Function *FindFunc) { return (FSM.first < FindFunc); });
157+
auto Iter = FunctionFlags.find(Func);
113158
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
114-
"No Shader Flags Mask exists for function");
159+
"Get Shader Flags : No Shader Flags Mask exists for function");
115160
return Iter->second;
116161
}
117162

@@ -142,7 +187,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
142187
for (const auto &F : M.getFunctionList()) {
143188
if (F.isDeclaration())
144189
continue;
145-
auto SFMask = FlagsInfo.getFunctionFlags(&F);
190+
const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
146191
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
147192
(uint64_t)(SFMask));
148193
}

llvm/lib/Target/DirectX/DXILShaderFlags.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,31 +71,31 @@ struct ComputedShaderFlags {
7171
return FeatureFlags;
7272
}
7373

74-
void merge(const uint64_t IVal) {
74+
void merge(const ComputedShaderFlags CSF) {
7575
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
76-
FlagName |= (IVal & getMask(DxilModuleBit));
77-
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
78-
FlagName |= (IVal & getMask(DxilModuleBit));
76+
FlagName |= CSF.FlagName;
77+
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
7978
#include "llvm/BinaryFormat/DXContainerConstants.def"
80-
return;
8179
}
8280

8381
void print(raw_ostream &OS = dbgs()) const;
8482
LLVM_DUMP_METHOD void dump() const { print(); }
8583
};
8684

8785
struct ModuleShaderFlags {
88-
void initialize(const Module &, DXILResourceTypeMap &DRTM);
86+
void initialize(Module &, DXILResourceTypeMap &DRTM);
8987
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
9088
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
9189

9290
private:
93-
/// Vector of sorted Function-Shader Flag mask pairs representing properties
94-
/// of each of the functions in the module. Shader Flags of each function
95-
/// represent both module-level and function-level flags
96-
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
91+
/// Map of Function-Shader Flag Mask pairs representing properties of each of
92+
/// the functions in the module. Shader Flags of each function represent both
93+
/// module-level and function-level flags
94+
DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
9795
/// Combined Shader Flag Mask of all functions of the module
9896
ComputedShaderFlags CombinedSFMask{};
97+
void updateFunctionFlags(ComputedShaderFlags &, const Instruction &,
98+
DXILResourceTypeMap &);
9999
};
100100

101101
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {

llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
1212
; CHECK-NEXT: ;
1313
; CHECK-NEXT: ; Shader Flags for Module Functions
1414

15+
;CHECK: ; Function top_level : 0x00000044
16+
define double @top_level() #0 {
17+
%r = call double @test_uitofp_i64(i64 5)
18+
ret double %r
19+
}
20+
21+
1522
; CHECK: ; Function test_fdiv_double : 0x00000044
1623
define double @test_fdiv_double(double %a, double %b) #0 {
1724
%res = fdiv double %a, %b
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.7-library"
4+
5+
; CHECK: ; Combined Shader Flags for Module
6+
; CHECK-NEXT: ; Shader Flags Value: 0x00000044
7+
; CHECK-NEXT: ;
8+
; CHECK-NEXT: ; Note: shader requires additional functionality:
9+
; CHECK-NEXT: ; Double-precision floating point
10+
; CHECK-NEXT: ; Double-precision extensions for 11.1
11+
; CHECK-NEXT: ; Note: extra DXIL module flags:
12+
; CHECK-NEXT: ;
13+
; CHECK-NEXT: ; Shader Flags for Module Functions
14+
15+
; Call Graph of test source
16+
; main -> [get_fptoui_flag, get_sitofp_fdiv_flag]
17+
; get_fptoui_flag -> [get_sitofp_uitofp_flag, call_get_uitofp_flag]
18+
; get_sitofp_uitofp_flag -> [call_get_fptoui_flag, call_get_sitofp_flag]
19+
; call_get_fptoui_flag -> [get_fptoui_flag]
20+
; get_sitofp_fdiv_flag -> [get_no_flags, get_all_doubles_flags]
21+
; get_all_doubles_flags -> [call_get_sitofp_fdiv_flag]
22+
; call_get_sitofp_fdiv_flag -> [get_sitofp_fdiv_flag]
23+
; call_get_sitofp_flag -> [get_sitofp_flag]
24+
; call_get_uitofp_flag -> [get_uitofp_flag]
25+
; get_sitofp_flag -> []
26+
; get_uitofp_flag -> []
27+
; get_no_flags -> []
28+
;
29+
; Strongly Connected Component in the CG
30+
; [get_fptoui_flag, get_sitofp_uitofp_flag, call_get_fptoui_flag]
31+
; [get_sitofp_fdiv_flag, get_all_doubles_flags, call_get_sitofp_fdiv_flag]
32+
33+
;
34+
; CHECK: ; Function get_sitofp_flag : 0x00000044
35+
define double @get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
36+
%2 = sitofp i32 %0 to double
37+
ret double %2
38+
}
39+
40+
; CHECK: ; Function call_get_sitofp_flag : 0x00000044
41+
define double @call_get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
42+
%2 = tail call double @get_sitofp_flag(i32 noundef %0)
43+
ret double %2
44+
}
45+
46+
; CHECK: ; Function get_uitofp_flag : 0x00000044
47+
define double @get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
48+
%2 = uitofp i32 %0 to double
49+
ret double %2
50+
}
51+
52+
; CHECK: ; Function call_get_uitofp_flag : 0x00000044
53+
define double @call_get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
54+
%2 = tail call double @get_uitofp_flag(i32 noundef %0)
55+
ret double %2
56+
}
57+
58+
; CHECK: ; Function call_get_fptoui_flag : 0x00000044
59+
define double @call_get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
60+
%2 = tail call double @get_fptoui_flag(double noundef %0)
61+
ret double %2
62+
}
63+
64+
; CHECK: ; Function get_fptoui_flag : 0x00000044
65+
define double @get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
66+
%2 = fcmp ugt double %0, 5.000000e+00
67+
br i1 %2, label %6, label %3
68+
69+
3: ; preds = %1
70+
%4 = fptoui double %0 to i64
71+
%5 = tail call double @get_sitofp_uitofp_flag(i64 noundef %4)
72+
br label %9
73+
74+
6: ; preds = %1
75+
%7 = fptoui double %0 to i32
76+
%8 = tail call double @call_get_uitofp_flag(i32 noundef %7)
77+
br label %9
78+
79+
9: ; preds = %6, %3
80+
%10 = phi double [ %5, %3 ], [ %8, %6 ]
81+
ret double %10
82+
}
83+
84+
; CHECK: ; Function get_sitofp_uitofp_flag : 0x00000044
85+
define double @get_sitofp_uitofp_flag(i64 noundef %0) local_unnamed_addr #0 {
86+
%2 = icmp ult i64 %0, 6
87+
br i1 %2, label %3, label %7
88+
89+
3: ; preds = %1
90+
%4 = add nuw nsw i64 %0, 1
91+
%5 = uitofp i64 %4 to double
92+
%6 = tail call double @call_get_fptoui_flag(double noundef %5)
93+
br label %10
94+
95+
7: ; preds = %1
96+
%8 = trunc i64 %0 to i32
97+
%9 = tail call double @call_get_sitofp_flag(i32 noundef %8)
98+
br label %10
99+
100+
10: ; preds = %7, %3
101+
%11 = phi double [ %6, %3 ], [ %9, %7 ]
102+
ret double %11
103+
}
104+
105+
; CHECK: ; Function get_no_flags : 0x00000000
106+
define i32 @get_no_flags(i32 noundef %0) local_unnamed_addr #0 {
107+
%2 = mul nsw i32 %0, %0
108+
ret i32 %2
109+
}
110+
111+
; CHECK: ; Function call_get_sitofp_fdiv_flag : 0x00000044
112+
define i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
113+
%2 = icmp eq i32 %0, 0
114+
br i1 %2, label %5, label %3
115+
116+
3: ; preds = %1
117+
%4 = mul nsw i32 %0, %0
118+
br label %7
119+
120+
5: ; preds = %1
121+
%6 = tail call double @get_sitofp_fdiv_flag(i32 noundef 0)
122+
br label %7
123+
124+
7: ; preds = %5, %3
125+
%8 = phi i32 [ %4, %3 ], [ 0, %5 ]
126+
ret i32 %8
127+
}
128+
129+
; CHECK: ; Function get_sitofp_fdiv_flag : 0x00000044
130+
define double @get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
131+
%2 = icmp sgt i32 %0, 5
132+
br i1 %2, label %3, label %6
133+
134+
3: ; preds = %1
135+
%4 = tail call i32 @get_no_flags(i32 noundef %0)
136+
%5 = sitofp i32 %4 to double
137+
br label %9
138+
139+
6: ; preds = %1
140+
%7 = tail call double @get_all_doubles_flags(i32 noundef %0)
141+
%8 = fdiv double %7, 3.000000e+00
142+
br label %9
143+
144+
9: ; preds = %6, %3
145+
%10 = phi double [ %5, %3 ], [ %8, %6 ]
146+
ret double %10
147+
}
148+
149+
; CHECK: ; Function get_all_doubles_flags : 0x00000044
150+
define double @get_all_doubles_flags(i32 noundef %0) local_unnamed_addr #0 {
151+
%2 = tail call i32 @call_get_sitofp_fdiv_flag(i32 noundef %0)
152+
%3 = icmp eq i32 %2, 0
153+
%4 = select i1 %3, double 1.000000e+01, double 1.000000e+02
154+
ret double %4
155+
}
156+
157+
; CHECK: ; Function main : 0x00000044
158+
define i32 @main() local_unnamed_addr #0 {
159+
%1 = tail call double @get_fptoui_flag(double noundef 1.000000e+00)
160+
%2 = tail call double @get_sitofp_fdiv_flag(i32 noundef 4)
161+
%3 = fadd double %1, %2
162+
%4 = fcmp ogt double %3, 0.000000e+00
163+
%5 = zext i1 %4 to i32
164+
ret i32 %5
165+
}
166+
167+
attributes #0 = { convergent norecurse nounwind "hlsl.export"}

0 commit comments

Comments
 (0)