@@ -26,7 +26,7 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
2626                                 ulps allowed: T = 3,
2727                                 file: String = #file, line: UInt = #line)
2828                                 where T: BinaryFloatingPoint {
29-   if actual == T(expected) || actual.isNaN && expected.isNaN {
29+   if actual == T(expected) || actual.isNaN && expected.isNaN || actual.isInfinite && expected.isInfinite  {
3030    return
3131  }
3232  //  Compute error in ulp, compare to tolerance.
@@ -38,17 +38,40 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
3838             file: file, line: line)
3939}
4040
41+ func computeDividedDifference<T: BinaryFloatingPoint> (
42+   _ f: (T, T) -> T,
43+   _ x: T,
44+   _ y: T,
45+   eps: T = 0.01
46+ ) -> (dfdx: T, dfdy: T) {
47+   let dfdx = (f(x + eps, y) - f(x, y)) / eps
48+   let dfdy = (f(x, y + eps) - f(x, y)) / eps
49+   return (dfdx, dfdy)
50+ }
51+ 
4152func checkGradient<T: BinaryFloatingPoint & Differentiable>(
4253  _ f: @differentiable (T, T) -> T,
4354  _ x: T,
44-   _ y: T)
55+   _ y: T,
56+   ulps: T = 192)
4557where T == T.TangentVector {
4658  let eps = T(0.01)
4759  let grad = gradient(at: x, y, in: f)
48-   let dfdx = (f(x + eps, y) - f(x, y)) / eps
49-   let dfdy = (f(x, y + eps) - f(x, y)) / eps
50-   expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: 192)
51-   expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: 192)
60+   let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps)
61+   expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: ulps)
62+   expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: ulps)
63+ }
64+ 
65+ func checkDerivative<T: BinaryFloatingPoint & Differentiable>(
66+   _ f: @differentiable (T, T) -> T,
67+   _ x: T,
68+   _ y: T,
69+   ulps: T = 192)
70+ where T == T.TangentVector {
71+   let eps = T(0.01)
72+   let deriv = derivative(at: x, y, in: f)
73+   let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps)
74+   expectEqualWithTolerance(TestLiteralType(dfdx + dfdy), deriv, ulps: ulps)
5275}
5376
5477%for op in ['derivative', 'gradient']:
@@ -111,6 +134,68 @@ DerivativeTests.test("${op}_${T}") {
111134      checkGradient({ fmod($0, $1) }, x, y)
112135%else: # if op == 'derivative'
113136      // TODO(TF-1108): Implement JVPs for `remainder` and `fmod`.
137+ %end
138+     }
139+   }
140+ 
141+   // pow
142+   let eps:${T} = 0.01
143+   let ulps:${T} = eps/eps.ulp
144+ 
145+   // Checks for negative base.
146+   for a in -3..<0 {
147+     let x = ${T}(a)
148+     for b in -3...3 {
149+       let y = ${T}(b)
150+       let expectedDx =  y * pow(x, y - 1)
151+       let expectedDy = ${T}.zero
152+       let dpow = ${op}(at: x, y, in: pow)
153+ %if op == 'gradient':
154+       expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0)
155+       expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1)
156+ %else: # if op == 'derivative'
157+       expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow)
158+ %end
159+     }
160+   }
161+ 
162+   // Checks for 0 base.
163+   for b in -3...3 {
164+     let y = ${T}(b)
165+     var expectedValues: (dx: ${T}, dy: ${T})?
166+     if y.isLess(than: 0) {
167+       expectedValues = (dx: ${T}.infinity, dy: ${T}.nan)
168+     } else if y.isZero {
169+       expectedValues = (dx: ${T}.nan, dy: ${T}.zero)
170+     } else if !y.isEqual(to: 1) {
171+       expectedValues = (dx: ${T}.zero, dy: ${T}.zero)
172+     }
173+     if let (expectedDx, expectedDy) = expectedValues {
174+       let dpow = ${op}(at: 0.0, y, in: pow)
175+ %if op == 'gradient':
176+       expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0)
177+       expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1)
178+ %else: # if op == 'derivative'
179+       expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow)
180+ %end
181+     } else {
182+ %if op == 'gradient':
183+       checkGradient({ pow($0, $1) }, 0.0, y, ulps: ulps)
184+ %else: # if op == 'derivative'
185+       checkDerivative({ pow($0, $1) }, 0.0, y, ulps: ulps)
186+ %end
187+     }
188+   }
189+ 
190+   // Checks for positive base.
191+   for a in 1...3 {
192+     let x = ${T}(a)
193+     for b in -3...3 {
194+       let y = ${T}(b)
195+ %if op == 'gradient':
196+       checkGradient({ pow($0, $1) }, x, y, ulps: ulps)
197+ %else: # if op == 'derivative'
198+       checkDerivative({ pow($0, $1) }, x, y, ulps: ulps)
114199%end
115200    }
116201  }
0 commit comments