10
10
11
11
#include " clang-mlir.h"
12
12
#include " TypeUtils.h"
13
+ #include " mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h"
13
14
#include " mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14
15
#include " mlir/Dialect/DLTI/DLTI.h"
15
16
#include " mlir/Dialect/SCF/IR/SCF.h"
45
46
#include " mlir/Dialect/SYCL/IR/SYCLOpsDialect.h.inc"
46
47
#include " mlir/Dialect/SYCL/IR/SYCLOpsTypes.h"
47
48
48
- static bool DEBUG_FUNCTION = false ;
49
49
static bool BREAKPOINT_FUNCTION = false ;
50
50
51
51
using namespace std ;
@@ -56,6 +56,7 @@ using namespace llvm::opt;
56
56
using namespace mlir ;
57
57
using namespace mlir ::arith;
58
58
using namespace mlir ::func;
59
+ using namespace mlir ::sycl;
59
60
using namespace mlirclang ;
60
61
61
62
static cl::opt<bool >
@@ -68,6 +69,10 @@ static cl::opt<bool> memRefABI("memref-abi", cl::init(true),
68
69
cl::opt<std::string> PrefixABI (" prefix-abi" , cl::init(" " ),
69
70
cl::desc(" Prefix for emitted symbols" ));
70
71
72
+ static cl::opt<bool > DebugFunction (
73
+ " debug-function" , cl::init(false ),
74
+ cl::desc(" Print informations about functions being processed." ));
75
+
71
76
static cl::opt<bool >
72
77
CombinedStructABI (" struct-abi" , cl::init(true ),
73
78
cl::desc(" Use literal LLVM ABI for structs" ));
@@ -111,6 +116,34 @@ MLIRScanner::MLIRScanner(MLIRASTConsumer &Glob,
111
116
: Glob(Glob), module(module ), builder(module ->getContext ()),
112
117
loc(builder.getUnknownLoc()), ThisCapture(nullptr ), LTInfo(LTInfo) {}
113
118
119
+ void MLIRScanner::initSupportedConstructors () {
120
+ // List from SYCLFuncRegistry.cpp Please modify as new constructors are
121
+ // added to that file.
122
+ supportedCons.insert (" _ZN2cl4sycl2idILi1EEC1Ev" );
123
+ supportedCons.insert (" _ZN2cl4sycl2idILi2EEC1Ev" );
124
+ supportedCons.insert (" _ZN2cl4sycl2idILi3EEC1Ev" );
125
+ supportedCons.insert (
126
+ " _ZN2cl4sycl2idILi1EEC1ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE" );
127
+ supportedCons.insert (
128
+ " _ZN2cl4sycl2idILi2EEC1ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE" );
129
+ supportedCons.insert (
130
+ " _ZN2cl4sycl2idILi3EEC1ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE" );
131
+ supportedCons.insert (
132
+ " _ZN2cl4sycl2idILi1EEC1ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm" );
133
+ supportedCons.insert (
134
+ " _ZN2cl4sycl2idILi2EEC1ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm" );
135
+ supportedCons.insert (
136
+ " _ZN2cl4sycl2idILi3EEC1ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm" );
137
+ supportedCons.insert (
138
+ " _ZN2cl4sycl2idILi1EEC1ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm" );
139
+ supportedCons.insert (
140
+ " _ZN2cl4sycl2idILi2EEC1ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm" );
141
+ supportedCons.insert (
142
+ " _ZN2cl4sycl2idILi3EEC1ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm" );
143
+ supportedCons.insert (" _ZN2cl4sycl6detail5arrayILi1EEC1ILi1EEENSt9enable_"
144
+ " ifIXeqT_Li1EEmE4typeE" );
145
+ }
146
+
114
147
void MLIRScanner::init (mlir::func::FuncOp function, const FunctionDecl *fd) {
115
148
this ->function = function;
116
149
this ->EmittingFunctionDecl = fd;
@@ -120,6 +153,7 @@ void MLIRScanner::init(mlir::func::FuncOp function, const FunctionDecl *fd) {
120
153
llvm::errs () << *fd << " \n " ;
121
154
}
122
155
156
+ initSupportedConstructors ();
123
157
setEntryAndAllocBlock (function.addEntryBlock ());
124
158
125
159
unsigned i = 0 ;
@@ -1363,6 +1397,16 @@ MLIRScanner::VisitCXXConstructExpr(clang::CXXConstructExpr *cons) {
1363
1397
return VisitConstructCommon (cons, /* name*/ nullptr , /* space*/ 0 );
1364
1398
}
1365
1399
1400
+ static void getMangledFuncName (std::string &name, const FunctionDecl *FD,
1401
+ CodeGen::CodeGenModule &CGM) {
1402
+ if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
1403
+ name = CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
1404
+ else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
1405
+ name = CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
1406
+ else
1407
+ name = CGM.getMangledName (FD).str ();
1408
+ }
1409
+
1366
1410
ValueCategory MLIRScanner::VisitConstructCommon (clang::CXXConstructExpr *cons,
1367
1411
VarDecl *name, unsigned memtype,
1368
1412
mlir::Value op,
@@ -1439,11 +1483,33 @@ ValueCategory MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons,
1439
1483
assert (obj.isReference );
1440
1484
}
1441
1485
1442
- // / If the constructor is part of the SYCL namespace, we do not want the
1486
+ // / If the constructor is part of the SYCL namespace, we may not want the
1443
1487
// / GetOrCreateMLIRFunction to add this FuncOp to the functionsToEmit dequeu,
1444
- // / since we will create it's equivalent with SYCL operations.
1445
- const auto ShouldEmit = !mlirclang::isNamespaceSYCL (
1488
+ // / since we will create it's equivalent with SYCL operations. Please note
1489
+ // / that we still generate some constructors that we need for lowering some
1490
+ // / sycl op. Therefore, in those case, we set ShouldEmit back to "true" by
1491
+ // / looking them up in our "registry" of supported constructors.
1492
+
1493
+ bool ShouldEmit = !mlirclang::isNamespaceSYCL (
1446
1494
cons->getConstructor ()->getEnclosingNamespaceContext ());
1495
+
1496
+ if (const FunctionDecl *FuncDecl =
1497
+ dyn_cast<FunctionDecl>(cons->getConstructor ())) {
1498
+ std::string name;
1499
+ getMangledFuncName (name, FuncDecl, Glob.CGM );
1500
+ name = (PrefixABI + name);
1501
+
1502
+ if (DebugFunction) {
1503
+ llvm::dbgs () << " Starting codegen of " << name << " \n " ;
1504
+ }
1505
+ if (isSupportedConstructor (name)) {
1506
+ if (DebugFunction) {
1507
+ llvm::dbgs () << " Function found in registry, continue codegen-ing...\n " ;
1508
+ }
1509
+ ShouldEmit = true ;
1510
+ }
1511
+ }
1512
+
1447
1513
auto tocall =
1448
1514
Glob.GetOrCreateMLIRFunction (cons->getConstructor (), ShouldEmit);
1449
1515
@@ -4262,12 +4328,7 @@ mlir::LLVM::LLVMFuncOp MLIRASTConsumer::GetOrCreateFreeFunction() {
4262
4328
mlir::LLVM::LLVMFuncOp
4263
4329
MLIRASTConsumer::GetOrCreateLLVMFunction (const FunctionDecl *FD) {
4264
4330
std::string name;
4265
- if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
4266
- name = CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4267
- else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
4268
- name = CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
4269
- else
4270
- name = CGM.getMangledName (FD).str ();
4331
+ getMangledFuncName (name, FD, CGM);
4271
4332
4272
4333
if (name != " malloc" && name != " free" )
4273
4334
name = (PrefixABI + name);
@@ -4630,25 +4691,20 @@ mlir::Value MLIRASTConsumer::GetOrCreateGlobalLLVMString(
4630
4691
return globalPtr;
4631
4692
}
4632
4693
4633
- mlir::func::FuncOp
4634
- MLIRASTConsumer::GetOrCreateMLIRFunction (const FunctionDecl *FD,
4635
- const bool ShouldEmit,
4636
- bool getDeviceStub) {
4694
+ mlir::func::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction (
4695
+ const FunctionDecl *FD, const bool ShouldEmit, bool getDeviceStub) {
4637
4696
assert (FD->getTemplatedKind () !=
4638
4697
FunctionDecl::TemplatedKind::TK_FunctionTemplate);
4639
4698
assert (
4640
4699
FD->getTemplatedKind () !=
4641
4700
FunctionDecl::TemplatedKind::TK_DependentFunctionTemplateSpecialization);
4701
+
4642
4702
std::string name;
4643
4703
if (getDeviceStub)
4644
4704
name =
4645
4705
CGM.getMangledName (GlobalDecl (FD, KernelReferenceKind::Kernel)).str ();
4646
- else if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
4647
- name = CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4648
- else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
4649
- name = CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
4650
4706
else
4651
- name = CGM. getMangledName (FD). str ( );
4707
+ getMangledFuncName ( name, FD, CGM);
4652
4708
4653
4709
name = (PrefixABI + name);
4654
4710
@@ -4855,7 +4911,7 @@ void MLIRASTConsumer::run() {
4855
4911
while (functionsToEmit.size ()) {
4856
4912
const FunctionDecl *FD = functionsToEmit.front ();
4857
4913
4858
- if (BREAKPOINT_FUNCTION && DEBUG_FUNCTION ) {
4914
+ if (BREAKPOINT_FUNCTION && DebugFunction ) {
4859
4915
printf (" \n " );
4860
4916
printf (" -- FUNCTION BEING EMITTED : \033 [0;32m %s \033 [0m -- \n " ,
4861
4917
FD->getNameAsString ().c_str ());
@@ -4870,14 +4926,7 @@ void MLIRASTConsumer::run() {
4870
4926
TK_DependentFunctionTemplateSpecialization);
4871
4927
std::string name;
4872
4928
4873
- if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
4874
- name =
4875
- CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4876
- else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
4877
- name =
4878
- CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
4879
- else
4880
- name = CGM.getMangledName (FD).str ();
4929
+ getMangledFuncName (name, FD, CGM);
4881
4930
4882
4931
if (done.count (name))
4883
4932
continue ;
@@ -4886,7 +4935,7 @@ void MLIRASTConsumer::run() {
4886
4935
auto Function = GetOrCreateMLIRFunction (FD, true );
4887
4936
ms.init (Function, FD);
4888
4937
4889
- if (BREAKPOINT_FUNCTION && DEBUG_FUNCTION ) {
4938
+ if (BREAKPOINT_FUNCTION && DebugFunction ) {
4890
4939
printf (" \n " );
4891
4940
Function.dump ();
4892
4941
printf (" \n " );
@@ -4926,7 +4975,7 @@ void MLIRASTConsumer::HandleDeclContext(DeclContext *DC) {
4926
4975
HandleDeclContext (NS);
4927
4976
continue ;
4928
4977
}
4929
- FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(D);
4978
+ const FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(D);
4930
4979
if (!fd) {
4931
4980
continue ;
4932
4981
}
@@ -4953,14 +5002,7 @@ void MLIRASTConsumer::HandleDeclContext(DeclContext *DC) {
4953
5002
externLinkage = false ;
4954
5003
4955
5004
std::string name;
4956
- if (auto CC = dyn_cast<CXXConstructorDecl>(fd))
4957
- name =
4958
- CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4959
- else if (auto CC = dyn_cast<CXXDestructorDecl>(fd))
4960
- name =
4961
- CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
4962
- else
4963
- name = CGM.getMangledName (fd).str ();
5005
+ getMangledFuncName (name, fd, CGM);
4964
5006
4965
5007
// Don't create std functions unless necessary
4966
5008
if (StringRef (name).startswith (" _ZNKSt" ))
@@ -5002,7 +5044,7 @@ bool MLIRASTConsumer::HandleTopLevelDecl(DeclGroupRef dg) {
5002
5044
HandleDeclContext (NS);
5003
5045
continue ;
5004
5046
}
5005
- FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(*it);
5047
+ const FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(*it);
5006
5048
if (!fd) {
5007
5049
continue ;
5008
5050
}
@@ -5034,14 +5076,7 @@ bool MLIRASTConsumer::HandleTopLevelDecl(DeclGroupRef dg) {
5034
5076
externLinkage = false ;
5035
5077
5036
5078
std::string name;
5037
- if (auto CC = dyn_cast<CXXConstructorDecl>(fd))
5038
- name =
5039
- CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
5040
- else if (auto CC = dyn_cast<CXXDestructorDecl>(fd))
5041
- name =
5042
- CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
5043
- else
5044
- name = CGM.getMangledName (fd).str ();
5079
+ getMangledFuncName (name, fd, CGM);
5045
5080
5046
5081
// Don't create std functions unless necessary
5047
5082
if (StringRef (name).startswith (" _ZNKSt" ))
0 commit comments