@@ -1080,8 +1080,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1080
1080
}
1081
1081
1082
1082
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1083
- const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1084
- EmitFallbackCallbackTy emitTargetCallFallbackCB , TargetKernelArgs &Args,
1083
+ const LocationDescription &Loc, Value *OutlinedFnID,
1084
+ EmitFallbackCallbackTy EmitTargetCallFallbackCB , TargetKernelArgs &Args,
1085
1085
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1086
1086
1087
1087
if (!updateToLocation(Loc))
@@ -1134,7 +1134,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1134
1134
1135
1135
auto CurFn = Builder.GetInsertBlock()->getParent();
1136
1136
emitBlock(OffloadFailedBlock, CurFn);
1137
- Builder.restoreIP(emitTargetCallFallbackCB (Builder.saveIP()));
1137
+ Builder.restoreIP(EmitTargetCallFallbackCB (Builder.saveIP()));
1138
1138
emitBranch(OffloadContBlock);
1139
1139
emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1140
1140
return Builder.saveIP();
@@ -1736,7 +1736,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1736
1736
// - All code is inserted in the entry block of the current function.
1737
1737
static Value *emitTaskDependencies(
1738
1738
OpenMPIRBuilder &OMPBuilder,
1739
- SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1739
+ const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1740
1740
// Early return if we have no dependencies to process
1741
1741
if (Dependencies.empty())
1742
1742
return nullptr;
@@ -6403,16 +6403,44 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
6403
6403
SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6404
6404
}
6405
6405
6406
- Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6407
- PointerNum, RTArgs.BasePointersArray,
6408
- RTArgs.PointersArray, RTArgs.SizesArray,
6409
- RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6410
- RTArgs.MappersArray};
6406
+ SmallVector<llvm::Value *, 13> OffloadingArgs = {
6407
+ SrcLocInfo, DeviceID,
6408
+ PointerNum, RTArgs.BasePointersArray,
6409
+ RTArgs.PointersArray, RTArgs.SizesArray,
6410
+ RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6411
+ RTArgs.MappersArray};
6411
6412
6412
6413
if (IsStandAlone) {
6413
6414
assert(MapperFunc && "MapperFunc missing for standalone target data");
6414
- Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
6415
- OffloadingArgs);
6415
+
6416
+ auto TaskBodyCB = [&](Value *, Value *, IRBuilderBase::InsertPoint) {
6417
+ if (Info.HasNoWait) {
6418
+ OffloadingArgs.append({llvm::Constant::getNullValue(Int32),
6419
+ llvm::Constant::getNullValue(VoidPtr),
6420
+ llvm::Constant::getNullValue(Int32),
6421
+ llvm::Constant::getNullValue(VoidPtr)});
6422
+ }
6423
+
6424
+ Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
6425
+ OffloadingArgs);
6426
+
6427
+ if (Info.HasNoWait) {
6428
+ BasicBlock *OffloadContBlock =
6429
+ BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
6430
+ Function *CurFn = Builder.GetInsertBlock()->getParent();
6431
+ emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
6432
+ Builder.restoreIP(Builder.saveIP());
6433
+ }
6434
+ };
6435
+
6436
+ bool RequiresOuterTargetTask = Info.HasNoWait;
6437
+
6438
+ if (!RequiresOuterTargetTask)
6439
+ TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
6440
+ /*TargetTaskAllocaIP=*/{});
6441
+ else
6442
+ emitTargetTask(TaskBodyCB, DeviceID, SrcLocInfo, AllocaIP,
6443
+ /*Dependencies=*/{}, Info.HasNoWait);
6416
6444
} else {
6417
6445
Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6418
6446
omp::OMPRTL___tgt_target_data_begin_mapper);
@@ -6836,13 +6864,18 @@ static void emitTargetOutlinedFunction(
6836
6864
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
6837
6865
IsOffloadEntry, OutlinedFn, OutlinedFnID);
6838
6866
}
6867
+
6839
6868
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6840
- Function *OutlinedFn, Value *OutlinedFnID,
6841
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
6842
- Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
6843
- SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6869
+ TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
6870
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
6871
+ const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6844
6872
bool HasNoWait) {
6845
6873
6874
+ // The following explains the code-gen scenario for the `target` directive. A
6875
+ // similar scneario is followed for other device-related directives (e.g.
6876
+ // `target enter data`) but in similar fashion since we only need to emit task
6877
+ // that encapsulates the proper runtime call.
6878
+ //
6846
6879
// When we arrive at this function, the target region itself has been
6847
6880
// outlined into the function OutlinedFn.
6848
6881
// So at ths point, for
@@ -6950,22 +6983,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6950
6983
6951
6984
Builder.restoreIP(TargetTaskBodyIP);
6952
6985
6953
- if (OutlinedFnID) {
6954
- // emitKernelLaunch makes the necessary runtime call to offload the kernel.
6955
- // We then outline all that code into a separate function
6956
- // ('kernel_launch_function' in the pseudo code above). This function is
6957
- // then called by the target task proxy function (see
6958
- // '@.omp_target_task_proxy_func' in the pseudo code above)
6959
- // "@.omp_target_task_proxy_func' is generated by
6960
- // emitTargetTaskProxyFunction.
6961
- Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
6962
- EmitTargetCallFallbackCB, Args, DeviceID,
6963
- RTLoc, TargetTaskAllocaIP));
6964
- } else {
6965
- // When OutlinedFnID is set to nullptr, then it's not an offloading call. In
6966
- // this case, we execute the host implementation directly.
6967
- Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
6968
- }
6986
+ TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP);
6969
6987
6970
6988
OI.ExitBB = Builder.saveIP().getBlock();
6971
6989
OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, HasNoWait,
@@ -7153,18 +7171,40 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7153
7171
bool HasDependencies = Dependencies.size() > 0;
7154
7172
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7155
7173
7174
+ OpenMPIRBuilder::TargetKernelArgs KArgs;
7175
+
7176
+ auto TaskBodyCB = [&](Value *DeviceID, Value *RTLoc,
7177
+ IRBuilderBase::InsertPoint TargetTaskAllocaIP) {
7178
+ if (OutlinedFnID) {
7179
+ // emitKernelLaunch makes the necessary runtime call to offload the
7180
+ // kernel. We then outline all that code into a separate function
7181
+ // ('kernel_launch_function' in the pseudo code above). This function is
7182
+ // then called by the target task proxy function (see
7183
+ // '@.omp_target_task_proxy_func' in the pseudo code above)
7184
+ // "@.omp_target_task_proxy_func' is generated by
7185
+ // emitTargetTaskProxyFunction.
7186
+ Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7187
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7188
+ RTLoc, TargetTaskAllocaIP));
7189
+ } else {
7190
+ // When OutlinedFnID is set to nullptr, then it's not an offloading
7191
+ // call. In this case, we execute the host implementation directly.
7192
+ OMPBuilder.Builder.restoreIP(
7193
+ EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP()));
7194
+ }
7195
+ };
7196
+
7156
7197
// If we don't have an ID for the target region, it means an offload entry
7157
7198
// wasn't created. In this case we just run the host fallback directly.
7158
7199
if (!OutlinedFnID) {
7159
7200
if (RequiresOuterTargetTask) {
7160
7201
// Arguments that are intended to be directly forwarded to an
7161
7202
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
7162
7203
// results in that call not being done.
7163
- OpenMPIRBuilder::TargetKernelArgs KArgs;
7164
- Builder.restoreIP(OMPBuilder.emitTargetTask(
7165
- OutlinedFn, /*OutlinedFnID=*/nullptr, EmitTargetCallFallbackCB, KArgs,
7166
- /*DeviceID=*/nullptr, /*RTLoc=*/nullptr, AllocaIP, Dependencies,
7167
- HasNoWait));
7204
+ Builder.restoreIP(OMPBuilder.emitTargetTask(TaskBodyCB,
7205
+ /*DeviceID=*/nullptr,
7206
+ /*RTLoc=*/nullptr, AllocaIP,
7207
+ Dependencies, HasNoWait));
7168
7208
} else {
7169
7209
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
7170
7210
}
@@ -7201,20 +7241,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7201
7241
// TODO: Use correct DynCGGroupMem
7202
7242
Value *DynCGGroupMem = Builder.getInt32(0);
7203
7243
7204
- OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
7205
- NumTeamsC, NumThreadsC, DynCGGroupMem ,
7206
- HasNoWait);
7244
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(
7245
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7246
+ DynCGGroupMem, HasNoWait);
7207
7247
7208
7248
// The presence of certain clauses on the target directive require the
7209
7249
// explicit generation of the target task.
7210
7250
if (RequiresOuterTargetTask) {
7211
7251
Builder.restoreIP(OMPBuilder.emitTargetTask(
7212
- OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7213
- RTLoc, AllocaIP, Dependencies, HasNoWait));
7252
+ TaskBodyCB, DeviceID, RTLoc, AllocaIP, Dependencies, HasNoWait));
7214
7253
} else {
7215
7254
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7216
- Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7217
- DeviceID, RTLoc, AllocaIP));
7255
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID, RTLoc ,
7256
+ AllocaIP));
7218
7257
}
7219
7258
}
7220
7259
0 commit comments