diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index d8afb630a..18ee18151 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -181,8 +181,14 @@ let ## power # literal_pow is in base.jl function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number) - y = x ^ p - _dx = _pow_grad_x(x, p, float(y)) + if isinteger(p) + tmp = x ^ (p - 1) + y = x * tmp + _dx = p * tmp + else + y = x ^ p + _dx = _pow_grad_x(x, p, float(y)) + end if iszero(Δp) # Treat this as a strong zero, to avoid NaN, and save the cost of log return y, _dx * Δx