Skip to content

[DirectX] Split resource info into type and binding info. NFC #119773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 18, 2024

Conversation

bogner
Copy link
Contributor

@bogner bogner commented Dec 12, 2024

This splits the DXILResourceAnalysis pass into TypeAnalysis and BindingAnalysis passes. The type analysis pass is made immutable and populated lazily so that it can be used earlier in the pipeline without needing to carefully maintain the invariants of the binding analysis.

Fixes #118400

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

Changes

This splits the DXILResourceAnalysis pass into TypeAnalysis and BindingAnalysis passes. The type analysis pass is made immutable and populated lazily so that it can be used earlier in the pipeline without needing to carefully maintain the invariants of the binding analysis.

Fixes #118400


Patch is 93.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119773.diff

19 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+161-58)
  • (modified) llvm/include/llvm/InitializePasses.h (+2-1)
  • (modified) llvm/include/llvm/LinkAllPasses.h (+2-1)
  • (modified) llvm/lib/Analysis/Analysis.cpp (+2-1)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+225-158)
  • (modified) llvm/lib/Passes/PassRegistry.def (+4-2)
  • (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+17-11)
  • (modified) llvm/lib/Target/DirectX/DXILDataScalarization.cpp (-7)
  • (modified) llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp (-5)
  • (modified) llvm/lib/Target/DirectX/DXILFinalizeLinkage.h (-1)
  • (modified) llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (-7)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (-6)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+33-22)
  • (modified) llvm/lib/Target/DirectX/DXILPrepare.cpp (+1-1)
  • (modified) llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp (+47-32)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+27-20)
  • (modified) llvm/test/Analysis/DXILResource/buffer-frombinding.ll (+2-14)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+2-2)
  • (modified) llvm/unittests/Analysis/DXILResourceTest.cpp (+197-195)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 0205356af54443..2f5dded46538ea 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -25,6 +25,8 @@ class MDTuple;
 class TargetExtType;
 class Value;
 
+class DXILResourceTypeMap;
+
 namespace dxil {
 
 /// The dx.RawBuffer target extension type
@@ -196,27 +198,8 @@ class SamplerExtType : public TargetExtType {
 
 //===----------------------------------------------------------------------===//
 
-class ResourceInfo {
+class ResourceTypeInfo {
 public:
-  struct ResourceBinding {
-    uint32_t RecordID;
-    uint32_t Space;
-    uint32_t LowerBound;
-    uint32_t Size;
-
-    bool operator==(const ResourceBinding &RHS) const {
-      return std::tie(RecordID, Space, LowerBound, Size) ==
-             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
-    }
-    bool operator!=(const ResourceBinding &RHS) const {
-      return !(*this == RHS);
-    }
-    bool operator<(const ResourceBinding &RHS) const {
-      return std::tie(RecordID, Space, LowerBound, Size) <
-             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
-    }
-  };
-
   struct UAVInfo {
     bool GloballyCoherent;
     bool HasCounter;
@@ -266,12 +249,11 @@ class ResourceInfo {
   };
 
 private:
-  ResourceBinding Binding;
   TargetExtType *HandleTy;
 
   // GloballyCoherent and HasCounter aren't really part of the type and need to
-  // be determined by analysis, so they're just provided directly when we
-  // construct these.
+  // be determined by analysis, so they're just provided directly by the
+  // DXILResourceTypeMap when we construct these.
   bool GloballyCoherent;
   bool HasCounter;
 
@@ -279,9 +261,13 @@ class ResourceInfo {
   dxil::ResourceKind Kind;
 
 public:
-  ResourceInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
-               uint32_t Size, TargetExtType *HandleTy,
-               bool GloballyCoherent = false, bool HasCounter = false);
+  ResourceTypeInfo(TargetExtType *HandleTy, const dxil::ResourceClass RC,
+                   const dxil::ResourceKind Kind, bool GloballyCoherent = false,
+                   bool HasCounter = false);
+  ResourceTypeInfo(TargetExtType *HandleTy, bool GloballyCoherent = false,
+                   bool HasCounter = false)
+      : ResourceTypeInfo(HandleTy, {}, dxil::ResourceKind::Invalid,
+                         GloballyCoherent, HasCounter) {}
 
   TargetExtType *getHandleTy() const { return HandleTy; }
 
@@ -303,44 +289,157 @@ class ResourceInfo {
   dxil::SamplerFeedbackType getFeedbackType() const;
   uint32_t getMultiSampleCount() const;
 
-  StringRef getName() const {
-    // TODO: Get the name from the symbol once we include one here.
-    return "";
-  }
   dxil::ResourceClass getResourceClass() const { return RC; }
   dxil::ResourceKind getResourceKind() const { return Kind; }
 
+  bool operator==(const ResourceTypeInfo &RHS) const;
+  bool operator!=(const ResourceTypeInfo &RHS) const { return !(*this == RHS); }
+  bool operator<(const ResourceTypeInfo &RHS) const;
+
+  void print(raw_ostream &OS, const DataLayout &DL) const;
+};
+
+//===----------------------------------------------------------------------===//
+
+class ResourceBindingInfo {
+public:
+  struct ResourceBinding {
+    uint32_t RecordID;
+    uint32_t Space;
+    uint32_t LowerBound;
+    uint32_t Size;
+
+    bool operator==(const ResourceBinding &RHS) const {
+      return std::tie(RecordID, Space, LowerBound, Size) ==
+             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
+    }
+    bool operator!=(const ResourceBinding &RHS) const {
+      return !(*this == RHS);
+    }
+    bool operator<(const ResourceBinding &RHS) const {
+      return std::tie(RecordID, Space, LowerBound, Size) <
+             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
+    }
+  };
+
+private:
+  ResourceBinding Binding;
+  TargetExtType *HandleTy;
+
+public:
+  ResourceBindingInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
+                      uint32_t Size, TargetExtType *HandleTy)
+      : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy) {}
+
   void setBindingID(unsigned ID) { Binding.RecordID = ID; }
 
   const ResourceBinding &getBinding() const { return Binding; }
+  TargetExtType *getHandleTy() const { return HandleTy; }
+  const StringRef getName() const {
+    // TODO: Get the name from the symbol once we include one here.
+    return "";
+  }
 
-  MDTuple *getAsMetadata(Module &M) const;
-  std::pair<uint32_t, uint32_t> getAnnotateProps(Module &M) const;
+  MDTuple *getAsMetadata(Module &M, DXILResourceTypeMap &DRTM) const;
+  MDTuple *getAsMetadata(Module &M, dxil::ResourceTypeInfo RTI) const;
 
-  bool operator==(const ResourceInfo &RHS) const;
-  bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); }
-  bool operator<(const ResourceInfo &RHS) const;
+  std::pair<uint32_t, uint32_t>
+  getAnnotateProps(Module &M, DXILResourceTypeMap &DRTM) const;
+  std::pair<uint32_t, uint32_t>
+  getAnnotateProps(Module &M, dxil::ResourceTypeInfo RTI) const;
 
-  void print(raw_ostream &OS, const DataLayout &DL) const;
+  bool operator==(const ResourceBindingInfo &RHS) const {
+    return std::tie(Binding, HandleTy) == std::tie(RHS.Binding, RHS.HandleTy);
+  }
+  bool operator!=(const ResourceBindingInfo &RHS) const {
+    return !(*this == RHS);
+  }
+  bool operator<(const ResourceBindingInfo &RHS) const {
+    return Binding < RHS.Binding;
+  }
+
+  void print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
+             const DataLayout &DL) const;
+  void print(raw_ostream &OS, dxil::ResourceTypeInfo RTI,
+             const DataLayout &DL) const;
 };
 
 } // namespace dxil
 
 //===----------------------------------------------------------------------===//
 
