Skip to content

Commit 6ded2a3

Browse files
authored
Merge pull request #1942 from firedrakeproject/ReubenHill/interp_hessian
Add hessian to InterpolateBlock
2 parents 5c17582 + bc97163 commit 6ded2a3

File tree

2 files changed

+432
-2
lines changed

2 files changed

+432
-2
lines changed

firedrake/adjoint/blocks.py

Lines changed: 254 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,45 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
280280

281281

282282
class InterpolateBlock(Block, Backend):
283+
r"""
284+
Annotates an interpolator.
285+
286+
Consider the block as f with 1 forward model output ``v``, and inputs ``u`` and ``g``
287+
(there can, in principle, be any number of outputs).
288+
The adjoint input is ``vhat`` (``uhat`` and ``ghat`` are adjoints to ``u`` and ``v``
289+
respectively and are shown for completeness). The downstream block is ``J``
290+
which has input ``v``.
291+
292+
::
293+
294+
_ _
295+
|J|--<--v--<--|f|--<--u--<--...
296+
¯ | ¯ |
297+
vhat | uhat
298+
|
299+
---<--g--<--...
300+
|
301+
ghat
302+
303+
(Arrows indicate forward model direction)
304+
305+
::
306+
307+
J : V ⟶ R i.e. J(v) ∈ R ∀ v ∈ V
308+
309+
Interpolation can operate on an expression which may not be linear in its
310+
arguments.
311+
312+
::
313+
314+
f : W × G ⟶ V i.e. f(u, g) ∈ V ∀ u ∈ W and g ∈ G.
315+
f = I ∘ expr
316+
I : X ⟶ V i.e. I(;x) ∈ V ∀ x ∈ X.
317+
X is infinite dimensional.
318+
expr: W × G ⟶ X i.e. expr(u, g) ∈ X ∀ u ∈ W and g ∈ G.
319+
320+
Arguments after a semicolon are linear (i.e. operation I is linear)
321+
"""
283322
def __init__(self, interpolator, *functions, **kwargs):
284323
super().__init__()
285324

@@ -314,22 +353,235 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_outputs):
314353
return replace(self.expr, self._replace_map())
315354

316355
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
356+
r"""
357+
Denote ``d_u[A]`` as the gateaux derivative in the ``u`` direction.
358+
Arguments after a semicolon are linear.
359+
360+
This calculates
361+
362+
::
363+
364+
uhat = vhat ⋅ d_u[f](u, g; ⋅) (for inputs[idx] ∈ W)
365+
or
366+
ghat = vhat ⋅ d_g[f](u, g; ⋅) (for inputs[idx] ∈ G)
367+
368+
where ``inputs[idx]`` specifies the derivative direction, ``vhat`` is
369+
``adj_inputs[0]`` (since we assume only one block output)
370+
and ``⋅`` denotes an unspecified operand of ``u'`` (for
371+
``inputs[idx]`` ∈ ``W``) or ``g'`` (for ``inputs[idx]`` ∈ ``G``) size (``vhat`` left
372+
multiplies the derivative).
373+
374+
::
375+
376+
f = I ∘ expr : W × G ⟶ V
377+
i.e. I(expr|_{u, g}) ∈ V ∀ u ∈ W, g ∈ G.
378+
379+
Since ``I`` is linear we get that
380+
381+
::
382+
383+
d_u[I ∘ expr](u, g; u') = I ∘ d_u[expr](u|_u, g|_g; u')
384+
d_g[I ∘ expr](u, g; g') = I ∘ d_u[expr](u|_u, g|_g; g').
385+
386+
In tensor notation
387+
388+
::
389+
390+
uhat_q^T = vhat_p^T I([dexpr/du|_u]_q)_p
391+
or
392+
ghat_q^T = vhat_p^T I([dexpr/dg|_u]_q)_p
393+
394+
the output is then
395+
396+
::
397+
398+
uhat_q = I^T([dexpr/du|_u]_q)_p vhat_p
399+
or
400+
ghat_q = I^T([dexpr/dg|_u]_q)_p vhat_p.
401+
"""
402+
if len(adj_inputs) > 1:
403+
raise(NotImplementedError("Interpolate block must have a single output"))
317404
dJdm = self.backend.derivative(prepared, inputs[idx])
318-
return self.backend.Interpolator(dJdm, self.V).interpolate(adj_inputs[0], transpose=True)
405+
return self.backend.Interpolator(dJdm, self.V).interpolate(adj_inputs[0], transpose=True).vector()
319406

