From ef76cc4d5f38ec3cb5ef0ef81ca301bf28024329 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 25 Sep 2024 09:58:54 -0400 Subject: [PATCH] misc: cleanup test and fix comments --- devito/finite_differences/derivative.py | 46 ++++++++++++++----------- devito/finite_differences/tools.py | 4 +-- devito/ir/equations/algorithms.py | 4 +-- tests/test_derivatives.py | 5 +++ tests/test_tensors.py | 2 +- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index b93e3ac94c..41e60552b9 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -235,22 +235,19 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None) except AttributeError: raise TypeError("fd_order incompatible with dimensions") - # In case this was called on a cross derivative we need to propagate - # the call to the nested derivative if isinstance(self.expr, Derivative): - _fd_orders = {k: v for k, v in _fd_order.items() if k in self.expr.dims} - _x0s = {k: v for k, v in _x0.items() if k in self.expr.dims and - k not in self.dims} - new_expr = self.expr(x0=_x0s, fd_order=_fd_orders, side=side, - method=method, weights=weights) + # In case this was called on a perfect cross-derivative `u.dxdy` + # we need to propagate the call to the nested derivative + x0s = self.filter_dims(self.expr.filter_dims(_x0), neg=True) + expr = self.expr(x0=x0s, fd_order=self.expr.filter_dims(_fd_order), + side=side, method=method) else: - new_expr = self.expr + expr = self.expr - _fd_order = tuple(v for k, v in _fd_order.items() if k in self.dims) - _fd_order = DimensionTuple(*_fd_order, getters=self.dims) + _fd_order = self.filter_dims(_fd_order, as_tuple=True) return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method, - weights=weights, expr=new_expr) + weights=weights, expr=expr) def _rebuild(self, *args, **kwargs): kwargs['preprocessed'] = True @@ -305,10 +302,10 @@ def _xreplace(self, subs): # Resolve nested derivatives dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)} - new_expr = self.expr.xreplace(dsubs) + expr = self.expr.xreplace(dsubs) subs = self._ppsubs + (subs,) # Postponed substitutions - return self._rebuild(subs=subs, expr=new_expr), True + return self._rebuild(subs=subs, expr=expr), True @cached_property def _metadata(self): @@ -316,6 +313,19 @@ def _metadata(self): ret.append(self.expr.staggered or (None,)) return tuple(ret) + def filter_dims(self, col, as_tuple=False, neg=False): + """ + Filter collectiion to only keep the derivative's dimensions as keys. + """ + if neg: + filtered = {k: v for k, v in col.items() if k not in self.dims} + else: + filtered = {k: v for k, v in col.items() if k in self.dims} + if as_tuple: + return DimensionTuple(*filtered.values(), getters=self.dims) + else: + return filtered + @property def dims(self): return self._dims @@ -436,13 +446,9 @@ def _eval_fd(self, expr, **kwargs): """ # Step 1: Evaluate non-derivative x0. We currently enforce a simple 2nd order # interpolation to avoid very expensive finite differences on top of it - x0_interp = {} - x0_deriv = {} - for d, v in self.x0.items(): - if d in self.dims: - x0_deriv[d] = v - elif not d.is_Time: - x0_interp[d] = v + x0_deriv = self.filter_dims(self.x0) + x0_interp = {d: v for d, v in self.x0.items() + if d not in x0_deriv and not d.is_Time} if x0_interp and self.method == 'FD': expr = interp_for_fd(expr, x0_interp, **kwargs) diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index b9afd8ca02..3ec12cd5bc 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -86,8 +86,8 @@ def generate_fd_shortcuts(dims, so, to=0): from devito.finite_differences.derivative import Derivative def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs): - # Spearate dimension to always have cross derivatives return nested - # derivatives. + # Separate dimensions to always have cross derivatives return nested + # derivatives. E.g `u.dxdy -> u.dx.dy` dims = as_tuple(dims) deriv_order = as_tuple(deriv_order) fd_order = as_tuple(fd_order) diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py index 85c4a67646..8337b5adf0 100644 --- a/devito/ir/equations/algorithms.py +++ b/devito/ir/equations/algorithms.py @@ -138,9 +138,9 @@ def _lower_exprs(expressions, subs): # Handle Array if isinstance(f, Array) and f.initvalue is not None: - initv = [_lower_exprs(i, subs) for i in f.initvalue] + initvalue = [_lower_exprs(i, subs) for i in f.initvalue] # TODO: fix rebuild to avoid new name - f = f._rebuild(name='%si' % f.name, initvalue=initv) + f = f._rebuild(name='%si' % f.name, initvalue=initvalue) mapper[i] = f.indexed[indices] # Add dimensions map to the mapper in case dimensions are used diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index f02ee59407..30893b2afb 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -404,6 +404,11 @@ def test_xderiv_x0(self): - f.dx(x0=x+h_x/2).dy(x0=y+h_y/2).evaluate assert simplify(expr) == 0 + # Check x0 is correctly set + dfdxdx = f.dx(x0=x+h_x/2).dx(x0=x-h_x/2) + assert dict(dfdxdx.x0) == {x: x-h_x/2} + assert dict(dfdxdx.expr.x0) == {x: x+h_x/2} + def test_fd_new_side(self): grid = Grid((10,)) u = Function(name="u", grid=grid, space_order=4) diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 30ac7e578f..15e18ababd 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -353,7 +353,7 @@ def test_shifted_curl_of_vector(shift, ndim): dorder = order or 4 for drv in drvs: assert drv.expr in f - assert drv.fd_order == dorder + assert drv.fd_order == (dorder,) if shift is None: assert drv.x0 == {} else: