13
13
14
14
#include " DXILShaderFlags.h"
15
15
#include " DirectX.h"
16
+ #include " llvm/ADT/SCCIterator.h"
16
17
#include " llvm/ADT/STLExtras.h"
18
+ #include " llvm/ADT/SmallVector.h"
19
+ #include " llvm/Analysis/CallGraph.h"
17
20
#include " llvm/Analysis/DXILResource.h"
18
21
#include " llvm/IR/Instruction.h"
22
+ #include " llvm/IR/Instructions.h"
19
23
#include " llvm/IR/IntrinsicInst.h"
20
24
#include " llvm/IR/Intrinsics.h"
21
25
#include " llvm/IR/IntrinsicsDirectX.h"
27
31
using namespace llvm ;
28
32
using namespace llvm ::dxil;
29
33
30
- static void updateFunctionFlags (ComputedShaderFlags &CSF, const Instruction &I,
31
- DXILResourceTypeMap &DRTM) {
34
+ // / Update the shader flags mask based on the given instruction.
35
+ // / \param CSF Shader flags mask to update.
36
+ // / \param I Instruction to check.
37
+ void ModuleShaderFlags::updateFunctionFlags (ComputedShaderFlags &CSF,
38
+ const Instruction &I,
39
+ DXILResourceTypeMap &DRTM) {
32
40
if (!CSF.Doubles )
33
41
CSF.Doubles = I.getType ()->isDoubleTy ();
34
42
35
43
if (!CSF.Doubles ) {
36
- for (Value *Op : I.operands ())
37
- CSF.Doubles |= Op->getType ()->isDoubleTy ();
44
+ for (const Value *Op : I.operands ()) {
45
+ if (Op->getType ()->isDoubleTy ()) {
46
+ CSF.Doubles = true ;
47
+ break ;
48
+ }
49
+ }
38
50
}
51
+
39
52
if (CSF.Doubles ) {
40
53
switch (I.getOpcode ()) {
41
54
case Instruction::FDiv:
42
55
case Instruction::UIToFP:
43
56
case Instruction::SIToFP:
44
57
case Instruction::FPToUI:
45
58
case Instruction::FPToSI:
46
- // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
47
- // https://github.com/llvm/llvm-project/issues/114554
48
59
CSF.DX11_1_DoubleExtensions = true ;
49
60
break ;
50
61
}
@@ -62,27 +73,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
62
73
}
63
74
}
64
75
}
76
+ // Handle call instructions
77
+ if (auto *CI = dyn_cast<CallInst>(&I)) {
78
+ const Function *CF = CI->getCalledFunction ();
79
+ // Merge-in shader flags mask of the called function in the current module
80
+ if (FunctionFlags.contains (CF)) {
81
+ CSF.merge (FunctionFlags[CF]);
82
+ }
83
+ // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
84
+ // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
85
+ }
65
86
}
66
87
67
- void ModuleShaderFlags::initialize (const Module &M, DXILResourceTypeMap &DRTM) {
68
-
69
- // Collect shader flags for each of the functions
70
- for (const auto &F : M.getFunctionList ()) {
71
- if (F.isDeclaration ()) {
72
- assert (!F.getName ().starts_with (" dx.op." ) &&
73
- " DXIL Shader Flag analysis should not be run post-lowering." );
74
- continue ;
88
+ // / Construct ModuleShaderFlags for module Module M
89
+ void ModuleShaderFlags::initialize (Module &M, DXILResourceTypeMap &DRTM) {
90
+ CallGraph CG (M);
91
+
92
+ // Compute Shader Flags Mask for all functions using post-order visit of SCC
93
+ // of the call graph.
94
+ for (scc_iterator<CallGraph *> SCCI = scc_begin (&CG); !SCCI.isAtEnd ();
95
+ ++SCCI) {
96
+ const std::vector<CallGraphNode *> &CurSCC = *SCCI;
97
+
98
+ // Union of shader masks of all functions in CurSCC
99
+ ComputedShaderFlags SCCSF;
100
+ // List of functions in CurSCC that are neither external nor declarations
101
+ // and hence whose flags are collected
102
+ SmallVector<Function *> CurSCCFuncs;
103
+ for (CallGraphNode *CGN : CurSCC) {
104
+ Function *F = CGN->getFunction ();
105
+ if (!F)
106
+ continue ;
107
+
108
+ if (F->isDeclaration ()) {
109
+ assert (!F->getName ().starts_with (" dx.op." ) &&
110
+ " DXIL Shader Flag analysis should not be run post-lowering." );
111
+ continue ;
112
+ }
113
+
114
+ ComputedShaderFlags CSF;
115
+ for (const auto &BB : *F)
116
+ for (const auto &I : BB)
117
+ updateFunctionFlags (CSF, I, DRTM);
118
+ // Update combined shader flags mask for all functions in this SCC
119
+ SCCSF.merge (CSF);
120
+
121
+ CurSCCFuncs.push_back (F);
75
122
}
76
- ComputedShaderFlags CSF;
77
- for (const auto &BB : F)
78
- for (const auto &I : BB)
79
- updateFunctionFlags (CSF, I, DRTM);
80
- // Insert shader flag mask for function F
81
- FunctionFlags.push_back ({&F, CSF});
82
- // Update combined shader flags mask
83
- CombinedSFMask.merge (CSF);
123
+
124
+ // Update combined shader flags mask for all functions of the module
125
+ CombinedSFMask.merge (SCCSF);
126
+
127
+ // Shader flags mask of each of the functions in an SCC of the call graph is
128
+ // the union of all functions in the SCC. Update shader flags masks of
129
+ // functions in CurSCC accordingly. This is trivially true if SCC contains
130
+ // one function.
131
+ for (Function *F : CurSCCFuncs)
132
+ // Merge SCCSF with that of F
133
+ FunctionFlags[F].merge (SCCSF);
84
134
}
85
- llvm::sort (FunctionFlags);
86
135
}
87
136
88
137
void ComputedShaderFlags::print (raw_ostream &OS) const {
@@ -106,12 +155,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
106
155
// / Return the shader flags mask of the specified function Func.
107
156
const ComputedShaderFlags &
108
157
ModuleShaderFlags::getFunctionFlags (const Function *Func) const {
109
- const auto Iter = llvm::lower_bound (
110
- FunctionFlags, Func,
111
- [](const std::pair<const Function *, ComputedShaderFlags> FSM,
112
- const Function *FindFunc) { return (FSM.first < FindFunc); });
158
+ auto Iter = FunctionFlags.find (Func);
113
159
assert ((Iter != FunctionFlags.end () && Iter->first == Func) &&
114
- " No Shader Flags Mask exists for function" );
160
+ " Get Shader Flags : No Shader Flags Mask exists for function" );
115
161
return Iter->second ;
116
162
}
117
163
@@ -142,7 +188,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
142
188
for (const auto &F : M.getFunctionList ()) {
143
189
if (F.isDeclaration ())
144
190
continue ;
145
- auto SFMask = FlagsInfo.getFunctionFlags (&F);
191
+ const ComputedShaderFlags & SFMask = FlagsInfo.getFunctionFlags (&F);
146
192
OS << formatv (" ; Function {0} : {1:x8}\n ;\n " , F.getName (),
147
193
(uint64_t )(SFMask));
148
194
}
0 commit comments