Skip to content

[DirectX] Propagate shader flags mask of callees to callers #118306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 14, 2025

Conversation

bharadwajy
Copy link
Contributor

@bharadwajy bharadwajy commented Dec 2, 2024

Propagate shader flags mask of callees to callers.

Add test to verify propagation of shader flags

Addresses #112949

@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2024

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

Changes

Propagate shader flags mask of callees to callers.

Add test to verify propagation of shader flags


Full diff: https://github.com/llvm/llvm-project/pull/118306.diff

4 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+39-4)
  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.h (+2)
  • (modified) llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll (+7)
  • (added) llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll (+92)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98abd5..f242204363cfe8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -15,6 +15,7 @@
 #include "DirectX.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
@@ -47,10 +48,14 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
 }
 
 void ModuleShaderFlags::initialize(const Module &M) {
+  SmallVector<const Function *> WorkList;
   // Collect shader flags for each of the functions
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
+    if (!F.user_empty()) {
+      WorkList.push_back(&F);
+    }
     ComputedShaderFlags CSF;
     for (const auto &BB : F)
       for (const auto &I : BB)
@@ -61,6 +66,21 @@ void ModuleShaderFlags::initialize(const Module &M) {
     CombinedSFMask.merge(CSF);
   }
   llvm::sort(FunctionFlags);
+  // Propagate shader flag mask of functions to their callers.
+  while (!WorkList.empty()) {
+    const Function *Func = WorkList.pop_back_val();
+    if (!Func->user_empty()) {
+      ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+      // Update mask of callers with that of Func
+      for (const auto User : Func->users()) {
+        if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+          const Function *Caller = CI->getParent()->getParent();
+          if (mergeFunctionShaderFlags(Caller, FuncSF))
+            WorkList.push_back(Caller);
+        }
+      }
+    }
+  }
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -81,16 +101,31 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-/// Return the shader flags mask of the specified function Func.
-const ComputedShaderFlags &
-ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
   const auto Iter = llvm::lower_bound(
       FunctionFlags, Func,
       [](const std::pair<const Function *, ComputedShaderFlags> FSM,
          const Function *FindFunc) { return (FSM.first < FindFunc); });
   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
          "No Shader Flags Mask exists for function");
-  return Iter->second;
+  return Iter;
+}
+
+/// Merge mask NewSF to that of Func, if different.
+/// Return true if mask of Func is changed, else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(
+    const Function *Func, const ComputedShaderFlags NewSF) {
+  const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
+  if ((FuncSFInfo->second & NewSF) != NewSF) {
+    const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
+    return true;
+  }
+  return false;
+}
+/// Return the shader flags mask of the specified function Func.
+const ComputedShaderFlags &
+ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+  return getFunctionShaderFlagInfo(Func)->second;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..8c581f243ca98b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
   SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
   /// Combined Shader Flag Mask of all functions of the module
   ComputedShaderFlags CombinedSFMask{};
+  auto getFunctionShaderFlagInfo(const Function *) const;
+  bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..8e5e61b42469ad 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Shader Flags for Module Functions
 
+;CHECK: ; Function top_level : 0x00000044
+define void @top_level() #0 {
+  call void @test_uitofp_i64(i64 noundef 5)
+  ret void
+}
+
+
 ; CHECK: ; Function test_fdiv_double : 0x00000044
 define double @test_fdiv_double(double %a, double %b) #0 {
   %res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..93d634c0384ae7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; CHECK: ; Function call_n6 : 0x00000044
