Skip to content

[DirectX] Propagate shader flags mask of callees to callers #118306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 75 additions & 30 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
Expand All @@ -27,24 +30,31 @@
using namespace llvm;
using namespace llvm::dxil;

static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
DXILResourceTypeMap &DRTM) {
/// Update the shader flags mask based on the given instruction.
/// \param CSF Shader flags mask to update.
/// \param I Instruction to check.
void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
const Instruction &I,
DXILResourceTypeMap &DRTM) {
if (!CSF.Doubles)
CSF.Doubles = I.getType()->isDoubleTy();

if (!CSF.Doubles) {
for (Value *Op : I.operands())
CSF.Doubles |= Op->getType()->isDoubleTy();
for (const Value *Op : I.operands()) {
if (Op->getType()->isDoubleTy()) {
CSF.Doubles = true;
break;
}
}
}

if (CSF.Doubles) {
switch (I.getOpcode()) {
case Instruction::FDiv:
case Instruction::UIToFP:
case Instruction::SIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
// TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
// https://github.com/llvm/llvm-project/issues/114554
CSF.DX11_1_DoubleExtensions = true;
break;
}
Expand All @@ -62,27 +72,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
}
}
}
// Handle call instructions
if (auto *CI = dyn_cast<CallInst>(&I)) {
const Function *CF = CI->getCalledFunction();
// Merge-in shader flags mask of the called function in the current module
if (FunctionFlags.contains(CF))
CSF.merge(FunctionFlags[CF]);

// TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
// DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
}
}

void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {

// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration()) {
assert(!F.getName().starts_with("dx.op.") &&
"DXIL Shader Flag analysis should not be run post-lowering.");
continue;
/// Construct ModuleShaderFlags for module Module M
void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) {
CallGraph CG(M);

// Compute Shader Flags Mask for all functions using post-order visit of SCC
// of the call graph.
for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
++SCCI) {
const std::vector<CallGraphNode *> &CurSCC = *SCCI;

// Union of shader masks of all functions in CurSCC
ComputedShaderFlags SCCSF;
// List of functions in CurSCC that are neither external nor declarations
// and hence whose flags are collected
SmallVector<Function *> CurSCCFuncs;
for (CallGraphNode *CGN : CurSCC) {
Function *F = CGN->getFunction();
if (!F)
continue;

if (F->isDeclaration()) {
assert(!F->getName().starts_with("dx.op.") &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be report_fatal_error instead of an assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be report_fatal_error instead of an assert?

Possibly... however, this change is a result of merging current upstream version of the file. My thought (admittedly not too strong) is to not change as part of this PR.

"DXIL Shader Flag analysis should not be run post-lowering.");
continue;
}

ComputedShaderFlags CSF;
for (const auto &BB : *F)
for (const auto &I : BB)
updateFunctionFlags(CSF, I, DRTM);
// Update combined shader flags mask for all functions in this SCC
SCCSF.merge(CSF);

CurSCCFuncs.push_back(F);
}
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
updateFunctionFlags(CSF, I, DRTM);
// Insert shader flag mask for function F
FunctionFlags.push_back({&F, CSF});
// Update combined shader flags mask
CombinedSFMask.merge(CSF);

// Update combined shader flags mask for all functions of the module
CombinedSFMask.merge(SCCSF);

// Shader flags mask of each of the functions in an SCC of the call graph is
// the union of all functions in the SCC. Update shader flags masks of
// functions in CurSCC accordingly. This is trivially true if SCC contains
// one function.
for (Function *F : CurSCCFuncs)
// Merge SCCSF with that of F
FunctionFlags[F].merge(SCCSF);
}
llvm::sort(FunctionFlags);
}

void ComputedShaderFlags::print(raw_ostream &OS) const {
Expand All @@ -106,12 +154,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
const auto Iter = llvm::lower_bound(
FunctionFlags, Func,
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
"Get Shader Flags : No Shader Flags Mask exists for function");
return Iter->second;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually be safer to return FunctionFlags[Func] instead of using Find. Then if it fails, the returned result is a default initialized flags structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually be safer to return FunctionFlags[Func] instead of using Find. Then if it fails, the returned result is a default initialized flags structure.

Returning default initialized flags structure value for F whose shader flags mask does not exist would amount to providing incorrect information about a function whose existence it has no information about and in addition inserting it to the DenseMap. Such functionality is not the intent of getFunctionFlags(). Hence the usage of find and assertion while getting shader flags for a Function *F.

}

