Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint variational solver #3723

Open
wants to merge 48 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
6997430
Add adjoint variational solver.
Ig-dolci Jul 22, 2024
eb52a86
wip
Ig-dolci Jul 22, 2024
6b98f55
Update and Jacobian coefficients
Ig-dolci Jul 22, 2024
0a83c85
wip
Ig-dolci Jul 22, 2024
9387d48
wip
Ig-dolci Jul 22, 2024
be9046d
Update adjoint solver coefficients correctly.
Ig-dolci Aug 4, 2024
81da699
dd
Ig-dolci Aug 4, 2024
4968121
linting
Ig-dolci Aug 4, 2024
8e482d8
fix condition identation
Ig-dolci Aug 4, 2024
449e567
More docs; code enhacing.
Ig-dolci Aug 5, 2024
b24232c
linting
Ig-dolci Aug 5, 2024
4159f79
keep dFdm cache
Ig-dolci Aug 5, 2024
3fc41dd
LinearSolver if jacobian is constant
Ig-dolci Aug 5, 2024
ad05a12
update test jacobian is constant
Ig-dolci Aug 5, 2024
5a9d246
wip
Ig-dolci Aug 5, 2024
d004653
wip
Ig-dolci Aug 6, 2024
7c75e43
compute adjoint bdy working
Ig-dolci Aug 6, 2024
bf17de8
small changes
Ig-dolci Aug 6, 2024
022ef27
linting
Ig-dolci Aug 6, 2024
bfcf4c7
dd
Ig-dolci Aug 6, 2024
f5f705b
Adjoint variational solver when the jacobian is not constant
Ig-dolci Aug 6, 2024
bc205ab
compute adjoint boundary as a private method
Ig-dolci Aug 6, 2024
0c05743
Remove dJdu attribute; define form and boundary variables when necess…
Ig-dolci Aug 23, 2024
144a2f1
synchronize master
Ig-dolci Aug 23, 2024
aaac6fd
lint
Ig-dolci Aug 23, 2024
342f093
Update the problem.J coefficients
Ig-dolci Aug 26, 2024
bc3d765
Delegated block solvers and forms.
Ig-dolci Aug 26, 2024
9c9e376
minor changes
Ig-dolci Aug 27, 2024
1357a17
remove weak ref; minor changes; LinearSolver if Jacobian is constant
Ig-dolci Sep 10, 2024
d8345ee
lvs when not self._ad_problem._constant_jacobian.
Ig-dolci Sep 11, 2024
f0270e1
flake8
Ig-dolci Sep 11, 2024
58bec52
code for constant jacobian
Ig-dolci Sep 14, 2024
6bb1d14
Adapting test
Ig-dolci Sep 14, 2024
856de03
merge master conflict solved
Ig-dolci Sep 14, 2024
1e67f5a
flake8
Ig-dolci Sep 14, 2024
3f4306f
solvers dictionary
Ig-dolci Oct 7, 2024
0294c95
Update test for the new solvers.
Ig-dolci Oct 7, 2024
f05e8b9
Add adj_args and adj_kwargs; add a property in nlvs; remove _assemble…
Ig-dolci Oct 8, 2024
92f7f81
flake8; minimal tests modifications
Ig-dolci Oct 8, 2024
2f8c84b
flake8
Ig-dolci Oct 8, 2024
de687bd
minor changes
Ig-dolci Oct 9, 2024
23c82b6
minor changes
Ig-dolci Oct 9, 2024
fdcb2b3
LinearSolver for constant jacobian
Ig-dolci Oct 14, 2024
5fee9ee
minor changes
Ig-dolci Oct 14, 2024
01bbe5c
Keep only variational solver.
Ig-dolci Oct 14, 2024
74de2ac
Update constant Jacobian test according new adjoint solver.
Ig-dolci Oct 14, 2024
850e210
Minor chancge
Ig-dolci Oct 14, 2024
1dc5892
Test with right branch
Ig-dolci Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_value = firedrake.Function(self.collapsed_space, vec.dat)

if adj_value.ufl_shape == () or adj_value.ufl_shape[0] <= 1:
r = adj_value.dat.data_ro.sum()
R = firedrake.FunctionSpace(self.parent_space.mesh(), "R", 0)
r = firedrake.Function(R.dual(), val=adj_value.dat.data_ro.sum())
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
else:
output = []
subindices = _extract_subindices(self.function_space)
Expand Down
93 changes: 72 additions & 21 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,17 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):

adj_sol_bdy = None
if compute_bdy:
adj_sol_bdy = firedrake.Function(
self.function_space.dual(),
dJdu_copy.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)
).dat
)
adj_sol_bdy = self.compute_adj_bdy(
adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy)

return adj_sol, adj_sol_bdy

def compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.Function(
self.function_space.dual(), dJdu.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)).dat)
return adj_sol_bdy

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
if not self.linear and self.func == block_variable.output:
Expand Down Expand Up @@ -630,6 +632,44 @@ def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
func.assign(self._ad_nlvs._problem.u)
return func

