Skip to content

Commit 34e85a2

Browse files
authored
Add addr replacement handling of memtransfer (rust-lang#879)
* Add addr replacement handling of memtransfer * Update test
1 parent a49ef5e commit 34e85a2

File tree

5 files changed

+30
-6
lines changed

5 files changed

+30
-6
lines changed

enzyme/Enzyme/FunctionUtils.cpp

+26-1
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,35 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
350350
Intrinsic::getDeclaration(MS->getParent()->getParent()->getParent(),
351351
Intrinsic::memset, tys),
352352
nargs));
353-
nMS->copyIRFlags(MS);
353+
nMS->copyMetadata(*MS);
354+
nMS->setAttributes(MS->getAttributes());
354355
toErase.push_back(MS);
355356
continue;
356357
}
358+
if (auto MTI = dyn_cast<MemTransferInst>(inst)) {
359+
IRBuilder<> B(MTI);
360+
361+
Value *nargs[4] = {MTI->getArgOperand(0), MTI->getArgOperand(1),
362+
MTI->getArgOperand(2), MTI->getArgOperand(3)};
363+
364+
if (nargs[0] == prev)
365+
nargs[0] = rep;
366+
367+
if (nargs[1] == prev)
368+
nargs[1] = rep;
369+
370+
Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(),
371+
nargs[2]->getType()};
372+
373+
auto nMTI = cast<CallInst>(B.CreateCall(
374+
Intrinsic::getDeclaration(MTI->getParent()->getParent()->getParent(),
375+
MTI->getIntrinsicID(), tys),
376+
nargs));
377+
nMTI->copyMetadata(*MTI);
378+
nMTI->setAttributes(MTI->getAttributes());
379+
toErase.push_back(MTI);
380+
continue;
381+
}
357382
if (auto CI = dyn_cast<CallInst>(inst)) {
358383
if (auto F = CI->getCalledFunction()) {
359384
if (F->getName() == "julia.write_barrier" && legal) {

enzyme/Enzyme/Utils.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI,
6363
SmallVector<llvm::Instruction *, 2> res;
6464
if (EnzymePostCacheStore) {
6565
uint64_t size = 0;
66-
LLVMValueRef V2 = nullptr;
6766
auto ptr = EnzymePostCacheStore(wrap(SI), wrap(&B), &size);
6867
for (size_t i = 0; i < size; i++) {
6968
res.push_back(cast<Instruction>(unwrap(ptr[i])));

enzyme/test/Enzyme/ReverseMode/writeonlyretcjl.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ entry:
4040
; CHECK-NEXT: %"r'ai" = alloca double, i64 1, align 8
4141
; CHECK-NEXT: %0 = bitcast double* %"r'ai" to {}*
4242
; CHECK-NEXT: %1 = bitcast {}* %0 to i8*
43-
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %1, i8 0, i64 8, i1 false)
43+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %1, i8 0, i64 8, i1 false)
4444
; CHECK-NEXT: %2 = load double, double* %"r'ai", align 8
4545
; CHECK-NEXT: %3 = fadd fast double %2, %differeturn
4646
; CHECK-NEXT: store double %3, double* %"r'ai", align 8,

enzyme/test/Enzyme/ReverseMode/writeonlyretjl.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ entry:
7272
; CHECK-NEXT: %"r'ai" = alloca double, i64 1, align 8
7373
; CHECK-NEXT: %0 = bitcast double* %"r'ai" to {}*
7474
; CHECK-NEXT: %1 = bitcast {}* %0 to i8*
75-
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %1, i8 0, i64 8, i1 false)
75+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %1, i8 0, i64 8, i1 false)
7676
; CHECK-NEXT: %2 = load double, double* %"r'ai", align 8
7777
; CHECK-NEXT: %3 = fadd fast double %2, %differeturn
7878
; CHECK-NEXT: store double %3, double* %"r'ai", align 8

enzyme/test/Enzyme/ReverseMode/writeonlyretjlptr.ll

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ define double @dsquare({} addrspace(10)* %x, {} addrspace(10)* %dx) {
6464
; CHECK-NEXT: %"r'ai" = alloca {} addrspace(10)*, i64 1, align 8
6565
; CHECK-NEXT: %1 = bitcast {} addrspace(10)** %"r'ai" to {}*
6666
; CHECK-NEXT: %2 = bitcast {}* %1 to i8*
67-
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %2, i8 0, i64 8, i1 false)
67+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %2, i8 0, i64 8, i1 false)
6868
; CHECK-NEXT: call void @augmented_subsq({} addrspace(10)** %r, {} addrspace(10)** %"r'ai", {} addrspace(10)* %x, {} addrspace(10)* %"x'")
6969
; CHECK-NEXT: %l = load {} addrspace(10)*, {} addrspace(10)** %r, align 8
7070
; CHECK-NEXT: %bc = bitcast {} addrspace(10)* %l to double addrspace(10)*
@@ -78,7 +78,7 @@ define double @dsquare({} addrspace(10)* %x, {} addrspace(10)* %dx) {
7878
; CHECK-NEXT: %"r'ai" = alloca {} addrspace(10)*, i64 1, align 8
7979
; CHECK-NEXT: %0 = bitcast {} addrspace(10)** %"r'ai" to {}*
8080
; CHECK-NEXT: %1 = bitcast {}* %0 to i8*
81-
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %1, i8 0, i64 8, i1 false)
81+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %1, i8 0, i64 8, i1 false)
8282
; CHECK-NEXT: %"l'ipl" = load {} addrspace(10)*, {} addrspace(10)** %"r'ai", align 8
8383
; CHECK-NEXT: %"bc'ipc" = bitcast {} addrspace(10)* %"l'ipl" to double addrspace(10)*
8484
; CHECK-NEXT: %2 = load double, double addrspace(10)* %"bc'ipc"

0 commit comments

Comments
 (0)