+define double @call_n6(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = sitofp i32 %0 to double
+  ret double %2
+}
+; CHECK: ; Function call_n4 : 0x00000044
+define double @call_n4(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @call_n6(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function call_n7 : 0x00000044
+define double @call_n7(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = uitofp i32 %0 to double
+  ret double %2
+}
+
+; CHECK: ; Function call_n5 : 0x00000044
+define double @call_n5(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @call_n7(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function call_n2 : 0x00000044
+define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
+  %2 = icmp ult i64 %0, 6
+  br i1 %2, label %3, label %7
+
+3:                                                ; preds = %1
+  %4 = add nuw nsw i64 %0, 1
+  %5 = uitofp i64 %4 to double
+  %6 = tail call double @call_n1(double noundef %5)
+  br label %10
+
+7:                                                ; preds = %1
+  %8 = trunc i64 %0 to i32
+  %9 = tail call double @call_n4(i32 noundef %8)
+  br label %10
+
+10:                                               ; preds = %7, %3
+  %11 = phi double [ %6, %3 ], [ %9, %7 ]
+  ret double %11
+}
+
+; CHECK: ; Function call_n1 : 0x00000044
+define double @call_n1(double noundef %0) local_unnamed_addr #0 {
+  %2 = fcmp ugt double %0, 5.000000e+00
+  br i1 %2, label %6, label %3
+
+3:                                                ; preds = %1
+  %4 = fptoui double %0 to i64
+  %5 = tail call double @call_n2(i64 noundef %4)
+  br label %9
+
+6:                                                ; preds = %1
+  %7 = fptoui double %0 to i32
+  %8 = tail call double @call_n5(i32 noundef %7)
+  br label %9
+
+9:                                                ; preds = %6, %3
+  %10 = phi double [ %5, %3 ], [ %8, %6 ]
+  ret double %10
+}
+
+; CHECK: ; Function call_n3 : 0x00000044
+define double @call_n3(double noundef %0) local_unnamed_addr #0 {
+  %2 = fdiv double %0, 3.000000e+00
+  ret double %2
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+  %1 = tail call double @call_n1(double noundef 1.000000e+00)
+  %2 = tail call double @call_n3(double noundef %1)
+  ret i32 0
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}

@bharadwajy bharadwajy linked an issue Dec 2, 2024 that may be closed by this pull request
Comment on lines 119 to 120
if ((FuncSFInfo->second & NewSF) != NewSF) {
const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This presumes that the merge is a simple bitwise OR operation and you can use this & to detect a change. It also uses & operator on the ComputedShaderFlags type, which I assume works by triggering an implicit cast to uint64_t. I believe it would be cleaner to allow for a copy of the ComputedShaderFlags and to compare before and after merge to detect a change.

As a side note, we probably don't need to have all the shader flags defined as bitfield entries with one-by-one translation from bitfields to uint64_t. I certainly get that bitfields don't have a guaranteed layout in C++, but we could instead use an explicit uint64_t field with an enum defining the (DxilModule) bit shift for each flag (plus generated set/get accessors). If we did that, operations such as this would be more efficient, since it doesn't have to go through multiple translations between individual bits and combined uint64_t values each time. Copying and merging would be made trivial as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very much an abuse of the user-defined conversion operator... maybe we should remove that to prevent future abuses.

It seems like a better approach here since you're looking to compare equality is to just implement operator==.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait... I'm clearly not thinking straight because that last comment doesn't fully make sense. I do think we should probably get rid of the implicit conversion operator. Clearly operator== isn't the right approach.

For flag merging, we maybe don't even need to check it conditionally though. This should all be really well optimized straight-line code to merge flags, so the branch is probably unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed merge() to operate on ComputeShaderFlags type thereby not relying on implicit cast to uint64_t.

Updated mergeFunctionShaderFlags() with the check for shader flag value change.

However, testing for equality still implicitly casts to uint64_t. Getting rid of implicit conversion operator requires changes to other access functions such as getFeatureFlags() etc., that use the conversion. I'll plan to delete implicit conversion operator in a follow-on PR. I think this implies that operator== would need to be implemented. Created an issue for this.

@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
auto getFunctionShaderFlagInfo(const Function *) const;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM has an "almost never auto" policy. We should only use auto where it makes code easier to read and where types are obvious from other context. That is not the case here.

see: https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function is deleted in commit that addresses the PR feedback..

/// Return true if mask of Func is changed, else false.
bool ModuleShaderFlags::mergeFunctionShaderFlags(
const Function *Func, const ComputedShaderFlags NewSF) {
const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written, this is probably a const local variable (copy)... which you then const_cast below. This is a prime example of why auto should be used judiciously.

Copy link
Contributor Author

@bharadwajy bharadwajy Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written, this is probably a const local variable (copy)... which you then const_cast below. This is a prime example of why auto should be used judiciously.

Yeah, the root cause for the use of const_cast and auto return type stems from the desire to define the functionality to search through the FunctionFlags in a single place. The idea is to use the same function (which is getFunctionShaderFlagsInfo()) to (a) query for the shader flags of a Function * and (b) to locate the std::pair<> for merge/update the shader flags of a given Function *.

Will try out a different refactoring while keeping that goal intact.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to distinguish search result usage in the context of a "getter" vs "setter" (merge) to eliminate the need to use const_cast.

Comment on lines 56 to 82
if (!F.user_empty()) {
WorkList.push_back(&F);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (!F.user_empty()) {
WorkList.push_back(&F);
}
if (!F.user_empty())
WorkList.push_back(&F);

see: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed per suggested LLVM coding style.

Comment on lines 72 to 122
if (!Func->user_empty()) {
const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
// Update mask of callers with that of Func
for (const auto User : Func->users()) {
if (const CallInst *CI = dyn_cast<CallInst>(User)) {
const Function *Caller = CI->getParent()->getParent();
if (mergeFunctionShaderFlags(Caller, FuncSF))
WorkList.push_back(Caller);
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few different style nits here:

Suggested change
if (!Func->user_empty()) {
const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
// Update mask of callers with that of Func
for (const auto User : Func->users()) {
if (const CallInst *CI = dyn_cast<CallInst>(User)) {
const Function *Caller = CI->getParent()->getParent();
if (mergeFunctionShaderFlags(Caller, FuncSF))
WorkList.push_back(Caller);
}
}
}
if (Func->user_empty())
continue;
const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
// Update mask of callers with that of Func
for (const auto *User : Func->users()) {
const CallInst *CI = dyn_cast<CallInst>(User);
if (!CI)
continue;
const Function *Caller = CI->getParent()->getParent();
if (mergeFunctionShaderFlags(Caller, FuncSF))
WorkList.push_back(Caller);
}

see:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed per suggested LLVM coding style.

[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
const std::pair<const Function *, ComputedShaderFlags> *Iter =
llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be storing pairs sorted by pointer and doing searches on a vector. A DenseMap will be algorithmically more efficient and is the correct data structure for the use case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be storing pairs sorted by pointer and doing searches on a vector. A DenseMap will be algorithmically more efficient and is the correct data structure for the use case.

Collection of shader flags in the analysis pass and subsequent queries in later passes follows the "distinct pattern" outlined in LLVM Programmer's Manual. Hence the choice of SmallVector as opposed to other Map-like containers such as DenseMap. Other considerations (including those related to memory usage) listed in DenseMap - with seemingly higher cost for desired functionality - also influenced the choice of SmallVector. As for the sort order, since the Function pointers are unique per compilation invocation and just need to be sorted in some order during such an invocation, compareShaderFlagsInfo explicitly compares function pointers (although default comparator of std::pair does the same).

That said, I am open to any strong recommendation to change it to DenseMap for reasons not considered above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes made to use DenseMap instead of Smallvector as suggested.

@@ -117,7 +156,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
auto SFMask = FlagsInfo.getFunctionFlags(&F);
const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but FWIW, since sizeof(ComputedShaderFlags) == 8, the cost of a copy here is negligible since the address being copied for the reference will also be 64-bits in all the cases we care about.

Comment on lines 78 to 79
if (mergeFunctionShaderFlags(Caller, FuncSF))
WorkList.push_back(Caller);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is very inefficient. Consider a chain of functions, f1, f2, f3..., where f1 has no flags, f2 has one flag, f3 has another flag, etc. In this case we might process f1, then add it to the worklist again while processing f2, process it again, and then add both f2 to the worklist while processing f3, and then add f1 to be processes again while processing f2, and so on.

The correct way to do this type of thing is to walk the call graph in post-order. One way to do this is by switching this pass to a CallGraphSCCPass (old PM) / CGSCCPass (new PM).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, since this is an analysis and it's used by module passes we won't be able to simply convert to an SCC pass. As is, we may need to construct the call graph manually to do this in the more efficient way. It's probably fine to just construct a CallGraph and use SCCIterator to process the functions in post order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes made to build a CallGraph and use scc_iterator to create and propagate shader flags masks for module functions.

Comment on lines 158 to 161
auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"Merge Shader Flags : No Shader Flags Mask exists for function");
Iter->second.merge(SF);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

Suggested change
auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"Merge Shader Flags : No Shader Flags Mask exists for function");
Iter->second.merge(SF);
assert(FunctionFlags.contains(Func) && "Function does not have shader flags.");
FunctionFlags[Func].merge(SF);

Since this can be reduced to a single line of code, it might not actually be worth having a method for it.

What are the conditions where a function might not have shader flags? Can it be considered valid to just treat any function that doesn't have flags as having no flags?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the conditions where a function might not have shader flags?

AFAIK, in a well-formed module, a function that is not a declaration, has an implementation associated. As a result, it must have a shader flags mask (with zero or a non-zero value) associated with it. Hence, a function that has no shader flags associated with it is not in the module. Such situation is considered a bug in the analysis pass or in the transformation pass that queries for shader flags of such (an unknown) function. Hence the asserts in mergeDunctionShaderFlags() and getFunctionFlags() where, in addition, [] is not used since it inserts a default initialized shader flag mask.

Can it be considered valid to just treat any function that doesn't have flags as having no flags?

By treating a function with non-existent shader flags in the shader flags map as having no flags (and not asserting), it appears that consistency between other passes and information collected by shader flag analysis pass would be broken - at a minimum. In other words, shader flag analysis information is providing incorrect information about a function whose existence it has no information about.

Let me know if I missed something in any of my assumptions about module functions.

assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
"Get Shader Flags : No Shader Flags Mask exists for function");
return Iter->second;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually be safer to return FunctionFlags[Func] instead of using Find. Then if it fails, the returned result is a default initialized flags structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually be safer to return FunctionFlags[Func] instead of using Find. Then if it fails, the returned result is a default initialized flags structure.

Returning default initialized flags structure value for F whose shader flags mask does not exist would amount to providing incorrect information about a function whose existence it has no information about and in addition inserting it to the DenseMap. Such functionality is not the intent of getFunctionFlags(). Hence the usage of find and assertion while getting shader flags for a Function *F.

Comment on lines 97 to 98
// Propagate Shader Flag Masks to callers with another post-order call graph
// walk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be simpler to propagate the mask to callers in the first loop? Since we're iterating in SCC order, the loop over instructions above should just be able to query the flags for any call it encounters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be simpler to propagate the mask to callers in the first loop? Since we're iterating in SCC order, the loop over instructions above should just be able to query the flags for any call it encounters.

Changes made to compute shader flags mask of SCC functions and propagate them to their callers into one loop.

// Insert shader flag mask for function F
FunctionFlags.insert({F, CSF});
// Update combined shader flags mask for all functions of the module
CombinedSFMask.merge(CSF);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have to calculate SCCSF anyways, we can hoist the merge of the combined mask out of this loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have to calculate SCCSF anyways, we can hoist the merge of the combined mask out of this loop

Changed.

Function *F = CGN->getFunction();
if (!F)
continue;
mergeFunctionShaderFlags(F, SCCSF);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we defer inserting into the function flags map in the loop above and just calculate SCCSF there, and also remove the CurSCC.size() < 2 check, then we can avoid the extra merge here and simply insert the function flags for the SCC for each function at this point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we defer inserting into the function flags map in the loop above and just calculate SCCSF there, and also remove the CurSCC.size() < 2 check, then we can avoid the extra merge here and simply insert the function flags for the SCC for each function at this point.

Changes made to compute shader flags mask of SCC functions and propagate them to their callers into one loop.

@bharadwajy bharadwajy requested a review from bogner January 3, 2025 20:22
Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments on the algorithm. Also we should probably have a test that has an SCC of size greater than one.

Comment on lines 97 to 106
// Propagate Shader Flag Masks to callers of F
for (const auto User : F->users()) {
if (const CallInst *CI = dyn_cast<CallInst>(User)) {
const Function *Caller = CI->getParent()->getParent();
if (FunctionFlags.contains(Caller))
FunctionFlags[Caller].merge(SCCSF);
else
FunctionFlags[Caller] = SCCSF;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need a loop to propagate flags to callers. Instead, we should handle this in updateFunctionFlags with something like:

  if (auto *CI = dyn_cast<CallInst>(&I))
    CSF.merge(FunctionFlags[CI->getCalledFunction()]);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need a loop to propagate flags to callers. Instead, we should handle this in updateFunctionFlags with something like:

  if (auto *CI = dyn_cast<CallInst>(&I))
    CSF.merge(FunctionFlags[CI->getCalledFunction()]);

Moved functionality to updateFunctionFlags().

Comment on lines 93 to 96
if (FunctionFlags.contains(F))
FunctionFlags[F].merge(SCCSF);
else
FunctionFlags[F] = SCCSF;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant. FunctionFlags[F].merge(SCCSF) will do the right thing if the function doesn't yet have a mask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant. FunctionFlags[F].merge(SCCSF) will do the right thing if the function doesn't yet have a mask.

Changed.

Comment on lines 88 to 89
if (!F)
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have the exact same checks here as the loop above (ie, skip both external nodes and declarations). It may be arguably more maintainable to just fill a SmallVector<Function *> with the functions we handle in the SCC and loop over that instead of looping over CurSCC again and having to repeat the checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have the exact same checks here as the loop above (ie, skip both external nodes and declarations). It may be arguably more maintainable to just fill a SmallVector<Function *> with the functions we handle in the SCC and loop over that instead of looping over CurSCC again and having to repeat the checks.

Changed to use Smallvector<Function *> as suggested.

Comment on lines 140 to 145
void ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
ComputedShaderFlags SF) {
assert(FunctionFlags.contains(Func) &&
"Merge Shader Flags : No Shader Flags Mask exists for function");
FunctionFlags[Func].merge(SF);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't appear to be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't appear to be used.

Deleted unused function mergeFunctionShaderFlags().

@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions

;CHECK: ; Function top_level : 0x00000044
define void @top_level() #0 {
call void @test_uitofp_i64(i64 noundef 5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no void @test_uitofp_i64, and it isn't clear to me what noundef is doing here. I think you meant call double @test_uitofp_i64(i64 5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no void @test_uitofp_i64, and it isn't clear to me what noundef is doing here. I think you meant call double @test_uitofp_i64(i64 5)

Changed.

@bharadwajy
Copy link
Contributor Author

A few more comments on the algorithm. Also we should probably have a test that has an SCC of size greater than one.

Call Graph of code in propagate-function-flags-test.ll has an SCC with size 3. Let me know if you had some other test scenario in mind. Thanks!

@bharadwajy bharadwajy requested a review from bogner January 7, 2025 16:11
@bogner
Copy link
Contributor

bogner commented Jan 9, 2025

A few more comments on the algorithm. Also we should probably have a test that has an SCC of size greater than one.

Call Graph of code in propagate-function-flags-test.ll has an SCC with size 3. Let me know if you had some other test scenario in mind. Thanks!

Ah, that's true, but it doesn't seem to be a very good test case. Consider:

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 0392d3c84735..01bf300a7f02 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -97,19 +97,12 @@ void ModuleShaderFlags::initialize(Module &M) {
       // Update combined shader flags mask for all functions in this SCC
       SCCSF.merge(CSF);

-      CurSCCFuncs.push_back(F);
+      FunctionFlags[F].merge(SCCSF);
     }

     // Update combined shader flags mask for all functions of the module
     CombinedSFMask.merge(SCCSF);

-    // Shader flags mask of each of the functions in an SCC of the call graph is
-    // the union of all functions in the SCC. Update shader flags masks of
-    // functions in CurSCC accordingly. This is trivially true if SCC contains
-    // one function.
-    for (Function *F : CurSCCFuncs)
-      // Merge SCCSF with that of F
-      FunctionFlags[F].merge(SCCSF);
   }
 }

With this, propagate-function-flags-test.ll still passes (perhaps by coincidence). So it doesn't seem to be a very good test of how we're handling SCCs with more than one function.

@bharadwajy
Copy link
Contributor Author

A few more comments on the algorithm. Also we should probably have a test that has an SCC of size greater than one.

Call Graph of code in propagate-function-flags-test.ll has an SCC with size 3. Let me know if you had some other test scenario in mind. Thanks!

Ah, that's true, but it doesn't seem to be a very good test case. Consider:

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 0392d3c84735..01bf300a7f02 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -97,19 +97,12 @@ void ModuleShaderFlags::initialize(Module &M) {
       // Update combined shader flags mask for all functions in this SCC
       SCCSF.merge(CSF);

-      CurSCCFuncs.push_back(F);
+      FunctionFlags[F].merge(SCCSF);
     }

     // Update combined shader flags mask for all functions of the module
     CombinedSFMask.merge(SCCSF);

-    // Shader flags mask of each of the functions in an SCC of the call graph is
-    // the union of all functions in the SCC. Update shader flags masks of
-    // functions in CurSCC accordingly. This is trivially true if SCC contains
-    // one function.
-    for (Function *F : CurSCCFuncs)
-      // Merge SCCSF with that of F
-      FunctionFlags[F].merge(SCCSF);
   }
 }

With this, propagate-function-flags-test.ll still passes (perhaps by coincidence). So it doesn't seem to be a very good test of how we're handling SCCs with more than one function.

Expanded propagate-function-flags.ll to test the case highlighted above.

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine, though the propagate-function-flags test is pretty inscrutable. It would be nice to use function names that make it clearer what's being tested, and probably have some comments about what the CFG looks like, otherwise this test will be more or less impossible to update in the future.

Add tests to verify propagation of shader flags
@bharadwajy bharadwajy force-pushed the shader-flags/propagate-shader-flags branch from 805810e to b6dfe53 Compare January 14, 2025 14:04
@bharadwajy
Copy link
Contributor Author

This looks fine, though the propagate-function-flags test is pretty inscrutable. It would be nice to use function names that make it clearer what's being tested, and probably have some comments about what the CFG looks like, otherwise this test will be more or less impossible to update in the future.

Updated the test as suggested. Thanks!

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. Take them or leave them.

Comment on lines 80 to 82
if (FunctionFlags.contains(CF)) {
CSF.merge(FunctionFlags[CF]);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (FunctionFlags.contains(CF)) {
CSF.merge(FunctionFlags[CF]);
}
if (FunctionFlags.contains(CF))
CSF.merge(FunctionFlags[CF]);

see: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue;

if (F->isDeclaration()) {
assert(!F->getName().starts_with("dx.op.") &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be report_fatal_error instead of an assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be report_fatal_error instead of an assert?

Possibly... however, this change is a result of merging current upstream version of the file. My thought (admittedly not too strong) is to not change as part of this PR.

@bharadwajy bharadwajy merged commit a4b7a2d into llvm:main Jan 14, 2025
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[DirectX] Propagate shader flags mask for entry functions
5 participants