Skip to content

[DirectX] Introduce the DXILResourceAccess pass #116726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 18, 2024

Conversation

bogner
Copy link
Contributor

@bogner bogner commented Nov 19, 2024

This pass transforms resource access via llvm.dx.resource.getpointer into buffer loads and stores.

Fixes #114848.

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Justin Bogner (bogner)

Changes

This pass transforms resource access via llvm.dx.resource.getpointer into buffer loads and stores.

Fixes #114848.


Patch is 24.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116726.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.cpp (+196)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.h (+28)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+7)
  • (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+4-1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+3)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll (+35)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll (+103)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+2-2)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 6b577c02f05450..cd2ea3e07ee5b5 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -275,6 +275,9 @@ class DXILResourceMap {
   DXILResourceMap(
       SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);
 
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv);
+
   iterator begin() { return Resources.begin(); }
   const_iterator begin() const { return Resources.begin(); }
   iterator end() { return Resources.end(); }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 48a9595f844f05..0d324f541d7663 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -27,6 +27,9 @@ def int_dx_handle_fromBinding
           [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
           [IntrNoMem]>;
 
+def int_dx_resource_getpointer
+    : DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty],
+                            [IntrNoMem]>;
 def int_dx_typedBufferLoad
     : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty],
                             [IntrReadMem]>;
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 2802480481690d..44909376928d65 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -744,6 +744,12 @@ DXILResourceMap::DXILResourceMap(
   }
 }
 
