Skip to content

Commit 94815c9

Browse files
vchuravywsmoses
andauthored
Better errors (rust-lang#588)
* More advanced error handler * Improved keep minus one Co-authored-by: William S. Moses <[email protected]>
1 parent abc282d commit 94815c9

File tree

9 files changed

+88
-22
lines changed

9 files changed

+88
-22
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) {
221221
delete TA;
222222
}
223223

224+
void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI,
225+
LLVMValueRef F) {
226+
FnTypeInfo FTI(eunwrap(CTI, cast<Function>(unwrap(F))));
227+
return (void *)&((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer;
228+
}
229+
224230
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
225231
CustomShadowFree FHandle) {
226232
shadowHandlers[std::string(Name)] =
@@ -539,8 +545,36 @@ const char *EnzymeTypeTreeToString(CTypeTreeRef src) {
539545

540546
return cstr;
541547
}
548+
549+
// TODO deprecated
542550
void EnzymeTypeTreeToStringFree(const char *cstr) { delete[] cstr; }
543551

552+
const char *EnzymeTypeAnalyzerToString(void *src) {
553+
auto TA = (TypeAnalyzer *)src;
554+
std::string str;
555+
raw_string_ostream ss(str);
556+
TA->dump(ss);
557+
ss.str();
558+
char *cstr = new char[str.length() + 1];
559+
std::strcpy(cstr, str.c_str());
560+
return cstr;
561+
}
562+
563+
const char *EnzymeGradientUtilsInvertedPointersToString(GradientUtils *gutils,
564+
void *src) {
565+
std::string str;
566+
raw_string_ostream ss(str);
567+
for (auto z : gutils->invertedPointers) {
568+
ss << "available inversion for " << *z.first << " of " << *z.second << "\n";
569+
}
570+
ss.str();
571+
char *cstr = new char[str.length() + 1];
572+
std::strcpy(cstr, str.c_str());
573+
return cstr;
574+
}
575+
576+
void EnzymeStringFree(const char *cstr) { delete[] cstr; }
577+
544578
void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2) {
545579
Instruction *I1 = cast<Instruction>(unwrap(inst1));
546580
Instruction *I2 = cast<Instruction>(unwrap(inst2));

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
8989
cl::opt<bool> EnzymeJuliaAddrLoad(
9090
"enzyme-julia-addr-load", cl::init(false), cl::Hidden,
9191
cl::desc("Mark all loads resulting in an addr(13)* to be legal to redo"));
92-
93-
void (*CustomErrorHandler)(const char *) = nullptr;
9492
}
9593

9694
struct CacheAnalysis {
@@ -1826,7 +1824,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
18261824
llvm::raw_string_ostream ss(s);
18271825
ss << "No augmented forward pass found for " + todiff->getName() << "\n";
18281826
ss << *todiff << "\n";
1829-
CustomErrorHandler(ss.str().c_str());
1827+
CustomErrorHandler(ss.str().c_str(), wrap(todiff),
1828+
ErrorType::NoDerivative, nullptr);
18301829
}
18311830
llvm::errs() << "mod: " << *todiff->getParent() << "\n";
18321831
llvm::errs() << *todiff << "\n";
@@ -3343,7 +3342,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
33433342
ss << "No reverse pass found for " + key.todiff->getName() << "\n";
33443343
ss << *key.todiff << "\n";
33453344
if (CustomErrorHandler) {
3346-
CustomErrorHandler(ss.str().c_str());
3345+
CustomErrorHandler(ss.str().c_str(), wrap(key.todiff),
3346+
ErrorType::NoDerivative, nullptr);
33473347
} else {
33483348
llvm_unreachable(ss.str().c_str());
33493349
}
@@ -3958,7 +3958,8 @@ Function *EnzymeLogic::CreateForwardDiff(
39583958
llvm::raw_string_ostream ss(s);
39593959
ss << "No forward derivative found for " + todiff->getName() << "\n";
39603960
ss << *todiff << "\n";
3961-
CustomErrorHandler(s.c_str());
3961+
CustomErrorHandler(s.c_str(), wrap(todiff), ErrorType::NoDerivative,
3962+
nullptr);
39623963
}
39633964
if (todiff->empty())
39643965
llvm::errs() << *todiff << "\n";

enzyme/Enzyme/EnzymeLogic.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050

5151
extern "C" {
5252
extern llvm::cl::opt<bool> EnzymePrint;
53-
extern void (*CustomErrorHandler)(const char *);
5453
}
5554

5655
enum class AugmentedStruct { Tape, Return, DifferentialReturn };

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ llvm::cl::opt<bool>
100100
EnzymeRematerialize("enzyme-rematerialize", cl::init(true), cl::Hidden,
101101
cl::desc("Rematerialize allocations/shadows in the "
102102
"reverse rather than caching"));
103-
104-
extern void (*CustomErrorHandler)(const char *);
105103
}
106104

107105
unsigned int MD_ToCopy[5] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa,
@@ -4478,11 +4476,11 @@ end:;
44784476
assert(BuilderM.GetInsertBlock()->getParent());
44794477
assert(oval);
44804478

