Skip to content

Commit 5c89a86

Browse files
authored
Handle hypot (rust-lang#568)
1 parent f3bd406 commit 5c89a86

File tree

7 files changed

+327
-6
lines changed

7 files changed

+327
-6
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8274,6 +8274,97 @@ class AdjointGenerator
82748274
}
82758275
}
82768276

8277+
if (funcName == "hypot" || funcName == "hypotf" || funcName == "hypotl") {
8278+
if (gutils->knownRecomputeHeuristic.find(orig) !=
8279+
gutils->knownRecomputeHeuristic.end()) {
8280+
if (!gutils->knownRecomputeHeuristic[orig]) {
8281+
gutils->cacheForReverse(BuilderZ, newCall,
8282+
getIndex(orig, CacheType::Self));
8283+
}
8284+
}
8285+
eraseIfUnused(*orig);
8286+
if (gutils->isConstantInstruction(orig))
8287+
return;
8288+
8289+
switch (Mode) {
8290+
case DerivativeMode::ForwardModeSplit:
8291+
case DerivativeMode::ForwardMode: {
8292+
IRBuilder<> Builder2(&call);
8293+
getForwardBuilder(Builder2);
8294+
8295+
Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0));
8296+
Value *y = gutils->getNewFromOriginal(orig->getArgOperand(1));
8297+
Value *args[] = {x, y};
8298+
#if LLVM_VERSION_MAJOR >= 11
8299+
auto callval = orig->getCalledOperand();
8300+
#else
8301+
auto callval = orig->getCalledValue();
8302+
#endif
8303+
CallInst *cubcall = cast<CallInst>(
8304+
Builder2.CreateCall(orig->getFunctionType(), callval, args));
8305+
cubcall->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc()));
8306+
cubcall->setCallingConv(orig->getCallingConv());
8307+
8308+
auto rule = [&](Value *dx, Value *dy) {
8309+
Value *t;
8310+
if (dx)
8311+
dx = Builder2.CreateFMul(x, dx);
8312+
if (dy)
8313+
dy = Builder2.CreateFMul(y, dy);
8314+
if (dy && dx)
8315+
t = Builder2.CreateFAdd(dx, dy);
8316+
else if (dx)
8317+
t = dx;
8318+
else
8319+
t = dy;
8320+
return Builder2.CreateFDiv(t, cubcall);
8321+
};
8322+
8323+
Value *dif =
8324+
applyChainRule(call.getType(), Builder2, rule,
8325+
gutils->isConstantValue(orig->getOperand(0))
8326+
? nullptr
8327+
: diffe(orig->getOperand(0), Builder2),
8328+
gutils->isConstantValue(orig->getOperand(1))
8329+
? nullptr
8330+
: diffe(orig->getOperand(1), Builder2));
8331+
setDiffe(orig, dif, Builder2);
8332+
return;
8333+
}
8334+
case DerivativeMode::ReverseModeGradient:
8335+
case DerivativeMode::ReverseModeCombined: {
8336+
IRBuilder<> Builder2(call.getParent());
8337+
getReverseBuilder(Builder2);
8338+
8339+
Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
8340+
Builder2);
8341+
Value *y = lookup(gutils->getNewFromOriginal(orig->getArgOperand(1)),
8342+
Builder2);
8343+
Value *args[] = {x, y};
8344+
#if LLVM_VERSION_MAJOR >= 11
8345+
auto callval = orig->getCalledOperand();
8346+
#else
8347+
auto callval = orig->getCalledValue();
8348+
#endif
8349+
CallInst *cubcall = cast<CallInst>(
8350+
Builder2.CreateCall(orig->getFunctionType(), callval, args));
8351+
cubcall->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc()));
8352+
cubcall->setCallingConv(orig->getCallingConv());
8353+
for (int i = 0; i < 2; i++) {
8354+
if (!gutils->isConstantValue(orig->getArgOperand(i))) {
8355+
Value *dif0 = Builder2.CreateFDiv(
8356+
Builder2.CreateFMul(diffe(orig, Builder2), args[i]), cubcall);
8357+
addToDiffe(orig->getArgOperand(i), dif0, Builder2, x->getType());
8358+
}
8359+
}
8360+
return;
8361+
}
8362+
case DerivativeMode::ReverseModePrimal: {
8363+
return;
8364+
}
8365+
}
8366+
}
8367+
82778368
if (funcName == "tanhf" || funcName == "tanh") {
82788369
if (gutils->knownRecomputeHeuristic.find(orig) !=
82798370
gutils->knownRecomputeHeuristic.end()) {

enzyme/Enzyme/GradientUtils.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,13 +1732,15 @@ class GradientUtils : public CacheUtility {
17321732
Value *vals[size] = {args...};
17331733

17341734
for (size_t i = 0; i < size; ++i)
1735-
assert(cast<ArrayType>(vals[i]->getType())->getNumElements() == width);
1735+
if (vals[i])
1736+
assert(cast<ArrayType>(vals[i]->getType())->getNumElements() ==
1737+
width);
17361738

17371739
Type *wrappedType = ArrayType::get(diffType, width);
17381740
Value *res = UndefValue::get(wrappedType);
17391741
for (unsigned int i = 0; i < getWidth(); ++i) {
1740-
auto tup =
1741-
std::tuple<Args...>{(Builder.CreateExtractValue(args, {i}))...};
1742+
auto tup = std::tuple<Args...>{
1743+
(args ? Builder.CreateExtractValue(args, {i}) : nullptr)...};
17421744
auto diff = std::apply(rule, std::move(tup));
17431745
res = Builder.CreateInsertValue(res, diff, {i});
17441746
}
@@ -1757,11 +1759,13 @@ class GradientUtils : public CacheUtility {
17571759
Value *vals[size] = {args...};
17581760

17591761
for (size_t i = 0; i < size; ++i)
1760-
assert(cast<ArrayType>(vals[i]->getType())->getNumElements() == width);
1762+
if (vals[i])
1763+
assert(cast<ArrayType>(vals[i]->getType())->getNumElements() ==
1764+
width);
17611765

17621766
for (unsigned int i = 0; i < getWidth(); ++i) {
1763-
auto tup =
1764-
std::tuple<Args...>{(Builder.CreateExtractValue(args, {i}))...};
1767+
auto tup = std::tuple<Args...>{
1768+
(args ? Builder.CreateExtractValue(args, {i}) : nullptr)...};
17651769
std::apply(rule, std::move(tup));
17661770
}
17671771
} else {
@@ -1776,6 +1780,7 @@ class GradientUtils : public CacheUtility {
17761780
IRBuilder<> &Builder, Func rule) {
17771781
if (width > 1) {
17781782
for (auto diff : diffs) {
1783+
assert(diff);
17791784
assert(cast<ArrayType>(diff->getType())->getNumElements() == width);
17801785
}
17811786
Type *wrappedType = ArrayType::get(diffType, width);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x, double %y) {
5+
entry:
6+
%call = call double @hypot(double %x, double %y)
7+
ret double %call
8+
}
9+
10+
define double @test_derivative(double %x, double %y) {
11+
entry:
12+
%0 = tail call double (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.000000e+00, double %y, double 1.000000e+00)
13+
ret double %0
14+
}
15+
16+
declare double @hypot(double, double)
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_fwddiff(...)
20+
21+
; CHECK-LABEL: define internal double @fwddiffetester(
22+
; CHECK-NEXT: entry:
23+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double %y)
24+
; CHECK-NEXT: %1 = fmul fast double %x, %"x'"
25+
; CHECK-NEXT: %2 = fmul fast double %y, %"y'"
26+
; CHECK-NEXT: %3 = fadd fast double %1, %2
27+
; CHECK-NEXT: %4 = fdiv fast double %3, %0
28+
; CHECK-NEXT: ret double %4
29+
; CHECK-NEXT: }
30+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
define double @tester2(double %x) {
4+
entry:
5+
%call = call double @hypot(double %x, double 2.000000e+00)
6+
ret double %call
7+
}
8+
9+
define double @test_derivative(double %x, double %y) {
10+
entry:
11+
%0 = tail call double (...) @__enzyme_fwddiff(double (double)* nonnull @tester2, double %x, double 1.000000e+00)
12+
ret double %0
13+
}
14+
15+
declare double @hypot(double, double)
16+
17+
; Function Attrs: nounwind
18+
declare double @__enzyme_fwddiff(...)
19+
20+
; CHECK-LABEL: define internal double @fwddiffetester2(
21+
; CHECK-NEXT: entry:
22+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double 2.000000e+00)
23+
; CHECK-NEXT: %1 = fmul fast double %x, %"x'"
24+
; CHECK-NEXT: %2 = fdiv fast double %1, %0
25+
; CHECK-NEXT: ret double %2
26+
; CHECK-NEXT: }
27+
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x, double %y) {
5+
entry:
6+
%call = call double @hypot(double %x, double %y)
7+
ret double %call
8+
}
9+
10+
define double @tester2(double %x) {
11+
entry:
12+
%call = call double @hypot(double %x, double 2.000000e+00)
13+
ret double %call
14+
}
15+
16+
define double @test_derivative(double %x, double %y) {
17+
entry:
18+
%0 = tail call double (...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.000000e+00, double %y, double 1.000000e+00, i8* null)
19+
%1 = tail call double (...) @__enzyme_fwdsplit(double (double)* nonnull @tester2, double %x, double 1.000000e+00, i8* null)
20+
ret double %0
21+
}
22+
23+
declare double @hypot(double, double)
24+
25+
; Function Attrs: nounwind
26+
declare double @__enzyme_fwdsplit(...)
27+
28+
; CHECK-LABEL: define internal double @fwddiffetester(
29+
; CHECK-NEXT: entry:
30+
; CHECK-NEXT: tail call void @free(i8* {{(nonnull )?}}%tapeArg)
31+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double %y)
32+
; CHECK-NEXT: %1 = fmul fast double %x, %"x'"
33+
; CHECK-NEXT: %2 = fmul fast double %y, %"y'"
34+
; CHECK-NEXT: %3 = fadd fast double %1, %2
35+
; CHECK-NEXT: %4 = fdiv fast double %3, %0
36+
; CHECK-NEXT: ret double %4
37+
; CHECK-NEXT: }
38+
39+
; CHECK-LABEL: define internal double @fwddiffetester2(
40+
; CHECK-NEXT: entry:
41+
; CHECK-NEXT: tail call void @free(i8* {{(nonnull )?}}%tapeArg)
42+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double 2.000000e+00)
43+
; CHECK-NEXT: %1 = fmul fast double %x, %"x'"
44+
; CHECK-NEXT: %2 = fdiv fast double %1, %0
45+
; CHECK-NEXT: ret double %2
46+
; CHECK-NEXT: }
47+
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
%struct.Gradients = type { double, double, double }
4+
5+
define double @tester(double %x, double %y) {
6+
entry:
7+
%call = call double @hypot(double %x, double %y)
8+
ret double %call
9+
}
10+
11+
define double @tester2(double %x) {
12+
entry:
13+
%call = call double @hypot(double %x, double 2.000000e+00)
14+
ret double %call
15+
}
16+
17+
18+
define %struct.Gradients @test_derivative(double %x, double %y) {
19+
entry:
20+
%0 = tail call %struct.Gradients (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 2.0, double 3.0, double %y, double 1.0, double 2.0, double 3.0)
21+
%1 = tail call %struct.Gradients (...) @__enzyme_fwddiff(double (double)* nonnull @tester2, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 2.0, double 3.0)
22+
ret %struct.Gradients %0
23+
}
24+
25+
declare double @hypot(double, double)
26+
27+
declare %struct.Gradients @__enzyme_fwddiff(...)
28+
29+
; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", double %y, [3 x double] %"y'")
30+
; CHECK-NEXT: entry:
31+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double %y)
32+
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0
33+
; CHECK-NEXT: %2 = extractvalue [3 x double] %"y'", 0
34+
; CHECK-NEXT: %3 = fmul fast double %x, %1
35+
; CHECK-NEXT: %4 = fmul fast double %y, %2
36+
; CHECK-NEXT: %5 = fadd fast double %3, %4
37+
; CHECK-NEXT: %6 = fdiv fast double %5, %0
38+
; CHECK-NEXT: %7 = insertvalue [3 x double] undef, double %6, 0
39+
; CHECK-NEXT: %8 = extractvalue [3 x double] %"x'", 1
40+
; CHECK-NEXT: %9 = extractvalue [3 x double] %"y'", 1
41+
; CHECK-NEXT: %10 = fmul fast double %x, %8
42+
; CHECK-NEXT: %11 = fmul fast double %y, %9
43+
; CHECK-NEXT: %12 = fadd fast double %10, %11
44+
; CHECK-NEXT: %13 = fdiv fast double %12, %0
45+
; CHECK-NEXT: %14 = insertvalue [3 x double] %7, double %13, 1
46+
; CHECK-NEXT: %15 = extractvalue [3 x double] %"x'", 2
47+
; CHECK-NEXT: %16 = extractvalue [3 x double] %"y'", 2
48+
; CHECK-NEXT: %17 = fmul fast double %x, %15
49+
; CHECK-NEXT: %18 = fmul fast double %y, %16
50+
; CHECK-NEXT: %19 = fadd fast double %17, %18
51+
; CHECK-NEXT: %20 = fdiv fast double %19, %0
52+
; CHECK-NEXT: %21 = insertvalue [3 x double] %14, double %20, 2
53+
; CHECK-NEXT: ret [3 x double] %21
54+
; CHECK-NEXT: }
55+
56+
; CHECK: define internal [3 x double] @fwddiffe3tester2(double %x, [3 x double] %"x'")
57+
; CHECK-NEXT: entry:
58+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double 2.000000e+00)
59+
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0
60+
; CHECK-NEXT: %2 = fmul fast double %x, %1
61+
; CHECK-NEXT: %3 = fdiv fast double %2, %0
62+
; CHECK-NEXT: %4 = insertvalue [3 x double] undef, double %3, 0
63+
; CHECK-NEXT: %5 = extractvalue [3 x double] %"x'", 1
64+
; CHECK-NEXT: %6 = fmul fast double %x, %5
65+
; CHECK-NEXT: %7 = fdiv fast double %6, %0
66+
; CHECK-NEXT: %8 = insertvalue [3 x double] %4, double %7, 1
67+
; CHECK-NEXT: %9 = extractvalue [3 x double] %"x'", 2
68+
; CHECK-NEXT: %10 = fmul fast double %x, %9
69+
; CHECK-NEXT: %11 = fdiv fast double %10, %0
70+
; CHECK-NEXT: %12 = insertvalue [3 x double] %8, double %11, 2
71+
; CHECK-NEXT: ret [3 x double] %12
72+
; CHECK-NEXT: }
73+
74+
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x, double %y) {
5+
entry:
6+
%call = call double @hypot(double %x, double %y)
7+
ret double %call
8+
}
9+
10+
define double @tester2(double %x) {
11+
entry:
12+
%call = call double @hypot(double %x, double 2.000000e+00)
13+
ret double %call
14+
}
15+
16+
define double @test_derivative(double %x, double %y) {
17+
entry:
18+
%0 = tail call double (...) @__enzyme_autodiff(double (double, double)* nonnull @tester, double %x, double %y)
19+
%1 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @tester2, double %x)
20+
ret double %0
21+
}
22+
23+
declare double @hypot(double, double)
24+
25+
; Function Attrs: nounwind
26+
declare double @__enzyme_autodiff(...)
27+
28+
; CHECK: define internal { double, double } @diffetester(double %x, double %y, double %differeturn)
29+
; CHECK-NEXT: entry:
30+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double %y)
31+
; CHECK-NEXT: %1 = fmul fast double %differeturn, %x
32+
; CHECK-NEXT: %2 = fdiv fast double %1, %0
33+
; CHECK-NEXT: %3 = fmul fast double %differeturn, %y
34+
; CHECK-NEXT: %4 = fdiv fast double %3, %0
35+
; CHECK-NEXT: %5 = insertvalue { double, double } undef, double %2, 0
36+
; CHECK-NEXT: %6 = insertvalue { double, double } %5, double %4, 1
37+
; CHECK-NEXT: ret { double, double } %6
38+
; CHECK-NEXT: }
39+
40+
; CHECK: define internal { double } @diffetester2(double %x, double %differeturn)
41+
; CHECK-NEXT: entry:
42+
; CHECK-NEXT: %0 = call fast double @hypot(double %x, double 2.000000e+00)
43+
; CHECK-NEXT: %1 = fmul fast double %differeturn, %x
44+
; CHECK-NEXT: %2 = fdiv fast double %1, %0
45+
; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0
46+
; CHECK-NEXT: ret { double } %3
47+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)