Skip to content

Commit

Permalink
mpi: drop halospots with empty iters
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 17, 2023
1 parent d1d0369 commit c34ed8e
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 21 deletions.
5 changes: 2 additions & 3 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class DeviceAccizer(PragmaDeviceAwareTransformer):

lang = AccBB

def _make_partree(self, candidates, nthreads=None, index=0):
def _make_partree(self, candidates, nthreads=None):
assert candidates

root, collapsable = self._select_candidates(candidates)
Expand All @@ -164,8 +164,7 @@ def _make_partree(self, candidates, nthreads=None, index=0):
if self._is_offloadable(root) and \
all(i.is_Affine for i in [root] + collapsable) and \
self.par_tile:
idx = min(index, len(self.par_tile) - 1)
tile = self.par_tile[idx]
tile = self.par_tile.next()
assert isinstance(tile, tuple)
nremainder = (ncollapsable + 1) - len(tile)
if nremainder >= 0:
Expand Down
7 changes: 7 additions & 0 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def _drop_halospots(iet):
if f in hs.fmapper and all(i.is_reduction for i in v):
mapper[hs].add(f)

# If a HaloSpot is outside any iteration it is not needed
for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items():
if not iters and halo_spots:
for hs in halo_spots:
for f in hs.fmapper:
mapper[hs].add(f)

# Transform the IET introducing the "reduced" HaloSpots
subs = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs]))
for hs in FindNodes(HaloSpot).visit(iet)}
Expand Down
9 changes: 5 additions & 4 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
make_sections_from_imask)
from devito.symbolics import INT, ccode
from devito.tools import as_tuple, flatten, is_integer, prod
from devito.tools.data_structures import UnboundTuple
from devito.types import Symbol

__all__ = ['PragmaSimdTransformer', 'PragmaShmTransformer',
Expand Down Expand Up @@ -347,7 +348,7 @@ def _make_threaded_prodders(self, partree):
partree = Transformer(mapper).visit(partree)
return partree

def _make_partree(self, candidates, nthreads=None, index=None):
def _make_partree(self, candidates, nthreads=None):
assert candidates

# Get the collapsable Iterations
Expand Down Expand Up @@ -465,7 +466,7 @@ def _make_nested_partree(self, partree):
def _make_parallel(self, iet):
mapper = {}
parrays = {}
for i, tree in enumerate(retrieve_iteration_tree(iet, mode='superset')):
for tree in retrieve_iteration_tree(iet, mode='superset'):
# Get the parallelizable Iterations in `tree`
candidates = filter_iterations(tree, key=self.key)
if not candidates:
Expand All @@ -477,7 +478,7 @@ def _make_parallel(self, iet):
continue

# Outer parallelism
root, partree = self._make_partree(candidates, index=i)
root, partree = self._make_partree(candidates)
if partree is None or root in mapper:
continue

Expand Down Expand Up @@ -566,7 +567,7 @@ def __init__(self, sregistry, options, platform, compiler):
super().__init__(sregistry, options, platform, compiler)

self.gpu_fit = options['gpu-fit']
self.par_tile = options['par-tile']
self.par_tile = UnboundTuple(options['par-tile'])
self.par_disabled = options['par-disabled']

def _make_threaded_prodders(self, partree):
Expand Down
14 changes: 9 additions & 5 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import singledispatch

import numpy as np
from sympy import Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple
from sympy import (Function, Indexed, Integer, Mul, Number,
Pow, S, Symbol, Tuple)
from sympy.core.operations import AssocOp

from devito.finite_differences import Derivative
from devito.finite_differences.differentiable import IndexDerivative
Expand Down Expand Up @@ -268,13 +270,15 @@ def sympy_dtype(expr, default):
Try to infer the data type of the expression
returns the default if non is found
"""
args = expr.args
# Symbol/... without argument, check its dtype
if len(args) == 0:
if len(expr.args) == 0:
try:
return expr.dtype
except AttributeError:
return default
else:
# Infer expression dtype from its arguments
return infer_dtype([sympy_dtype(a, default) for a in expr.args])
if not (isinstance(expr.func, AssocOp) or expr.is_Pow):
return default
else:
# Infer expression dtype from its arguments
return infer_dtype([sympy_dtype(a, default) for a in expr.args])
18 changes: 18 additions & 0 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,21 @@ def next(self):
if self.curiter is None:
raise StopIteration
return next(self.curiter)


class UnboundTuple(object):
"""
A simple data structure that returns the last element forever once reached
"""
def __init__(self, items):
self.items = as_tuple(items)
self.last = len(self.items)
self.current = 0

def next(self):
item = self.items[self.current]
self.current = min(self.last - 1, self.current+1)
return item

def __len__(self):
return self.last
1 change: 1 addition & 0 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def test_simd_space_invariant(self):
assert 'omp simd' in iterations[3].pragmas[0].value

op.apply()
print(op._lib)
assert np.isclose(np.linalg.norm(f.data), 37.1458, rtol=1e-5)

def test_parallel_prec_inject(self):
Expand Down
14 changes: 7 additions & 7 deletions tests/test_gpu_openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def test_tile_insteadof_collapse(self, par_tile):
trees = retrieve_iteration_tree(op)
assert len(trees) == 6

assert trees[0][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
assert trees[1][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
assert trees[2][1].pragmas[0].value ==\
'acc parallel loop tile(32,4) present(u)'
# Only the AFFINE Iterations are tiled
assert trees[3][1].pragmas[0].value ==\
'acc parallel loop present(src,src_coords,u)'
assert trees[4][1].pragmas[0].value ==\
'acc parallel loop present(src,src_coords,u) deviceptr(r1,r2,r3)'

@pytest.mark.parametrize('par_tile', [((32, 4, 4), (8, 8)), ((32, 4), (8, 8)),
((32, 4, 4), (8, 8, 8))])
Expand All @@ -130,11 +130,11 @@ def test_multiple_tile_sizes(self, par_tile):
opt=('advanced', {'par-tile': par_tile}))

trees = retrieve_iteration_tree(op)
assert len(trees) == 4
assert len(trees) == 6

assert trees[0][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
assert trees[1][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
assert trees[2][1].pragmas[0].value ==\
'acc parallel loop tile(8,8) present(u)'

def test_multi_tile_blocking_structure(self):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2493,8 +2493,10 @@ def test_adjoint_codegen(self, shape, kernel, space_order, save):
op_adj = solver.op_adj()
adj_calls = FindNodes(Call).visit(op_adj)

assert len(fwd_calls) == 1
assert len(adj_calls) == 1
# one halo, 2 * ndim memalign and free (pos temp src/rec)
sf_calls = 2 * len(shape) + 2 * len(shape)
assert len(fwd_calls) == 1 + sf_calls
assert len(adj_calls) == 1 + sf_calls

def run_adjoint_F(self, nd):
"""
Expand Down

0 comments on commit c34ed8e

Please sign in to comment.