@@ -1080,8 +1080,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1080
1080
}
1081
1081
1082
1082
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch (
1083
- const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1084
- EmitFallbackCallbackTy emitTargetCallFallbackCB , TargetKernelArgs &Args,
1083
+ const LocationDescription &Loc, Value *OutlinedFnID,
1084
+ EmitFallbackCallbackTy EmitTargetCallFallbackCB , TargetKernelArgs &Args,
1085
1085
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1086
1086
1087
1087
if (!updateToLocation (Loc))
@@ -1134,7 +1134,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1134
1134
1135
1135
auto CurFn = Builder.GetInsertBlock ()->getParent ();
1136
1136
emitBlock (OffloadFailedBlock, CurFn);
1137
- Builder.restoreIP (emitTargetCallFallbackCB (Builder.saveIP ()));
1137
+ Builder.restoreIP (EmitTargetCallFallbackCB (Builder.saveIP ()));
1138
1138
emitBranch (OffloadContBlock);
1139
1139
emitBlock (OffloadContBlock, CurFn, /* IsFinished=*/ true );
1140
1140
return Builder.saveIP ();
@@ -1736,7 +1736,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1736
1736
// - All code is inserted in the entry block of the current function.
1737
1737
static Value *emitTaskDependencies (
1738
1738
OpenMPIRBuilder &OMPBuilder,
1739
- SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1739
+ const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1740
1740
// Early return if we have no dependencies to process
1741
1741
if (Dependencies.empty ())
1742
1742
return nullptr ;
@@ -6386,12 +6386,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
6386
6386
// closing of the region.
6387
6387
auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6388
6388
MapInfo = &GenMapInfoCB (Builder.saveIP ());
6389
- emitOffloadingArrays (AllocaIP, Builder.saveIP (), *MapInfo, Info,
6390
- /* IsNonContiguous=*/ true , DeviceAddrCB,
6391
- CustomMapperCB);
6392
-
6393
6389
TargetDataRTArgs RTArgs;
6394
- emitOffloadingArraysArgument (Builder, RTArgs, Info);
6390
+ emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info, RTArgs,
6391
+ *MapInfo, /* IsNonContiguous=*/ true ,
6392
+ /* ForEndCall=*/ false );
6395
6393
6396
6394
// Emit the number of elements in the offloading arrays.
6397
6395
Value *PointerNum = Builder.getInt32 (Info.NumberOfPtrs );
@@ -6403,16 +6401,45 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
6403
6401
SrcLocInfo = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
6404
6402
}
6405
6403
6406
- Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6407
- PointerNum, RTArgs.BasePointersArray ,
6408
- RTArgs.PointersArray , RTArgs.SizesArray ,
6409
- RTArgs.MapTypesArray , RTArgs.MapNamesArray ,
6410
- RTArgs.MappersArray };
6404
+ SmallVector<llvm::Value *, 13 > OffloadingArgs = {
6405
+ SrcLocInfo, DeviceID,
6406
+ PointerNum, RTArgs.BasePointersArray ,
6407
+ RTArgs.PointersArray , RTArgs.SizesArray ,
6408
+ RTArgs.MapTypesArray , RTArgs.MapNamesArray ,
6409
+ RTArgs.MappersArray };
6411
6410
6412
6411
if (IsStandAlone) {
6413
6412
assert (MapperFunc && " MapperFunc missing for standalone target data" );
6414
- Builder.CreateCall (getOrCreateRuntimeFunctionPtr (*MapperFunc),
6415
- OffloadingArgs);
6413
+
6414
+ auto TaskBodyCB = [&](Value *, Value *, IRBuilderBase::InsertPoint) {
6415
+ if (Info.HasNoWait ) {
6416
+ OffloadingArgs.push_back (llvm::Constant::getNullValue (Int32));
6417
+ OffloadingArgs.push_back (llvm::Constant::getNullValue (VoidPtr));
6418
+ OffloadingArgs.push_back (llvm::Constant::getNullValue (Int32));
6419
+ OffloadingArgs.push_back (llvm::Constant::getNullValue (VoidPtr));
6420
+ }
6421
+
6422
+ Builder.CreateCall (getOrCreateRuntimeFunctionPtr (*MapperFunc),
6423
+ OffloadingArgs);
6424
+
6425
+ if (Info.HasNoWait ) {
6426
+ BasicBlock *OffloadContBlock =
6427
+ BasicBlock::Create (Builder.getContext (), " omp_offload.cont" );
6428
+ auto *CurFn = Builder.GetInsertBlock ()->getParent ();
6429
+ emitBranch (OffloadContBlock);
6430
+ emitBlock (OffloadContBlock, CurFn, /* IsFinished=*/ true );
6431
+ Builder.restoreIP (Builder.saveIP ());
6432
+ }
6433
+ };
6434
+
6435
+ bool RequiresOuterTargetTask = Info.HasNoWait ;
6436
+
6437
+ if (!RequiresOuterTargetTask)
6438
+ TaskBodyCB (/* DeviceID=*/ nullptr , /* RTLoc=*/ nullptr ,
6439
+ /* TargetTaskAllocaIP=*/ {});
6440
+ else
6441
+ emitTargetTask (TaskBodyCB, DeviceID, SrcLocInfo, AllocaIP,
6442
+ /* Dependencies=*/ {}, Info.HasNoWait );
6416
6443
} else {
6417
6444
Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr (
6418
6445
omp::OMPRTL___tgt_target_data_begin_mapper);
@@ -6836,13 +6863,18 @@ static void emitTargetOutlinedFunction(
6836
6863
OMPBuilder.emitTargetRegionFunction (EntryInfo, GenerateOutlinedFunction,
6837
6864
IsOffloadEntry, OutlinedFn, OutlinedFnID);
6838
6865
}
6866
+
6839
6867
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask (
6840
- Function *OutlinedFn, Value *OutlinedFnID,
6841
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
6842
- Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
6843
- SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6868
+ TaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
6869
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
6870
+ const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6844
6871
bool HasNoWait) {
6845
6872
6873
+ // The following explains the code-gen scenario for the `target` directive. A
6874
+ // similar scneario is followed for other device-related directives (e.g.
6875
+ // `target enter data`) but in similar fashion since we only need to emit task
6876
+ // that encapsulates the proper runtime call.
6877
+ //
6846
6878
// When we arrive at this function, the target region itself has been
6847
6879
// outlined into the function OutlinedFn.
6848
6880
// So at ths point, for
@@ -6950,22 +6982,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6950
6982
6951
6983
Builder.restoreIP (TargetTaskBodyIP);
6952
6984
6953
- if (OutlinedFnID) {
6954
- // emitKernelLaunch makes the necessary runtime call to offload the kernel.
6955
- // We then outline all that code into a separate function
6956
- // ('kernel_launch_function' in the pseudo code above). This function is
6957
- // then called by the target task proxy function (see
6958
- // '@.omp_target_task_proxy_func' in the pseudo code above)
6959
- // "@.omp_target_task_proxy_func' is generated by
6960
- // emitTargetTaskProxyFunction.
6961
- Builder.restoreIP (emitKernelLaunch (Builder, OutlinedFn, OutlinedFnID,
6962
- EmitTargetCallFallbackCB, Args, DeviceID,
6963
- RTLoc, TargetTaskAllocaIP));
6964
- } else {
6965
- // When OutlinedFnID is set to nullptr, then it's not an offloading call. In
6966
- // this case, we execute the host implementation directly.
6967
- Builder.restoreIP (EmitTargetCallFallbackCB (Builder.saveIP ()));
6968
- }
6985
+ TaskBodyCB (DeviceID, RTLoc, TargetTaskAllocaIP);
6969
6986
6970
6987
OI.ExitBB = Builder.saveIP ().getBlock ();
6971
6988
OI.PostOutlineCB = [this , ToBeDeleted, Dependencies, HasNoWait,
@@ -7153,21 +7170,44 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7153
7170
bool HasDependencies = Dependencies.size () > 0 ;
7154
7171
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7155
7172
7173
+ OpenMPIRBuilder::TargetKernelArgs KArgs;
7174
+
7175
+ auto TaskBodyCB = [&](Value *DeviceID, Value *RTLoc,
7176
+ IRBuilderBase::InsertPoint TargetTaskAllocaIP) {
7177
+ if (OutlinedFnID) {
7178
+ // emitKernelLaunch makes the necessary runtime call to offload the
7179
+ // kernel. We then outline all that code into a separate function
7180
+ // ('kernel_launch_function' in the pseudo code above). This function is
7181
+ // then called by the target task proxy function (see
7182
+ // '@.omp_target_task_proxy_func' in the pseudo code above)
7183
+ // "@.omp_target_task_proxy_func' is generated by
7184
+ // emitTargetTaskProxyFunction.
7185
+ Builder.restoreIP (OMPBuilder.emitKernelLaunch (
7186
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7187
+ RTLoc, TargetTaskAllocaIP));
7188
+ } else {
7189
+ // When OutlinedFnID is set to nullptr, then it's not an offloading
7190
+ // call. In this case, we execute the host implementation directly.
7191
+ OMPBuilder.Builder .restoreIP (
7192
+ EmitTargetCallFallbackCB (OMPBuilder.Builder .saveIP ()));
7193
+ }
7194
+ };
7195
+
7156
7196
// If we don't have an ID for the target region, it means an offload entry
7157
7197
// wasn't created. In this case we just run the host fallback directly.
7158
7198
if (!OutlinedFnID) {
7159
7199
if (RequiresOuterTargetTask) {
7160
7200
// Arguments that are intended to be directly forwarded to an
7161
7201
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
7162
7202
// results in that call not being done.
7163
- OpenMPIRBuilder::TargetKernelArgs KArgs;
7164
- Builder.restoreIP (OMPBuilder.emitTargetTask (
7165
- OutlinedFn, /* OutlinedFnID=*/ nullptr , EmitTargetCallFallbackCB, KArgs,
7166
- /* DeviceID=*/ nullptr , /* RTLoc=*/ nullptr , AllocaIP, Dependencies,
7167
- HasNoWait));
7203
+ Builder.restoreIP (OMPBuilder.emitTargetTask (TaskBodyCB,
7204
+ /* DeviceID=*/ nullptr ,
7205
+ /* RTLoc=*/ nullptr , AllocaIP,
7206
+ Dependencies, HasNoWait));
7168
7207
} else {
7169
7208
Builder.restoreIP (EmitTargetCallFallbackCB (Builder.saveIP ()));
7170
7209
}
7210
+
7171
7211
return ;
7172
7212
}
7173
7213
@@ -7201,20 +7241,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7201
7241
// TODO: Use correct DynCGGroupMem
7202
7242
Value *DynCGGroupMem = Builder.getInt32 (0 );
7203
7243
7204
- OpenMPIRBuilder::TargetKernelArgs KArgs (NumTargetItems, RTArgs, NumIterations,
7205
- NumTeamsC, NumThreadsC, DynCGGroupMem ,
7206
- HasNoWait);
7244
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (
7245
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7246
+ DynCGGroupMem, HasNoWait);
7207
7247
7208
7248
// The presence of certain clauses on the target directive require the
7209
7249
// explicit generation of the target task.
7210
7250
if (RequiresOuterTargetTask) {
7211
7251
Builder.restoreIP (OMPBuilder.emitTargetTask (
7212
- OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7213
- RTLoc, AllocaIP, Dependencies, HasNoWait));
7252
+ TaskBodyCB, DeviceID, RTLoc, AllocaIP, Dependencies, HasNoWait));
7214
7253
} else {
7215
7254
Builder.restoreIP (OMPBuilder.emitKernelLaunch (
7216
- Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7217
- DeviceID, RTLoc, AllocaIP));
7255
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID, RTLoc ,
7256
+ AllocaIP));
7218
7257
}
7219
7258
}
7220
7259
0 commit comments