13
13
14
14
#include " DXILShaderFlags.h"
15
15
#include " DirectX.h"
16
- #include " llvm/ADT/STLExtras.h"
16
+ #include " llvm/ADT/SCCIterator.h"
17
+ #include " llvm/ADT/SmallVector.h"
18
+ #include " llvm/Analysis/CallGraph.h"
17
19
#include " llvm/Analysis/DXILResource.h"
18
20
#include " llvm/IR/Instruction.h"
21
+ #include " llvm/IR/Instructions.h"
19
22
#include " llvm/IR/IntrinsicInst.h"
20
23
#include " llvm/IR/Intrinsics.h"
21
24
#include " llvm/IR/IntrinsicsDirectX.h"
27
30
using namespace llvm ;
28
31
using namespace llvm ::dxil;
29
32
30
- static void updateFunctionFlags (ComputedShaderFlags &CSF, const Instruction &I,
31
- DXILResourceTypeMap &DRTM) {
33
+ // / Update the shader flags mask based on the given instruction.
34
+ // / \param CSF Shader flags mask to update.
35
+ // / \param I Instruction to check.
36
+ void ModuleShaderFlags::updateFunctionFlags (ComputedShaderFlags &CSF,
37
+ const Instruction &I,
38
+ DXILResourceTypeMap &DRTM) {
32
39
if (!CSF.Doubles )
33
40
CSF.Doubles = I.getType ()->isDoubleTy ();
34
41
35
42
if (!CSF.Doubles ) {
36
- for (Value *Op : I.operands ())
37
- CSF.Doubles |= Op->getType ()->isDoubleTy ();
43
+ for (const Value *Op : I.operands ()) {
44
+ if (Op->getType ()->isDoubleTy ()) {
45
+ CSF.Doubles = true ;
46
+ break ;
47
+ }
48
+ }
38
49
}
50
+
39
51
if (CSF.Doubles ) {
40
52
switch (I.getOpcode ()) {
41
53
case Instruction::FDiv:
42
54
case Instruction::UIToFP:
43
55
case Instruction::SIToFP:
44
56
case Instruction::FPToUI:
45
57
case Instruction::FPToSI:
46
- // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
47
- // https://github.com/llvm/llvm-project/issues/114554
48
58
CSF.DX11_1_DoubleExtensions = true ;
49
59
break ;
50
60
}
@@ -62,27 +72,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
62
72
}
63
73
}
64
74
}
75
+ // Handle call instructions
76
+ if (auto *CI = dyn_cast<CallInst>(&I)) {
77
+ const Function *CF = CI->getCalledFunction ();
78
+ // Merge-in shader flags mask of the called function in the current module
79
+ if (FunctionFlags.contains (CF))
80
+ CSF.merge (FunctionFlags[CF]);
81
+
82
+ // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
83
+ // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
84
+ }
65
85
}
66
86
67
- void ModuleShaderFlags::initialize (const Module &M, DXILResourceTypeMap &DRTM) {
68
-
69
- // Collect shader flags for each of the functions
70
- for (const auto &F : M.getFunctionList ()) {
71
- if (F.isDeclaration ()) {
72
- assert (!F.getName ().starts_with (" dx.op." ) &&
73
- " DXIL Shader Flag analysis should not be run post-lowering." );
74
- continue ;
87
+ // / Construct ModuleShaderFlags for module Module M
88
+ void ModuleShaderFlags::initialize (Module &M, DXILResourceTypeMap &DRTM) {
89
+ CallGraph CG (M);
90
+
91
+ // Compute Shader Flags Mask for all functions using post-order visit of SCC
92
+ // of the call graph.
93
+ for (scc_iterator<CallGraph *> SCCI = scc_begin (&CG); !SCCI.isAtEnd ();
94
+ ++SCCI) {
95
+ const std::vector<CallGraphNode *> &CurSCC = *SCCI;
96
+
97
+ // Union of shader masks of all functions in CurSCC
98
+ ComputedShaderFlags SCCSF;
99
+ // List of functions in CurSCC that are neither external nor declarations
100
+ // and hence whose flags are collected
101
+ SmallVector<Function *> CurSCCFuncs;
102
+ for (CallGraphNode *CGN : CurSCC) {
103
+ Function *F = CGN->getFunction ();
104
+ if (!F)
105
+ continue ;
106
+
107
+ if (F->isDeclaration ()) {
108
+ assert (!F->getName ().starts_with (" dx.op." ) &&
109
+ " DXIL Shader Flag analysis should not be run post-lowering." );
110
+ continue ;
111
+ }
112
+
113
+ ComputedShaderFlags CSF;
114
+ for (const auto &BB : *F)
115
+ for (const auto &I : BB)
116
+ updateFunctionFlags (CSF, I, DRTM);
117
+ // Update combined shader flags mask for all functions in this SCC
118
+ SCCSF.merge (CSF);
119
+
120
+ CurSCCFuncs.push_back (F);
75
121
}
76
- ComputedShaderFlags CSF;
77
- for (const auto &BB : F)
78
- for (const auto &I : BB)
79
- updateFunctionFlags (CSF, I, DRTM);
80
- // Insert shader flag mask for function F
81
- FunctionFlags.push_back ({&F, CSF});
82
- // Update combined shader flags mask
83
- CombinedSFMask.merge (CSF);
122
+
123
+ // Update combined shader flags mask for all functions of the module
124
+ CombinedSFMask.merge (SCCSF);
125
+
126
+ // Shader flags mask of each of the functions in an SCC of the call graph is
127
+ // the union of all functions in the SCC. Update shader flags masks of
128
+ // functions in CurSCC accordingly. This is trivially true if SCC contains
129
+ // one function.
130
+ for (Function *F : CurSCCFuncs)
131
+ // Merge SCCSF with that of F
132
+ FunctionFlags[F].merge (SCCSF);
84
133
}
85
- llvm::sort (FunctionFlags);
86
134
}
87
135
88
136
void ComputedShaderFlags::print (raw_ostream &OS) const {
@@ -106,12 +154,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
106
154
// / Return the shader flags mask of the specified function Func.
107
155
const ComputedShaderFlags &
108
156
ModuleShaderFlags::getFunctionFlags (const Function *Func) const {
109
- const auto Iter = llvm::lower_bound (
110
- FunctionFlags, Func,
111
- [](const std::pair<const Function *, ComputedShaderFlags> FSM,
112
- const Function *FindFunc) { return (FSM.first < FindFunc); });
157
+ auto Iter = FunctionFlags.find (Func);
113
158
assert ((Iter != FunctionFlags.end () && Iter->first == Func) &&
114
- " No Shader Flags Mask exists for function" );
159
+ " Get Shader Flags : No Shader Flags Mask exists for function" );
115
160
return Iter->second ;
116
161
}
117
162
@@ -142,7 +187,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
142
187
for (const auto &F : M.getFunctionList ()) {
143
188
if (F.isDeclaration ())
144
189
continue ;
145
- auto SFMask = FlagsInfo.getFunctionFlags (&F);
190
+ const ComputedShaderFlags & SFMask = FlagsInfo.getFunctionFlags (&F);
146
191
OS << formatv (" ; Function {0} : {1:x8}\n ;\n " , F.getName (),
147
192
(uint64_t )(SFMask));
148
193
}
0 commit comments