Skip to content

Commit

Permalink
builtins: Support batched initialize_function
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Aug 3, 2023
1 parent d87430f commit 9b9edb5
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 63 deletions.
158 changes: 95 additions & 63 deletions devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,66 @@ def fset(f, g):
return f


def _initialize_function(function, data, nbl, mapper=None, mode='constant'):
"""
Construct the symbolic objects for `initialize_function`.
"""
nbl, slices = nbl_to_padsize(nbl, function.ndim)
if isinstance(data, dv.Function):
function.data[slices] = data.data[:]
else:
function.data[slices] = data
lhs = []
rhs = []
options = []

if mode == 'reflect' and function.grid.distributor.is_parallel:
# Check that HALO size is appropriate
halo = function.halo
local_size = function.shape

def buff(i, j):
return [(i + k - 2*max(max(nbl))) for k in j]

b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))]
if any(np.array(b) < 0):
raise ValueError("Function `%s` halo is not sufficiently thick." % function)

for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)):
dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl)
dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr)
if mode == 'constant':
subsl = nl
subsr = d.symbolic_max - nr
elif mode == 'reflect':
subsl = 2*nl - 1 - dim_l
subsr = 2*(d.symbolic_max - nr) + 1 - dim_r
else:
raise ValueError("Mode not available")
lhs.append(function.subs({d: dim_l}))
lhs.append(function.subs({d: dim_r}))
rhs.append(function.subs({d: subsl}))
rhs.append(function.subs({d: subsr}))
options.extend([None, None])

if mapper and d in mapper.keys():
exprs = mapper[d]
lhs_extra = exprs['lhs']
rhs_extra = exprs['rhs']
lhs.extend(as_list(lhs_extra))
rhs.extend(as_list(rhs_extra))
options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ])
if isinstance(options_extra, list):
options.extend(options_extra)
else:
options.extend([options_extra])

if all(options is None for i in options):
options = None

return lhs, rhs, options


def initialize_function(function, data, nbl, mapper=None, mode='constant',
name=None, pad_halo=True, **kwargs):
"""
Expand All @@ -225,9 +285,9 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',
Parameters
----------
function : Function
function : Function or list of Functions
The initialised object.
data : ndarray or Function
data : ndarray or Function or list of ndarray/Function
The data used for initialisation.
nbl : int or tuple of int or tuple of tuple of int
Number of outer layers (such as absorbing layers for boundary damping).
Expand Down Expand Up @@ -286,73 +346,45 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',
[2, 3, 3, 3, 3, 2],
[2, 2, 2, 2, 2, 2]], dtype=int32)
"""
name = name or 'pad_%s' % function.name
if isinstance(function, dv.TimeFunction):
if isinstance(function, (list, tuple)):
if not isinstance(data, (list, tuple)):
raise TypeError("Expected a list of `data`")
elif len(function) != len(data):
raise ValueError("Expected %d `data` items, got %d" %
(len(function), len(data)))

if mapper is not None:
raise NotImplementedError("Unsupported `mapper` with batching")

functions = function
datas = data
else:
functions = (function,)
datas = (data,)

if any(isinstance(f, dv.TimeFunction) for f in functions):
raise NotImplementedError("TimeFunctions are not currently supported.")

if nbl == 0:
if isinstance(data, dv.Function):
function.data[:] = data.data[:]
else:
function.data[:] = data[:]
if pad_halo:
pad_outhalo(function)
return

nbl, slices = nbl_to_padsize(nbl, function.ndim)
if isinstance(data, dv.Function):
function.data[slices] = data.data[:]
for f in functions:
if isinstance(data, dv.Function):
f.data[:] = data.data[:]
else:
f.data[:] = data[:]
else:
function.data[slices] = data
lhs = []
rhs = []
options = []

if mode == 'reflect' and function.grid.distributor.is_parallel:
# Check that HALO size is appropriate
halo = function.halo
local_size = function.shape

def buff(i, j):
return [(i + k - 2*max(max(nbl))) for k in j]
lhss, rhss, optionss = [], [], []
for f, data in zip(functions, datas):
lhs, rhs, options = _initialize_function(f, data, nbl, mapper, mode)

b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))]
if any(np.array(b) < 0):
raise ValueError("Function `%s` halo is not sufficiently thick." % function)
lhss.extend(lhs)
rhss.extend(rhs)
optionss.extend(options)

for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)):
dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl)
dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr)
if mode == 'constant':
subsl = nl
subsr = d.symbolic_max - nr
elif mode == 'reflect':
subsl = 2*nl - 1 - dim_l
subsr = 2*(d.symbolic_max - nr) + 1 - dim_r
else:
raise ValueError("Mode not available")
lhs.append(function.subs({d: dim_l}))
lhs.append(function.subs({d: dim_r}))
rhs.append(function.subs({d: subsl}))
rhs.append(function.subs({d: subsr}))
options.extend([None, None])

if mapper and d in mapper.keys():
exprs = mapper[d]
lhs_extra = exprs['lhs']
rhs_extra = exprs['rhs']
lhs.extend(as_list(lhs_extra))
rhs.extend(as_list(rhs_extra))
options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ])
if isinstance(options_extra, list):
options.extend(options_extra)
else:
options.extend([options_extra])

if all(options is None for i in options):
options = None
assert len(lhss) == len(rhss) == len(optionss)

assign(lhs, rhs, options=options, name=name, **kwargs)
name = name or 'pad_%s' % '_'.join(f.name for f in functions)
assign(lhss, rhss, options=optionss, name=name, **kwargs)

if pad_halo:
pad_outhalo(function)
for f in functions:
pad_outhalo(f)
17 changes: 17 additions & 0 deletions tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,23 @@ def test_if_halo_mpi(self, nbl):
expected = np.pad(a[na//2:, na//2:], [(0, 1+nbl), (0, 1+nbl)], 'edge')
assert np.all(f._data_with_outhalo._local == expected)

def test_batching(self):
grid = Grid(shape=(12, 12))

a = np.arange(16).reshape((4, 4))

f = Function(name='f', grid=grid, dtype=np.int32)
g = Function(name='g', grid=grid, dtype=np.int32)
h = Function(name='h', grid=grid, dtype=np.int32)

initialize_function([f, g, h], [a, a, a], 4, mode='reflect')

for i in [f, g, h]:
assert np.all(a[:, ::-1] - np.array(i.data[4:8, 0:4]) == 0)
assert np.all(a[:, ::-1] - np.array(i.data[4:8, 8:12]) == 0)
assert np.all(a[::-1, :] - np.array(i.data[0:4, 4:8]) == 0)
assert np.all(a[::-1, :] - np.array(i.data[8:12, 4:8]) == 0)


class TestBuiltinsResult(object):

Expand Down

0 comments on commit 9b9edb5

Please sign in to comment.