Skip to content

[DirectX] Introduce the DXILResourceAccess pass #116726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
DXILPrettyPrinter.cpp
DXILResource.cpp
DXILResourceAnalysis.cpp
DXILResourceAccess.cpp
DXILShaderFlags.cpp
DXILTranslateMetadata.cpp

Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,14 @@ class OpLowerer {
});
}

[[nodiscard]] bool lowerGetPointer(Function &F) {
// These should have already been handled in DXILResourceAccess, so we can
// just clean up the dead prototype.
assert(F.user_empty() && "getpointer operations should have been removed");
F.eraseFromParent();
return false;
}

[[nodiscard]] bool lowerTypedBufferStore(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int8Ty = IRB.getInt8Ty();
Expand Down Expand Up @@ -707,6 +715,9 @@ class OpLowerer {
case Intrinsic::dx_handle_fromBinding:
HasErrors |= lowerHandleFromBinding(F);
break;
case Intrinsic::dx_resource_getpointer:
HasErrors |= lowerGetPointer(F);
break;
case Intrinsic::dx_typedBufferLoad:
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
break;
Expand Down
192 changes: 192 additions & 0 deletions llvm/lib/Target/DirectX/DXILResourceAccess.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DXILResourceAccess.h"
#include "DirectX.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/InitializePasses.h"

#define DEBUG_TYPE "dxil-resource-access"

using namespace llvm;

static void replaceTypedBufferAccess(IntrinsicInst *II,
dxil::ResourceTypeInfo &RTI) {
const DataLayout &DL = II->getDataLayout();

auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
assert(HandleType->getName() == "dx.TypedBuffer" &&
"Unexpected typed buffer type");
Type *ContainedType = HandleType->getTypeParameter(0);

// We need the size of an element in bytes so that we can calculate the offset
// in elements given a total offset in bytes later.
Type *ScalarType = ContainedType->getScalarType();
uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;

// Process users keeping track of indexing accumulated from GEPs.
struct AccessAndIndex {
User *Access;
Value *Index;
};
SmallVector<AccessAndIndex> Worklist;
for (User *U : II->users())
Worklist.push_back({U, nullptr});

SmallVector<Instruction *> DeadInsts;
while (!Worklist.empty()) {
AccessAndIndex Current = Worklist.back();
Worklist.pop_back();

if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
IRBuilder<> Builder(GEP);

Value *Index;
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
APInt Scaled = ConstantOffset.udiv(ScalarSize);
Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
} else {
auto IndexIt = GEP->idx_begin();
assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
"GEP is not indexing through pointer");
++IndexIt;
Index = *IndexIt;
assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
}

for (User *U : GEP->users())
Worklist.push_back({U, Index});
DeadInsts.push_back(GEP);

} else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
assert(SI->getValueOperand() != II && "Pointer escaped!");
IRBuilder<> Builder(SI);

Value *V = SI->getValueOperand();
if (V->getType() == ContainedType) {
// V is already the right type.
} else if (V->getType() == ScalarType) {
// We're storing a scalar, so we need to load the current value and only
// replace the relevant part.
auto *Load = Builder.CreateIntrinsic(
ContainedType, Intrinsic::dx_typedBufferLoad,
{II->getOperand(0), II->getOperand(1)});
// If we have an offset from seeing a GEP earlier, use it.
Value *IndexOp = Current.Index
? Current.Index
: ConstantInt::get(Builder.getInt32Ty(), 0);
V = Builder.CreateInsertElement(Load, V, IndexOp);
} else {
llvm_unreachable("Store to typed resource has invalid type");
}

auto *Inst = Builder.CreateIntrinsic(
Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
{II->getOperand(0), II->getOperand(1), V});
SI->replaceAllUsesWith(Inst);
DeadInsts.push_back(SI);

} else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
IRBuilder<> Builder(LI);
Value *V =
Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
{II->getOperand(0), II->getOperand(1)});
if (Current.Index)
V = Builder.CreateExtractElement(V, Current.Index);

LI->replaceAllUsesWith(V);
DeadInsts.push_back(LI);

} else
llvm_unreachable("Unhandled instruction - pointer escaped?");
}

// Traverse the now-dead instructions in RPO and remove them.
for (Instruction *Dead : llvm::reverse(DeadInsts))
Dead->eraseFromParent();
II->eraseFromParent();
}

static bool transformResourcePointers(Function &F, DXILResourceTypeMap &DRTM) {
bool Changed = false;
SmallVector<std::pair<IntrinsicInst *, dxil::ResourceTypeInfo>> Resources;
for (BasicBlock &BB : F)
for (Instruction &I : BB)
if (auto *II = dyn_cast<IntrinsicInst>(&I))
if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) {
auto *HandleTy = cast<TargetExtType>(II->getArgOperand(0)->getType());
Resources.emplace_back(II, DRTM[HandleTy]);
}

