Skip to content

Commit

Permalink
api: fix pickle with derivative
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 13, 2024
1 parent 7c7994c commit 958720a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
4 changes: 2 additions & 2 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from .tools import direct, transpose
from .rsfd import d45
from devito.tools import (as_mapper, as_tuple, filter_ordered, frozendict, is_integer,
Reconstructable)
Pickable)
from devito.types.utils import DimensionTuple

__all__ = ['Derivative']


class Derivative(sympy.Derivative, Differentiable, Reconstructable):
class Derivative(sympy.Derivative, Differentiable, Pickable):

"""
An unevaluated Derivative, which carries metadata (Dimensions,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PrecomputedSparseTimeFunction)
from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext
from devito.data import LEFT, OWNED
from devito.finite_differences.tools import direct, transpose, left, right, centered
from devito.mpi.halo_scheme import Halo
from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject,
MPIRegion)
Expand Down Expand Up @@ -504,6 +505,33 @@ def test_receiver(self, pickle):
assert np.all(new_rec.data == 1)
assert np.all(new_rec.coordinates.data == [[0.], [1.], [2.]])

@pytest.mark.parametrize('transpose', [direct, transpose])
@pytest.mark.parametrize('side', [left, right, centered])
@pytest.mark.parametrize('deriv_order', [1, 2])
@pytest.mark.parametrize('fd_order', [2, 4])
@pytest.mark.parametrize('x0', ["{}", "{x: x + x.spacing/2}"])
@pytest.mark.parametrize('method', ['FD', 'RSFD'])
@pytest.mark.parametrize('weights', [None, [1., 2., 3.]])
def test_derivative(self, pickle, transpose, side, deriv_order,
fd_order, x0, method, weights):
grid = Grid(shape=(3,))
x = grid.dimensions[0]
x0 = eval(x0)
f = Function(name='f', grid=grid, space_order=2)
dfdx = f.diff(x, order=deriv_order, fd_order=fd_order, side=side,
x0=x0, method=method, weights=weights)

pkl_dfdx = pickle.dumps(dfdx)
new_dfdx = pickle.loads(pkl_dfdx)

assert new_dfdx.dims == dfdx.dims
assert new_dfdx.side == dfdx.side
assert new_dfdx.fd_order == dfdx.fd_order
assert new_dfdx.deriv_order == dfdx.deriv_order
assert new_dfdx.x0 == dfdx.x0
assert new_dfdx.method == dfdx.method
assert new_dfdx.weights == dfdx.weights


class TestAdvanced:

Expand Down

0 comments on commit 958720a

Please sign in to comment.