Skip to content

Remove arguments of _forward_eval_ϵ #2736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 12 additions & 29 deletions src/Nonlinear/ReverseAD/forward_over_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ end

function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
input_ϵ = _reinterpret_unsafe(T, d.input_ϵ)
fill!(d.output_ϵ, 0.0)
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
subexpr_forward_values_ϵ =
Expand All @@ -126,22 +125,10 @@ function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
d,
subexpr,
_reinterpret_unsafe(T, d.storage_ϵ),
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
input_ϵ,
subexpr_forward_values_ϵ,
d.data.operators,
)
end
_forward_eval_ϵ(
d,
ex,
_reinterpret_unsafe(T, d.storage_ϵ),
_reinterpret_unsafe(T, d.partials_storage_ϵ),
input_ϵ,
subexpr_forward_values_ϵ,
d.data.operators,
)
_forward_eval_ϵ(d, ex, _reinterpret_unsafe(T, d.partials_storage_ϵ))
# do a reverse pass
subexpr_reverse_values_ϵ =
_reinterpret_unsafe(T, d.subexpression_reverse_values_ϵ)
Expand Down Expand Up @@ -180,11 +167,7 @@ end
_forward_eval_ϵ(
d::NLPEvaluator,
ex::Union{_FunctionStorage,_SubexpressionStorage},
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
x_values_ϵ,
subexpression_values_ϵ,
user_operators::Nonlinear.OperatorRegistry,
) where {N,T}

Evaluate the directional derivatives of the expression tree in `ex`.
Expand All @@ -198,15 +181,15 @@ This assumes that `_reverse_model(d, x)` has already been called.
function _forward_eval_ϵ(
d::NLPEvaluator,
ex::Union{_FunctionStorage,_SubexpressionStorage},
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
x_values_ϵ,
subexpression_values_ϵ,
user_operators::Nonlinear.OperatorRegistry,
) where {N,T}
partials_storage_ϵ::AbstractVector{P},
) where {N,T,P<:ForwardDiff.Partials{N,T}}
storage_ϵ = _reinterpret_unsafe(P, d.storage_ϵ)
x_values_ϵ = reinterpret(P, d.input_ϵ)
subexpression_values_ϵ =
_reinterpret_unsafe(P, d.subexpression_forward_values_ϵ)
@assert length(storage_ϵ) >= length(ex.nodes)
@assert length(partials_storage_ϵ) >= length(ex.nodes)
zero_ϵ = zero(ForwardDiff.Partials{N,T})
zero_ϵ = zero(P)
# ex.nodes is already in order such that parents always appear before children
# so a backwards pass through ex.nodes is a forward pass through the tree
children_arr = SparseArrays.rowvals(ex.adj)
Expand Down Expand Up @@ -339,16 +322,16 @@ function _forward_eval_ϵ(
n_children,
)
has_hessian = Nonlinear.eval_multivariate_hessian(
user_operators,
user_operators.multivariate_operators[node.index],
d.data.operators,
d.data.operators.multivariate_operators[node.index],
H,
f_input,
)
# This might be `false` if we extend this code to all
# multivariate functions.
@assert has_hessian
for col in 1:n_children
dual = zero(ForwardDiff.Partials{N,T})
dual = zero(P)
for row in 1:n_children
# Make sure we get the lower-triangular component.
h = row >= col ? H[row, col] : H[col, row]
Expand All @@ -366,7 +349,7 @@ function _forward_eval_ϵ(
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
@inbounds child_idx = children_arr[ex.adj.colptr[k]]
f′′ = Nonlinear.eval_univariate_hessian(
user_operators,
d.data.operators,
node.index,
ex.forward_storage[child_idx],
)
Expand Down
18 changes: 1 addition & 17 deletions src/Nonlinear/ReverseAD/mathoptinterface_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
d,
subexpr,
reinterpret(T, d.storage_ϵ),
reinterpret(T, subexpr.partials_storage_ϵ),
input_ϵ,
subexpr_forward_values_ϵ,
d.data.operators,
)
end
# we only need to do one reverse pass through the subexpressions as well
Expand All @@ -365,11 +361,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
_forward_eval_ϵ(
d,
something(d.objective),
reinterpret(T, d.storage_ϵ),
reinterpret(T, d.partials_storage_ϵ),
input_ϵ,
subexpr_forward_values_ϵ,
d.data.operators,
)
_reverse_eval_ϵ(
output_ϵ,
Expand All @@ -383,15 +375,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
)
end
for (i, con) in enumerate(d.constraints)
_forward_eval_ϵ(
d,
con,
reinterpret(T, d.storage_ϵ),
reinterpret(T, d.partials_storage_ϵ),
input_ϵ,
subexpr_forward_values_ϵ,
d.data.operators,
)
_forward_eval_ϵ(d, con, reinterpret(T, d.partials_storage_ϵ))
_reverse_eval_ϵ(
output_ϵ,
con,
Expand Down
Loading