+bool DXILResourceMap::invalidate(Module &M, const PreservedAnalyses &PA,
+                                 ModuleAnalysisManager::Invalidator &Inv) {
+  auto PAC = PA.getChecker<DXILResourceAnalysis>();
+  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>());
+}
+
 void DXILResourceMap::print(raw_ostream &OS) const {
   for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
     OS << "Binding " << I << ":\n";
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index a726071e0dcecd..26315db891b577 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
   DXILPrettyPrinter.cpp
   DXILResource.cpp
   DXILResourceAnalysis.cpp
+  DXILResourceAccess.cpp
   DXILShaderFlags.cpp
   DXILTranslateMetadata.cpp
 
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
new file mode 100644
index 00000000000000..f9b28800b74909
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
@@ -0,0 +1,196 @@
+//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILResourceAccess.h"
+#include "DirectX.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/InitializePasses.h"
+
+#define DEBUG_TYPE "dxil-resource-access"
+
+using namespace llvm;
+
+static void replaceTypedBufferAccess(IntrinsicInst *II,
+                                     dxil::ResourceInfo &RI) {
+  const DataLayout &DL = II->getDataLayout();
+
+  auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
+  assert(HandleType->getName() == "dx.TypedBuffer" &&
+         "Unexpected typed buffer type");
+  Type *ContainedType = HandleType->getTypeParameter(0);
+  Type *ScalarType = ContainedType->getScalarType();
+  uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
+  int NumElements = ContainedType->getNumContainedTypes();
+  if (!NumElements)
+    NumElements = 1;
+
+  // Process users keeping track of indexing accumulated from GEPs.
+  struct AccessAndIndex {
+    User *Access;
+    Value *Index;
+  };
+  SmallVector<AccessAndIndex> Worklist;
+  for (User *U : II->users())
+    Worklist.push_back({U, nullptr});
+
+  SmallVector<Instruction *> DeadInsts;
+  while (!Worklist.empty()) {
+    AccessAndIndex Current = Worklist.back();
+    Worklist.pop_back();
+
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
+      IRBuilder<> Builder(GEP);
+
+      Value *Index;
+      APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+      if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
+        APInt Scaled = ConstantOffset.udiv(ScalarSize);
+        Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
+      } else {
+        auto IndexIt = GEP->idx_begin();
+        assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
+               "GEP is not indexing through pointer");
+        ++IndexIt;
+        Index = *IndexIt;
+        assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
+      }
+
+      for (User *U : GEP->users())
+        Worklist.push_back({U, Index});
+      DeadInsts.push_back(GEP);
+
+    } else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
+      assert(SI->getValueOperand() != II && "Pointer escaped!");
+      IRBuilder<> Builder(SI);
+
+      Value *V = SI->getValueOperand();
+      if (V->getType() == ContainedType) {
+        // V is already the right type.
+      } else if (V->getType() == ScalarType) {
+        // We're storing a scalar, so we need to load the current value and only
+        // replace the relevant part.
+        auto *Load = Builder.CreateIntrinsic(
+            ContainedType, Intrinsic::dx_typedBufferLoad,
+            {II->getOperand(0), II->getOperand(1)});
+        // If we have an offset from seeing a GEP earlier, use it.
+        Value *IndexOp = Current.Index
+                             ? Current.Index
+                             : ConstantInt::get(Builder.getInt32Ty(), 0);
+        V = Builder.CreateInsertElement(Load, V, IndexOp);
+      } else {
+        llvm_unreachable("Store to typed resource has invalid type");
+      }
+
+      auto *Inst = Builder.CreateIntrinsic(
+          Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
+          {II->getOperand(0), II->getOperand(1), V});
+      SI->replaceAllUsesWith(Inst);
+      DeadInsts.push_back(SI);
+
+    } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
+      IRBuilder<> Builder(LI);
+      Value *V =
+          Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
+                                  {II->getOperand(0), II->getOperand(1)});
+      if (Current.Index)
+        V = Builder.CreateExtractElement(V, Current.Index);
+
+      LI->replaceAllUsesWith(V);
+      DeadInsts.push_back(LI);
+
+    } else
+      llvm_unreachable("Unhandled instruction - pointer escaped?");
+  }
+
+  // Traverse the now-dead instructions in RPO and remove them.
+  for (Instruction *Dead : llvm::reverse(DeadInsts))
+    Dead->eraseFromParent();
+  II->eraseFromParent();
+}
+
+static bool transformResourcePointers(Function &F, DXILResourceMap &DRM) {
+  // TODO: Should we have a more efficient way to find resources used in a
+  // particular function?
+  SmallVector<std::pair<IntrinsicInst *, dxil::ResourceInfo &>> Resources;
+  for (BasicBlock &BB : F)
+    for (Instruction &I : BB)
+      if (auto *CI = dyn_cast<CallInst>(&I)) {
+        auto It = DRM.find(CI);
+        if (It == DRM.end())
+          continue;
+        for (User *U : CI->users())
+          if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U))
+            if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer)
+              Resources.emplace_back(II, *It);
+      }
+
+  for (const auto &[II, RI] : Resources) {
+    if (RI.isTyped())
+      replaceTypedBufferAccess(II, RI);
+
+    // TODO: handle other resource types. We should probably have an
+    // `unreachable` here once we've added support for all of them.
+  }
+
+  return false;
+}
+
+PreservedAnalyses DXILResourceAccess::run(Function &F,
+                                          FunctionAnalysisManager &FAM) {
+  auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
+  DXILResourceMap *DRM =
+      MAMProxy.getCachedResult<DXILResourceAnalysis>(*F.getParent());
+  assert(DRM && "DXILResourceAnalysis must be available");
+
+  bool MadeChanges = transformResourcePointers(F, *DRM);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+
+  PreservedAnalyses PA;
+  PA.preserve<DXILResourceAnalysis>();
+  PA.preserve<DominatorTreeAnalysis>();
+  return PA;
+}
+
+namespace {
+class DXILResourceAccessLegacy : public FunctionPass {
+public:
+  bool runOnFunction(Function &F) override {
+    DXILResourceMap &DRM =
+        getAnalysis<DXILResourceWrapperPass>().getResourceMap();
+
+    return transformResourcePointers(F, DRM);
+  }
+  StringRef getPassName() const override { return "DXIL Resource Access"; }
+  DXILResourceAccessLegacy() : FunctionPass(ID) {}
+
+  static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    AU.addRequired<DXILResourceWrapperPass>();
+    AU.addPreserved<DXILResourceWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
+  }
+};
+char DXILResourceAccessLegacy::ID = 0;
+} // end anonymous namespace
+
+INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
+                      "DXIL Resource Access", false, false)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
+INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
+                    "DXIL Resource Access", false, false)
+
+FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
+  return new DXILResourceAccessLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.h b/llvm/lib/Target/DirectX/DXILResourceAccess.h
new file mode 100644
index 00000000000000..ac47db21266f64
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.h
@@ -0,0 +1,28 @@
+//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file Pass for replacing pointers to DXIL resources with load and store
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class DXILResourceAccess: public PassInfoMixin<DXILResourceAccess> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3454f16ecd5955..add23587de7d58 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -12,6 +12,7 @@
 #define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H
 
 namespace llvm {
+class FunctionPass;
 class ModulePass;
 class PassRegistry;
 class raw_ostream;
@@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 /// Pass to lowering LLVM intrinsic call to DXIL op function call.
 ModulePass *createDXILOpLoweringLegacyPass();
 
+/// Initializer for DXILResourceAccess
+void initializeDXILResourceAccessLegacyPass(PassRegistry &);
+
+/// Pass to update resource accesses to use load/store directly.
+FunctionPass *createDXILResourceAccessLegacyPass();
+
 /// Initializer for DXILTranslateMetadata.
 void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index a0f864ed39375f..87591b104ce52c 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
 // TODO: rename to print<foo> after NPM switch
 MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
 #undef MODULE_PASS
+
+#ifndef FUNCTION_PASS
+#define FUNCTION_PASS(NAME, CREATE_PASS)
+#endif
+FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
+#undef FUNCTION_PASS
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 655427a3e80209..9dade16ffe2732 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -17,6 +17,7 @@
 #include "DXILIntrinsicExpansion.h"
 #include "DXILOpLowering.h"
 #include "DXILPrettyPrinter.h"
+#include "DXILResourceAccess.h"
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DXILTranslateMetadata.h"
@@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeWriteDXILPassPass(*PR);
   initializeDXContainerGlobalsPass(*PR);
   initializeDXILOpLoweringLegacyPass(*PR);
+  initializeDXILResourceAccessLegacyPass(*PR);
   initializeDXILTranslateMetadataLegacyPass(*PR);
   initializeDXILResourceMDWrapperPass(*PR);
   initializeShaderFlagsAnalysisWrapperPass(*PR);
@@ -91,9 +93,10 @@ class DirectXPassConfig : public TargetPassConfig {
   void addCodeGenPrepare() override {
     addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILDataScalarizationLegacyPass());
+    addPass(createDXILFlattenArraysLegacyPass());
+    addPass(createDXILResourceAccessLegacyPass());
     ScalarizerPassOptions DxilScalarOptions;
     DxilScalarOptions.ScalarizeLoadStore = true;
-    addPass(createDXILFlattenArraysLegacyPass());
     addPass(createScalarizerPass(DxilScalarOptions));
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILFinalizeLinkageLegacyPass());
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 03d069c9fcb36d..9341bc8bc02de6 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Argument.h"
@@ -351,6 +352,7 @@ void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequired<DominatorTreeWrapperPass>();
   AU.addRequired<TargetTransformInfoWrapperPass>();
   AU.addPreserved<DominatorTreeWrapperPass>();
+  AU.addPreserved<DXILResourceWrapperPass>();
 }
 
 char ScalarizerLegacyPass::ID = 0;
