-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[DirectX] Propagate shader flags mask of callees to callers #118306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,12 @@ | |
|
||
#include "DXILShaderFlags.h" | ||
#include "DirectX.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SCCIterator.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/Analysis/CallGraph.h" | ||
#include "llvm/Analysis/DXILResource.h" | ||
#include "llvm/IR/Instruction.h" | ||
#include "llvm/IR/Instructions.h" | ||
#include "llvm/IR/IntrinsicInst.h" | ||
#include "llvm/IR/Intrinsics.h" | ||
#include "llvm/IR/IntrinsicsDirectX.h" | ||
|
@@ -27,24 +30,31 @@ | |
using namespace llvm; | ||
using namespace llvm::dxil; | ||
|
||
static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, | ||
DXILResourceTypeMap &DRTM) { | ||
/// Update the shader flags mask based on the given instruction. | ||
/// \param CSF Shader flags mask to update. | ||
/// \param I Instruction to check. | ||
void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, | ||
const Instruction &I, | ||
DXILResourceTypeMap &DRTM) { | ||
if (!CSF.Doubles) | ||
CSF.Doubles = I.getType()->isDoubleTy(); | ||
|
||
if (!CSF.Doubles) { | ||
for (Value *Op : I.operands()) | ||
CSF.Doubles |= Op->getType()->isDoubleTy(); | ||
for (const Value *Op : I.operands()) { | ||
if (Op->getType()->isDoubleTy()) { | ||
CSF.Doubles = true; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
if (CSF.Doubles) { | ||
switch (I.getOpcode()) { | ||
case Instruction::FDiv: | ||
case Instruction::UIToFP: | ||
case Instruction::SIToFP: | ||
case Instruction::FPToUI: | ||
case Instruction::FPToSI: | ||
// TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma | ||
// https://github.com/llvm/llvm-project/issues/114554 | ||
CSF.DX11_1_DoubleExtensions = true; | ||
break; | ||
} | ||
|
@@ -62,27 +72,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, | |
} | ||
} | ||
} | ||
// Handle call instructions | ||
if (auto *CI = dyn_cast<CallInst>(&I)) { | ||
const Function *CF = CI->getCalledFunction(); | ||
// Merge-in shader flags mask of the called function in the current module | ||
if (FunctionFlags.contains(CF)) | ||
CSF.merge(FunctionFlags[CF]); | ||
|
||
// TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic | ||
// DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 | ||
} | ||
} | ||
|
||
void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) { | ||
|
||
// Collect shader flags for each of the functions | ||
for (const auto &F : M.getFunctionList()) { | ||
if (F.isDeclaration()) { | ||
assert(!F.getName().starts_with("dx.op.") && | ||
"DXIL Shader Flag analysis should not be run post-lowering."); | ||
continue; | ||
/// Construct ModuleShaderFlags for module Module M | ||
void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) { | ||
CallGraph CG(M); | ||
|
||
// Compute Shader Flags Mask for all functions using post-order visit of SCC | ||
// of the call graph. | ||
for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd(); | ||
++SCCI) { | ||
const std::vector<CallGraphNode *> &CurSCC = *SCCI; | ||
|
||
// Union of shader masks of all functions in CurSCC | ||
ComputedShaderFlags SCCSF; | ||
// List of functions in CurSCC that are neither external nor declarations | ||
// and hence whose flags are collected | ||
SmallVector<Function *> CurSCCFuncs; | ||
for (CallGraphNode *CGN : CurSCC) { | ||
Function *F = CGN->getFunction(); | ||
if (!F) | ||
continue; | ||
|
||
if (F->isDeclaration()) { | ||
assert(!F->getName().starts_with("dx.op.") && | ||
"DXIL Shader Flag analysis should not be run post-lowering."); | ||
continue; | ||
} | ||
|
||
ComputedShaderFlags CSF; | ||
for (const auto &BB : *F) | ||
for (const auto &I : BB) | ||
updateFunctionFlags(CSF, I, DRTM); | ||
// Update combined shader flags mask for all functions in this SCC | ||
SCCSF.merge(CSF); | ||
|
||
CurSCCFuncs.push_back(F); | ||
} | ||
ComputedShaderFlags CSF; | ||
for (const auto &BB : F) | ||
for (const auto &I : BB) | ||
updateFunctionFlags(CSF, I, DRTM); | ||
// Insert shader flag mask for function F | ||
FunctionFlags.push_back({&F, CSF}); | ||
// Update combined shader flags mask | ||
CombinedSFMask.merge(CSF); | ||
|
||
// Update combined shader flags mask for all functions of the module | ||
CombinedSFMask.merge(SCCSF); | ||
|
||
// Shader flags mask of each of the functions in an SCC of the call graph is | ||
// the union of all functions in the SCC. Update shader flags masks of | ||
// functions in CurSCC accordingly. This is trivially true if SCC contains | ||
// one function. | ||
for (Function *F : CurSCCFuncs) | ||
// Merge SCCSF with that of F | ||
FunctionFlags[F].merge(SCCSF); | ||
} | ||
llvm::sort(FunctionFlags); | ||
} | ||
|
||
void ComputedShaderFlags::print(raw_ostream &OS) const { | ||
|
@@ -106,12 +154,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const { | |
/// Return the shader flags mask of the specified function Func. | ||
const ComputedShaderFlags & | ||
ModuleShaderFlags::getFunctionFlags(const Function *Func) const { | ||
const auto Iter = llvm::lower_bound( | ||
FunctionFlags, Func, | ||
[](const std::pair<const Function *, ComputedShaderFlags> FSM, | ||
const Function *FindFunc) { return (FSM.first < FindFunc); }); | ||
auto Iter = FunctionFlags.find(Func); | ||
assert((Iter != FunctionFlags.end() && Iter->first == Func) && | ||
"No Shader Flags Mask exists for function"); | ||
"Get Shader Flags : No Shader Flags Mask exists for function"); | ||
return Iter->second; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would actually be safer to return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Returning default initialized flags structure value for |
||
} | ||
|
||
|
@@ -142,7 +187,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, | |
for (const auto &F : M.getFunctionList()) { | ||
if (F.isDeclaration()) | ||
continue; | ||
auto SFMask = FlagsInfo.getFunctionFlags(&F); | ||
const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine, but FWIW, since |
||
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(), | ||
(uint64_t)(SFMask)); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s | ||
|
||
target triple = "dxil-pc-shadermodel6.7-library" | ||
|
||
; CHECK: ; Combined Shader Flags for Module | ||
; CHECK-NEXT: ; Shader Flags Value: 0x00000044 | ||
; CHECK-NEXT: ; | ||
; CHECK-NEXT: ; Note: shader requires additional functionality: | ||
; CHECK-NEXT: ; Double-precision floating point | ||
; CHECK-NEXT: ; Double-precision extensions for 11.1 | ||
; CHECK-NEXT: ; Note: extra DXIL module flags: | ||
; CHECK-NEXT: ; | ||
; CHECK-NEXT: ; Shader Flags for Module Functions | ||
|
||
; Call Graph of test source | ||
; main -> [get_fptoui_flag, get_sitofp_fdiv_flag] | ||
; get_fptoui_flag -> [get_sitofp_uitofp_flag, call_get_uitofp_flag] | ||
; get_sitofp_uitofp_flag -> [call_get_fptoui_flag, call_get_sitofp_flag] | ||
; call_get_fptoui_flag -> [get_fptoui_flag] | ||
; get_sitofp_fdiv_flag -> [get_no_flags, get_all_doubles_flags] | ||
; get_all_doubles_flags -> [call_get_sitofp_fdiv_flag] | ||
; call_get_sitofp_fdiv_flag -> [get_sitofp_fdiv_flag] | ||
; call_get_sitofp_flag -> [get_sitofp_flag] | ||
; call_get_uitofp_flag -> [get_uitofp_flag] | ||
; get_sitofp_flag -> [] | ||
; get_uitofp_flag -> [] | ||
; get_no_flags -> [] | ||
; | ||
; Strongly Connected Component in the CG | ||
; [get_fptoui_flag, get_sitofp_uitofp_flag, call_get_fptoui_flag] | ||
; [get_sitofp_fdiv_flag, get_all_doubles_flags, call_get_sitofp_fdiv_flag] | ||
|
||
; | ||
; CHECK: ; Function get_sitofp_flag : 0x00000044 | ||
define double @get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = sitofp i32 %0 to double | ||
ret double %2 | ||
} | ||
|
||
; CHECK: ; Function call_get_sitofp_flag : 0x00000044 | ||
define double @call_get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = tail call double @get_sitofp_flag(i32 noundef %0) | ||
ret double %2 | ||
} | ||
|
||
; CHECK: ; Function get_uitofp_flag : 0x00000044 | ||
define double @get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = uitofp i32 %0 to double | ||
ret double %2 | ||
} | ||
|
||
; CHECK: ; Function call_get_uitofp_flag : 0x00000044 | ||
define double @call_get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = tail call double @get_uitofp_flag(i32 noundef %0) | ||
ret double %2 | ||
} | ||
|
||
; CHECK: ; Function call_get_fptoui_flag : 0x00000044 | ||
define double @call_get_fptoui_flag(double noundef %0) local_unnamed_addr #0 { | ||
%2 = tail call double @get_fptoui_flag(double noundef %0) | ||
ret double %2 | ||
} | ||
|
||
; CHECK: ; Function get_fptoui_flag : 0x00000044 | ||
define double @get_fptoui_flag(double noundef %0) local_unnamed_addr #0 { | ||
%2 = fcmp ugt double %0, 5.000000e+00 | ||
br i1 %2, label %6, label %3 | ||
|
||
3: ; preds = %1 | ||
%4 = fptoui double %0 to i64 | ||
%5 = tail call double @get_sitofp_uitofp_flag(i64 noundef %4) | ||
br label %9 | ||
|
||
6: ; preds = %1 | ||
%7 = fptoui double %0 to i32 | ||
%8 = tail call double @call_get_uitofp_flag(i32 noundef %7) | ||
br label %9 | ||
|
||
9: ; preds = %6, %3 | ||
%10 = phi double [ %5, %3 ], [ %8, %6 ] | ||
ret double %10 | ||
} | ||
|
||
; CHECK: ; Function get_sitofp_uitofp_flag : 0x00000044 | ||
define double @get_sitofp_uitofp_flag(i64 noundef %0) local_unnamed_addr #0 { | ||
%2 = icmp ult i64 %0, 6 | ||
br i1 %2, label %3, label %7 | ||
|
||
3: ; preds = %1 | ||
%4 = add nuw nsw i64 %0, 1 | ||
%5 = uitofp i64 %4 to double | ||
%6 = tail call double @call_get_fptoui_flag(double noundef %5) | ||
br label %10 | ||
|
||
7: ; preds = %1 | ||
%8 = trunc i64 %0 to i32 | ||
%9 = tail call double @call_get_sitofp_flag(i32 noundef %8) | ||
br label %10 | ||
|
||
10: ; preds = %7, %3 | ||
%11 = phi double [ %6, %3 ], [ %9, %7 ] | ||
ret double %11 | ||
} | ||
|
||
; CHECK: ; Function get_no_flags : 0x00000000 | ||
define i32 @get_no_flags(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = mul nsw i32 %0, %0 | ||
ret i32 %2 | ||
} | ||
|
||
; CHECK: ; Function call_get_sitofp_fdiv_flag : 0x00000044 | ||
define i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = icmp eq i32 %0, 0 | ||
br i1 %2, label %5, label %3 | ||
|
||
3: ; preds = %1 | ||
%4 = mul nsw i32 %0, %0 | ||
br label %7 | ||
|
||
5: ; preds = %1 | ||
%6 = tail call double @get_sitofp_fdiv_flag(i32 noundef 0) | ||
br label %7 | ||
|
||
7: ; preds = %5, %3 | ||
%8 = phi i32 [ %4, %3 ], [ 0, %5 ] | ||
ret i32 %8 | ||
} | ||
|
||
; CHECK: ; Function get_sitofp_fdiv_flag : 0x00000044 | ||
define double @get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = icmp sgt i32 %0, 5 | ||
br i1 %2, label %3, label %6 | ||
|
||
3: ; preds = %1 | ||
%4 = tail call i32 @get_no_flags(i32 noundef %0) | ||
%5 = sitofp i32 %4 to double | ||
br label %9 | ||
|
||
6: ; preds = %1 | ||
%7 = tail call double @get_all_doubles_flags(i32 noundef %0) | ||
%8 = fdiv double %7, 3.000000e+00 | ||
br label %9 | ||
|
||
9: ; preds = %6, %3 | ||
%10 = phi double [ %5, %3 ], [ %8, %6 ] | ||
ret double %10 | ||
} | ||
|
||
; CHECK: ; Function get_all_doubles_flags : 0x00000044 | ||
define double @get_all_doubles_flags(i32 noundef %0) local_unnamed_addr #0 { | ||
%2 = tail call i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) | ||
%3 = icmp eq i32 %2, 0 | ||
%4 = select i1 %3, double 1.000000e+01, double 1.000000e+02 | ||
ret double %4 | ||
} | ||
|
||
; CHECK: ; Function main : 0x00000044 | ||
define i32 @main() local_unnamed_addr #0 { | ||
%1 = tail call double @get_fptoui_flag(double noundef 1.000000e+00) | ||
%2 = tail call double @get_sitofp_fdiv_flag(i32 noundef 4) | ||
%3 = fadd double %1, %2 | ||
%4 = fcmp ogt double %3, 0.000000e+00 | ||
%5 = zext i1 %4 to i32 | ||
ret i32 %5 | ||
} | ||
|
||
attributes #0 = { convergent norecurse nounwind "hlsl.export"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be
report_fatal_error
instead of an assert?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly... however, this change is a result of merging current upstream version of the file. My thought (admittedly not too strong) is to not change as part of this PR.