4481-
if (CustomErrorHandler && isa<Constant>(oval)) {
4479+
if (CustomErrorHandler) {
44824480
std::string str;
44834481
raw_string_ostream ss(str);
44844482
ss << "cannot find shadow for " << *oval;
4485-
CustomErrorHandler(str.c_str());
4483+
CustomErrorHandler(str.c_str(), wrap(oval), ErrorType::NoShadow, this);
44864484
}
44874485

44884486
llvm::errs() << *newFunc->getParent() << "\n";

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,17 @@ void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) {
631631
Invalid = true;
632632
return;
633633
}
634+
if (CustomErrorHandler) {
635+
std::string str;
636+
raw_string_ostream ss(str);
637+
ss << "Illegal updateAnalysis prev:" << prev.str()
638+
<< " new: " << Data.str() << "\n";
639+
ss << "val: " << *Val;
640+
if (Origin)
641+
ss << " origin=" << *Origin;
642+
CustomErrorHandler(str.c_str(), wrap(Val), ErrorType::IllegalTypeAnalysis,
643+
(void *)this);
644+
}
634645
llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
635646
llvm::errs() << *fntypeinfo.Function << "\n";
636647
dump();
@@ -1384,7 +1395,19 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
13841395
return;
13851396

13861397
if (direction & DOWN) {
1387-
updateAnalysis(&gep, pointerAnalysis.KeepMinusOne(), &gep);
1398+
bool legal = true;
1399+
auto keepMinus = pointerAnalysis.KeepMinusOne(legal);
1400+
if (!legal) {
1401+
if (CustomErrorHandler)
1402+
CustomErrorHandler("Could not keep minus one", wrap(&gep),
1403+
ErrorType::IllegalTypeAnalysis, this);
1404+
else {
1405+
dump();
1406+
llvm::errs() << " could not perform minus one for gep'd: " << gep
1407+
<< "\n";
1408+
}
1409+
}
1410+
updateAnalysis(&gep, keepMinus, &gep);
13881411
updateAnalysis(&gep, TypeTree(pointerAnalysis.Inner0()).Only(-1), &gep);
13891412
}
13901413
if (direction & UP)
@@ -2114,14 +2137,13 @@ void TypeAnalyzer::visitInsertValueInst(InsertValueInst &I) {
21142137
updateAnalysis(&I, new_res, &I);
21152138
}
21162139

2117-
void TypeAnalyzer::dump() {
2118-
llvm::errs() << "<analysis>\n";
2140+
void TypeAnalyzer::dump(llvm::raw_ostream &ss) {
2141+
ss << "<analysis>\n";
21192142
for (auto &pair : analysis) {
2120-
llvm::errs() << *pair.first << ": " << pair.second.str()
2121-
<< ", intvals: " << to_string(knownIntegralValues(pair.first))
2122-
<< "\n";
2143+
ss << *pair.first << ": " << pair.second.str()
2144+
<< ", intvals: " << to_string(knownIntegralValues(pair.first)) << "\n";
21232145
}
2124-
llvm::errs() << "</analysis>\n";
2146+
ss << "</analysis>\n";
21252147
}
21262148

21272149
void TypeAnalyzer::visitAtomicRMWInst(llvm::AtomicRMWInst &I) {

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ class TypeAnalyzer : public llvm::InstVisitor<TypeAnalyzer> {
347347

348348
TypeTree getReturnAnalysis();
349349

350-
void dump();
350+
void dump(llvm::raw_ostream &ss = llvm::errs());
351351

352352
std::set<int64_t> knownIntegralValues(llvm::Value *val);
353353

enzyme/Enzyme/TypeAnalysis/TypeTree.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
657657
}
658658

659659
/// Keep only pointers (or anything's) to a repeated value (represented by -1)
660-
TypeTree KeepMinusOne() const {
660+
TypeTree KeepMinusOne(bool &legal) const {
661661
TypeTree dat;
662662

663663
for (const auto &pair : mapping) {
@@ -674,9 +674,8 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
674674
dat.insert(pair.first, pair.second);
675675
continue;
676676
}
677-
llvm::errs() << "could not merge test " << str() << "\n";
678-
assert(0 && "could not merge");
679-
llvm_unreachable("could not merge");
677+
legal = false;
678+
break;
680679
}
681680

682681
if (pair.first[1] == -1) {

enzyme/Enzyme/Utils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939

4040
using namespace llvm;
4141

42+
extern "C" {
43+
void (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
44+
void *) = nullptr;
45+
}
46+
4247
EnzymeFailure::EnzymeFailure(llvm::StringRef RemarkName,
4348
const llvm::DiagnosticLocation &Loc,
4449
const llvm::Instruction *CodeRegion)

enzyme/Enzyme/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,17 @@ namespace llvm {
6464
class ScalarEvolution;
6565
}
6666

67+
enum class ErrorType {
68+
NoDerivative = 0,
69+
NoShadow = 1,
70+
IllegalTypeAnalysis = 2
71+
};
72+
6773
extern "C" {
6874
/// Print additional debug info relevant to performance
6975
extern llvm::cl::opt<bool> EnzymePrintPerf;
76+
extern void (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
77+
void *);
7078
}
7179

7280
extern std::map<std::string, std::function<llvm::Value *(

0 commit comments

Comments
 (0)