320407
def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs):
321408
return replace(self.expr, self._replace_map())
322409

323410
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
411+
r"""
412+
Denote ``d_u[A]`` as the gateaux derivative in the ``u`` direction.
413+
Arguments after a semicolon are linear.
414+
415+
For a block with two inputs this calculates
416+
417+
::
418+
419+
v' = d_u[f](u, g; u') + d_g[f](u, g; g')
420+
421+
where ``u' = tlm_inputs[0]`` and ``g = tlm_inputs[1]``.
422+
423+
::
424+
425+
f = I ∘ expr : W × G ⟶ V
426+
i.e. I(expr|_{u, g}) ∈ V ∀ u ∈ W, g ∈ G.
427+
428+
Since ``I`` is linear we get that
429+
430+
::
431+
432+
d_u[I ∘ expr](u, g; u') = I ∘ d_u[expr](u|_u, g|_g; u')
433+
d_g[I ∘ expr](u, g; g') = I ∘ d_u[expr](u|_u, g|_g; g').
434+
435+
In tensor notation the output is then
436+
437+
::
438+
439+
v'_l = I([dexpr/du|_{u,g}]_k u'_k)_l + I([dexpr/du|_{u,g}]_k g'_k)_l
440+
= I([dexpr/du|_{u,g}]_k u'_k + [dexpr/du|_{u,g}]_k g'_k)_l
441+
442+
since ``I`` is linear.
443+
"""
324444
dJdm = 0.
325445

446+
assert len(inputs) == len(tlm_inputs)
326447
for i, input in enumerate(inputs):
327448
if tlm_inputs[i] is None:
328449
continue
329450
dJdm += self.backend.derivative(prepared, input, tlm_inputs[i])
330-
331451
return self.backend.Interpolator(dJdm, self.V).interpolate()
332452