def _adjoint_solve(self, dJdu, adj_sol):
# Update the right hand side of the adjoint equation.
self._ad_dJdu.assign(dJdu)

# Update the left hand side coefficients of the adjoint equation.
if isinstance(self._ad_adj_solver, firedrake.LinearVariationalSolver):
problem = self._ad_adj_solver._problem
for block_variable in self.get_dependencies():
# The self.adj_F coefficients hold the forward output
# references.
if block_variable.output in self.adj_F.coefficients():
index = self.adj_F.coefficients().index(block_variable.output)
if isinstance(
block_variable.output, (
firedrake.Function, firedrake.Constant,
firedrake.Cofunction)):
# `problem.J` is a deep copy of `self.adj_F`.
# The indices of `self.adj_F` serve as a map for
# updating the coefficients of the adjoint solver.
problem.J.coefficients()[index].assign(
block_variable.saved_output)
bv = self.get_outputs()[0]
if bv.output in self.adj_F.coefficients():
index = self.adj_F.coefficients().index(bv.output)
problem.J.coefficients()[index].assign(
bv.checkpoint)
# Solve the adjoint equation.
self._ad_adj_solver.solve()
adj_sol.assign(self._ad_adj_solver._problem.u)
elif isinstance(self._ad_adj_solver, firedrake.LinearSolver):
# Solve the adjoint equation.
self._ad_adj_solver.solve(adj_sol, self._ad_dJdu)
else:
raise NotImplementedError(
"Only LinearVariationalSolver and LinearSolver are supported."
)
return adj_sol

def _ad_assign_map(self, form):
count_map = self._ad_nlvs._problem._ad_count_map
assign_map = {}
Expand Down Expand Up @@ -666,25 +706,36 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):
return dFdu

def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
dJdu = adj_inputs[0]

F_form = self._create_F_form()
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

dFdu_form = self.adj_F
dJdu = dJdu.copy()

# Replace the form coefficients with checkpointed values.
replace_map = self._replace_map(dFdu_form)
replace_map[self.func] = self.get_outputs()[0].saved_output
dFdu_form = replace(dFdu_form, replace_map)

compute_bdy = self._should_compute_boundary_adjoint(
relevant_dependencies
)
adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq(
dFdu_form, dJdu, compute_bdy
)
# Forward form.
F_form = self._create_F_form()
dJdu = adj_inputs[0]
dJdu_copy = dJdu.copy()
adj_sol = firedrake.Function(self.function_space)
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
for bc in bcs:
bc.apply(dJdu)
# Solve the adjoint equation and update the adjoint solution
# (`adj_sol`).
adj_sol = self._adjoint_solve(dJdu, adj_sol)
self.adj_sol = adj_sol
adj_sol_bdy = None

if compute_bdy:
if isinstance(
self._ad_adj_solver, firedrake.LinearVariationalSolver
):
dFdu_adj_form = self._ad_adj_solver._problem.J
elif isinstance(self._ad_adj_solver, firedrake.LinearSolver):
# Jacobian is constant in this case.
dFdu_adj_form = self.adj_F

adj_sol_bdy = self.compute_adj_bdy(
adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy)

if self.adj_cb is not None:
self.adj_cb(adj_sol)
if self.adj_bdy_cb is not None and compute_bdy:
Expand Down
72 changes: 62 additions & 10 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def wrapper(self, problem, *args, **kwargs):
self._ad_kwargs = kwargs
self._ad_nlvs = None
self._ad_adj_cache = {}
self._ad_adj_solver = None
self._ad_dJdu = None

return wrapper

Expand All @@ -58,7 +60,7 @@ def wrapper(self, **kwargs):
Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic
for the purposes of the adjoint computation (such as projecting fields to other function spaces
for the purposes of visualisation)."""

from firedrake import LinearVariationalSolver, assemble, LinearSolver, Cofunction
annotate = annotate_tape(kwargs)
if annotate:
tape = get_working_tape()
Expand All @@ -76,13 +78,36 @@ def wrapper(self, **kwargs):
solver_kwargs=self._ad_kwargs,
ad_block_tag=self.ad_block_tag,
**sb_kwargs)
# Forward variational solver.
if not self._ad_nlvs:
self._ad_nlvs = type(self)(
self._ad_problem_clone(self._ad_problem, block.get_dependencies()),
**self._ad_kwargs
)

block._ad_nlvs = self._ad_nlvs
if not self._ad_dJdu:
# Right-hand side of the adjoint equation.
self._ad_dJdu = Cofunction(block.function_space.dual())
block._ad_dJdu = self._ad_dJdu

# Adjoint solver.
with stop_annotating():
if not self._ad_adj_solver:
# Homogeneous boundary conditions for the adjoint problem
# when Dirichlet boundary conditions are applied.
bcs = block._homogenize_bcs()
if block._ad_nlvs._problem._constant_jacobian:
self._ad_adj_solver = LinearSolver(
assemble(block.adj_F, bcs=bcs),
solver_parameters=self.parameters)
else:
problem = self._ad_adj_lvs_problem(block, bcs)
self._ad_adj_solver = LinearVariationalSolver(
problem, solver_parameters=self.parameters)

block._ad_adj_solver = self._ad_adj_solver

tape.add_block(block)

with stop_annotating():
Expand All @@ -103,7 +128,41 @@ def _ad_problem_clone(self, problem, dependencies):
affect the user-defined self._ad_problem.F, self._ad_problem.J and self._ad_problem.u
expressions, we'll instead create clones of them.
"""
from firedrake import Function, NonlinearVariationalProblem
from firedrake import NonlinearVariationalProblem
_ad_count_map, F_replace_map, J_replace_map = self._build_count_map(
problem, dependencies)
nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map),
F_replace_map[problem.u_restrict],
bcs=problem.bcs,
J=replace(problem.J, J_replace_map))
nlvp.is_linear = problem.is_linear
nlvp._constant_jacobian = problem._constant_jacobian
nlvp._ad_count_map_update(_ad_count_map)
return nlvp

