Skip to content

Commit 86ffa75

Browse files
authored
Julia malloc handling (rust-lang#754)
1 parent f10b38d commit 86ffa75

File tree

2 files changed

+65
-55
lines changed

2 files changed

+65
-55
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10298,56 +10298,60 @@ class AdjointGenerator
1029810298
cast<CallInst>(anti)->setTailCallKind(orig->getTailCallKind());
1029910299
cast<CallInst>(anti)->setDebugLoc(dbgLoc);
1030010300

10301+
if (anti->getType()->isPointerTy()) {
1030110302
#if LLVM_VERSION_MAJOR >= 14
10302-
cast<CallInst>(anti)->addAttributeAtIndex(
10303-
AttributeList::ReturnIndex, Attribute::NoAlias);
10304-
cast<CallInst>(anti)->addAttributeAtIndex(
10305-
AttributeList::ReturnIndex, Attribute::NonNull);
10306-
#else
10307-
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10308-
Attribute::NoAlias);
10309-
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10310-
Attribute::NonNull);
10311-
#endif
10312-
10313-
if (called->getName() == "malloc" ||
10314-
called->getName() == "_Znwm") {
10315-
if (auto ci = dyn_cast<ConstantInt>(args[0])) {
10316-
unsigned derefBytes = ci->getLimitedValue();
10317-
CallInst *cal =
10318-
cast<CallInst>(gutils->getNewFromOriginal(orig));
10303+
cast<CallInst>(anti)->addAttributeAtIndex(
10304+
AttributeList::ReturnIndex, Attribute::NoAlias);
10305+
cast<CallInst>(anti)->addAttributeAtIndex(
10306+
AttributeList::ReturnIndex, Attribute::NonNull);
10307+
#else
10308+
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10309+
Attribute::NoAlias);
10310+
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10311+
Attribute::NonNull);
10312+
#endif
10313+
10314+
if (called->getName() == "malloc" ||
10315+
called->getName() == "_Znwm") {
10316+
if (auto ci = dyn_cast<ConstantInt>(args[0])) {
10317+
unsigned derefBytes = ci->getLimitedValue();
10318+
CallInst *cal =
10319+
cast<CallInst>(gutils->getNewFromOriginal(orig));
1031910320
#if LLVM_VERSION_MAJOR >= 14
10320-
cast<CallInst>(anti)->addDereferenceableRetAttr(derefBytes);
10321-
cal->addDereferenceableRetAttr(derefBytes);
10321+
cast<CallInst>(anti)->addDereferenceableRetAttr(
10322+
derefBytes);
10323+
cal->addDereferenceableRetAttr(derefBytes);
1032210324
#if !defined(FLANG) && !defined(ROCM)
10323-
AttrBuilder B(called->getContext());
10324-
#else
10325-
AttrBuilder B;
10326-
#endif
10327-
B.addDereferenceableOrNullAttr(derefBytes);
10328-
cast<CallInst>(anti)->setAttributes(
10329-
cast<CallInst>(anti)->getAttributes().addRetAttributes(
10330-
orig->getContext(), B));
10331-
cal->setAttributes(cal->getAttributes().addRetAttributes(
10332-
orig->getContext(), B));
10333-
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10334-
Attribute::NoAlias);
10335-
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10336-
Attribute::NonNull);
10337-
#else
10338-
cast<CallInst>(anti)->addDereferenceableAttr(
10339-
llvm::AttributeList::ReturnIndex, derefBytes);
10340-
cal->addDereferenceableAttr(
10341-
llvm::AttributeList::ReturnIndex, derefBytes);
10342-
cast<CallInst>(anti)->addDereferenceableOrNullAttr(
10343-
llvm::AttributeList::ReturnIndex, derefBytes);
10344-
cal->addDereferenceableOrNullAttr(
10345-
llvm::AttributeList::ReturnIndex, derefBytes);
10346-
cal->addAttribute(AttributeList::ReturnIndex,
10347-
Attribute::NoAlias);
10348-
cal->addAttribute(AttributeList::ReturnIndex,
10349-
Attribute::NonNull);
10325+
AttrBuilder B(called->getContext());
10326+
#else
10327+
AttrBuilder B;
10328+
#endif
10329+
B.addDereferenceableOrNullAttr(derefBytes);
10330+
cast<CallInst>(anti)->setAttributes(
10331+
cast<CallInst>(anti)
10332+
->getAttributes()
10333+
.addRetAttributes(orig->getContext(), B));
10334+
cal->setAttributes(cal->getAttributes().addRetAttributes(
10335+
orig->getContext(), B));
10336+
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10337+
Attribute::NoAlias);
10338+
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10339+
Attribute::NonNull);
10340+
#else
10341+
cast<CallInst>(anti)->addDereferenceableAttr(
10342+
llvm::AttributeList::ReturnIndex, derefBytes);
10343+
cal->addDereferenceableAttr(
10344+
llvm::AttributeList::ReturnIndex, derefBytes);
10345+
cast<CallInst>(anti)->addDereferenceableOrNullAttr(
10346+
llvm::AttributeList::ReturnIndex, derefBytes);
10347+
cal->addDereferenceableOrNullAttr(
10348+
llvm::AttributeList::ReturnIndex, derefBytes);
10349+
cal->addAttribute(AttributeList::ReturnIndex,
10350+
Attribute::NoAlias);
10351+
cal->addAttribute(AttributeList::ReturnIndex,
10352+
Attribute::NonNull);
1035010353
#endif
10354+
}
1035110355
}
1035210356
}
1035310357
return anti;
@@ -10425,10 +10429,6 @@ class AdjointGenerator
1042510429
Value *tofree = lookup(anti, Builder2);
1042610430
assert(tofree);
1042710431
assert(tofree->getType());
10428-
assert(Type::getInt8Ty(tofree->getContext()));
10429-
assert(
10430-
PointerType::getUnqual(Type::getInt8Ty(tofree->getContext())));
10431-
assert(Type::getInt8PtrTy(tofree->getContext()));
1043210432
auto rule = [&](Value *tofree) {
1043310433
auto CI = freeKnownAllocation(Builder2, tofree, *called, dbgLoc,
1043410434
gutils->TLI);

enzyme/Enzyme/LibraryFuncs.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ extern std::map<std::string,
4242
/// For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp
4343
static inline bool isAllocationFunction(const llvm::Function &F,
4444
const llvm::TargetLibraryInfo &TLI) {
45-
if (F.getName() == "calloc")
45+
if (F.getName() == "calloc" || F.getName() == "malloc")
4646
return true;
4747
if (F.getName() == "swift_allocObject")
4848
return true;
@@ -220,9 +220,14 @@ static inline void zeroKnownAllocation(llvm::IRBuilder<> &bb,
220220
}
221221
Value *dst_arg = toZero;
222222

223-
dst_arg = bb.CreateBitCast(
224-
dst_arg, Type::getInt8PtrTy(toZero->getContext(),
225-
toZero->getType()->getPointerAddressSpace()));
223+
if (dst_arg->getType()->isIntegerTy())
224+
dst_arg =
225+
bb.CreateIntToPtr(dst_arg, Type::getInt8PtrTy(toZero->getContext()));
226+
else
227+
dst_arg = bb.CreateBitCast(
228+
dst_arg,
229+
Type::getInt8PtrTy(toZero->getContext(),
230+
toZero->getType()->getPointerAddressSpace()));
226231

227232
auto val_arg = ConstantInt::get(Type::getInt8Ty(toZero->getContext()), 0);
228233
auto len_arg =
@@ -330,8 +335,13 @@ freeKnownAllocation(llvm::IRBuilder<> &builder, llvm::Value *tofree,
330335
&allocationfn);
331336
}
332337

338+
if (tofree->getType()->isIntegerTy())
339+
tofree = builder.CreateIntToPtr(tofree,
340+
Type::getInt8PtrTy(tofree->getContext()));
341+
333342
llvm::LibFunc libfunc;
334-
if (allocationfn.getName() == "calloc") {
343+
if (allocationfn.getName() == "calloc" ||
344+
allocationfn.getName() == "malloc") {
335345
libfunc = LibFunc_malloc;
336346
} else {
337347
bool res = TLI.getLibFunc(allocationfn, libfunc);

0 commit comments

Comments
 (0)