for (auto &[II, RI] : Resources) {
if (RI.isTyped()) {
Changed = true;
replaceTypedBufferAccess(II, RI);
}

// TODO: handle other resource types. We should probably have an
// `unreachable` here once we've added support for all of them.
}

return Changed;
}

PreservedAnalyses DXILResourceAccess::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
DXILResourceTypeMap *DRTM =
MAMProxy.getCachedResult<DXILResourceTypeAnalysis>(*F.getParent());
assert(DRTM && "DXILResourceTypeAnalysis must be available");

bool MadeChanges = transformResourcePointers(F, *DRTM);
if (!MadeChanges)
return PreservedAnalyses::all();

PreservedAnalyses PA;
PA.preserve<DXILResourceTypeAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
return PA;
}

namespace {
class DXILResourceAccessLegacy : public FunctionPass {
public:
bool runOnFunction(Function &F) override {
DXILResourceTypeMap &DRTM =
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();

return transformResourcePointers(F, DRTM);
}
StringRef getPassName() const override { return "DXIL Resource Access"; }
DXILResourceAccessLegacy() : FunctionPass(ID) {}

static char ID; // Pass identification.
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
AU.addRequired<DXILResourceTypeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
}
};
char DXILResourceAccessLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
"DXIL Resource Access", false, false)
INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
"DXIL Resource Access", false, false)

FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
return new DXILResourceAccessLegacy();
}
28 changes: 28 additions & 0 deletions llvm/lib/Target/DirectX/DXILResourceAccess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file Pass for replacing pointers to DXIL resources with load and store
// operations.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H

#include "llvm/IR/PassManager.h"

namespace llvm {

class DXILResourceAccess : public PassInfoMixin<DXILResourceAccess> {
public:
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

} // namespace llvm

#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
7 changes: 7 additions & 0 deletions llvm/lib/Target/DirectX/DirectX.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H

namespace llvm {
class FunctionPass;
class ModulePass;
class PassRegistry;
class raw_ostream;
Expand Down Expand Up @@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
/// Pass to lowering LLVM intrinsic call to DXIL op function call.
ModulePass *createDXILOpLoweringLegacyPass();

/// Initializer for DXILResourceAccess
void initializeDXILResourceAccessLegacyPass(PassRegistry &);

/// Pass to update resource accesses to use load/store directly.
FunctionPass *createDXILResourceAccessLegacyPass();

/// Initializer for DXILTranslateMetadata.
void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/DirectX/DirectXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
// TODO: rename to print<foo> after NPM switch
MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
#undef MODULE_PASS

#ifndef FUNCTION_PASS
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
#undef FUNCTION_PASS
5 changes: 4 additions & 1 deletion llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "DXILIntrinsicExpansion.h"
#include "DXILOpLowering.h"
#include "DXILPrettyPrinter.h"
#include "DXILResourceAccess.h"
#include "DXILResourceAnalysis.h"
#include "DXILShaderFlags.h"
#include "DXILTranslateMetadata.h"
Expand Down Expand Up @@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
initializeWriteDXILPassPass(*PR);
initializeDXContainerGlobalsPass(*PR);
initializeDXILOpLoweringLegacyPass(*PR);
initializeDXILResourceAccessLegacyPass(*PR);
initializeDXILTranslateMetadataLegacyPass(*PR);
initializeDXILResourceMDWrapperPass(*PR);
initializeShaderFlagsAnalysisWrapperPass(*PR);
Expand Down Expand Up @@ -92,9 +94,10 @@ class DirectXPassConfig : public TargetPassConfig {
addPass(createDXILFinalizeLinkageLegacyPass());
addPass(createDXILIntrinsicExpansionLegacyPass());
addPass(createDXILDataScalarizationLegacyPass());
addPass(createDXILFlattenArraysLegacyPass());
addPass(createDXILResourceAccessLegacyPass());
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createDXILFlattenArraysLegacyPass());
addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
Expand Down
35 changes: 35 additions & 0 deletions llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
; RUN: opt -S -dxil-resource-access %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.6-compute"

declare void @use_float4(<4 x float>)
declare void @use_float(float)

; CHECK-LABEL: define void @load_float4
define void @load_float4(i32 %index, i32 %elemindex) {
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK-NOT: @llvm.dx.resource.getpointer
%ptr = call ptr @llvm.dx.resource.getpointer(
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
%vec_data = load <4 x float>, ptr %ptr
call void @use_float4(<4 x float> %vec_data)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
; CHECK: extractelement <4 x float> %[[VALUE]], i32 1
%y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 1
%y_data = load float, ptr %y_ptr
call void @use_float(float %y_data)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
%dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
%dyndata = load float, ptr %dynamic
call void @use_float(float %dyndata)

ret void
}
Loading
Loading