Expand Down Expand Up @@ -142,7 +187,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
auto SFMask = FlagsInfo.getFunctionFlags(&F);
const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but FWIW, since sizeof(ComputedShaderFlags) == 8, the cost of a copy here is negligible since the address being copied for the reference will also be 64-bits in all the cases we care about.

OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
(uint64_t)(SFMask));
}
Expand Down
20 changes: 10 additions & 10 deletions llvm/lib/Target/DirectX/DXILShaderFlags.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,31 @@ struct ComputedShaderFlags {
return FeatureFlags;
}

void merge(const uint64_t IVal) {
void merge(const ComputedShaderFlags CSF) {
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
FlagName |= (IVal & getMask(DxilModuleBit));
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
FlagName |= (IVal & getMask(DxilModuleBit));
FlagName |= CSF.FlagName;
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
#include "llvm/BinaryFormat/DXContainerConstants.def"
return;
}

void print(raw_ostream &OS = dbgs()) const;
LLVM_DUMP_METHOD void dump() const { print(); }
};

struct ModuleShaderFlags {
void initialize(const Module &, DXILResourceTypeMap &DRTM);
void initialize(Module &, DXILResourceTypeMap &DRTM);
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }

private:
/// Vector of sorted Function-Shader Flag mask pairs representing properties
/// of each of the functions in the module. Shader Flags of each function
/// represent both module-level and function-level flags
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
/// Map of Function-Shader Flag Mask pairs representing properties of each of
/// the functions in the module. Shader Flags of each function represent both
/// module-level and function-level flags
DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
void updateFunctionFlags(ComputedShaderFlags &, const Instruction &,
DXILResourceTypeMap &);
};

class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
Expand Down
7 changes: 7 additions & 0 deletions llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions

;CHECK: ; Function top_level : 0x00000044
define double @top_level() #0 {
%r = call double @test_uitofp_i64(i64 5)
ret double %r
}


; CHECK: ; Function test_fdiv_double : 0x00000044
define double @test_fdiv_double(double %a, double %b) #0 {
%res = fdiv double %a, %b
Expand Down
167 changes: 167 additions & 0 deletions llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.7-library"

; CHECK: ; Combined Shader Flags for Module
; CHECK-NEXT: ; Shader Flags Value: 0x00000044
; CHECK-NEXT: ;
; CHECK-NEXT: ; Note: shader requires additional functionality:
; CHECK-NEXT: ; Double-precision floating point
; CHECK-NEXT: ; Double-precision extensions for 11.1
; CHECK-NEXT: ; Note: extra DXIL module flags:
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions

; Call Graph of test source
; main -> [get_fptoui_flag, get_sitofp_fdiv_flag]
; get_fptoui_flag -> [get_sitofp_uitofp_flag, call_get_uitofp_flag]
; get_sitofp_uitofp_flag -> [call_get_fptoui_flag, call_get_sitofp_flag]
; call_get_fptoui_flag -> [get_fptoui_flag]
; get_sitofp_fdiv_flag -> [get_no_flags, get_all_doubles_flags]
; get_all_doubles_flags -> [call_get_sitofp_fdiv_flag]
; call_get_sitofp_fdiv_flag -> [get_sitofp_fdiv_flag]
; call_get_sitofp_flag -> [get_sitofp_flag]
; call_get_uitofp_flag -> [get_uitofp_flag]
; get_sitofp_flag -> []
; get_uitofp_flag -> []
; get_no_flags -> []
;
; Strongly Connected Component in the CG
; [get_fptoui_flag, get_sitofp_uitofp_flag, call_get_fptoui_flag]
; [get_sitofp_fdiv_flag, get_all_doubles_flags, call_get_sitofp_fdiv_flag]

;
; CHECK: ; Function get_sitofp_flag : 0x00000044
define double @get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
%2 = sitofp i32 %0 to double
ret double %2
}

; CHECK: ; Function call_get_sitofp_flag : 0x00000044
define double @call_get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
%2 = tail call double @get_sitofp_flag(i32 noundef %0)
ret double %2
}

