Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor the riesz map out into a separate object. #3662

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch ufl dham/cofunction_is_terminal \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
73 changes: 73 additions & 0 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,76 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
def __str__(self):
deps = self.get_dependencies()
return f"{deps[1]}[{self.idx}].assign({deps[0]})"


class CofunctionAssignBlock(Block):
"""Class specifically for the case b.assign(a).

All other cofunction assignment operations are annotated via Assemble. In
effect this means that this is the annotation of an identity operation.

Parameters
----------
lhs:
The target of the assignment.
rhs:
The cofunction being assigned.
"""

def __init__(self, lhs: firedrake.Cofunction, rhs: firedrake.Cofunction,
ad_block_tag=None):
super().__init__(ad_block_tag=ad_block_tag)
self.add_output(lhs.block_variable)
self.add_dependency(rhs)

def recompute_component(self, inputs, block_variable, idx, prepared=None):
"""Recompute the assignment.

Parameters
----------
inputs : list of Function or Constant
The variable in the RHS of the assignment.
block_variable : pyadjoint.block_variable.BlockVariable
The output block variable.
idx : int
Index associated to the inputs list.
prepared :
The precomputed RHS value.

Notes
-----
Recomputes the block_variable only if the checkpoint was not delegated
to another :class:`~firedrake.function.Function`.

Returns
-------
Cofunction
Return either the firedrake cofunction or `BlockVariable`
checkpoint to which was delegated the checkpointing.
"""
assert idx == 0 # There must be only one RHS.
if isinstance(block_variable.checkpoint, DelegatedFunctionCheckpoint):
return block_variable.checkpoint
else:
output = firedrake.Cofunction(
block_variable.output.function_space()
)
output.assign(inputs[0])
return maybe_disk_checkpoint(output)

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
return adj_inputs[0]

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
block_variable, idx,
relevant_dependencies, prepared=None):
return hessian_inputs[0]

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
prepared=None):
return tlm_inputs[0]

def __str__(self):
deps = self.get_dependencies()
return f"assign({deps[0]})"
19 changes: 15 additions & 4 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,15 @@ def _init_solver_parameters(self, args, kwargs):
self.assemble_kwargs = {}

def __str__(self):
return "solve({} = {})".format(ufl2unicode(self.lhs),
ufl2unicode(self.rhs))
try:
lhs_string = ufl2unicode(self.lhs)
except AttributeError:
lhs_string = str(self.lhs)
try:
rhs_string = ufl2unicode(self.rhs)
except AttributeError:
rhs_string = str(self.rhs)
return "solve({} = {})".format(lhs_string, rhs_string)

def _create_F_form(self):
# Process the equation forms, replacing values with checkpoints,
Expand Down Expand Up @@ -692,7 +699,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
c = block_variable.output
c_rep = block_variable.saved_output

if isinstance(c, firedrake.Function):
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.Constant):
mesh = F_form.ufl_domain()
Expand Down Expand Up @@ -729,7 +736,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
replace_map[self.func] = self.get_outputs()[0].saved_output
dFdm = replace(dFdm, replace_map)

dFdm = dFdm * adj_sol
if isinstance(dFdm, firedrake.Argument):
# Corner case. Should be fixed more permanently upstream in UFL.
dFdm = ufl.Action(dFdm, adj_sol)
else:
dFdm = dFdm * adj_sol
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)

return dFdm
Expand Down
63 changes: 17 additions & 46 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,55 +221,18 @@ def _ad_create_checkpoint(self):
return self.copy(deepcopy=True)

def _ad_convert_riesz(self, value, options=None):
from firedrake import Function, Cofunction
from firedrake import Function

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
solver_options = options.get("solver_options", {})
V = options.get("function_space", self.function_space())
if value == 0.:
# In adjoint-based differentiation, value == 0. arises only when
# the functional is independent on the control variable.
return Function(V)

if not isinstance(value, (Cofunction, Function)):
raise TypeError("Expected a Cofunction or a Function")

if riesz_representation == "l2":
return Function(V, val=value.dat)

elif riesz_representation in ("L2", "H1"):
if not isinstance(value, Cofunction):
raise TypeError("Expected a Cofunction")

ret = Function(V)
a = self._define_riesz_map_form(riesz_representation, V)
firedrake.solve(a == value, ret, **solver_options)
return ret

elif callable(riesz_representation):
return riesz_representation(value)

else:
raise ValueError(
"Unknown Riesz representation %s" % riesz_representation)

def _define_riesz_map_form(self, riesz_representation, V):
from firedrake import TrialFunction, TestFunction
return Function(self.function_space())

u = TrialFunction(V)
v = TestFunction(V)
if riesz_representation == "L2":
a = firedrake.inner(u, v)*firedrake.dx

elif riesz_representation == "H1":
a = firedrake.inner(u, v)*firedrake.dx \
+ firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx

else:
raise NotImplementedError(
"Unknown Riesz representation %s" % riesz_representation)
return a
return value.riesz_representation(riesz_map=riesz_representation,
solver_options=solver_options)

@no_annotations
def _ad_convert_type(self, value, options=None):
Expand All @@ -294,9 +257,7 @@ def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint

def _ad_will_add_as_dependency(self):
"""Method called when the object is added as a Block dependency.

"""
"""Method called when the object is added as a Block dependency."""
with checkpoint_init_data():
super()._ad_will_add_as_dependency()

Expand All @@ -305,7 +266,8 @@ def _ad_mul(self, other):
from firedrake import Function

r = Function(self.function_space())
# `self` can be a Cofunction in which case only left multiplication with a scalar is allowed.
# `self` can be a Cofunction in which case only left multiplication
# with a scalar is allowed.
r.assign(other * self)
return r

Expand All @@ -318,7 +280,10 @@ def _ad_add(self, other):
return r

def _ad_dot(self, other, options=None):
from firedrake import assemble
from firedrake import assemble, action, Cofunction

if isinstance(other, Cofunction):
return assemble(action(other, self))

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
Expand Down Expand Up @@ -400,3 +365,9 @@ def _applyBinary(self, f, y):

def __deepcopy__(self, memodict={}):
return self.copy(deepcopy=True)


class CofunctionMixin(FunctionMixin):

def _ad_dot(self, other):
return firedrake.assemble(firedrake.action(self, other))
Loading
Loading