Skip to content

[flang][OpenMP] Support target enter|update|exit .. nowait #113305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9672,8 +9672,8 @@ static void emitTargetCallKernelLaunch(
DynCGGroupMem, HasNoWait);

CGF.Builder.restoreIP(OMPRuntime->getOMPBuilder().emitKernelLaunch(
CGF.Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, Args,
DeviceID, RTLoc, AllocaIP));
CGF.Builder, OutlinedFnID, EmitTargetCallFallbackCB, Args, DeviceID,
RTLoc, AllocaIP));
};

if (RequiresOuterTask)
Expand Down
39 changes: 25 additions & 14 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2264,6 +2264,9 @@ class OpenMPIRBuilder {

bool EmitDebug = false;

/// Whether the `target ... data` directive has a `nowait` clause.
bool HasNoWait = false;

explicit TargetDataInfo() {}
explicit TargetDataInfo(bool RequiresDevicePointerInfo,
bool SeparateBeginEndCalls)
Expand Down Expand Up @@ -2342,26 +2345,34 @@ class OpenMPIRBuilder {
/// Generate a target region entry call and host fallback call.
///
/// \param Loc The location at which the request originated and is fulfilled.
/// \param OutlinedFn The outlined kernel function.
/// \param OutlinedFnID The ooulined function ID.
/// \param EmitTargetCallFallbackCB Call back function to generate host
/// fallback code.
/// \param Args Data structure holding information about the kernel arguments.
/// \param DeviceID Identifier for the device via the 'device' clause.
/// \param RTLoc Source location identifier
/// \param AllocaIP The insertion point to be used for alloca instructions.
InsertPointTy emitKernelLaunch(
const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP);
InsertPointTy
emitKernelLaunch(const LocationDescription &Loc, Value *OutlinedFnID,
EmitFallbackCallbackTy EmitTargetCallFallbackCB,
TargetKernelArgs &Args, Value *DeviceID, Value *RTLoc,
InsertPointTy AllocaIP);

/// Callback type for generating the bodies of device directives that require
/// outer target tasks (e.g. in case of having `nowait` or `depend` clauses).
///
/// \param DeviceID The ID of the device on which the target region will
/// execute.
/// \param RTLoc Source location identifier
/// \Param TargetTaskAllocaIP Insertion point for the alloca block of the
/// generated task.
using TargetTaskBodyCallbackTy =
function_ref<void(Value *DeviceID, Value *RTLoc,
IRBuilderBase::InsertPoint TargetTaskAllocaIP)>;

/// Generate a target-task for the target construct
///
/// \param OutlinedFn The outlined device/target kernel function.
/// \param OutlinedFnID The ooulined function ID.
/// \param EmitTargetCallFallbackCB Call back function to generate host
/// fallback code.
/// \param Args Data structure holding information about the kernel arguments.
/// \param TaskBodyCB Callback to generate the actual body of the target task.
/// \param DeviceID Identifier for the device via the 'device' clause.
/// \param RTLoc Source location identifier
/// \param AllocaIP The insertion point to be used for alloca instructions.
Expand All @@ -2370,10 +2381,10 @@ class OpenMPIRBuilder {
/// \param HasNoWait True if the target construct had 'nowait' on it, false
/// otherwise
InsertPointTy emitTargetTask(
Function *OutlinedFn, Value *OutlinedFnID,
EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP,
SmallVector<OpenMPIRBuilder::DependData> &Dependencies, bool HasNoWait);
TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
bool HasNoWait);

/// Emit the arguments to be passed to the runtime library based on the
/// arrays of base pointers, pointers, sizes, map types, and mappers. If
Expand Down
125 changes: 82 additions & 43 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
const LocationDescription &Loc, Value *OutlinedFnID,
EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {

if (!updateToLocation(Loc))
Expand Down Expand Up @@ -1134,7 +1134,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(

auto CurFn = Builder.GetInsertBlock()->getParent();
emitBlock(OffloadFailedBlock, CurFn);
Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
emitBranch(OffloadContBlock);
emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
return Builder.saveIP();
Expand Down Expand Up @@ -1736,7 +1736,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
// - All code is inserted in the entry block of the current function.
static Value *emitTaskDependencies(
OpenMPIRBuilder &OMPBuilder,
SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
// Early return if we have no dependencies to process
if (Dependencies.empty())
return nullptr;
Expand Down Expand Up @@ -6403,16 +6403,44 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
}

Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
PointerNum, RTArgs.BasePointersArray,
RTArgs.PointersArray, RTArgs.SizesArray,
RTArgs.MapTypesArray, RTArgs.MapNamesArray,
RTArgs.MappersArray};
SmallVector<llvm::Value *, 13> OffloadingArgs = {
SrcLocInfo, DeviceID,
PointerNum, RTArgs.BasePointersArray,
RTArgs.PointersArray, RTArgs.SizesArray,
RTArgs.MapTypesArray, RTArgs.MapNamesArray,
RTArgs.MappersArray};

if (IsStandAlone) {
assert(MapperFunc && "MapperFunc missing for standalone target data");
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
OffloadingArgs);

auto TaskBodyCB = [&](Value *, Value *, IRBuilderBase::InsertPoint) {
if (Info.HasNoWait) {
OffloadingArgs.append({llvm::Constant::getNullValue(Int32),
llvm::Constant::getNullValue(VoidPtr),
llvm::Constant::getNullValue(Int32),
llvm::Constant::getNullValue(VoidPtr)});
}

Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
OffloadingArgs);

if (Info.HasNoWait) {
BasicBlock *OffloadContBlock =
BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
Function *CurFn = Builder.GetInsertBlock()->getParent();
emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
Builder.restoreIP(Builder.saveIP());
}
};

bool RequiresOuterTargetTask = Info.HasNoWait;

if (!RequiresOuterTargetTask)
TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
/*TargetTaskAllocaIP=*/{});
else
emitTargetTask(TaskBodyCB, DeviceID, SrcLocInfo, AllocaIP,
/*Dependencies=*/{}, Info.HasNoWait);
} else {
Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
omp::OMPRTL___tgt_target_data_begin_mapper);
Expand Down Expand Up @@ -6836,13 +6864,18 @@ static void emitTargetOutlinedFunction(
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
IsOffloadEntry, OutlinedFn, OutlinedFnID);
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
Function *OutlinedFn, Value *OutlinedFnID,
EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
bool HasNoWait) {

// The following explains the code-gen scenario for the `target` directive. A
// similar scneario is followed for other device-related directives (e.g.
// `target enter data`) but in similar fashion since we only need to emit task
// that encapsulates the proper runtime call.
//
// When we arrive at this function, the target region itself has been
// outlined into the function OutlinedFn.
// So at ths point, for
Expand Down Expand Up @@ -6950,22 +6983,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(

Builder.restoreIP(TargetTaskBodyIP);

if (OutlinedFnID) {
// emitKernelLaunch makes the necessary runtime call to offload the kernel.
// We then outline all that code into a separate function
// ('kernel_launch_function' in the pseudo code above). This function is
// then called by the target task proxy function (see
// '@.omp_target_task_proxy_func' in the pseudo code above)
// "@.omp_target_task_proxy_func' is generated by
// emitTargetTaskProxyFunction.
Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
EmitTargetCallFallbackCB, Args, DeviceID,
RTLoc, TargetTaskAllocaIP));
} else {
// When OutlinedFnID is set to nullptr, then it's not an offloading call. In
// this case, we execute the host implementation directly.
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}
TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP);