; CHECK: ; Function get_uitofp_flag : 0x00000044
define double @get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
%2 = uitofp i32 %0 to double
ret double %2
}

; CHECK: ; Function call_get_uitofp_flag : 0x00000044
define double @call_get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
%2 = tail call double @get_uitofp_flag(i32 noundef %0)
ret double %2
}

; CHECK: ; Function call_get_fptoui_flag : 0x00000044
define double @call_get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
%2 = tail call double @get_fptoui_flag(double noundef %0)
ret double %2
}

; CHECK: ; Function get_fptoui_flag : 0x00000044
define double @get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
%2 = fcmp ugt double %0, 5.000000e+00
br i1 %2, label %6, label %3

3: ; preds = %1
%4 = fptoui double %0 to i64
%5 = tail call double @get_sitofp_uitofp_flag(i64 noundef %4)
br label %9

6: ; preds = %1
%7 = fptoui double %0 to i32
%8 = tail call double @call_get_uitofp_flag(i32 noundef %7)
br label %9

9: ; preds = %6, %3
%10 = phi double [ %5, %3 ], [ %8, %6 ]
ret double %10
}

; CHECK: ; Function get_sitofp_uitofp_flag : 0x00000044
define double @get_sitofp_uitofp_flag(i64 noundef %0) local_unnamed_addr #0 {
%2 = icmp ult i64 %0, 6
br i1 %2, label %3, label %7

3: ; preds = %1
%4 = add nuw nsw i64 %0, 1
%5 = uitofp i64 %4 to double
%6 = tail call double @call_get_fptoui_flag(double noundef %5)
br label %10

7: ; preds = %1
%8 = trunc i64 %0 to i32
%9 = tail call double @call_get_sitofp_flag(i32 noundef %8)
br label %10

10: ; preds = %7, %3
%11 = phi double [ %6, %3 ], [ %9, %7 ]
ret double %11
}

; CHECK: ; Function get_no_flags : 0x00000000
define i32 @get_no_flags(i32 noundef %0) local_unnamed_addr #0 {
%2 = mul nsw i32 %0, %0
ret i32 %2
}

; CHECK: ; Function call_get_sitofp_fdiv_flag : 0x00000044
define i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
%2 = icmp eq i32 %0, 0
br i1 %2, label %5, label %3

3: ; preds = %1
%4 = mul nsw i32 %0, %0
br label %7

5: ; preds = %1
%6 = tail call double @get_sitofp_fdiv_flag(i32 noundef 0)
br label %7

7: ; preds = %5, %3
%8 = phi i32 [ %4, %3 ], [ 0, %5 ]
ret i32 %8
}

; CHECK: ; Function get_sitofp_fdiv_flag : 0x00000044
define double @get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
%2 = icmp sgt i32 %0, 5
br i1 %2, label %3, label %6

3: ; preds = %1
%4 = tail call i32 @get_no_flags(i32 noundef %0)
%5 = sitofp i32 %4 to double
br label %9

6: ; preds = %1
%7 = tail call double @get_all_doubles_flags(i32 noundef %0)
%8 = fdiv double %7, 3.000000e+00
br label %9

9: ; preds = %6, %3
%10 = phi double [ %5, %3 ], [ %8, %6 ]
ret double %10
}

; CHECK: ; Function get_all_doubles_flags : 0x00000044
define double @get_all_doubles_flags(i32 noundef %0) local_unnamed_addr #0 {
%2 = tail call i32 @call_get_sitofp_fdiv_flag(i32 noundef %0)
%3 = icmp eq i32 %2, 0
%4 = select i1 %3, double 1.000000e+01, double 1.000000e+02
ret double %4
}

; CHECK: ; Function main : 0x00000044
define i32 @main() local_unnamed_addr #0 {
%1 = tail call double @get_fptoui_flag(double noundef 1.000000e+00)
%2 = tail call double @get_sitofp_fdiv_flag(i32 noundef 4)
%3 = fadd double %1, %2
%4 = fcmp ogt double %3, 0.000000e+00
%5 = zext i1 %4 to i32
ret i32 %5
}

attributes #0 = { convergent norecurse nounwind "hlsl.export"}
Loading