@no_annotations
def _ad_adj_lvs_problem(self, block, bcs):
"""Create the adjoint variational problem."""
from firedrake import Function, LinearVariationalProblem
adj_sol = Function(block.function_space)
tmp_problem = LinearVariationalProblem(
block.adj_F, block._ad_dJdu, adj_sol, bcs=bcs)
# The `block.adj_F` coefficients hold the output references.
# We do not want to modify the user-defined values. Hence, the adjoint
# linear variational problem is created with a deep copy of the forward
# outputs.
_ad_count_map, _, J_replace_map = self._build_count_map(
tmp_problem, block._dependencies)
lvp = LinearVariationalProblem(
replace(tmp_problem.J, J_replace_map), self._ad_dJdu, adj_sol,
bcs=tmp_problem.bcs)
lvp._ad_count_map_update(_ad_count_map)
del tmp_problem
return lvp

def _build_count_map(self, problem, dependencies):
from firedrake import Function

F_replace_map = {}
J_replace_map = {}

Expand All @@ -128,11 +187,4 @@ def _ad_problem_clone(self, problem, dependencies):
else:
J_replace_map[coeff] = coeff.copy()
_ad_count_map[J_replace_map[coeff]] = coeff.count()

nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map),
F_replace_map[problem.u_restrict],
bcs=problem.bcs,
J=replace(problem.J, J_replace_map))
nlvp._constant_jacobian = problem._constant_jacobian
nlvp._ad_count_map_update(_ad_count_map)
return nlvp
return _ad_count_map, F_replace_map, J_replace_map
37 changes: 28 additions & 9 deletions tests/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,16 +968,35 @@ def test_lvs_constant_jacobian(constant_jacobian):
solver = LinearVariationalSolver(problem)
solver.solve()
J = assemble(v * v * dx)

assert "dFdu_adj" not in solver._ad_adj_cache

dJ = compute_gradient(J, Control(u), options={"riesz_representation": "l2"})

cached_dFdu_adj = solver._ad_adj_cache.get("dFdu_adj", None)
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
assert (cached_dFdu_adj is None) == (not constant_jacobian)
assert np.allclose(dJ.dat.data_ro, 2 * assemble(inner(u_ref, test) * dx).dat.data_ro)

dJ = compute_gradient(J, Control(u), options={"riesz_representation": "l2"})

assert cached_dFdu_adj is solver._ad_adj_cache.get("dFdu_adj", None)
assert np.allclose(dJ.dat.data_ro, 2 * assemble(inner(u_ref, test) * dx).dat.data_ro)
@pytest.mark.skipcomplex
@pytest.mark.parametrize("constant_jacobian", [False, True])
def test_adjoint_solver_compute_bdy(constant_jacobian):
# Testing the case where is required to compute the adjoint
# boundary condition.
mesh = UnitIntervalMesh(10)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
sol = Function(space, name="sol")
# Dirichlet boundary conditions
R = FunctionSpace(mesh, "R", 0)
a = Function(R, val=1.0)
b = Function(R, val=2.0)
bc_left = DirichletBC(space, a, 1)
bc_right = DirichletBC(space, b, 2)
bc = [bc_left, bc_right]
F = dot(grad(trial), grad(test)) * dx
problem = LinearVariationalProblem(lhs(F), rhs(F), sol, bcs=bc,
constant_jacobian=constant_jacobian)
solver = LinearVariationalSolver(problem)
solver.solve()
J = assemble(sol * sol * dx)
J_hat = ReducedFunctional(J, [Control(a), Control(b)])

assert taylor_test(
J_hat, [a, b], [Function(R, val=rand(1)), Function(R, val=rand(1))]
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
) > 1.9
Loading