diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4f61ab2d9b..76ffa199ef 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -82,6 +82,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch ufl dham/cofunction_is_terminal \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index fc5be8486a..42d5f13cda 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -268,3 +268,76 @@ def recompute_component(self, inputs, block_variable, idx, prepared): def __str__(self): deps = self.get_dependencies() return f"{deps[1]}[{self.idx}].assign({deps[0]})" + + +class CofunctionAssignBlock(Block): + """Class specifically for the case b.assign(a). + + All other cofunction assignment operations are annotated via Assemble. In + effect this means that this is the annotation of an identity operation. + + Parameters + ---------- + lhs: + The target of the assignment. + rhs: + The cofunction being assigned. + """ + + def __init__(self, lhs: firedrake.Cofunction, rhs: firedrake.Cofunction, + ad_block_tag=None): + super().__init__(ad_block_tag=ad_block_tag) + self.add_output(lhs.block_variable) + self.add_dependency(rhs) + + def recompute_component(self, inputs, block_variable, idx, prepared=None): + """Recompute the assignment. + + Parameters + ---------- + inputs : list of Function or Constant + The variable in the RHS of the assignment. + block_variable : pyadjoint.block_variable.BlockVariable + The output block variable. + idx : int + Index associated to the inputs list. + prepared : + The precomputed RHS value. + + Notes + ----- + Recomputes the block_variable only if the checkpoint was not delegated + to another :class:`~firedrake.function.Function`. + + Returns + ------- + Cofunction + Return either the firedrake cofunction or `BlockVariable` + checkpoint to which was delegated the checkpointing. + """ + assert idx == 0 # There must be only one RHS. + if isinstance(block_variable.checkpoint, DelegatedFunctionCheckpoint): + return block_variable.checkpoint + else: + output = firedrake.Cofunction( + block_variable.output.function_space() + ) + output.assign(inputs[0]) + return maybe_disk_checkpoint(output) + + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, + prepared=None): + return adj_inputs[0] + + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, + block_variable, idx, + relevant_dependencies, prepared=None): + return hessian_inputs[0] + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, + prepared=None): + return tlm_inputs[0] + + def __str__(self): + deps = self.get_dependencies() + return f"assign({deps[0]})" diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 7a57e883c3..9eb6a8626c 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -80,8 +80,15 @@ def _init_solver_parameters(self, args, kwargs): self.assemble_kwargs = {} def __str__(self): - return "solve({} = {})".format(ufl2unicode(self.lhs), - ufl2unicode(self.rhs)) + try: + lhs_string = ufl2unicode(self.lhs) + except AttributeError: + lhs_string = str(self.lhs) + try: + rhs_string = ufl2unicode(self.rhs) + except AttributeError: + rhs_string = str(self.rhs) + return "solve({} = {})".format(lhs_string, rhs_string) def _create_F_form(self): # Process the equation forms, replacing values with checkpoints, @@ -692,7 +699,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, c = block_variable.output c_rep = block_variable.saved_output - if isinstance(c, firedrake.Function): + if isinstance(c, (firedrake.Function, firedrake.Cofunction)): trial_function = firedrake.TrialFunction(c.function_space()) elif isinstance(c, firedrake.Constant): mesh = F_form.ufl_domain() @@ -729,7 +736,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, replace_map[self.func] = self.get_outputs()[0].saved_output dFdm = replace(dFdm, replace_map) - dFdm = dFdm * adj_sol + if isinstance(dFdm, firedrake.Argument): + # Corner case. Should be fixed more permanently upstream in UFL. + dFdm = ufl.Action(dFdm, adj_sol) + else: + dFdm = dFdm * adj_sol dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs) return dFdm diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index d56438fec7..da790d3d7c 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -221,55 +221,18 @@ def _ad_create_checkpoint(self): return self.copy(deepcopy=True) def _ad_convert_riesz(self, value, options=None): - from firedrake import Function, Cofunction + from firedrake import Function options = {} if options is None else options riesz_representation = options.get("riesz_representation", "L2") solver_options = options.get("solver_options", {}) - V = options.get("function_space", self.function_space()) if value == 0.: # In adjoint-based differentiation, value == 0. arises only when # the functional is independent on the control variable. - return Function(V) - - if not isinstance(value, (Cofunction, Function)): - raise TypeError("Expected a Cofunction or a Function") - - if riesz_representation == "l2": - return Function(V, val=value.dat) - - elif riesz_representation in ("L2", "H1"): - if not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") - - ret = Function(V) - a = self._define_riesz_map_form(riesz_representation, V) - firedrake.solve(a == value, ret, **solver_options) - return ret - - elif callable(riesz_representation): - return riesz_representation(value) - - else: - raise ValueError( - "Unknown Riesz representation %s" % riesz_representation) - - def _define_riesz_map_form(self, riesz_representation, V): - from firedrake import TrialFunction, TestFunction + return Function(self.function_space()) - u = TrialFunction(V) - v = TestFunction(V) - if riesz_representation == "L2": - a = firedrake.inner(u, v)*firedrake.dx - - elif riesz_representation == "H1": - a = firedrake.inner(u, v)*firedrake.dx \ - + firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx - - else: - raise NotImplementedError( - "Unknown Riesz representation %s" % riesz_representation) - return a + return value.riesz_representation(riesz_map=riesz_representation, + solver_options=solver_options) @no_annotations def _ad_convert_type(self, value, options=None): @@ -294,9 +257,7 @@ def _ad_restore_at_checkpoint(self, checkpoint): return checkpoint def _ad_will_add_as_dependency(self): - """Method called when the object is added as a Block dependency. - - """ + """Method called when the object is added as a Block dependency.""" with checkpoint_init_data(): super()._ad_will_add_as_dependency() @@ -305,7 +266,8 @@ def _ad_mul(self, other): from firedrake import Function r = Function(self.function_space()) - # `self` can be a Cofunction in which case only left multiplication with a scalar is allowed. + # `self` can be a Cofunction in which case only left multiplication + # with a scalar is allowed. r.assign(other * self) return r @@ -318,7 +280,10 @@ def _ad_add(self, other): return r def _ad_dot(self, other, options=None): - from firedrake import assemble + from firedrake import assemble, action, Cofunction + + if isinstance(other, Cofunction): + return assemble(action(other, self)) options = {} if options is None else options riesz_representation = options.get("riesz_representation", "L2") @@ -400,3 +365,9 @@ def _applyBinary(self, f, y): def __deepcopy__(self, memodict={}): return self.copy(deepcopy=True) + + +class CofunctionMixin(FunctionMixin): + + def _ad_dot(self, other): + return firedrake.assemble(firedrake.action(self, other)) diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index b848c08f74..af28686977 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -1,22 +1,24 @@ +from functools import cached_property import numpy as np import ufl from ufl.form import BaseForm from pyop2 import op2, mpi -from pyadjoint.tape import stop_annotating, annotate_tape +from pyadjoint.tape import stop_annotating, annotate_tape, get_working_tape import firedrake.assemble import firedrake.functionspaceimpl as functionspaceimpl from firedrake import utils, vector, ufl_expr from firedrake.utils import ScalarType -from firedrake.adjoint_utils.function import FunctionMixin +from firedrake.adjoint_utils.function import CofunctionMixin from firedrake.adjoint_utils.checkpointing import DelegatedFunctionCheckpoint +from firedrake.adjoint_utils.blocks.function import CofunctionAssignBlock from firedrake.petsc import PETSc -class Cofunction(ufl.Cofunction, FunctionMixin): +class Cofunction(ufl.Cofunction, CofunctionMixin): r"""A :class:`Cofunction` represents a function on a dual space. - Like Functions, cofunctions are - represented as sums of basis functions: + + Like Functions, cofunctions are represented as sums of basis functions: .. math:: @@ -32,7 +34,7 @@ class Cofunction(ufl.Cofunction, FunctionMixin): """ @PETSc.Log.EventDecorator() - @FunctionMixin._ad_annotate_init + @CofunctionMixin._ad_annotate_init def __init__(self, function_space, val=None, name=None, dtype=ScalarType, count=None): r""" @@ -104,13 +106,13 @@ def _analyze_form_arguments(self): self._coefficients = (self,) @utils.cached_property - @FunctionMixin._ad_annotate_subfunctions + @CofunctionMixin._ad_annotate_subfunctions def subfunctions(self): r"""Extract any sub :class:`Cofunction`\s defined on the component spaces of this this :class:`Cofunction`'s :class:`.FunctionSpace`.""" return tuple(type(self)(fs, dat) for fs, dat in zip(self.function_space(), self.dat)) - @FunctionMixin._ad_annotate_subfunctions + @CofunctionMixin._ad_annotate_subfunctions def split(self): import warnings warnings.warn("The .split() method is deprecated, please use the .subfunctions property instead", category=FutureWarning) @@ -198,8 +200,15 @@ def assign(self, expr, subset=None): and expr.function_space() == self.function_space()): # do not annotate in case of self assignment if annotate_tape() and self != expr: + if subset is not None: + raise NotImplementedError("Cofunction subset assignment " + "annotation is not supported.") self.block_variable = self.create_block_variable() self.block_variable._checkpoint = DelegatedFunctionCheckpoint(expr.block_variable) + get_working_tape().add_block( + CofunctionAssignBlock(self, expr) + ) + expr.dat.copy(self.dat, subset=subset) return self elif isinstance(expr, BaseForm): @@ -211,39 +220,38 @@ def assign(self, expr, subset=None): raise ValueError('Cannot assign %s' % expr) - def riesz_representation(self, riesz_map='L2', **solver_options): - """Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. + def riesz_representation(self, riesz_map='L2', *, bcs=None, + solver_options=None): + """Return the Riesz representation of this :class:`Cofunction`. - Example: For a L2 Riesz map, the Riesz representation is obtained by solving - the linear system ``Mx = self``, where M is the L2 mass matrix, i.e. M = - with u and v trial and test functions, respectively. + Example: For a L2 Riesz map, the Riesz representation is obtained by + solving the linear system ``Mx = self``, where M is the L2 mass matrix, + i.e. M = with u and v trial and test functions, respectively. Parameters ---------- - riesz_map : str or collections.abc.Callable - The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. - solver_options : dict - Solver options to pass to the linear solver: - - solver_parameters: optional solver parameters. - - nullspace: an optional :class:`.VectorSpaceBasis` (or :class:`.MixedVectorSpaceBasis`) - spanning the null space of the operator. - - transpose_nullspace: as for the nullspace, but used to make the right hand side consistent. - - near_nullspace: as for the nullspace, but used to add the near nullspace. - - options_prefix: an optional prefix used to distinguish PETSc options. - If not provided a unique prefix will be created. - Use this option if you want to pass options to the solver from the command line - in addition to through the ``solver_parameters`` dict. + riesz_map : str or ufl.sobolevspace.SobolevSpace or + collections.abc.Callable + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a + callable. + bcs: DirichletBC or list of DirichletBC + Boundary conditions to apply to the Riesz map. + solver_options: dict + A dictionary of PETSc options to be passed to the solver. Returns ------- firedrake.function.Function - Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. + Riesz representation of this :class:`Cofunction` with respect to + the given Riesz map. """ - return self._ad_convert_riesz(self, options={"function_space": self.function_space().dual(), - "riesz_representation": riesz_map, - "solver_options": solver_options}) + if not callable(riesz_map): + riesz_map = RieszMap(self.function_space(), riesz_map, bcs=bcs, + solver_options=solver_options) + + return riesz_map(self) - @FunctionMixin._ad_annotate_iadd + @CofunctionMixin._ad_annotate_iadd @utils.known_pyop2_safe def __iadd__(self, expr): @@ -259,7 +267,7 @@ def __iadd__(self, expr): # Let Python hit `BaseForm.__add__` which relies on ufl.FormSum. return NotImplemented - @FunctionMixin._ad_annotate_isub + @CofunctionMixin._ad_annotate_isub @utils.known_pyop2_safe def __isub__(self, expr): @@ -276,7 +284,7 @@ def __isub__(self, expr): # Let Python hit `BaseForm.__sub__` which relies on ufl.FormSum. return NotImplemented - @FunctionMixin._ad_annotate_imul + @CofunctionMixin._ad_annotate_imul def __imul__(self, expr): if np.isscalar(expr): @@ -343,3 +351,120 @@ def __str__(self): def cell_node_map(self): return self.function_space().cell_node_map() + + +class RieszMap: + """Return a map between dual and primal function spaces. + + A `RieszMap` can be called on a `Cofunction` in the appropriate space to + yield the `Function` which is the Riesz representer under the given inner + product. Conversely, it can be called on a `Function` to apply the given + inner product and return a `Cofunction`. + + Parameters + ---------- + function_space_or_inner_product: FunctionSpace or ufl.Form + The space from which to map, or a bilinear form defining an inner + product. + sobolev_space: str or ufl.sobolevspace.SobolevSpace. + Used to determine the inner product. + bcs: DirichletBC or list of DirichletBC + Boundary conditions to apply to the Riesz map. + solver_options: dict + A dictionary of PETSc options to be passed to the solver. + """ + + def __init__(self, function_space_or_inner_product=None, + sobolev_space=ufl.L2, *, bcs=None, solver_options=None): + if isinstance(function_space_or_inner_product, ufl.Form): + args = ufl.algorithms.extract_arguments( + function_space_or_inner_product + ) + if len(args) != 2: + raise ValueError(f"inner_product has arity {len(args)}, " + "should be 2.") + function_space = args[0].function_space() + inner_product = function_space_or_inner_product + else: + function_space = function_space_or_inner_product + if hasattr(function_space, "function_space"): + function_space = function_space.function_space() + if ufl.duals.is_dual(function_space): + function_space = function_space.dual() + + if str(sobolev_space) == "l2": + inner_product = "l2" + else: + from firedrake import TrialFunction, TestFunction + u = TrialFunction(function_space) + v = TestFunction(function_space) + inner_product = RieszMap._inner_product_form( + sobolev_space, u, v + ) + + self._function_space = function_space + self._inner_product = inner_product + self._bcs = bcs + self._solver_options = solver_options or {} + + @staticmethod + def _inner_product_form(sobolev_space, u, v): + from firedrake import inner, dx, grad + inner_products = { + "L2": lambda u, v: inner(u, v)*dx, + "H1": lambda u, v: inner(u, v)*dx + inner(grad(u), grad(v))*dx + } + try: + return inner_products[str(sobolev_space)](u, v) + except KeyError: + raise ValueError("No inner product defined for Sobolev space " + f"{sobolev_space}.") + + @cached_property + def _solver(self): + from firedrake import (LinearVariationalSolver, + LinearVariationalProblem, Function, Cofunction) + rhs = Cofunction(self._function_space.dual()) + soln = Function(self._function_space) + lvp = LinearVariationalProblem(self._inner_product, rhs, soln, + bcs=self._bcs) + solver = LinearVariationalSolver( + lvp, solver_parameters=self._solver_options + ) + return solver.solve, rhs, soln + + def __call__(self, value): + """Return the Riesz representer of a Function or Cofunction.""" + from firedrake import Function, Cofunction + + if ufl.duals.is_dual(value): + if value.function_space().dual() != self._function_space: + raise ValueError("Function space mismatch in RieszMap.") + output = Function(self._function_space) + + if self._inner_product == "l2": + for o, c in zip(output.subfunctions, value.subfunctions): + o.dat.data[:] = c.dat.data[:] + else: + solve, rhs, soln = self._solver + rhs.assign(value) + solve() + output = Function(self._function_space) + output.assign(soln) + elif ufl.duals.is_primal(value): + if value.function_space().dual() != self._function_space: + raise ValueError("Function space mismatch in RieszMap.") + + if self._inner_product == "l2": + output = Cofunction(self._function_space.dual()) + for o, c in zip(output.subfunctions, value.subfunctions): + o.dat.data[:] = c.dat.data[:] + else: + output = firedrake.assemble( + firedrake.action(self._inner_product, value) + ) + else: + raise ValueError( + f"Unable to ascertain if {value} is primal or dual." + ) + return output diff --git a/firedrake/function.py b/firedrake/function.py index 7307877de2..81d55b7dd3 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -18,7 +18,7 @@ from firedrake.utils import ScalarType, IntType, as_ctypes from firedrake import functionspaceimpl -from firedrake.cofunction import Cofunction +from firedrake.cofunction import Cofunction, RieszMap from firedrake import utils from firedrake import vector from firedrake.adjoint_utils import FunctionMixin @@ -468,40 +468,36 @@ def assign(self, expr, subset=None): Assigner(self, expr, subset).assign() return self - def riesz_representation(self, riesz_map='L2'): - """Return the Riesz representation of this :class:`Function` with respect to the given Riesz map. + def riesz_representation(self, riesz_map='L2', bcs=None, + solver_options=None): + """Return the Riesz representation of this :class:`Function`. - Example: For a L2 Riesz map, the Riesz representation is obtained by taking the action - of ``M`` on ``self``, where M is the L2 mass matrix, i.e. M = - with u and v trial and test functions, respectively. + Example: For a L2 Riesz map, the Riesz representation is obtained by + taking the action of ``M`` on ``self``, where M is the L2 mass matrix, + i.e. M = with u and v trial and test functions, respectively. Parameters ---------- - riesz_map : str or collections.abc.Callable - The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + riesz_map : str or ufl.sobolevspace.SobolevSpace or + collections.abc.Callable + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a + callable which applies the Riesz map. + bcs: DirichletBC or list of DirichletBC + Boundary conditions to apply to the Riesz map. + solver_options: dict + A dictionary of PETSc options to be passed to the solver. Returns ------- firedrake.cofunction.Cofunction - Riesz representation of this :class:`Function` with respect to the given Riesz map. + Riesz representation of this :class:`Function` with respect to the + given Riesz map. """ - from firedrake.ufl_expr import action - from firedrake.assemble import assemble + if not callable(riesz_map): + riesz_map = RieszMap(self.function_space(), riesz_map, bcs=bcs, + solver_options=solver_options) - V = self.function_space() - if riesz_map == "l2": - return Cofunction(V.dual(), val=self.dat) - - elif riesz_map in ("L2", "H1"): - a = self._define_riesz_map_form(riesz_map, V) - return assemble(action(a, self)) - - elif callable(riesz_map): - return riesz_map(self) - - else: - raise NotImplementedError( - "Unknown Riesz representation %s" % riesz_map) + return riesz_map(self) @FunctionMixin._ad_annotate_iadd def __iadd__(self, expr): diff --git a/firedrake/ufl_expr.py b/firedrake/ufl_expr.py index e129c26037..eeadfd32d1 100644 --- a/firedrake/ufl_expr.py +++ b/firedrake/ufl_expr.py @@ -45,6 +45,12 @@ def __init__(self, function_space, number, part=None): number, part=part) self._function_space = function_space + def arguments(self): + return (self,) + + def coefficients(self): + return () + @utils.cached_property def cell_node_map(self): return self.function_space().cell_node_map diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 4aa66f84fd..53406bab9d 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -947,3 +947,23 @@ def test_riesz_representation_for_adjoints(): and np.allclose(dJdu_default_L2.dat.data, dJdu_function_L2.dat.data) and np.allclose(dJdu_L2.dat.data, dJdu_function_L2.dat.data) ) + + +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +def test_cofunction_assign_functional(): + """Test that cofunction assignment is correctly annotated. + """ + mesh = UnitIntervalMesh(5) + fs = FunctionSpace(mesh, "R", 0) + f = Function(fs) + f.assign(1.0) + f2 = Function(fs) + f2.assign(1.0) + v = TestFunction(fs) + + cof = assemble(f * v * dx) + cof2 = Cofunction(cof) + cof2.assign(cof) # Not currently taped! + J = assemble(action(cof2, f2)) + Jhat = ReducedFunctional(J, Control(f)) + assert np.allclose(float(Jhat.derivative()), 1.0)