@@ -1348,5 +1350,6 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM)
   bool Changed = Impl.visit(F);
   PreservedAnalyses PA;
   PA.preserve<DominatorTreeAnalysis>();
+  PA.preserve<DXILResourceAnalysis>();
   return Changed ? PA : PreservedAnalyses::all();
 }
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
new file mode 100644
index 00000000000000..2c17ec674632ba
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
@@ -0,0 +1,35 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @use_float4(<4 x float>)
+declare void @use_float(<4 x float>)
+
+; CHECK-LABEL: define void @load_float4
+define void @load_float4(i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  %vec_data = load <4 x float>, ptr %ptr
+  call void @use_float4(<4 x float> %vec_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 4
+  %y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 4
+  %y_data = load float, ptr %y_ptr
+  call void @use_float(float %y_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  %dyndata = load float, ptr %dynamic
+  call void @use_float(float %dyndata)
+
+  ret void
+}
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
new file mode 100644
index 00000000000000..dd63acc3c0e96c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
@@ -0,0 +1,103 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+; CHECK-LABEL: define void @store_float4
+define void @store_float4(<4 x float> %data, i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; Store the whole value
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %data)
+  store <4 x float> %data, ptr %ptr
+
+  ; Store just the .x component
+  %scalar = extractelement <4 x float> %data, i32 0
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 0
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  store float %scalar, ptr %ptr
+
+  ; Store just the .y component
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 1
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %y_ptr = getelementptr inbounds i8, ptr %ptr, i32 4
+  store float %scalar, ptr %y_ptr
+
+  ; Store to one of the elements dynamically
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 %elemindex
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  store float %scalar, ptr %dynamic
+
+  ret void
+}
+
+; CHECK-LABEL: define void @store_half4
+define void @store_half4(<4 x half> %data, i32 %index) {
+  %buffer = call target("dx.TypedBuffer", <4 x half>, 1, 0, 0)
+      @llvm....
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

Changes

This pass transforms resource access via llvm.dx.resource.getpointer into buffer loads and stores.

Fixes #114848.


Patch is 24.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116726.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.cpp (+196)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.h (+28)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+7)
  • (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+4-1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+3)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll (+35)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll (+103)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+2-2)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 6b577c02f05450..cd2ea3e07ee5b5 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -275,6 +275,9 @@ class DXILResourceMap {
   DXILResourceMap(
       SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);
 
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv);
+
   iterator begin() { return Resources.begin(); }
   const_iterator begin() const { return Resources.begin(); }
   iterator end() { return Resources.end(); }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 48a9595f844f05..0d324f541d7663 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -27,6 +27,9 @@ def int_dx_handle_fromBinding
           [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
           [IntrNoMem]>;
 
+def int_dx_resource_getpointer
+    : DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty],
+                            [IntrNoMem]>;
 def int_dx_typedBufferLoad
     : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty],
                             [IntrReadMem]>;
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 2802480481690d..44909376928d65 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -744,6 +744,12 @@ DXILResourceMap::DXILResourceMap(
   }
 }
 
