Skip to content

Commit

Permalink
api: revamp cross derivatives to always be nested derivatives with sh…
Browse files Browse the repository at this point in the history
…ortcuts
  • Loading branch information
mloubout committed Sep 24, 2024
1 parent fb4b50e commit d6d6759
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 13 deletions.
15 changes: 7 additions & 8 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ def __new__(cls, expr, *dims, **kwargs):
obj = Differentiable.__new__(cls, expr, *var_count)
obj._dims = tuple(OrderedDict.fromkeys(new_dims))

skip = kwargs.get('preprocessed', False) or obj.ndims == 1

obj._fd_order = fd_o if skip else DimensionTuple(*fd_o, getters=obj._dims)
obj._deriv_order = orders if skip else DimensionTuple(*orders, getters=obj._dims)
obj._fd_order = DimensionTuple(*as_tuple(fd_o), getters=obj._dims)
obj._deriv_order = DimensionTuple(*as_tuple(orders), getters=obj._dims)
obj._side = kwargs.get("side")
obj._transpose = kwargs.get("transpose", direct)
obj._method = kwargs.get("method", 'FD')
Expand Down Expand Up @@ -137,7 +135,7 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
fd_orders = kwargs.get('fd_order')
deriv_orders = kwargs.get('deriv_order')
if len(dims) == 1:
dims = tuple([dims[0]]*max(1, deriv_orders))
dims = tuple([dims[0]]*max(1, deriv_orders[0]))
variable_count = [sympy.Tuple(s, dims.count(s))
for s in filter_ordered(dims)]
return dims, deriv_orders, fd_orders, variable_count
Expand Down Expand Up @@ -293,8 +291,9 @@ def _xreplace(self, subs):
except AttributeError:
return new, True

new_expr = self.expr.xreplace(subs)
subs = self._ppsubs + (subs,) # Postponed substitutions
return self._rebuild(subs=subs), True
return self._rebuild(subs=subs, expr=new_expr), True

@cached_property
def _metadata(self):
Expand Down Expand Up @@ -455,8 +454,8 @@ def _eval_fd(self, expr, **kwargs):
side=self.side)
else:
assert self.method == 'FD'
res = generic_derivative(expr, self.dims[0], as_tuple(self.fd_order)[0],
self.deriv_order, weights=self.weights,
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
self.deriv_order[0], weights=self.weights,
side=self.side, matvec=self.transpose,
x0=self.x0, expand=expand)

Expand Down
8 changes: 6 additions & 2 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def coefficients(self):
coefficients = {f.coefficients for f in self._functions}
# If there is multiple ones, we have to revert to the highest priority
# i.e we have to remove symbolic
key = lambda x: coeff_priority[x]
key = lambda x: coeff_priority.get(x, -1)
return sorted(coefficients, key=key, reverse=True)[0]

@cached_property
Expand Down Expand Up @@ -427,7 +427,11 @@ def has_free(self, *patterns):


def highest_priority(DiffOp):
prio = lambda x: getattr(x, '_fd_priority', 0)
# We want to get the object with highest priority
# We also need to make sure that the object with the largest
# set of dimensions is used when multiple ones with the same
# priority appear
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


Expand Down
16 changes: 14 additions & 2 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,18 @@ def generate_fd_shortcuts(dims, so, to=0):
from devito.finite_differences.derivative import Derivative

def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs):
return Derivative(expr, *as_tuple(dims), deriv_order=deriv_order,
fd_order=fd_order, side=side, **kwargs)
# Spearate dimension to always have cross derivatives return nested
# derivatives.
# Reverse to match the syntax `u.dxdy = (u.dx).dy` with x the inner
# derivative
dims = as_tuple(dims)[::-1]
deriv_order = as_tuple(deriv_order)[::-1]
fd_order = as_tuple(fd_order)[::-1]
deriv = Derivative(expr, dims[0], deriv_order=deriv_order[0],
fd_order=fd_order[0], side=side, **kwargs)
for (d, do, fo) in zip(dims[1:], deriv_order[1:], fd_order[1:]):
deriv = Derivative(deriv, d, deriv_order=do, fd_order=fo, side=side, **kwargs)
return deriv

all_combs = dim_with_order(dims, orders)

Expand Down Expand Up @@ -318,6 +328,8 @@ def process_weights(weights, expr):
if weights is None:
return 0, None
elif isinstance(weights, Function):
if len(weights.dimensions) == 1:
return weights.shape[0], weights.dimensions[0]
wdim = {d for d in weights.dimensions if d not in expr.dimensions}
assert len(wdim) == 1
wdim = wdim.pop()
Expand Down
3 changes: 2 additions & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ def _apply_coeffs(cls, expr, coefficients):
for coeff in coefficients.coefficients:
derivs = [d for d in retrieve_derivatives(expr)
if coeff.dimension in d.dims and
coeff.deriv_order == d.deriv_order]
coeff.deriv_order == d.deriv_order.get(coeff.dimension, None)]
if not derivs:
continue
mapper.update({d: d._rebuild(weights=coeff.weights) for d in derivs})
if not mapper:
return expr

return expr.xreplace(mapper)

def _evaluate(self, **kwargs):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_symbolic_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,30 @@ def test_function_coefficients(self):

assert np.all(np.isclose(f0.data[:] - f1.data[:], 0.0, atol=1e-5, rtol=0))

def test_function_coefficients_xderiv(self):
p = Dimension('p')

nstc = 8

grid = Grid(shape=(51, 51, 51))
x, y, z = grid.dimensions

f = Function(name='f', grid=grid, space_order=(2*nstc, 0, 0),
coefficients='symbolic')
g = Function(name='g', grid=grid, space_order=(2*nstc, 0, 0))
ax = Function(name='DD2x', space_order=0, shape=(2*nstc + 1,),
dimensions=(p,))
ay = Function(name='DD2y', space_order=0, shape=(2*nstc + 1,),
dimensions=(p,))
stencil_coeffs_x_p1 = Coefficient(1, f, x, ax)
stencil_coeffs_y_p1 = Coefficient(1, f, y, ay)
stencil_coeffs = Substitutions(stencil_coeffs_x_p1, stencil_coeffs_y_p1)

eqn = Eq(g, f.dxdy, coefficients=stencil_coeffs)

op = Operator(eqn)
op()

def test_coefficients_w_xreplace(self):
"""Test custom coefficients with an xreplace before they are applied"""
grid = Grid(shape=(4, 4))
Expand Down

0 comments on commit d6d6759

Please sign in to comment.