@@ -26,7 +26,7 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
26
26
ulps allowed: T = 3,
27
27
file: String = #file, line: UInt = #line)
28
28
where T: BinaryFloatingPoint {
29
- if actual == T(expected) || actual.isNaN && expected.isNaN {
29
+ if actual == T(expected) || actual.isNaN && expected.isNaN || actual.isInfinite && expected.isInfinite {
30
30
return
31
31
}
32
32
// Compute error in ulp, compare to tolerance.
@@ -38,17 +38,40 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
38
38
file: file, line: line)
39
39
}
40
40
41
+ func computeDividedDifference<T: BinaryFloatingPoint> (
42
+ _ f: (T, T) -> T,
43
+ _ x: T,
44
+ _ y: T,
45
+ eps: T = 0.01
46
+ ) -> (dfdx: T, dfdy: T) {
47
+ let dfdx = (f(x + eps, y) - f(x, y)) / eps
48
+ let dfdy = (f(x, y + eps) - f(x, y)) / eps
49
+ return (dfdx, dfdy)
50
+ }
51
+
41
52
func checkGradient<T: BinaryFloatingPoint & Differentiable>(
42
53
_ f: @differentiable (T, T) -> T,
43
54
_ x: T,
44
- _ y: T)
55
+ _ y: T,
56
+ ulps: T = 192)
45
57
where T == T.TangentVector {
46
58
let eps = T(0.01)
47
59
let grad = gradient(at: x, y, in: f)
48
- let dfdx = (f(x + eps, y) - f(x, y)) / eps
49
- let dfdy = (f(x, y + eps) - f(x, y)) / eps
50
- expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: 192)
51
- expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: 192)
60
+ let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps)
61
+ expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: ulps)
62
+ expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: ulps)
63
+ }
64
+
65
+ func checkDerivative<T: BinaryFloatingPoint & Differentiable>(
66
+ _ f: @differentiable (T, T) -> T,
67
+ _ x: T,
68
+ _ y: T,
69
+ ulps: T = 192)
70
+ where T == T.TangentVector {
71
+ let eps = T(0.01)
72
+ let deriv = derivative(at: x, y, in: f)
73
+ let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps)
74
+ expectEqualWithTolerance(TestLiteralType(dfdx + dfdy), deriv, ulps: ulps)
52
75
}
53
76
54
77
%for op in ['derivative', 'gradient']:
@@ -111,6 +134,68 @@ DerivativeTests.test("${op}_${T}") {
111
134
checkGradient({ fmod($0, $1) }, x, y)
112
135
%else: # if op == 'derivative'
113
136
// TODO(TF-1108): Implement JVPs for `remainder` and `fmod`.
137
+ %end
138
+ }
139
+ }
140
+
141
+ // pow
142
+ let eps:${T} = 0.01
143
+ let ulps:${T} = eps/eps.ulp
144
+
145
+ // Checks for negative base.
146
+ for a in -3..<0 {
147
+ let x = ${T}(a)
148
+ for b in -3...3 {
149
+ let y = ${T}(b)
150
+ let expectedDx = y * pow(x, y - 1)
151
+ let expectedDy = ${T}.zero
152
+ let dpow = ${op}(at: x, y, in: pow)
153
+ %if op == 'gradient':
154
+ expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0)
155
+ expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1)
156
+ %else: # if op == 'derivative'
157
+ expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow)
158
+ %end
159
+ }
160
+ }
161
+
162
+ // Checks for 0 base.
163
+ for b in -3...3 {
164
+ let y = ${T}(b)
165
+ var expectedValues: (dx: ${T}, dy: ${T})?
166
+ if y.isLess(than: 0) {
167
+ expectedValues = (dx: ${T}.infinity, dy: ${T}.nan)
168
+ } else if y.isZero {
169
+ expectedValues = (dx: ${T}.nan, dy: ${T}.zero)
170
+ } else if !y.isEqual(to: 1) {
171
+ expectedValues = (dx: ${T}.zero, dy: ${T}.zero)
172
+ }
173
+ if let (expectedDx, expectedDy) = expectedValues {
174
+ let dpow = ${op}(at: 0.0, y, in: pow)
175
+ %if op == 'gradient':
176
+ expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0)
177
+ expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1)
178
+ %else: # if op == 'derivative'
179
+ expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow)
180
+ %end
181
+ } else {
182
+ %if op == 'gradient':
183
+ checkGradient({ pow($0, $1) }, 0.0, y, ulps: ulps)
184
+ %else: # if op == 'derivative'
185
+ checkDerivative({ pow($0, $1) }, 0.0, y, ulps: ulps)
186
+ %end
187
+ }
188
+ }
189
+
190
+ // Checks for positive base.
191
+ for a in 1...3 {
192
+ let x = ${T}(a)
193
+ for b in -3...3 {
194
+ let y = ${T}(b)
195
+ %if op == 'gradient':
196
+ checkGradient({ pow($0, $1) }, x, y, ulps: ulps)
197
+ %else: # if op == 'derivative'
198
+ checkDerivative({ pow($0, $1) }, x, y, ulps: ulps)
114
199
%end
115
200
}
116
201
}
0 commit comments