-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[DirectX] Propagate shader flags mask of callees to callers #118306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DirectX] Propagate shader flags mask of callees to callers #118306
Conversation
@llvm/pr-subscribers-backend-directx Author: S. Bharadwaj Yadavalli (bharadwajy) ChangesPropagate shader flags mask of callees to callers. Add test to verify propagation of shader flags Full diff: https://github.com/llvm/llvm-project/pull/118306.diff 4 Files Affected:
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98abd5..f242204363cfe8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -15,6 +15,7 @@
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
@@ -47,10 +48,14 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
}
void ModuleShaderFlags::initialize(const Module &M) {
+ SmallVector<const Function *> WorkList;
// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
+ if (!F.user_empty()) {
+ WorkList.push_back(&F);
+ }
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
@@ -61,6 +66,21 @@ void ModuleShaderFlags::initialize(const Module &M) {
CombinedSFMask.merge(CSF);
}
llvm::sort(FunctionFlags);
+ // Propagate shader flag mask of functions to their callers.
+ while (!WorkList.empty()) {
+ const Function *Func = WorkList.pop_back_val();
+ if (!Func->user_empty()) {
+ ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+ // Update mask of callers with that of Func
+ for (const auto User : Func->users()) {
+ if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+ const Function *Caller = CI->getParent()->getParent();
+ if (mergeFunctionShaderFlags(Caller, FuncSF))
+ WorkList.push_back(Caller);
+ }
+ }
+ }
+ }
}
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -81,16 +101,31 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-/// Return the shader flags mask of the specified function Func.
-const ComputedShaderFlags &
-ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
const auto Iter = llvm::lower_bound(
FunctionFlags, Func,
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
- return Iter->second;
+ return Iter;
+}
+
+/// Merge mask NewSF to that of Func, if different.
+/// Return true if mask of Func is changed, else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(
+ const Function *Func, const ComputedShaderFlags NewSF) {
+ const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
+ if ((FuncSFInfo->second & NewSF) != NewSF) {
+ const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
+ return true;
+ }
+ return false;
+}
+/// Return the shader flags mask of the specified function Func.
+const ComputedShaderFlags &
+ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+ return getFunctionShaderFlagInfo(Func)->second;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..8c581f243ca98b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
+ auto getFunctionShaderFlagInfo(const Function *) const;
+ bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
};
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..8e5e61b42469ad 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions
+;CHECK: ; Function top_level : 0x00000044
+define void @top_level() #0 {
+ call void @test_uitofp_i64(i64 noundef 5)
+ ret void
+}
+
+
; CHECK: ; Function test_fdiv_double : 0x00000044
define double @test_fdiv_double(double %a, double %b) #0 {
%res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..93d634c0384ae7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ; Double-precision floating point
+; CHECK-NEXT: ; Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; CHECK: ; Function call_n6 : 0x00000044
+define double @call_n6(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = sitofp i32 %0 to double
+ ret double %2
+}
+; CHECK: ; Function call_n4 : 0x00000044
+define double @call_n4(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n6(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n7 : 0x00000044
+define double @call_n7(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = uitofp i32 %0 to double
+ ret double %2
+}
+
+; CHECK: ; Function call_n5 : 0x00000044
+define double @call_n5(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n7(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n2 : 0x00000044
+define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
+ %2 = icmp ult i64 %0, 6
+ br i1 %2, label %3, label %7
+
+3: ; preds = %1
+ %4 = add nuw nsw i64 %0, 1
+ %5 = uitofp i64 %4 to double
+ %6 = tail call double @call_n1(double noundef %5)
+ br label %10
+
+7: ; preds = %1
+ %8 = trunc i64 %0 to i32
+ %9 = tail call double @call_n4(i32 noundef %8)
+ br label %10
+
+10: ; preds = %7, %3
+ %11 = phi double [ %6, %3 ], [ %9, %7 ]
+ ret double %11
+}
+
+; CHECK: ; Function call_n1 : 0x00000044
+define double @call_n1(double noundef %0) local_unnamed_addr #0 {
+ %2 = fcmp ugt double %0, 5.000000e+00
+ br i1 %2, label %6, label %3
+
+3: ; preds = %1
+ %4 = fptoui double %0 to i64
+ %5 = tail call double @call_n2(i64 noundef %4)
+ br label %9
+
+6: ; preds = %1
+ %7 = fptoui double %0 to i32
+ %8 = tail call double @call_n5(i32 noundef %7)
+ br label %9
+
+9: ; preds = %6, %3
+ %10 = phi double [ %5, %3 ], [ %8, %6 ]
+ ret double %10
+}
+
+; CHECK: ; Function call_n3 : 0x00000044
+define double @call_n3(double noundef %0) local_unnamed_addr #0 {
+ %2 = fdiv double %0, 3.000000e+00
+ ret double %2
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+ %1 = tail call double @call_n1(double noundef 1.000000e+00)
+ %2 = tail call double @call_n3(double noundef %1)
+ ret i32 0
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
|
if ((FuncSFInfo->second & NewSF) != NewSF) { | ||
const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This presumes that the merge is a simple bitwise OR operation and you can use this &
to detect a change. It also uses &
operator on the ComputedShaderFlags
type, which I assume works by triggering an implicit cast to uint64_t
. I believe it would be cleaner to allow for a copy of the ComputedShaderFlags
and to compare before and after merge to detect a change.
As a side note, we probably don't need to have all the shader flags defined as bitfield entries with one-by-one translation from bitfields to uint64_t. I certainly get that bitfields don't have a guaranteed layout in C++, but we could instead use an explicit uint64_t
field with an enum defining the (DxilModule) bit shift for each flag (plus generated set/get accessors). If we did that, operations such as this would be more efficient, since it doesn't have to go through multiple translations between individual bits and combined uint64_t
values each time. Copying and merging would be made trivial as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very much an abuse of the user-defined conversion operator... maybe we should remove that to prevent future abuses.
It seems like a better approach here since you're looking to compare equality is to just implement operator==
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait... I'm clearly not thinking straight because that last comment doesn't fully make sense. I do think we should probably get rid of the implicit conversion operator. Clearly operator==
isn't the right approach.
For flag merging, we maybe don't even need to check it conditionally though. This should all be really well optimized straight-line code to merge flags, so the branch is probably unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed merge()
to operate on ComputeShaderFlags
type thereby not relying on implicit cast to uint64_t
.
Updated mergeFunctionShaderFlags()
with the check for shader flag value change.
However, testing for equality still implicitly casts to uint64_t
. Getting rid of implicit conversion operator requires changes to other access functions such as getFeatureFlags()
etc., that use the conversion. I'll plan to delete implicit conversion operator in a follow-on PR. I think this implies that operator==
would need to be implemented. Created an issue for this.
@@ -95,6 +95,8 @@ struct ModuleShaderFlags { | |||
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags; | |||
/// Combined Shader Flag Mask of all functions of the module | |||
ComputedShaderFlags CombinedSFMask{}; | |||
auto getFunctionShaderFlagInfo(const Function *) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM has an "almost never auto
" policy. We should only use auto
where it makes code easier to read and where types are obvious from other context. That is not the case here.
see: https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function is deleted in commit that addresses the PR feedback..
/// Return true if mask of Func is changed, else false. | ||
bool ModuleShaderFlags::mergeFunctionShaderFlags( | ||
const Function *Func, const ComputedShaderFlags NewSF) { | ||
const auto FuncSFInfo = getFunctionShaderFlagInfo(Func); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As written, this is probably a const
local variable (copy)... which you then const_cast
below. This is a prime example of why auto
should be used judiciously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As written, this is probably a
const
local variable (copy)... which you thenconst_cast
below. This is a prime example of whyauto
should be used judiciously.
Yeah, the root cause for the use of const_cast
and auto
return type stems from the desire to define the functionality to search through the FunctionFlags
in a single place. The idea is to use the same function (which is getFunctionShaderFlagsInfo()
) to (a) query for the shader flags of a Function *
and (b) to locate the std::pair<>
for merge/update the shader flags of a given Function *
.
Will try out a different refactoring while keeping that goal intact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored to distinguish search result usage in the context of a "getter" vs "setter" (merge) to eliminate the need to use const_cast
.
if (!F.user_empty()) { | ||
WorkList.push_back(&F); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if (!F.user_empty()) { | |
WorkList.push_back(&F); | |
} | |
if (!F.user_empty()) | |
WorkList.push_back(&F); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed per suggested LLVM coding style.
if (!Func->user_empty()) { | ||
const ComputedShaderFlags &FuncSF = getFunctionFlags(Func); | ||
// Update mask of callers with that of Func | ||
for (const auto User : Func->users()) { | ||
if (const CallInst *CI = dyn_cast<CallInst>(User)) { | ||
const Function *Caller = CI->getParent()->getParent(); | ||
if (mergeFunctionShaderFlags(Caller, FuncSF)) | ||
WorkList.push_back(Caller); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few different style nits here:
if (!Func->user_empty()) { | |
const ComputedShaderFlags &FuncSF = getFunctionFlags(Func); | |
// Update mask of callers with that of Func | |
for (const auto User : Func->users()) { | |
if (const CallInst *CI = dyn_cast<CallInst>(User)) { | |
const Function *Caller = CI->getParent()->getParent(); | |
if (mergeFunctionShaderFlags(Caller, FuncSF)) | |
WorkList.push_back(Caller); | |
} | |
} | |
} | |
if (Func->user_empty()) | |
continue; | |
const ComputedShaderFlags &FuncSF = getFunctionFlags(Func); | |
// Update mask of callers with that of Func | |
for (const auto *User : Func->users()) { | |
const CallInst *CI = dyn_cast<CallInst>(User); | |
if (!CI) | |
continue; | |
const Function *Caller = CI->getParent()->getParent(); | |
if (mergeFunctionShaderFlags(Caller, FuncSF)) | |
WorkList.push_back(Caller); | |
} |
see:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed per suggested LLVM coding style.
[](const std::pair<const Function *, ComputedShaderFlags> FSM, | ||
const Function *FindFunc) { return (FSM.first < FindFunc); }); | ||
const std::pair<const Function *, ComputedShaderFlags> *Iter = | ||
llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't be storing pairs sorted by pointer and doing searches on a vector. A DenseMap will be algorithmically more efficient and is the correct data structure for the use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't be storing pairs sorted by pointer and doing searches on a vector. A DenseMap will be algorithmically more efficient and is the correct data structure for the use case.
Collection of shader flags in the analysis pass and subsequent queries in later passes follows the "distinct pattern" outlined in LLVM Programmer's Manual. Hence the choice of SmallVector
as opposed to other Map
-like containers such as DenseMap
. Other considerations (including those related to memory usage) listed in DenseMap - with seemingly higher cost for desired functionality - also influenced the choice of SmallVector
. As for the sort order, since the Function
pointers are unique per compilation invocation and just need to be sorted in some order during such an invocation, compareShaderFlagsInfo
explicitly compares function pointers (although default comparator of std::pair
does the same).
That said, I am open to any strong recommendation to change it to DenseMap
for reasons not considered above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes made to use DenseMap
instead of Smallvector
as suggested.
@@ -117,7 +156,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, | |||
for (const auto &F : M.getFunctionList()) { | |||
if (F.isDeclaration()) | |||
continue; | |||
auto SFMask = FlagsInfo.getFunctionFlags(&F); | |||
const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine, but FWIW, since sizeof(ComputedShaderFlags) == 8
, the cost of a copy here is negligible since the address being copied for the reference will also be 64-bits in all the cases we care about.
if (mergeFunctionShaderFlags(Caller, FuncSF)) | ||
WorkList.push_back(Caller); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach is very inefficient. Consider a chain of functions, f1
, f2
, f3
..., where f1
has no flags, f2
has one flag, f3
has another flag, etc. In this case we might process f1
, then add it to the worklist again while processing f2
, process it again, and then add both f2
to the worklist while processing f3
, and then add f1
to be processes again while processing f2
, and so on.
The correct way to do this type of thing is to walk the call graph in post-order. One way to do this is by switching this pass to a CallGraphSCCPass (old PM) / CGSCCPass (new PM).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, since this is an analysis and it's used by module passes we won't be able to simply convert to an SCC pass. As is, we may need to construct the call graph manually to do this in the more efficient way. It's probably fine to just construct a CallGraph and use SCCIterator to process the functions in post order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes made to build a CallGraph
and use scc_iterator
to create and propagate shader flags masks for module functions.
auto Iter = FunctionFlags.find(Func); | ||
assert((Iter != FunctionFlags.end() && Iter->first == Func) && | ||
"Merge Shader Flags : No Shader Flags Mask exists for function"); | ||
Iter->second.merge(SF); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
auto Iter = FunctionFlags.find(Func); | |
assert((Iter != FunctionFlags.end() && Iter->first == Func) && | |
"Merge Shader Flags : No Shader Flags Mask exists for function"); | |
Iter->second.merge(SF); | |
assert(FunctionFlags.contains(Func) && "Function does not have shader flags."); | |
FunctionFlags[Func].merge(SF); |
Since this can be reduced to a single line of code, it might not actually be worth having a method for it.
What are the conditions where a function might not have shader flags? Can it be considered valid to just treat any function that doesn't have flags as having no flags?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the conditions where a function might not have shader flags?
AFAIK, in a well-formed module, a function that is not a declaration, has an implementation associated. As a result, it must have a shader flags mask (with zero or a non-zero value) associated with it. Hence, a function that has no shader flags associated with it is not in the module. Such situation is considered a bug in the analysis pass or in the transformation pass that queries for shader flags of such (an unknown) function. Hence the assert
s in mergeDunctionShaderFlags()
and getFunctionFlags()
where, in addition, []
is not used since it inserts a default initialized shader flag mask.
Can it be considered valid to just treat any function that doesn't have flags as having no flags?
By treating a function with non-existent shader flags in the shader flags map as having no flags (and not asserting), it appears that consistency between other passes and information collected by shader flag analysis pass would be broken - at a minimum. In other words, shader flag analysis information is providing incorrect information about a function whose existence it has no information about.
Let me know if I missed something in any of my assumptions about module functions.
assert((Iter != FunctionFlags.end() && Iter->first == Func) && | ||
"No Shader Flags Mask exists for function"); | ||
"Get Shader Flags : No Shader Flags Mask exists for function"); | ||
return Iter->second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would actually be safer to return FunctionFlags[Func]
instead of using Find. Then if it fails, the returned result is a default initialized flags structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would actually be safer to return
FunctionFlags[Func]
instead of using Find. Then if it fails, the returned result is a default initialized flags structure.
Returning default initialized flags structure value for F
whose shader flags mask does not exist would amount to providing incorrect information about a function whose existence it has no information about and in addition inserting it to the DenseMap
. Such functionality is not the intent of getFunctionFlags()
. Hence the usage of find
and assertion while getting shader flags for a Function *F
.
// Propagate Shader Flag Masks to callers with another post-order call graph | ||
// walk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be simpler to propagate the mask to callers in the first loop? Since we're iterating in SCC order, the loop over instructions above should just be able to query the flags for any call it encounters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be simpler to propagate the mask to callers in the first loop? Since we're iterating in SCC order, the loop over instructions above should just be able to query the flags for any call it encounters.
Changes made to compute shader flags mask of SCC functions and propagate them to their callers into one loop.
// Insert shader flag mask for function F | ||
FunctionFlags.insert({F, CSF}); | ||
// Update combined shader flags mask for all functions of the module | ||
CombinedSFMask.merge(CSF); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have to calculate SCCSF anyways, we can hoist the merge of the combined mask out of this loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have to calculate SCCSF anyways, we can hoist the merge of the combined mask out of this loop
Changed.
Function *F = CGN->getFunction(); | ||
if (!F) | ||
continue; | ||
mergeFunctionShaderFlags(F, SCCSF); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we defer inserting into the function flags map in the loop above and just calculate SCCSF there, and also remove the CurSCC.size() < 2
check, then we can avoid the extra merge here and simply insert the function flags for the SCC for each function at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we defer inserting into the function flags map in the loop above and just calculate SCCSF there, and also remove the
CurSCC.size() < 2
check, then we can avoid the extra merge here and simply insert the function flags for the SCC for each function at this point.
Changes made to compute shader flags mask of SCC functions and propagate them to their callers into one loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments on the algorithm. Also we should probably have a test that has an SCC of size greater than one.
// Propagate Shader Flag Masks to callers of F | ||
for (const auto User : F->users()) { | ||
if (const CallInst *CI = dyn_cast<CallInst>(User)) { | ||
const Function *Caller = CI->getParent()->getParent(); | ||
if (FunctionFlags.contains(Caller)) | ||
FunctionFlags[Caller].merge(SCCSF); | ||
else | ||
FunctionFlags[Caller] = SCCSF; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need a loop to propagate flags to callers. Instead, we should handle this in updateFunctionFlags
with something like:
if (auto *CI = dyn_cast<CallInst>(&I))
CSF.merge(FunctionFlags[CI->getCalledFunction()]);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need a loop to propagate flags to callers. Instead, we should handle this in
updateFunctionFlags
with something like:if (auto *CI = dyn_cast<CallInst>(&I)) CSF.merge(FunctionFlags[CI->getCalledFunction()]);
Moved functionality to updateFunctionFlags()
.
if (FunctionFlags.contains(F)) | ||
FunctionFlags[F].merge(SCCSF); | ||
else | ||
FunctionFlags[F] = SCCSF; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant. FunctionFlags[F].merge(SCCSF)
will do the right thing if the function doesn't yet have a mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant.
FunctionFlags[F].merge(SCCSF)
will do the right thing if the function doesn't yet have a mask.
Changed.
if (!F) | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably have the exact same checks here as the loop above (ie, skip both external nodes and declarations). It may be arguably more maintainable to just fill a SmallVector<Function *>
with the functions we handle in the SCC and loop over that instead of looping over CurSCC again and having to repeat the checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably have the exact same checks here as the loop above (ie, skip both external nodes and declarations). It may be arguably more maintainable to just fill a
SmallVector<Function *>
with the functions we handle in the SCC and loop over that instead of looping over CurSCC again and having to repeat the checks.
Changed to use Smallvector<Function *>
as suggested.
void ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func, | ||
ComputedShaderFlags SF) { | ||
assert(FunctionFlags.contains(Func) && | ||
"Merge Shader Flags : No Shader Flags Mask exists for function"); | ||
FunctionFlags[Func].merge(SF); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't appear to be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't appear to be used.
Deleted unused function mergeFunctionShaderFlags()
.
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library" | |||
; CHECK-NEXT: ; | |||
; CHECK-NEXT: ; Shader Flags for Module Functions | |||
|
|||
;CHECK: ; Function top_level : 0x00000044 | |||
define void @top_level() #0 { | |||
call void @test_uitofp_i64(i64 noundef 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no void @test_uitofp_i64
, and it isn't clear to me what noundef is doing here. I think you meant call double @test_uitofp_i64(i64 5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no
void @test_uitofp_i64
, and it isn't clear to me what noundef is doing here. I think you meantcall double @test_uitofp_i64(i64 5)
Changed.
Call Graph of code in |
Ah, that's true, but it doesn't seem to be a very good test case. Consider: diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 0392d3c84735..01bf300a7f02 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -97,19 +97,12 @@ void ModuleShaderFlags::initialize(Module &M) {
// Update combined shader flags mask for all functions in this SCC
SCCSF.merge(CSF);
- CurSCCFuncs.push_back(F);
+ FunctionFlags[F].merge(SCCSF);
}
// Update combined shader flags mask for all functions of the module
CombinedSFMask.merge(SCCSF);
- // Shader flags mask of each of the functions in an SCC of the call graph is
- // the union of all functions in the SCC. Update shader flags masks of
- // functions in CurSCC accordingly. This is trivially true if SCC contains
- // one function.
- for (Function *F : CurSCCFuncs)
- // Merge SCCSF with that of F
- FunctionFlags[F].merge(SCCSF);
}
} With this, |
Expanded |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine, though the propagate-function-flags test is pretty inscrutable. It would be nice to use function names that make it clearer what's being tested, and probably have some comments about what the CFG looks like, otherwise this test will be more or less impossible to update in the future.
Add tests to verify propagation of shader flags
805810e
to
b6dfe53
Compare
Updated the test as suggested. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments. Take them or leave them.
if (FunctionFlags.contains(CF)) { | ||
CSF.merge(FunctionFlags[CF]); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if (FunctionFlags.contains(CF)) { | |
CSF.merge(FunctionFlags[CF]); | |
} | |
if (FunctionFlags.contains(CF)) | |
CSF.merge(FunctionFlags[CF]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
continue; | ||
|
||
if (F->isDeclaration()) { | ||
assert(!F->getName().starts_with("dx.op.") && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be report_fatal_error
instead of an assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be
report_fatal_error
instead of an assert?
Possibly... however, this change is a result of merging current upstream version of the file. My thought (admittedly not too strong) is to not change as part of this PR.
Propagate shader flags mask of callees to callers.
Add test to verify propagation of shader flags
Addresses #112949