OI.ExitBB = Builder.saveIP().getBlock();
OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, HasNoWait,
Expand Down Expand Up @@ -7153,18 +7171,40 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
bool HasDependencies = Dependencies.size() > 0;
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;

OpenMPIRBuilder::TargetKernelArgs KArgs;

auto TaskBodyCB = [&](Value *DeviceID, Value *RTLoc,
IRBuilderBase::InsertPoint TargetTaskAllocaIP) {
if (OutlinedFnID) {
// emitKernelLaunch makes the necessary runtime call to offload the
// kernel. We then outline all that code into a separate function
// ('kernel_launch_function' in the pseudo code above). This function is
// then called by the target task proxy function (see
// '@.omp_target_task_proxy_func' in the pseudo code above)
// "@.omp_target_task_proxy_func' is generated by
// emitTargetTaskProxyFunction.
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
RTLoc, TargetTaskAllocaIP));
} else {
// When OutlinedFnID is set to nullptr, then it's not an offloading
// call. In this case, we execute the host implementation directly.
OMPBuilder.Builder.restoreIP(
EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP()));
}
};

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
if (RequiresOuterTargetTask) {
// Arguments that are intended to be directly forwarded to an
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
// results in that call not being done.
OpenMPIRBuilder::TargetKernelArgs KArgs;
Builder.restoreIP(OMPBuilder.emitTargetTask(
OutlinedFn, /*OutlinedFnID=*/nullptr, EmitTargetCallFallbackCB, KArgs,
/*DeviceID=*/nullptr, /*RTLoc=*/nullptr, AllocaIP, Dependencies,
HasNoWait));
Builder.restoreIP(OMPBuilder.emitTargetTask(TaskBodyCB,
/*DeviceID=*/nullptr,
/*RTLoc=*/nullptr, AllocaIP,
Dependencies, HasNoWait));
} else {
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}
Expand Down Expand Up @@ -7201,20 +7241,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);

OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
NumTeamsC, NumThreadsC, DynCGGroupMem,
HasNoWait);
KArgs = OpenMPIRBuilder::TargetKernelArgs(
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask) {
Builder.restoreIP(OMPBuilder.emitTargetTask(
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
RTLoc, AllocaIP, Dependencies, HasNoWait));
TaskBodyCB, DeviceID, RTLoc, AllocaIP, Dependencies, HasNoWait));
} else {
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP));
Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID, RTLoc,
AllocaIP));
}
}

Expand Down
34 changes: 22 additions & 12 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());

llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
/*SeparateBeginEndCalls=*/true);

LogicalResult result =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
Expand All @@ -2905,9 +2907,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
return success();
})
.Case([&](omp::TargetEnterDataOp enterDataOp) {
if (enterDataOp.getNowait())
if (!enterDataOp.getDependVars().empty())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling of depend looks unrelated to handling nowait. Since it reuses code, I don't mind in this case.

return (LogicalResult)(enterDataOp.emitError(
"`nowait` is not supported yet"));
"`depend` is not supported yet"));

if (auto ifVar = enterDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
Expand All @@ -2917,14 +2919,18 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
RTLFn =
enterDataOp.getNowait()
? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
: llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
mapVars = enterDataOp.getMapVars();
info.HasNoWait = enterDataOp.getNowait();
return success();
})
.Case([&](omp::TargetExitDataOp exitDataOp) {
if (exitDataOp.getNowait())
if (!exitDataOp.getDependVars().empty())
return (LogicalResult)(exitDataOp.emitError(
"`nowait` is not supported yet"));
"`depend` is not supported yet"));

if (auto ifVar = exitDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
Expand All @@ -2935,14 +2941,17 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();

RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
RTLFn = exitDataOp.getNowait()
? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
: llvm::omp::OMPRTL___tgt_target_data_end_mapper;
mapVars = exitDataOp.getMapVars();
info.HasNoWait = exitDataOp.getNowait();
return success();
})
.Case([&](omp::TargetUpdateOp updateDataOp) {
if (updateDataOp.getNowait())
if (!updateDataOp.getDependVars().empty())
return (LogicalResult)(updateDataOp.emitError(
"`nowait` is not supported yet"));
"`depend` is not supported yet"));

if (auto ifVar = updateDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
Expand All @@ -2953,8 +2962,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();

RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
RTLFn =
updateDataOp.getNowait()
? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
: llvm::omp::OMPRTL___tgt_target_data_update_mapper;
mapVars = updateDataOp.getMapVars();
info.HasNoWait = updateDataOp.getNowait();
return success();
})
.Default([&](Operation *op) {
Expand Down Expand Up @@ -3005,9 +3018,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
: basePointer);
};

llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
/*SeparateBeginEndCalls=*/true);

using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
Expand Down
Loading
Loading