Skip to content

Commit 1f72056

Browse files
committed
Fix cuda flag with clang-repl
1 parent ecbd2d5 commit 1f72056

File tree

4 files changed

+88
-61
lines changed

4 files changed

+88
-61
lines changed

clang/include/clang/Interpreter/Interpreter.h

+9-4
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class Interpreter {
9595
// An optional parser for CUDA offloading
9696
std::unique_ptr<IncrementalParser> DeviceParser;
9797

98+
// An optional action for CUDA offloading
99+
std::unique_ptr<IncrementalAction> DeviceAct;
100+
98101
/// List containing information about each incrementally parsed piece of code.
99102
std::list<PartialTranslationUnit> PTUs;
100103

@@ -129,7 +132,8 @@ class Interpreter {
129132
public:
130133
virtual ~Interpreter();
131134
static llvm::Expected<std::unique_ptr<Interpreter>>
132-
create(std::unique_ptr<CompilerInstance> CI);
135+
create(std::unique_ptr<CompilerInstance> CI,
136+
std::unique_ptr<CompilerInstance> DeviceCI = nullptr);
133137
static llvm::Expected<std::unique_ptr<Interpreter>>
134138
createWithCUDA(std::unique_ptr<CompilerInstance> CI,
135139
std::unique_ptr<CompilerInstance> DCI);
@@ -175,10 +179,11 @@ class Interpreter {
175179
llvm::Expected<Expr *> ExtractValueFromExpr(Expr *E);
176180
llvm::Expected<llvm::orc::ExecutorAddr> CompileDtorCall(CXXRecordDecl *CXXRD);
177181

178-
CodeGenerator *getCodeGen() const;
179-
std::unique_ptr<llvm::Module> GenModule();
182+
CodeGenerator *getCodeGen(IncrementalAction *Action = nullptr) const;
183+
std::unique_ptr<llvm::Module> GenModule(IncrementalAction *Action = nullptr);
180184
PartialTranslationUnit &RegisterPTU(TranslationUnitDecl *TU,
181-
std::unique_ptr<llvm::Module> M = {});
185+
std::unique_ptr<llvm::Module> M = {},
186+
IncrementalAction *Action = nullptr);
182187

183188
// A cache for the compiled destructors used to for de-allocation of managed
184189
// clang::Values.

clang/lib/Interpreter/DeviceOffload.cpp

+20-23
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,21 @@ IncrementalCUDADeviceParser::IncrementalCUDADeviceParser(
2828
std::unique_ptr<CompilerInstance> DeviceInstance,
2929
CompilerInstance &HostInstance,
3030
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> FS,
31-
llvm::Error &Err, const std::list<PartialTranslationUnit> &PTUs)
31+
llvm::Error &Err, std::list<PartialTranslationUnit> &PTUs)
3232
: IncrementalParser(*DeviceInstance, Err), PTUs(PTUs), VFS(FS),
3333
CodeGenOpts(HostInstance.getCodeGenOpts()),
34-
TargetOpts(HostInstance.getTargetOpts()) {
34+
TargetOpts(DeviceInstance->getTargetOpts()) {
3535
if (Err)
3636
return;
37-
DeviceCI = std::move(DeviceInstance);
3837
StringRef Arch = TargetOpts.CPU;
3938
if (!Arch.starts_with("sm_") || Arch.substr(3).getAsInteger(10, SMVersion)) {
39+
DeviceInstance.release();
4040
Err = llvm::joinErrors(std::move(Err), llvm::make_error<llvm::StringError>(
4141
"Invalid CUDA architecture",
4242
llvm::inconvertibleErrorCode()));
4343
return;
4444
}
45+
DeviceCI = std::move(DeviceInstance);
4546
}
4647

4748
llvm::Expected<TranslationUnitDecl *>
@@ -50,25 +51,6 @@ IncrementalCUDADeviceParser::Parse(llvm::StringRef Input) {
5051
if (!PTU)
5152
return PTU.takeError();
5253

53-
auto PTX = GeneratePTX();
54-
if (!PTX)
55-
return PTX.takeError();
56-
57-
auto Err = GenerateFatbinary();
58-
if (Err)
59-
return std::move(Err);
60-
61-
std::string FatbinFileName =
62-
"/incr_module_" + std::to_string(PTUs.size()) + ".fatbin";
63-
VFS->addFile(FatbinFileName, 0,
64-
llvm::MemoryBuffer::getMemBuffer(
65-
llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
66-
"", false));
67-
68-
CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;
69-
70-
FatbinContent.clear();
71-
7254
return PTU;
7355
}
7456

@@ -78,9 +60,11 @@ llvm::Expected<llvm::StringRef> IncrementalCUDADeviceParser::GeneratePTX() {
7860

7961
const llvm::Target *Target = llvm::TargetRegistry::lookupTarget(
8062
PTU.TheModule->getTargetTriple(), Error);
81-
if (!Target)
63+
if (!Target) {
8264
return llvm::make_error<llvm::StringError>(std::move(Error),
8365
std::error_code());
66+
}
67+
8468
llvm::TargetOptions TO = llvm::TargetOptions();
8569
llvm::TargetMachine *TargetMachine = Target->createTargetMachine(
8670
PTU.TheModule->getTargetTriple(), TargetOpts.CPU, "", TO,
@@ -172,6 +156,19 @@ llvm::Error IncrementalCUDADeviceParser::GenerateFatbinary() {
172156

173157
FatbinContent.append(PTXCode.begin(), PTXCode.end());
174158

159+
auto &PTU = PTUs.back();
160+
161+
std::string FatbinFileName = "/" + PTU.TheModule->getName().str() + ".fatbin";
162+
163+
VFS->addFile(FatbinFileName, 0,
164+
llvm::MemoryBuffer::getMemBuffer(
165+
llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
166+
"", false));
167+
168+
CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;
169+
170+
FatbinContent.clear();
171+
175172
return llvm::Error::success();
176173
}
177174

clang/lib/Interpreter/DeviceOffload.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ class CodeGenOptions;
2424
class TargetOptions;
2525

2626
class IncrementalCUDADeviceParser : public IncrementalParser {
27-
const std::list<PartialTranslationUnit> &PTUs;
27+
std::list<PartialTranslationUnit> &PTUs;
2828

2929
public:
3030
IncrementalCUDADeviceParser(
3131
std::unique_ptr<CompilerInstance> DeviceInstance,
3232
CompilerInstance &HostInstance,
3333
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> VFS,
34-
llvm::Error &Err, const std::list<PartialTranslationUnit> &PTUs);
34+
llvm::Error &Err, std::list<PartialTranslationUnit> &PTUs);
3535

3636
llvm::Expected<TranslationUnitDecl *> Parse(llvm::StringRef Input) override;
3737

clang/lib/Interpreter/Interpreter.cpp

+57-32
Original file line numberDiff line numberDiff line change
@@ -451,13 +451,44 @@ const char *const Runtimes = R"(
451451
)";
452452

453453
llvm::Expected<std::unique_ptr<Interpreter>>
454-
Interpreter::create(std::unique_ptr<CompilerInstance> CI) {
454+
Interpreter::create(std::unique_ptr<CompilerInstance> CI,
455+
std::unique_ptr<CompilerInstance> DeviceCI) {
455456
llvm::Error Err = llvm::Error::success();
456457
auto Interp =
457458
std::unique_ptr<Interpreter>(new Interpreter(std::move(CI), Err));
458459
if (Err)
459460
return std::move(Err);
460461

462+
CompilerInstance &HostCI = *(Interp->getCompilerInstance());
463+
464+
if (DeviceCI) {
465+
Interp->DeviceAct = std::make_unique<IncrementalAction>(
466+
*DeviceCI, *Interp->TSCtx->getContext(), Err, *Interp);
467+
468+
if (Err)
469+
return std::move(Err);
470+
471+
DeviceCI->ExecuteAction(*Interp->DeviceAct);
472+
473+
// avoid writing fat binary to disk using an in-memory virtual file system
474+
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> IMVFS =
475+
std::make_unique<llvm::vfs::InMemoryFileSystem>();
476+
llvm::IntrusiveRefCntPtr<llvm::vfs::OverlayFileSystem> OverlayVFS =
477+
std::make_unique<llvm::vfs::OverlayFileSystem>(
478+
llvm::vfs::getRealFileSystem());
479+
OverlayVFS->pushOverlay(IMVFS);
480+
HostCI.createFileManager(OverlayVFS);
481+
482+
auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
483+
std::move(DeviceCI), HostCI, IMVFS, Err,
484+
Interp->PTUs);
485+
486+
if (Err)
487+
return std::move(Err);
488+
489+
Interp->DeviceParser = std::move(DeviceParser);
490+
}
491+
461492
// Add runtime code and set a marker to hide it from user code. Undo will not
462493
// go through that.
463494
auto PTU = Interp->Parse(Runtimes);
@@ -472,29 +503,7 @@ Interpreter::create(std::unique_ptr<CompilerInstance> CI) {
472503
llvm::Expected<std::unique_ptr<Interpreter>>
473504
Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
474505
std::unique_ptr<CompilerInstance> DCI) {
475-
// avoid writing fat binary to disk using an in-memory virtual file system
476-
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> IMVFS =
477-
std::make_unique<llvm::vfs::InMemoryFileSystem>();
478-
llvm::IntrusiveRefCntPtr<llvm::vfs::OverlayFileSystem> OverlayVFS =
479-
std::make_unique<llvm::vfs::OverlayFileSystem>(
480-
llvm::vfs::getRealFileSystem());
481-
OverlayVFS->pushOverlay(IMVFS);
482-
CI->createFileManager(OverlayVFS);
483-
484-
auto Interp = Interpreter::create(std::move(CI));
485-
if (auto E = Interp.takeError())
486-
return std::move(E);
487-
488-
llvm::Error Err = llvm::Error::success();
489-
auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
490-
std::move(DCI), *(*Interp)->getCompilerInstance(), IMVFS, Err,
491-
(*Interp)->PTUs);
492-
if (Err)
493-
return std::move(Err);
494-
495-
(*Interp)->DeviceParser = std::move(DeviceParser);
496-
497-
return Interp;
506+
return Interpreter::create(std::move(CI), std::move(DCI));
498507
}
499508

500509
const CompilerInstance *Interpreter::getCompilerInstance() const {
@@ -532,15 +541,16 @@ size_t Interpreter::getEffectivePTUSize() const {
532541

533542
PartialTranslationUnit &
534543
Interpreter::RegisterPTU(TranslationUnitDecl *TU,
535-
std::unique_ptr<llvm::Module> M /*={}*/) {
544+
std::unique_ptr<llvm::Module> M /*={}*/,
545+
IncrementalAction *Action) {
536546
PTUs.emplace_back(PartialTranslationUnit());
537547
PartialTranslationUnit &LastPTU = PTUs.back();
538548
LastPTU.TUPart = TU;
539549

540550
if (!M)
541-
M = GenModule();
551+
M = GenModule(Action);
542552

543-
assert((!getCodeGen() || M) && "Must have a llvm::Module at this point");
553+
assert((!getCodeGen(Action) || M) && "Must have a llvm::Module at this point");
544554

545555
LastPTU.TheModule = std::move(M);
546556
LLVM_DEBUG(llvm::dbgs() << "compile-ptu " << PTUs.size() - 1
@@ -558,8 +568,21 @@ Interpreter::Parse(llvm::StringRef Code) {
558568
// included in the host compilation
559569
if (DeviceParser) {
560570
llvm::Expected<TranslationUnitDecl *> DeviceTU = DeviceParser->Parse(Code);
561-
if (auto E = DeviceTU.takeError())
571+
if (auto E = DeviceTU.takeError()) {
562572
return std::move(E);
573+
}
574+
575+
auto *CudaParser = llvm::cast<IncrementalCUDADeviceParser>(DeviceParser.get());
576+
577+
PartialTranslationUnit &DevicePTU = RegisterPTU(*DeviceTU, nullptr, DeviceAct.get());
578+
579+
llvm::Expected<llvm::StringRef> PTX = CudaParser->GeneratePTX();
580+
if (!PTX)
581+
return PTX.takeError();
582+
583+
llvm::Error Err = CudaParser->GenerateFatbinary();
584+
if (Err)
585+
return std::move(Err);
563586
}
564587

565588
// Tell the interpreter sliently ignore unused expressions since value
@@ -736,9 +759,9 @@ llvm::Error Interpreter::LoadDynamicLibrary(const char *name) {
736759
return llvm::Error::success();
737760
}
738761

739-
std::unique_ptr<llvm::Module> Interpreter::GenModule() {
762+
std::unique_ptr<llvm::Module> Interpreter::GenModule(IncrementalAction *Action) {
740763
static unsigned ID = 0;
741-
if (CodeGenerator *CG = getCodeGen()) {
764+
if (CodeGenerator *CG = getCodeGen(Action)) {
742765
// Clang's CodeGen is designed to work with a single llvm::Module. In many
743766
// cases for convenience various CodeGen parts have a reference to the
744767
// llvm::Module (TheModule or Module) which does not change when a new
@@ -760,8 +783,10 @@ std::unique_ptr<llvm::Module> Interpreter::GenModule() {
760783
return nullptr;
761784
}
762785

763-
CodeGenerator *Interpreter::getCodeGen() const {
764-
FrontendAction *WrappedAct = Act->getWrapped();
786+
CodeGenerator *Interpreter::getCodeGen(IncrementalAction *Action) const {
787+
if (!Action)
788+
Action = Act.get();
789+
FrontendAction *WrappedAct = Action->getWrapped();
765790
if (!WrappedAct->hasIRSupport())
766791
return nullptr;
767792
return static_cast<CodeGenAction *>(WrappedAct)->getCodeGenerator();

0 commit comments

Comments
 (0)