Skip to content

Commit a49ef5e

Browse files
authored
Custom alloc zeroing (rust-lang#880)
1 parent fb64f4d commit a49ef5e

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

enzyme/Enzyme/Utils.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ void (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
4848
void *) = nullptr;
4949
LLVMValueRef (*CustomAllocator)(LLVMBuilderRef, LLVMTypeRef,
5050
/*Count*/ LLVMValueRef,
51-
/*Align*/ LLVMValueRef, uint8_t) = nullptr;
51+
/*Align*/ LLVMValueRef, uint8_t,
52+
LLVMValueRef *) = nullptr;
5253
LLVMValueRef (*CustomDeallocator)(LLVMBuilderRef, LLVMValueRef) = nullptr;
5354
void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef,
5455
LLVMValueRef) = nullptr;
@@ -91,10 +92,6 @@ Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc,
9192
CreateAllocation(B, RT, P, "tapemem", &malloccall, nullptr)->getType();
9293
if (auto F = getFunctionFromCall(malloccall)) {
9394
custom = F->getName() != "malloc";
94-
if (F->getName() == "julia.gc_alloc_obj" ||
95-
F->getName() == "jl_gc_alloc_typed" ||
96-
F->getName() == "ijl_gc_alloc_typed")
97-
ZeroInit = false;
9895
}
9996
allocType = cast<PointerType>(malloccall->getType());
10097
BB->eraseFromParent();
@@ -173,7 +170,8 @@ Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc,
173170
next->getType(),
174171
newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(RT) / 8);
175172
auto elSize = B.CreateUDiv(next, tsize, "", /*isExact*/ true);
176-
gVal = CreateAllocation(B, RT, elSize, "", nullptr, nullptr);
173+
Instruction *SubZero = nullptr;
174+
gVal = CreateAllocation(B, RT, elSize, "", nullptr, &SubZero);
177175

178176
gVal = B.CreatePointerCast(
179177
gVal, PointerType::get(
@@ -187,6 +185,21 @@ Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc,
187185
margs[2]->getType()};
188186
auto memsetF = Intrinsic::getDeclaration(&M, Intrinsic::memcpy, tys);
189187
B.CreateCall(memsetF, margs);
188+
if (SubZero) {
189+
ZeroInit = false;
190+
IRBuilder<> BB(SubZero);
191+
Value *zeroSize = BB.CreateSub(next, prevSize);
192+
Value *tmp = SubZero->getOperand(0);
193+
194+
#if LLVM_VERSION_MAJOR > 7
195+
tmp = BB.CreateInBoundsGEP(tmp->getType()->getPointerElementType(), tmp,
196+
prevSize);
197+
#else
198+
tmp = BB.CreateInBoundsGEP(tmp, prevSize);
199+
#endif
200+
SubZero->setOperand(0, tmp);
201+
SubZero->setOperand(2, zeroSize);
202+
}
190203
}
191204

192205
if (ZeroInit) {
@@ -255,15 +268,21 @@ Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count,
255268
auto Align = ConstantInt::get(Count->getType(), AlignI);
256269
CallInst *malloccall = nullptr;
257270
if (CustomAllocator) {
271+
LLVMValueRef wzeromem = nullptr;
258272
res = unwrap(CustomAllocator(wrap(&Builder), wrap(T), wrap(Count),
259-
wrap(Align), isDefault));
273+
wrap(Align), isDefault,
274+
ZeroMem ? &wzeromem : nullptr));
260275
if (auto I = dyn_cast<Instruction>(res))
261276
I->setName(Name);
262277

263278
malloccall = dyn_cast<CallInst>(res);
264279
if (malloccall == nullptr) {
265280
malloccall = cast<CallInst>(cast<Instruction>(res)->getOperand(0));
266281
}
282+
if (ZeroMem) {
283+
*ZeroMem = cast_or_null<Instruction>(unwrap(wzeromem));
284+
ZeroMem = nullptr;
285+
}
267286
} else {
268287
if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) {
269288
res = CallInst::CreateMalloc(Builder.GetInsertBlock(), Count->getType(),
@@ -324,11 +343,6 @@ Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count,
324343
if (caller) {
325344
*caller = malloccall;
326345
}
327-
if (auto F = getFunctionFromCall(malloccall))
328-
if (F->getName() == "julia.gc_alloc_obj" ||
329-
F->getName() == "jl_gc_alloc_typed" ||
330-
F->getName() == "ijl_gc_alloc_typed")
331-
ZeroMem = nullptr;
332346
if (ZeroMem) {
333347
auto PT = cast<PointerType>(malloccall->getType());
334348
Value *tozero = malloccall;

0 commit comments

Comments
 (0)