@@ -280,6 +280,45 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
280280
281281
282282class InterpolateBlock (Block , Backend ):
283+ r"""
284+ Annotates an interpolator.
285+
286+ Consider the block as f with 1 forward model output ``v``, and inputs ``u`` and ``g``
287+ (there can, in principle, be any number of outputs).
288+ The adjoint input is ``vhat`` (``uhat`` and ``ghat`` are adjoints to ``u`` and ``v``
289+ respectively and are shown for completeness). The downstream block is ``J``
290+ which has input ``v``.
291+
292+ ::
293+
294+ _ _
295+ |J|--<--v--<--|f|--<--u--<--...
296+ ¯ | ¯ |
297+ vhat | uhat
298+ |
299+ ---<--g--<--...
300+ |
301+ ghat
302+
303+ (Arrows indicate forward model direction)
304+
305+ ::
306+
307+ J : V ⟶ R i.e. J(v) ∈ R ∀ v ∈ V
308+
309+ Interpolation can operate on an expression which may not be linear in its
310+ arguments.
311+
312+ ::
313+
314+ f : W × G ⟶ V i.e. f(u, g) ∈ V ∀ u ∈ W and g ∈ G.
315+ f = I ∘ expr
316+ I : X ⟶ V i.e. I(;x) ∈ V ∀ x ∈ X.
317+ X is infinite dimensional.
318+ expr: W × G ⟶ X i.e. expr(u, g) ∈ X ∀ u ∈ W and g ∈ G.
319+
320+ Arguments after a semicolon are linear (i.e. operation I is linear)
321+ """
283322 def __init__ (self , interpolator , * functions , ** kwargs ):
284323 super ().__init__ ()
285324
@@ -314,22 +353,235 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_outputs):
314353 return replace (self .expr , self ._replace_map ())
315354
316355 def evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx , prepared = None ):
356+ r"""
357+ Denote ``d_u[A]`` as the gateaux derivative in the ``u`` direction.
358+ Arguments after a semicolon are linear.
359+
360+ This calculates
361+
362+ ::
363+
364+ uhat = vhat ⋅ d_u[f](u, g; ⋅) (for inputs[idx] ∈ W)
365+ or
366+ ghat = vhat ⋅ d_g[f](u, g; ⋅) (for inputs[idx] ∈ G)
367+
368+ where ``inputs[idx]`` specifies the derivative direction, ``vhat`` is
369+ ``adj_inputs[0]`` (since we assume only one block output)
370+ and ``⋅`` denotes an unspecified operand of ``u'`` (for
371+ ``inputs[idx]`` ∈ ``W``) or ``g'`` (for ``inputs[idx]`` ∈ ``G``) size (``vhat`` left
372+ multiplies the derivative).
373+
374+ ::
375+
376+ f = I ∘ expr : W × G ⟶ V
377+ i.e. I(expr|_{u, g}) ∈ V ∀ u ∈ W, g ∈ G.
378+
379+ Since ``I`` is linear we get that
380+
381+ ::
382+
383+ d_u[I ∘ expr](u, g; u') = I ∘ d_u[expr](u|_u, g|_g; u')
384+ d_g[I ∘ expr](u, g; g') = I ∘ d_u[expr](u|_u, g|_g; g').
385+
386+ In tensor notation
387+
388+ ::
389+
390+ uhat_q^T = vhat_p^T I([dexpr/du|_u]_q)_p
391+ or
392+ ghat_q^T = vhat_p^T I([dexpr/dg|_u]_q)_p
393+
394+ the output is then
395+
396+ ::
397+
398+ uhat_q = I^T([dexpr/du|_u]_q)_p vhat_p
399+ or
400+ ghat_q = I^T([dexpr/dg|_u]_q)_p vhat_p.
401+ """
402+ if len (adj_inputs ) > 1 :
403+ raise (NotImplementedError ("Interpolate block must have a single output" ))
317404 dJdm = self .backend .derivative (prepared , inputs [idx ])
318- return self .backend .Interpolator (dJdm , self .V ).interpolate (adj_inputs [0 ], transpose = True )
405+ return self .backend .Interpolator (dJdm , self .V ).interpolate (adj_inputs [0 ], transpose = True ). vector ()
319406
320407 def prepare_evaluate_tlm (self , inputs , tlm_inputs , relevant_outputs ):
321408 return replace (self .expr , self ._replace_map ())
322409
323410 def evaluate_tlm_component (self , inputs , tlm_inputs , block_variable , idx , prepared = None ):
411+ r"""
412+ Denote ``d_u[A]`` as the gateaux derivative in the ``u`` direction.
413+ Arguments after a semicolon are linear.
414+
415+ For a block with two inputs this calculates
416+
417+ ::
418+
419+ v' = d_u[f](u, g; u') + d_g[f](u, g; g')
420+
421+ where ``u' = tlm_inputs[0]`` and ``g = tlm_inputs[1]``.
422+
423+ ::
424+
425+ f = I ∘ expr : W × G ⟶ V
426+ i.e. I(expr|_{u, g}) ∈ V ∀ u ∈ W, g ∈ G.
427+
428+ Since ``I`` is linear we get that
429+
430+ ::
431+
432+ d_u[I ∘ expr](u, g; u') = I ∘ d_u[expr](u|_u, g|_g; u')
433+ d_g[I ∘ expr](u, g; g') = I ∘ d_u[expr](u|_u, g|_g; g').
434+
435+ In tensor notation the output is then
436+
437+ ::
438+
439+ v'_l = I([dexpr/du|_{u,g}]_k u'_k)_l + I([dexpr/du|_{u,g}]_k g'_k)_l
440+ = I([dexpr/du|_{u,g}]_k u'_k + [dexpr/du|_{u,g}]_k g'_k)_l
441+
442+ since ``I`` is linear.
443+ """
324444 dJdm = 0.
325445
446+ assert len (inputs ) == len (tlm_inputs )
326447 for i , input in enumerate (inputs ):
327448 if tlm_inputs [i ] is None :
328449 continue
329450 dJdm += self .backend .derivative (prepared , input , tlm_inputs [i ])
330-
331451 return self .backend .Interpolator (dJdm , self .V ).interpolate ()
332452
453+ def prepare_evaluate_hessian (self , inputs , hessian_inputs , adj_inputs , relevant_dependencies ):
454+ return self .prepare_evaluate_adj (inputs , hessian_inputs , relevant_dependencies )
455+
456+ def evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
457+ block_variable , idx ,
458+ relevant_dependencies , prepared = None ):
459+ r"""
460+ Denote ``d_u[A]`` as the gateaux derivative in the ``u`` direction.
461+ Arguments after a semicolon are linear.
462+
463+ hessian_input is ``d_v[d_v[J]](v; v', ⋅)`` where the direction ``⋅`` is left
464+ unspecified so it can be operated upon.
465+
466+ .. warning::
467+ NOTE: This comment describes the implementation of 1 block input ``u``.
468+ (e.g. interpolating from an expression with 1 coefficient).
469+ Explaining how this works for multiple block inputs (e.g. ``u`` and ``g``) is
470+ currently too complicated for the author to do succinctly!
471+
472+ This function needs to output ``d_u[d_u[J ∘ f]](u; u', ⋅)`` where
473+ the direction ``⋅`` will be specified in another function and
474+ multiplied on the right with the output of this function.
475+ We will calculate this using the chain rule.
476+
477+ ::
478+
479+ J : V ⟶ R i.e. J(v) ∈ R ∀ v ∈ V
480+ f = I ∘ expr : W ⟶ V
481+ J ∘ f : W ⟶ R i.e. J(f|u) ∈ R ∀ u ∈ V.
482+ d_u[J ∘ f] : W × W ⟶ R i.e. d_u[J ∘ f](u; u')
483+ d_u[d_u[J ∘ f]] : W × W × W ⟶ R i.e. d_u[d_u[J ∘ f]](u; u', u'')
484+ d_v[J] : V × V ⟶ R i.e. d_v[J](v; v')
485+ d_v[d_v[J]] : V × V × V ⟶ R i.e. d_v[d_v[J]](v; v', v'')
486+
487+ Chain rule:
488+
489+ ::
490+
491+ d_u[J ∘ f](u; u') = d_v[J](v = f|u; v' = d_u[f](u; u'))
492+
493+ Multivariable chain rule:
494+
495+ ::
496+
497+ d_u[d_u[J ∘ f]](u; u', u'') =
498+ d_v[d_v[J]](v = f|u; v' = d_u[f](u; u'), v'' = d_u[f](u; u''))
499+ + d_v'[d_v[J]](v = f|u; v' = d_u[f](u; u'), v'' = d_u[d_u[f]](u; u', u''))
500+ = d_v[d_v[J]](v = f|u; v' = d_u[f](u; u'), v''=d_u[f](u; u''))
501+ + d_v[J](v = f|u; v' = v'' = d_u[d_u[f]](u; u', u''))
502+
503+ since ``d_v[d_v[J]]`` is linear in ``v'`` so differentiating wrt to it leaves
504+ its coefficient, the bare d_v[J] operator which acts on the ``v''`` term
505+ that remains.
506+
507+ The ``d_u[d_u[f]](u; u', u'')`` term can be simplified further:
508+
509+ ::
510+
511+ f = I ∘ expr : W ⟶ V i.e. I(expr|u) ∈ V ∀ u ∈ W
512+ d_u[I ∘ expr] : W × W ⟶ V i.e. d_u[I ∘ expr](u; u')
513+ d_u[d_u[I ∘ expr]] : W × W × W ⟶ V i.e. d_u[I ∘ expr](u; u', u'')
514+ d_x[I] : X × X ⟶ V i.e. d_x[I](x; x')
515+ d_x[d_x[I]] : X × X × X ⟶ V i.e. d_x[d_x[I]](x; x', x'')
516+ d_u[expr] : W × W ⟶ X i.e. d_u[expr](u; u')
517+ d_u[d_u[expr]] : W × W × W ⟶ X i.e. d_u[d_u[expr]](u; u', u'')
518+
519+ Since ``I`` is linear we get that
520+
521+ ::
522+
523+ d_u[d_u[I ∘ expr]](u; u', u'') = I ∘ d_u[d_u[expr]](u; u', u'').
524+
525+ So our full hessian is:
526+
527+ ::
528+
529+ d_u[d_u[J ∘ f]](u; u', u'')
530+ = d_v[d_v[J]](v = f|u; v' = d_u[f](u; u'), v''=d_u[f](u; u''))
531+ + d_v[J](v = f|u; v' = v'' = d_u[d_u[f]](u; u', u''))
532+
533+ In tensor notation
534+
535+ ::
536+
537+ [d^2[J ∘ f]/du^2|_u]_{lk} u'_k u''_k =
538+ [d^2J/dv^2|_{v=f|_u}]_{ij} [df/du|_u]_{jk} u'_k [df/du|_u]_{il} u''_l
539+ + [dJ/dv|_{v=f_u}]_i I([d^2expr/du^2|_u]_{lk} u'_k)_i u''_l
540+
541+ In the first term:
542+
543+ ::
544+
545+ [df/du|_u]_{jk} u'_k = v'_j
546+ => [d^2J/dv^2|_{v=f|_u}]_{ij} [df/du|_u]_{jk} u'_k
547+ = [d^2J/dv^2|_{v=f|_u}]_{ij} v'_j
548+ = hessian_input_i
549+ => [d^2J/dv^2|_{v=f|_u}]_{ij} [df/du|_u]_{jk} u'_k [df/du|_u]_{il}
550+ = hessian_input_i [df/du|_u]_{il}
551+ = self.evaluate_adj_component(inputs, hessian_inputs, ...)_l
552+
553+ In the second term we calculate everything explicitly though note
554+ ``[dJ/dv|_{v=f_u}]_i = adj_inputs[0]_i``
555+
556+ Also, the second term is 0 if ``expr`` is linear.
557+ """
558+
559+ if len (hessian_inputs ) > 1 or len (adj_inputs ) > 1 :
560+ raise (NotImplementedError ("Interpolate block must have a single output" ))
561+
562+ component = self .evaluate_adj_component (inputs , hessian_inputs , block_variable , idx , prepared )
563+
564+ # Prepare again by replacing expression
565+ expr = replace (self .expr , self ._replace_map ())
566+
567+ # Calculate first derivative for each relevant block
568+ dexprdu = 0.
569+ for _ , bv in relevant_dependencies :
570+ # Only take the derivative if there is a direction to take it in
571+ if bv .tlm_value is None :
572+ continue
573+ dexprdu += self .backend .derivative (expr , bv .saved_output , bv .tlm_value )
574+
575+ # Calculate the second derivative w.r.t. the specified coefficient's
576+ # saved value. Leave argument unspecified so it can be calculated with
577+ # the eventual inner product with u''.
578+ d2exprdudu = self .backend .derivative (dexprdu , block_variable .saved_output )
579+
580+ # left multiply by dJ/dv (adj_inputs[0]) - i.e. interpolate using the
581+ # transpose operator
582+ component += self .backend .Interpolator (d2exprdudu , self .V ).interpolate (adj_inputs [0 ], transpose = True )
583+ return component .vector ()
584+
333585 def prepare_recompute_component (self , inputs , relevant_outputs ):
334586 return replace (self .expr , self ._replace_map ())
335587
0 commit comments