From 9b9edb5a4d138a8dd2bff18fddede114c91153f2 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 1 Aug 2023 10:01:16 +0100 Subject: [PATCH] builtins: Support batched initialize_function --- devito/builtins/initializers.py | 158 +++++++++++++++++++------------- tests/test_builtins.py | 17 ++++ 2 files changed, 112 insertions(+), 63 deletions(-) diff --git a/devito/builtins/initializers.py b/devito/builtins/initializers.py index db0cd640cce..0f1e9ea5c26 100644 --- a/devito/builtins/initializers.py +++ b/devito/builtins/initializers.py @@ -216,6 +216,66 @@ def fset(f, g): return f +def _initialize_function(function, data, nbl, mapper=None, mode='constant'): + """ + Construct the symbolic objects for `initialize_function`. + """ + nbl, slices = nbl_to_padsize(nbl, function.ndim) + if isinstance(data, dv.Function): + function.data[slices] = data.data[:] + else: + function.data[slices] = data + lhs = [] + rhs = [] + options = [] + + if mode == 'reflect' and function.grid.distributor.is_parallel: + # Check that HALO size is appropriate + halo = function.halo + local_size = function.shape + + def buff(i, j): + return [(i + k - 2*max(max(nbl))) for k in j] + + b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))] + if any(np.array(b) < 0): + raise ValueError("Function `%s` halo is not sufficiently thick." % function) + + for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)): + dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl) + dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr) + if mode == 'constant': + subsl = nl + subsr = d.symbolic_max - nr + elif mode == 'reflect': + subsl = 2*nl - 1 - dim_l + subsr = 2*(d.symbolic_max - nr) + 1 - dim_r + else: + raise ValueError("Mode not available") + lhs.append(function.subs({d: dim_l})) + lhs.append(function.subs({d: dim_r})) + rhs.append(function.subs({d: subsl})) + rhs.append(function.subs({d: subsr})) + options.extend([None, None]) + + if mapper and d in mapper.keys(): + exprs = mapper[d] + lhs_extra = exprs['lhs'] + rhs_extra = exprs['rhs'] + lhs.extend(as_list(lhs_extra)) + rhs.extend(as_list(rhs_extra)) + options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ]) + if isinstance(options_extra, list): + options.extend(options_extra) + else: + options.extend([options_extra]) + + if all(options is None for i in options): + options = None + + return lhs, rhs, options + + def initialize_function(function, data, nbl, mapper=None, mode='constant', name=None, pad_halo=True, **kwargs): """ @@ -225,9 +285,9 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant', Parameters ---------- - function : Function + function : Function or list of Functions The initialised object. - data : ndarray or Function + data : ndarray or Function or list of ndarray/Function The data used for initialisation. nbl : int or tuple of int or tuple of tuple of int Number of outer layers (such as absorbing layers for boundary damping). @@ -286,73 +346,45 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant', [2, 3, 3, 3, 3, 2], [2, 2, 2, 2, 2, 2]], dtype=int32) """ - name = name or 'pad_%s' % function.name - if isinstance(function, dv.TimeFunction): + if isinstance(function, (list, tuple)): + if not isinstance(data, (list, tuple)): + raise TypeError("Expected a list of `data`") + elif len(function) != len(data): + raise ValueError("Expected %d `data` items, got %d" % + (len(function), len(data))) + + if mapper is not None: + raise NotImplementedError("Unsupported `mapper` with batching") + + functions = function + datas = data + else: + functions = (function,) + datas = (data,) + + if any(isinstance(f, dv.TimeFunction) for f in functions): raise NotImplementedError("TimeFunctions are not currently supported.") if nbl == 0: - if isinstance(data, dv.Function): - function.data[:] = data.data[:] - else: - function.data[:] = data[:] - if pad_halo: - pad_outhalo(function) - return - - nbl, slices = nbl_to_padsize(nbl, function.ndim) - if isinstance(data, dv.Function): - function.data[slices] = data.data[:] + for f in functions: + if isinstance(data, dv.Function): + f.data[:] = data.data[:] + else: + f.data[:] = data[:] else: - function.data[slices] = data - lhs = [] - rhs = [] - options = [] - - if mode == 'reflect' and function.grid.distributor.is_parallel: - # Check that HALO size is appropriate - halo = function.halo - local_size = function.shape - - def buff(i, j): - return [(i + k - 2*max(max(nbl))) for k in j] + lhss, rhss, optionss = [], [], [] + for f, data in zip(functions, datas): + lhs, rhs, options = _initialize_function(f, data, nbl, mapper, mode) - b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))] - if any(np.array(b) < 0): - raise ValueError("Function `%s` halo is not sufficiently thick." % function) + lhss.extend(lhs) + rhss.extend(rhs) + optionss.extend(options) - for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)): - dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl) - dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr) - if mode == 'constant': - subsl = nl - subsr = d.symbolic_max - nr - elif mode == 'reflect': - subsl = 2*nl - 1 - dim_l - subsr = 2*(d.symbolic_max - nr) + 1 - dim_r - else: - raise ValueError("Mode not available") - lhs.append(function.subs({d: dim_l})) - lhs.append(function.subs({d: dim_r})) - rhs.append(function.subs({d: subsl})) - rhs.append(function.subs({d: subsr})) - options.extend([None, None]) - - if mapper and d in mapper.keys(): - exprs = mapper[d] - lhs_extra = exprs['lhs'] - rhs_extra = exprs['rhs'] - lhs.extend(as_list(lhs_extra)) - rhs.extend(as_list(rhs_extra)) - options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ]) - if isinstance(options_extra, list): - options.extend(options_extra) - else: - options.extend([options_extra]) - - if all(options is None for i in options): - options = None + assert len(lhss) == len(rhss) == len(optionss) - assign(lhs, rhs, options=options, name=name, **kwargs) + name = name or 'pad_%s' % '_'.join(f.name for f in functions) + assign(lhss, rhss, options=optionss, name=name, **kwargs) if pad_halo: - pad_outhalo(function) + for f in functions: + pad_outhalo(f) diff --git a/tests/test_builtins.py b/tests/test_builtins.py index 0c70d63e868..346095b8162 100644 --- a/tests/test_builtins.py +++ b/tests/test_builtins.py @@ -327,6 +327,23 @@ def test_if_halo_mpi(self, nbl): expected = np.pad(a[na//2:, na//2:], [(0, 1+nbl), (0, 1+nbl)], 'edge') assert np.all(f._data_with_outhalo._local == expected) + def test_batching(self): + grid = Grid(shape=(12, 12)) + + a = np.arange(16).reshape((4, 4)) + + f = Function(name='f', grid=grid, dtype=np.int32) + g = Function(name='g', grid=grid, dtype=np.int32) + h = Function(name='h', grid=grid, dtype=np.int32) + + initialize_function([f, g, h], [a, a, a], 4, mode='reflect') + + for i in [f, g, h]: + assert np.all(a[:, ::-1] - np.array(i.data[4:8, 0:4]) == 0) + assert np.all(a[:, ::-1] - np.array(i.data[4:8, 8:12]) == 0) + assert np.all(a[::-1, :] - np.array(i.data[0:4, 4:8]) == 0) + assert np.all(a[::-1, :] - np.array(i.data[8:12, 4:8]) == 0) + class TestBuiltinsResult(object):