From c37c30213e468564be987518eabdf65f9e1abb8c Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 2 Oct 2024 21:38:44 -0400 Subject: [PATCH 01/12] add `finite_difference_jvp` Add the pushforward operation with implementation taken from jacobian but simplified. --- src/jvp | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 src/jvp diff --git a/src/jvp b/src/jvp new file mode 100644 index 0000000..42d6bd7 --- /dev/null +++ b/src/jvp @@ -0,0 +1,98 @@ +""" + FiniteDiff.finite_difference_jvp( + f, + x, + v, + fdtype = Val(:forward), + f_in=nothing; + relstep=default_relstep(fdtype, eltype(x)) + absstep=relstep) +""" +function finite_difference_jvp( + f, + x, + v + fdtype = Val(:forward), + f_in = nothing; + relstep=default_relstep(eltype(x), eltype(x)), + absstep=relstep, + dir=true) + if fdtype == Val(:complex) + ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") + end + vecx = _vec(x) + vecv = _vec(v) + + tmp = sqrt(dot(vecx, vecv)) + epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir) + if fdtype == Val(:forward) + fx = f_in isa Nothing ? f(x) : f_in + _x = @. x + epsilon * v + fx1 = f(_x) + return @. (fx1-fx)/epsilon + elseif fdtype == Val(:central) + _x = @. x + epsilon * v + fx1 = f(_x) + _x = @. x - epsilon * v + fx = f(_x) + return @. (fx1-fx)/(2epsilon) + else + fdtype_error(eltype(x)) + end +end + +""" + FiniteDiff.finite_difference_jvp!( + jvp::AbstractArray{<:Number}, + f, + x::AbstractArray{<:Number}, + v, + fdtype = Val(:forward), + f_in=nothing, + fx1 = nothing; + relstep=default_relstep(fdtype, eltype(x)) + absstep=relstep) +""" +function finite_difference_jvp!( + jvp, + f, + x, + v, + fdtype = Val(:forward), + f_in = nothing, + fx1 = nothing; + relstep = default_relstep(eltype(x), eltype(x)), + absstep = relstep, + dir = true) + if fdtype == Val(:complex) + ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") + end + vecx = _vec(x) + vecv = _vec(v) + + tmp = sqrt(dot(vecx, vecv)) + epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir) + if fdtype == Val(:forward) + if f_in isa Nothing + fx1 = copy(jvp) + f(fx1, x) + else + fx1 = f_in + end + @. x = x + epsilon * v + f(jvp, x) + @. jvp = (jvp-fx)/epsilon + elseif fdtype == Val(:central) + @. x = x - epsilon * v + if fx1 isa Nothing + fx1 = copy(jvp) + end + f(fx1, x) + @. x = x + epsilon * v + f(jvp, x) + @. jvp = (jvp-fx1)/(2epsilon) + else + fdtype_error(eltype(x)) + end + nothing +end From f0ac3bda25d7d15d949cb9e64fdb10711871adf9 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 2 Oct 2024 21:39:32 -0400 Subject: [PATCH 02/12] include file --- src/FiniteDiff.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/FiniteDiff.jl b/src/FiniteDiff.jl index bc05843..2481a70 100644 --- a/src/FiniteDiff.jl +++ b/src/FiniteDiff.jl @@ -40,5 +40,6 @@ include("derivatives.jl") include("gradients.jl") include("jacobians.jl") include("hessians.jl") +include("jvp.jl") end # module From 82fc83a3d553187d3757be0348b1691f09b37c62 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Wed, 2 Oct 2024 21:39:57 -0400 Subject: [PATCH 03/12] typo --- src/{jvp => jvp.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/{jvp => jvp.jl} (100%) diff --git a/src/jvp b/src/jvp.jl similarity index 100% rename from src/jvp rename to src/jvp.jl From b0922f6a935c12c44a24253f2cd6d0ada4e226d3 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 27 Jan 2025 16:51:48 -0500 Subject: [PATCH 04/12] typo --- src/jvp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jvp.jl b/src/jvp.jl index 42d6bd7..9aa9d0d 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -11,7 +11,7 @@ function finite_difference_jvp( f, x, - v + v, fdtype = Val(:forward), f_in = nothing; relstep=default_relstep(eltype(x), eltype(x)), From aa5b95466b15813aa4e691e871d61ba17f50f13c Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 28 Jan 2025 10:56:28 -0500 Subject: [PATCH 05/12] actually add cache --- src/jvp.jl | 189 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 144 insertions(+), 45 deletions(-) diff --git a/src/jvp.jl b/src/jvp.jl index 9aa9d0d..0145732 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -1,98 +1,197 @@ +mutable struct JVPCache{X1, FX1, FDType} + x1 :: X1 + fx1 :: FX1 +end + +""" + FiniteDiff.JVPCache( + x, + fdtype :: Type{T1} = Val{:forward}) + +Allocating Cache Constructor. +""" +function JVPCache( + x, + fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD} + fdtype isa Type && (fdtype = fdtype()) + JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x)) +end + +""" + FiniteDiff.JVPCache( + x, + fx1, + fdtype :: Type{T1} = Val{:forward}, + +Non-Allocating Cache Constructor. +""" +function JVPCache( + x, + fx, + fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD} + fdtype isa Type && (fdtype = fdtype()) + JVPCache{typeof(x), typeof(fx), fdtype}(copy(x),fx) +end + +""" + FiniteDiff.finite_difference_jvp( + f, + x :: AbstractArray{<:Number}, + v :: AbstractArray{<:Number}, + fdtype :: Type{T1}=Val{:central}, + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep) + +Cache-less. +""" +function finite_difference_jvp(f, x, v, + fdtype = Val(:forward), + f_in = nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + dir=true) + + if f_in isa Nothing + fx = f(x) + else + fx = f_in + end + cache = JVPCache(x, fx, fdtype) + finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir) +end + """ FiniteDiff.finite_difference_jvp( f, x, v, - fdtype = Val(:forward), - f_in=nothing; - relstep=default_relstep(fdtype, eltype(x)) - absstep=relstep) + cache::JVPCache; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + +Cached. """ function finite_difference_jvp( f, x, v, - fdtype = Val(:forward), - f_in = nothing; - relstep=default_relstep(eltype(x), eltype(x)), + cache::JVPCache{X1, FX1, fdtype}, + f_in=nothing; + relstep=default_relstep(fdtype, eltype(x)), absstep=relstep, - dir=true) + dir=true) where {X1, FX1, fdtype} + if fdtype == Val(:complex) ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") end - vecx = _vec(x) - vecv = _vec(v) + (; x1, fx1) = cache - tmp = sqrt(dot(vecx, vecv)) - epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir) + tmp = sqrt(dot(_vec(x), _vec(v))) + epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir) if fdtype == Val(:forward) fx = f_in isa Nothing ? f(x) : f_in - _x = @. x + epsilon * v - fx1 = f(_x) - return @. (fx1-fx)/epsilon + @. x1 = x + epsilon * v + fx1 = f(x1) + @. fx1 = (fx1-fx)/epsilon elseif fdtype == Val(:central) - _x = @. x + epsilon * v - fx1 = f(_x) - _x = @. x - epsilon * v - fx = f(_x) - return @. (fx1-fx)/(2epsilon) + @. x1 = x + epsilon * v + fx1 = f(x1) + @. x1 = x - epsilon * v + fx = f(x1) + @. fx1 = (fx1-fx)/(2epsilon) else fdtype_error(eltype(x)) end + fx1 end """ - FiniteDiff.finite_difference_jvp!( + finite_difference_jvp!( jvp::AbstractArray{<:Number}, f, x::AbstractArray{<:Number}, - v, - fdtype = Val(:forward), - f_in=nothing, - fx1 = nothing; - relstep=default_relstep(fdtype, eltype(x)) + v::AbstractArray{<:Number}, + fdtype :: Type{T1}=Val{:forward}, + returntype :: Type{T2}=eltype(x), + f_in :: Union{T2,Nothing}=nothing; + relstep=default_relstep(fdtype, eltype(x)), absstep=relstep) + +Cache-less. +""" +function finite_difference_jvp!(jvp, + f, + x, + v, + fdtype = Val(:forward), + f_in = nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep) + if !isnothing(f_in) + cache = JVPCache(x, f_in, fdtype) + elseif fdtype == Val(:forward) + fx = zero(x) + f(fx,x) + cache = JVPCache(x, fx, fdtype) + else + cache = JVPCache(x, fdtype) + end + finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep) +end + +""" + FiniteDiff.finite_difference_jvp!( + jvp::AbstractArray{<:Number}, + f, + x::AbstractArray{<:Number}, + v::AbstractArray{<:Number}, + cache::JVPCache; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep,) + +Cached. """ function finite_difference_jvp!( jvp, f, x, v, - fdtype = Val(:forward), - f_in = nothing, - fx1 = nothing; - relstep = default_relstep(eltype(x), eltype(x)), + cache::JVPCache{X1, FX1, fdtype}, + f_in = nothing; + relstep = default_relstep(fdtype, eltype(x)), absstep = relstep, - dir = true) + dir = true) where {X1, FX1, fdtype} + if fdtype == Val(:complex) ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") end - vecx = _vec(x) - vecv = _vec(v) - tmp = sqrt(dot(vecx, vecv)) - epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir) + (;x1, fx1) = cache + tmp = sqrt(dot(_vec(x), _vec(v))) + epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir) if fdtype == Val(:forward) if f_in isa Nothing - fx1 = copy(jvp) f(fx1, x) else fx1 = f_in end - @. x = x + epsilon * v - f(jvp, x) + @. x1 = x + epsilon * v + f(jvp, x1) @. jvp = (jvp-fx)/epsilon elseif fdtype == Val(:central) - @. x = x - epsilon * v - if fx1 isa Nothing - fx1 = copy(jvp) - end - f(fx1, x) - @. x = x + epsilon * v - f(jvp, x) + @. x1 = x - epsilon * v + f(fx1, x1) + @. x1 = x + epsilon * v + f(jvp, x1) @. jvp = (jvp-fx1)/(2epsilon) else fdtype_error(eltype(x)) end nothing end + +function resize!(cache::JVPCache, i::Int) + resize!(cache.x1, i) + cache.fx1 !== nothing && resize!(cache.fx1, i) + nothing +end From 37ccbee640d0734b8deef93f2e8be83504c3f547 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 28 Jan 2025 13:07:49 -0500 Subject: [PATCH 06/12] fix sqrt sign --- src/jvp.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jvp.jl b/src/jvp.jl index 0145732..a85f6f1 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -86,7 +86,7 @@ function finite_difference_jvp( end (; x1, fx1) = cache - tmp = sqrt(dot(_vec(x), _vec(v))) + tmp = sqrt(abs(dot(_vec(x), _vec(v)))) epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir) if fdtype == Val(:forward) fx = f_in isa Nothing ? f(x) : f_in @@ -167,7 +167,7 @@ function finite_difference_jvp!( end (;x1, fx1) = cache - tmp = sqrt(dot(_vec(x), _vec(v))) + tmp = sqrt(abs(dot(_vec(x), _vec(v)))) epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir) if fdtype == Val(:forward) if f_in isa Nothing From 7d4dfdba5c06e71f80302aea4e934a1381eec2ef Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 28 Jan 2025 13:27:31 -0500 Subject: [PATCH 07/12] typo --- src/jvp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jvp.jl b/src/jvp.jl index a85f6f1..4aa8e2f 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -177,7 +177,7 @@ function finite_difference_jvp!( end @. x1 = x + epsilon * v f(jvp, x1) - @. jvp = (jvp-fx)/epsilon + @. jvp = (jvp-fx1)/epsilon elseif fdtype == Val(:central) @. x1 = x - epsilon * v f(fx1, x1) From 02286bc2cd399f78dbc26c266749b7b23df02caf Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 28 Jan 2025 14:35:59 -0500 Subject: [PATCH 08/12] add tests --- test/finitedifftests.jl | 50 ++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/test/finitedifftests.jl b/test/finitedifftests.jl index a0b31b8..e975202 100644 --- a/test/finitedifftests.jl +++ b/test/finitedifftests.jl @@ -382,38 +382,68 @@ df = zero(x) df_ref = diag(J_ref) epsilon = zero(x) forward_cache = FiniteDiff.JacobianCache(x, Val{:forward}, eltype(x)) +forward_jvp_cache = FiniteDiff.JVPCache(x, Val{:forward}) @test forward_cache.colorvec == 1:length(x) central_cache = FiniteDiff.JacobianCache(x, Val{:central}, eltype(x)) +central_jvp_cache = FiniteDiff.JVPCache(x, Val{:central}) complex_cache = FiniteDiff.JacobianCache(x, Val{:complex}, eltype(x)) f_in = copy(y) +vdir = rand(2) +jvp_ref = J_ref*vdir @time @testset "Out-of-Place Jacobian StridedArray real-valued tests" begin - @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-4 - @test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-4 - @test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-4 - @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4 - @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-4 + @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-6 + @test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-6 @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, central_cache), J_ref) < 1e-8 @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, Val{:central}), J_ref) < 1e-8 @test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, complex_cache), J_ref) < 1e-14 end +@time @testset "Out-of-Place JVP StridedArray real-valued tests" begin + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6 + @test_throws Any err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8 + @test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, Val{:central}), jvp_ref) < 1e-8 +end + function test_iipJac(J_ref, args...; kwargs...) _J = zero(J_ref) FiniteDiff.finite_difference_jacobian!(_J, args...; kwargs...) _J end @time @testset "inPlace Jacobian StridedArray real-valued tests" begin - @test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-4 - @test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-4 - @test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-4 - @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4 - @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-4 + @test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-6 + @test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-6 + @test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-6 + @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6 + @test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-6 @test err_func(test_iipJac(J_ref, iipf, x, central_cache), J_ref) < 1e-8 @test err_func(test_iipJac(J_ref, iipf, x, Val{:central}), J_ref) < 1e-8 @test err_func(test_iipJac(J_ref, iipf, x, complex_cache), J_ref) < 1e-14 end +function test_iipJVP(jvp_ref, args...; kwargs...) + _jvp = zero(jvp_ref) + FiniteDiff.finite_difference_jvp!(_jvp, args...; kwargs...) + _jvp +end + +@time @testset "inPlace JVP StridedArray real-valued tests" begin + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6 + @test_throws Any err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8 + @test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, Val{:central}), jvp_ref) < 1e-8 +end + function iipf(fvec, x) fvec[1] = (im * x[1] + 3) * (x[2]^3 - 7) + 18 fvec[2] = sin(x[2] * exp(x[1]) - 1) From 062facce8ad534f15bb9ba78b7d12c292667ec59 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 28 Jan 2025 17:21:52 -0500 Subject: [PATCH 09/12] don\'t assume mutability for out of place --- src/jvp.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/jvp.jl b/src/jvp.jl index 4aa8e2f..213cd9a 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -84,21 +84,20 @@ function finite_difference_jvp( if fdtype == Val(:complex) ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") end - (; x1, fx1) = cache tmp = sqrt(abs(dot(_vec(x), _vec(v)))) epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir) if fdtype == Val(:forward) fx = f_in isa Nothing ? f(x) : f_in - @. x1 = x + epsilon * v + x1 = @. x + epsilon * v fx1 = f(x1) - @. fx1 = (fx1-fx)/epsilon + fx1 = @. (fx1-fx)/epsilon elseif fdtype == Val(:central) - @. x1 = x + epsilon * v + x1 = @. x + epsilon * v fx1 = f(x1) - @. x1 = x - epsilon * v + x1 = @. x - epsilon * v fx = f(x1) - @. fx1 = (fx1-fx)/(2epsilon) + fx1 = @. (fx1-fx)/epsilon else fdtype_error(eltype(x)) end From c21186bcb8d41051227618a8b4be6e54e8876f57 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 28 Jan 2025 17:35:16 -0500 Subject: [PATCH 10/12] typo --- src/jvp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jvp.jl b/src/jvp.jl index 213cd9a..2478cb3 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -97,7 +97,7 @@ function finite_difference_jvp( fx1 = f(x1) x1 = @. x - epsilon * v fx = f(x1) - fx1 = @. (fx1-fx)/epsilon + fx1 = @. (fx1-fx)/2epsilon else fdtype_error(eltype(x)) end From 39767f4e4b724a675ff45682d7ca0c6ec7c45c16 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 28 Jan 2025 17:42:04 -0500 Subject: [PATCH 11/12] Update src/jvp.jl --- src/jvp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jvp.jl b/src/jvp.jl index 2478cb3..ee177c6 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -97,7 +97,7 @@ function finite_difference_jvp( fx1 = f(x1) x1 = @. x - epsilon * v fx = f(x1) - fx1 = @. (fx1-fx)/2epsilon + fx1 = @. (fx1-fx)/(2epsilon) else fdtype_error(eltype(x)) end From 59b1ebb03577917d7e98466a2c1a4cab5224b689 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 30 Jan 2025 20:19:28 -0500 Subject: [PATCH 12/12] Update src/jvp.jl --- src/jvp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jvp.jl b/src/jvp.jl index ee177c6..27fe46d 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -30,7 +30,7 @@ function JVPCache( fx, fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD} fdtype isa Type && (fdtype = fdtype()) - JVPCache{typeof(x), typeof(fx), fdtype}(copy(x),fx) + JVPCache{typeof(x), typeof(fx), fdtype}(x,fx) end """