diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 675946f845..218ac9c379 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,6 +83,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch pyadjoint dolci/tape_recompute_count \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index a96889b7ba..994f4fc253 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -2,8 +2,9 @@ import ufl from ufl import replace from ufl.formatting.ufl2unicode import ufl2unicode +from enum import Enum -from pyadjoint import Block, stop_annotating +from pyadjoint import Block, stop_annotating, get_working_tape from pyadjoint.enlisting import Enlist import firedrake from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint @@ -24,6 +25,12 @@ def extract_subfunction(u, V): return u +class Solver(Enum): + """Enum for solver types.""" + FORWARD = 0 + ADJOINT = 1 + + class GenericSolveBlock(Block): pop_kwargs_keys = ["adj_cb", "adj_bdy_cb", "adj2_cb", "adj2_bdy_cb", "forward_args", "forward_kwargs", "adj_args", @@ -206,15 +213,17 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): adj_sol_bdy = None if compute_bdy: - adj_sol_bdy = firedrake.Function( - self.function_space.dual(), - dJdu_copy.dat - firedrake.assemble( - firedrake.action(dFdu_adj_form, adj_sol) - ).dat - ) + adj_sol_bdy = self._compute_adj_bdy( + adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy) return adj_sol, adj_sol_bdy + def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): + adj_sol_bdy = firedrake.Function( + self.function_space.dual(), dJdu.dat - firedrake.assemble( + firedrake.action(dFdu_adj_form, adj_sol)).dat) + return adj_sol_bdy + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if not self.linear and self.func == block_variable.output: @@ -626,15 +635,62 @@ def _init_solver_parameters(self, args, kwargs): super()._init_solver_parameters(args, kwargs) solve_init_params(self, args, kwargs, varform=True) + def recompute_component(self, inputs, block_variable, idx, prepared): + tape = get_working_tape() + if self._ad_solvers["recompute_count"] == tape.recompute_count - 1: + # Update how many times the block has been recomputed. + self._ad_solvers["recompute_count"] = tape.recompute_count + if self._ad_solvers["forward_nlvs"]._problem._constant_jacobian: + self._ad_solvers["forward_nlvs"].invalidate_jacobian() + self._ad_solvers["update_adjoint"] = True + return super().recompute_component(inputs, block_variable, idx, prepared) + def _forward_solve(self, lhs, rhs, func, bcs, **kwargs): - self._ad_nlvs_replace_forms() - self._ad_nlvs.parameters.update(self.solver_params) - self._ad_nlvs.solve() - func.assign(self._ad_nlvs._problem.u) + self._ad_solver_replace_forms() + self._ad_solvers["forward_nlvs"].parameters.update(self.solver_params) + self._ad_solvers["forward_nlvs"].solve() + func.assign(self._ad_solvers["forward_nlvs"]._problem.u) return func - def _ad_assign_map(self, form): - count_map = self._ad_nlvs._problem._ad_count_map + def _adjoint_solve(self, dJdu, compute_bdy): + dJdu_copy = dJdu.copy() + # Homogenize and apply boundary conditions on adj_dFdu and dJdu. + bcs = self._homogenize_bcs() + for bc in bcs: + bc.apply(dJdu) + + if ( + self._ad_solvers["forward_nlvs"]._problem._constant_jacobian + and self._ad_solvers["update_adjoint"] + ): + # Update left hand side of the adjoint equation. + self._ad_solver_replace_forms(Solver.ADJOINT) + self._ad_solvers["adjoint_lvs"].invalidate_jacobian() + self._ad_solvers["update_adjoint"] = False + elif not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian: + # Update left hand side of the adjoint equation. + self._ad_solver_replace_forms(Solver.ADJOINT) + + # Update the right hand side of the adjoint equation. + # problem.F._component[1] is the right hand side of the adjoint. + self._ad_solvers["adjoint_lvs"]._problem.F._components[1].assign(dJdu) + + # Solve the adjoint linear variational solver. + self._ad_solvers["adjoint_lvs"].solve() + u_sol = self._ad_solvers["adjoint_lvs"]._problem.u + + adj_sol_bdy = None + if compute_bdy: + jac_adj = self._ad_solvers["adjoint_lvs"]._problem.J + adj_sol_bdy = self._compute_adj_bdy( + u_sol, adj_sol_bdy, jac_adj, dJdu_copy) + return u_sol, adj_sol_bdy + + def _ad_assign_map(self, form, solver): + if solver == Solver.FORWARD: + count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map + else: + count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map assign_map = {} form_ad_count_map = dict((count_map[coeff], coeff) for coeff in form.coefficients()) @@ -647,54 +703,47 @@ def _ad_assign_map(self, form): if coeff_count in form_ad_count_map: assign_map[form_ad_count_map[coeff_count]] = \ block_variable.saved_output + + if ( + solver == Solver.ADJOINT + and not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian + ): + block_variable = self.get_outputs()[0] + coeff_count = block_variable.output.count() + if coeff_count in form_ad_count_map: + assign_map[form_ad_count_map[coeff_count]] = \ + block_variable.saved_output return assign_map - def _ad_assign_coefficients(self, form): - assign_map = self._ad_assign_map(form) + def _ad_assign_coefficients(self, form, solver): + assign_map = self._ad_assign_map(form, solver) for coeff, value in assign_map.items(): coeff.assign(value) - def _ad_nlvs_replace_forms(self): - problem = self._ad_nlvs._problem - self._ad_assign_coefficients(problem.F) - self._ad_assign_coefficients(problem.J) - - def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs): - if "dFdu_adj" in self._adj_cache: - dFdu = self._adj_cache["dFdu_adj"] + def _ad_solver_replace_forms(self, solver=Solver.FORWARD): + if solver == Solver.FORWARD: + problem = self._ad_solvers["forward_nlvs"]._problem + self._ad_assign_coefficients(problem.F, solver) + self._ad_assign_coefficients(problem.J, solver) else: - dFdu = super()._assemble_dFdu_adj(dFdu_adj_form, **kwargs) - if self._ad_nlvs._problem._constant_jacobian: - self._adj_cache["dFdu_adj"] = dFdu - return dFdu + self._ad_assign_coefficients( + self._ad_solvers["adjoint_lvs"]._problem.J, solver) def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): - dJdu = adj_inputs[0] - - F_form = self._create_F_form() - - dFdu_form = self.adj_F - dJdu = dJdu.copy() - - # Replace the form coefficients with checkpointed values. - replace_map = self._replace_map(dFdu_form) - replace_map[self.func] = self.get_outputs()[0].saved_output - dFdu_form = replace(dFdu_form, replace_map) - compute_bdy = self._should_compute_boundary_adjoint( relevant_dependencies ) - adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq( - dFdu_form, dJdu, compute_bdy - ) - self.adj_state = adj_sol + adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy) + if not self.adj_state: + self.adj_state = firedrake.Function(adj_sol.function_space()) + self.adj_state.assign(adj_sol) if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: self.adj_bdy_cb(adj_sol_bdy) r = {} - r["form"] = F_form + r["form"] = self._create_F_form() r["adj_sol"] = adj_sol r["adj_sol_bdy"] = adj_sol_bdy return r diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index a6811002ac..a61555f1a9 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -45,7 +45,8 @@ def wrapper(self, problem, *args, **kwargs): self._ad_problem = problem self._ad_args = args self._ad_kwargs = kwargs - self._ad_nlvs = None + self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None, + "recompute_count": 0} self._ad_adj_cache = {} return wrapper @@ -58,7 +59,7 @@ def wrapper(self, **kwargs): Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint computation (such as projecting fields to other function spaces for the purposes of visualisation).""" - + from firedrake import LinearVariationalSolver annotate = annotate_tape(kwargs) if annotate: tape = get_working_tape() @@ -76,13 +77,24 @@ def wrapper(self, **kwargs): solver_kwargs=self._ad_kwargs, ad_block_tag=self.ad_block_tag, **sb_kwargs) - if not self._ad_nlvs: - self._ad_nlvs = type(self)( + + # Forward variational solver. + if not self._ad_solvers["forward_nlvs"]: + self._ad_solvers["forward_nlvs"] = type(self)( self._ad_problem_clone(self._ad_problem, block.get_dependencies()), **self._ad_kwargs ) - block._ad_nlvs = self._ad_nlvs + # Adjoint variational solver. + if not self._ad_solvers["adjoint_lvs"]: + with stop_annotating(): + self._ad_solvers["adjoint_lvs"] = LinearVariationalSolver( + self._ad_adj_lvs_problem(block), *block.adj_args, **block.adj_kwargs) + if self._ad_problem._constant_jacobian: + self._ad_solvers["update_adjoint"] = False + + block._ad_solvers = self._ad_solvers + tape.add_block(block) with stop_annotating(): @@ -103,22 +115,65 @@ def _ad_problem_clone(self, problem, dependencies): affect the user-defined self._ad_problem.F, self._ad_problem.J and self._ad_problem.u expressions, we'll instead create clones of them. """ - from firedrake import Function, NonlinearVariationalProblem + from firedrake import NonlinearVariationalProblem + _ad_count_map, J_replace_map, F_replace_map = self._build_count_map( + problem.J, dependencies, Form=problem.F) + nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map), + F_replace_map[problem.u_restrict], + bcs=problem.bcs, + J=replace(problem.J, J_replace_map)) + nlvp.is_linear = problem.is_linear + nlvp._constant_jacobian = problem._constant_jacobian + nlvp._ad_count_map_update(_ad_count_map) + return nlvp + + @no_annotations + def _ad_adj_lvs_problem(self, block): + """Create the adjoint variational problem.""" + from firedrake import Function, Cofunction, LinearVariationalProblem + # Homogeneous boundary conditions for the adjoint problem + # when Dirichlet boundary conditions are applied. + bcs = block._homogenize_bcs() + adj_sol = Function(block.function_space) + right_hand_side = Cofunction(block.function_space.dual()) + tmp_problem = LinearVariationalProblem( + block.adj_F, right_hand_side, + adj_sol, bcs=bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + # The `block.adj_F` coefficients hold the output references. + # We do not want to modify the user-defined values. Hence, the adjoint + # linear variational problem is created with a deep copy of the + # `block.adj_F` coefficients. + _ad_count_map, J_replace_map, _ = self._build_count_map( + block.adj_F, block._dependencies + ) + lvp = LinearVariationalProblem( + replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, + bcs=tmp_problem.bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + lvp._ad_count_map_update(_ad_count_map) + del tmp_problem + return lvp + + def _build_count_map(self, J, dependencies, Form=None): + from firedrake import Function + F_replace_map = {} J_replace_map = {} - - F_coefficients = problem.F.coefficients() - J_coefficients = problem.J.coefficients() + if Form: + F_coefficients = Form.coefficients() + J_coefficients = J.coefficients() _ad_count_map = {} for block_variable in dependencies: coeff = block_variable.output - if coeff in F_coefficients and coeff not in F_replace_map: - if isinstance(coeff, Function) and coeff.ufl_element().family() == "Real": - F_replace_map[coeff] = copy.deepcopy(coeff) - else: - F_replace_map[coeff] = coeff.copy(deepcopy=True) - _ad_count_map[F_replace_map[coeff]] = coeff.count() + if Form: + if coeff in F_coefficients and coeff not in F_replace_map: + if isinstance(coeff, Function) and coeff.ufl_element().family() == "Real": + F_replace_map[coeff] = copy.deepcopy(coeff) + else: + F_replace_map[coeff] = coeff.copy(deepcopy=True) + _ad_count_map[F_replace_map[coeff]] = coeff.count() if coeff in J_coefficients and coeff not in J_replace_map: if coeff in F_replace_map: @@ -128,11 +183,4 @@ def _ad_problem_clone(self, problem, dependencies): else: J_replace_map[coeff] = coeff.copy() _ad_count_map[J_replace_map[coeff]] = coeff.count() - - nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map), - F_replace_map[problem.u_restrict], - bcs=problem.bcs, - J=replace(problem.J, J_replace_map)) - nlvp._constant_jacobian = problem._constant_jacobian - nlvp._ad_count_map_update(_ad_count_map) - return nlvp + return _ad_count_map, J_replace_map, F_replace_map diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index c429a951cb..962f2a0194 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -840,7 +840,7 @@ def test_assign_cofunction(solve_type): solver.solve() J += assemble(((sol + Constant(1.0)) ** 2) * dx) rf = ReducedFunctional(J, Control(k)) - assert rf(k) == J + assert np.isclose(rf(k), J, rtol=1e-10) assert taylor_test(rf, k, Function(V).assign(0.1)) > 1.9 @@ -969,17 +969,15 @@ def test_lvs_constant_jacobian(constant_jacobian): solver.solve() J = assemble(v * v * dx) - assert "dFdu_adj" not in solver._ad_adj_cache + J_hat = ReducedFunctional(J, Control(u)) - dJ = compute_gradient(J, Control(u), options={"riesz_representation": "l2"}) - - cached_dFdu_adj = solver._ad_adj_cache.get("dFdu_adj", None) - assert (cached_dFdu_adj is None) == (not constant_jacobian) + dJ = J_hat.derivative(options={"riesz_representation": "l2"}) assert np.allclose(dJ.dat.data_ro, 2 * assemble(inner(u_ref, test) * dx).dat.data_ro) - dJ = compute_gradient(J, Control(u), options={"riesz_representation": "l2"}) + u_ref = Function(space, name="u").interpolate(X[0] - 0.1) + J_hat(u_ref) - assert cached_dFdu_adj is solver._ad_adj_cache.get("dFdu_adj", None) + dJ = J_hat.derivative(options={"riesz_representation": "l2"}) assert np.allclose(dJ.dat.data_ro, 2 * assemble(inner(u_ref, test) * dx).dat.data_ro)