Skip to content

Commit

Permalink
api: cleanup hierachy and properties of sparse and interpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 23, 2023
1 parent 2d7778d commit a945ed4
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 222 deletions.
47 changes: 22 additions & 25 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,37 +107,37 @@ def interpolate(self, *args, **kwargs):
pass


class WeightedInterpolation(GenericInterpolator):
class WeightedInterpolator(GenericInterpolator):

"""
Represent an Interpolation operation on a SparseFunction that is separable
in space, meaning hte coefficient are defined for each Dimension separately
in space, meaning the coefficients are defined for each Dimension separately
and multiplied at a given point: `w[x, y] = wx[x] * wy[y]`
"""

def __init__(self, sfunction):
self.sfunction = sfunction

@property
@cached_property
def grid(self):
return self.sfunction.grid

@property
def _weights(self):
raise NotImplementedError

@property
@cached_property
def _psym(self):
return self.sfunction._point_symbols

@property
@cached_property
def _gdim(self):
return self.grid.dimensions

def implicit_dims(self, implicit_dims):
return as_tuple(implicit_dims) + self.sfunction.dimensions

@property
@cached_property
def r(self):
return self.sfunction.r

Expand Down Expand Up @@ -313,18 +313,17 @@ def callback():
return Injection(field, expr, offset, self, callback)


class LinearInterpolator(WeightedInterpolation):

class LinearInterpolator(WeightedInterpolator):
"""
Concrete implementation of GenericInterpolator implementing a Linear interpolation
Concrete implementation of WeightedInterpolator implementing a Linear interpolation
scheme, i.e. Bilinear for 2D and Trilinear for 3D problems.
Parameters
----------
sfunction: The SparseFunction that this Interpolator operates on.
"""

@cached_property
@property
def _weights(self):
return {d: [1 - p/d.spacing, p/d.spacing]
for (d, p) in zip(self._gdim, self._psym)}
Expand All @@ -336,7 +335,16 @@ def _coeff_temps(self, implicit_dims):
for (d, pos) in zip(self._gdim, pmap)]


class PrecomputedInterpolator(WeightedInterpolation):
class PrecomputedInterpolator(WeightedInterpolator):
"""
Concrete implementation of WeightedInterpolator implementing a Precomputed
interpolation scheme, i.e. an interpolation with user provided precomputed
weigths/coefficients.
Parameters
----------
sfunction: The SparseFunction that this Interpolator operates on.
"""

def _positions(self, implicit_dims):
if self.sfunction.gridpoints is None:
Expand All @@ -346,23 +354,12 @@ def _positions(self, implicit_dims):
return []

@property
def _interp_points(self):
return range(-self.r//2 + 1, self.r//2 + 1)

@property
def _icoeffs(self):
def interpolation_coeffs(self):
return self.sfunction.interpolation_coeffs

@property
def _idim(self):
return self.sfunction.interpolation_coeffs.dimensions[-1]

@property
def _ddim(self):
return self.sfunction.interpolation_coeffs.dimensions[1]

@cached_property
def _weights(self):
return {d: [self._icoeffs.subs({self._ddim: di, self._idim: k})
ddim, cdim = self.interpolation_coeffs.dimensions[1:]
return {d: [self.interpolation_coeffs.subs({ddim: di, cdim: k})
for k in self._interp_points]
for (di, d) in enumerate(self._gdim)}
7 changes: 6 additions & 1 deletion devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cgen import dtype_to_ctype as cgen_dtype_to_ctype

__all__ = ['int2', 'int3', 'int4', 'float2', 'float3', 'float4', 'double2', # noqa
'double3', 'double4', 'dtypes_vector_mapper',
'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype',
'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len',
'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper',
'is_external_ctype', 'infer_dtype']
Expand Down Expand Up @@ -128,6 +128,11 @@ def dtype_to_mpitype(dtype):
}[dtype]


def dtype_to_mpidtype(dtype):
from devito.mpi import MPI
return MPI._typedict[np.dtype(dtype).char]


def dtype_len(dtype):
"""
Number of elements associated with one object of type `dtype`. Thus,
Expand Down
Loading

0 comments on commit a945ed4

Please sign in to comment.