@@ -63,6 +63,7 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
63
63
#include " common/LLVMWarningsPop.hpp"
64
64
65
65
#include < algorithm>
66
+ #include < map>
66
67
67
68
using namespace llvm ;
68
69
using namespace IGC ;
@@ -108,22 +109,42 @@ namespace //Anonymous
108
109
SAMPLER
109
110
};
110
111
111
- const auto FNAME_ENQUEUE_KERNEL = " _Z14enqueue_kernel" ;
112
- const auto FNAME_ENQUEUE_KERNEL_BASIC = " __enqueue_kernel_basic" ;
113
- const auto FNAME_ENQUEUE_KERNEL_VAARGS = " __enqueue_kernel_vaargs" ;
114
- const auto FNAME_ENQUEUE_KERNEL_EVENTS_VAARGS = " __enqueue_kernel_events_vaargs" ;
115
- const auto FNAME_WORK_GROUP_SIZE_IMPL = " __get_kernel_work_group_size_impl" ;
116
- const auto FNAME_PREFERRED_WORK_GROUP_SIZE_MULTIPLE = " _Z45get_kernel_preferred_work_group_size_multiple" ;
117
- const auto FNAME_PREFERRED_WORK_GROUP_MULTIPLE_IMPL = " __get_kernel_preferred_work_group_multiple_impl" ;
118
- const auto FNAME_MAX_SUB_GROUP_SIZE_FOR_NDRANGE = " _Z41get_kernel_max_sub_group_size_for_ndrange" ;
119
- const auto FNAME_SUB_GROUP_COUNT_FOR_NDRANGE = " _Z38get_kernel_sub_group_count_for_ndrange" ;
120
-
121
- const auto FNAME_SPIRV_ENQUEUE_KERNEL = " __builtin_spirv_OpEnqueueKernel" ;
122
- const auto FNAME_SPIRV_SUB_GROUP_COUNT_FOR_NDRANGE = " __builtin_spirv_OpGetKernelNDrangeSubGroupCount" ;
123
- const auto FNAME_SPIRV_MAX_SUB_GROUP_SIZE_FOR_NDRANGE = " __builtin_spirv_OpGetKernelNDrangeMaxSubGroupSize" ;
124
- const auto FNAME_SPIRV_PREFERRED_WORK_GROUP_SIZE_MULTIPLE = " __builtin_spirv_OpGetKernelPreferredWorkGroupSizeMultiple" ;
125
- const auto FNAME_SPIRV_LOCAL_SIZE_FOR_SUB_GROUP_COUNT = " __builtin_spirv_OpGetKernelLocalSizeForSubgroupCount" ;
126
- const auto FNAME_SPIRV_MAX_NUM_SUB_GROUPS = " __builtin_spirv_OpGetKernelMaxNumSubgroups" ;
112
+ enum class DeviceEnqueueFunction {
113
+ ENQUEUE_KERNEL,
114
+ ENQUEUE_KERNEL_BASIC,
115
+ ENQUEUE_KERNEL_VAARGS,
116
+ ENQUEUE_KERNEL_EVENTS_VAARGS,
117
+ WORK_GROUP_SIZE_IMPL,
118
+ PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
119
+ PREFERRED_WORK_GROUP_MULTIPLE_IMPL,
120
+ MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
121
+ SUB_GROUP_COUNT_FOR_NDRANGE,
122
+ SPIRV_ENQUEUE_KERNEL,
123
+ SPIRV_SUB_GROUP_COUNT_FOR_NDRANGE,
124
+ SPIRV_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
125
+ SPIRV_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
126
+ SPIRV_LOCAL_SIZE_FOR_SUB_GROUP_COUNT,
127
+ SPIRV_MAX_NUM_SUB_GROUPS,
128
+ NUM_FUNCTIONS_WITH_BLOCK_ARGS
129
+ };
130
+
131
+ const std::map<DeviceEnqueueFunction, const char *> DeviceEnqueueFunctionNames = {
132
+ { DeviceEnqueueFunction::ENQUEUE_KERNEL, " _Z14enqueue_kernel" },
133
+ { DeviceEnqueueFunction::ENQUEUE_KERNEL_BASIC, " __enqueue_kernel_basic" },
134
+ { DeviceEnqueueFunction::ENQUEUE_KERNEL_VAARGS, " __enqueue_kernel_vaargs" },
135
+ { DeviceEnqueueFunction::ENQUEUE_KERNEL_EVENTS_VAARGS, " __enqueue_kernel_events_vaargs" },
136
+ { DeviceEnqueueFunction::WORK_GROUP_SIZE_IMPL, " __get_kernel_work_group_size_impl" },
137
+ { DeviceEnqueueFunction::PREFERRED_WORK_GROUP_SIZE_MULTIPLE, " _Z45get_kernel_preferred_work_group_size_multiple" },
138
+ { DeviceEnqueueFunction::PREFERRED_WORK_GROUP_MULTIPLE_IMPL, " __get_kernel_preferred_work_group_multiple_impl" },
139
+ { DeviceEnqueueFunction::MAX_SUB_GROUP_SIZE_FOR_NDRANGE, " _Z41get_kernel_max_sub_group_size_for_ndrange" },
140
+ { DeviceEnqueueFunction::SUB_GROUP_COUNT_FOR_NDRANGE, " _Z38get_kernel_sub_group_count_for_ndrange" },
141
+ { DeviceEnqueueFunction::SPIRV_ENQUEUE_KERNEL, " __builtin_spirv_OpEnqueueKernel" },
142
+ { DeviceEnqueueFunction::SPIRV_SUB_GROUP_COUNT_FOR_NDRANGE, " __builtin_spirv_OpGetKernelNDrangeSubGroupCount" },
143
+ { DeviceEnqueueFunction::SPIRV_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, " __builtin_spirv_OpGetKernelNDrangeMaxSubGroupSize" },
144
+ { DeviceEnqueueFunction::SPIRV_PREFERRED_WORK_GROUP_SIZE_MULTIPLE, " __builtin_spirv_OpGetKernelPreferredWorkGroupSizeMultiple" },
145
+ { DeviceEnqueueFunction::SPIRV_LOCAL_SIZE_FOR_SUB_GROUP_COUNT, " __builtin_spirv_OpGetKernelLocalSizeForSubgroupCount" },
146
+ { DeviceEnqueueFunction::SPIRV_MAX_NUM_SUB_GROUPS, " __builtin_spirv_OpGetKernelMaxNumSubgroups" }
147
+ };
127
148
128
149
// ///////////////////////////////////////////////////////////////////////////////////////////////
129
150
// / helper class to build and query llvm metadata of kernel/dispatcher
@@ -727,10 +748,10 @@ namespace //Anonymous
727
748
return invokeFunc;
728
749
}
729
750
730
- std::vector<llvm::Value*> getLocaSizes () const { return _local_sizes; }
751
+ std::vector<llvm::Value*> getLocalSizes () const { return _local_sizes; }
731
752
732
753
bool hasEvents () const { return getNumWaitEvents () != nullptr ; }
733
- bool hasLocals () const { return getLocaSizes ().size () > 0 ; }
754
+ bool hasLocals () const { return getLocalSizes ().size () > 0 ; }
734
755
};
735
756
736
757
// ////////////////////////////////////////////////////////////////////////
@@ -885,6 +906,11 @@ namespace //Anonymous
885
906
for (unsigned i = localSizesStartArgNum; i < argsNum; i++)
886
907
{
887
908
auto arg = _call.getArgOperand (i);
909
+ if (arg->getType ()->isPointerTy ()) {
910
+ IRBuilder<> builder (&_call);
911
+ arg = builder.CreateLoad (arg);
912
+ }
913
+
888
914
if (!arg->getType ()->isIntegerTy (64 ) && !arg->getType ()->isIntegerTy (32 ))
889
915
report_fatal_error (" OpEnqueueKernel signature does not match" );
890
916
@@ -1354,7 +1380,7 @@ namespace //Anonymous
1354
1380
1355
1381
Function* getInvokeFunctionFromKernelWrapper (const Function* invokeFunc, DataContext& dataContext) {
1356
1382
assert (isInvokeFunctionKernelWrapper (invokeFunc, dataContext));
1357
- const CallInst* inst = dyn_cast<CallInst>(*(invokeFunc->arg_begin ())-> user_begin ( ));
1383
+ const CallInst* inst = dyn_cast<CallInst>(& *(invokeFunc->begin ()-> begin () ));
1358
1384
if (inst) {
1359
1385
return inst->getCalledFunction ();
1360
1386
} else {
@@ -1363,11 +1389,21 @@ namespace //Anonymous
1363
1389
}
1364
1390
1365
1391
bool isEnqueueKernelFunction (StringRef funcName) {
1366
- return funcName.startswith (FNAME_ENQUEUE_KERNEL) ||
1367
- funcName.startswith (FNAME_SPIRV_ENQUEUE_KERNEL) ||
1368
- funcName.startswith (FNAME_ENQUEUE_KERNEL_BASIC) ||
1369
- funcName.startswith (FNAME_ENQUEUE_KERNEL_VAARGS) ||
1370
- funcName.startswith (FNAME_ENQUEUE_KERNEL_EVENTS_VAARGS);
1392
+
1393
+ return funcName.startswith (DeviceEnqueueFunctionNames.at (DeviceEnqueueFunction::ENQUEUE_KERNEL)) ||
1394
+ funcName.startswith (DeviceEnqueueFunctionNames.at (DeviceEnqueueFunction::SPIRV_ENQUEUE_KERNEL)) ||
1395
+ funcName.startswith (DeviceEnqueueFunctionNames.at (DeviceEnqueueFunction::ENQUEUE_KERNEL_BASIC)) ||
1396
+ funcName.startswith (DeviceEnqueueFunctionNames.at (DeviceEnqueueFunction::ENQUEUE_KERNEL_VAARGS)) ||
1397
+ funcName.startswith (DeviceEnqueueFunctionNames.at (DeviceEnqueueFunction::ENQUEUE_KERNEL_EVENTS_VAARGS));
1398
+ }
1399
+
1400
+ bool isDeviceEnqueueFunction (StringRef funcName) {
1401
+ for (auto el : DeviceEnqueueFunctionNames) {
1402
+ if (funcName.startswith (el.second )) {
1403
+ return true ;
1404
+ }
1405
+ }
1406
+ return false ;
1371
1407
}
1372
1408
1373
1409
@@ -1423,7 +1459,8 @@ namespace //Anonymous
1423
1459
{
1424
1460
bool changed = false ;
1425
1461
for (auto &func : M.functions ()) {
1426
- if (!isEnqueueKernelFunction (func.getName ())) continue ;
1462
+ if (!isDeviceEnqueueFunction (func.getName ())) continue ;
1463
+
1427
1464
for (auto user : func.users ()) {
1428
1465
auto callInst = dyn_cast<CallInst>(user);
1429
1466
if (!callInst) continue ;
@@ -1737,70 +1774,72 @@ namespace //Anonymous
1737
1774
CallHandler* DataContext::registerCallHandler (llvm::CallInst& call)
1738
1775
{
1739
1776
// device enqueue call handlers factories registry
1740
- const std::pair<StringRef , std::function<CallHandler*(llvm::CallInst&, DataContext& dm)>> handlers[] =
1777
+ const std::pair<DeviceEnqueueFunction , std::function<CallHandler*(llvm::CallInst&, DataContext& dm)>> handlers[] =
1741
1778
{
1742
1779
{
1743
- FNAME_MAX_SUB_GROUP_SIZE_FOR_NDRANGE ,
1780
+ DeviceEnqueueFunction::MAX_SUB_GROUP_SIZE_FOR_NDRANGE ,
1744
1781
[](llvm::CallInst& call, DataContext& dm){ return new KernelSubGroupSizeCall (new ObjCNDRangeAndBlockCallArgs (call, dm)); }
1745
1782
},
1746
1783
{
1747
- FNAME_PREFERRED_WORK_GROUP_SIZE_MULTIPLE ,
1784
+ DeviceEnqueueFunction::PREFERRED_WORK_GROUP_SIZE_MULTIPLE ,
1748
1785
[](llvm::CallInst& call, DataContext& dm){ return new KernelSubGroupSizeCall (new ObjCBlockCallArgs (call, dm)); }
1749
1786
},
1750
1787
{
1751
- FNAME_PREFERRED_WORK_GROUP_MULTIPLE_IMPL ,
1788
+ DeviceEnqueueFunction::PREFERRED_WORK_GROUP_MULTIPLE_IMPL ,
1752
1789
[](llvm::CallInst& call, DataContext& dm) { return new KernelSubGroupSizeCall (new ObjCBlockCallArgs (call, dm)); }
1753
1790
},
1754
1791
{
1755
- FNAME_WORK_GROUP_SIZE_IMPL ,
1792
+ DeviceEnqueueFunction::WORK_GROUP_SIZE_IMPL ,
1756
1793
[](llvm::CallInst& call, DataContext& dm) { return new KernelMaxWorkGroupSizeCall (new ObjCBlockCallArgs (call, dm)); }
1757
1794
},
1758
1795
{
1759
- FNAME_SUB_GROUP_COUNT_FOR_NDRANGE ,
1796
+ DeviceEnqueueFunction::SUB_GROUP_COUNT_FOR_NDRANGE ,
1760
1797
[](llvm::CallInst& call, DataContext& dm){ return new KernelSubGroupCountForNDRangeCall (new ObjCNDRangeAndBlockCallArgs (call, dm)); }
1761
1798
},
1762
1799
{
1763
- FNAME_ENQUEUE_KERNEL ,
1800
+ DeviceEnqueueFunction::ENQUEUE_KERNEL ,
1764
1801
[](llvm::CallInst& call, DataContext& dm){ return new EnqueueKernelCall (new ObjCEnqueueKernelArgs (call, dm)); }
1765
1802
},
1766
1803
{
1767
- FNAME_ENQUEUE_KERNEL_BASIC ,
1804
+ DeviceEnqueueFunction::ENQUEUE_KERNEL_BASIC ,
1768
1805
[](llvm::CallInst& call, DataContext& dm){ return new EnqueueKernelCall (new ObjCEnqueueKernelArgs (call, dm)); }
1769
1806
},
1770
1807
{
1771
- FNAME_ENQUEUE_KERNEL_VAARGS ,
1808
+ DeviceEnqueueFunction::ENQUEUE_KERNEL_VAARGS ,
1772
1809
[](llvm::CallInst& call, DataContext& dm){ return new EnqueueKernelCall (new ObjCEnqueueKernelArgs (call, dm)); }
1773
1810
},
1774
1811
{
1775
- FNAME_ENQUEUE_KERNEL_EVENTS_VAARGS ,
1812
+ DeviceEnqueueFunction::ENQUEUE_KERNEL_EVENTS_VAARGS ,
1776
1813
[](llvm::CallInst& call, DataContext& dm){ return new EnqueueKernelCall (new ObjCEnqueueKernelArgs (call, dm)); }
1777
1814
},
1778
1815
{
1779
- FNAME_SPIRV_MAX_SUB_GROUP_SIZE_FOR_NDRANGE ,
1816
+ DeviceEnqueueFunction::SPIRV_MAX_SUB_GROUP_SIZE_FOR_NDRANGE ,
1780
1817
[](llvm::CallInst& call, DataContext& dm){ return new KernelSubGroupSizeCall (new SPIRVNDRangeAndInvokeCallArgs (call, dm)); }
1781
1818
},
1782
1819
{
1783
- FNAME_SPIRV_PREFERRED_WORK_GROUP_SIZE_MULTIPLE ,
1820
+ DeviceEnqueueFunction::SPIRV_PREFERRED_WORK_GROUP_SIZE_MULTIPLE ,
1784
1821
[](llvm::CallInst& call, DataContext& dm){ return new KernelSubGroupSizeCall (new SPIRVInvokeCallArgs (call, dm)); }
1785
1822
},
1786
1823
{
1787
- FNAME_SPIRV_LOCAL_SIZE_FOR_SUB_GROUP_COUNT ,
1824
+ DeviceEnqueueFunction::SPIRV_LOCAL_SIZE_FOR_SUB_GROUP_COUNT ,
1788
1825
[](llvm::CallInst& call, DataContext& dm){ return new KernelLocalSizeForSubgroupCount (new SPIRVSubgroupCountAndInvokeCallArgs (call, dm)); }
1789
1826
},
1790
1827
{
1791
- FNAME_SPIRV_MAX_NUM_SUB_GROUPS ,
1828
+ DeviceEnqueueFunction::SPIRV_MAX_NUM_SUB_GROUPS ,
1792
1829
[](llvm::CallInst& call, DataContext& dm){ return new KernelMaxNumSubgroups (new SPIRVInvokeCallArgs (call, dm)); }
1793
1830
},
1794
1831
{
1795
- FNAME_SPIRV_SUB_GROUP_COUNT_FOR_NDRANGE ,
1832
+ DeviceEnqueueFunction::SPIRV_SUB_GROUP_COUNT_FOR_NDRANGE ,
1796
1833
[](llvm::CallInst& call, DataContext& dm){ return new KernelSubGroupCountForNDRangeCall (new SPIRVNDRangeAndInvokeCallArgs (call, dm)); }
1797
1834
},
1798
1835
{
1799
- FNAME_SPIRV_ENQUEUE_KERNEL ,
1836
+ DeviceEnqueueFunction::SPIRV_ENQUEUE_KERNEL ,
1800
1837
[](llvm::CallInst& call, DataContext& dm){ return new EnqueueKernelCall (new SPIRVOpEnqueueKernelCallArgs (call, dm)); }
1801
1838
},
1802
1839
};
1803
1840
1841
+ static_assert (sizeof (handlers) == sizeof (decltype (handlers[0 ])) * (size_t )DeviceEnqueueFunction::NUM_FUNCTIONS_WITH_BLOCK_ARGS, " Not all enqueue functions have handlers!" );
1842
+
1804
1843
auto calledFunction = call.getCalledFunction ();
1805
1844
// fail indirect calls
1806
1845
if (calledFunction == nullptr )
@@ -1814,7 +1853,7 @@ namespace //Anonymous
1814
1853
for (auto & handler_pair : handlers)
1815
1854
{
1816
1855
// if called function name matches one of known
1817
- if (calledFunctionName.startswith (handler_pair.first ))
1856
+ if (calledFunctionName.startswith (DeviceEnqueueFunctionNames. at ( handler_pair.first ) ))
1818
1857
{
1819
1858
// use appropriate factory to construct CallHandler
1820
1859
auto callHandler = handler_pair.second (call, *this );
@@ -2178,16 +2217,20 @@ namespace //Anonymous
2178
2217
llvm::CallInst* BlockInvoke::EmitBlockInvokeCall (IGCLLVM::IRBuilder<>& builder, llvm::ArrayRef<llvm::Argument*> captures, llvm::ArrayRef<llvm::Argument*> tailingArgs) const
2179
2218
{
2180
2219
// IRBuilder: allocate structure
2181
- auto block_descriptor_val = builder.CreateAlloca (_captureStructType, nullptr , " .block_struct" );
2182
- auto dl = getFunction ()->getParent ()->getDataLayout ();
2183
- auto blockStructAlign = getPrefStructAlignment (_captureStructType, &dl);
2184
- block_descriptor_val->setAlignment (blockStructAlign);
2185
- // IRBuilder: store arguments to structure
2186
- StoreInstBuilder storeBuilder (builder);
2187
- for (unsigned argIdx = 0 ; argIdx < getCaptureIndicies ().size (); ++argIdx)
2188
- {
2189
- auto srcArg = captures[argIdx];
2190
- storeBuilder.Store (block_descriptor_val, srcArg, getCaptureIndicies ()[argIdx]);
2220
+ // If we didn't track the capturedStructType, it might have been not used in the kernel.
2221
+ Value* block_descriptor_val = ConstantPointerNull::get (builder.getInt8PtrTy ());
2222
+ if (_captureStructType) {
2223
+ block_descriptor_val = builder.CreateAlloca (_captureStructType, nullptr , " .block_struct" );
2224
+ auto dl = getFunction ()->getParent ()->getDataLayout ();
2225
+ auto blockStructAlign = getPrefStructAlignment (_captureStructType, &dl);
2226
+ cast<AllocaInst>(block_descriptor_val)->setAlignment (blockStructAlign);
2227
+ // IRBuilder: store arguments to structure
2228
+ StoreInstBuilder storeBuilder (builder);
2229
+ for (unsigned argIdx = 0 ; argIdx < getCaptureIndicies ().size (); ++argIdx)
2230
+ {
2231
+ auto srcArg = captures[argIdx];
2232
+ storeBuilder.Store (block_descriptor_val, srcArg, getCaptureIndicies ()[argIdx]);
2233
+ }
2191
2234
}
2192
2235
2193
2236
// IRBuilder: call block_invoke
@@ -2533,14 +2576,14 @@ namespace //Anonymous
2533
2576
if (_deviceExecCall->hasLocals ())
2534
2577
{
2535
2578
auto int32ptrty = Type::getInt32PtrTy (context);
2536
- auto localsBuf = AllocateBuffer (int32ty, _deviceExecCall->getLocaSizes ().size (), " local_size_buf" );
2579
+ auto localsBuf = AllocateBuffer (int32ty, _deviceExecCall->getLocalSizes ().size (), " local_size_buf" );
2537
2580
uint64_t localSizeOffset = 0 ;
2538
- for (auto localSizeValue : _deviceExecCall->getLocaSizes ())
2581
+ for (auto localSizeValue : _deviceExecCall->getLocalSizes ())
2539
2582
{
2540
2583
auto storedSize = storeBuilder.Store (localsBuf, localSizeValue, localSizeOffset);
2541
2584
localSizeOffset += sizeInBlocks (storedSize, int32ty);
2542
2585
}
2543
- assert (_deviceExecCall->getLocaSizes ().size () == localSizeOffset);
2586
+ assert (_deviceExecCall->getLocalSizes ().size () == localSizeOffset);
2544
2587
2545
2588
localSizesBuf = builder.CreatePointerCast (localsBuf, int32ptrty);
2546
2589
localSizesNumValue = llvm::ConstantInt::get (int32ty, localSizeOffset);
0 commit comments