Skip to content

Commit b433920

Browse files
authored
[AutoDiff][TF-1200] Adding derivatives for stdlib pow function. (#30580)
Adding JVP and VJP function derivatives for pow function defined in stdlib. Resolves TF-1200.
1 parent 2caef36 commit b433920

File tree

2 files changed

+119
-6
lines changed

2 files changed

+119
-6
lines changed

stdlib/public/Differentiation/TgmathDerivatives.swift.gyb

+28
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ func _${derivative_kind}Trunc<T: FloatingPoint & Differentiable> (
130130
}
131131
%end # for derivative_kind in ['jvp', 'vjp']:
132132

133+
// Unary functions
133134
%for derivative_kind in ['jvp', 'vjp']:
134135
% linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
135136
% for T in ['Float', 'Double', 'Float80']:
@@ -271,3 +272,30 @@ func _${derivative_kind}Erfc(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${
271272
% end # if T == 'Float80':
272273
% end # for T in ['Float', 'Double', 'Float80']:
273274
%end # for derivative_kind in ['jvp', 'vjp']:
275+
276+
// Binary functions
277+
%for T in ['Float', 'Float80']:
278+
% if T == 'Float80':
279+
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
280+
% end
281+
@inlinable
282+
@derivative(of: pow)
283+
func _vjpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, pullback: (${T}) -> (${T}, ${T})) {
284+
let value = pow(x, y)
285+
return (value, { v in (
286+
v * y * pow(x, y - 1), v * value * log(x.isLessThanOrEqualTo(0) ? ${T}(1) : x)
287+
) })
288+
}
289+
290+
@inlinable
291+
@derivative(of: pow)
292+
func _jvpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, differential: (${T}, ${T}) -> ${T}) {
293+
let value = pow(x, y)
294+
return (value, { (dx, dy) in
295+
dx * y * pow(x, y - 1) + dy * value * log(x.isLessThanOrEqualTo(0) ? ${T}(1) : x)
296+
})
297+
}
298+
% if T == 'Float80':
299+
#endif
300+
% end # if T == 'Float80':
301+
%end # for T in ['Float', 'Float80']:

test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb

+91-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
2626
ulps allowed: T = 3,
2727
file: String = #file, line: UInt = #line)
2828
where T: BinaryFloatingPoint {
29-
if actual == T(expected) || actual.isNaN && expected.isNaN {
29+
if actual == T(expected) || actual.isNaN && expected.isNaN || actual.isInfinite && expected.isInfinite {
3030
return
3131
}
3232
// Compute error in ulp, compare to tolerance.
@@ -38,17 +38,40 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
3838
file: file, line: line)
3939
}
4040

41+
func computeDividedDifference<T: BinaryFloatingPoint> (
42+
_ f: (T, T) -> T,
43+
_ x: T,
44+
_ y: T,
45+
eps: T = 0.01
46+
) -> (dfdx: T, dfdy: T) {
47+
let dfdx = (f(x + eps, y) - f(x, y)) / eps
48+
let dfdy = (f(x, y + eps) - f(x, y)) / eps
49+
return (dfdx, dfdy)
50+
}
51+
4152
func checkGradient<T: BinaryFloatingPoint & Differentiable>(
4253
_ f: @differentiable (T, T) -> T,
4354
_ x: T,
44-
_ y: T)
55+
_ y: T,
56+
ulps: T = 192)
4557
where T == T.TangentVector {
4658
let eps = T(0.01)
4759
let grad = gradient(at: x, y, in: f)
48-
let dfdx = (f(x + eps, y) - f(x, y)) / eps
49-
let dfdy = (f(x, y + eps) - f(x, y)) / eps
50-
expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: 192)
51-
expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: 192)
60+
let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps)
61+
expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: ulps)
62+
expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: ulps)
63+
}
64+
65+
func checkDerivative<T: BinaryFloatingPoint & Differentiable>(
66+
_ f: @differentiable (T, T) -> T,
67+
_ x: T,
68+
_ y: T,
69+
ulps: T = 192)
70+
where T == T.TangentVector {
71+
let eps = T(0.01)
72+
let deriv = derivative(at: x, y, in: f)
73+
let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps)
74+
expectEqualWithTolerance(TestLiteralType(dfdx + dfdy), deriv, ulps: ulps)
5275
}
5376

5477
%for op in ['derivative', 'gradient']:
@@ -111,6 +134,68 @@ DerivativeTests.test("${op}_${T}") {
111134
checkGradient({ fmod($0, $1) }, x, y)
112135
%else: # if op == 'derivative'
113136
// TODO(TF-1108): Implement JVPs for `remainder` and `fmod`.
137+
%end
138+
}
139+
}
140+
141+
// pow
142+
let eps:${T} = 0.01
143+
let ulps:${T} = eps/eps.ulp
144+
145+
// Checks for negative base.
146+
for a in -3..<0 {
147+
let x = ${T}(a)
148+
for b in -3...3 {
149+
let y = ${T}(b)
150+
let expectedDx = y * pow(x, y - 1)
151+
let expectedDy = ${T}.zero
152+
let dpow = ${op}(at: x, y, in: pow)
153+
%if op == 'gradient':
154+
expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0)
155+
expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1)
156+
%else: # if op == 'derivative'
157+
expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow)
158+
%end
159+
}
160+
}
161+
162+
// Checks for 0 base.
163+
for b in -3...3 {
164+
let y = ${T}(b)
165+
var expectedValues: (dx: ${T}, dy: ${T})?
166+
if y.isLess(than: 0) {
167+
expectedValues = (dx: ${T}.infinity, dy: ${T}.nan)
168+
} else if y.isZero {
169+
expectedValues = (dx: ${T}.nan, dy: ${T}.zero)
170+
} else if !y.isEqual(to: 1) {
171+
expectedValues = (dx: ${T}.zero, dy: ${T}.zero)
172+
}
173+
if let (expectedDx, expectedDy) = expectedValues {
174+
let dpow = ${op}(at: 0.0, y, in: pow)
175+
%if op == 'gradient':
176+
expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0)
177+
expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1)
178+
%else: # if op == 'derivative'
179+
expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow)
180+
%end
181+
} else {
182+
%if op == 'gradient':
183+
checkGradient({ pow($0, $1) }, 0.0, y, ulps: ulps)
184+
%else: # if op == 'derivative'
185+
checkDerivative({ pow($0, $1) }, 0.0, y, ulps: ulps)
186+
%end
187+
}
188+
}
189+
190+
// Checks for positive base.
191+
for a in 1...3 {
192+
let x = ${T}(a)
193+
for b in -3...3 {
194+
let y = ${T}(b)
195+
%if op == 'gradient':
196+
checkGradient({ pow($0, $1) }, x, y, ulps: ulps)
197+
%else: # if op == 'derivative'
198+
checkDerivative({ pow($0, $1) }, x, y, ulps: ulps)
114199
%end
115200
}
116201
}

0 commit comments

Comments
 (0)