-class DXILResourceMap {
-  SmallVector<dxil::ResourceInfo> Infos;
+class DXILResourceTypeMap {
+  struct Info {
+    dxil::ResourceClass RC;
+    dxil::ResourceKind Kind;
+    bool GloballyCoherent;
+    bool HasCounter;
+  };
+  DenseMap<TargetExtType *, Info> Infos;
+
+public:
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv);
+
+  dxil::ResourceTypeInfo operator[](TargetExtType *Ty) {
+    Info I = Infos[Ty];
+    return dxil::ResourceTypeInfo(Ty, I.RC, I.Kind, I.GloballyCoherent,
+                                  I.HasCounter);
+  }
+
+  void setGloballyCoherent(TargetExtType *Ty, bool GloballyCoherent) {
+    Infos[Ty].GloballyCoherent = GloballyCoherent;
+  }
+
+  void setHasCounter(TargetExtType *Ty, bool HasCounter) {
+    Infos[Ty].HasCounter = HasCounter;
+  }
+};
+
+class DXILResourceTypeAnalysis
+    : public AnalysisInfoMixin<DXILResourceTypeAnalysis> {
+  friend AnalysisInfoMixin<DXILResourceTypeAnalysis>;
+
+  static AnalysisKey Key;
+
+public:
+  using Result = DXILResourceTypeMap;
+
+  DXILResourceTypeMap run(Module &M, ModuleAnalysisManager &AM) {
+    return Result();
+  }
+};
+
+class DXILResourceTypeWrapperPass : public ImmutablePass {
+  DXILResourceTypeMap DRTM;
+
+  virtual void anchor();
+
+public:
+  static char ID;
+  DXILResourceTypeWrapperPass();
+
+  DXILResourceTypeMap &getResourceTypeMap() { return DRTM; }
+  const DXILResourceTypeMap &getResourceTypeMap() const { return DRTM; }
+};
+
+ModulePass *createDXILResourceTypeWrapperPassPass();
+
+//===----------------------------------------------------------------------===//
+
+class DXILBindingMap {
+  SmallVector<dxil::ResourceBindingInfo> Infos;
   DenseMap<CallInst *, unsigned> CallMap;
   unsigned FirstUAV = 0;
   unsigned FirstCBuffer = 0;
   unsigned FirstSampler = 0;
 
   /// Populate the map given the resource binding calls in the given module.
-  void populate(Module &M);
+  void populate(Module &M, DXILResourceTypeMap &DRTM);
 
 public:
-  using iterator = SmallVector<dxil::ResourceInfo>::iterator;
-  using const_iterator = SmallVector<dxil::ResourceInfo>::const_iterator;
+  using iterator = SmallVector<dxil::ResourceBindingInfo>::iterator;
+  using const_iterator = SmallVector<dxil::ResourceBindingInfo>::const_iterator;
 
   iterator begin() { return Infos.begin(); }
   const_iterator begin() const { return Infos.begin(); }
@@ -399,47 +498,51 @@ class DXILResourceMap {
     return make_range(sampler_begin(), sampler_end());
   }
 
-  void print(raw_ostream &OS, const DataLayout &DL) const;
+  void print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
+             const DataLayout &DL) const;
 
-  friend class DXILResourceAnalysis;
-  friend class DXILResourceWrapperPass;
+  friend class DXILResourceBindingAnalysis;
+  friend class DXILResourceBindingWrapperPass;
 };
 
-class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
-  friend AnalysisInfoMixin<DXILResourceAnalysis>;
+class DXILResourceBindingAnalysis
+    : public AnalysisInfoMixin<DXILResourceBindingAnalysis> {
+  friend AnalysisInfoMixin<DXILResourceBindingAnalysis>;
 
   static AnalysisKey Key;
 
 public:
-  using Result = DXILResourceMap;
+  using Result = DXILBindingMap;
 
   /// Gather resource info for the module \c M.
-  DXILResourceMap run(Module &M, ModuleAnalysisManager &AM);
+  DXILBindingMap run(Module &M, ModuleAnalysisManager &AM);
 };
 