453+
def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies):
454+
return self.prepare_evaluate_adj(inputs, hessian_inputs, relevant_dependencies)
455+
456+
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
457+
block_variable, idx,
458+
relevant_dependencies, prepared=None):
459+
r"""
460+
Denote ``d_u[A]`` as the gateaux derivative in the ``u`` direction.
461+
Arguments after a semicolon are linear.
462+
463+
hessian_input is ``d_v[d_v[J]](v; v', ⋅)`` where the direction ``⋅`` is left
464+
unspecified so it can be operated upon.
465+
466+
.. warning::
467+
NOTE: This comment describes the implementation of 1 block input ``u``.
468+
(e.g. interpolating from an expression with 1 coefficient).
469+
Explaining how this works for multiple block inputs (e.g. ``u`` and ``g``) is
470+
currently too complicated for the author to do succinctly!
471+
472+
This function needs to output ``d_u[d_u[J ∘ f]](u; u', ⋅)`` where
473+
the direction ``⋅`` will be specified in another function and
474+
multiplied on the right with the output of this function.
475+
We will calculate this using the chain rule.
476+
477+
::
478+
479+
J : V ⟶ R i.e. J(v) ∈ R ∀ v ∈ V
480+
f = I ∘ expr : W ⟶ V
481+
J ∘ f : W ⟶ R i.e. J(f|u) ∈ R ∀ u ∈ V.
482+
d_u[J ∘ f] : W × W ⟶ R i.e. d_u[J ∘ f](u; u')
483+
d_u[d_u[J ∘ f]] : W × W × W ⟶ R i.e. d_u[d_u[J ∘ f]](u; u', u'')
484+
d_v[J] : V × V ⟶ R i.e. d_v[J](v; v')
485+
d_v[d_v[J]] : V × V × V ⟶ R i.e. d_v[d_v[J]](v; v', v'')
486+
487+
Chain rule:
488+
489+
::
490+
491+
d_u[J ∘ f](u; u') = d_v[J](v = f|u; v' = d_u[f](u; u'))
492+
493+
Multivariable chain rule:
494+
495+
::
496+
497+
d_u[d_u[J ∘ f]](u; u', u'') =
498+
d_v[d_v[J]](v = f|u; v' = d_u[f](u; u'), v'' = d_u[f](u; u''))
499+
+ d_v'[d_v[J]](v = f|u; v' = d_u[f](u; u'), v'' = d_u[d_u[f]](u; u', u''))
500+
= d_v[d_v[J]](v = f|u; v' = d_u[f](u; u'), v''=d_u[f](u; u''))
501+
+ d_v[J](v = f|u; v' = v'' = d_u[d_u[f]](u; u', u''))
502+
503+
since ``d_v[d_v[J]]`` is linear in ``v'`` so differentiating wrt to it leaves
504+
its coefficient, the bare d_v[J] operator which acts on the ``v''`` term
505+
that remains.
506+
507+
The ``d_u[d_u[f]](u; u', u'')`` term can be simplified further:
508+
509+
::
510+
511+
f = I ∘ expr : W ⟶ V i.e. I(expr|u) ∈ V ∀ u ∈ W
512+
d_u[I ∘ expr] : W × W ⟶ V i.e. d_u[I ∘ expr](u; u')
513+
d_u[d_u[I ∘ expr]] : W × W × W ⟶ V i.e. d_u[I ∘ expr](u; u', u'')
514+
d_x[I] : X × X ⟶ V i.e. d_x[I](x; x')
515+
d_x[d_x[I]] : X × X × X ⟶ V i.e. d_x[d_x[I]](x; x', x'')
516+
d_u[expr] : W × W ⟶ X i.e. d_u[expr](u; u')
517+
d_u[d_u[expr]] : W × W × W ⟶ X i.e. d_u[d_u[expr]](u; u', u'')
518+
519+
Since ``I`` is linear we get that
520+
521+
::
522+
523+
d_u[d_u[I ∘ expr]](u; u', u'') = I ∘ d_u[d_u[expr]](u; u', u'').
524+
525+
So our full hessian is:
526+
527+
::
528+
529+
d_u[d_u[J ∘ f]](u; u', u'')
530+
= d_v[d_v[J]](v = f|u; v' = d_u[f](u; u'), v''=d_u[f](u; u''))
531+
+ d_v[J](v = f|u; v' = v'' = d_u[d_u[f]](u; u', u''))
532+
533+
In tensor notation
534+
535+
::
536+
537+
[d^2[J ∘ f]/du^2|_u]_{lk} u'_k u''_k =
538+
[d^2J/dv^2|_{v=f|_u}]_{ij} [df/du|_u]_{jk} u'_k [df/du|_u]_{il} u''_l
539+
+ [dJ/dv|_{v=f_u}]_i I([d^2expr/du^2|_u]_{lk} u'_k)_i u''_l
540+
541+
In the first term:
542+
543+
::
544+
545+
[df/du|_u]_{jk} u'_k = v'_j
546+
=> [d^2J/dv^2|_{v=f|_u}]_{ij} [df/du|_u]_{jk} u'_k
547+
= [d^2J/dv^2|_{v=f|_u}]_{ij} v'_j
548+
= hessian_input_i
549+
=> [d^2J/dv^2|_{v=f|_u}]_{ij} [df/du|_u]_{jk} u'_k [df/du|_u]_{il}
550+
= hessian_input_i [df/du|_u]_{il}
551+
= self.evaluate_adj_component(inputs, hessian_inputs, ...)_l
552+
553+
In the second term we calculate everything explicitly though note
554+
``[dJ/dv|_{v=f_u}]_i = adj_inputs[0]_i``
555+
556+
Also, the second term is 0 if ``expr`` is linear.
557+
"""
558+
559+
if len(hessian_inputs) > 1 or len(adj_inputs) > 1:
560+
raise(NotImplementedError("Interpolate block must have a single output"))
561+
562+
component = self.evaluate_adj_component(inputs, hessian_inputs, block_variable, idx, prepared)
563+
564+
# Prepare again by replacing expression
565+
expr = replace(self.expr, self._replace_map())
566+
567+
# Calculate first derivative for each relevant block
568+
dexprdu = 0.
569+
for _, bv in relevant_dependencies:
570+
# Only take the derivative if there is a direction to take it in
571+
if bv.tlm_value is None:
572+
continue
573+
dexprdu += self.backend.derivative(expr, bv.saved_output, bv.tlm_value)
574+
575+
# Calculate the second derivative w.r.t. the specified coefficient's
576+
# saved value. Leave argument unspecified so it can be calculated with
577+
# the eventual inner product with u''.
578+
d2exprdudu = self.backend.derivative(dexprdu, block_variable.saved_output)
579+
580+
# left multiply by dJ/dv (adj_inputs[0]) - i.e. interpolate using the
581+
# transpose operator
582+
component += self.backend.Interpolator(d2exprdudu, self.V).interpolate(adj_inputs[0], transpose=True)
583+
return component.vector()
584+
333585
def prepare_recompute_component(self, inputs, relevant_outputs):
334586
return replace(self.expr, self._replace_map())
335587

0 commit comments

Comments
 (0)