Skip to content

Commit 51284d6

Browse files
authored
eigen test (rust-lang#363)
1 parent df488bd commit 51284d6

File tree

4 files changed

+80
-30
lines changed

4 files changed

+80
-30
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7903,8 +7903,6 @@ class AdjointGenerator
79037903
IRBuilder<> Builder2(&call);
79047904
getForwardBuilder(Builder2);
79057905

7906-
bool retUsed = subretused;
7907-
79087906
SmallVector<Value *, 8> args;
79097907
std::vector<DIFFE_TYPE> argsInverted;
79107908
std::map<int, Type *> gradByVal;
@@ -7970,7 +7968,7 @@ class AdjointGenerator
79707968

79717969
auto newcalled = gutils->Logic.CreateForwardDiff(
79727970
cast<Function>(called), subretType, argsInverted, gutils->TLI,
7973-
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
7971+
TR.analyzer.interprocedural, /*returnValue*/ subretused,
79747972
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
79757973
nextTypeInfo, {});
79767974

@@ -7989,30 +7987,46 @@ class AdjointGenerator
79897987
}
79907988
#endif
79917989

7992-
if (!newcalled->getReturnType()->isVoidTy()) {
7993-
bool structret = retUsed && subretType != DIFFE_TYPE::CONSTANT;
7994-
auto newcall = gutils->getNewFromOriginal(orig);
7995-
Value *diffe;
7996-
if (structret) {
7997-
diffe = Builder2.CreateExtractValue(diffes, 1);
7998-
} else {
7999-
diffe = diffes;
8000-
}
7990+
auto newcall = gutils->getNewFromOriginal(orig);
7991+
auto ifound = gutils->invertedPointers.find(orig);
7992+
Value *primal = nullptr;
7993+
Value *diffe = nullptr;
80017994

8002-
auto ifound = gutils->invertedPointers.find(orig);
8003-
if (ifound != gutils->invertedPointers.end()) {
8004-
auto placeholder = cast<PHINode>(&*ifound->second);
7995+
if (subretused && subretType != DIFFE_TYPE::CONSTANT) {
7996+
primal = Builder2.CreateExtractValue(diffes, 0);
7997+
diffe = Builder2.CreateExtractValue(diffes, 1);
7998+
} else if (!newcalled->getReturnType()->isVoidTy()) {
7999+
diffe = diffes;
8000+
}
8001+
8002+
if (ifound != gutils->invertedPointers.end()) {
8003+
auto placeholder = cast<PHINode>(&*ifound->second);
8004+
if (primal) {
8005+
gutils->replaceAWithB(newcall, primal);
8006+
gutils->erase(newcall);
8007+
}
8008+
if (diffe) {
80058009
gutils->replaceAWithB(placeholder, diffe);
8006-
gutils->erase(placeholder);
80078010
} else {
8008-
gutils->replaceAWithB(newcall, diffe);
8011+
gutils->invertedPointers.erase(ifound);
8012+
}
8013+
gutils->erase(placeholder);
8014+
} else {
8015+
if (primal && diffe) {
8016+
gutils->replaceAWithB(newcall, primal);
8017+
if (!gutils->isConstantValue(&call)) {
8018+
setDiffe(&call, diffe, Builder2);
8019+
}
80098020
gutils->erase(newcall);
8021+
} else if (diffe) {
8022+
gutils->replaceAWithB(newcall, diffe);
80108023
if (!gutils->isConstantValue(&call)) {
80118024
setDiffe(&call, diffe, Builder2);
80128025
}
8026+
gutils->erase(newcall);
8027+
} else {
8028+
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
80138029
}
8014-
} else {
8015-
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
80168030
}
80178031

80188032
return;

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,8 +2380,16 @@ void createTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
23802380
switch (retVal) {
23812381
case ReturnType::Return: {
23822382
auto ret = inst->getOperand(0);
2383-
toret = retType == DIFFE_TYPE::CONSTANT ? gutils->getNewFromOriginal(ret)
2384-
: gutils->diffe(ret, nBuilder);
2383+
2384+
if (retType == DIFFE_TYPE::CONSTANT) {
2385+
toret = gutils->getNewFromOriginal(ret);
2386+
} else if (!ret->getType()->isFPOrFPVectorTy() &&
2387+
TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
2388+
toret = gutils->invertPointerM(ret, nBuilder);
2389+
} else {
2390+
toret = gutils->diffe(ret, nBuilder);
2391+
}
2392+
23852393
break;
23862394
}
23872395
case ReturnType::TwoReturns: {
@@ -2392,7 +2400,8 @@ void createTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
23922400
toret =
23932401
nBuilder.CreateInsertValue(toret, gutils->getNewFromOriginal(ret), 0);
23942402

2395-
if (TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
2403+
if (!ret->getType()->isFPOrFPVectorTy() &&
2404+
TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
23962405
toret = nBuilder.CreateInsertValue(
23972406
toret, gutils->invertPointerM(ret, nBuilder), 1);
23982407
} else {
@@ -3717,7 +3726,6 @@ Function *EnzymeLogic::CreateForwardDiff(
37173726
return foundcalled;
37183727
}
37193728

3720-
auto TRo = TA.analyzeFunction(oldTypeInfo);
37213729
bool retActive = retType != DIFFE_TYPE::CONSTANT;
37223730

37233731
ReturnType retVal =

enzyme/test/Enzyme/ForwardMode/ptr-ret.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ declare dso_local double @_Z16__enzyme_fwddiffz(...)
3636

3737
; CHECK: define internal double @fwddiffe_Z6squared(double %x, double %"x'")
3838
; CHECK-NEXT: entry:
39-
; CHECK-NEXT: %call = call double* @_Z6toHeapd(double %x)
4039
; CHECK-NEXT: %0 = call { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
41-
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 1
42-
; CHECK-NEXT: %2 = load double, double* %call, align 8
40+
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 0
41+
; CHECK-NEXT: %2 = extractvalue { double*, double* } %0, 1
4342
; CHECK-NEXT: %3 = load double, double* %1, align 8
44-
; CHECK-NEXT: %4 = fmul fast double %3, %x
45-
; CHECK-NEXT: %5 = fmul fast double %"x'", %2
46-
; CHECK-NEXT: %6 = fadd fast double %4, %5
47-
; CHECK-NEXT: ret double %6
43+
; CHECK-NEXT: %4 = load double, double* %2, align 8
44+
; CHECK-NEXT: %5 = fmul fast double %4, %x
45+
; CHECK-NEXT: %6 = fmul fast double %"x'", %3
46+
; CHECK-NEXT: %7 = fadd fast double %5, %6
47+
; CHECK-NEXT: ret double %7
4848
; CHECK-NEXT: }
4949

5050
; CHECK: define internal { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %clang++ -mllvm -force-vector-width=1 -ffast-math -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
2+
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
3+
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
4+
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
5+
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
6+
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
7+
8+
#include "test_utils.h"
9+
#include <eigen3/Eigen/Core>
10+
#include <eigen3/Eigen/Dense>
11+
12+
double __enzyme_fwddiff(double(double), double, double);
13+
14+
double square(double x) {
15+
Eigen::Vector3d v(x, x * x, x * x * x);
16+
v *= 2;
17+
return v[1];
18+
}
19+
20+
double dsquare(double x) { return __enzyme_fwddiff(square, x, 1.0); }
21+
22+
int main() {
23+
double x = 4;
24+
double res = dsquare(x);
25+
APPROX_EQ(res, 16.0, 1e-10);
26+
printf("dsquare(%f)=%f\n", x, res);
27+
return 0;
28+
}

0 commit comments

Comments
 (0)