Skip to content

[NVPTX] add an optional early copy of byval arguments #113384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};

struct NVPTXCopyByValArgsPass : PassInfoMixin<NVPTXCopyByValArgsPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

namespace NVPTX {
enum DrvInterface {
NVCL,
Expand Down
74 changes: 49 additions & 25 deletions llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,33 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
PI.setAborted(&II);
}
}; // struct ArgUseChecker

void copyByValParam(Function &F, Argument &Arg) {
LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
// Otherwise we have to create a temporary copy.
BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
Type *StructType = Arg.getParamByValType();
const DataLayout &DL = F.getDataLayout();
AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
Arg.getName(), FirstInst);
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
.value_or(DL.getPrefTypeAlign(StructType)));
Arg.replaceAllUsesWith(AllocA);

Value *ArgInParam = new AddrSpaceCastInst(
&Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
Arg.getName(), FirstInst);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load
// is definitely not volatile.
const auto ArgSize = *AllocA->getAllocationSize(DL);
IRBuilder<> IRB(&*FirstInst);
IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
ArgSize);
}
} // namespace

void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
Expand All @@ -558,7 +585,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,

ArgUseChecker AUC(DL, IsGridConstant);
ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
// Easy case, accessing parameter directly is fine.
if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
// Convert all loads and intermediate operations to use parameter AS and
Expand Down Expand Up @@ -587,7 +614,6 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
// However, we're still not allowed to write to it. If the user specified
// `__grid_constant__` for the argument, we'll consider escaped pointer as
// read-only.
unsigned AS = DL.getAllocaAddrSpace();
if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
// Replace all argument pointer uses (which might include a device function
Expand All @@ -612,29 +638,8 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,

// Do not replace Arg in the cast to param space
CastToParam->setOperand(0, Arg);
} else {
LLVM_DEBUG(dbgs() << "Creating a local copy of " << *Arg << "\n");
// Otherwise we have to create a temporary copy.
AllocaInst *AllocA =
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
.value_or(DL.getPrefTypeAlign(StructType)));
Arg->replaceAllUsesWith(AllocA);

Value *ArgInParam = new AddrSpaceCastInst(
Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
Arg->getName(), FirstInst);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load
// is definitely not volatile.
const auto ArgSize = *AllocA->getAllocationSize(DL);
IRBuilder<> IRB(&*FirstInst);
IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
ArgSize);
}
} else
copyByValParam(*Func, *Arg);
}

void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
Expand Down Expand Up @@ -734,3 +739,22 @@ bool NVPTXLowerArgs::runOnFunction(Function &F) {
}

FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }

static bool copyFunctionByValArgs(Function &F) {
LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
<< "\n");
bool Changed = false;
for (Argument &Arg : F.args())
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
!(isParamGridConstant(Arg) && isKernelFunction(F))) {
copyByValParam(F, Arg);
Changed = true;
}
return Changed;
}

PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F,
FunctionAnalysisManager &AM) {
return copyFunctionByValArgs(F) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
}
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
#endif
FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
#undef FUNCTION_PASS
25 changes: 25 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ static cl::opt<bool> UseShortPointersOpt(
"Use 32-bit pointers for accessing const/local/shared address spaces."),
cl::init(false), cl::Hidden);

// byval arguments in NVPTX are special. We're only allowed to read from them
// using a special instruction, and if we ever need to write to them or take an
// address, we must make a local copy and use it, instead.
//
// The problem is that local copies are very expensive, and we create them very
// late in the compilation pipeline, so LLVM does not have much of a chance to
// eliminate them, if they turn out to be unnecessary.
//
// One way around that is to create such copies early on, and let them percolate
// through the optimizations. The copying itself will never trigger creation of
// another copy later on, as the reads are allowed. If LLVM can eliminate it,
// it's a win. It the full optimization pipeline can't remove the copy, that's
// as good as it gets in terms of the effort we could've done, and it's
// certainly a much better effort than what we do now.
//
// This early injection of the copies has potential to create undesireable
// side-effects, so it's disabled by default, for now, until it sees more
// testing.
static cl::opt<bool> EarlyByValArgsCopy(
"nvptx-early-byval-copy",
cl::desc("Create a copy of byval function arguments early."),
cl::init(false), cl::Hidden);

namespace llvm {

void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
Expand Down Expand Up @@ -236,6 +259,8 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
// Note: NVVMIntrRangePass was causing numerical discrepancies at one
// point, if issues crop up, consider disabling.
FPM.addPass(NVVMIntrRangePass());
if (EarlyByValArgsCopy)
FPM.addPass(NVPTXCopyByValArgsPass());
PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
});
}
Expand Down
Loading
Loading