+bool DXILResourceMap::invalidate(Module &M, const PreservedAnalyses &PA,
+                                 ModuleAnalysisManager::Invalidator &Inv) {
+  auto PAC = PA.getChecker<DXILResourceAnalysis>();
+  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>());
+}
+
 void DXILResourceMap::print(raw_ostream &OS) const {
   for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
     OS << "Binding " << I << ":\n";
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index a726071e0dcecd..26315db891b577 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
   DXILPrettyPrinter.cpp
   DXILResource.cpp
   DXILResourceAnalysis.cpp
+  DXILResourceAccess.cpp
   DXILShaderFlags.cpp
   DXILTranslateMetadata.cpp
 
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
new file mode 100644
index 00000000000000..f9b28800b74909
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
@@ -0,0 +1,196 @@
+//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILResourceAccess.h"
+#include "DirectX.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/InitializePasses.h"
+
+#define DEBUG_TYPE "dxil-resource-access"
+
+using namespace llvm;
+
+static void replaceTypedBufferAccess(IntrinsicInst *II,
+                                     dxil::ResourceInfo &RI) {
+  const DataLayout &DL = II->getDataLayout();
+
+  auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
+  assert(HandleType->getName() == "dx.TypedBuffer" &&
+         "Unexpected typed buffer type");
+  Type *ContainedType = HandleType->getTypeParameter(0);
+  Type *ScalarType = ContainedType->getScalarType();
+  uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
+  int NumElements = ContainedType->getNumContainedTypes();
+  if (!NumElements)
+    NumElements = 1;
+
+  // Process users keeping track of indexing accumulated from GEPs.
+  struct AccessAndIndex {
+    User *Access;
+    Value *Index;
+  };
+  SmallVector<AccessAndIndex> Worklist;
+  for (User *U : II->users())
+    Worklist.push_back({U, nullptr});
+
+  SmallVector<Instruction *> DeadInsts;
+  while (!Worklist.empty()) {
+    AccessAndIndex Current = Worklist.back();
+    Worklist.pop_back();
+
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
+      IRBuilder<> Builder(GEP);
+
+      Value *Index;
+      APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+      if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
+        APInt Scaled = ConstantOffset.udiv(ScalarSize);
+        Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
+      } else {
+        auto IndexIt = GEP->idx_begin();
+        assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
+               "GEP is not indexing through pointer");
+        ++IndexIt;
+        Index = *IndexIt;
+        assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
+      }
+
+      for (User *U : GEP->users())
+        Worklist.push_back({U, Index});
+      DeadInsts.push_back(GEP);
+
+    } else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
+      assert(SI->getValueOperand() != II && "Pointer escaped!");
+      IRBuilder<> Builder(SI);
+
+      Value *V = SI->getValueOperand();
+      if (V->getType() == ContainedType) {
+        // V is already the right type.
+      } else if (V->getType() == ScalarType) {
+        // We're storing a scalar, so we need to load the current value and only
+        // replace the relevant part.
+        auto *Load = Builder.CreateIntrinsic(
+            ContainedType, Intrinsic::dx_typedBufferLoad,
+            {II->getOperand(0), II->getOperand(1)});
+        // If we have an offset from seeing a GEP earlier, use it.
+        Value *IndexOp = Current.Index
+                             ? Current.Index
+                             : ConstantInt::get(Builder.getInt32Ty(), 0);
+        V = Builder.CreateInsertElement(Load, V, IndexOp);
+      } else {
+        llvm_unreachable("Store to typed resource has invalid type");
+      }
+
+      auto *Inst = Builder.CreateIntrinsic(
+          Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
+          {II->getOperand(0), II->getOperand(1), V});
+      SI->replaceAllUsesWith(Inst);
+      DeadInsts.push_back(SI);
+
+    } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
+      IRBuilder<> Builder(LI);
+      Value *V =
+          Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
+                                  {II->getOperand(0), II->getOperand(1)});
+      if (Current.Index)
+        V = Builder.CreateExtractElement(V, Current.Index);
+
+      LI->replaceAllUsesWith(V);
+      DeadInsts.push_back(LI);
+
+    } else
+      llvm_unreachable("Unhandled instruction - pointer escaped?");
+  }
+
+  // Traverse the now-dead instructions in RPO and remove them.
+  for (Instruction *Dead : llvm::reverse(DeadInsts))
+    Dead->eraseFromParent();
+  II->eraseFromParent();
+}
+
+static bool transformResourcePointers(Function &F, DXILResourceMap &DRM) {
+  // TODO: Should we have a more efficient way to find resources used in a
+  // particular function?
+  SmallVector<std::pair<IntrinsicInst *, dxil::ResourceInfo &>> Resources;
+  for (BasicBlock &BB : F)
+    for (Instruction &I : BB)
+      if (auto *CI = dyn_cast<CallInst>(&I)) {
+        auto It = DRM.find(CI);
+        if (It == DRM.end())
+          continue;
+        for (User *U : CI->users())
+          if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U))
+            if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer)
+              Resources.emplace_back(II, *It);
+      }
+
+  for (const auto &[II, RI] : Resources) {
+    if (RI.isTyped())
+      replaceTypedBufferAccess(II, RI);
+
+    // TODO: handle other resource types. We should probably have an
+    // `unreachable` here once we've added support for all of them.
+  }
+
+  return false;
+}
+
+PreservedAnalyses DXILResourceAccess::run(Function &F,
+                                          FunctionAnalysisManager &FAM) {
+  auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
+  DXILResourceMap *DRM =
+      MAMProxy.getCachedResult<DXILResourceAnalysis>(*F.getParent());
+  assert(DRM && "DXILResourceAnalysis must be available");
+
+  bool MadeChanges = transformResourcePointers(F, *DRM);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+
+  PreservedAnalyses PA;
+  PA.preserve<DXILResourceAnalysis>();
+  PA.preserve<DominatorTreeAnalysis>();
+  return PA;
+}
+
+namespace {
+class DXILResourceAccessLegacy : public FunctionPass {
+public:
+  bool runOnFunction(Function &F) override {
+    DXILResourceMap &DRM =
+        getAnalysis<DXILResourceWrapperPass>().getResourceMap();
+
+    return transformResourcePointers(F, DRM);
+  }
+  StringRef getPassName() const override { return "DXIL Resource Access"; }
+  DXILResourceAccessLegacy() : FunctionPass(ID) {}
+
+  static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    AU.addRequired<DXILResourceWrapperPass>();
+    AU.addPreserved<DXILResourceWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
+  }
+};
+char DXILResourceAccessLegacy::ID = 0;
+} // end anonymous namespace
+
+INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
+                      "DXIL Resource Access", false, false)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
+INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
+                    "DXIL Resource Access", false, false)
+
+FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
+  return new DXILResourceAccessLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.h b/llvm/lib/Target/DirectX/DXILResourceAccess.h
new file mode 100644
index 00000000000000..ac47db21266f64
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.h
@@ -0,0 +1,28 @@
+//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file Pass for replacing pointers to DXIL resources with load and store
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class DXILResourceAccess: public PassInfoMixin<DXILResourceAccess> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3454f16ecd5955..add23587de7d58 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -12,6 +12,7 @@
 #define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H
 
 namespace llvm {
+class FunctionPass;
 class ModulePass;
 class PassRegistry;
 class raw_ostream;
@@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 /// Pass to lowering LLVM intrinsic call to DXIL op function call.
 ModulePass *createDXILOpLoweringLegacyPass();
 
+/// Initializer for DXILResourceAccess
+void initializeDXILResourceAccessLegacyPass(PassRegistry &);
+
+/// Pass to update resource accesses to use load/store directly.
+FunctionPass *createDXILResourceAccessLegacyPass();
+
 /// Initializer for DXILTranslateMetadata.
 void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index a0f864ed39375f..87591b104ce52c 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
 // TODO: rename to print<foo> after NPM switch
 MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
 #undef MODULE_PASS
+
+#ifndef FUNCTION_PASS
+#define FUNCTION_PASS(NAME, CREATE_PASS)
+#endif
+FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
+#undef FUNCTION_PASS
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 655427a3e80209..9dade16ffe2732 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -17,6 +17,7 @@
 #include "DXILIntrinsicExpansion.h"
 #include "DXILOpLowering.h"
 #include "DXILPrettyPrinter.h"
+#include "DXILResourceAccess.h"
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DXILTranslateMetadata.h"
@@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeWriteDXILPassPass(*PR);
   initializeDXContainerGlobalsPass(*PR);
   initializeDXILOpLoweringLegacyPass(*PR);
+  initializeDXILResourceAccessLegacyPass(*PR);
   initializeDXILTranslateMetadataLegacyPass(*PR);
   initializeDXILResourceMDWrapperPass(*PR);
   initializeShaderFlagsAnalysisWrapperPass(*PR);
@@ -91,9 +93,10 @@ class DirectXPassConfig : public TargetPassConfig {
   void addCodeGenPrepare() override {
     addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILDataScalarizationLegacyPass());
+    addPass(createDXILFlattenArraysLegacyPass());
+    addPass(createDXILResourceAccessLegacyPass());
     ScalarizerPassOptions DxilScalarOptions;
     DxilScalarOptions.ScalarizeLoadStore = true;
-    addPass(createDXILFlattenArraysLegacyPass());
     addPass(createScalarizerPass(DxilScalarOptions));
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILFinalizeLinkageLegacyPass());
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 03d069c9fcb36d..9341bc8bc02de6 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Argument.h"
@@ -351,6 +352,7 @@ void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequired<DominatorTreeWrapperPass>();
   AU.addRequired<TargetTransformInfoWrapperPass>();
   AU.addPreserved<DominatorTreeWrapperPass>();
+  AU.addPreserved<DXILResourceWrapperPass>();
 }
 
 char ScalarizerLegacyPass::ID = 0;