-/// Printer pass for the \c DXILResourceAnalysis results.
-class DXILResourcePrinterPass : public PassInfoMixin<DXILResourcePrinterPass> {
+/// Printer pass for the \c DXILResourceBindingAnalysis results.
+class DXILResourceBindingPrinterPass
+    : public PassInfoMixin<DXILResourceBindingPrinterPass> {
   raw_ostream &OS;
 
 public:
-  explicit DXILResourcePrinterPass(raw_ostream &OS) : OS(OS) {}
+  explicit DXILResourceBindingPrinterPass(raw_ostream &OS) : OS(OS) {}
 
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 
   static bool isRequired() { return true; }
 };
 
-class DXILResourceWrapperPass : public ModulePass {
-  std::unique_ptr<DXILResourceMap> Map;
+class DXILResourceBindingWrapperPass : public ModulePass {
+  std::unique_ptr<DXILBindingMap> Map;
+  DXILResourceTypeMap *DRTM;
 
 public:
   static char ID; // Class identification, replacement for typeinfo
 
-  DXILResourceWrapperPass();
-  ~DXILResourceWrapperPass() override;
+  DXILResourceBindingWrapperPass();
+  ~DXILResourceBindingWrapperPass() override;
 
-  const DXILResourceMap &getResourceMap() const { return *Map; }
-  DXILResourceMap &getResourceMap() { return *Map; }
+  const DXILBindingMap &getBindingMap() const { return *Map; }
+  DXILBindingMap &getBindingMap() { return *Map; }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override;
   bool runOnModule(Module &M) override;
@@ -449,7 +552,7 @@ class DXILResourceWrapperPass : public ModulePass {
   void dump() const;
 };
 
-ModulePass *createDXILResourceWrapperPassPass();
+ModulePass *createDXILResourceBindingWrapperPassPass();
 
 } // namespace llvm
 
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index 7d829cf5b9b015..1cb9013bc48cc5 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -84,7 +84,8 @@ void initializeDAHPass(PassRegistry &);
 void initializeDCELegacyPassPass(PassRegistry &);
 void initializeDXILMetadataAnalysisWrapperPassPass(PassRegistry &);
 void initializeDXILMetadataAnalysisWrapperPrinterPass(PassRegistry &);
-void initializeDXILResourceWrapperPassPass(PassRegistry &);
+void initializeDXILResourceBindingWrapperPassPass(PassRegistry &);
+void initializeDXILResourceTypeWrapperPassPass(PassRegistry &);
 void initializeDeadMachineInstructionElimPass(PassRegistry &);
 void initializeDebugifyMachineModulePass(PassRegistry &);
 void initializeDependenceAnalysisWrapperPassPass(PassRegistry &);
diff --git a/llvm/include/llvm/LinkAllPasses.h b/llvm/include/llvm/LinkAllPasses.h
index 54245ca0b70222..ac1970334de0cd 100644
--- a/llvm/include/llvm/LinkAllPasses.h
+++ b/llvm/include/llvm/LinkAllPasses.h
@@ -70,7 +70,8 @@ struct ForcePassLinking {
     (void)llvm::createCallGraphViewerPass();
     (void)llvm::createCFGSimplificationPass();
     (void)llvm::createStructurizeCFGPass();
-    (void)llvm::createDXILResourceWrapperPassPass();
+    (void)llvm::createDXILResourceBindingWrapperPassPass();
+    (void)llvm::createDXILResourceTypeWrapperPassPass();
     (void)llvm::createDeadArgEliminationPass();
     (void)llvm::createDeadCodeEliminationPass();
     (void)llvm::createDependenceAnalysisWrapperPass();
diff --git a/llvm/lib/Analysis/Analysis.cpp b/llvm/lib/Analysis/Analysis.cpp
index 58723469f21ca8..bc2b8a57f83a7a 100644
--- a/llvm/lib/Analysis/Analysis.cpp
+++ b/llvm/lib/Analysis/Analysis.cpp
@@ -25,7 +25,8 @@ void llvm::initializeAnalysis(PassRegistry &Registry) {
   initializeCallGraphDOTPrinterPass(Registry);
   initializeCallGraphViewerPass(Registry);
   initializeCycleInfoWrapperPassPass(Registry);
-  initializeDXILResourceWrapperPassPass(Registry);
+  initializeDXILResourceBindingWrapperPassPass(Registry);
+  initializeDXILResourceTypeWrapperPassPass(Registry);
   initializeDependenceAnalysisWrapperPassPass(Registry);
   initializeDominanceFrontierWrapperPassPass(Registry);
   initializeDomViewerWrapperPassPass(Registry);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index f96a9468d6bc54..e1942a0c4930cd 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -177,12 +177,19 @@ static dxil::ElementType toDXILElementType(Type *Ty, bool IsSigned) {
   return ElementType::Invalid;
 }
 
-ResourceInfo::ResourceInfo(uint32_t RecordID, uint32_t Space,
-                           uint32_t LowerBound, uint32_t Size,
-                           TargetExtType *HandleTy, bool GloballyCoherent,
-                           bool HasCounter)
-    : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy),
-      GloballyCoherent(GloballyCoherent), HasCounter(HasCounter) {
+ResourceTypeInfo::ResourceTypeInfo(TargetExtType *HandleTy,
+                                   const dxil::ResourceClass RC_,
+                                   const dxil::ResourceKind Kind_,
+                                   bool GloballyCoherent, bool HasCounter)
+    : HandleTy(HandleTy), GloballyCoherent(GloballyCoherent),
+      HasCounter(HasCounter) {
+  // If we're provided a resource class and kind, trust them.
+  if (Kind_ != dxil::ResourceKind::Invalid) {
+    RC = RC_;
+    Kind = Kind_;
+    return;
+  }
+
   if (auto *Ty = dyn_cast<RawBufferExtType>(HandleTy)) {
     RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV;
     Kind = Ty->isStructured() ? ResourceKind::StructuredBuffer
@@ -209,21 +216,21 @@ ResourceInfo::ResourceInfo(uint32_t RecordID, uint32_t Space,
     llvm_unreachable("Unknown handle type");
 }
 
-bool ResourceInfo::isUAV() const { return RC == ResourceClass::UAV; }
+bool ResourceTypeInfo::isUAV() const { return RC == ResourceClass::UAV; }
 
-bool ResourceInfo::isCBuffer() const {
+bool ResourceTypeInfo::isCBuffer() const {
   return RC == ResourceClass::CBuffer;
 }
 
-bool ResourceInfo::isSampler() const {
+bool ResourceTypeInfo::isSampler() const {
   return RC == ResourceClass::Sampler;
 }
 
-bool ResourceInfo::isStruct() const {
+bool ResourceTypeInfo::isStruct() const {
   return Kind == ResourceKind::StructuredBuffer;
 }
 
-bool ResourceInfo::isTyped() const {
+bool ResourceTypeInfo::isTyped() const {
   switch (Kind) {
   case ResourceKind::Texture1D:
   case ResourceKind::Texture2D:
@@ -252,12 +259,12 @@ bool ResourceInfo::isTyped() const {
   llvm_unreachable("Unhandled ResourceKind enum");
 }
 
-bool ResourceInfo::isFeedback() const {
+bool ResourceTypeInfo::isFeedback() const {
   return Kind == ResourceKind::FeedbackTexture2D ||
          Kind == ResourceKind::FeedbackTexture2DArray;
 }
 
-bool ResourceInfo::isMultiSample() const {
+bool ResourceTypeInfo::isMultiSample() const {
   return Kind == ResourceKind::Texture2DMS ||
          Kind == ResourceKind::Texture2DMSArray;
 }
@@ -293,24 +300,24 @@ static bool isROV(dxil::ResourceKind Kind, TargetExtType *Ty) {
   llvm_unreachable("Unhandled ResourceKind enum");
 }
 
-ResourceInfo::UAVInfo ResourceInfo::getUAV() const {
+ResourceTypeInfo::UAVInfo ResourceTypeInfo::getUAV() const {
   assert(isUAV() && "Not a UAV");
   return {GloballyCoherent, HasCounter, isROV(Kind, HandleTy)};
 }
 
-uint32_t ResourceInfo::getCBufferSize(const DataLayout &DL) const {
+uint32_t ResourceTypeInfo::getCBufferSize(const DataLayout &DL) const {
   assert(isCBuffer() && "Not a CBuffer");
   Type *Ty = cast<CBufferExtType>(HandleTy)->getResourceType();
   return DL.getTypeSizeInBits(Ty) / 8;
 }
 
-dxil::SamplerType ResourceInfo::getSamplerType() const {
+dxil::SamplerType ResourceTypeInfo::getSamplerType() const {
   assert(isSampler() && "Not a Sampler");
   return cast<SamplerExtType>(HandleTy)->getSamplerType();
 }
 
-ResourceInfo::StructInfo
-ResourceInfo::getStruct(const DataLayout &DL) const {
+ResourceTypeInfo::StructInfo
+ResourceTypeInfo::getStruct(const DataLayout &DL) const {
   assert(isStruct() && "Not a Struct");
 
   Type *ElTy = cast<RawBufferExtType>(HandleTy)->getResourceType();
@@ -360,7 +367,7 @@ static std::pair<Type *, bool> getTypedElementType(dxil::ResourceKind Kind,
   llvm_unreachable("Unhandled ResourceKind enum");
 }
 
-ResourceInfo::TypedInfo ResourceInfo::getTyped() const {
+ResourceTypeInfo::TypedInfo ResourceTypeInfo::getTyped() const {
   assert(isTyped() && "Not typed");
 
   auto [ElTy, IsSigned] = getTypedElementType(Kind, HandleTy);
@@ -371,17 +378,85 @@ ResourceInfo::TypedInfo ResourceInfo::getTyped() const {
   return {ET, Count};
 }
 
-dxil::SamplerFeedbackType ResourceInfo::getFeedbackType() const {
+dxil::SamplerFeedbackType ResourceTypeInfo::getFeedbackType() const {
   assert(isFeedback() && "Not Feedback");
   return cast<FeedbackTextureExtType>(HandleTy)->getFeedbackType();
 }
-
-uint32_t ResourceInfo::getMultiSampleCount() const {
+uint32_t ResourceTypeInfo::getMultiSampleCount() const {
   assert(isMultiSample() && "Not MultiSampled");
   return cast<MSTextureExtType>(HandleTy)->getSampleCount();
 }
 
-MDTuple *ResourceInfo::getAsMetadata(Module &M) const {
+bool ResourceTypeInfo::operator==(const ResourceTypeInfo &RHS) const {
+  return std::tie(HandleTy, GloballyCoherent, HasCounter) ==
+         std::tie(RHS.HandleTy, RHS.GloballyCoherent, RHS.HasCounter);
+}
+
+bool ResourceTypeInfo::operator<(const ResourceTypeInfo &RHS) const {
+  // An empty datalayout is sufficient for sorting purposes.
+  DataLayout DummyDL;
+  if (std::tie(RC, Kind) < std::tie(RHS.RC, RHS.Kind))
+    return true;
+  if (isCBuffer() && RHS.isCBuffer() &&
+      getCBufferSize(DummyDL) < RHS.getCBufferSize(DummyDL))
+    return true;
+  if (isSampler() && RHS.isSampler() && getSamplerType() < RHS.getSamplerType())
+    return true;
+  if (isUAV() && RHS.isUAV() && getUAV() < RHS.getUAV())
+    return true;
+  if (isStruct() && RHS.isStruct() &&
+      getStruct(DummyDL) < RHS.getStruct(DummyDL))
+    return true;
+  if (isFeedback() && RHS.isFeedback() &&
+      getFeedbackType() < RHS.getFeedbackType())
+    return true;
+  if (isTyped() && RHS.isTyped() && getTyped() < RHS.getTyped())
+    return true;
+  if (isMultiSample() && RHS.isMultiSample() &&
+      getMultiSampleCount() < RHS.getMultiSampleCount())
+    return true;
+  return false;
+}
+
+void ResourceTypeInfo::print(raw_ostream &OS, const DataLayout &DL) const {
+  OS << "  Class: " << getResourceClassName(RC) << "\n"
+     << "  Kind: " << getResourceKindName(Kind) << "\n";
+
+  if (isCBuffer()) {
+    OS << "  CBuffer size: " << getCBufferSize(DL) << "\n";
+  } else if (isSampler()) {
+    OS << "  Sampler Type: " << getSamplerTypeName(getSamplerType()) << "\n";
+  } else {
+    if (isUAV()) {
+      UAVInfo UAVFlags = getUAV();
+      OS << "  Globally Coherent: " << UAVFlags.GloballyCoherent << "\n"
+         << "  HasCounter: " << UAVFlags.HasCounter << "\n"
+         << "  IsROV: " << UAVFlags.IsROV << "\n";
+    }
+    if (isMultiSample())
+      OS << "  Sample Count: " << getMultiSampleCount() << "\n";
+
+    if (isStruct()) {
+      StructInfo Struct = getStruct(DL);
+      OS << "  Buffer Stride: " << Struct.Stride << "\n";
+      OS << "  Alignment: " << Struct.AlignLog2 << "\n";
+    } else if (isTyped()) {
+      TypedInfo Typed = ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Justin Bogner (bogner)

Changes

This splits the DXILResourceAnalysis pass into TypeAnalysis and BindingAnalysis passes. The type analysis pass is made immutable and populated lazily so that it can be used earlier in the pipeline without needing to carefully maintain the invariants of the binding analysis.

Fixes #118400


Patch is 93.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119773.diff

19 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+161-58)
  • (modified) llvm/include/llvm/InitializePasses.h (+2-1)
  • (modified) llvm/include/llvm/LinkAllPasses.h (+2-1)
  • (modified) llvm/lib/Analysis/Analysis.cpp (+2-1)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+225-158)
  • (modified) llvm/lib/Passes/PassRegistry.def (+4-2)
  • (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+17-11)
  • (modified) llvm/lib/Target/DirectX/DXILDataScalarization.cpp (-7)
  • (modified) llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp (-5)
  • (modified) llvm/lib/Target/DirectX/DXILFinalizeLinkage.h (-1)
  • (modified) llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (-7)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (-6)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+33-22)
  • (modified) llvm/lib/Target/DirectX/DXILPrepare.cpp (+1-1)
  • (modified) llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp (+47-32)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+27-20)
  • (modified) llvm/test/Analysis/DXILResource/buffer-frombinding.ll (+2-14)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+2-2)
  • (modified) llvm/unittests/Analysis/DXILResourceTest.cpp (+197-195)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 0205356af54443..2f5dded46538ea 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -25,6 +25,8 @@ class MDTuple;
 class TargetExtType;
 class Value;
 
+class DXILResourceTypeMap;
+
 namespace dxil {
 
 /// The dx.RawBuffer target extension type
@@ -196,27 +198,8 @@ class SamplerExtType : public TargetExtType {
 
 //===----------------------------------------------------------------------===//
 
-class ResourceInfo {
+class ResourceTypeInfo {
 public:
-  struct ResourceBinding {
-    uint32_t RecordID;
-    uint32_t Space;
-    uint32_t LowerBound;
-    uint32_t Size;
-
-    bool operator==(const ResourceBinding &RHS) const {
-      return std::tie(RecordID, Space, LowerBound, Size) ==
-             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
-    }
-    bool operator!=(const ResourceBinding &RHS) const {
-      return !(*this == RHS);
-    }
-    bool operator<(const ResourceBinding &RHS) const {
-      return std::tie(RecordID, Space, LowerBound, Size) <
-             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
-    }
-  };
-
   struct UAVInfo {
     bool GloballyCoherent;
     bool HasCounter;
@@ -266,12 +249,11 @@ class ResourceInfo {
   };
 
 private:
-  ResourceBinding Binding;
   TargetExtType *HandleTy;
 
   // GloballyCoherent and HasCounter aren't really part of the type and need to
-  // be determined by analysis, so they're just provided directly when we
-  // construct these.
+  // be determined by analysis, so they're just provided directly by the
+  // DXILResourceTypeMap when we construct these.
   bool GloballyCoherent;
   bool HasCounter;
 
@@ -279,9 +261,13 @@ class ResourceInfo {
   dxil::ResourceKind Kind;
 
 public:
-  ResourceInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
-               uint32_t Size, TargetExtType *HandleTy,
-               bool GloballyCoherent = false, bool HasCounter = false);
+  ResourceTypeInfo(TargetExtType *HandleTy, const dxil::ResourceClass RC,
+                   const dxil::ResourceKind Kind, bool GloballyCoherent = false,
+                   bool HasCounter = false);
+  ResourceTypeInfo(TargetExtType *HandleTy, bool GloballyCoherent = false,
+                   bool HasCounter = false)
+      : ResourceTypeInfo(HandleTy, {}, dxil::ResourceKind::Invalid,
+                         GloballyCoherent, HasCounter) {}
 
   TargetExtType *getHandleTy() const { return HandleTy; }
 
@@ -303,44 +289,157 @@ class ResourceInfo {
   dxil::SamplerFeedbackType getFeedbackType() const;
   uint32_t getMultiSampleCount() const;
 
-  StringRef getName() const {
-    // TODO: Get the name from the symbol once we include one here.
-    return "";
-  }
   dxil::ResourceClass getResourceClass() const { return RC; }
   dxil::ResourceKind getResourceKind() const { return Kind; }
 
+  bool operator==(const ResourceTypeInfo &RHS) const;
+  bool operator!=(const ResourceTypeInfo &RHS) const { return !(*this == RHS); }
+  bool operator<(const ResourceTypeInfo &RHS) const;
+
+  void print(raw_ostream &OS, const DataLayout &DL) const;
+};
+
+//===----------------------------------------------------------------------===//
+
+class ResourceBindingInfo {
+public:
+  struct ResourceBinding {
+    uint32_t RecordID;
+    uint32_t Space;
+    uint32_t LowerBound;
+    uint32_t Size;
+
+    bool operator==(const ResourceBinding &RHS) const {
+      return std::tie(RecordID, Space, LowerBound, Size) ==
+             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
+    }
+    bool operator!=(const ResourceBinding &RHS) const {
+      return !(*this == RHS);
+    }
+    bool operator<(const ResourceBinding &RHS) const {
+      return std::tie(RecordID, Space, LowerBound, Size) <
+             std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
+    }
+  };
+
+private:
+  ResourceBinding Binding;
+  TargetExtType *HandleTy;
+
+public:
+  ResourceBindingInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
+                      uint32_t Size, TargetExtType *HandleTy)
+      : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy) {}
+
   void setBindingID(unsigned ID) { Binding.RecordID = ID; }
 
   const ResourceBinding &getBinding() const { return Binding; }
+  TargetExtType *getHandleTy() const { return HandleTy; }
+  const StringRef getName() const {
+    // TODO: Get the name from the symbol once we include one here.
+    return "";
+  }
 
-  MDTuple *getAsMetadata(Module &M) const;
-  std::pair<uint32_t, uint32_t> getAnnotateProps(Module &M) const;
+  MDTuple *getAsMetadata(Module &M, DXILResourceTypeMap &DRTM) const;
+  MDTuple *getAsMetadata(Module &M, dxil::ResourceTypeInfo RTI) const;
 
-  bool operator==(const ResourceInfo &RHS) const;
-  bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); }
-  bool operator<(const ResourceInfo &RHS) const;
+  std::pair<uint32_t, uint32_t>
+  getAnnotateProps(Module &M, DXILResourceTypeMap &DRTM) const;
+  std::pair<uint32_t, uint32_t>
+  getAnnotateProps(Module &M, dxil::ResourceTypeInfo RTI) const;
 
-  void print(raw_ostream &OS, const DataLayout &DL) const;
+  bool operator==(const ResourceBindingInfo &RHS) const {
+    return std::tie(Binding, HandleTy) == std::tie(RHS.Binding, RHS.HandleTy);
+  }
+  bool operator!=(const ResourceBindingInfo &RHS) const {
+    return !(*this == RHS);
+  }
+  bool operator<(const ResourceBindingInfo &RHS) const {
+    return Binding < RHS.Binding;
+  }
+
+  void print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
+             const DataLayout &DL) const;
+  void print(raw_ostream &OS, dxil::ResourceTypeInfo RTI,
+             const DataLayout &DL) const;
 };
 
 } // namespace dxil
 
 //===----------------------------------------------------------------------===//
 
-class DXILResourceMap {
-  SmallVector<dxil::ResourceInfo> Infos;
+class DXILResourceTypeMap {
+  struct Info {
+    dxil::ResourceClass RC;
+    dxil::ResourceKind Kind;
+    bool GloballyCoherent;
+    bool HasCounter;
+  };
+  DenseMap<TargetExtType *, Info> Infos;
+
+public:
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv);
+
+  dxil::ResourceTypeInfo operator[](TargetExtType *Ty) {
+    Info I = Infos[Ty];
+    return dxil::ResourceTypeInfo(Ty, I.RC, I.Kind, I.GloballyCoherent,
+                                  I.HasCounter);
+  }
+
+  void setGloballyCoherent(TargetExtType *Ty, bool GloballyCoherent) {
+    Infos[Ty].GloballyCoherent = GloballyCoherent;
+  }
+
+  void setHasCounter(TargetExtType *Ty, bool HasCounter) {
+    Infos[Ty].HasCounter = HasCounter;
+  }
+};
+
+class DXILResourceTypeAnalysis
+    : public AnalysisInfoMixin<DXILResourceTypeAnalysis> {
+  friend AnalysisInfoMixin<DXILResourceTypeAnalysis>;
+
+  static AnalysisKey Key;
+
+public:
+  using Result = DXILResourceTypeMap;
+
+  DXILResourceTypeMap run(Module &M, ModuleAnalysisManager &AM) {
+    return Result();
+  }
+};
+
+class DXILResourceTypeWrapperPass : public ImmutablePass {
+  DXILResourceTypeMap DRTM;
+
+  virtual void anchor();
+
+public:
+  static char ID;
+  DXILResourceTypeWrapperPass();
+
+  DXILResourceTypeMap &getResourceTypeMap() { return DRTM; }
+  const DXILResourceTypeMap &getResourceTypeMap() const { return DRTM; }
+};
+
+ModulePass *createDXILResourceTypeWrapperPassPass();
+
+//===----------------------------------------------------------------------===//
+
+class DXILBindingMap {
+  SmallVector<dxil::ResourceBindingInfo> Infos;
   DenseMap<CallInst *, unsigned> CallMap;
   unsigned FirstUAV = 0;
   unsigned FirstCBuffer = 0;
   unsigned FirstSampler = 0;
 
   /// Populate the map given the resource binding calls in the given module.
-  void populate(Module &M);
+  void populate(Module &M, DXILResourceTypeMap &DRTM);
 
 public:
-  using iterator = SmallVector<dxil::ResourceInfo>::iterator;
-  using const_iterator = SmallVector<dxil::ResourceInfo>::const_iterator;
+  using iterator = SmallVector<dxil::ResourceBindingInfo>::iterator;
+  using const_iterator = SmallVector<dxil::ResourceBindingInfo>::const_iterator;
 
   iterator begin() { return Infos.begin(); }
   const_iterator begin() const { return Infos.begin(); }
@@ -399,47 +498,51 @@ class DXILResourceMap {
     return make_range(sampler_begin(), sampler_end());
   }
 
-  void print(raw_ostream &OS, const DataLayout &DL) const;
+  void print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
+             const DataLayout &DL) const;
 
-  friend class DXILResourceAnalysis;
-  friend class DXILResourceWrapperPass;
+  friend class DXILResourceBindingAnalysis;
+  friend class DXILResourceBindingWrapperPass;
 };
 
-class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
-  friend AnalysisInfoMixin<DXILResourceAnalysis>;
+class DXILResourceBindingAnalysis
+    : public AnalysisInfoMixin<DXILResourceBindingAnalysis> {
+  friend AnalysisInfoMixin<DXILResourceBindingAnalysis>;
 
   static AnalysisKey Key;
 
 public:
-  using Result = DXILResourceMap;
+  using Result = DXILBindingMap;
 
   /// Gather resource info for the module \c M.
-  DXILResourceMap run(Module &M, ModuleAnalysisManager &AM);
+  DXILBindingMap run(Module &M, ModuleAnalysisManager &AM);
 };
 
-/// Printer pass for the \c DXILResourceAnalysis results.
-class DXILResourcePrinterPass : public PassInfoMixin<DXILResourcePrinterPass> {
+/// Printer pass for the \c DXILResourceBindingAnalysis results.
+class DXILResourceBindingPrinterPass
+    : public PassInfoMixin<DXILResourceBindingPrinterPass> {
   raw_ostream &OS;
 
 public:
-  explicit DXILResourcePrinterPass(raw_ostream &OS) : OS(OS) {}
+  explicit DXILResourceBindingPrinterPass(raw_ostream &OS) : OS(OS) {}
 
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 
   static bool isRequired() { return true; }
 };
 
-class DXILResourceWrapperPass : public ModulePass {
-  std::unique_ptr<DXILResourceMap> Map;
+class DXILResourceBindingWrapperPass : public ModulePass {
+  std::unique_ptr<DXILBindingMap> Map;
+  DXILResourceTypeMap *DRTM;
 
 public:
   static char ID; // Class identification, replacement for typeinfo
 
-  DXILResourceWrapperPass();
-  ~DXILResourceWrapperPass() override;
+  DXILResourceBindingWrapperPass();
+  ~DXILResourceBindingWrapperPass() override;
 
-  const DXILResourceMap &getResourceMap() const { return *Map; }
-  DXILResourceMap &getResourceMap() { return *Map; }
+  const DXILBindingMap &getBindingMap() const { return *Map; }
+  DXILBindingMap &getBindingMap() { return *Map; }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override;
   bool runOnModule(Module &M) override;
@@ -449,7 +552,7 @@ class DXILResourceWrapperPass : public ModulePass {
   void dump() const;
 };
 
-ModulePass *createDXILResourceWrapperPassPass();
+ModulePass *createDXILResourceBindingWrapperPassPass();
 
 } // namespace llvm
 
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index 7d829cf5b9b015..1cb9013bc48cc5 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -84,7 +84,8 @@ void initializeDAHPass(PassRegistry &);
 void initializeDCELegacyPassPass(PassRegistry &);
 void initializeDXILMetadataAnalysisWrapperPassPass(PassRegistry &);
 void initializeDXILMetadataAnalysisWrapperPrinterPass(PassRegistry &);
-void initializeDXILResourceWrapperPassPass(PassRegistry &);
+void initializeDXILResourceBindingWrapperPassPass(PassRegistry &);
+void initializeDXILResourceTypeWrapperPassPass(PassRegistry &);
 void initializeDeadMachineInstructionElimPass(PassRegistry &);
 void initializeDebugifyMachineModulePass(PassRegistry &);
 void initializeDependenceAnalysisWrapperPassPass(PassRegistry &);
diff --git a/llvm/include/llvm/LinkAllPasses.h b/llvm/include/llvm/LinkAllPasses.h
index 54245ca0b70222..ac1970334de0cd 100644
--- a/llvm/include/llvm/LinkAllPasses.h
+++ b/llvm/include/llvm/LinkAllPasses.h
@@ -70,7 +70,8 @@ struct ForcePassLinking {
     (void)llvm::createCallGraphViewerPass();
     (void)llvm::createCFGSimplificationPass();
     (void)llvm::createStructurizeCFGPass();
-    (void)llvm::createDXILResourceWrapperPassPass();
+    (void)llvm::createDXILResourceBindingWrapperPassPass();
+    (void)llvm::createDXILResourceTypeWrapperPassPass();
     (void)llvm::createDeadArgEliminationPass();
     (void)llvm::createDeadCodeEliminationPass();
     (void)llvm::createDependenceAnalysisWrapperPass();
diff --git a/llvm/lib/Analysis/Analysis.cpp b/llvm/lib/Analysis/Analysis.cpp
index 58723469f21ca8..bc2b8a57f83a7a 100644
--- a/llvm/lib/Analysis/Analysis.cpp
+++ b/llvm/lib/Analysis/Analysis.cpp
@@ -25,7 +25,8 @@ void llvm::initializeAnalysis(PassRegistry &Registry) {
   initializeCallGraphDOTPrinterPass(Registry);
   initializeCallGraphViewerPass(Registry);
   initializeCycleInfoWrapperPassPass(Registry);
-  initializeDXILResourceWrapperPassPass(Registry);
+  initializeDXILResourceBindingWrapperPassPass(Registry);
+  initializeDXILResourceTypeWrapperPassPass(Registry);
   initializeDependenceAnalysisWrapperPassPass(Registry);
   initializeDominanceFrontierWrapperPassPass(Registry);
   initializeDomViewerWrapperPassPass(Registry);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index f96a9468d6bc54..e1942a0c4930cd 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -177,12 +177,19 @@ static dxil::ElementType toDXILElementType(Type *Ty, bool IsSigned) {
   return ElementType::Invalid;
 }
 
-ResourceInfo::ResourceInfo(uint32_t RecordID, uint32_t Space,
-                           uint32_t LowerBound, uint32_t Size,
-                           TargetExtType *HandleTy, bool GloballyCoherent,
-                           bool HasCounter)
-    : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy),
-      GloballyCoherent(GloballyCoherent), HasCounter(HasCounter) {
+ResourceTypeInfo::ResourceTypeInfo(TargetExtType *HandleTy,
+                                   const dxil::ResourceClass RC_,
+                                   const dxil::ResourceKind Kind_,
+                                   bool GloballyCoherent, bool HasCounter)
+    : HandleTy(HandleTy), GloballyCoherent(GloballyCoherent),
+      HasCounter(HasCounter) {
+  // If we're provided a resource class and kind, trust them.
+  if (Kind_ != dxil::ResourceKind::Invalid) {
+    RC = RC_;
+    Kind = Kind_;
+    return;
+  }
+
   if (auto *Ty = dyn_cast<RawBufferExtType>(HandleTy)) {
     RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV;
     Kind = Ty->isStructured() ? ResourceKind::StructuredBuffer
@@ -209,21 +216,21 @@ ResourceInfo::ResourceInfo(uint32_t RecordID, uint32_t Space,
     llvm_unreachable("Unknown handle type");
 }
 
-bool ResourceInfo::isUAV() const { return RC == ResourceClass::UAV; }
+bool ResourceTypeInfo::isUAV() const { return RC == ResourceClass::UAV; }
 
-bool ResourceInfo::isCBuffer() const {
+bool ResourceTypeInfo::isCBuffer() const {
   return RC == ResourceClass::CBuffer;
 }
 
-bool ResourceInfo::isSampler() const {
+bool ResourceTypeInfo::isSampler() const {
   return RC == ResourceClass::Sampler;
 }
 
-bool ResourceInfo::isStruct() const {
+bool ResourceTypeInfo::isStruct() const {
   return Kind == ResourceKind::StructuredBuffer;
 }
 
-bool ResourceInfo::isTyped() const {
+bool ResourceTypeInfo::isTyped() const {
   switch (Kind) {
   case ResourceKind::Texture1D:
   case ResourceKind::Texture2D:
@@ -252,12 +259,12 @@ bool ResourceInfo::isTyped() const {
   llvm_unreachable("Unhandled ResourceKind enum");
 }
 
-bool ResourceInfo::isFeedback() const {
+bool ResourceTypeInfo::isFeedback() const {
   return Kind == ResourceKind::FeedbackTexture2D ||
          Kind == ResourceKind::FeedbackTexture2DArray;
 }
 
-bool ResourceInfo::isMultiSample() const {
+bool ResourceTypeInfo::isMultiSample() const {
   return Kind == ResourceKind::Texture2DMS ||
          Kind == ResourceKind::Texture2DMSArray;
 }
@@ -293,24 +300,24 @@ static bool isROV(dxil::ResourceKind Kind, TargetExtType *Ty) {
   llvm_unreachable("Unhandled ResourceKind enum");
 }
 
-ResourceInfo::UAVInfo ResourceInfo::getUAV() const {
+ResourceTypeInfo::UAVInfo ResourceTypeInfo::getUAV() const {
   assert(isUAV() && "Not a UAV");
   return {GloballyCoherent, HasCounter, isROV(Kind, HandleTy)};
 }
 
-uint32_t ResourceInfo::getCBufferSize(const DataLayout &DL) const {
+uint32_t ResourceTypeInfo::getCBufferSize(const DataLayout &DL) const {
   assert(isCBuffer() && "Not a CBuffer");
   Type *Ty = cast<CBufferExtType>(HandleTy)->getResourceType();
   return DL.getTypeSizeInBits(Ty) / 8;
 }
 
-dxil::SamplerType ResourceInfo::getSamplerType() const {
+dxil::SamplerType ResourceTypeInfo::getSamplerType() const {
   assert(isSampler() && "Not a Sampler");
   return cast<SamplerExtType>(HandleTy)->getSamplerType();
 }
 
-ResourceInfo::StructInfo
-ResourceInfo::getStruct(const DataLayout &DL) const {
+ResourceTypeInfo::StructInfo
+ResourceTypeInfo::getStruct(const DataLayout &DL) const {
   assert(isStruct() && "Not a Struct");
 
   Type *ElTy = cast<RawBufferExtType>(HandleTy)->getResourceType();
@@ -360,7 +367,7 @@ static std::pair<Type *, bool> getTypedElementType(dxil::ResourceKind Kind,
   llvm_unreachable("Unhandled ResourceKind enum");
 }
 
-ResourceInfo::TypedInfo ResourceInfo::getTyped() const {
+ResourceTypeInfo::TypedInfo ResourceTypeInfo::getTyped() const {
   assert(isTyped() && "Not typed");
 
   auto [ElTy, IsSigned] = getTypedElementType(Kind, HandleTy);
@@ -371,17 +378,85 @@ ResourceInfo::TypedInfo ResourceInfo::getTyped() const {
   return {ET, Count};
 }
 
-dxil::SamplerFeedbackType ResourceInfo::getFeedbackType() const {
+dxil::SamplerFeedbackType ResourceTypeInfo::getFeedbackType() const {
   assert(isFeedback() && "Not Feedback");
   return cast<FeedbackTextureExtType>(HandleTy)->getFeedbackType();
 }
-
-uint32_t ResourceInfo::getMultiSampleCount() const {
+uint32_t ResourceTypeInfo::getMultiSampleCount() const {
   assert(isMultiSample() && "Not MultiSampled");
   return cast<MSTextureExtType>(HandleTy)->getSampleCount();
 }
 
-MDTuple *ResourceInfo::getAsMetadata(Module &M) const {
+bool ResourceTypeInfo::operator==(const ResourceTypeInfo &RHS) const {
+  return std::tie(HandleTy, GloballyCoherent, HasCounter) ==
+         std::tie(RHS.HandleTy, RHS.GloballyCoherent, RHS.HasCounter);
+}
+
+bool ResourceTypeInfo::operator<(const ResourceTypeInfo &RHS) const {
+  // An empty datalayout is sufficient for sorting purposes.
+  DataLayout DummyDL;
+  if (std::tie(RC, Kind) < std::tie(RHS.RC, RHS.Kind))
+    return true;
+  if (isCBuffer() && RHS.isCBuffer() &&
+      getCBufferSize(DummyDL) < RHS.getCBufferSize(DummyDL))
+    return true;
+  if (isSampler() && RHS.isSampler() && getSamplerType() < RHS.getSamplerType())
+    return true;
+  if (isUAV() && RHS.isUAV() && getUAV() < RHS.getUAV())
+    return true;
+  if (isStruct() && RHS.isStruct() &&
+      getStruct(DummyDL) < RHS.getStruct(DummyDL))
+    return true;
+  if (isFeedback() && RHS.isFeedback() &&
+      getFeedbackType() < RHS.getFeedbackType())
+    return true;
+  if (isTyped() && RHS.isTyped() && getTyped() < RHS.getTyped())
+    return true;
+  if (isMultiSample() && RHS.isMultiSample() &&
+      getMultiSampleCount() < RHS.getMultiSampleCount())
+    return true;
+  return false;
+}
+
+void ResourceTypeInfo::print(raw_ostream &OS, const DataLayout &DL) const {
+  OS << "  Class: " << getResourceClassName(RC) << "\n"
+     << "  Kind: " << getResourceKindName(Kind) << "\n";
+
+  if (isCBuffer()) {
+    OS << "  CBuffer size: " << getCBufferSize(DL) << "\n";
+  } else if (isSampler()) {
+    OS << "  Sampler Type: " << getSamplerTypeName(getSamplerType()) << "\n";
+  } else {
+    if (isUAV()) {
+      UAVInfo UAVFlags = getUAV();
+      OS << "  Globally Coherent: " << UAVFlags.GloballyCoherent << "\n"
+         << "  HasCounter: " << UAVFlags.HasCounter << "\n"
+         << "  IsROV: " << UAVFlags.IsROV << "\n";
+    }
+    if (isMultiSample())
+      OS << "  Sample Count: " << getMultiSampleCount() << "\n";
+
+    if (isStruct()) {
+      StructInfo Struct = getStruct(DL);
+      OS << "  Buffer Stride: " << Struct.Stride << "\n";
+      OS << "  Alignment: " << Struct.AlignLog2 << "\n";
+    } else if (isTyped()) {
+      TypedInfo Typed = ...
[truncated]

bogner added a commit to bogner/llvm-project that referenced this pull request Dec 12, 2024
@bogner bogner linked an issue Dec 12, 2024 that may be closed by this pull request
@bogner bogner force-pushed the 2024-12-12-split-resource-analysis branch from 7bf113d to aefabdf Compare December 16, 2024 18:49
@bogner bogner changed the base branch from users/bogner/119772 to main December 16, 2024 23:07
This splits the DXILResourceAnalysis pass into TypeAnalysis and
BindingAnalysis passes. The type analysis pass is made immutable and
populated lazily so that it can be used earlier in the pipeline without
needing to carefully maintain the invariants of the binding analysis.

Fixes llvm#118400
@bogner bogner force-pushed the 2024-12-12-split-resource-analysis branch from 801f94b to 305a0f5 Compare December 16, 2024 23:13
bogner added a commit to bogner/llvm-project that referenced this pull request Dec 16, 2024
Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@bogner bogner merged commit 3eca15c into llvm:main Dec 18, 2024
8 checks passed
bogner added a commit to bogner/llvm-project that referenced this pull request Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

4 participants