diff --git a/src/FiniteDiff.jl b/src/FiniteDiff.jl index bc05843..2481a70 100644 --- a/src/FiniteDiff.jl +++ b/src/FiniteDiff.jl @@ -40,5 +40,6 @@ include("derivatives.jl") include("gradients.jl") include("jacobians.jl") include("hessians.jl") +include("jvp.jl") end # module diff --git a/src/jvp.jl b/src/jvp.jl new file mode 100644 index 0000000..27fe46d --- /dev/null +++ b/src/jvp.jl @@ -0,0 +1,196 @@ +mutable struct JVPCache{X1, FX1, FDType} + x1 :: X1 + fx1 :: FX1 +end + +""" + FiniteDiff.JVPCache( + x, + fdtype :: Type{T1} = Val{:forward}) + +Allocating Cache Constructor. +""" +function JVPCache( + x, + fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD} + fdtype isa Type && (fdtype = fdtype()) + JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x)) +end + +""" + FiniteDiff.JVPCache( + x, + fx1, + fdtype :: Type{T1} = Val{:forward}, + +Non-Allocating Cache Constructor. +""" +function JVPCache( + x, + fx, + fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD} + fdtype isa Type && (fdtype = fdtype()) + JVPCache{typeof(x), typeof(fx), fdtype}(x,fx) +end + +""" + FiniteDiff.finite_difference_jvp( + f, + x :: AbstractArray{<:Number}, + v :: AbstractArray{<:Number}, + fdtype :: Type{T1}=Val{:central}, + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep) + +Cache-less. +""" +function finite_difference_jvp(f, x, v, + fdtype = Val(:forward), + f_in = nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + dir=true) + + if f_in isa Nothing + fx = f(x) + else + fx = f_in + end + cache = JVPCache(x, fx, fdtype) + finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir) +end + +""" + FiniteDiff.finite_difference_jvp( + f, + x, + v, + cache::JVPCache; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + +Cached. +""" +function finite_difference_jvp( + f, + x, + v, + cache::JVPCache{X1, FX1, fdtype}, + f_in=nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + dir=true) where {X1, FX1, fdtype} + + if fdtype == Val(:complex) + ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") + end + + tmp = sqrt(abs(dot(_vec(x), _vec(v)))) + epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir) + if fdtype == Val(:forward) + fx = f_in isa Nothing ? f(x) : f_in + x1 = @. x + epsilon * v + fx1 = f(x1) + fx1 = @. (fx1-fx)/epsilon + elseif fdtype == Val(:central) + x1 = @. x + epsilon * v + fx1 = f(x1) + x1 = @. x - epsilon * v + fx = f(x1) + fx1 = @. (fx1-fx)/(2epsilon) + else + fdtype_error(eltype(x)) + end + fx1 +end + +""" + finite_difference_jvp!( + jvp::AbstractArray{<:Number}, + f, + x::AbstractArray{<:Number}, + v::AbstractArray{<:Number}, + fdtype :: Type{T1}=Val{:forward}, + returntype :: Type{T2}=eltype(x), + f_in :: Union{T2,Nothing}=nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep) + +Cache-less. +""" +function finite_difference_jvp!(jvp, + f, + x, + v, + fdtype = Val(:forward), + f_in = nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep) + if !isnothing(f_in) + cache = JVPCache(x, f_in, fdtype) + elseif fdtype == Val(:forward) + fx = zero(x) + f(fx,x) + cache = JVPCache(x, fx, fdtype) + else + cache = JVPCache(x, fdtype) + end + finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep) +end + +""" + FiniteDiff.finite_difference_jvp!( + jvp::AbstractArray{<:Number}, + f, + x::AbstractArray{<:Number}, + v::AbstractArray{<:Number}, + cache::JVPCache; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep,) + +Cached. +""" +function finite_difference_jvp!( + jvp, + f, + x, + v, + cache::JVPCache{X1, FX1, fdtype}, + f_in = nothing; + relstep = default_relstep(fdtype, eltype(x)), + absstep = relstep, + dir = true) where {X1, FX1, fdtype} + + if fdtype == Val(:complex) + ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") + end + + (;x1, fx1) = cache + tmp = sqrt(abs(dot(_vec(x), _vec(v)))) + epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir) + if fdtype == Val(:forward) + if f_in isa Nothing + f(fx1, x) + else + fx1 = f_in + end + @. x1 = x + epsilon * v + f(jvp, x1) + @. jvp = (jvp-fx1)/epsilon + elseif fdtype == Val(:central) + @. x1 = x - epsilon * v + f(fx1, x1) + @. x1 = x + epsilon * v + f(jvp, x1) + @. jvp = (jvp-fx1)/(2epsilon) + else + fdtype_error(eltype(x)) + end + nothing +end + +function resize!(cache::JVPCache, i::Int) + resize!(cache.x1, i) + cache.fx1 !== nothing && resize!(cache.fx1, i) + nothing +end diff --git a/test/finitedifftests.jl b/test/finitedifftests.jl index a0b31b8..e975202 100644 --- a/test/finitedifftests.jl +++ b/test/finitedifftests.jl @@ -382,38 +382,68 @@ df = zero(x) df_ref = diag(J_ref) epsilon = zero(x) forward_cache = FiniteDiff.JacobianCache(x, Val{:forward}, eltype(x)) +forward_jvp_cache = FiniteDiff.JVPCache(x, Val{:forward}) @test forward_cache.colorvec == 1:length(x) central_cache = FiniteDiff.JacobianCache(x, Val{:central}, eltype(x)) +central_jvp_cache = FiniteDiff.JVPCache(x, Val{:central}) complex_cache = FiniteDiff.JacobianCache(x, Val{:complex}, eltype(x)) f_in = copy(y) +vdir = rand(2) +jvp_ref = J_ref*vdir @time @testset "Out-of-Place Jacobian StridedArray real-valued tests" begin - @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-4 - @test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-4 - @test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-4 - @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4 - @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-4 + @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-6 + @test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-6 @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, central_cache), J_ref) < 1e-8 @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, Val{:central}), J_ref) < 1e-8 @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, complex_cache), J_ref) < 1e-14 end +@time @testset "Out-of-Place JVP StridedArray real-valued tests" begin + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6 + @test_throws Any err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, Val{:central}), jvp_ref) < 1e-8 +end + function test_iipJac(J_ref, args...; kwargs...) _J = zero(J_ref) FiniteDiff.finite_difference_jacobian!(_J, args...; kwargs...) _J end @time @testset "inPlace Jacobian StridedArray real-valued tests" begin - @test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-4 - @test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-4 - @test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-4 - @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4 - @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-4 + @test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-6 + @test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-6 + @test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-6 + @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6 + @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-6 @test err_func(test_iipJac(J_ref, iipf, x, central_cache), J_ref) < 1e-8 @test err_func(test_iipJac(J_ref, iipf, x, Val{:central}), J_ref) < 1e-8 @test err_func(test_iipJac(J_ref, iipf, x, complex_cache), J_ref) < 1e-14 end +function test_iipJVP(jvp_ref, args...; kwargs...) + _jvp = zero(jvp_ref) + FiniteDiff.finite_difference_jvp!(_jvp, args...; kwargs...) + _jvp +end + +@time @testset "inPlace JVP StridedArray real-valued tests" begin + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6 + @test_throws Any err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, Val{:central}), jvp_ref) < 1e-8 +end + function iipf(fvec, x) fvec[1] = (im * x[1] + 3) * (x[2]^3 - 7) + 18 fvec[2] = sin(x[2] * exp(x[1]) - 1)