Skip to content
3 changes: 2 additions & 1 deletion firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@ def _ad_assign_map(self, form):
for block_variable in self.get_dependencies():
coeff = block_variable.output
if isinstance(coeff,
(firedrake.Coefficient, firedrake.Constant)):
(firedrake.Coefficient, firedrake.Constant,
firedrake.Cofunction)):
coeff_count = coeff.count()
if coeff_count in form_ad_count_map:
assign_map[form_ad_count_map[coeff_count]] = \
Expand Down
12 changes: 9 additions & 3 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import firedrake
from .checkpointing import disk_checkpointing, CheckpointFunction, \
CheckpointBase, checkpoint_init_data, DelegatedFunctionCheckpoint
from numbers import Number


class FunctionMixin(FloatingType):
Expand Down Expand Up @@ -230,12 +231,17 @@ def _ad_convert_riesz(self, value, options=None):

if riesz_representation != "l2" and not isinstance(value, Cofunction):
raise TypeError("Expected a Cofunction")
elif not isinstance(value, (float, (Cofunction, Function))):
elif not isinstance(value, (Number, Cofunction, Function)):
raise TypeError("Expected a Cofunction, Function or a float")

if riesz_representation == "l2":
value = value.dat if isinstance(value, (Cofunction, Function)) else value
return Function(V, val=value)
if isinstance(value, (Cofunction, Function)):
return Function(V, val=value.dat)
else:
f = Function(V)
with stop_annotating():
f.assign(value)
return f

elif riesz_representation in ("L2", "H1"):
ret = Function(V)
Expand Down
20 changes: 14 additions & 6 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from ufl.form import BaseForm
from pyop2 import op2, mpi
from pyadjoint.tape import stop_annotating, annotate_tape
import firedrake.assemble
import firedrake.functionspaceimpl as functionspaceimpl
from firedrake import utils, vector, ufl_expr
from firedrake.utils import ScalarType
from firedrake.adjoint_utils.function import FunctionMixin
from firedrake.adjoint_utils.checkpointing import DelegatedFunctionCheckpoint
from firedrake.petsc import PETSc


Expand All @@ -31,7 +33,8 @@ class Cofunction(ufl.Cofunction, FunctionMixin):

@PETSc.Log.EventDecorator()
@FunctionMixin._ad_annotate_init
def __init__(self, function_space, val=None, name=None, dtype=ScalarType):
def __init__(self, function_space, val=None, name=None, dtype=ScalarType,
count=None):
r"""
:param function_space: the :class:`.FunctionSpace`,
or :class:`.MixedFunctionSpace` on which to build this :class:`Cofunction`.
Expand All @@ -55,7 +58,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType):
raise NotImplementedError("Can't make a Cofunction defined on a "
+ str(type(function_space)))

ufl.Cofunction.__init__(self, V.ufl_function_space())
ufl.Cofunction.__init__(self, V.ufl_function_space(), count=count)

# User comm
self.comm = V.comm
Expand Down Expand Up @@ -168,7 +171,6 @@ def zero(self, subset=None):
return self.assign(0, subset=subset)

@PETSc.Log.EventDecorator()
@FunctionMixin._ad_not_implemented
@utils.known_pyop2_safe
def assign(self, expr, subset=None):
r"""Set the :class:`Cofunction` value to the pointwise value of
Expand All @@ -189,15 +191,21 @@ def assign(self, expr, subset=None):
"""
expr = ufl.as_ufl(expr)
if isinstance(expr, ufl.classes.Zero):
self.dat.zero(subset=subset)
with stop_annotating(modifies=(self,)):
self.dat.zero(subset=subset)
return self
elif (isinstance(expr, Cofunction)
and expr.function_space() == self.function_space()):
# do not annotate in case of self assignment
if annotate_tape() and self != expr:
self.block_variable = self.create_block_variable()
self.block_variable._checkpoint = DelegatedFunctionCheckpoint(expr.block_variable)
expr.dat.copy(self.dat, subset=subset)
return self
elif isinstance(expr, BaseForm):
# Enable to write down c += B where c is a Cofunction
# and B an appropriate BaseForm object
# Enable c.assign(B) where c is a Cofunction and B an appropriate BaseForm object.
# If annotation is enabled, the following operation will result in an assemble block on the
# Pyadjoint tape.
assembled_expr = firedrake.assemble(expr)
return self.assign(assembled_expr, subset=subset)

Expand Down
59 changes: 59 additions & 0 deletions tests/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from numpy.random import rand
from pyadjoint.tape import get_working_tape, pause_annotation, stop_annotating
from ufl.classes import Zero

from firedrake import *
from firedrake.adjoint import *
Expand Down Expand Up @@ -798,3 +799,61 @@ def test_3325():

constraint = UFLInequalityConstraint(-inner(g, g)*ds(4), control)
minimize(Jhat, method="SLSQP", constraints=constraint)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
@pytest.mark.parametrize("solve_type", ["solve", "linear_variational_solver"])
def test_assign_cofunction(solve_type):
# See https://github.com/firedrakeproject/firedrake/issues/3464 .
# This function tests the case where Cofunction assigns a
# Cofunction and a BaseForm.
mesh = UnitSquareMesh(2, 2)
V = FunctionSpace(mesh, "CG", 1)
v = TestFunction(V)
u = TrialFunction(V)
k = Function(V).assign(1.0)
a = k * u * v * dx
b = Constant(1.0) * v * dx
u0 = Cofunction(V.dual(), name="u0")
u1 = Cofunction(V.dual(), name="u1")
sol = Function(V, name="sol")
if solve_type == "linear_variational_solver":
problem = LinearVariationalProblem(lhs(a), rhs(a) + u1, sol)
solver = LinearVariationalSolver(problem)
J = 0
for i in range(2):
# This loop emulates a time-dependent problem, where the Cofunction
# added on the right-hand of the equation is updated at each time step.
u0.assign(assemble(b))
u1.assign(i * u0 + b)
if solve_type == "solve":
solve(a == u1, sol)
if solve_type == "linear_variational_solver":
solver.solve()
J += assemble(((sol + Constant(1.0)) ** 2) * dx)
rf = ReducedFunctional(J, Control(k))
assert rf(k) == J
assert taylor_test(rf, k, Function(V).assign(0.1)) > 1.9


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
def test_assign_zero_cofunction():
# See https://github.com/firedrakeproject/firedrake/issues/3464 .
# It is expected the tape breaks since the functional loses its dependency
# on the control after the Cofunction assigns Zero.
mesh = UnitSquareMesh(2, 2)
V = FunctionSpace(mesh, "CG", 1)
v = TestFunction(V)
u = TrialFunction(V)
k = Function(V).assign(1.0)
a = u * v * dx
b = k * v * dx
u0 = Cofunction(V.dual(), name="u0")
u0.assign(b)
u0.assign(Zero())
sol = Function(V, name="c")
solve(a == u0, sol)
J = assemble(((sol + Constant(1.0)) ** 2) * dx)
# The zero assignment should break the tape and hence cause a zero
# gradient.
assert all(compute_gradient(J, Control(k)).dat.data_ro == 0.0)