diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 66fffa5256..c8d5f3d0da 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -103,10 +103,8 @@ def __new__(cls, expr, *dims, **kwargs): obj = Differentiable.__new__(cls, expr, *var_count) obj._dims = tuple(OrderedDict.fromkeys(new_dims)) - skip = kwargs.get('preprocessed', False) or obj.ndims == 1 - - obj._fd_order = fd_o if skip else DimensionTuple(*fd_o, getters=obj._dims) - obj._deriv_order = orders if skip else DimensionTuple(*orders, getters=obj._dims) + obj._fd_order = DimensionTuple(*as_tuple(fd_o), getters=obj._dims) + obj._deriv_order = DimensionTuple(*as_tuple(orders), getters=obj._dims) obj._side = kwargs.get("side") obj._transpose = kwargs.get("transpose", direct) obj._method = kwargs.get("method", 'FD') @@ -137,7 +135,7 @@ def _process_kwargs(cls, expr, *dims, **kwargs): fd_orders = kwargs.get('fd_order') deriv_orders = kwargs.get('deriv_order') if len(dims) == 1: - dims = tuple([dims[0]]*max(1, deriv_orders)) + dims = tuple([dims[0]]*max(1, deriv_orders[0])) variable_count = [sympy.Tuple(s, dims.count(s)) for s in filter_ordered(dims)] return dims, deriv_orders, fd_orders, variable_count @@ -293,8 +291,9 @@ def _xreplace(self, subs): except AttributeError: return new, True + new_expr = self.expr.xreplace(subs) subs = self._ppsubs + (subs,) # Postponed substitutions - return self._rebuild(subs=subs), True + return self._rebuild(subs=subs, expr=new_expr), True @cached_property def _metadata(self): @@ -455,8 +454,8 @@ def _eval_fd(self, expr, **kwargs): side=self.side) else: assert self.method == 'FD' - res = generic_derivative(expr, self.dims[0], as_tuple(self.fd_order)[0], - self.deriv_order, weights=self.weights, + res = generic_derivative(expr, self.dims[0], self.fd_order[0], + self.deriv_order[0], weights=self.weights, side=self.side, matvec=self.transpose, x0=self.x0, expand=expand) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 9a28dfde07..f78abbd7d5 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -141,7 +141,7 @@ def coefficients(self): coefficients = {f.coefficients for f in self._functions} # If there is multiple ones, we have to revert to the highest priority # i.e we have to remove symbolic - key = lambda x: coeff_priority[x] + key = lambda x: coeff_priority.get(x, -1) return sorted(coefficients, key=key, reverse=True)[0] @cached_property @@ -427,7 +427,11 @@ def has_free(self, *patterns): def highest_priority(DiffOp): - prio = lambda x: getattr(x, '_fd_priority', 0) + # We want to get the object with highest priority + # We also need to make sure that the object with the largest + # set of dimensions is used when multiple ones with the same + # priority appear + prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions)) return sorted(DiffOp._args_diff, key=prio, reverse=True)[0] diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 31538e0e3f..8651ea15f7 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -86,8 +86,18 @@ def generate_fd_shortcuts(dims, so, to=0): from devito.finite_differences.derivative import Derivative def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs): - return Derivative(expr, *as_tuple(dims), deriv_order=deriv_order, - fd_order=fd_order, side=side, **kwargs) + # Spearate dimension to always have cross derivatives return nested + # derivatives. + # Reverse to match the syntax `u.dxdy = (u.dx).dy` with x the inner + # derivative + dims = as_tuple(dims)[::-1] + deriv_order = as_tuple(deriv_order)[::-1] + fd_order = as_tuple(fd_order)[::-1] + deriv = Derivative(expr, dims[0], deriv_order=deriv_order[0], + fd_order=fd_order[0], side=side, **kwargs) + for (d, do, fo) in zip(dims[1:], deriv_order[1:], fd_order[1:]): + deriv = Derivative(deriv, d, deriv_order=do, fd_order=fo, side=side, **kwargs) + return deriv all_combs = dim_with_order(dims, orders) @@ -318,6 +328,8 @@ def process_weights(weights, expr): if weights is None: return 0, None elif isinstance(weights, Function): + if len(weights.dimensions) == 1: + return weights.shape[0], weights.dimensions[0] wdim = {d for d in weights.dimensions if d not in expr.dimensions} assert len(wdim) == 1 wdim = wdim.pop() diff --git a/devito/types/equation.py b/devito/types/equation.py index 546c2c4e96..662cdd0d34 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -90,12 +90,13 @@ def _apply_coeffs(cls, expr, coefficients): for coeff in coefficients.coefficients: derivs = [d for d in retrieve_derivatives(expr) if coeff.dimension in d.dims and - coeff.deriv_order == d.deriv_order] + coeff.deriv_order == d.deriv_order.get(coeff.dimension, None)] if not derivs: continue mapper.update({d: d._rebuild(weights=coeff.weights) for d in derivs}) if not mapper: return expr + return expr.xreplace(mapper) def _evaluate(self, **kwargs): diff --git a/tests/test_symbolic_coefficients.py b/tests/test_symbolic_coefficients.py index 0c520549f6..30a89b155b 100644 --- a/tests/test_symbolic_coefficients.py +++ b/tests/test_symbolic_coefficients.py @@ -80,6 +80,30 @@ def test_function_coefficients(self): assert np.all(np.isclose(f0.data[:] - f1.data[:], 0.0, atol=1e-5, rtol=0)) + def test_function_coefficients_xderiv(self): + p = Dimension('p') + + nstc = 8 + + grid = Grid(shape=(51, 51, 51)) + x, y, z = grid.dimensions + + f = Function(name='f', grid=grid, space_order=(2*nstc, 0, 0), + coefficients='symbolic') + g = Function(name='g', grid=grid, space_order=(2*nstc, 0, 0)) + ax = Function(name='DD2x', space_order=0, shape=(2*nstc + 1,), + dimensions=(p,)) + ay = Function(name='DD2y', space_order=0, shape=(2*nstc + 1,), + dimensions=(p,)) + stencil_coeffs_x_p1 = Coefficient(1, f, x, ax) + stencil_coeffs_y_p1 = Coefficient(1, f, y, ay) + stencil_coeffs = Substitutions(stencil_coeffs_x_p1, stencil_coeffs_y_p1) + + eqn = Eq(g, f.dxdy, coefficients=stencil_coeffs) + + op = Operator(eqn) + op() + def test_coefficients_w_xreplace(self): """Test custom coefficients with an xreplace before they are applied""" grid = Grid(shape=(4, 4))