@@ -1348,5 +1350,6 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM)
   bool Changed = Impl.visit(F);
   PreservedAnalyses PA;
   PA.preserve<DominatorTreeAnalysis>();
+  PA.preserve<DXILResourceAnalysis>();
   return Changed ? PA : PreservedAnalyses::all();
 }
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
new file mode 100644
index 00000000000000..2c17ec674632ba
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
@@ -0,0 +1,35 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @use_float4(<4 x float>)
+declare void @use_float(<4 x float>)
+
+; CHECK-LABEL: define void @load_float4
+define void @load_float4(i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  %vec_data = load <4 x float>, ptr %ptr
+  call void @use_float4(<4 x float> %vec_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 4
+  %y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 4
+  %y_data = load float, ptr %y_ptr
+  call void @use_float(float %y_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  %dyndata = load float, ptr %dynamic
+  call void @use_float(float %dyndata)
+
+  ret void
+}
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
new file mode 100644
index 00000000000000..dd63acc3c0e96c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
@@ -0,0 +1,103 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+; CHECK-LABEL: define void @store_float4
+define void @store_float4(<4 x float> %data, i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; Store the whole value
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %data)
+  store <4 x float> %data, ptr %ptr
+
+  ; Store just the .x component
+  %scalar = extractelement <4 x float> %data, i32 0
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 0
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  store float %scalar, ptr %ptr
+
+  ; Store just the .y component
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 1
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %y_ptr = getelementptr inbounds i8, ptr %ptr, i32 4
+  store float %scalar, ptr %y_ptr
+
+  ; Store to one of the elements dynamically
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 %elemindex
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  store float %scalar, ptr %dynamic
+
+  ret void
+}
+
+; CHECK-LABEL: define void @store_half4
+define void @store_half4(<4 x half> %data, i32 %index) {
+  %buffer = call target("dx.TypedBuffer", <4 x half>, 1, 0, 0)
+      @llvm....
[truncated]

Copy link

github-actions bot commented Nov 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of questions about the tests, otherwise LGTM!

Copy link
Contributor

@damyanp damyanp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

bogner added a commit to bogner/llvm-project that referenced this pull request Nov 27, 2024
Including this removes the PR's dependency on llvm#116726
@bogner bogner force-pushed the 2024-11-18-resource-access branch from 765ddaf to f12a6bf Compare December 12, 2024 22:18
@bogner bogner changed the base branch from main to users/bogner/119773 December 12, 2024 22:29
@bogner bogner force-pushed the 2024-11-18-resource-access branch from f12a6bf to 958f7ec Compare December 12, 2024 22:49
@bogner bogner force-pushed the users/bogner/119773 branch 3 times, most recently from f8291d2 to 4f018d3 Compare December 16, 2024 23:15
@bogner bogner force-pushed the 2024-11-18-resource-access branch from 0ee6dbf to 2171296 Compare December 16, 2024 23:18
@bogner bogner changed the base branch from users/bogner/119773 to main December 18, 2024 16:03
@bogner bogner force-pushed the 2024-11-18-resource-access branch from 2171296 to 4666082 Compare December 18, 2024 16:06
@bogner bogner merged commit 0fca76d into llvm:main Dec 18, 2024
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[DirectX] Replace resource accesses in the device address space to typedbuffer load and store
7 participants